67 lines
2.0 KiB
Python
67 lines
2.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
In-memory registry for long-lived analysis sessions.
|
|
"""
|
|
|
|
import os
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Optional
|
|
|
|
from config.llm_config import LLMConfig
|
|
from data_analysis_agent import DataAnalysisAgent
|
|
|
|
|
|
@dataclass
|
|
class RuntimeSession:
|
|
session_id: str
|
|
user_id: str
|
|
session_output_dir: str
|
|
uploaded_files: list[str]
|
|
template_path: Optional[str]
|
|
agent: DataAnalysisAgent
|
|
initialized: bool = False
|
|
lock: threading.Lock = field(default_factory=threading.Lock)
|
|
|
|
|
|
class SessionManager:
|
|
"""Keeps session-scoped agents alive across follow-up topics."""
|
|
|
|
def __init__(self, outputs_dir: str):
|
|
self.outputs_dir = os.path.abspath(outputs_dir)
|
|
self._sessions: Dict[str, RuntimeSession] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def get_or_create(
|
|
self,
|
|
session_id: str,
|
|
user_id: str,
|
|
session_output_dir: str,
|
|
uploaded_files: list[str],
|
|
template_path: Optional[str],
|
|
) -> RuntimeSession:
|
|
with self._lock:
|
|
runtime = self._sessions.get(session_id)
|
|
if runtime is None:
|
|
runtime = RuntimeSession(
|
|
session_id=session_id,
|
|
user_id=user_id,
|
|
session_output_dir=session_output_dir,
|
|
uploaded_files=uploaded_files,
|
|
template_path=template_path,
|
|
agent=DataAnalysisAgent(
|
|
llm_config=LLMConfig(),
|
|
output_dir=self.outputs_dir,
|
|
max_rounds=20,
|
|
force_max_rounds=False,
|
|
),
|
|
)
|
|
self._sessions[session_id] = runtime
|
|
return runtime
|
|
|
|
def close(self, session_id: str) -> None:
|
|
with self._lock:
|
|
runtime = self._sessions.pop(session_id, None)
|
|
if runtime:
|
|
runtime.agent.close_session()
|