Files
vibe_data_ana/src/tools/tool_manager.py

183 lines
6.1 KiB
Python
Raw Normal View History

"""工具管理器,负责根据数据特征动态选择和管理工具。"""
from typing import List, Dict, Any
import pandas as pd
from src.tools.base import AnalysisTool, ToolRegistry
from src.models import DataProfile
class ToolManager:
"""
工具管理器负责根据数据特征动态选择合适的工具
"""
def __init__(self, registry: ToolRegistry = None):
"""
初始化工具管理器
参数
registry: 工具注册表如果为 None 则创建新的注册表
"""
self.registry = registry if registry else ToolRegistry()
self._missing_tools: List[str] = []
def select_tools(self, data_profile: DataProfile) -> List[AnalysisTool]:
"""
根据数据画像选择合适的工具
参数
data_profile: 数据画像
返回
适用的工具列表
"""
selected_tools = []
# 检查时间字段
if self._has_datetime_column(data_profile):
selected_tools.extend(self._get_time_series_tools())
# 检查分类字段
if self._has_categorical_column(data_profile):
selected_tools.extend(self._get_categorical_tools())
# 检查数值字段
if self._has_numeric_column(data_profile):
selected_tools.extend(self._get_numeric_tools())
# 检查地理字段
if self._has_geo_column(data_profile):
selected_tools.extend(self._get_geo_tools())
# 添加通用工具(适用于所有数据)
selected_tools.extend(self._get_universal_tools())
# 去重
unique_tools = []
seen_names = set()
for tool in selected_tools:
if tool.name not in seen_names:
unique_tools.append(tool)
seen_names.add(tool.name)
return unique_tools
def _has_datetime_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含日期时间列。"""
return any(col.dtype == 'datetime' for col in data_profile.columns)
def _has_categorical_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含分类列。"""
return any(col.dtype == 'categorical' for col in data_profile.columns)
def _has_numeric_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含数值列。"""
return any(col.dtype == 'numeric' for col in data_profile.columns)
def _has_geo_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含地理列。"""
# 检查列名是否包含地理相关关键词
geo_keywords = ['lat', 'lon', 'latitude', 'longitude', 'location', 'address', 'city', 'country']
for col in data_profile.columns:
col_name_lower = col.name.lower()
if any(keyword in col_name_lower for keyword in geo_keywords):
return True
return False
def _get_time_series_tools(self) -> List[AnalysisTool]:
"""获取时间序列分析工具。"""
tools = []
tool_names = ['get_time_series', 'calculate_trend', 'create_line_chart']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_categorical_tools(self) -> List[AnalysisTool]:
"""获取分类数据分析工具。"""
tools = []
tool_names = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
'create_bar_chart', 'create_pie_chart']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_numeric_tools(self) -> List[AnalysisTool]:
"""获取数值数据分析工具。"""
tools = []
tool_names = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_geo_tools(self) -> List[AnalysisTool]:
"""获取地理数据分析工具。"""
tools = []
# 目前没有实现地理工具,记录为缺失
tool_names = ['create_map_visualization']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_universal_tools(self) -> List[AnalysisTool]:
"""获取通用工具(适用于所有数据)。"""
tools = []
# 通用工具已经在其他类别中包含了
return tools
def get_missing_tools(self) -> List[str]:
"""
获取缺失的工具列表
返回
缺失的工具名称列表
"""
return list(set(self._missing_tools))
def clear_missing_tools(self) -> None:
"""清空缺失工具列表。"""
self._missing_tools = []
def get_tool_descriptions(self, tools: List[AnalysisTool]) -> List[Dict[str, Any]]:
"""
获取工具的描述信息 AI 选择
参数
tools: 工具列表
返回
工具描述列表
"""
descriptions = []
for tool in tools:
descriptions.append({
'name': tool.name,
'description': tool.description,
'parameters': tool.parameters
})
return descriptions