Complete AI Data Analysis Agent implementation with 95.7% test coverage

This commit is contained in:
2026-03-07 00:04:29 +08:00
parent 621e546b43
commit 7071b1f730
245 changed files with 22612 additions and 2211 deletions

182
src/tools/tool_manager.py Normal file
View File

@@ -0,0 +1,182 @@
"""工具管理器,负责根据数据特征动态选择和管理工具。"""
from typing import List, Dict, Any
import pandas as pd
from src.tools.base import AnalysisTool, ToolRegistry
from src.models import DataProfile
class ToolManager:
"""
工具管理器,负责根据数据特征动态选择合适的工具。
"""
def __init__(self, registry: ToolRegistry = None):
"""
初始化工具管理器。
参数:
registry: 工具注册表,如果为 None 则创建新的注册表
"""
self.registry = registry if registry else ToolRegistry()
self._missing_tools: List[str] = []
def select_tools(self, data_profile: DataProfile) -> List[AnalysisTool]:
"""
根据数据画像选择合适的工具。
参数:
data_profile: 数据画像
返回:
适用的工具列表
"""
selected_tools = []
# 检查时间字段
if self._has_datetime_column(data_profile):
selected_tools.extend(self._get_time_series_tools())
# 检查分类字段
if self._has_categorical_column(data_profile):
selected_tools.extend(self._get_categorical_tools())
# 检查数值字段
if self._has_numeric_column(data_profile):
selected_tools.extend(self._get_numeric_tools())
# 检查地理字段
if self._has_geo_column(data_profile):
selected_tools.extend(self._get_geo_tools())
# 添加通用工具(适用于所有数据)
selected_tools.extend(self._get_universal_tools())
# 去重
unique_tools = []
seen_names = set()
for tool in selected_tools:
if tool.name not in seen_names:
unique_tools.append(tool)
seen_names.add(tool.name)
return unique_tools
def _has_datetime_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含日期时间列。"""
return any(col.dtype == 'datetime' for col in data_profile.columns)
def _has_categorical_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含分类列。"""
return any(col.dtype == 'categorical' for col in data_profile.columns)
def _has_numeric_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含数值列。"""
return any(col.dtype == 'numeric' for col in data_profile.columns)
def _has_geo_column(self, data_profile: DataProfile) -> bool:
"""检查是否包含地理列。"""
# 检查列名是否包含地理相关关键词
geo_keywords = ['lat', 'lon', 'latitude', 'longitude', 'location', 'address', 'city', 'country']
for col in data_profile.columns:
col_name_lower = col.name.lower()
if any(keyword in col_name_lower for keyword in geo_keywords):
return True
return False
def _get_time_series_tools(self) -> List[AnalysisTool]:
"""获取时间序列分析工具。"""
tools = []
tool_names = ['get_time_series', 'calculate_trend', 'create_line_chart']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_categorical_tools(self) -> List[AnalysisTool]:
"""获取分类数据分析工具。"""
tools = []
tool_names = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
'create_bar_chart', 'create_pie_chart']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_numeric_tools(self) -> List[AnalysisTool]:
"""获取数值数据分析工具。"""
tools = []
tool_names = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_geo_tools(self) -> List[AnalysisTool]:
"""获取地理数据分析工具。"""
tools = []
# 目前没有实现地理工具,记录为缺失
tool_names = ['create_map_visualization']
for tool_name in tool_names:
try:
tool = self.registry.get_tool(tool_name)
tools.append(tool)
except KeyError:
self._missing_tools.append(tool_name)
return tools
def _get_universal_tools(self) -> List[AnalysisTool]:
"""获取通用工具(适用于所有数据)。"""
tools = []
# 通用工具已经在其他类别中包含了
return tools
def get_missing_tools(self) -> List[str]:
"""
获取缺失的工具列表。
返回:
缺失的工具名称列表
"""
return list(set(self._missing_tools))
def clear_missing_tools(self) -> None:
"""清空缺失工具列表。"""
self._missing_tools = []
def get_tool_descriptions(self, tools: List[AnalysisTool]) -> List[Dict[str, Any]]:
"""
获取工具的描述信息(供 AI 选择)。
参数:
tools: 工具列表
返回:
工具描述列表
"""
descriptions = []
for tool in tools:
descriptions.append({
'name': tool.name,
'description': tool.description,
'parameters': tool.parameters
})
return descriptions