208 lines
5.9 KiB
Python
208 lines
5.9 KiB
Python
"""Unit tests for task execution engine."""
|
|
|
|
import pytest
|
|
import pandas as pd
|
|
|
|
from src.engines.task_execution import (
|
|
execute_task,
|
|
call_tool,
|
|
extract_insights,
|
|
_fallback_task_execution,
|
|
_find_tool
|
|
)
|
|
from src.models.analysis_plan import AnalysisTask
|
|
from src.data_access import DataAccessLayer
|
|
from src.tools.stats_tools import CalculateStatisticsTool
|
|
from src.tools.query_tools import GetValueCountsTool
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_data():
|
|
"""Create sample data for testing."""
|
|
return pd.DataFrame({
|
|
'value': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
|
'category': ['A', 'B', 'A', 'B', 'A', 'B', 'A', 'B', 'A', 'B'],
|
|
'score': [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
|
|
})
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_tools():
|
|
"""Create sample tools for testing."""
|
|
return [
|
|
CalculateStatisticsTool(),
|
|
GetValueCountsTool()
|
|
]
|
|
|
|
|
|
def test_fallback_execution_success(sample_data, sample_tools):
|
|
"""Test successful fallback execution."""
|
|
task = AnalysisTask(
|
|
id="task_1",
|
|
name="Calculate Statistics",
|
|
description="Calculate basic statistics",
|
|
priority=5,
|
|
required_tools=['calculate_statistics']
|
|
)
|
|
|
|
data_access = DataAccessLayer(sample_data)
|
|
result = _fallback_task_execution(task, sample_tools, data_access)
|
|
|
|
assert result.task_id == "task_1"
|
|
assert result.task_name == "Calculate Statistics"
|
|
assert isinstance(result.success, bool)
|
|
assert result.execution_time >= 0
|
|
|
|
|
|
def test_fallback_execution_no_tools(sample_data):
|
|
"""Test fallback execution with no tools."""
|
|
task = AnalysisTask(
|
|
id="task_1",
|
|
name="Test Task",
|
|
description="Test",
|
|
priority=3,
|
|
required_tools=['nonexistent_tool']
|
|
)
|
|
|
|
data_access = DataAccessLayer(sample_data)
|
|
result = _fallback_task_execution(task, [], data_access)
|
|
|
|
assert not result.success
|
|
assert result.error is not None
|
|
|
|
|
|
def test_call_tool_success(sample_data, sample_tools):
|
|
"""Test successful tool calling."""
|
|
tool = sample_tools[0] # CalculateStatisticsTool
|
|
data_access = DataAccessLayer(sample_data)
|
|
|
|
result = call_tool(tool, data_access, column='value')
|
|
|
|
assert isinstance(result, dict)
|
|
assert 'success' in result
|
|
|
|
|
|
def test_call_tool_with_invalid_params(sample_data, sample_tools):
|
|
"""Test tool calling with invalid parameters."""
|
|
tool = sample_tools[0]
|
|
data_access = DataAccessLayer(sample_data)
|
|
|
|
result = call_tool(tool, data_access, column='nonexistent_column')
|
|
|
|
assert isinstance(result, dict)
|
|
# Should handle error gracefully
|
|
|
|
|
|
def test_extract_insights_simple():
|
|
"""Test simple insight extraction."""
|
|
history = [
|
|
{'type': 'thought', 'content': 'Starting analysis'},
|
|
{'type': 'action', 'tool': 'calculate_statistics', 'params': {}},
|
|
{'type': 'observation', 'result': {'data': {'mean': 5.5, 'std': 2.87}}}
|
|
]
|
|
|
|
insights = extract_insights(history, client=None)
|
|
|
|
assert isinstance(insights, list)
|
|
assert len(insights) > 0
|
|
|
|
|
|
def test_extract_insights_empty_history():
|
|
"""Test insight extraction with empty history."""
|
|
insights = extract_insights([], client=None)
|
|
|
|
assert isinstance(insights, list)
|
|
|
|
|
|
def test_find_tool_exists(sample_tools):
|
|
"""Test finding an existing tool."""
|
|
tool = _find_tool(sample_tools, 'calculate_statistics')
|
|
|
|
assert tool is not None
|
|
assert tool.name == 'calculate_statistics'
|
|
|
|
|
|
def test_find_tool_not_exists(sample_tools):
|
|
"""Test finding a non-existent tool."""
|
|
tool = _find_tool(sample_tools, 'nonexistent_tool')
|
|
|
|
assert tool is None
|
|
|
|
|
|
def test_execution_result_structure(sample_data, sample_tools):
|
|
"""Test that execution result has correct structure."""
|
|
task = AnalysisTask(
|
|
id="task_1",
|
|
name="Test Task",
|
|
description="Test",
|
|
priority=3,
|
|
required_tools=['calculate_statistics']
|
|
)
|
|
|
|
data_access = DataAccessLayer(sample_data)
|
|
result = _fallback_task_execution(task, sample_tools, data_access)
|
|
|
|
# Check all required fields
|
|
assert hasattr(result, 'task_id')
|
|
assert hasattr(result, 'task_name')
|
|
assert hasattr(result, 'success')
|
|
assert hasattr(result, 'data')
|
|
assert hasattr(result, 'visualizations')
|
|
assert hasattr(result, 'insights')
|
|
assert hasattr(result, 'error')
|
|
assert hasattr(result, 'execution_time')
|
|
|
|
|
|
def test_execution_with_multiple_tools(sample_data, sample_tools):
|
|
"""Test execution with multiple required tools."""
|
|
task = AnalysisTask(
|
|
id="task_1",
|
|
name="Multi-tool Task",
|
|
description="Use multiple tools",
|
|
priority=3,
|
|
required_tools=['calculate_statistics', 'get_value_counts']
|
|
)
|
|
|
|
data_access = DataAccessLayer(sample_data)
|
|
result = _fallback_task_execution(task, sample_tools, data_access)
|
|
|
|
# Should execute first available tool
|
|
assert result is not None
|
|
|
|
|
|
def test_execution_time_tracking(sample_data, sample_tools):
|
|
"""Test that execution time is tracked."""
|
|
task = AnalysisTask(
|
|
id="task_1",
|
|
name="Test Task",
|
|
description="Test",
|
|
priority=3,
|
|
required_tools=['calculate_statistics']
|
|
)
|
|
|
|
data_access = DataAccessLayer(sample_data)
|
|
result = _fallback_task_execution(task, sample_tools, data_access)
|
|
|
|
assert result.execution_time >= 0
|
|
assert result.execution_time < 10 # Should be fast
|
|
|
|
|
|
def test_execution_with_empty_data():
|
|
"""Test execution with empty data."""
|
|
empty_data = pd.DataFrame()
|
|
task = AnalysisTask(
|
|
id="task_1",
|
|
name="Test Task",
|
|
description="Test",
|
|
priority=3,
|
|
required_tools=['calculate_statistics']
|
|
)
|
|
|
|
data_access = DataAccessLayer(empty_data)
|
|
tools = [CalculateStatisticsTool()]
|
|
|
|
result = _fallback_task_execution(task, tools, data_access)
|
|
|
|
# Should handle gracefully
|
|
assert result is not None
|