427 lines
13 KiB
Python
427 lines
13 KiB
Python
"""单元测试:错误处理机制。"""
|
||
|
||
import pytest
|
||
import pandas as pd
|
||
import time
|
||
from pathlib import Path
|
||
from unittest.mock import Mock, patch, MagicMock
|
||
import tempfile
|
||
import os
|
||
|
||
from src.error_handling import (
|
||
load_data_with_retry,
|
||
call_llm_with_fallback,
|
||
execute_tool_safely,
|
||
execute_task_with_recovery,
|
||
validate_tool_params,
|
||
validate_tool_result,
|
||
DataLoadError,
|
||
AICallError,
|
||
ToolExecutionError
|
||
)
|
||
|
||
|
||
class TestLoadDataWithRetry:
|
||
"""测试数据加载错误处理。"""
|
||
|
||
def test_load_valid_csv(self, tmp_path):
|
||
"""测试加载有效的 CSV 文件。"""
|
||
# 创建测试文件
|
||
csv_file = tmp_path / "test.csv"
|
||
df = pd.DataFrame({
|
||
'col1': [1, 2, 3],
|
||
'col2': ['a', 'b', 'c']
|
||
})
|
||
df.to_csv(csv_file, index=False)
|
||
|
||
# 加载数据
|
||
result = load_data_with_retry(str(csv_file))
|
||
|
||
assert len(result) == 3
|
||
assert len(result.columns) == 2
|
||
assert list(result.columns) == ['col1', 'col2']
|
||
|
||
def test_load_gbk_encoded_file(self, tmp_path):
|
||
"""测试加载 GBK 编码的文件。"""
|
||
# 创建 GBK 编码的文件
|
||
csv_file = tmp_path / "test_gbk.csv"
|
||
df = pd.DataFrame({
|
||
'列1': [1, 2, 3],
|
||
'列2': ['中文', '测试', '数据']
|
||
})
|
||
df.to_csv(csv_file, index=False, encoding='gbk')
|
||
|
||
# 加载数据
|
||
result = load_data_with_retry(str(csv_file))
|
||
|
||
assert len(result) == 3
|
||
assert '列1' in result.columns
|
||
assert '列2' in result.columns
|
||
|
||
def test_load_file_not_exists(self):
|
||
"""测试文件不存在的情况。"""
|
||
with pytest.raises(DataLoadError, match="文件不存在"):
|
||
load_data_with_retry("nonexistent.csv")
|
||
|
||
def test_load_empty_file(self, tmp_path):
|
||
"""测试空文件的处理。"""
|
||
# 创建空文件
|
||
csv_file = tmp_path / "empty.csv"
|
||
csv_file.touch()
|
||
|
||
with pytest.raises(DataLoadError, match="文件为空"):
|
||
load_data_with_retry(str(csv_file))
|
||
|
||
def test_load_large_file_sampling(self, tmp_path):
|
||
"""测试大文件采样。"""
|
||
# 创建大文件(模拟)
|
||
csv_file = tmp_path / "large.csv"
|
||
df = pd.DataFrame({
|
||
'col1': range(2000000),
|
||
'col2': range(2000000)
|
||
})
|
||
# 只保存前 1500000 行以加快测试
|
||
df.head(1500000).to_csv(csv_file, index=False)
|
||
|
||
# 加载数据(应该采样到 1000000 行)
|
||
result = load_data_with_retry(str(csv_file), sample_size=1000000)
|
||
|
||
assert len(result) == 1000000
|
||
|
||
def test_load_different_separator(self, tmp_path):
|
||
"""测试不同分隔符的文件。"""
|
||
# 创建使用分号分隔的文件
|
||
csv_file = tmp_path / "semicolon.csv"
|
||
with open(csv_file, 'w') as f:
|
||
f.write("col1;col2\n")
|
||
f.write("1;a\n")
|
||
f.write("2;b\n")
|
||
|
||
# 加载数据
|
||
result = load_data_with_retry(str(csv_file))
|
||
|
||
assert len(result) == 2
|
||
assert len(result.columns) == 2
|
||
|
||
|
||
class TestCallLLMWithFallback:
|
||
"""测试 AI 调用错误处理。"""
|
||
|
||
def test_successful_call(self):
|
||
"""测试成功的 AI 调用。"""
|
||
mock_func = Mock(return_value={'result': 'success'})
|
||
|
||
result = call_llm_with_fallback(mock_func, prompt="test")
|
||
|
||
assert result == {'result': 'success'}
|
||
assert mock_func.call_count == 1
|
||
|
||
def test_retry_on_timeout(self):
|
||
"""测试超时重试机制。"""
|
||
mock_func = Mock(side_effect=[
|
||
TimeoutError("timeout"),
|
||
TimeoutError("timeout"),
|
||
{'result': 'success'}
|
||
])
|
||
|
||
result = call_llm_with_fallback(mock_func, max_retries=3, prompt="test")
|
||
|
||
assert result == {'result': 'success'}
|
||
assert mock_func.call_count == 3
|
||
|
||
def test_exponential_backoff(self):
|
||
"""测试指数退避。"""
|
||
mock_func = Mock(side_effect=[
|
||
Exception("error"),
|
||
{'result': 'success'}
|
||
])
|
||
|
||
start_time = time.time()
|
||
result = call_llm_with_fallback(mock_func, max_retries=3, prompt="test")
|
||
elapsed = time.time() - start_time
|
||
|
||
# 应该等待至少 1 秒(2^0)
|
||
assert elapsed >= 1.0
|
||
assert result == {'result': 'success'}
|
||
|
||
def test_fallback_on_failure(self):
|
||
"""测试降级策略。"""
|
||
mock_func = Mock(side_effect=Exception("error"))
|
||
fallback_func = Mock(return_value={'result': 'fallback'})
|
||
|
||
result = call_llm_with_fallback(
|
||
mock_func,
|
||
fallback_func=fallback_func,
|
||
max_retries=2,
|
||
prompt="test"
|
||
)
|
||
|
||
assert result == {'result': 'fallback'}
|
||
assert mock_func.call_count == 2
|
||
assert fallback_func.call_count == 1
|
||
|
||
def test_no_fallback_raises_error(self):
|
||
"""测试无降级策略时抛出错误。"""
|
||
mock_func = Mock(side_effect=Exception("error"))
|
||
|
||
with pytest.raises(AICallError, match="AI 调用失败"):
|
||
call_llm_with_fallback(mock_func, max_retries=2, prompt="test")
|
||
|
||
def test_fallback_also_fails(self):
|
||
"""测试降级策略也失败的情况。"""
|
||
mock_func = Mock(side_effect=Exception("error"))
|
||
fallback_func = Mock(side_effect=Exception("fallback error"))
|
||
|
||
with pytest.raises(AICallError, match="AI 调用和降级策略都失败"):
|
||
call_llm_with_fallback(
|
||
mock_func,
|
||
fallback_func=fallback_func,
|
||
max_retries=2,
|
||
prompt="test"
|
||
)
|
||
|
||
|
||
class TestExecuteToolSafely:
|
||
"""测试工具执行错误处理。"""
|
||
|
||
def test_successful_execution(self):
|
||
"""测试成功的工具执行。"""
|
||
mock_tool = Mock()
|
||
mock_tool.name = "test_tool"
|
||
mock_tool.parameters = {'required': [], 'properties': {}}
|
||
mock_tool.execute = Mock(return_value={'data': 'result'})
|
||
|
||
df = pd.DataFrame({'col1': [1, 2, 3]})
|
||
result = execute_tool_safely(mock_tool, df)
|
||
|
||
assert result['success'] is True
|
||
assert result['data'] == {'data': 'result'}
|
||
assert result['tool'] == 'test_tool'
|
||
|
||
def test_missing_execute_method(self):
|
||
"""测试工具缺少 execute 方法。"""
|
||
mock_tool = Mock(spec=[])
|
||
mock_tool.name = "bad_tool"
|
||
|
||
df = pd.DataFrame({'col1': [1, 2, 3]})
|
||
result = execute_tool_safely(mock_tool, df)
|
||
|
||
assert result['success'] is False
|
||
assert 'execute 方法' in result['error']
|
||
|
||
def test_parameter_validation_failure(self):
|
||
"""测试参数验证失败。"""
|
||
mock_tool = Mock()
|
||
mock_tool.name = "test_tool"
|
||
mock_tool.parameters = {
|
||
'required': ['column'],
|
||
'properties': {
|
||
'column': {'type': 'string'}
|
||
}
|
||
}
|
||
mock_tool.execute = Mock(return_value={'data': 'result'})
|
||
|
||
df = pd.DataFrame({'col1': [1, 2, 3]})
|
||
# 缺少必需参数
|
||
result = execute_tool_safely(mock_tool, df)
|
||
|
||
assert result['success'] is False
|
||
assert '参数验证失败' in result['error']
|
||
|
||
def test_empty_data(self):
|
||
"""测试空数据。"""
|
||
mock_tool = Mock()
|
||
mock_tool.name = "test_tool"
|
||
mock_tool.parameters = {'required': [], 'properties': {}}
|
||
|
||
df = pd.DataFrame()
|
||
result = execute_tool_safely(mock_tool, df)
|
||
|
||
assert result['success'] is False
|
||
assert '数据为空' in result['error']
|
||
|
||
def test_execution_exception(self):
|
||
"""测试执行异常。"""
|
||
mock_tool = Mock()
|
||
mock_tool.name = "test_tool"
|
||
mock_tool.parameters = {'required': [], 'properties': {}}
|
||
mock_tool.execute = Mock(side_effect=Exception("execution error"))
|
||
|
||
df = pd.DataFrame({'col1': [1, 2, 3]})
|
||
result = execute_tool_safely(mock_tool, df)
|
||
|
||
assert result['success'] is False
|
||
assert 'execution error' in result['error']
|
||
|
||
|
||
class TestValidateToolParams:
|
||
"""测试工具参数验证。"""
|
||
|
||
def test_valid_params(self):
|
||
"""测试有效参数。"""
|
||
mock_tool = Mock()
|
||
mock_tool.parameters = {
|
||
'required': ['column'],
|
||
'properties': {
|
||
'column': {'type': 'string'}
|
||
}
|
||
}
|
||
|
||
result = validate_tool_params(mock_tool, {'column': 'col1'})
|
||
|
||
assert result['valid'] is True
|
||
|
||
def test_missing_required_param(self):
|
||
"""测试缺少必需参数。"""
|
||
mock_tool = Mock()
|
||
mock_tool.parameters = {
|
||
'required': ['column'],
|
||
'properties': {}
|
||
}
|
||
|
||
result = validate_tool_params(mock_tool, {})
|
||
|
||
assert result['valid'] is False
|
||
assert '缺少必需参数' in result['error']
|
||
|
||
def test_wrong_param_type(self):
|
||
"""测试参数类型错误。"""
|
||
mock_tool = Mock()
|
||
mock_tool.parameters = {
|
||
'required': [],
|
||
'properties': {
|
||
'column': {'type': 'string'}
|
||
}
|
||
}
|
||
|
||
result = validate_tool_params(mock_tool, {'column': 123})
|
||
|
||
assert result['valid'] is False
|
||
assert '应为字符串类型' in result['error']
|
||
|
||
|
||
class TestValidateToolResult:
|
||
"""测试工具结果验证。"""
|
||
|
||
def test_valid_result(self):
|
||
"""测试有效结果。"""
|
||
result = validate_tool_result({'data': 'test'})
|
||
|
||
assert result['valid'] is True
|
||
|
||
def test_none_result(self):
|
||
"""测试 None 结果。"""
|
||
result = validate_tool_result(None)
|
||
|
||
assert result['valid'] is False
|
||
assert 'None' in result['error']
|
||
|
||
def test_wrong_type_result(self):
|
||
"""测试错误类型结果。"""
|
||
result = validate_tool_result("string result")
|
||
|
||
assert result['valid'] is False
|
||
assert '类型错误' in result['error']
|
||
|
||
|
||
class TestExecuteTaskWithRecovery:
|
||
"""测试任务执行错误处理。"""
|
||
|
||
def test_successful_execution(self):
|
||
"""测试成功的任务执行。"""
|
||
mock_task = Mock()
|
||
mock_task.id = "task1"
|
||
mock_task.name = "Test Task"
|
||
mock_task.dependencies = []
|
||
|
||
mock_plan = Mock()
|
||
mock_plan.tasks = [mock_task]
|
||
|
||
mock_execute = Mock(return_value=Mock(success=True))
|
||
|
||
result = execute_task_with_recovery(mock_task, mock_plan, mock_execute)
|
||
|
||
assert mock_task.status == 'completed'
|
||
assert mock_execute.call_count == 1
|
||
|
||
def test_skip_on_missing_dependency(self):
|
||
"""测试依赖任务不存在时跳过。"""
|
||
mock_task = Mock()
|
||
mock_task.id = "task2"
|
||
mock_task.name = "Test Task"
|
||
mock_task.dependencies = ["task1"]
|
||
|
||
mock_plan = Mock()
|
||
mock_plan.tasks = [mock_task]
|
||
|
||
mock_execute = Mock()
|
||
|
||
result = execute_task_with_recovery(mock_task, mock_plan, mock_execute)
|
||
|
||
assert mock_task.status == 'skipped'
|
||
assert mock_execute.call_count == 0
|
||
|
||
def test_skip_on_failed_dependency(self):
|
||
"""测试依赖任务失败时跳过。"""
|
||
mock_dep_task = Mock()
|
||
mock_dep_task.id = "task1"
|
||
mock_dep_task.status = 'failed'
|
||
|
||
mock_task = Mock()
|
||
mock_task.id = "task2"
|
||
mock_task.name = "Test Task"
|
||
mock_task.dependencies = ["task1"]
|
||
|
||
mock_plan = Mock()
|
||
mock_plan.tasks = [mock_dep_task, mock_task]
|
||
|
||
mock_execute = Mock()
|
||
|
||
result = execute_task_with_recovery(mock_task, mock_plan, mock_execute)
|
||
|
||
assert mock_task.status == 'skipped'
|
||
assert mock_execute.call_count == 0
|
||
|
||
def test_mark_failed_on_exception(self):
|
||
"""测试执行异常时标记失败。"""
|
||
mock_task = Mock()
|
||
mock_task.id = "task1"
|
||
mock_task.name = "Test Task"
|
||
mock_task.dependencies = []
|
||
|
||
mock_plan = Mock()
|
||
mock_plan.tasks = [mock_task]
|
||
|
||
mock_execute = Mock(side_effect=Exception("execution error"))
|
||
|
||
result = execute_task_with_recovery(mock_task, mock_plan, mock_execute)
|
||
|
||
assert mock_task.status == 'failed'
|
||
|
||
def test_continue_on_task_failure(self):
|
||
"""测试单个任务失败不影响其他任务。"""
|
||
mock_task1 = Mock()
|
||
mock_task1.id = "task1"
|
||
mock_task1.name = "Task 1"
|
||
mock_task1.dependencies = []
|
||
|
||
mock_task2 = Mock()
|
||
mock_task2.id = "task2"
|
||
mock_task2.name = "Task 2"
|
||
mock_task2.dependencies = []
|
||
|
||
mock_plan = Mock()
|
||
mock_plan.tasks = [mock_task1, mock_task2]
|
||
|
||
# 第一个任务失败
|
||
mock_execute = Mock(side_effect=Exception("error"))
|
||
result1 = execute_task_with_recovery(mock_task1, mock_plan, mock_execute)
|
||
|
||
assert mock_task1.status == 'failed'
|
||
|
||
# 第二个任务应该可以继续执行
|
||
mock_execute2 = Mock(return_value=Mock(success=True))
|
||
result2 = execute_task_with_recovery(mock_task2, mock_plan, mock_execute2)
|
||
|
||
assert mock_task2.status == 'completed'
|