Files

642 lines
24 KiB
Python
Raw Permalink Normal View History

2026-01-09 16:52:45 +08:00
import sys
import os
import threading
import glob
import uuid
import json
2026-01-31 20:27:17 +08:00
from datetime import datetime
2026-01-09 16:52:45 +08:00
from typing import Optional, Dict, List
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
# Add parent directory to path to import agent modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data_analysis_agent import DataAnalysisAgent
from config.llm_config import LLMConfig
from utils.create_session_dir import create_session_output_dir
2026-01-31 20:27:17 +08:00
from config.llm_config import LLMConfig
from utils.create_session_dir import create_session_output_dir
2026-01-09 16:52:45 +08:00
app = FastAPI(title="IOV Data Analysis Agent")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Session Management ---
class SessionData:
def __init__(self, session_id: str):
self.session_id = session_id
self.is_running = False
self.output_dir: Optional[str] = None
self.generated_report: Optional[str] = None
self.log_file: Optional[str] = None
self.analysis_results: List[Dict] = [] # Store analysis results for gallery
2026-01-22 22:26:04 +08:00
self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up
# 新增:进度跟踪
self.current_round: int = 0
self.max_rounds: int = 20
self.progress_percentage: float = 0.0
self.status_message: str = "等待开始"
# 新增:历史记录
self.created_at: str = ""
self.last_updated: str = ""
self.user_requirement: str = ""
self.file_list: List[str] = []
2026-01-31 20:27:17 +08:00
self.reusable_script: Optional[str] = None # 新增:可复用脚本路径
2026-01-22 22:26:04 +08:00
2026-01-09 16:52:45 +08:00
class SessionManager:
def __init__(self):
self.sessions: Dict[str, SessionData] = {}
self.lock = threading.Lock()
def create_session(self) -> str:
with self.lock:
session_id = str(uuid.uuid4())
self.sessions[session_id] = SessionData(session_id)
return session_id
2026-01-09 16:52:45 +08:00
def get_session(self, session_id: str) -> Optional[SessionData]:
2026-01-31 20:27:17 +08:00
if session_id in self.sessions:
return self.sessions[session_id]
# Fallback: Try to reconstruct from disk for history sessions
output_dir = os.path.join("outputs", f"session_{session_id}")
if os.path.exists(output_dir) and os.path.isdir(output_dir):
return self._reconstruct_session(session_id, output_dir)
return None
def _reconstruct_session(self, session_id: str, output_dir: str) -> SessionData:
"""从磁盘目录重建会话对象"""
session = SessionData(session_id)
session.output_dir = output_dir
session.is_running = False
session.current_round = session.max_rounds
session.progress_percentage = 100.0
session.status_message = "已完成 (历史记录)"
# Recover Log
log_path = os.path.join(output_dir, "process.log")
if os.path.exists(log_path):
session.log_file = log_path
# Recover Report
# 宽容查找:扫描所有 .md 文件,优先取包含 "report" 或 "报告" 的文件
md_files = glob.glob(os.path.join(output_dir, "*.md"))
if md_files:
# 默认取第一个
chosen = md_files[0]
# 尝试找更好的匹配
for md in md_files:
fname = os.path.basename(md).lower()
if "report" in fname or "报告" in fname:
chosen = md
break
session.generated_report = chosen
# Recover Script (查找可能的脚本文件)
possible_scripts = ["data_analysis_script.py", "script.py", "analysis_script.py"]
for s in possible_scripts:
p = os.path.join(output_dir, s)
if os.path.exists(p):
session.reusable_script = p
break
# Recover Results (images etc)
results_json = os.path.join(output_dir, "results.json")
if os.path.exists(results_json):
try:
with open(results_json, "r") as f:
session.analysis_results = json.load(f)
except:
pass
# Recover Metadata
try:
stat = os.stat(output_dir)
dt = datetime.fromtimestamp(stat.st_ctime)
session.created_at = dt.strftime("%Y-%m-%d %H:%M:%S")
except:
pass
# Cache it
with self.lock:
self.sessions[session_id] = session
return session
2026-01-09 16:52:45 +08:00
def list_sessions(self):
return list(self.sessions.keys())
def delete_session(self, session_id: str) -> bool:
"""删除指定会话"""
with self.lock:
if session_id in self.sessions:
session = self.sessions[session_id]
if session.agent:
session.agent.reset()
del self.sessions[session_id]
return True
return False
def get_session_info(self, session_id: str) -> Optional[Dict]:
"""获取会话详细信息"""
session = self.get_session(session_id)
if session:
return {
"session_id": session.session_id,
"is_running": session.is_running,
"progress": session.progress_percentage,
"status": session.status_message,
"current_round": session.current_round,
"max_rounds": session.max_rounds,
"created_at": session.created_at,
"last_updated": session.last_updated,
2026-01-31 20:27:17 +08:00
"created_at": session.created_at,
"last_updated": session.last_updated,
"user_requirement": session.user_requirement[:100] + "..." if len(session.user_requirement) > 100 else session.user_requirement,
"script_path": session.reusable_script # 新增:返回脚本路径
}
return None
2026-01-09 16:52:45 +08:00
session_manager = SessionManager()
# Mount static files
os.makedirs("web/static", exist_ok=True)
os.makedirs("uploads", exist_ok=True)
os.makedirs("outputs", exist_ok=True)
app.mount("/static", StaticFiles(directory="web/static"), name="static")
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
# --- Helper Functions ---
2026-01-22 22:26:04 +08:00
def run_analysis_task(session_id: str, files: list, user_requirement: str, is_followup: bool = False):
2026-01-09 16:52:45 +08:00
"""
Runs the analysis agent in a background thread for a specific session.
"""
session = session_manager.get_session(session_id)
if not session:
print(f"Error: Session {session_id} not found in background task.")
return
session.is_running = True
try:
2026-01-22 22:26:04 +08:00
# Create session directory if not exists (for follow-up it should accept existing)
2026-01-09 16:52:45 +08:00
base_output_dir = "outputs"
2026-01-22 22:26:04 +08:00
if not session.output_dir:
session.output_dir = create_session_output_dir(base_output_dir, user_requirement)
session_output_dir = session.output_dir
2026-01-09 16:52:45 +08:00
# Initialize Log capturing
session.log_file = os.path.join(session_output_dir, "process.log")
# Thread-safe logging requires a bit of care.
# Since we are running in a thread, redirecting sys.stdout globally is BAD for multi-session.
# However, for this MVP, if we run multiple sessions concurrently, their logs will mix in stdout.
# BUT we are writing to specific log files.
# We need a logger that writes to the session's log file.
# And the Agent needs to use that logger.
# Currently the Agent uses print().
# To support true concurrent logging without mixing, we'd need to refactor Agent to use a logger instance.
# LIMITATION: For now, we accept that stdout redirection intercepts EVERYTHING.
# So multiple concurrent sessions is risky with global stdout redirection.
# A safer approach for now: We won't redirect stdout globally for multi-session support
# unless we lock execution to one at a time.
# OR: We just rely on the fact that we might only run one analysis at a time mostly.
# Let's try to just write to the log file explicitly if we could, but we can't change Agent easily right now.
# Compromise: We will continue to use global redirection but acknowledge it's not thread-safe for output.
# A better way: Modify Agent to accept a 'log_callback'.
# For this refactor, let's stick to the existing pattern but bind it to the thread if possible? No.
# We will wrap the execution with a simple File Logger that appends to the distinct file.
# But sys.stdout is global.
# We will assume single concurrent analysis for safety, or accept mixed terminal output but separate file logs?
# Actually, if we swap sys.stdout, it affects all threads.
# So we MUST NOT swap sys.stdout if we want concurrency.
# If we don't swap stdout, we don't capture logs to file unless Agent does it.
# The Agent code has `print`.
# Correct fix: Refactor Agent to use `logging` module or pass a printer.
# Given the scope, let's just hold the lock (serialize execution) OR allow mixing in terminal
# but try to capture to file?
# Let's just write to the file.
2026-01-22 22:26:04 +08:00
# Let's just write to the file.
with open(session.log_file, "a" if is_followup else "w", encoding="utf-8") as f:
if is_followup:
f.write(f"\n--- Follow-up Session {session_id} Continued ---\n")
else:
f.write(f"--- Session {session_id} Started ---\n")
2026-01-09 16:52:45 +08:00
# We will create a custom print function that writes to the file
# And monkeypatch builtins.print? No, that's too hacky.
# Let's just use the stdout redirector, but acknowledge only one active session at a time is safe.
# We can implement a crude lock for now.
class FileLogger:
def __init__(self, filename):
self.terminal = sys.__stdout__
self.log = open(filename, "a", encoding="utf-8", buffering=1)
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def close(self):
self.log.close()
logger = FileLogger(session.log_file)
sys.stdout = logger # Global hijack!
try:
2026-01-22 22:26:04 +08:00
if not is_followup:
llm_config = LLMConfig()
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
session.agent = agent
result = agent.analyze(
user_input=user_requirement,
files=files,
session_output_dir=session_output_dir,
reset_session=True
)
else:
agent = session.agent
if not agent:
print("Error: Agent not initialized for follow-up.")
return
result = agent.analyze(
user_input=user_requirement,
files=None,
session_output_dir=session_output_dir,
reset_session=False,
max_rounds=10
)
2026-01-09 16:52:45 +08:00
session.generated_report = result.get("report_file_path", None)
session.analysis_results = result.get("analysis_results", [])
2026-01-31 20:27:17 +08:00
session.reusable_script = result.get("reusable_script_path", None) # 新增:保存脚本路径
2026-01-09 16:52:45 +08:00
# Save results to json for persistence
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
json.dump(session.analysis_results, f, default=str)
except Exception as e:
print(f"Error during analysis: {e}")
finally:
sys.stdout = logger.terminal
logger.close()
except Exception as e:
print(f"System Error: {e}")
finally:
session.is_running = False
# --- Pydantic Models ---
class StartRequest(BaseModel):
requirement: str
2026-01-22 22:26:04 +08:00
class ChatRequest(BaseModel):
session_id: str
message: str
2026-01-09 16:52:45 +08:00
# --- API Endpoints ---
@app.get("/")
async def read_root():
return FileResponse("web/static/index.html")
@app.post("/api/upload")
async def upload_files(files: list[UploadFile] = File(...)):
saved_files = []
for file in files:
file_location = f"uploads/{file.filename}"
with open(file_location, "wb+") as file_object:
file_object.write(file.file.read())
saved_files.append(file_location)
return {"info": f"Saved {len(saved_files)} files", "paths": saved_files}
@app.post("/api/start")
async def start_analysis(request: StartRequest, background_tasks: BackgroundTasks):
session_id = session_manager.create_session()
files = glob.glob("uploads/*.csv")
if not files:
if os.path.exists("cleaned_data.csv"):
files = ["cleaned_data.csv"]
else:
raise HTTPException(status_code=400, detail="No CSV files found")
files = [os.path.abspath(f) for f in files] # Only use absolute paths
2026-01-22 22:26:04 +08:00
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement, is_followup=False)
2026-01-09 16:52:45 +08:00
return {"status": "started", "session_id": session_id}
2026-01-22 22:26:04 +08:00
@app.post("/api/chat")
async def chat_analysis(request: ChatRequest, background_tasks: BackgroundTasks):
session = session_manager.get_session(request.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.is_running:
raise HTTPException(status_code=400, detail="Analysis already in progress")
background_tasks.add_task(run_analysis_task, request.session_id, [], request.message, is_followup=True)
return {"status": "started"}
2026-01-09 16:52:45 +08:00
@app.get("/api/status")
async def get_status(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
log_content = ""
if session.log_file and os.path.exists(session.log_file):
with open(session.log_file, "r", encoding="utf-8") as f:
log_content = f.read()
return {
"is_running": session.is_running,
"log": log_content,
"has_report": session.generated_report is not None,
2026-01-31 20:27:17 +08:00
"report_path": session.generated_report,
"script_path": session.reusable_script # 新增:返回脚本路径
2026-01-09 16:52:45 +08:00
}
2026-01-22 22:26:04 +08:00
@app.get("/api/export")
async def export_session(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.output_dir or not os.path.exists(session.output_dir):
raise HTTPException(status_code=404, detail="No data available for export")
# Create a zip file
import shutil
# We want to zip the contents of session_output_dir
# Zip path should be outside to avoid recursive zipping if inside
zip_base_name = os.path.join("outputs", f"export_{session_id}")
# shutil.make_archive expects base_name (without extension) and root_dir
archive_path = shutil.make_archive(zip_base_name, 'zip', session.output_dir)
return FileResponse(archive_path, media_type='application/zip', filename=f"analysis_export_{session_id}.zip")
2026-01-09 16:52:45 +08:00
@app.get("/api/report")
async def get_report(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if not session.generated_report or not os.path.exists(session.generated_report):
return {"content": "Report not ready."}
with open(session.generated_report, "r", encoding="utf-8") as f:
content = f.read()
# Fix image paths
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
web_base_path = f"/{relative_session_path}"
2026-01-22 22:26:04 +08:00
# Robust image path replacement
# 1. Replace explicit relative paths ./image.png
2026-01-09 16:52:45 +08:00
content = content.replace("](./", f"]({web_base_path}/")
2026-01-22 22:26:04 +08:00
# 2. Replace naked paths that might be generated like ](image.png) but NOT ](http...) or ](/...)
import re
def replace_link(match):
alt = match.group(1)
url = match.group(2)
if url.startswith("http") or url.startswith("/") or url.startswith("data:"):
return match.group(0)
# Remove ./ if exists again just in case
clean_url = url.lstrip("./")
return f"![{alt}]({web_base_path}/{clean_url})"
content = re.sub(r'!\[(.*?)\]\((.*?)\)', replace_link, content)
2026-01-09 16:52:45 +08:00
return {"content": content, "base_path": web_base_path}
@app.get("/api/figures")
async def get_figures(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# We can try to get from memory first
results = session.analysis_results
# If empty in memory (maybe server restarted but files exist?), try load json
if not results and session.output_dir:
json_path = os.path.join(session.output_dir, "results.json")
if os.path.exists(json_path):
with open(json_path, 'r') as f:
results = json.load(f)
# Extract collected figures
figures = []
# We iterate over analysis results to find 'collect_figures' actions
if results:
for item in results:
if item.get("action") == "collect_figures":
collected = item.get("collected_figures", [])
for fig in collected:
# Enrich with web path
if session.output_dir:
# Assume filename is present
fname = fig.get("filename")
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
fig["web_url"] = f"/{relative_session_path}/{fname}"
figures.append(fig)
# Also check for 'generate_code' results that might have implicit figures if we parse them
# But the 'collect_figures' action is the reliable source as per agent design
# Auto-discovery fallback if list is empty but pngs exist?
if not figures and session.output_dir:
# Simple scan
pngs = glob.glob(os.path.join(session.output_dir, "*.png"))
for p in pngs:
fname = os.path.basename(p)
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
figures.append({
"filename": fname,
"description": "Auto-discovered image",
"analysis": "No analysis available",
"web_url": f"/{relative_session_path}/{fname}"
})
return {"figures": figures}
@app.get("/api/export")
async def export_report(session_id: str = Query(..., description="Session ID")):
session = session_manager.get_session(session_id)
if not session or not session.output_dir:
raise HTTPException(status_code=404, detail="Session not found")
import zipfile
import tempfile
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"report_{timestamp}.zip"
export_dir = "outputs"
os.makedirs(export_dir, exist_ok=True)
temp_zip_path = os.path.join(export_dir, zip_filename)
with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(session.output_dir):
for file in files:
if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')):
abs_path = os.path.join(root, file)
rel_path = os.path.relpath(abs_path, session.output_dir)
zf.write(abs_path, arcname=rel_path)
return FileResponse(
path=temp_zip_path,
filename=zip_filename,
media_type='application/zip'
)
2026-01-31 20:27:17 +08:00
@app.get("/api/download_script")
async def download_script(session_id: str = Query(..., description="Session ID")):
"""下载生成的Python脚本"""
session = session_manager.get_session(session_id)
if not session or not session.reusable_script:
raise HTTPException(status_code=404, detail="Script not found")
2026-01-09 16:52:45 +08:00
2026-01-31 20:27:17 +08:00
if not os.path.exists(session.reusable_script):
raise HTTPException(status_code=404, detail="Script file missing on server")
2026-01-09 16:52:45 +08:00
2026-01-31 20:27:17 +08:00
return FileResponse(
path=session.reusable_script,
filename=os.path.basename(session.reusable_script),
media_type='text/x-python'
)
2026-01-09 16:52:45 +08:00
2026-01-31 20:27:17 +08:00
# --- Tools API ---
2026-01-09 16:52:45 +08:00
# --- 新增API端点 ---
@app.get("/api/sessions/progress")
async def get_session_progress(session_id: str = Query(..., description="Session ID")):
"""获取会话分析进度"""
session_info = session_manager.get_session_info(session_id)
if not session_info:
raise HTTPException(status_code=404, detail="Session not found")
return session_info
@app.get("/api/sessions/list")
async def list_all_sessions():
"""获取所有会话列表"""
session_ids = session_manager.list_sessions()
sessions_info = []
for sid in session_ids:
info = session_manager.get_session_info(sid)
if info:
sessions_info.append(info)
return {"sessions": sessions_info, "total": len(sessions_info)}
@app.delete("/api/sessions/{session_id}")
async def delete_specific_session(session_id: str):
"""删除指定会话"""
success = session_manager.delete_session(session_id)
if not success:
raise HTTPException(status_code=404, detail="Session not found")
return {"status": "deleted", "session_id": session_id}
2026-01-31 20:27:17 +08:00
return {"status": "deleted", "session_id": session_id}
# --- History API ---
@app.get("/api/history")
async def get_history():
"""
Get list of past analysis sessions from outputs directory
"""
history = []
output_base = "outputs"
if not os.path.exists(output_base):
return {"history": []}
try:
# Scan for session_* directories
for entry in os.scandir(output_base):
if entry.is_dir() and entry.name.startswith("session_"):
# Extract timestamp from folder name: session_20250101_120000
session_id = entry.name.replace("session_", "")
# Check creation time or extract from name
try:
# Try to parse timestamp from ID if it matches format
# Format: YYYYMMDD_HHMMSS
timestamp_str = session_id
dt = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S")
display_time = dt.strftime("%Y-%m-%d %H:%M:%S")
sort_key = dt.timestamp()
except ValueError:
# Fallback to file creation time
sort_key = entry.stat().st_ctime
display_time = datetime.fromtimestamp(sort_key).strftime("%Y-%m-%d %H:%M:%S")
history.append({
"id": session_id,
"timestamp": display_time,
"sort_key": sort_key,
"name": f"Session {display_time}"
})
# Sort by latest first
history.sort(key=lambda x: x["sort_key"], reverse=True)
# Cleanup internal sort key
for item in history:
del item["sort_key"]
return {"history": history}
except Exception as e:
print(f"Error scanning history: {e}")
return {"history": []}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)