262 lines
8.9 KiB
Python
262 lines
8.9 KiB
Python
"""
|
||
AI-Driven Data Analysis Framework
|
||
==================================
|
||
Generic pipeline: works with any CSV file.
|
||
AI decides everything at runtime based on data characteristics.
|
||
|
||
Flow: Load CSV → AI understands metadata → AI plans analysis → AI executes tasks → AI generates report
|
||
"""
|
||
import sys
|
||
import os
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
|
||
from src.env_loader import load_env_with_fallback
|
||
load_env_with_fallback(['.env'])
|
||
|
||
import logging
|
||
import json
|
||
from datetime import datetime
|
||
from typing import Optional, List
|
||
from openai import OpenAI
|
||
|
||
from src.engines.ai_data_understanding import ai_understand_data_with_dal
|
||
from src.engines.requirement_understanding import understand_requirement
|
||
from src.engines.analysis_planning import plan_analysis
|
||
from src.engines.task_execution import execute_task
|
||
from src.engines.report_generation import generate_report
|
||
from src.tools.tool_manager import ToolManager
|
||
from src.tools.base import _global_registry
|
||
from src.models import DataProfile, AnalysisResult
|
||
from src.config import get_config
|
||
|
||
# Register all tools
|
||
from src.tools import register_tool
|
||
from src.tools.query_tools import (
|
||
GetColumnDistributionTool, GetValueCountsTool,
|
||
GetTimeSeriesTool, GetCorrelationTool
|
||
)
|
||
from src.tools.stats_tools import (
|
||
CalculateStatisticsTool, PerformGroupbyTool,
|
||
DetectOutliersTool, CalculateTrendTool
|
||
)
|
||
from src.tools.viz_tools import (
|
||
CreateBarChartTool, CreateLineChartTool,
|
||
CreatePieChartTool, CreateHeatmapTool
|
||
)
|
||
|
||
for tool_cls in [
|
||
GetColumnDistributionTool, GetValueCountsTool, GetTimeSeriesTool, GetCorrelationTool,
|
||
CalculateStatisticsTool, PerformGroupbyTool, DetectOutliersTool, CalculateTrendTool,
|
||
CreateBarChartTool, CreateLineChartTool, CreatePieChartTool, CreateHeatmapTool
|
||
]:
|
||
register_tool(tool_cls())
|
||
|
||
logging.basicConfig(level=logging.WARNING)
|
||
|
||
|
||
def run_analysis(
|
||
data_file: str,
|
||
user_requirement: Optional[str] = None,
|
||
template_file: Optional[str] = None,
|
||
output_dir: str = "analysis_output"
|
||
):
|
||
"""
|
||
Run the full AI-driven analysis pipeline.
|
||
|
||
Args:
|
||
data_file: Path to any CSV file
|
||
user_requirement: Natural language requirement (optional)
|
||
template_file: Report template path (optional)
|
||
output_dir: Output directory
|
||
"""
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
config = get_config()
|
||
|
||
print("\n" + "=" * 70)
|
||
print("AI-Driven Data Analysis")
|
||
print("=" * 70)
|
||
print(f"Start: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
print(f"Data: {data_file}")
|
||
if template_file:
|
||
print(f"Template: {template_file}")
|
||
print("=" * 70)
|
||
|
||
# ── Stage 1: AI Data Understanding ──
|
||
print("\n[1/5] AI Understanding Data...")
|
||
print(" (AI sees only metadata — never raw rows)")
|
||
profile, dal = ai_understand_data_with_dal(data_file)
|
||
print(f" Type: {profile.inferred_type}")
|
||
print(f" Quality: {profile.quality_score}/100")
|
||
print(f" Columns: {profile.column_count}, Rows: {profile.row_count}")
|
||
print(f" Summary: {profile.summary[:120]}...")
|
||
|
||
# ── Stage 2: Requirement Understanding ──
|
||
print("\n[2/5] Understanding Requirements...")
|
||
requirement = understand_requirement(
|
||
user_input=user_requirement or "对数据进行全面分析",
|
||
data_profile=profile,
|
||
template_path=template_file
|
||
)
|
||
print(f" Objectives: {len(requirement.objectives)}")
|
||
for obj in requirement.objectives:
|
||
print(f" - {obj.name} (priority: {obj.priority})")
|
||
|
||
# ── Stage 3: AI Analysis Planning ──
|
||
print("\n[3/5] AI Planning Analysis...")
|
||
tool_manager = ToolManager(_global_registry)
|
||
tools = tool_manager.select_tools(profile)
|
||
print(f" Available tools: {len(tools)}")
|
||
|
||
analysis_plan = plan_analysis(profile, requirement, available_tools=tools)
|
||
print(f" Tasks planned: {len(analysis_plan.tasks)}")
|
||
for task in sorted(analysis_plan.tasks, key=lambda t: t.priority, reverse=True):
|
||
print(f" [{task.priority}] {task.name}")
|
||
if task.required_tools:
|
||
print(f" tools: {', '.join(task.required_tools)}")
|
||
|
||
# ── Stage 4: AI Task Execution ──
|
||
print("\n[4/5] AI Executing Tasks...")
|
||
# Reuse DAL from Stage 1 — no need to load data again
|
||
results: List[AnalysisResult] = []
|
||
|
||
sorted_tasks = sorted(analysis_plan.tasks, key=lambda t: t.priority, reverse=True)
|
||
for i, task in enumerate(sorted_tasks, 1):
|
||
print(f"\n Task {i}/{len(sorted_tasks)}: {task.name}")
|
||
result = execute_task(task, tools, dal)
|
||
results.append(result)
|
||
|
||
if result.success:
|
||
print(f" ✓ Done ({result.execution_time:.1f}s), insights: {len(result.insights)}")
|
||
for insight in result.insights[:2]:
|
||
print(f" - {insight[:80]}")
|
||
else:
|
||
print(f" ✗ Failed: {result.error}")
|
||
|
||
successful = sum(1 for r in results if r.success)
|
||
print(f"\n Results: {successful}/{len(results)} tasks succeeded")
|
||
|
||
# ── Stage 5: Report Generation ──
|
||
print("\n[5/5] Generating Report...")
|
||
report_path = os.path.join(output_dir, "analysis_report.md")
|
||
|
||
if template_file and os.path.exists(template_file):
|
||
report = _generate_template_report(profile, results, template_file, config)
|
||
else:
|
||
report = generate_report(results, requirement, profile)
|
||
|
||
# Save report
|
||
with open(report_path, 'w', encoding='utf-8') as f:
|
||
f.write(report)
|
||
|
||
print(f" Report saved: {report_path}")
|
||
print(f" Report length: {len(report)} chars")
|
||
|
||
# ── Done ──
|
||
print("\n" + "=" * 70)
|
||
print("Analysis Complete!")
|
||
print(f"End: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
print(f"Output: {report_path}")
|
||
print("=" * 70)
|
||
|
||
return True
|
||
|
||
|
||
def _generate_template_report(
|
||
profile: DataProfile,
|
||
results: List[AnalysisResult],
|
||
template_path: str,
|
||
config
|
||
) -> str:
|
||
"""Use AI to fill a template with data from task execution results."""
|
||
client = OpenAI(api_key=config.llm.api_key, base_url=config.llm.base_url)
|
||
|
||
with open(template_path, 'r', encoding='utf-8') as f:
|
||
template = f.read()
|
||
|
||
# Collect all data from task results
|
||
all_data = {}
|
||
all_insights = []
|
||
for r in results:
|
||
if r.success:
|
||
all_data[r.task_name] = {
|
||
'data': json.dumps(r.data, ensure_ascii=False, default=str)[:1000],
|
||
'insights': r.insights
|
||
}
|
||
all_insights.extend(r.insights)
|
||
|
||
data_json = json.dumps(all_data, ensure_ascii=False, indent=1)
|
||
if len(data_json) > 12000:
|
||
data_json = data_json[:12000] + "\n... (truncated)"
|
||
|
||
prompt = f"""你是一位专业的数据分析师。请根据以下分析结果,按照模板格式生成完整的报告。
|
||
|
||
## 数据概况
|
||
- 类型: {profile.inferred_type}
|
||
- 行数: {profile.row_count}, 列数: {profile.column_count}
|
||
- 质量: {profile.quality_score}/100
|
||
- 摘要: {profile.summary[:500]}
|
||
|
||
## 关键字段
|
||
{json.dumps(profile.key_fields, ensure_ascii=False, indent=2)}
|
||
|
||
## 分析结果
|
||
{data_json}
|
||
|
||
## 关键洞察
|
||
{chr(10).join(f"- {i}" for i in all_insights[:20])}
|
||
|
||
## 报告模板
|
||
```markdown
|
||
{template}
|
||
```
|
||
|
||
## 要求
|
||
1. 用实际数据填充模板中所有占位符
|
||
2. 根据数据中的字段,智能映射到模板分类
|
||
3. 所有数字必须来自分析结果,不要编造
|
||
4. 如果某个模板分类在数据中没有对应,标注"本期无数据"
|
||
5. 保持Markdown格式
|
||
"""
|
||
|
||
print(" AI filling template with analysis results...")
|
||
response = client.chat.completions.create(
|
||
model=config.llm.model,
|
||
messages=[
|
||
{"role": "system", "content": "你是数据分析专家。根据分析结果填充报告模板,所有数字必须来自真实数据。输出纯Markdown。"},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=0.3,
|
||
max_tokens=4000
|
||
)
|
||
|
||
report = response.choices[0].message.content
|
||
|
||
header = f"""<!--
|
||
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||
Data: {profile.file_path} ({profile.row_count} rows x {profile.column_count} cols)
|
||
Quality: {profile.quality_score}/100
|
||
Template: {template_path}
|
||
AI never accessed raw data rows — only aggregated tool results
|
||
-->
|
||
|
||
"""
|
||
return header + report
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description="AI-Driven Data Analysis")
|
||
parser.add_argument("--data", default="cleaned_data.csv", help="CSV file path")
|
||
parser.add_argument("--requirement", default=None, help="Analysis requirement (natural language)")
|
||
parser.add_argument("--template", default=None, help="Report template path")
|
||
parser.add_argument("--output", default="analysis_output", help="Output directory")
|
||
args = parser.parse_args()
|
||
|
||
success = run_analysis(
|
||
data_file=args.data,
|
||
user_requirement=args.requirement,
|
||
template_file=args.template,
|
||
output_dir=args.output
|
||
)
|
||
sys.exit(0 if success else 1)
|