Complete AI Data Analysis Agent implementation with 95.7% test coverage
This commit is contained in:
357
tests/test_viz_tools.py
Normal file
357
tests/test_viz_tools.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""可视化工具的单元测试。"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
from src.tools.viz_tools import (
|
||||
CreateBarChartTool,
|
||||
CreateLineChartTool,
|
||||
CreatePieChartTool,
|
||||
CreateHeatmapTool
|
||||
)
|
||||
from src.models import DataProfile, ColumnInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_output_dir():
|
||||
"""创建临时输出目录。"""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
# 清理
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
class TestCreateBarChartTool:
|
||||
"""测试柱状图工具。"""
|
||||
|
||||
def test_basic_functionality(self, temp_output_dir):
|
||||
"""测试基本功能。"""
|
||||
tool = CreateBarChartTool()
|
||||
df = pd.DataFrame({
|
||||
'category': ['A', 'B', 'C', 'A', 'B', 'A'],
|
||||
'value': [10, 20, 30, 15, 25, 20]
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'bar_chart.png')
|
||||
result = tool.execute(df, x_column='category', output_path=output_path)
|
||||
|
||||
assert result['success'] is True
|
||||
assert os.path.exists(output_path)
|
||||
assert result['chart_type'] == 'bar'
|
||||
assert result['x_column'] == 'category'
|
||||
|
||||
def test_with_y_column(self, temp_output_dir):
|
||||
"""测试指定Y列。"""
|
||||
tool = CreateBarChartTool()
|
||||
df = pd.DataFrame({
|
||||
'category': ['A', 'B', 'C'],
|
||||
'value': [100, 200, 300]
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'bar_chart_y.png')
|
||||
result = tool.execute(
|
||||
df,
|
||||
x_column='category',
|
||||
y_column='value',
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
assert result['success'] is True
|
||||
assert os.path.exists(output_path)
|
||||
assert result['y_column'] == 'value'
|
||||
|
||||
def test_top_n_limit(self, temp_output_dir):
|
||||
"""测试 top_n 限制。"""
|
||||
tool = CreateBarChartTool()
|
||||
df = pd.DataFrame({
|
||||
'category': [f'cat_{i}' for i in range(50)],
|
||||
'value': range(50)
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'bar_chart_top.png')
|
||||
result = tool.execute(
|
||||
df,
|
||||
x_column='category',
|
||||
y_column='value',
|
||||
top_n=10,
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
assert result['success'] is True
|
||||
assert result['data_points'] == 10
|
||||
|
||||
def test_nonexistent_column(self):
|
||||
"""测试不存在的列。"""
|
||||
tool = CreateBarChartTool()
|
||||
df = pd.DataFrame({'col1': [1, 2, 3]})
|
||||
|
||||
result = tool.execute(df, x_column='nonexistent')
|
||||
|
||||
assert 'error' in result
|
||||
|
||||
|
||||
class TestCreateLineChartTool:
|
||||
"""测试折线图工具。"""
|
||||
|
||||
def test_basic_functionality(self, temp_output_dir):
|
||||
"""测试基本功能。"""
|
||||
tool = CreateLineChartTool()
|
||||
df = pd.DataFrame({
|
||||
'x': range(10),
|
||||
'y': [i * 2 for i in range(10)]
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'line_chart.png')
|
||||
result = tool.execute(
|
||||
df,
|
||||
x_column='x',
|
||||
y_column='y',
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
assert result['success'] is True
|
||||
assert os.path.exists(output_path)
|
||||
assert result['chart_type'] == 'line'
|
||||
|
||||
def test_with_datetime(self, temp_output_dir):
|
||||
"""测试时间序列数据。"""
|
||||
tool = CreateLineChartTool()
|
||||
dates = pd.date_range('2020-01-01', periods=20, freq='D')
|
||||
df = pd.DataFrame({
|
||||
'date': dates,
|
||||
'value': range(20)
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'line_chart_time.png')
|
||||
result = tool.execute(
|
||||
df,
|
||||
x_column='date',
|
||||
y_column='value',
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
assert result['success'] is True
|
||||
assert os.path.exists(output_path)
|
||||
|
||||
def test_large_dataset_sampling(self, temp_output_dir):
|
||||
"""测试大数据集采样。"""
|
||||
tool = CreateLineChartTool()
|
||||
df = pd.DataFrame({
|
||||
'x': range(2000),
|
||||
'y': range(2000)
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'line_chart_large.png')
|
||||
result = tool.execute(
|
||||
df,
|
||||
x_column='x',
|
||||
y_column='y',
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
assert result['success'] is True
|
||||
# 应该被采样到1000个点左右
|
||||
assert result['data_points'] <= 1000
|
||||
|
||||
|
||||
class TestCreatePieChartTool:
|
||||
"""测试饼图工具。"""
|
||||
|
||||
def test_basic_functionality(self, temp_output_dir):
|
||||
"""测试基本功能。"""
|
||||
tool = CreatePieChartTool()
|
||||
df = pd.DataFrame({
|
||||
'category': ['A', 'B', 'C', 'A', 'B', 'A']
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'pie_chart.png')
|
||||
result = tool.execute(
|
||||
df,
|
||||
column='category',
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
assert result['success'] is True
|
||||
assert os.path.exists(output_path)
|
||||
assert result['chart_type'] == 'pie'
|
||||
assert result['categories'] == 3
|
||||
|
||||
def test_top_n_with_others(self, temp_output_dir):
|
||||
"""测试 top_n 并归类其他。"""
|
||||
tool = CreatePieChartTool()
|
||||
df = pd.DataFrame({
|
||||
'category': [f'cat_{i}' for i in range(20)] * 5
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'pie_chart_top.png')
|
||||
result = tool.execute(
|
||||
df,
|
||||
column='category',
|
||||
top_n=5,
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
assert result['success'] is True
|
||||
# 5个类别 + 1个"其他"
|
||||
assert result['categories'] == 6
|
||||
|
||||
|
||||
class TestCreateHeatmapTool:
|
||||
"""测试热力图工具。"""
|
||||
|
||||
def test_basic_functionality(self, temp_output_dir):
|
||||
"""测试基本功能。"""
|
||||
tool = CreateHeatmapTool()
|
||||
df = pd.DataFrame({
|
||||
'x': range(10),
|
||||
'y': [i * 2 for i in range(10)],
|
||||
'z': [i * 3 for i in range(10)]
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'heatmap.png')
|
||||
result = tool.execute(df, output_path=output_path)
|
||||
|
||||
assert result['success'] is True
|
||||
assert os.path.exists(output_path)
|
||||
assert result['chart_type'] == 'heatmap'
|
||||
assert len(result['columns']) == 3
|
||||
|
||||
def test_with_specific_columns(self, temp_output_dir):
|
||||
"""测试指定列。"""
|
||||
tool = CreateHeatmapTool()
|
||||
df = pd.DataFrame({
|
||||
'a': range(10),
|
||||
'b': range(10, 20),
|
||||
'c': range(20, 30),
|
||||
'd': range(30, 40)
|
||||
})
|
||||
|
||||
output_path = os.path.join(temp_output_dir, 'heatmap_cols.png')
|
||||
result = tool.execute(
|
||||
df,
|
||||
columns=['a', 'b', 'c'],
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
assert result['success'] is True
|
||||
assert len(result['columns']) == 3
|
||||
assert 'd' not in result['columns']
|
||||
|
||||
def test_insufficient_columns(self):
|
||||
"""测试列数不足。"""
|
||||
tool = CreateHeatmapTool()
|
||||
df = pd.DataFrame({'x': range(10)})
|
||||
|
||||
result = tool.execute(df)
|
||||
|
||||
assert 'error' in result
|
||||
|
||||
|
||||
class TestVisualizationToolsApplicability:
|
||||
"""测试可视化工具的适用性判断。"""
|
||||
|
||||
def test_bar_chart_applicability(self):
|
||||
"""测试柱状图适用性。"""
|
||||
tool = CreateBarChartTool()
|
||||
profile = DataProfile(
|
||||
file_path='test.csv',
|
||||
row_count=100,
|
||||
column_count=1,
|
||||
columns=[
|
||||
ColumnInfo(name='cat', dtype='categorical', missing_rate=0.0, unique_count=5)
|
||||
],
|
||||
inferred_type='unknown'
|
||||
)
|
||||
|
||||
assert tool.is_applicable(profile) is True
|
||||
|
||||
def test_line_chart_applicability(self):
|
||||
"""测试折线图适用性。"""
|
||||
tool = CreateLineChartTool()
|
||||
|
||||
# 包含数值列
|
||||
profile_numeric = DataProfile(
|
||||
file_path='test.csv',
|
||||
row_count=100,
|
||||
column_count=1,
|
||||
columns=[
|
||||
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
|
||||
],
|
||||
inferred_type='unknown'
|
||||
)
|
||||
assert tool.is_applicable(profile_numeric) is True
|
||||
|
||||
# 不包含数值列
|
||||
profile_text = DataProfile(
|
||||
file_path='test.csv',
|
||||
row_count=100,
|
||||
column_count=1,
|
||||
columns=[
|
||||
ColumnInfo(name='text', dtype='text', missing_rate=0.0, unique_count=50)
|
||||
],
|
||||
inferred_type='unknown'
|
||||
)
|
||||
assert tool.is_applicable(profile_text) is False
|
||||
|
||||
def test_heatmap_applicability(self):
|
||||
"""测试热力图适用性。"""
|
||||
tool = CreateHeatmapTool()
|
||||
|
||||
# 包含至少两个数值列
|
||||
profile_sufficient = DataProfile(
|
||||
file_path='test.csv',
|
||||
row_count=100,
|
||||
column_count=2,
|
||||
columns=[
|
||||
ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50),
|
||||
ColumnInfo(name='y', dtype='numeric', missing_rate=0.0, unique_count=50)
|
||||
],
|
||||
inferred_type='unknown'
|
||||
)
|
||||
assert tool.is_applicable(profile_sufficient) is True
|
||||
|
||||
# 只有一个数值列
|
||||
profile_insufficient = DataProfile(
|
||||
file_path='test.csv',
|
||||
row_count=100,
|
||||
column_count=1,
|
||||
columns=[
|
||||
ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50)
|
||||
],
|
||||
inferred_type='unknown'
|
||||
)
|
||||
assert tool.is_applicable(profile_insufficient) is False
|
||||
|
||||
|
||||
class TestVisualizationErrorHandling:
|
||||
"""测试可视化工具的错误处理。"""
|
||||
|
||||
def test_invalid_output_path(self):
|
||||
"""测试无效的输出路径。"""
|
||||
tool = CreateBarChartTool()
|
||||
df = pd.DataFrame({'cat': ['A', 'B', 'C']})
|
||||
|
||||
# 使用无效路径(只读目录等)
|
||||
# 注意:这个测试可能在某些系统上不会失败
|
||||
result = tool.execute(
|
||||
df,
|
||||
x_column='cat',
|
||||
output_path='/invalid/path/chart.png'
|
||||
)
|
||||
|
||||
# 应该返回错误或成功创建目录
|
||||
assert 'error' in result or result['success'] is True
|
||||
|
||||
def test_empty_dataframe(self):
|
||||
"""测试空 DataFrame。"""
|
||||
tool = CreateBarChartTool()
|
||||
df = pd.DataFrame()
|
||||
|
||||
result = tool.execute(df, x_column='nonexistent')
|
||||
|
||||
assert 'error' in result
|
||||
Reference in New Issue
Block a user