"""单元测试:错误处理机制。""" import pytest import pandas as pd import time from pathlib import Path from unittest.mock import Mock, patch, MagicMock import tempfile import os from src.error_handling import ( load_data_with_retry, call_llm_with_fallback, execute_tool_safely, execute_task_with_recovery, validate_tool_params, validate_tool_result, DataLoadError, AICallError, ToolExecutionError ) class TestLoadDataWithRetry: """测试数据加载错误处理。""" def test_load_valid_csv(self, tmp_path): """测试加载有效的 CSV 文件。""" # 创建测试文件 csv_file = tmp_path / "test.csv" df = pd.DataFrame({ 'col1': [1, 2, 3], 'col2': ['a', 'b', 'c'] }) df.to_csv(csv_file, index=False) # 加载数据 result = load_data_with_retry(str(csv_file)) assert len(result) == 3 assert len(result.columns) == 2 assert list(result.columns) == ['col1', 'col2'] def test_load_gbk_encoded_file(self, tmp_path): """测试加载 GBK 编码的文件。""" # 创建 GBK 编码的文件 csv_file = tmp_path / "test_gbk.csv" df = pd.DataFrame({ '列1': [1, 2, 3], '列2': ['中文', '测试', '数据'] }) df.to_csv(csv_file, index=False, encoding='gbk') # 加载数据 result = load_data_with_retry(str(csv_file)) assert len(result) == 3 assert '列1' in result.columns assert '列2' in result.columns def test_load_file_not_exists(self): """测试文件不存在的情况。""" with pytest.raises(DataLoadError, match="文件不存在"): load_data_with_retry("nonexistent.csv") def test_load_empty_file(self, tmp_path): """测试空文件的处理。""" # 创建空文件 csv_file = tmp_path / "empty.csv" csv_file.touch() with pytest.raises(DataLoadError, match="文件为空"): load_data_with_retry(str(csv_file)) def test_load_large_file_sampling(self, tmp_path): """测试大文件采样。""" # 创建大文件(模拟) csv_file = tmp_path / "large.csv" df = pd.DataFrame({ 'col1': range(2000000), 'col2': range(2000000) }) # 只保存前 1500000 行以加快测试 df.head(1500000).to_csv(csv_file, index=False) # 加载数据(应该采样到 1000000 行) result = load_data_with_retry(str(csv_file), sample_size=1000000) assert len(result) == 1000000 def test_load_different_separator(self, tmp_path): """测试不同分隔符的文件。""" # 创建使用分号分隔的文件 csv_file = tmp_path / "semicolon.csv" with open(csv_file, 'w') as f: f.write("col1;col2\n") f.write("1;a\n") f.write("2;b\n") # 加载数据 result = load_data_with_retry(str(csv_file)) assert len(result) == 2 assert len(result.columns) == 2 class TestCallLLMWithFallback: """测试 AI 调用错误处理。""" def test_successful_call(self): """测试成功的 AI 调用。""" mock_func = Mock(return_value={'result': 'success'}) result = call_llm_with_fallback(mock_func, prompt="test") assert result == {'result': 'success'} assert mock_func.call_count == 1 def test_retry_on_timeout(self): """测试超时重试机制。""" mock_func = Mock(side_effect=[ TimeoutError("timeout"), TimeoutError("timeout"), {'result': 'success'} ]) result = call_llm_with_fallback(mock_func, max_retries=3, prompt="test") assert result == {'result': 'success'} assert mock_func.call_count == 3 def test_exponential_backoff(self): """测试指数退避。""" mock_func = Mock(side_effect=[ Exception("error"), {'result': 'success'} ]) start_time = time.time() result = call_llm_with_fallback(mock_func, max_retries=3, prompt="test") elapsed = time.time() - start_time # 应该等待至少 1 秒(2^0) assert elapsed >= 1.0 assert result == {'result': 'success'} def test_fallback_on_failure(self): """测试降级策略。""" mock_func = Mock(side_effect=Exception("error")) fallback_func = Mock(return_value={'result': 'fallback'}) result = call_llm_with_fallback( mock_func, fallback_func=fallback_func, max_retries=2, prompt="test" ) assert result == {'result': 'fallback'} assert mock_func.call_count == 2 assert fallback_func.call_count == 1 def test_no_fallback_raises_error(self): """测试无降级策略时抛出错误。""" mock_func = Mock(side_effect=Exception("error")) with pytest.raises(AICallError, match="AI 调用失败"): call_llm_with_fallback(mock_func, max_retries=2, prompt="test") def test_fallback_also_fails(self): """测试降级策略也失败的情况。""" mock_func = Mock(side_effect=Exception("error")) fallback_func = Mock(side_effect=Exception("fallback error")) with pytest.raises(AICallError, match="AI 调用和降级策略都失败"): call_llm_with_fallback( mock_func, fallback_func=fallback_func, max_retries=2, prompt="test" ) class TestExecuteToolSafely: """测试工具执行错误处理。""" def test_successful_execution(self): """测试成功的工具执行。""" mock_tool = Mock() mock_tool.name = "test_tool" mock_tool.parameters = {'required': [], 'properties': {}} mock_tool.execute = Mock(return_value={'data': 'result'}) df = pd.DataFrame({'col1': [1, 2, 3]}) result = execute_tool_safely(mock_tool, df) assert result['success'] is True assert result['data'] == {'data': 'result'} assert result['tool'] == 'test_tool' def test_missing_execute_method(self): """测试工具缺少 execute 方法。""" mock_tool = Mock(spec=[]) mock_tool.name = "bad_tool" df = pd.DataFrame({'col1': [1, 2, 3]}) result = execute_tool_safely(mock_tool, df) assert result['success'] is False assert 'execute 方法' in result['error'] def test_parameter_validation_failure(self): """测试参数验证失败。""" mock_tool = Mock() mock_tool.name = "test_tool" mock_tool.parameters = { 'required': ['column'], 'properties': { 'column': {'type': 'string'} } } mock_tool.execute = Mock(return_value={'data': 'result'}) df = pd.DataFrame({'col1': [1, 2, 3]}) # 缺少必需参数 result = execute_tool_safely(mock_tool, df) assert result['success'] is False assert '参数验证失败' in result['error'] def test_empty_data(self): """测试空数据。""" mock_tool = Mock() mock_tool.name = "test_tool" mock_tool.parameters = {'required': [], 'properties': {}} df = pd.DataFrame() result = execute_tool_safely(mock_tool, df) assert result['success'] is False assert '数据为空' in result['error'] def test_execution_exception(self): """测试执行异常。""" mock_tool = Mock() mock_tool.name = "test_tool" mock_tool.parameters = {'required': [], 'properties': {}} mock_tool.execute = Mock(side_effect=Exception("execution error")) df = pd.DataFrame({'col1': [1, 2, 3]}) result = execute_tool_safely(mock_tool, df) assert result['success'] is False assert 'execution error' in result['error'] class TestValidateToolParams: """测试工具参数验证。""" def test_valid_params(self): """测试有效参数。""" mock_tool = Mock() mock_tool.parameters = { 'required': ['column'], 'properties': { 'column': {'type': 'string'} } } result = validate_tool_params(mock_tool, {'column': 'col1'}) assert result['valid'] is True def test_missing_required_param(self): """测试缺少必需参数。""" mock_tool = Mock() mock_tool.parameters = { 'required': ['column'], 'properties': {} } result = validate_tool_params(mock_tool, {}) assert result['valid'] is False assert '缺少必需参数' in result['error'] def test_wrong_param_type(self): """测试参数类型错误。""" mock_tool = Mock() mock_tool.parameters = { 'required': [], 'properties': { 'column': {'type': 'string'} } } result = validate_tool_params(mock_tool, {'column': 123}) assert result['valid'] is False assert '应为字符串类型' in result['error'] class TestValidateToolResult: """测试工具结果验证。""" def test_valid_result(self): """测试有效结果。""" result = validate_tool_result({'data': 'test'}) assert result['valid'] is True def test_none_result(self): """测试 None 结果。""" result = validate_tool_result(None) assert result['valid'] is False assert 'None' in result['error'] def test_wrong_type_result(self): """测试错误类型结果。""" result = validate_tool_result("string result") assert result['valid'] is False assert '类型错误' in result['error'] class TestExecuteTaskWithRecovery: """测试任务执行错误处理。""" def test_successful_execution(self): """测试成功的任务执行。""" mock_task = Mock() mock_task.id = "task1" mock_task.name = "Test Task" mock_task.dependencies = [] mock_plan = Mock() mock_plan.tasks = [mock_task] mock_execute = Mock(return_value=Mock(success=True)) result = execute_task_with_recovery(mock_task, mock_plan, mock_execute) assert mock_task.status == 'completed' assert mock_execute.call_count == 1 def test_skip_on_missing_dependency(self): """测试依赖任务不存在时跳过。""" mock_task = Mock() mock_task.id = "task2" mock_task.name = "Test Task" mock_task.dependencies = ["task1"] mock_plan = Mock() mock_plan.tasks = [mock_task] mock_execute = Mock() result = execute_task_with_recovery(mock_task, mock_plan, mock_execute) assert mock_task.status == 'skipped' assert mock_execute.call_count == 0 def test_skip_on_failed_dependency(self): """测试依赖任务失败时跳过。""" mock_dep_task = Mock() mock_dep_task.id = "task1" mock_dep_task.status = 'failed' mock_task = Mock() mock_task.id = "task2" mock_task.name = "Test Task" mock_task.dependencies = ["task1"] mock_plan = Mock() mock_plan.tasks = [mock_dep_task, mock_task] mock_execute = Mock() result = execute_task_with_recovery(mock_task, mock_plan, mock_execute) assert mock_task.status == 'skipped' assert mock_execute.call_count == 0 def test_mark_failed_on_exception(self): """测试执行异常时标记失败。""" mock_task = Mock() mock_task.id = "task1" mock_task.name = "Test Task" mock_task.dependencies = [] mock_plan = Mock() mock_plan.tasks = [mock_task] mock_execute = Mock(side_effect=Exception("execution error")) result = execute_task_with_recovery(mock_task, mock_plan, mock_execute) assert mock_task.status == 'failed' def test_continue_on_task_failure(self): """测试单个任务失败不影响其他任务。""" mock_task1 = Mock() mock_task1.id = "task1" mock_task1.name = "Task 1" mock_task1.dependencies = [] mock_task2 = Mock() mock_task2.id = "task2" mock_task2.name = "Task 2" mock_task2.dependencies = [] mock_plan = Mock() mock_plan.tasks = [mock_task1, mock_task2] # 第一个任务失败 mock_execute = Mock(side_effect=Exception("error")) result1 = execute_task_with_recovery(mock_task1, mock_plan, mock_execute) assert mock_task1.status == 'failed' # 第二个任务应该可以继续执行 mock_execute2 = Mock(return_value=Mock(success=True)) result2 = execute_task_with_recovery(mock_task2, mock_plan, mock_execute2) assert mock_task2.status == 'completed'