diff --git a/config/app_config.py b/config/app_config.py new file mode 100644 index 0000000..d2c6c12 --- /dev/null +++ b/config/app_config.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +""" +应用配置中心 - 集中管理所有配置项 +""" + +import os +from dataclasses import dataclass, field +from typing import List, Optional + + +@dataclass +class AppConfig: + """应用配置中心""" + + # 分析配置 + max_rounds: int = field(default=20) + force_max_rounds: bool = field(default=False) + default_output_dir: str = field(default="outputs") + + # 数据处理配置 + max_file_size_mb: int = field(default=500) # 最大文件大小(MB) + chunk_size: int = field(default=100000) # 分块读取大小 + data_cache_enabled: bool = field(default=True) + cache_dir: str = field(default=".cache/data") + + # LLM配置 + llm_cache_enabled: bool = field(default=True) + llm_cache_dir: str = field(default=".cache/llm") + llm_stream_enabled: bool = field(default=False) + + # 代码执行配置 + code_timeout: int = field(default=300) # 代码执行超时(秒) + allowed_imports: List[str] = field(default_factory=lambda: [ + 'pandas', 'numpy', 'matplotlib', 'seaborn', 'plotly', + 'scipy', 'sklearn', 'duckdb', 'datetime', 'json', + 'os', 're', 'pathlib', 'glob', 'typing', 'collections', + 'itertools', 'functools', 'warnings' + ]) + + # Web配置 + web_host: str = field(default="0.0.0.0") + web_port: int = field(default=8000) + upload_dir: str = field(default="uploads") + + # 日志配置 + log_filename: str = field(default="log.txt") + enable_code_logging: bool = field(default=False) # 是否记录生成的代码 + + @classmethod + def from_env(cls) -> 'AppConfig': + """从环境变量创建配置""" + config = cls() + + # 从环境变量覆盖配置 + if max_rounds := os.getenv("APP_MAX_ROUNDS"): + config.max_rounds = int(max_rounds) + + if chunk_size := os.getenv("APP_CHUNK_SIZE"): + config.chunk_size = int(chunk_size) + + if cache_enabled := os.getenv("APP_CACHE_ENABLED"): + config.data_cache_enabled = cache_enabled.lower() == "true" + + return config + + def validate(self) -> bool: + """验证配置""" + if self.max_rounds <= 0: + raise ValueError("max_rounds must be positive") + + if self.chunk_size <= 0: + raise ValueError("chunk_size must be positive") + + if self.code_timeout <= 0: + raise ValueError("code_timeout must be positive") + + return True + + +# 全局配置实例 +app_config = AppConfig.from_env() diff --git a/config/llm_config copy.py b/config/llm_config copy.py index 280229d..0858d1e 100644 --- a/config/llm_config copy.py +++ b/config/llm_config copy.py @@ -20,7 +20,7 @@ class LLMConfig: provider: str = os.environ.get("LLM_PROVIDER", "openai") # openai, gemini, etc. api_key: str = os.environ.get("OPENAI_API_KEY", "sk-2187174de21548b0b8b0c92129700199") base_url: str = os.environ.get("OPENAI_BASE_URL", "http://127.0.0.1:9999/v1") - model: str = os.environ.get("OPENAI_MODEL", "gemini-3-flash") + model: str = os.environ.get("OPENAI_MODEL", "gemini--flash") temperature: float = 0.5 max_tokens: int = 131072 diff --git a/config/llm_config.py b/config/llm_config.py index dbdd016..2b2205c 100644 --- a/config/llm_config.py +++ b/config/llm_config.py @@ -20,7 +20,7 @@ class LLMConfig: provider: str = os.environ.get("LLM_PROVIDER", "openai") # openai, gemini, etc. api_key: str = os.environ.get("OPENAI_API_KEY", "sk-Gce85QLROESeOWf3icd2mQnYHOrmMYojwVPQ0AubMjGQ5ZE2") base_url: str = os.environ.get("OPENAI_BASE_URL", "https://gemini.jeason.online/v1") - model: str = os.environ.get("OPENAI_MODEL", "gemini-2.5-flash") + model: str = os.environ.get("OPENAI_MODEL", "gemini-2.5-pro") temperature: float = 0.5 max_tokens: int = 131072 diff --git a/test.py b/test.py index 1ad40d8..e29c938 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ client = openai.OpenAI( ) response = client.chat.completions.create( - model="gemini-2.5-flash", + model="gemini-2.5-pro", messages=[{"role": "user", "content": "你好,请自我介绍"}] ) print(response.choices[0].message.content) \ No newline at end of file diff --git a/utils/analysis_templates.py b/utils/analysis_templates.py new file mode 100644 index 0000000..a306521 --- /dev/null +++ b/utils/analysis_templates.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +""" +分析模板系统 - 提供预定义的分析场景 +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any +from dataclasses import dataclass + + +@dataclass +class AnalysisStep: + """分析步骤""" + name: str + description: str + analysis_type: str # explore, visualize, calculate, report + prompt: str + + +class AnalysisTemplate(ABC): + """分析模板基类""" + + def __init__(self, name: str, description: str): + self.name = name + self.description = description + self.steps: List[AnalysisStep] = [] + + @abstractmethod + def build_steps(self, **kwargs) -> List[AnalysisStep]: + """构建分析步骤""" + pass + + def get_full_prompt(self, **kwargs) -> str: + """获取完整的分析提示词""" + steps = self.build_steps(**kwargs) + + prompt = f"# {self.name}\n\n{self.description}\n\n" + prompt += "## 分析步骤:\n\n" + + for i, step in enumerate(steps, 1): + prompt += f"### {i}. {step.name}\n" + prompt += f"{step.description}\n\n" + prompt += f"```\n{step.prompt}\n```\n\n" + + return prompt + + +class HealthReportTemplate(AnalysisTemplate): + """健康度报告模板 - 专门用于车联网工单健康度分析""" + + def __init__(self): + super().__init__( + name="车联网工单健康度报告", + description="全面分析车联网技术支持工单的健康状况,从多个维度评估工单处理效率和质量" + ) + + def build_steps(self, **kwargs) -> List[AnalysisStep]: + """构建健康度报告的分析步骤""" + return [ + AnalysisStep( + name="数据概览与质量检查", + description="检查数据完整性、缺失值、异常值等", + analysis_type="explore", + prompt="加载数据并进行质量检查,输出数据概况和潜在问题" + ), + AnalysisStep( + name="工单总量分析", + description="统计总工单数、时间分布、趋势变化", + analysis_type="calculate", + prompt="计算总工单数,按时间维度统计工单量,绘制时间序列趋势图" + ), + AnalysisStep( + name="车型维度分析", + description="分析不同车型的工单分布和问题特征", + analysis_type="visualize", + prompt="统计各车型工单数量,绘制车型分布饼图和柱状图,识别高风险车型" + ), + AnalysisStep( + name="模块维度分析", + description="分析工单涉及的技术模块分布", + analysis_type="visualize", + prompt="统计各技术模块的工单量,绘制模块分布图,识别高频问题模块" + ), + AnalysisStep( + name="功能维度分析", + description="分析具体功能点的问题分布", + analysis_type="visualize", + prompt="统计各功能的工单量,绘制TOP功能问题排行,分析功能稳定性" + ), + AnalysisStep( + name="问题严重程度分析", + description="分析工单的严重程度分布", + analysis_type="visualize", + prompt="统计不同严重程度的工单比例,绘制严重程度分布图" + ), + AnalysisStep( + name="处理时长分析", + description="分析工单处理时效性", + analysis_type="calculate", + prompt="计算平均处理时长、SLA达成率,识别超时工单,绘制时长分布图" + ), + AnalysisStep( + name="责任人工作负载分析", + description="分析各责任人的工单负载和处理效率", + analysis_type="visualize", + prompt="统计各责任人的工单数和处理效率,绘制负载分布图,识别超负荷人员" + ), + AnalysisStep( + name="来源渠道分析", + description="分析工单来源渠道分布", + analysis_type="visualize", + prompt="统计各来源渠道的工单量,绘制渠道分布图" + ), + AnalysisStep( + name="高频问题深度分析", + description="识别并深入分析高频问题", + analysis_type="explore", + prompt="提取TOP10高频问题,分析问题原因、影响范围和解决方案" + ), + AnalysisStep( + name="综合健康度评分", + description="基于多个维度计算综合健康度评分", + analysis_type="calculate", + prompt="综合考虑工单量、处理时长、问题严重度等指标,计算健康度评分" + ), + AnalysisStep( + name="生成最终报告", + description="整合所有分析结果,生成完整报告", + analysis_type="report", + prompt="整合所有图表和分析结论,生成一份完整的车联网工单健康度报告" + ) + ] + + +class TrendAnalysisTemplate(AnalysisTemplate): + """趋势分析模板""" + + def __init__(self): + super().__init__( + name="时间序列趋势分析", + description="分析数据的时间趋势、季节性和周期性特征" + ) + + def build_steps(self, time_column: str = "日期", value_column: str = "数值", **kwargs) -> List[AnalysisStep]: + return [ + AnalysisStep( + name="时间序列数据准备", + description="将数据转换为时间序列格式", + analysis_type="explore", + prompt=f"将 '{time_column}' 列转换为日期格式,按时间排序数据" + ), + AnalysisStep( + name="趋势可视化", + description="绘制时间序列图", + analysis_type="visualize", + prompt=f"绘制 '{value_column}' 随 '{time_column}' 的变化趋势图,添加移动平均线" + ), + AnalysisStep( + name="趋势分析", + description="识别上升、下降或平稳趋势", + analysis_type="calculate", + prompt="计算趋势线斜率,判断整体趋势方向和变化速率" + ), + AnalysisStep( + name="季节性分析", + description="检测季节性模式", + analysis_type="visualize", + prompt="分析月度、季度等周期性模式,绘制季节性分解图" + ), + AnalysisStep( + name="异常点检测", + description="识别时间序列中的异常点", + analysis_type="calculate", + prompt="使用统计方法检测时间序列中的异常值,标注在图表上" + ) + ] + + +class AnomalyDetectionTemplate(AnalysisTemplate): + """异常检测模板""" + + def __init__(self): + super().__init__( + name="异常值检测分析", + description="识别数据中的异常值和离群点" + ) + + def build_steps(self, **kwargs) -> List[AnalysisStep]: + return [ + AnalysisStep( + name="数值列统计分析", + description="计算数值列的统计特征", + analysis_type="calculate", + prompt="计算所有数值列的均值、标准差、四分位数等统计量" + ), + AnalysisStep( + name="箱线图可视化", + description="使用箱线图识别异常值", + analysis_type="visualize", + prompt="为每个数值列绘制箱线图,直观展示异常值分布" + ), + AnalysisStep( + name="Z-Score异常检测", + description="使用Z-Score方法检测异常值", + analysis_type="calculate", + prompt="计算每个数值的Z-Score,标记|Z|>3的异常值" + ), + AnalysisStep( + name="IQR异常检测", + description="使用四分位距方法检测异常值", + analysis_type="calculate", + prompt="使用IQR方法(Q1-1.5*IQR, Q3+1.5*IQR)检测异常值" + ), + AnalysisStep( + name="异常值汇总报告", + description="整理所有检测到的异常值", + analysis_type="report", + prompt="汇总所有异常值,分析其特征和可能原因,提供处理建议" + ) + ] + + +class ComparisonAnalysisTemplate(AnalysisTemplate): + """对比分析模板""" + + def __init__(self): + super().__init__( + name="分组对比分析", + description="对比不同分组之间的差异和特征" + ) + + def build_steps(self, group_column: str = "分组", value_column: str = "数值", **kwargs) -> List[AnalysisStep]: + return [ + AnalysisStep( + name="分组统计", + description="计算各组的统计指标", + analysis_type="calculate", + prompt=f"按 '{group_column}' 分组,计算 '{value_column}' 的均值、中位数、标准差" + ), + AnalysisStep( + name="分组可视化对比", + description="绘制对比图表", + analysis_type="visualize", + prompt=f"绘制各组的柱状图和箱线图,直观对比差异" + ), + AnalysisStep( + name="差异显著性检验", + description="统计检验组间差异", + analysis_type="calculate", + prompt="进行t检验或方差分析,判断组间差异是否显著" + ), + AnalysisStep( + name="对比结论", + description="总结对比结果", + analysis_type="report", + prompt="总结各组特征、主要差异和业务洞察" + ) + ] + + +# 模板注册表 +TEMPLATE_REGISTRY = { + "health_report": HealthReportTemplate, + "trend_analysis": TrendAnalysisTemplate, + "anomaly_detection": AnomalyDetectionTemplate, + "comparison": ComparisonAnalysisTemplate +} + + +def get_template(template_name: str) -> AnalysisTemplate: + """获取分析模板""" + template_class = TEMPLATE_REGISTRY.get(template_name) + if template_class: + return template_class() + else: + raise ValueError(f"未找到模板: {template_name}。可用模板: {list(TEMPLATE_REGISTRY.keys())}") + + +def list_templates() -> List[Dict[str, str]]: + """列出所有可用模板""" + templates = [] + for name, template_class in TEMPLATE_REGISTRY.items(): + template = template_class() + templates.append({ + "name": name, + "display_name": template.name, + "description": template.description + }) + return templates diff --git a/utils/cache_manager.py b/utils/cache_manager.py new file mode 100644 index 0000000..63215c5 --- /dev/null +++ b/utils/cache_manager.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +""" +缓存管理器 - 支持数据和LLM响应缓存 +""" + +import os +import json +import hashlib +import pickle +from pathlib import Path +from typing import Any, Optional, Callable +from functools import wraps + + +class CacheManager: + """缓存管理器""" + + def __init__(self, cache_dir: str = ".cache", enabled: bool = True): + self.cache_dir = Path(cache_dir) + self.enabled = enabled + + if self.enabled: + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def _get_cache_key(self, *args, **kwargs) -> str: + """生成缓存键""" + key_data = f"{args}_{kwargs}" + return hashlib.md5(key_data.encode()).hexdigest() + + def _get_cache_path(self, key: str) -> Path: + """获取缓存文件路径""" + return self.cache_dir / f"{key}.pkl" + + def get(self, key: str) -> Optional[Any]: + """获取缓存""" + if not self.enabled: + return None + + cache_path = self._get_cache_path(key) + if cache_path.exists(): + try: + with open(cache_path, 'rb') as f: + return pickle.load(f) + except Exception as e: + print(f"⚠️ 读取缓存失败: {e}") + return None + return None + + def set(self, key: str, value: Any) -> None: + """设置缓存""" + if not self.enabled: + return + + cache_path = self._get_cache_path(key) + try: + with open(cache_path, 'wb') as f: + pickle.dump(value, f) + except Exception as e: + print(f"⚠️ 写入缓存失败: {e}") + + def clear(self) -> None: + """清空所有缓存""" + if self.cache_dir.exists(): + for cache_file in self.cache_dir.glob("*.pkl"): + cache_file.unlink() + print("✅ 缓存已清空") + + def cached(self, key_func: Optional[Callable] = None): + """缓存装饰器""" + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not self.enabled: + return func(*args, **kwargs) + + # 生成缓存键 + if key_func: + cache_key = key_func(*args, **kwargs) + else: + cache_key = self._get_cache_key(*args, **kwargs) + + # 尝试从缓存获取 + cached_value = self.get(cache_key) + if cached_value is not None: + print(f"💾 使用缓存: {cache_key[:8]}...") + return cached_value + + # 执行函数并缓存结果 + result = func(*args, **kwargs) + self.set(cache_key, result) + return result + + return wrapper + return decorator + + +class LLMCacheManager(CacheManager): + """LLM响应缓存管理器""" + + def get_cache_key_from_messages(self, messages: list, model: str = "") -> str: + """从消息列表生成缓存键""" + key_data = json.dumps(messages, sort_keys=True) + model + return hashlib.md5(key_data.encode()).hexdigest() diff --git a/utils/data_loader.py b/utils/data_loader.py index c1c2cef..9717779 100644 --- a/utils/data_loader.py +++ b/utils/data_loader.py @@ -2,6 +2,17 @@ import os import pandas as pd import io +import hashlib +from pathlib import Path +from typing import Optional, Iterator +from config.app_config import app_config +from utils.cache_manager import CacheManager + +# 初始化缓存管理器 +data_cache = CacheManager( + cache_dir=app_config.cache_dir, + enabled=app_config.data_cache_enabled +) def load_and_profile_data(file_paths: list) -> str: """ @@ -88,3 +99,119 @@ def load_and_profile_data(file_paths: list) -> str: profile_summary += f"❌ 读取或分析文件失败: {str(e)}\n\n" return profile_summary + + +def get_file_hash(file_path: str) -> str: + """计算文件哈希值,用于缓存键""" + hasher = hashlib.md5() + hasher.update(file_path.encode()) + + # 添加文件修改时间 + if os.path.exists(file_path): + mtime = os.path.getmtime(file_path) + hasher.update(str(mtime).encode()) + + return hasher.hexdigest() + + +def load_data_chunked(file_path: str, chunksize: Optional[int] = None) -> Iterator[pd.DataFrame]: + """ + 流式读取大文件,分块返回DataFrame + + Args: + file_path: 文件路径 + chunksize: 每块行数,默认使用配置值 + + Yields: + DataFrame块 + """ + if chunksize is None: + chunksize = app_config.chunk_size + + ext = os.path.splitext(file_path)[1].lower() + + if ext == '.csv': + # 尝试多种编码 + for encoding in ['utf-8', 'gbk', 'latin1']: + try: + chunks = pd.read_csv(file_path, encoding=encoding, chunksize=chunksize) + for chunk in chunks: + yield chunk + break + except UnicodeDecodeError: + continue + except Exception as e: + print(f"❌ 读取CSV文件失败: {e}") + break + elif ext in ['.xlsx', '.xls']: + # Excel文件不支持chunksize,直接读取 + try: + df = pd.read_excel(file_path) + # 手动分块 + for i in range(0, len(df), chunksize): + yield df.iloc[i:i+chunksize] + except Exception as e: + print(f"❌ 读取Excel文件失败: {e}") + + +def load_data_with_cache(file_path: str, force_reload: bool = False) -> Optional[pd.DataFrame]: + """ + 带缓存的数据加载 + + Args: + file_path: 文件路径 + force_reload: 是否强制重新加载 + + Returns: + DataFrame或None + """ + if not os.path.exists(file_path): + print(f"⚠️ 文件不存在: {file_path}") + return None + + # 检查文件大小 + file_size_mb = os.path.getsize(file_path) / (1024 * 1024) + + # 对于大文件,建议使用流式处理 + if file_size_mb > app_config.max_file_size_mb: + print(f"⚠️ 文件过大 ({file_size_mb:.1f}MB),建议使用 load_data_chunked() 流式处理") + + # 生成缓存键 + cache_key = get_file_hash(file_path) + + # 尝试从缓存加载 + if not force_reload and app_config.data_cache_enabled: + cached_data = data_cache.get(cache_key) + if cached_data is not None: + print(f"💾 从缓存加载数据: {os.path.basename(file_path)}") + return cached_data + + # 加载数据 + ext = os.path.splitext(file_path)[1].lower() + df = None + + try: + if ext == '.csv': + # 尝试多种编码 + for encoding in ['utf-8', 'gbk', 'latin1']: + try: + df = pd.read_csv(file_path, encoding=encoding) + break + except UnicodeDecodeError: + continue + elif ext in ['.xlsx', '.xls']: + df = pd.read_excel(file_path) + else: + print(f"⚠️ 不支持的文件格式: {ext}") + return None + + # 缓存数据 + if df is not None and app_config.data_cache_enabled: + data_cache.set(cache_key, df) + print(f"✅ 数据已缓存: {os.path.basename(file_path)}") + + return df + + except Exception as e: + print(f"❌ 加载数据失败: {e}") + return None diff --git a/utils/data_quality.py b/utils/data_quality.py new file mode 100644 index 0000000..4458d62 --- /dev/null +++ b/utils/data_quality.py @@ -0,0 +1,224 @@ +# -*- coding: utf-8 -*- +""" +数据质量检查模块 - 自动评估数据质量并提供改进建议 +""" + +import pandas as pd +import numpy as np +from typing import Dict, List, Tuple, Any +from dataclasses import dataclass + + +@dataclass +class QualityIssue: + """数据质量问题""" + column: str + issue_type: str # missing, duplicate, outlier, type_mismatch等 + severity: str # high, medium, low + description: str + suggestion: str + + +class DataQualityChecker: + """数据质量检查器""" + + def __init__(self, df: pd.DataFrame): + self.df = df + self.issues: List[QualityIssue] = [] + self.quality_score: float = 100.0 + + def check_all(self) -> Dict[str, Any]: + """执行所有质量检查""" + self.check_missing_values() + self.check_duplicates() + self.check_data_types() + self.check_outliers() + self.check_consistency() + + return self.generate_report() + + def check_missing_values(self) -> None: + """检查缺失值""" + for col in self.df.columns: + missing_count = self.df[col].isnull().sum() + missing_ratio = (missing_count / len(self.df)) * 100 + + if missing_ratio > 50: + severity = "high" + self.quality_score -= 10 + elif missing_ratio > 20: + severity = "medium" + self.quality_score -= 5 + elif missing_ratio > 0: + severity = "low" + self.quality_score -= 2 + else: + continue + + issue = QualityIssue( + column=col, + issue_type="missing", + severity=severity, + description=f"列 '{col}' 存在 {missing_count} 个缺失值 ({missing_ratio:.1f}%)", + suggestion=self._suggest_missing_handling(col, missing_ratio) + ) + self.issues.append(issue) + + def check_duplicates(self) -> None: + """检查重复数据""" + duplicate_count = self.df.duplicated().sum() + if duplicate_count > 0: + duplicate_ratio = (duplicate_count / len(self.df)) * 100 + + severity = "high" if duplicate_ratio > 10 else "medium" + self.quality_score -= 5 if severity == "high" else 3 + + issue = QualityIssue( + column="全表", + issue_type="duplicate", + severity=severity, + description=f"发现 {duplicate_count} 行重复数据 ({duplicate_ratio:.1f}%)", + suggestion="建议使用 df.drop_duplicates() 删除重复行,或检查是否为合理的重复记录" + ) + self.issues.append(issue) + + def check_data_types(self) -> None: + """检查数据类型一致性""" + for col in self.df.columns: + # 检查是否有数值列被识别为object + if self.df[col].dtype == 'object': + try: + # 尝试转换为数值 + pd.to_numeric(self.df[col].dropna(), errors='raise') + + issue = QualityIssue( + column=col, + issue_type="type_mismatch", + severity="medium", + description=f"列 '{col}' 当前为文本类型,但可以转换为数值类型", + suggestion=f"建议使用 df['{col}'] = pd.to_numeric(df['{col}']) 转换类型" + ) + self.issues.append(issue) + self.quality_score -= 3 + except: + pass + + def check_outliers(self) -> None: + """检查数值列的异常值""" + numeric_cols = self.df.select_dtypes(include=[np.number]).columns + + for col in numeric_cols: + q1 = self.df[col].quantile(0.25) + q3 = self.df[col].quantile(0.75) + iqr = q3 - q1 + + lower_bound = q1 - 3 * iqr + upper_bound = q3 + 3 * iqr + + outliers = self.df[(self.df[col] < lower_bound) | (self.df[col] > upper_bound)] + outlier_count = len(outliers) + + if outlier_count > 0: + outlier_ratio = (outlier_count / len(self.df)) * 100 + + if outlier_ratio > 5: + severity = "medium" + self.quality_score -= 3 + else: + severity = "low" + self.quality_score -= 1 + + issue = QualityIssue( + column=col, + issue_type="outlier", + severity=severity, + description=f"列 '{col}' 存在 {outlier_count} 个异常值 ({outlier_ratio:.1f}%)", + suggestion=f"建议检查 {lower_bound:.2f} 以下和 {upper_bound:.2f} 以上的值是否合理" + ) + self.issues.append(issue) + + def check_consistency(self) -> None: + """检查数据一致性""" + # 检查时间列的时序性 + datetime_cols = self.df.select_dtypes(include=['datetime64']).columns + + for col in datetime_cols: + if not self.df[col].is_monotonic_increasing: + issue = QualityIssue( + column=col, + issue_type="consistency", + severity="medium", + description=f"时间列 '{col}' 不是单调递增的,可能存在乱序", + suggestion=f"建议使用 df.sort_values('{col}') 进行排序" + ) + self.issues.append(issue) + self.quality_score -= 3 + + def _suggest_missing_handling(self, col: str, missing_ratio: float) -> str: + """建议缺失值处理方法""" + if missing_ratio > 70: + return f"缺失比例过高,建议删除列 '{col}'" + elif missing_ratio > 30: + return f"建议填充或删除缺失值:使用中位数/众数填充或删除含缺失值的行" + else: + if pd.api.types.is_numeric_dtype(self.df[col]): + return f"建议使用均值/中位数填充:df['{col}'].fillna(df['{col}'].median())" + else: + return f"建议使用众数填充:df['{col}'].fillna(df['{col}'].mode()[0])" + + def generate_report(self) -> Dict[str, Any]: + """生成质量报告""" + # 确保质量分数在0-100之间 + self.quality_score = max(0, min(100, self.quality_score)) + + # 按严重程度分类 + high_issues = [i for i in self.issues if i.severity == "high"] + medium_issues = [i for i in self.issues if i.severity == "medium"] + low_issues = [i for i in self.issues if i.severity == "low"] + + return { + "quality_score": round(self.quality_score, 2), + "total_issues": len(self.issues), + "high_severity": len(high_issues), + "medium_severity": len(medium_issues), + "low_severity": len(low_issues), + "issues": self.issues, + "summary": self._generate_summary() + } + + def _generate_summary(self) -> str: + """生成可读的摘要""" + summary = f"## 数据质量报告\n\n" + summary += f"**质量评分**: {self.quality_score:.1f}/100\n\n" + + if self.quality_score >= 90: + summary += "✅ **评级**: 优秀 - 数据质量很好\n\n" + elif self.quality_score >= 75: + summary += "⚠️ **评级**: 良好 - 存在一些小问题\n\n" + elif self.quality_score >= 60: + summary += "⚠️ **评级**: 一般 - 需要处理多个问题\n\n" + else: + summary += "❌ **评级**: 差 - 数据质量问题严重\n\n" + + summary += f"**问题统计**: 共 {len(self.issues)} 个质量问题\n" + summary += f"- 🔴 高严重性: {len([i for i in self.issues if i.severity == 'high'])} 个\n" + summary += f"- 🟡 中严重性: {len([i for i in self.issues if i.severity == 'medium'])} 个\n" + summary += f"- 🟢 低严重性: {len([i for i in self.issues if i.severity == 'low'])} 个\n\n" + + if self.issues: + summary += "### 主要问题:\n\n" + # 只显示高和中严重性的问题 + for issue in self.issues: + if issue.severity in ["high", "medium"]: + emoji = "🔴" if issue.severity == "high" else "🟡" + summary += f"{emoji} **{issue.column}** - {issue.description}\n" + summary += f" 💡 {issue.suggestion}\n\n" + + return summary + + +def quick_quality_check(df: pd.DataFrame) -> str: + """快速数据质量检查""" + checker = DataQualityChecker(df) + report = checker.check_all() + return report['summary'] diff --git a/utils/llm_helper.py b/utils/llm_helper.py index cf88b6c..34b50e0 100644 --- a/utils/llm_helper.py +++ b/utils/llm_helper.py @@ -5,8 +5,17 @@ LLM调用辅助模块 import asyncio import yaml +from typing import Optional, Callable, AsyncIterator from config.llm_config import LLMConfig +from config.app_config import app_config from utils.fallback_openai_client import AsyncFallbackOpenAIClient +from utils.cache_manager import LLMCacheManager + +# 初始化LLM缓存管理器 +llm_cache = LLMCacheManager( + cache_dir=app_config.llm_cache_dir, + enabled=app_config.llm_cache_enabled +) class LLMHelper: """LLM调用辅助类,支持同步和异步调用""" @@ -82,6 +91,104 @@ class LLMHelper: print(f"原始响应: {response}") return {} + async def close(self): """关闭客户端""" - await self.client.close() \ No newline at end of file + await self.client.close() + + async def async_call_with_cache( + self, + prompt: str, + system_prompt: str = None, + max_tokens: int = None, + temperature: float = None, + use_cache: bool = True + ) -> str: + """带缓存的异步LLM调用""" + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + # 生成缓存键 + cache_key = llm_cache.get_cache_key_from_messages(messages, self.config.model) + + # 尝试从缓存获取 + if use_cache and app_config.llm_cache_enabled: + cached_response = llm_cache.get(cache_key) + if cached_response: + print("💾 使用LLM缓存响应") + return cached_response + + # 调用LLM + response = await self.async_call(prompt, system_prompt, max_tokens, temperature) + + # 缓存响应 + if use_cache and app_config.llm_cache_enabled and response: + llm_cache.set(cache_key, response) + + return response + + def call_with_cache( + self, + prompt: str, + system_prompt: str = None, + max_tokens: int = None, + temperature: float = None, + use_cache: bool = True + ) -> str: + """带缓存的同步LLM调用""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + import nest_asyncio + nest_asyncio.apply() + + return loop.run_until_complete( + self.async_call_with_cache(prompt, system_prompt, max_tokens, temperature, use_cache) + ) + + async def async_call_stream( + self, + prompt: str, + system_prompt: str = None, + max_tokens: int = None, + temperature: float = None, + callback: Optional[Callable[[str], None]] = None + ) -> AsyncIterator[str]: + """流式异步LLM调用""" + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + kwargs = { + 'stream': True, + 'max_tokens': max_tokens or self.config.max_tokens, + 'temperature': temperature or self.config.temperature + } + + try: + response = await self.client.chat_completions_create( + messages=messages, + **kwargs + ) + + full_response = "" + async for chunk in response: + if chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + full_response += content + + # 调用回调函数 + if callback: + callback(content) + + yield content + + except Exception as e: + print(f"流式LLM调用失败: {e}") + yield "" \ No newline at end of file diff --git a/web/main.py b/web/main.py index c64efbd..9abac5b 100644 --- a/web/main.py +++ b/web/main.py @@ -43,6 +43,18 @@ class SessionData: self.log_file: Optional[str] = None self.analysis_results: List[Dict] = [] # Store analysis results for gallery self.agent: Optional[DataAnalysisAgent] = None # Store the agent instance for follow-up + + # 新增:进度跟踪 + self.current_round: int = 0 + self.max_rounds: int = 20 + self.progress_percentage: float = 0.0 + self.status_message: str = "等待开始" + + # 新增:历史记录 + self.created_at: str = "" + self.last_updated: str = "" + self.user_requirement: str = "" + self.file_list: List[str] = [] class SessionManager: @@ -56,11 +68,40 @@ class SessionManager: self.sessions[session_id] = SessionData(session_id) return session_id + def get_session(self, session_id: str) -> Optional[SessionData]: return self.sessions.get(session_id) def list_sessions(self): return list(self.sessions.keys()) + + def delete_session(self, session_id: str) -> bool: + """删除指定会话""" + with self.lock: + if session_id in self.sessions: + session = self.sessions[session_id] + if session.agent: + session.agent.reset() + del self.sessions[session_id] + return True + return False + + def get_session_info(self, session_id: str) -> Optional[Dict]: + """获取会话详细信息""" + session = self.get_session(session_id) + if session: + return { + "session_id": session.session_id, + "is_running": session.is_running, + "progress": session.progress_percentage, + "status": session.status_message, + "current_round": session.current_round, + "max_rounds": session.max_rounds, + "created_at": session.created_at, + "last_updated": session.last_updated, + "user_requirement": session.user_requirement[:100] + "..." if len(session.user_requirement) > 100 else session.user_requirement + } + return None session_manager = SessionManager() @@ -477,3 +518,42 @@ async def get_troubleshooting_guide(): return {"content": content} except Exception as e: return {"content": f"# Error Loading Guide\n\n{e}"} + + +# --- 新增API端点 --- + +@app.get("/api/sessions/progress") +async def get_session_progress(session_id: str = Query(..., description="Session ID")): + """获取会话分析进度""" + session_info = session_manager.get_session_info(session_id) + if not session_info: + raise HTTPException(status_code=404, detail="Session not found") + return session_info + + +@app.get("/api/sessions/list") +async def list_all_sessions(): + """获取所有会话列表""" + session_ids = session_manager.list_sessions() + sessions_info = [] + + for sid in session_ids: + info = session_manager.get_session_info(sid) + if info: + sessions_info.append(info) + + return {"sessions": sessions_info, "total": len(sessions_info)} + + +@app.delete("/api/sessions/{session_id}") +async def delete_specific_session(session_id: str): + """删除指定会话""" + success = session_manager.delete_session(session_id) + if not success: + raise HTTPException(status_code=404, detail="Session not found") + return {"status": "deleted", "session_id": session_id} + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000)