Files
vibe_data_ana/tests/test_viz_tools.py

358 lines
11 KiB
Python

"""可视化工具的单元测试。"""
import pytest
import pandas as pd
import numpy as np
import os
from pathlib import Path
import tempfile
import shutil
from src.tools.viz_tools import (
CreateBarChartTool,
CreateLineChartTool,
CreatePieChartTool,
CreateHeatmapTool
)
from src.models import DataProfile, ColumnInfo
@pytest.fixture
def temp_output_dir():
"""创建临时输出目录。"""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# 清理
shutil.rmtree(temp_dir, ignore_errors=True)
class TestCreateBarChartTool:
"""测试柱状图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreateBarChartTool()
df = pd.DataFrame({
'category': ['A', 'B', 'C', 'A', 'B', 'A'],
'value': [10, 20, 30, 15, 25, 20]
})
output_path = os.path.join(temp_output_dir, 'bar_chart.png')
result = tool.execute(df, x_column='category', output_path=output_path)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'bar'
assert result['x_column'] == 'category'
def test_with_y_column(self, temp_output_dir):
"""测试指定Y列。"""
tool = CreateBarChartTool()
df = pd.DataFrame({
'category': ['A', 'B', 'C'],
'value': [100, 200, 300]
})
output_path = os.path.join(temp_output_dir, 'bar_chart_y.png')
result = tool.execute(
df,
x_column='category',
y_column='value',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['y_column'] == 'value'
def test_top_n_limit(self, temp_output_dir):
"""测试 top_n 限制。"""
tool = CreateBarChartTool()
df = pd.DataFrame({
'category': [f'cat_{i}' for i in range(50)],
'value': range(50)
})
output_path = os.path.join(temp_output_dir, 'bar_chart_top.png')
result = tool.execute(
df,
x_column='category',
y_column='value',
top_n=10,
output_path=output_path
)
assert result['success'] is True
assert result['data_points'] == 10
def test_nonexistent_column(self):
"""测试不存在的列。"""
tool = CreateBarChartTool()
df = pd.DataFrame({'col1': [1, 2, 3]})
result = tool.execute(df, x_column='nonexistent')
assert 'error' in result
class TestCreateLineChartTool:
"""测试折线图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreateLineChartTool()
df = pd.DataFrame({
'x': range(10),
'y': [i * 2 for i in range(10)]
})
output_path = os.path.join(temp_output_dir, 'line_chart.png')
result = tool.execute(
df,
x_column='x',
y_column='y',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'line'
def test_with_datetime(self, temp_output_dir):
"""测试时间序列数据。"""
tool = CreateLineChartTool()
dates = pd.date_range('2020-01-01', periods=20, freq='D')
df = pd.DataFrame({
'date': dates,
'value': range(20)
})
output_path = os.path.join(temp_output_dir, 'line_chart_time.png')
result = tool.execute(
df,
x_column='date',
y_column='value',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
def test_large_dataset_sampling(self, temp_output_dir):
"""测试大数据集采样。"""
tool = CreateLineChartTool()
df = pd.DataFrame({
'x': range(2000),
'y': range(2000)
})
output_path = os.path.join(temp_output_dir, 'line_chart_large.png')
result = tool.execute(
df,
x_column='x',
y_column='y',
output_path=output_path
)
assert result['success'] is True
# 应该被采样到1000个点左右
assert result['data_points'] <= 1000
class TestCreatePieChartTool:
"""测试饼图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreatePieChartTool()
df = pd.DataFrame({
'category': ['A', 'B', 'C', 'A', 'B', 'A']
})
output_path = os.path.join(temp_output_dir, 'pie_chart.png')
result = tool.execute(
df,
column='category',
output_path=output_path
)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'pie'
assert result['categories'] == 3
def test_top_n_with_others(self, temp_output_dir):
"""测试 top_n 并归类其他。"""
tool = CreatePieChartTool()
df = pd.DataFrame({
'category': [f'cat_{i}' for i in range(20)] * 5
})
output_path = os.path.join(temp_output_dir, 'pie_chart_top.png')
result = tool.execute(
df,
column='category',
top_n=5,
output_path=output_path
)
assert result['success'] is True
# 5个类别 + 1个"其他"
assert result['categories'] == 6
class TestCreateHeatmapTool:
"""测试热力图工具。"""
def test_basic_functionality(self, temp_output_dir):
"""测试基本功能。"""
tool = CreateHeatmapTool()
df = pd.DataFrame({
'x': range(10),
'y': [i * 2 for i in range(10)],
'z': [i * 3 for i in range(10)]
})
output_path = os.path.join(temp_output_dir, 'heatmap.png')
result = tool.execute(df, output_path=output_path)
assert result['success'] is True
assert os.path.exists(output_path)
assert result['chart_type'] == 'heatmap'
assert len(result['columns']) == 3
def test_with_specific_columns(self, temp_output_dir):
"""测试指定列。"""
tool = CreateHeatmapTool()
df = pd.DataFrame({
'a': range(10),
'b': range(10, 20),
'c': range(20, 30),
'd': range(30, 40)
})
output_path = os.path.join(temp_output_dir, 'heatmap_cols.png')
result = tool.execute(
df,
columns=['a', 'b', 'c'],
output_path=output_path
)
assert result['success'] is True
assert len(result['columns']) == 3
assert 'd' not in result['columns']
def test_insufficient_columns(self):
"""测试列数不足。"""
tool = CreateHeatmapTool()
df = pd.DataFrame({'x': range(10)})
result = tool.execute(df)
assert 'error' in result
class TestVisualizationToolsApplicability:
"""测试可视化工具的适用性判断。"""
def test_bar_chart_applicability(self):
"""测试柱状图适用性。"""
tool = CreateBarChartTool()
profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='cat', dtype='categorical', missing_rate=0.0, unique_count=5)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile) is True
def test_line_chart_applicability(self):
"""测试折线图适用性。"""
tool = CreateLineChartTool()
# 包含数值列
profile_numeric = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_numeric) is True
# 不包含数值列
profile_text = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='text', dtype='text', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_text) is False
def test_heatmap_applicability(self):
"""测试热力图适用性。"""
tool = CreateHeatmapTool()
# 包含至少两个数值列
profile_sufficient = DataProfile(
file_path='test.csv',
row_count=100,
column_count=2,
columns=[
ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50),
ColumnInfo(name='y', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_sufficient) is True
# 只有一个数值列
profile_insufficient = DataProfile(
file_path='test.csv',
row_count=100,
column_count=1,
columns=[
ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50)
],
inferred_type='unknown'
)
assert tool.is_applicable(profile_insufficient) is False
class TestVisualizationErrorHandling:
"""测试可视化工具的错误处理。"""
def test_invalid_output_path(self):
"""测试无效的输出路径。"""
tool = CreateBarChartTool()
df = pd.DataFrame({'cat': ['A', 'B', 'C']})
# 使用无效路径(只读目录等)
# 注意:这个测试可能在某些系统上不会失败
result = tool.execute(
df,
x_column='cat',
output_path='/invalid/path/chart.png'
)
# 应该返回错误或成功创建目录
assert 'error' in result or result['success'] is True
def test_empty_dataframe(self):
"""测试空 DataFrame。"""
tool = CreateBarChartTool()
df = pd.DataFrame()
result = tool.execute(df, x_column='nonexistent')
assert 'error' in result