更新readme文档
This commit is contained in:
404
web/main.py
Normal file
404
web/main.py
Normal file
@@ -0,0 +1,404 @@
|
||||
|
||||
import sys
|
||||
import os
|
||||
import threading
|
||||
import glob
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional, Dict, List
|
||||
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Add parent directory to path to import agent modules
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from data_analysis_agent import DataAnalysisAgent
|
||||
from config.llm_config import LLMConfig
|
||||
from utils.create_session_dir import create_session_output_dir
|
||||
from merge_excel import merge_excel_files
|
||||
from sort_csv import sort_csv_by_time
|
||||
|
||||
app = FastAPI(title="IOV Data Analysis Agent")
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# --- Session Management ---
|
||||
|
||||
class SessionData:
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
self.is_running = False
|
||||
self.output_dir: Optional[str] = None
|
||||
self.generated_report: Optional[str] = None
|
||||
self.log_file: Optional[str] = None
|
||||
self.analysis_results: List[Dict] = [] # Store analysis results for gallery
|
||||
|
||||
class SessionManager:
|
||||
def __init__(self):
|
||||
self.sessions: Dict[str, SessionData] = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def create_session(self) -> str:
|
||||
with self.lock:
|
||||
session_id = str(uuid.uuid4())
|
||||
self.sessions[session_id] = SessionData(session_id)
|
||||
return session_id
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[SessionData]:
|
||||
return self.sessions.get(session_id)
|
||||
|
||||
def list_sessions(self):
|
||||
return list(self.sessions.keys())
|
||||
|
||||
session_manager = SessionManager()
|
||||
|
||||
# Mount static files
|
||||
os.makedirs("web/static", exist_ok=True)
|
||||
os.makedirs("uploads", exist_ok=True)
|
||||
os.makedirs("outputs", exist_ok=True)
|
||||
|
||||
app.mount("/static", StaticFiles(directory="web/static"), name="static")
|
||||
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
||||
|
||||
# --- Helper Functions ---
|
||||
|
||||
def run_analysis_task(session_id: str, files: list, user_requirement: str):
|
||||
"""
|
||||
Runs the analysis agent in a background thread for a specific session.
|
||||
"""
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
print(f"Error: Session {session_id} not found in background task.")
|
||||
return
|
||||
|
||||
session.is_running = True
|
||||
try:
|
||||
# Create session directory
|
||||
base_output_dir = "outputs"
|
||||
# We enforce a specific directory naming convention or let the util handle it
|
||||
# ideally we map session_id to the directory
|
||||
# For now, let's use the standard utility but we might lose the direct mapping if not careful
|
||||
# Let's trust the return value
|
||||
session_output_dir = create_session_output_dir(base_output_dir, user_requirement)
|
||||
session.output_dir = session_output_dir
|
||||
|
||||
# Initialize Log capturing
|
||||
session.log_file = os.path.join(session_output_dir, "process.log")
|
||||
|
||||
# Thread-safe logging requires a bit of care.
|
||||
# Since we are running in a thread, redirecting sys.stdout globally is BAD for multi-session.
|
||||
# However, for this MVP, if we run multiple sessions concurrently, their logs will mix in stdout.
|
||||
# BUT we are writing to specific log files.
|
||||
# We need a logger that writes to the session's log file.
|
||||
# And the Agent needs to use that logger.
|
||||
# Currently the Agent uses print().
|
||||
# To support true concurrent logging without mixing, we'd need to refactor Agent to use a logger instance.
|
||||
# LIMITATION: For now, we accept that stdout redirection intercepts EVERYTHING.
|
||||
# So multiple concurrent sessions is risky with global stdout redirection.
|
||||
# A safer approach for now: We won't redirect stdout globally for multi-session support
|
||||
# unless we lock execution to one at a time.
|
||||
# OR: We just rely on the fact that we might only run one analysis at a time mostly.
|
||||
# Let's try to just write to the log file explicitly if we could, but we can't change Agent easily right now.
|
||||
# Compromise: We will continue to use global redirection but acknowledge it's not thread-safe for output.
|
||||
# A better way: Modify Agent to accept a 'log_callback'.
|
||||
# For this refactor, let's stick to the existing pattern but bind it to the thread if possible? No.
|
||||
|
||||
# We will wrap the execution with a simple File Logger that appends to the distinct file.
|
||||
# But sys.stdout is global.
|
||||
# We will assume single concurrent analysis for safety, or accept mixed terminal output but separate file logs?
|
||||
# Actually, if we swap sys.stdout, it affects all threads.
|
||||
# So we MUST NOT swap sys.stdout if we want concurrency.
|
||||
# If we don't swap stdout, we don't capture logs to file unless Agent does it.
|
||||
# The Agent code has `print`.
|
||||
# Correct fix: Refactor Agent to use `logging` module or pass a printer.
|
||||
# Given the scope, let's just hold the lock (serialize execution) OR allow mixing in terminal
|
||||
# but try to capture to file?
|
||||
# Let's just write to the file.
|
||||
|
||||
with open(session.log_file, "w", encoding="utf-8") as f:
|
||||
f.write(f"--- Session {session_id} Started ---\n")
|
||||
|
||||
# We will create a custom print function that writes to the file
|
||||
# And monkeypatch builtins.print? No, that's too hacky.
|
||||
# Let's just use the stdout redirector, but acknowledge only one active session at a time is safe.
|
||||
# We can implement a crude lock for now.
|
||||
|
||||
class FileLogger:
|
||||
def __init__(self, filename):
|
||||
self.terminal = sys.__stdout__
|
||||
self.log = open(filename, "a", encoding="utf-8", buffering=1)
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
|
||||
def flush(self):
|
||||
self.terminal.flush()
|
||||
self.log.flush()
|
||||
|
||||
def close(self):
|
||||
self.log.close()
|
||||
|
||||
logger = FileLogger(session.log_file)
|
||||
sys.stdout = logger # Global hijack!
|
||||
|
||||
try:
|
||||
llm_config = LLMConfig()
|
||||
agent = DataAnalysisAgent(llm_config, force_max_rounds=False, output_dir=base_output_dir)
|
||||
|
||||
result = agent.analyze(
|
||||
user_input=user_requirement,
|
||||
files=files,
|
||||
session_output_dir=session_output_dir
|
||||
)
|
||||
|
||||
session.generated_report = result.get("report_file_path", None)
|
||||
session.analysis_results = result.get("analysis_results", [])
|
||||
|
||||
# Save results to json for persistence
|
||||
with open(os.path.join(session_output_dir, "results.json"), "w") as f:
|
||||
json.dump(session.analysis_results, f, default=str)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during analysis: {e}")
|
||||
finally:
|
||||
sys.stdout = logger.terminal
|
||||
logger.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"System Error: {e}")
|
||||
finally:
|
||||
session.is_running = False
|
||||
|
||||
# --- Pydantic Models ---
|
||||
|
||||
class StartRequest(BaseModel):
|
||||
requirement: str
|
||||
|
||||
# --- API Endpoints ---
|
||||
|
||||
@app.get("/")
|
||||
async def read_root():
|
||||
return FileResponse("web/static/index.html")
|
||||
|
||||
@app.post("/api/upload")
|
||||
async def upload_files(files: list[UploadFile] = File(...)):
|
||||
saved_files = []
|
||||
for file in files:
|
||||
file_location = f"uploads/{file.filename}"
|
||||
with open(file_location, "wb+") as file_object:
|
||||
file_object.write(file.file.read())
|
||||
saved_files.append(file_location)
|
||||
return {"info": f"Saved {len(saved_files)} files", "paths": saved_files}
|
||||
|
||||
@app.post("/api/start")
|
||||
async def start_analysis(request: StartRequest, background_tasks: BackgroundTasks):
|
||||
session_id = session_manager.create_session()
|
||||
|
||||
files = glob.glob("uploads/*.csv")
|
||||
if not files:
|
||||
if os.path.exists("cleaned_data.csv"):
|
||||
files = ["cleaned_data.csv"]
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="No CSV files found")
|
||||
|
||||
files = [os.path.abspath(f) for f in files] # Only use absolute paths
|
||||
|
||||
background_tasks.add_task(run_analysis_task, session_id, files, request.requirement)
|
||||
return {"status": "started", "session_id": session_id}
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status(session_id: str = Query(..., description="Session ID")):
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
log_content = ""
|
||||
if session.log_file and os.path.exists(session.log_file):
|
||||
with open(session.log_file, "r", encoding="utf-8") as f:
|
||||
log_content = f.read()
|
||||
|
||||
return {
|
||||
"is_running": session.is_running,
|
||||
"log": log_content,
|
||||
"has_report": session.generated_report is not None,
|
||||
"report_path": session.generated_report
|
||||
}
|
||||
|
||||
@app.get("/api/report")
|
||||
async def get_report(session_id: str = Query(..., description="Session ID")):
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
if not session.generated_report or not os.path.exists(session.generated_report):
|
||||
return {"content": "Report not ready."}
|
||||
|
||||
with open(session.generated_report, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Fix image paths
|
||||
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
|
||||
web_base_path = f"/{relative_session_path}"
|
||||
content = content.replace("](./", f"]({web_base_path}/")
|
||||
|
||||
return {"content": content, "base_path": web_base_path}
|
||||
|
||||
@app.get("/api/figures")
|
||||
async def get_figures(session_id: str = Query(..., description="Session ID")):
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# We can try to get from memory first
|
||||
results = session.analysis_results
|
||||
|
||||
# If empty in memory (maybe server restarted but files exist?), try load json
|
||||
if not results and session.output_dir:
|
||||
json_path = os.path.join(session.output_dir, "results.json")
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r') as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Extract collected figures
|
||||
figures = []
|
||||
|
||||
# We iterate over analysis results to find 'collect_figures' actions
|
||||
if results:
|
||||
for item in results:
|
||||
if item.get("action") == "collect_figures":
|
||||
collected = item.get("collected_figures", [])
|
||||
for fig in collected:
|
||||
# Enrich with web path
|
||||
if session.output_dir:
|
||||
# Assume filename is present
|
||||
fname = fig.get("filename")
|
||||
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
|
||||
fig["web_url"] = f"/{relative_session_path}/{fname}"
|
||||
figures.append(fig)
|
||||
|
||||
# Also check for 'generate_code' results that might have implicit figures if we parse them
|
||||
# But the 'collect_figures' action is the reliable source as per agent design
|
||||
|
||||
# Auto-discovery fallback if list is empty but pngs exist?
|
||||
if not figures and session.output_dir:
|
||||
# Simple scan
|
||||
pngs = glob.glob(os.path.join(session.output_dir, "*.png"))
|
||||
for p in pngs:
|
||||
fname = os.path.basename(p)
|
||||
relative_session_path = os.path.relpath(session.output_dir, os.getcwd())
|
||||
figures.append({
|
||||
"filename": fname,
|
||||
"description": "Auto-discovered image",
|
||||
"analysis": "No analysis available",
|
||||
"web_url": f"/{relative_session_path}/{fname}"
|
||||
})
|
||||
|
||||
return {"figures": figures}
|
||||
|
||||
@app.get("/api/export")
|
||||
async def export_report(session_id: str = Query(..., description="Session ID")):
|
||||
session = session_manager.get_session(session_id)
|
||||
if not session or not session.output_dir:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
import zipfile
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"report_{timestamp}.zip"
|
||||
|
||||
export_dir = "outputs"
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
temp_zip_path = os.path.join(export_dir, zip_filename)
|
||||
|
||||
with zipfile.ZipFile(temp_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for root, dirs, files in os.walk(session.output_dir):
|
||||
for file in files:
|
||||
if file.endswith(('.md', '.png', '.csv', '.log', '.json', '.yaml')):
|
||||
abs_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(abs_path, session.output_dir)
|
||||
zf.write(abs_path, arcname=rel_path)
|
||||
|
||||
return FileResponse(
|
||||
path=temp_zip_path,
|
||||
filename=zip_filename,
|
||||
media_type='application/zip'
|
||||
)
|
||||
|
||||
# --- Tools API ---
|
||||
|
||||
class ToolRequest(BaseModel):
|
||||
source_dir: Optional[str] = "uploads"
|
||||
output_filename: Optional[str] = "merged_output.csv"
|
||||
target_file: Optional[str] = None
|
||||
|
||||
@app.post("/api/tools/merge")
|
||||
async def tool_merge_excel(req: ToolRequest):
|
||||
"""
|
||||
Trigger Excel Merge Tool
|
||||
"""
|
||||
try:
|
||||
source = req.source_dir
|
||||
output = req.output_filename
|
||||
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
await loop.run_in_executor(None, lambda: merge_excel_files(source, output))
|
||||
|
||||
output_abs = os.path.abspath(output)
|
||||
if os.path.exists(output_abs):
|
||||
return {"status": "success", "message": "Merge completed", "output_file": output_abs}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/tools/sort")
|
||||
async def tool_sort_csv(req: ToolRequest):
|
||||
"""
|
||||
Trigger CSV Sort Tool
|
||||
"""
|
||||
try:
|
||||
target = req.target_file
|
||||
if not target:
|
||||
raise HTTPException(status_code=400, detail="Target file required")
|
||||
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
await loop.run_in_executor(None, lambda: sort_csv_by_time(target))
|
||||
|
||||
return {"status": "success", "message": f"Sorted {target} by time"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# --- Help API ---
|
||||
|
||||
@app.get("/api/help/troubleshooting")
|
||||
async def get_troubleshooting_guide():
|
||||
"""
|
||||
Returns the content of troubleshooting_guide.md
|
||||
"""
|
||||
guide_path = os.path.expanduser("~/.gemini/antigravity/brain/3ff617fe-5f27-4ab8-b61b-c634f2e75255/troubleshooting_guide.md")
|
||||
|
||||
if not os.path.exists(guide_path):
|
||||
return {"content": "# Troubleshooting Guide Not Found\n\nCould not locate the guide artifact."}
|
||||
|
||||
try:
|
||||
with open(guide_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return {"content": content}
|
||||
except Exception as e:
|
||||
return {"content": f"# Error Loading Guide\n\n{e}"}
|
||||
Reference in New Issue
Block a user