二次重构,加入预设模板
This commit is contained in:
261
run_analysis_en.py
Normal file
261
run_analysis_en.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
AI-Driven Data Analysis Framework
|
||||
==================================
|
||||
Generic pipeline: works with any CSV file.
|
||||
AI decides everything at runtime based on data characteristics.
|
||||
|
||||
Flow: Load CSV → AI understands metadata → AI plans analysis → AI executes tasks → AI generates report
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from src.env_loader import load_env_with_fallback
|
||||
load_env_with_fallback(['.env'])
|
||||
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from openai import OpenAI
|
||||
|
||||
from src.engines.ai_data_understanding import ai_understand_data_with_dal
|
||||
from src.engines.requirement_understanding import understand_requirement
|
||||
from src.engines.analysis_planning import plan_analysis
|
||||
from src.engines.task_execution import execute_task
|
||||
from src.engines.report_generation import generate_report
|
||||
from src.tools.tool_manager import ToolManager
|
||||
from src.tools.base import _global_registry
|
||||
from src.models import DataProfile, AnalysisResult
|
||||
from src.config import get_config
|
||||
|
||||
# Register all tools
|
||||
from src.tools import register_tool
|
||||
from src.tools.query_tools import (
|
||||
GetColumnDistributionTool, GetValueCountsTool,
|
||||
GetTimeSeriesTool, GetCorrelationTool
|
||||
)
|
||||
from src.tools.stats_tools import (
|
||||
CalculateStatisticsTool, PerformGroupbyTool,
|
||||
DetectOutliersTool, CalculateTrendTool
|
||||
)
|
||||
from src.tools.viz_tools import (
|
||||
CreateBarChartTool, CreateLineChartTool,
|
||||
CreatePieChartTool, CreateHeatmapTool
|
||||
)
|
||||
|
||||
for tool_cls in [
|
||||
GetColumnDistributionTool, GetValueCountsTool, GetTimeSeriesTool, GetCorrelationTool,
|
||||
CalculateStatisticsTool, PerformGroupbyTool, DetectOutliersTool, CalculateTrendTool,
|
||||
CreateBarChartTool, CreateLineChartTool, CreatePieChartTool, CreateHeatmapTool
|
||||
]:
|
||||
register_tool(tool_cls())
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
|
||||
def run_analysis(
|
||||
data_file: str,
|
||||
user_requirement: Optional[str] = None,
|
||||
template_file: Optional[str] = None,
|
||||
output_dir: str = "analysis_output"
|
||||
):
|
||||
"""
|
||||
Run the full AI-driven analysis pipeline.
|
||||
|
||||
Args:
|
||||
data_file: Path to any CSV file
|
||||
user_requirement: Natural language requirement (optional)
|
||||
template_file: Report template path (optional)
|
||||
output_dir: Output directory
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
config = get_config()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("AI-Driven Data Analysis")
|
||||
print("=" * 70)
|
||||
print(f"Start: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"Data: {data_file}")
|
||||
if template_file:
|
||||
print(f"Template: {template_file}")
|
||||
print("=" * 70)
|
||||
|
||||
# ── Stage 1: AI Data Understanding ──
|
||||
print("\n[1/5] AI Understanding Data...")
|
||||
print(" (AI sees only metadata — never raw rows)")
|
||||
profile, dal = ai_understand_data_with_dal(data_file)
|
||||
print(f" Type: {profile.inferred_type}")
|
||||
print(f" Quality: {profile.quality_score}/100")
|
||||
print(f" Columns: {profile.column_count}, Rows: {profile.row_count}")
|
||||
print(f" Summary: {profile.summary[:120]}...")
|
||||
|
||||
# ── Stage 2: Requirement Understanding ──
|
||||
print("\n[2/5] Understanding Requirements...")
|
||||
requirement = understand_requirement(
|
||||
user_input=user_requirement or "对数据进行全面分析",
|
||||
data_profile=profile,
|
||||
template_path=template_file
|
||||
)
|
||||
print(f" Objectives: {len(requirement.objectives)}")
|
||||
for obj in requirement.objectives:
|
||||
print(f" - {obj.name} (priority: {obj.priority})")
|
||||
|
||||
# ── Stage 3: AI Analysis Planning ──
|
||||
print("\n[3/5] AI Planning Analysis...")
|
||||
tool_manager = ToolManager(_global_registry)
|
||||
tools = tool_manager.select_tools(profile)
|
||||
print(f" Available tools: {len(tools)}")
|
||||
|
||||
analysis_plan = plan_analysis(profile, requirement, available_tools=tools)
|
||||
print(f" Tasks planned: {len(analysis_plan.tasks)}")
|
||||
for task in sorted(analysis_plan.tasks, key=lambda t: t.priority, reverse=True):
|
||||
print(f" [{task.priority}] {task.name}")
|
||||
if task.required_tools:
|
||||
print(f" tools: {', '.join(task.required_tools)}")
|
||||
|
||||
# ── Stage 4: AI Task Execution ──
|
||||
print("\n[4/5] AI Executing Tasks...")
|
||||
# Reuse DAL from Stage 1 — no need to load data again
|
||||
results: List[AnalysisResult] = []
|
||||
|
||||
sorted_tasks = sorted(analysis_plan.tasks, key=lambda t: t.priority, reverse=True)
|
||||
for i, task in enumerate(sorted_tasks, 1):
|
||||
print(f"\n Task {i}/{len(sorted_tasks)}: {task.name}")
|
||||
result = execute_task(task, tools, dal)
|
||||
results.append(result)
|
||||
|
||||
if result.success:
|
||||
print(f" ✓ Done ({result.execution_time:.1f}s), insights: {len(result.insights)}")
|
||||
for insight in result.insights[:2]:
|
||||
print(f" - {insight[:80]}")
|
||||
else:
|
||||
print(f" ✗ Failed: {result.error}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
print(f"\n Results: {successful}/{len(results)} tasks succeeded")
|
||||
|
||||
# ── Stage 5: Report Generation ──
|
||||
print("\n[5/5] Generating Report...")
|
||||
report_path = os.path.join(output_dir, "analysis_report.md")
|
||||
|
||||
if template_file and os.path.exists(template_file):
|
||||
report = _generate_template_report(profile, results, template_file, config)
|
||||
else:
|
||||
report = generate_report(results, requirement, profile)
|
||||
|
||||
# Save report
|
||||
with open(report_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report)
|
||||
|
||||
print(f" Report saved: {report_path}")
|
||||
print(f" Report length: {len(report)} chars")
|
||||
|
||||
# ── Done ──
|
||||
print("\n" + "=" * 70)
|
||||
print("Analysis Complete!")
|
||||
print(f"End: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"Output: {report_path}")
|
||||
print("=" * 70)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _generate_template_report(
|
||||
profile: DataProfile,
|
||||
results: List[AnalysisResult],
|
||||
template_path: str,
|
||||
config
|
||||
) -> str:
|
||||
"""Use AI to fill a template with data from task execution results."""
|
||||
client = OpenAI(api_key=config.llm.api_key, base_url=config.llm.base_url)
|
||||
|
||||
with open(template_path, 'r', encoding='utf-8') as f:
|
||||
template = f.read()
|
||||
|
||||
# Collect all data from task results
|
||||
all_data = {}
|
||||
all_insights = []
|
||||
for r in results:
|
||||
if r.success:
|
||||
all_data[r.task_name] = {
|
||||
'data': json.dumps(r.data, ensure_ascii=False, default=str)[:1000],
|
||||
'insights': r.insights
|
||||
}
|
||||
all_insights.extend(r.insights)
|
||||
|
||||
data_json = json.dumps(all_data, ensure_ascii=False, indent=1)
|
||||
if len(data_json) > 12000:
|
||||
data_json = data_json[:12000] + "\n... (truncated)"
|
||||
|
||||
prompt = f"""你是一位专业的数据分析师。请根据以下分析结果,按照模板格式生成完整的报告。
|
||||
|
||||
## 数据概况
|
||||
- 类型: {profile.inferred_type}
|
||||
- 行数: {profile.row_count}, 列数: {profile.column_count}
|
||||
- 质量: {profile.quality_score}/100
|
||||
- 摘要: {profile.summary[:500]}
|
||||
|
||||
## 关键字段
|
||||
{json.dumps(profile.key_fields, ensure_ascii=False, indent=2)}
|
||||
|
||||
## 分析结果
|
||||
{data_json}
|
||||
|
||||
## 关键洞察
|
||||
{chr(10).join(f"- {i}" for i in all_insights[:20])}
|
||||
|
||||
## 报告模板
|
||||
```markdown
|
||||
{template}
|
||||
```
|
||||
|
||||
## 要求
|
||||
1. 用实际数据填充模板中所有占位符
|
||||
2. 根据数据中的字段,智能映射到模板分类
|
||||
3. 所有数字必须来自分析结果,不要编造
|
||||
4. 如果某个模板分类在数据中没有对应,标注"本期无数据"
|
||||
5. 保持Markdown格式
|
||||
"""
|
||||
|
||||
print(" AI filling template with analysis results...")
|
||||
response = client.chat.completions.create(
|
||||
model=config.llm.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "你是数据分析专家。根据分析结果填充报告模板,所有数字必须来自真实数据。输出纯Markdown。"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=4000
|
||||
)
|
||||
|
||||
report = response.choices[0].message.content
|
||||
|
||||
header = f"""<!--
|
||||
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
Data: {profile.file_path} ({profile.row_count} rows x {profile.column_count} cols)
|
||||
Quality: {profile.quality_score}/100
|
||||
Template: {template_path}
|
||||
AI never accessed raw data rows — only aggregated tool results
|
||||
-->
|
||||
|
||||
"""
|
||||
return header + report
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="AI-Driven Data Analysis")
|
||||
parser.add_argument("--data", default="cleaned_data.csv", help="CSV file path")
|
||||
parser.add_argument("--requirement", default=None, help="Analysis requirement (natural language)")
|
||||
parser.add_argument("--template", default=None, help="Report template path")
|
||||
parser.add_argument("--output", default="analysis_output", help="Output directory")
|
||||
args = parser.parse_args()
|
||||
|
||||
success = run_analysis(
|
||||
data_file=args.data,
|
||||
user_requirement=args.requirement,
|
||||
template_file=args.template,
|
||||
output_dir=args.output
|
||||
)
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user