"""可视化工具的单元测试。""" import pytest import pandas as pd import numpy as np import os from pathlib import Path import tempfile import shutil from src.tools.viz_tools import ( CreateBarChartTool, CreateLineChartTool, CreatePieChartTool, CreateHeatmapTool ) from src.models import DataProfile, ColumnInfo @pytest.fixture def temp_output_dir(): """创建临时输出目录。""" temp_dir = tempfile.mkdtemp() yield temp_dir # 清理 shutil.rmtree(temp_dir, ignore_errors=True) class TestCreateBarChartTool: """测试柱状图工具。""" def test_basic_functionality(self, temp_output_dir): """测试基本功能。""" tool = CreateBarChartTool() df = pd.DataFrame({ 'category': ['A', 'B', 'C', 'A', 'B', 'A'], 'value': [10, 20, 30, 15, 25, 20] }) output_path = os.path.join(temp_output_dir, 'bar_chart.png') result = tool.execute(df, x_column='category', output_path=output_path) assert result['success'] is True assert os.path.exists(output_path) assert result['chart_type'] == 'bar' assert result['x_column'] == 'category' def test_with_y_column(self, temp_output_dir): """测试指定Y列。""" tool = CreateBarChartTool() df = pd.DataFrame({ 'category': ['A', 'B', 'C'], 'value': [100, 200, 300] }) output_path = os.path.join(temp_output_dir, 'bar_chart_y.png') result = tool.execute( df, x_column='category', y_column='value', output_path=output_path ) assert result['success'] is True assert os.path.exists(output_path) assert result['y_column'] == 'value' def test_top_n_limit(self, temp_output_dir): """测试 top_n 限制。""" tool = CreateBarChartTool() df = pd.DataFrame({ 'category': [f'cat_{i}' for i in range(50)], 'value': range(50) }) output_path = os.path.join(temp_output_dir, 'bar_chart_top.png') result = tool.execute( df, x_column='category', y_column='value', top_n=10, output_path=output_path ) assert result['success'] is True assert result['data_points'] == 10 def test_nonexistent_column(self): """测试不存在的列。""" tool = CreateBarChartTool() df = pd.DataFrame({'col1': [1, 2, 3]}) result = tool.execute(df, x_column='nonexistent') assert 'error' in result class TestCreateLineChartTool: """测试折线图工具。""" def test_basic_functionality(self, temp_output_dir): """测试基本功能。""" tool = CreateLineChartTool() df = pd.DataFrame({ 'x': range(10), 'y': [i * 2 for i in range(10)] }) output_path = os.path.join(temp_output_dir, 'line_chart.png') result = tool.execute( df, x_column='x', y_column='y', output_path=output_path ) assert result['success'] is True assert os.path.exists(output_path) assert result['chart_type'] == 'line' def test_with_datetime(self, temp_output_dir): """测试时间序列数据。""" tool = CreateLineChartTool() dates = pd.date_range('2020-01-01', periods=20, freq='D') df = pd.DataFrame({ 'date': dates, 'value': range(20) }) output_path = os.path.join(temp_output_dir, 'line_chart_time.png') result = tool.execute( df, x_column='date', y_column='value', output_path=output_path ) assert result['success'] is True assert os.path.exists(output_path) def test_large_dataset_sampling(self, temp_output_dir): """测试大数据集采样。""" tool = CreateLineChartTool() df = pd.DataFrame({ 'x': range(2000), 'y': range(2000) }) output_path = os.path.join(temp_output_dir, 'line_chart_large.png') result = tool.execute( df, x_column='x', y_column='y', output_path=output_path ) assert result['success'] is True # 应该被采样到1000个点左右 assert result['data_points'] <= 1000 class TestCreatePieChartTool: """测试饼图工具。""" def test_basic_functionality(self, temp_output_dir): """测试基本功能。""" tool = CreatePieChartTool() df = pd.DataFrame({ 'category': ['A', 'B', 'C', 'A', 'B', 'A'] }) output_path = os.path.join(temp_output_dir, 'pie_chart.png') result = tool.execute( df, column='category', output_path=output_path ) assert result['success'] is True assert os.path.exists(output_path) assert result['chart_type'] == 'pie' assert result['categories'] == 3 def test_top_n_with_others(self, temp_output_dir): """测试 top_n 并归类其他。""" tool = CreatePieChartTool() df = pd.DataFrame({ 'category': [f'cat_{i}' for i in range(20)] * 5 }) output_path = os.path.join(temp_output_dir, 'pie_chart_top.png') result = tool.execute( df, column='category', top_n=5, output_path=output_path ) assert result['success'] is True # 5个类别 + 1个"其他" assert result['categories'] == 6 class TestCreateHeatmapTool: """测试热力图工具。""" def test_basic_functionality(self, temp_output_dir): """测试基本功能。""" tool = CreateHeatmapTool() df = pd.DataFrame({ 'x': range(10), 'y': [i * 2 for i in range(10)], 'z': [i * 3 for i in range(10)] }) output_path = os.path.join(temp_output_dir, 'heatmap.png') result = tool.execute(df, output_path=output_path) assert result['success'] is True assert os.path.exists(output_path) assert result['chart_type'] == 'heatmap' assert len(result['columns']) == 3 def test_with_specific_columns(self, temp_output_dir): """测试指定列。""" tool = CreateHeatmapTool() df = pd.DataFrame({ 'a': range(10), 'b': range(10, 20), 'c': range(20, 30), 'd': range(30, 40) }) output_path = os.path.join(temp_output_dir, 'heatmap_cols.png') result = tool.execute( df, columns=['a', 'b', 'c'], output_path=output_path ) assert result['success'] is True assert len(result['columns']) == 3 assert 'd' not in result['columns'] def test_insufficient_columns(self): """测试列数不足。""" tool = CreateHeatmapTool() df = pd.DataFrame({'x': range(10)}) result = tool.execute(df) assert 'error' in result class TestVisualizationToolsApplicability: """测试可视化工具的适用性判断。""" def test_bar_chart_applicability(self): """测试柱状图适用性。""" tool = CreateBarChartTool() profile = DataProfile( file_path='test.csv', row_count=100, column_count=1, columns=[ ColumnInfo(name='cat', dtype='categorical', missing_rate=0.0, unique_count=5) ], inferred_type='unknown' ) assert tool.is_applicable(profile) is True def test_line_chart_applicability(self): """测试折线图适用性。""" tool = CreateLineChartTool() # 包含数值列 profile_numeric = DataProfile( file_path='test.csv', row_count=100, column_count=1, columns=[ ColumnInfo(name='value', dtype='numeric', missing_rate=0.0, unique_count=50) ], inferred_type='unknown' ) assert tool.is_applicable(profile_numeric) is True # 不包含数值列 profile_text = DataProfile( file_path='test.csv', row_count=100, column_count=1, columns=[ ColumnInfo(name='text', dtype='text', missing_rate=0.0, unique_count=50) ], inferred_type='unknown' ) assert tool.is_applicable(profile_text) is False def test_heatmap_applicability(self): """测试热力图适用性。""" tool = CreateHeatmapTool() # 包含至少两个数值列 profile_sufficient = DataProfile( file_path='test.csv', row_count=100, column_count=2, columns=[ ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50), ColumnInfo(name='y', dtype='numeric', missing_rate=0.0, unique_count=50) ], inferred_type='unknown' ) assert tool.is_applicable(profile_sufficient) is True # 只有一个数值列 profile_insufficient = DataProfile( file_path='test.csv', row_count=100, column_count=1, columns=[ ColumnInfo(name='x', dtype='numeric', missing_rate=0.0, unique_count=50) ], inferred_type='unknown' ) assert tool.is_applicable(profile_insufficient) is False class TestVisualizationErrorHandling: """测试可视化工具的错误处理。""" def test_invalid_output_path(self): """测试无效的输出路径。""" tool = CreateBarChartTool() df = pd.DataFrame({'cat': ['A', 'B', 'C']}) # 使用无效路径(只读目录等) # 注意:这个测试可能在某些系统上不会失败 result = tool.execute( df, x_column='cat', output_path='/invalid/path/chart.png' ) # 应该返回错误或成功创建目录 assert 'error' in result or result['success'] is True def test_empty_dataframe(self): """测试空 DataFrame。""" tool = CreateBarChartTool() df = pd.DataFrame() result = tool.execute(df, x_column='nonexistent') assert 'error' in result