317 lines
9.1 KiB
Python
317 lines
9.1 KiB
Python
|
|
"""Task execution engine using ReAct pattern."""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import json
|
||
|
|
import re
|
||
|
|
import time
|
||
|
|
from typing import List, Dict, Any, Optional
|
||
|
|
from openai import OpenAI
|
||
|
|
|
||
|
|
from src.models.analysis_plan import AnalysisTask
|
||
|
|
from src.models.analysis_result import AnalysisResult
|
||
|
|
from src.tools.base import AnalysisTool
|
||
|
|
from src.data_access import DataAccessLayer
|
||
|
|
|
||
|
|
|
||
|
|
def execute_task(
|
||
|
|
task: AnalysisTask,
|
||
|
|
tools: List[AnalysisTool],
|
||
|
|
data_access: DataAccessLayer,
|
||
|
|
max_iterations: int = 10
|
||
|
|
) -> AnalysisResult:
|
||
|
|
"""
|
||
|
|
Execute analysis task using ReAct pattern.
|
||
|
|
|
||
|
|
ReAct loop: Thought -> Action -> Observation -> repeat
|
||
|
|
|
||
|
|
Args:
|
||
|
|
task: Analysis task to execute
|
||
|
|
tools: Available analysis tools
|
||
|
|
data_access: Data access layer for executing tools
|
||
|
|
max_iterations: Maximum number of iterations
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
AnalysisResult with execution results
|
||
|
|
|
||
|
|
Requirements: FR-5.1
|
||
|
|
"""
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
# Get API key
|
||
|
|
api_key = os.getenv('OPENAI_API_KEY')
|
||
|
|
if not api_key:
|
||
|
|
# Fallback to simple execution
|
||
|
|
return _fallback_task_execution(task, tools, data_access)
|
||
|
|
|
||
|
|
client = OpenAI(api_key=api_key)
|
||
|
|
|
||
|
|
# Execution history
|
||
|
|
history = []
|
||
|
|
visualizations = []
|
||
|
|
|
||
|
|
try:
|
||
|
|
for iteration in range(max_iterations):
|
||
|
|
# Thought: AI decides next action
|
||
|
|
thought_prompt = _build_thought_prompt(task, tools, history)
|
||
|
|
|
||
|
|
thought_response = client.chat.completions.create(
|
||
|
|
model="gpt-4",
|
||
|
|
messages=[
|
||
|
|
{"role": "system", "content": "You are a data analyst executing analysis tasks. Use the ReAct pattern: think, act, observe."},
|
||
|
|
{"role": "user", "content": thought_prompt}
|
||
|
|
],
|
||
|
|
temperature=0.7,
|
||
|
|
max_tokens=1000
|
||
|
|
)
|
||
|
|
|
||
|
|
thought = _parse_thought_response(thought_response.choices[0].message.content)
|
||
|
|
history.append({"type": "thought", "content": thought})
|
||
|
|
|
||
|
|
# Check if task is complete
|
||
|
|
if thought.get('is_completed', False):
|
||
|
|
break
|
||
|
|
|
||
|
|
# Action: Execute selected tool
|
||
|
|
tool_name = thought.get('selected_tool')
|
||
|
|
tool_params = thought.get('tool_params', {})
|
||
|
|
|
||
|
|
if tool_name:
|
||
|
|
tool = _find_tool(tools, tool_name)
|
||
|
|
if tool:
|
||
|
|
action_result = call_tool(tool, data_access, **tool_params)
|
||
|
|
history.append({
|
||
|
|
"type": "action",
|
||
|
|
"tool": tool_name,
|
||
|
|
"params": tool_params
|
||
|
|
})
|
||
|
|
|
||
|
|
# Observation: Record result
|
||
|
|
history.append({
|
||
|
|
"type": "observation",
|
||
|
|
"result": action_result
|
||
|
|
})
|
||
|
|
|
||
|
|
# Track visualizations
|
||
|
|
if 'visualization_path' in action_result:
|
||
|
|
visualizations.append(action_result['visualization_path'])
|
||
|
|
|
||
|
|
# Extract insights from history
|
||
|
|
insights = extract_insights(history, client)
|
||
|
|
|
||
|
|
execution_time = time.time() - start_time
|
||
|
|
|
||
|
|
return AnalysisResult(
|
||
|
|
task_id=task.id,
|
||
|
|
task_name=task.name,
|
||
|
|
success=True,
|
||
|
|
data=history[-1].get('result', {}) if history else {},
|
||
|
|
visualizations=visualizations,
|
||
|
|
insights=insights,
|
||
|
|
execution_time=execution_time
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
execution_time = time.time() - start_time
|
||
|
|
return AnalysisResult(
|
||
|
|
task_id=task.id,
|
||
|
|
task_name=task.name,
|
||
|
|
success=False,
|
||
|
|
error=str(e),
|
||
|
|
execution_time=execution_time
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _build_thought_prompt(
|
||
|
|
task: AnalysisTask,
|
||
|
|
tools: List[AnalysisTool],
|
||
|
|
history: List[Dict[str, Any]]
|
||
|
|
) -> str:
|
||
|
|
"""Build prompt for thought step."""
|
||
|
|
tool_descriptions = "\n".join([
|
||
|
|
f"- {tool.name}: {tool.description}"
|
||
|
|
for tool in tools
|
||
|
|
])
|
||
|
|
|
||
|
|
history_str = "\n".join([
|
||
|
|
f"{i+1}. {h['type']}: {str(h.get('content', h.get('result', '')))[:200]}"
|
||
|
|
for i, h in enumerate(history[-5:]) # Last 5 steps
|
||
|
|
])
|
||
|
|
|
||
|
|
prompt = f"""Task: {task.description}
|
||
|
|
Expected Output: {task.expected_output}
|
||
|
|
|
||
|
|
Available Tools:
|
||
|
|
{tool_descriptions}
|
||
|
|
|
||
|
|
Execution History:
|
||
|
|
{history_str if history else "No history yet"}
|
||
|
|
|
||
|
|
Think about:
|
||
|
|
1. What is the current state?
|
||
|
|
2. What should I do next?
|
||
|
|
3. Which tool should I use?
|
||
|
|
4. Is the task completed?
|
||
|
|
|
||
|
|
Respond in JSON format:
|
||
|
|
{{
|
||
|
|
"reasoning": "Your reasoning",
|
||
|
|
"is_completed": false,
|
||
|
|
"selected_tool": "tool_name",
|
||
|
|
"tool_params": {{"param": "value"}}
|
||
|
|
}}
|
||
|
|
"""
|
||
|
|
|
||
|
|
return prompt
|
||
|
|
|
||
|
|
|
||
|
|
def _parse_thought_response(response_text: str) -> Dict[str, Any]:
|
||
|
|
"""Parse thought response from AI."""
|
||
|
|
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||
|
|
if json_match:
|
||
|
|
try:
|
||
|
|
return json.loads(json_match.group())
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return {
|
||
|
|
'reasoning': response_text,
|
||
|
|
'is_completed': False,
|
||
|
|
'selected_tool': None,
|
||
|
|
'tool_params': {}
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def call_tool(
|
||
|
|
tool: AnalysisTool,
|
||
|
|
data_access: DataAccessLayer,
|
||
|
|
**kwargs
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Call analysis tool and return result.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tool: Tool to execute
|
||
|
|
data_access: Data access layer
|
||
|
|
**kwargs: Tool parameters
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tool execution result
|
||
|
|
|
||
|
|
Requirements: FR-5.2
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
result = data_access.execute_tool(tool, **kwargs)
|
||
|
|
return {
|
||
|
|
'success': True,
|
||
|
|
'data': result
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
return {
|
||
|
|
'success': False,
|
||
|
|
'error': str(e)
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def extract_insights(
|
||
|
|
history: List[Dict[str, Any]],
|
||
|
|
client: Optional[OpenAI] = None
|
||
|
|
) -> List[str]:
|
||
|
|
"""
|
||
|
|
Extract insights from execution history.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
history: Execution history
|
||
|
|
client: OpenAI client (optional)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of insights
|
||
|
|
|
||
|
|
Requirements: FR-5.4
|
||
|
|
"""
|
||
|
|
if not client:
|
||
|
|
# Simple extraction without AI
|
||
|
|
insights = []
|
||
|
|
for entry in history:
|
||
|
|
if entry['type'] == 'observation':
|
||
|
|
result = entry.get('result', {})
|
||
|
|
if isinstance(result, dict) and 'data' in result:
|
||
|
|
insights.append(f"Found data: {str(result['data'])[:100]}")
|
||
|
|
return insights[:5] # Limit to 5
|
||
|
|
|
||
|
|
# AI-driven insight extraction
|
||
|
|
history_str = json.dumps(history, indent=2, ensure_ascii=False)[:3000]
|
||
|
|
|
||
|
|
try:
|
||
|
|
response = client.chat.completions.create(
|
||
|
|
model="gpt-4",
|
||
|
|
messages=[
|
||
|
|
{"role": "system", "content": "Extract key insights from analysis execution history."},
|
||
|
|
{"role": "user", "content": f"Execution history:\n{history_str}\n\nExtract 3-5 key insights as a JSON array of strings."}
|
||
|
|
],
|
||
|
|
temperature=0.7,
|
||
|
|
max_tokens=500
|
||
|
|
)
|
||
|
|
|
||
|
|
insights_text = response.choices[0].message.content
|
||
|
|
json_match = re.search(r'\[.*\]', insights_text, re.DOTALL)
|
||
|
|
if json_match:
|
||
|
|
return json.loads(json_match.group())
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return ["Analysis completed successfully"]
|
||
|
|
|
||
|
|
|
||
|
|
def _find_tool(tools: List[AnalysisTool], tool_name: str) -> Optional[AnalysisTool]:
|
||
|
|
"""Find tool by name."""
|
||
|
|
for tool in tools:
|
||
|
|
if tool.name == tool_name:
|
||
|
|
return tool
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _fallback_task_execution(
|
||
|
|
task: AnalysisTask,
|
||
|
|
tools: List[AnalysisTool],
|
||
|
|
data_access: DataAccessLayer
|
||
|
|
) -> AnalysisResult:
|
||
|
|
"""Simple fallback execution without AI."""
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Execute first applicable tool
|
||
|
|
for tool_name in task.required_tools:
|
||
|
|
tool = _find_tool(tools, tool_name)
|
||
|
|
if tool:
|
||
|
|
result = call_tool(tool, data_access)
|
||
|
|
execution_time = time.time() - start_time
|
||
|
|
|
||
|
|
return AnalysisResult(
|
||
|
|
task_id=task.id,
|
||
|
|
task_name=task.name,
|
||
|
|
success=result.get('success', False),
|
||
|
|
data=result.get('data', {}),
|
||
|
|
insights=[f"Executed {tool_name}"],
|
||
|
|
execution_time=execution_time
|
||
|
|
)
|
||
|
|
|
||
|
|
# No tools executed
|
||
|
|
execution_time = time.time() - start_time
|
||
|
|
return AnalysisResult(
|
||
|
|
task_id=task.id,
|
||
|
|
task_name=task.name,
|
||
|
|
success=False,
|
||
|
|
error="No applicable tools found",
|
||
|
|
execution_time=execution_time
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
execution_time = time.time() - start_time
|
||
|
|
return AnalysisResult(
|
||
|
|
task_id=task.id,
|
||
|
|
task_name=task.name,
|
||
|
|
success=False,
|
||
|
|
error=str(e),
|
||
|
|
execution_time=execution_time
|
||
|
|
)
|