Files
vibe_data_ana/run_analysis_en.py
2026-03-09 10:37:35 +08:00

307 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI-Driven Data Analysis Framework
==================================
Generic pipeline: works with any CSV file.
AI decides everything at runtime based on data characteristics.
Flow: Load CSV → AI understands metadata → AI plans analysis → AI executes tasks → AI generates report
"""
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from src.env_loader import load_env_with_fallback
load_env_with_fallback(['.env'])
import logging
import json
from datetime import datetime
from typing import Optional, List
from openai import OpenAI
from src.engines.ai_data_understanding import ai_understand_data_with_dal
from src.engines.requirement_understanding import understand_requirement
from src.engines.analysis_planning import plan_analysis
from src.engines.task_execution import execute_task
from src.engines.report_generation import generate_report, _convert_chart_paths_in_report
from src.tools.tool_manager import ToolManager
from src.tools.base import _global_registry
from src.models import DataProfile, AnalysisResult
from src.config import get_config
# Register all tools
from src.tools import register_tool
from src.tools.query_tools import (
GetColumnDistributionTool, GetValueCountsTool,
GetTimeSeriesTool, GetCorrelationTool
)
from src.tools.stats_tools import (
CalculateStatisticsTool, PerformGroupbyTool,
DetectOutliersTool, CalculateTrendTool
)
from src.tools.viz_tools import (
CreateBarChartTool, CreateLineChartTool,
CreatePieChartTool, CreateHeatmapTool
)
for tool_cls in [
GetColumnDistributionTool, GetValueCountsTool, GetTimeSeriesTool, GetCorrelationTool,
CalculateStatisticsTool, PerformGroupbyTool, DetectOutliersTool, CalculateTrendTool,
CreateBarChartTool, CreateLineChartTool, CreatePieChartTool, CreateHeatmapTool
]:
register_tool(tool_cls())
logging.basicConfig(level=logging.WARNING)
def run_analysis(
data_file: str,
user_requirement: Optional[str] = None,
template_file: Optional[str] = None,
output_dir: str = "analysis_output"
):
"""
Run the full AI-driven analysis pipeline.
Each run creates a timestamped subdirectory under output_dir:
output_dir/run_20260309_143025/
├── analysis_report.md
└── charts/
├── bar_chart.png
└── ...
Args:
data_file: Path to any CSV file
user_requirement: Natural language requirement (optional)
template_file: Report template path (optional)
output_dir: Base output directory
"""
# 每次运行创建带时间戳的子目录
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_dir = os.path.join(output_dir, f"run_{run_timestamp}")
os.makedirs(run_dir, exist_ok=True)
config = get_config()
print("\n" + "=" * 70)
print("AI-Driven Data Analysis")
print("=" * 70)
print(f"Start: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Data: {data_file}")
print(f"Output: {run_dir}")
if template_file:
print(f"Template: {template_file}")
print("=" * 70)
# ── Stage 1: AI Data Understanding ──
print("\n[1/5] AI Understanding Data...")
print(" (AI sees only metadata — never raw rows)")
profile, dal = ai_understand_data_with_dal(data_file)
print(f" Type: {profile.inferred_type}")
print(f" Quality: {profile.quality_score}/100")
print(f" Columns: {profile.column_count}, Rows: {profile.row_count}")
print(f" Summary: {profile.summary[:120]}...")
# ── Stage 2: Requirement Understanding ──
print("\n[2/5] Understanding Requirements...")
requirement = understand_requirement(
user_input=user_requirement or "对数据进行全面分析",
data_profile=profile,
template_path=template_file
)
print(f" Objectives: {len(requirement.objectives)}")
for obj in requirement.objectives:
print(f" - {obj.name} (priority: {obj.priority})")
# ── Stage 3: AI Analysis Planning ──
print("\n[3/5] AI Planning Analysis...")
tool_manager = ToolManager(_global_registry)
tools = tool_manager.select_tools(profile)
print(f" Available tools: {len(tools)}")
analysis_plan = plan_analysis(profile, requirement, available_tools=tools)
print(f" Tasks planned: {len(analysis_plan.tasks)}")
for task in sorted(analysis_plan.tasks, key=lambda t: t.priority, reverse=True):
print(f" [{task.priority}] {task.name}")
if task.required_tools:
print(f" tools: {', '.join(task.required_tools)}")
# ── Stage 4: AI Task Execution ──
print("\n[4/5] AI Executing Tasks...")
# Reuse DAL from Stage 1 — no need to load data again
dal.set_output_dir(run_dir)
results: List[AnalysisResult] = []
sorted_tasks = sorted(analysis_plan.tasks, key=lambda t: t.priority, reverse=True)
for i, task in enumerate(sorted_tasks, 1):
print(f"\n Task {i}/{len(sorted_tasks)}: {task.name}")
result = execute_task(task, tools, dal)
results.append(result)
if result.success:
print(f" ✓ Done ({result.execution_time:.1f}s), insights: {len(result.insights)}")
for insight in result.insights[:2]:
print(f" - {insight[:80]}")
else:
print(f" ✗ Failed: {result.error}")
successful = sum(1 for r in results if r.success)
print(f"\n Results: {successful}/{len(results)} tasks succeeded")
# ── Stage 5: Report Generation ──
print("\n[5/5] Generating Report...")
report_path = os.path.join(run_dir, "analysis_report.md")
if template_file and os.path.exists(template_file):
report = _generate_template_report(profile, results, template_file, config, run_dir)
else:
report = generate_report(results, requirement, profile, output_path=run_dir)
# Save report — convert chart paths to relative (./charts/xxx.png)
report = _convert_chart_paths_in_report(report, run_dir)
with open(report_path, 'w', encoding='utf-8') as f:
f.write(report)
print(f" Report saved: {report_path}")
print(f" Report length: {len(report)} chars")
# ── Done ──
print("\n" + "=" * 70)
print("Analysis Complete!")
print(f"End: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Output: {run_dir}")
print("=" * 70)
return True
def _generate_template_report(
profile: DataProfile,
results: List[AnalysisResult],
template_path: str,
config,
run_dir: str = ""
) -> str:
"""Use AI to fill a template with data from task execution results."""
client = OpenAI(api_key=config.llm.api_key, base_url=config.llm.base_url)
with open(template_path, 'r', encoding='utf-8') as f:
template = f.read()
# Collect all data from task results
all_data = {}
all_insights = []
for r in results:
if r.success:
all_data[r.task_name] = {
'data': json.dumps(r.data, ensure_ascii=False, default=str)[:1000],
'insights': r.insights
}
all_insights.extend(r.insights)
data_json = json.dumps(all_data, ensure_ascii=False, indent=1)
if len(data_json) > 12000:
data_json = data_json[:12000] + "\n... (truncated)"
prompt = f"""你是一位专业的数据分析师。请根据以下分析结果,按照模板格式生成完整的报告。
## 数据概况
- 类型: {profile.inferred_type}
- 行数: {profile.row_count}, 列数: {profile.column_count}
- 质量: {profile.quality_score}/100
- 摘要: {profile.summary[:500]}
## 关键字段
{json.dumps(profile.key_fields, ensure_ascii=False, indent=2)}
## 分析结果
{data_json}
## 关键洞察
{chr(10).join(f"- {i}" for i in all_insights[:20])}
## 报告模板
```markdown
{template}
```
## 图表文件
以下是分析过程中生成的图表文件,请在报告适当位置嵌入:
{_collect_chart_paths(results, run_dir)}
## 要求
1. 用实际数据填充模板中所有占位符
2. 根据数据中的字段,智能映射到模板分类
3. 所有数字必须来自分析结果,不要编造
4. 如果某个模板分类在数据中没有对应,标注"本期无数据"
5. 保持Markdown格式
6. 在报告中嵌入图表,使用 ![描述](图表路径) 格式,让报告图文结合
"""
print(" AI filling template with analysis results...")
response = client.chat.completions.create(
model=config.llm.model,
messages=[
{"role": "system", "content": "你是数据分析专家。根据分析结果填充报告模板所有数字必须来自真实数据。输出纯Markdown。"},
{"role": "user", "content": prompt}
],
temperature=0.3,
max_tokens=4000
)
report = response.choices[0].message.content
header = f"""<!--
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Data: {profile.file_path} ({profile.row_count} rows x {profile.column_count} cols)
Quality: {profile.quality_score}/100
Template: {template_path}
AI never accessed raw data rows — only aggregated tool results
-->
"""
return header + report
def _collect_chart_paths(results: List[AnalysisResult], run_dir: str = "") -> str:
"""Collect all chart paths from task results for embedding in reports.
Returns paths relative to run_dir (e.g. ./charts/bar_chart.png)."""
paths = []
for r in results:
if not r.success:
continue
# From visualizations list
for viz in (r.visualizations or []):
if viz and viz not in paths:
paths.append(viz)
# From data dict (chart_path in tool results)
if isinstance(r.data, dict):
for key, val in r.data.items():
if isinstance(val, dict) and val.get('chart_path'):
cp = val['chart_path']
if cp not in paths:
paths.append(cp)
if not paths:
return "(无图表)"
# Convert to relative paths
from src.engines.report_generation import _to_relative_chart_path
rel_paths = [_to_relative_chart_path(p, run_dir) for p in paths]
return "\n".join(f"- {p}" for p in rel_paths)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="AI-Driven Data Analysis")
parser.add_argument("--data", default="cleaned_data.csv", help="CSV file path")
parser.add_argument("--requirement", default=None, help="Analysis requirement (natural language)")
parser.add_argument("--template", default=None, help="Report template path")
parser.add_argument("--output", default="analysis_output", help="Output directory")
args = parser.parse_args()
success = run_analysis(
data_file=args.data,
user_requirement=args.requirement,
template_file=args.template,
output_dir=args.output
)
sys.exit(0 if success else 1)