"""工具管理器,负责根据数据特征动态选择和管理工具。""" from typing import List, Dict, Any import pandas as pd from src.tools.base import AnalysisTool, ToolRegistry from src.models import DataProfile class ToolManager: """ 工具管理器,负责根据数据特征动态选择合适的工具。 """ def __init__(self, registry: ToolRegistry = None): """ 初始化工具管理器。 参数: registry: 工具注册表,如果为 None 则创建新的注册表 """ self.registry = registry if registry else ToolRegistry() self._missing_tools: List[str] = [] def select_tools(self, data_profile: DataProfile) -> List[AnalysisTool]: """ 根据数据画像选择合适的工具。 参数: data_profile: 数据画像 返回: 适用的工具列表 """ selected_tools = [] # 检查时间字段 if self._has_datetime_column(data_profile): selected_tools.extend(self._get_time_series_tools()) # 检查分类字段 if self._has_categorical_column(data_profile): selected_tools.extend(self._get_categorical_tools()) # 检查数值字段 if self._has_numeric_column(data_profile): selected_tools.extend(self._get_numeric_tools()) # 检查地理字段 if self._has_geo_column(data_profile): selected_tools.extend(self._get_geo_tools()) # 添加通用工具(适用于所有数据) selected_tools.extend(self._get_universal_tools()) # 去重 unique_tools = [] seen_names = set() for tool in selected_tools: if tool.name not in seen_names: unique_tools.append(tool) seen_names.add(tool.name) return unique_tools def _has_datetime_column(self, data_profile: DataProfile) -> bool: """检查是否包含日期时间列。""" return any(col.dtype == 'datetime' for col in data_profile.columns) def _has_categorical_column(self, data_profile: DataProfile) -> bool: """检查是否包含分类列。""" return any(col.dtype == 'categorical' for col in data_profile.columns) def _has_numeric_column(self, data_profile: DataProfile) -> bool: """检查是否包含数值列。""" return any(col.dtype == 'numeric' for col in data_profile.columns) def _has_geo_column(self, data_profile: DataProfile) -> bool: """检查是否包含地理列。""" # 检查列名是否包含地理相关关键词 geo_keywords = ['lat', 'lon', 'latitude', 'longitude', 'location', 'address', 'city', 'country'] for col in data_profile.columns: col_name_lower = col.name.lower() if any(keyword in col_name_lower for keyword in geo_keywords): return True return False def _get_time_series_tools(self) -> List[AnalysisTool]: """获取时间序列分析工具。""" tools = [] tool_names = ['get_time_series', 'calculate_trend', 'create_line_chart'] for tool_name in tool_names: try: tool = self.registry.get_tool(tool_name) tools.append(tool) except KeyError: self._missing_tools.append(tool_name) return tools def _get_categorical_tools(self) -> List[AnalysisTool]: """获取分类数据分析工具。""" tools = [] tool_names = ['get_column_distribution', 'get_value_counts', 'perform_groupby', 'create_bar_chart', 'create_pie_chart'] for tool_name in tool_names: try: tool = self.registry.get_tool(tool_name) tools.append(tool) except KeyError: self._missing_tools.append(tool_name) return tools def _get_numeric_tools(self) -> List[AnalysisTool]: """获取数值数据分析工具。""" tools = [] tool_names = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap'] for tool_name in tool_names: try: tool = self.registry.get_tool(tool_name) tools.append(tool) except KeyError: self._missing_tools.append(tool_name) return tools def _get_geo_tools(self) -> List[AnalysisTool]: """获取地理数据分析工具。""" tools = [] # 目前没有实现地理工具,记录为缺失 tool_names = ['create_map_visualization'] for tool_name in tool_names: try: tool = self.registry.get_tool(tool_name) tools.append(tool) except KeyError: self._missing_tools.append(tool_name) return tools def _get_universal_tools(self) -> List[AnalysisTool]: """获取通用工具(适用于所有数据)。""" tools = [] # 通用工具已经在其他类别中包含了 return tools def get_missing_tools(self) -> List[str]: """ 获取缺失的工具列表。 返回: 缺失的工具名称列表 """ return list(set(self._missing_tools)) def clear_missing_tools(self) -> None: """清空缺失工具列表。""" self._missing_tools = [] def get_tool_descriptions(self, tools: List[AnalysisTool]) -> List[Dict[str, Any]]: """ 获取工具的描述信息(供 AI 选择)。 参数: tools: 工具列表 返回: 工具描述列表 """ descriptions = [] for tool in tools: descriptions.append({ 'name': tool.name, 'description': tool.description, 'parameters': tool.parameters }) return descriptions