Complete AI Data Analysis Agent implementation with 95.7% test coverage
This commit is contained in:
182
src/tools/tool_manager.py
Normal file
182
src/tools/tool_manager.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""工具管理器,负责根据数据特征动态选择和管理工具。"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import pandas as pd
|
||||
|
||||
from src.tools.base import AnalysisTool, ToolRegistry
|
||||
from src.models import DataProfile
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
工具管理器,负责根据数据特征动态选择合适的工具。
|
||||
"""
|
||||
|
||||
def __init__(self, registry: ToolRegistry = None):
|
||||
"""
|
||||
初始化工具管理器。
|
||||
|
||||
参数:
|
||||
registry: 工具注册表,如果为 None 则创建新的注册表
|
||||
"""
|
||||
self.registry = registry if registry else ToolRegistry()
|
||||
self._missing_tools: List[str] = []
|
||||
|
||||
def select_tools(self, data_profile: DataProfile) -> List[AnalysisTool]:
|
||||
"""
|
||||
根据数据画像选择合适的工具。
|
||||
|
||||
参数:
|
||||
data_profile: 数据画像
|
||||
|
||||
返回:
|
||||
适用的工具列表
|
||||
"""
|
||||
selected_tools = []
|
||||
|
||||
# 检查时间字段
|
||||
if self._has_datetime_column(data_profile):
|
||||
selected_tools.extend(self._get_time_series_tools())
|
||||
|
||||
# 检查分类字段
|
||||
if self._has_categorical_column(data_profile):
|
||||
selected_tools.extend(self._get_categorical_tools())
|
||||
|
||||
# 检查数值字段
|
||||
if self._has_numeric_column(data_profile):
|
||||
selected_tools.extend(self._get_numeric_tools())
|
||||
|
||||
# 检查地理字段
|
||||
if self._has_geo_column(data_profile):
|
||||
selected_tools.extend(self._get_geo_tools())
|
||||
|
||||
# 添加通用工具(适用于所有数据)
|
||||
selected_tools.extend(self._get_universal_tools())
|
||||
|
||||
# 去重
|
||||
unique_tools = []
|
||||
seen_names = set()
|
||||
for tool in selected_tools:
|
||||
if tool.name not in seen_names:
|
||||
unique_tools.append(tool)
|
||||
seen_names.add(tool.name)
|
||||
|
||||
return unique_tools
|
||||
|
||||
def _has_datetime_column(self, data_profile: DataProfile) -> bool:
|
||||
"""检查是否包含日期时间列。"""
|
||||
return any(col.dtype == 'datetime' for col in data_profile.columns)
|
||||
|
||||
def _has_categorical_column(self, data_profile: DataProfile) -> bool:
|
||||
"""检查是否包含分类列。"""
|
||||
return any(col.dtype == 'categorical' for col in data_profile.columns)
|
||||
|
||||
def _has_numeric_column(self, data_profile: DataProfile) -> bool:
|
||||
"""检查是否包含数值列。"""
|
||||
return any(col.dtype == 'numeric' for col in data_profile.columns)
|
||||
|
||||
def _has_geo_column(self, data_profile: DataProfile) -> bool:
|
||||
"""检查是否包含地理列。"""
|
||||
# 检查列名是否包含地理相关关键词
|
||||
geo_keywords = ['lat', 'lon', 'latitude', 'longitude', 'location', 'address', 'city', 'country']
|
||||
for col in data_profile.columns:
|
||||
col_name_lower = col.name.lower()
|
||||
if any(keyword in col_name_lower for keyword in geo_keywords):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_time_series_tools(self) -> List[AnalysisTool]:
|
||||
"""获取时间序列分析工具。"""
|
||||
tools = []
|
||||
tool_names = ['get_time_series', 'calculate_trend', 'create_line_chart']
|
||||
|
||||
for tool_name in tool_names:
|
||||
try:
|
||||
tool = self.registry.get_tool(tool_name)
|
||||
tools.append(tool)
|
||||
except KeyError:
|
||||
self._missing_tools.append(tool_name)
|
||||
|
||||
return tools
|
||||
|
||||
def _get_categorical_tools(self) -> List[AnalysisTool]:
|
||||
"""获取分类数据分析工具。"""
|
||||
tools = []
|
||||
tool_names = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
|
||||
'create_bar_chart', 'create_pie_chart']
|
||||
|
||||
for tool_name in tool_names:
|
||||
try:
|
||||
tool = self.registry.get_tool(tool_name)
|
||||
tools.append(tool)
|
||||
except KeyError:
|
||||
self._missing_tools.append(tool_name)
|
||||
|
||||
return tools
|
||||
|
||||
def _get_numeric_tools(self) -> List[AnalysisTool]:
|
||||
"""获取数值数据分析工具。"""
|
||||
tools = []
|
||||
tool_names = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
|
||||
|
||||
for tool_name in tool_names:
|
||||
try:
|
||||
tool = self.registry.get_tool(tool_name)
|
||||
tools.append(tool)
|
||||
except KeyError:
|
||||
self._missing_tools.append(tool_name)
|
||||
|
||||
return tools
|
||||
|
||||
def _get_geo_tools(self) -> List[AnalysisTool]:
|
||||
"""获取地理数据分析工具。"""
|
||||
tools = []
|
||||
# 目前没有实现地理工具,记录为缺失
|
||||
tool_names = ['create_map_visualization']
|
||||
|
||||
for tool_name in tool_names:
|
||||
try:
|
||||
tool = self.registry.get_tool(tool_name)
|
||||
tools.append(tool)
|
||||
except KeyError:
|
||||
self._missing_tools.append(tool_name)
|
||||
|
||||
return tools
|
||||
|
||||
def _get_universal_tools(self) -> List[AnalysisTool]:
|
||||
"""获取通用工具(适用于所有数据)。"""
|
||||
tools = []
|
||||
# 通用工具已经在其他类别中包含了
|
||||
return tools
|
||||
|
||||
def get_missing_tools(self) -> List[str]:
|
||||
"""
|
||||
获取缺失的工具列表。
|
||||
|
||||
返回:
|
||||
缺失的工具名称列表
|
||||
"""
|
||||
return list(set(self._missing_tools))
|
||||
|
||||
def clear_missing_tools(self) -> None:
|
||||
"""清空缺失工具列表。"""
|
||||
self._missing_tools = []
|
||||
|
||||
def get_tool_descriptions(self, tools: List[AnalysisTool]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取工具的描述信息(供 AI 选择)。
|
||||
|
||||
参数:
|
||||
tools: 工具列表
|
||||
|
||||
返回:
|
||||
工具描述列表
|
||||
"""
|
||||
descriptions = []
|
||||
for tool in tools:
|
||||
descriptions.append({
|
||||
'name': tool.name,
|
||||
'description': tool.description,
|
||||
'parameters': tool.parameters
|
||||
})
|
||||
return descriptions
|
||||
Reference in New Issue
Block a user