621 lines
25 KiB
Python
621 lines
25 KiB
Python
|
|
"""工具系统的基于属性的测试。"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
import pandas as pd
|
|||
|
|
import numpy as np
|
|||
|
|
from hypothesis import given, strategies as st, settings, assume
|
|||
|
|
from typing import Dict, Any
|
|||
|
|
|
|||
|
|
from src.tools.base import AnalysisTool, ToolRegistry
|
|||
|
|
from src.tools.query_tools import (
|
|||
|
|
GetColumnDistributionTool,
|
|||
|
|
GetValueCountsTool,
|
|||
|
|
GetTimeSeriesTool,
|
|||
|
|
GetCorrelationTool
|
|||
|
|
)
|
|||
|
|
from src.tools.stats_tools import (
|
|||
|
|
CalculateStatisticsTool,
|
|||
|
|
PerformGroupbyTool,
|
|||
|
|
DetectOutliersTool,
|
|||
|
|
CalculateTrendTool
|
|||
|
|
)
|
|||
|
|
from src.models import DataProfile, ColumnInfo
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Hypothesis 策略用于生成测试数据
|
|||
|
|
|
|||
|
|
@st.composite
|
|||
|
|
def column_info_strategy(draw):
|
|||
|
|
"""生成随机的 ColumnInfo 实例。"""
|
|||
|
|
dtype = draw(st.sampled_from(['numeric', 'categorical', 'datetime', 'text']))
|
|||
|
|
return ColumnInfo(
|
|||
|
|
name=draw(st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=('Lu', 'Ll')))),
|
|||
|
|
dtype=dtype,
|
|||
|
|
missing_rate=draw(st.floats(min_value=0.0, max_value=1.0)),
|
|||
|
|
unique_count=draw(st.integers(min_value=1, max_value=1000)),
|
|||
|
|
sample_values=draw(st.lists(st.integers(), min_size=1, max_size=5)),
|
|||
|
|
statistics={'mean': draw(st.floats(allow_nan=False, allow_infinity=False))} if dtype == 'numeric' else {}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@st.composite
|
|||
|
|
def data_profile_strategy(draw):
|
|||
|
|
"""生成随机的 DataProfile 实例。"""
|
|||
|
|
columns = draw(st.lists(column_info_strategy(), min_size=1, max_size=10))
|
|||
|
|
return DataProfile(
|
|||
|
|
file_path=draw(st.text(min_size=1, max_size=50)),
|
|||
|
|
row_count=draw(st.integers(min_value=1, max_value=10000)),
|
|||
|
|
column_count=len(columns),
|
|||
|
|
columns=columns,
|
|||
|
|
inferred_type=draw(st.sampled_from(['ticket', 'sales', 'user', 'unknown'])),
|
|||
|
|
key_fields={},
|
|||
|
|
quality_score=draw(st.floats(min_value=0.0, max_value=100.0)),
|
|||
|
|
summary=draw(st.text(max_size=100))
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@st.composite
|
|||
|
|
def dataframe_strategy(draw, min_rows=10, max_rows=100, min_cols=2, max_cols=10):
|
|||
|
|
"""生成随机的 DataFrame 实例。"""
|
|||
|
|
n_rows = draw(st.integers(min_value=min_rows, max_value=max_rows))
|
|||
|
|
n_cols = draw(st.integers(min_value=min_cols, max_value=max_cols))
|
|||
|
|
|
|||
|
|
data = {}
|
|||
|
|
for i in range(n_cols):
|
|||
|
|
col_type = draw(st.sampled_from(['int', 'float', 'str']))
|
|||
|
|
col_name = f'col_{i}'
|
|||
|
|
|
|||
|
|
if col_type == 'int':
|
|||
|
|
data[col_name] = draw(st.lists(
|
|||
|
|
st.integers(min_value=-1000, max_value=1000),
|
|||
|
|
min_size=n_rows,
|
|||
|
|
max_size=n_rows
|
|||
|
|
))
|
|||
|
|
elif col_type == 'float':
|
|||
|
|
data[col_name] = draw(st.lists(
|
|||
|
|
st.floats(min_value=-1000.0, max_value=1000.0, allow_nan=False, allow_infinity=False),
|
|||
|
|
min_size=n_rows,
|
|||
|
|
max_size=n_rows
|
|||
|
|
))
|
|||
|
|
else: # str
|
|||
|
|
data[col_name] = draw(st.lists(
|
|||
|
|
st.text(min_size=1, max_size=10, alphabet=st.characters(whitelist_categories=('Lu', 'Ll'))),
|
|||
|
|
min_size=n_rows,
|
|||
|
|
max_size=n_rows
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
return pd.DataFrame(data)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 获取所有工具类用于测试
|
|||
|
|
ALL_TOOLS = [
|
|||
|
|
GetColumnDistributionTool,
|
|||
|
|
GetValueCountsTool,
|
|||
|
|
GetTimeSeriesTool,
|
|||
|
|
GetCorrelationTool,
|
|||
|
|
CalculateStatisticsTool,
|
|||
|
|
PerformGroupbyTool,
|
|||
|
|
DetectOutliersTool,
|
|||
|
|
CalculateTrendTool
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Feature: true-ai-agent, Property 10: 工具接口一致性
|
|||
|
|
@given(tool_class=st.sampled_from(ALL_TOOLS))
|
|||
|
|
@settings(max_examples=20)
|
|||
|
|
def test_tool_interface_consistency(tool_class):
|
|||
|
|
"""
|
|||
|
|
属性 10:对于任何工具,它应该实现标准接口(name, description, parameters,
|
|||
|
|
execute, is_applicable),并且 execute 方法应该接受 DataFrame 和参数,
|
|||
|
|
返回字典格式的聚合结果。
|
|||
|
|
|
|||
|
|
验证需求:FR-4.1
|
|||
|
|
"""
|
|||
|
|
# 创建工具实例
|
|||
|
|
tool = tool_class()
|
|||
|
|
|
|||
|
|
# 验证:工具应该是 AnalysisTool 的子类
|
|||
|
|
assert isinstance(tool, AnalysisTool), f"{tool_class.__name__} 不是 AnalysisTool 的子类"
|
|||
|
|
|
|||
|
|
# 验证:工具应该有 name 属性,且返回字符串
|
|||
|
|
assert hasattr(tool, 'name'), f"{tool_class.__name__} 缺少 name 属性"
|
|||
|
|
assert isinstance(tool.name, str), f"{tool_class.__name__}.name 不是字符串"
|
|||
|
|
assert len(tool.name) > 0, f"{tool_class.__name__}.name 是空字符串"
|
|||
|
|
|
|||
|
|
# 验证:工具应该有 description 属性,且返回字符串
|
|||
|
|
assert hasattr(tool, 'description'), f"{tool_class.__name__} 缺少 description 属性"
|
|||
|
|
assert isinstance(tool.description, str), f"{tool_class.__name__}.description 不是字符串"
|
|||
|
|
assert len(tool.description) > 0, f"{tool_class.__name__}.description 是空字符串"
|
|||
|
|
|
|||
|
|
# 验证:工具应该有 parameters 属性,且返回字典
|
|||
|
|
assert hasattr(tool, 'parameters'), f"{tool_class.__name__} 缺少 parameters 属性"
|
|||
|
|
assert isinstance(tool.parameters, dict), f"{tool_class.__name__}.parameters 不是字典"
|
|||
|
|
|
|||
|
|
# 验证:parameters 应该符合 JSON Schema 格式
|
|||
|
|
params = tool.parameters
|
|||
|
|
assert 'type' in params, f"{tool_class.__name__}.parameters 缺少 'type' 字段"
|
|||
|
|
assert params['type'] == 'object', f"{tool_class.__name__}.parameters.type 不是 'object'"
|
|||
|
|
|
|||
|
|
# 验证:工具应该有 execute 方法
|
|||
|
|
assert hasattr(tool, 'execute'), f"{tool_class.__name__} 缺少 execute 方法"
|
|||
|
|
assert callable(tool.execute), f"{tool_class.__name__}.execute 不可调用"
|
|||
|
|
|
|||
|
|
# 验证:工具应该有 is_applicable 方法
|
|||
|
|
assert hasattr(tool, 'is_applicable'), f"{tool_class.__name__} 缺少 is_applicable 方法"
|
|||
|
|
assert callable(tool.is_applicable), f"{tool_class.__name__}.is_applicable 不可调用"
|
|||
|
|
|
|||
|
|
# 验证:execute 方法应该接受 DataFrame 和关键字参数
|
|||
|
|
# 创建一个简单的测试 DataFrame
|
|||
|
|
test_df = pd.DataFrame({
|
|||
|
|
'col_0': [1, 2, 3, 4, 5],
|
|||
|
|
'col_1': ['a', 'b', 'c', 'd', 'e']
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 尝试调用 execute(可能会失败,但不应该因为签名问题)
|
|||
|
|
try:
|
|||
|
|
# 使用空参数调用(可能会因为缺少必需参数而失败,这是预期的)
|
|||
|
|
result = tool.execute(test_df)
|
|||
|
|
except (KeyError, ValueError, TypeError) as e:
|
|||
|
|
# 这些异常是可以接受的(参数验证失败)
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 验证:execute 方法应该返回字典
|
|||
|
|
# 我们需要提供有效的参数来测试返回类型
|
|||
|
|
# 根据工具类型提供适当的参数
|
|||
|
|
if tool.name == 'get_column_distribution':
|
|||
|
|
result = tool.execute(test_df, column='col_0')
|
|||
|
|
elif tool.name == 'get_value_counts':
|
|||
|
|
result = tool.execute(test_df, column='col_0')
|
|||
|
|
elif tool.name == 'calculate_statistics':
|
|||
|
|
result = tool.execute(test_df, column='col_0')
|
|||
|
|
elif tool.name == 'perform_groupby':
|
|||
|
|
result = tool.execute(test_df, group_by='col_1')
|
|||
|
|
elif tool.name == 'detect_outliers':
|
|||
|
|
result = tool.execute(test_df, column='col_0')
|
|||
|
|
elif tool.name == 'get_correlation':
|
|||
|
|
test_df_numeric = pd.DataFrame({
|
|||
|
|
'col_0': [1, 2, 3, 4, 5],
|
|||
|
|
'col_1': [2, 4, 6, 8, 10]
|
|||
|
|
})
|
|||
|
|
result = tool.execute(test_df_numeric)
|
|||
|
|
elif tool.name == 'get_time_series':
|
|||
|
|
test_df_time = pd.DataFrame({
|
|||
|
|
'time': pd.date_range('2020-01-01', periods=5),
|
|||
|
|
'value': [1, 2, 3, 4, 5]
|
|||
|
|
})
|
|||
|
|
result = tool.execute(test_df_time, time_column='time')
|
|||
|
|
elif tool.name == 'calculate_trend':
|
|||
|
|
test_df_trend = pd.DataFrame({
|
|||
|
|
'time': pd.date_range('2020-01-01', periods=5),
|
|||
|
|
'value': [1, 2, 3, 4, 5]
|
|||
|
|
})
|
|||
|
|
result = tool.execute(test_df_trend, time_column='time', value_column='value')
|
|||
|
|
else:
|
|||
|
|
# 未知工具,跳过返回类型验证
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 验证:返回值应该是字典
|
|||
|
|
assert isinstance(result, dict), f"{tool_class.__name__}.execute 返回值不是字典,而是 {type(result)}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Feature: true-ai-agent, Property 19: 工具输出过滤
|
|||
|
|
@given(
|
|||
|
|
tool_class=st.sampled_from(ALL_TOOLS),
|
|||
|
|
df=dataframe_strategy(min_rows=200, max_rows=500)
|
|||
|
|
)
|
|||
|
|
@settings(max_examples=20, deadline=None)
|
|||
|
|
def test_tool_output_filtering(tool_class, df):
|
|||
|
|
"""
|
|||
|
|
属性 19:对于任何工具的执行结果,返回的数据应该是聚合后的(如统计值、
|
|||
|
|
分组计数、图表数据),单次返回的数据行数不应超过100行,并且不应包含
|
|||
|
|
完整的原始数据表。
|
|||
|
|
|
|||
|
|
验证需求:约束条件5.3
|
|||
|
|
"""
|
|||
|
|
# 创建工具实例
|
|||
|
|
tool = tool_class()
|
|||
|
|
|
|||
|
|
# 确保 DataFrame 有足够的行数来测试过滤
|
|||
|
|
assume(len(df) >= 200)
|
|||
|
|
|
|||
|
|
# 根据工具类型准备适当的参数和数据
|
|||
|
|
result = None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if tool.name == 'get_column_distribution':
|
|||
|
|
# 使用第一列
|
|||
|
|
col_name = df.columns[0]
|
|||
|
|
result = tool.execute(df, column=col_name, top_n=10)
|
|||
|
|
|
|||
|
|
elif tool.name == 'get_value_counts':
|
|||
|
|
col_name = df.columns[0]
|
|||
|
|
result = tool.execute(df, column=col_name)
|
|||
|
|
|
|||
|
|
elif tool.name == 'calculate_statistics':
|
|||
|
|
# 找到数值列
|
|||
|
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|||
|
|
if len(numeric_cols) > 0:
|
|||
|
|
result = tool.execute(df, column=numeric_cols[0])
|
|||
|
|
|
|||
|
|
elif tool.name == 'perform_groupby':
|
|||
|
|
# 使用第一列作为分组列
|
|||
|
|
result = tool.execute(df, group_by=df.columns[0])
|
|||
|
|
|
|||
|
|
elif tool.name == 'detect_outliers':
|
|||
|
|
# 找到数值列
|
|||
|
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|||
|
|
if len(numeric_cols) > 0:
|
|||
|
|
result = tool.execute(df, column=numeric_cols[0])
|
|||
|
|
|
|||
|
|
elif tool.name == 'get_correlation':
|
|||
|
|
# 需要至少两个数值列
|
|||
|
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|||
|
|
if len(numeric_cols) >= 2:
|
|||
|
|
result = tool.execute(df)
|
|||
|
|
|
|||
|
|
elif tool.name == 'get_time_series':
|
|||
|
|
# 创建带时间列的 DataFrame
|
|||
|
|
df_with_time = df.copy()
|
|||
|
|
df_with_time['time_col'] = pd.date_range('2020-01-01', periods=len(df))
|
|||
|
|
result = tool.execute(df_with_time, time_column='time_col')
|
|||
|
|
|
|||
|
|
elif tool.name == 'calculate_trend':
|
|||
|
|
# 创建带时间列和数值列的 DataFrame
|
|||
|
|
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
|||
|
|
if len(numeric_cols) > 0:
|
|||
|
|
df_with_time = df.copy()
|
|||
|
|
df_with_time['time_col'] = pd.date_range('2020-01-01', periods=len(df))
|
|||
|
|
result = tool.execute(df_with_time, time_column='time_col', value_column=numeric_cols[0])
|
|||
|
|
|
|||
|
|
except (KeyError, ValueError, TypeError) as e:
|
|||
|
|
# 工具可能因为数据不适用而失败,这是可以接受的
|
|||
|
|
# 跳过此测试用例
|
|||
|
|
assume(False)
|
|||
|
|
|
|||
|
|
# 如果没有结果(工具不适用),跳过验证
|
|||
|
|
if result is None:
|
|||
|
|
assume(False)
|
|||
|
|
|
|||
|
|
# 如果结果包含错误,跳过验证(工具正确地拒绝了不适用的数据)
|
|||
|
|
if 'error' in result:
|
|||
|
|
assume(False)
|
|||
|
|
|
|||
|
|
# 验证:结果应该是字典
|
|||
|
|
assert isinstance(result, dict), f"工具 {tool.name} 返回值不是字典"
|
|||
|
|
|
|||
|
|
# 验证:结果不应包含完整的原始数据
|
|||
|
|
# 检查结果中的所有值
|
|||
|
|
def count_data_rows(obj, max_depth=5):
|
|||
|
|
"""递归计数结果中的数据行数"""
|
|||
|
|
if max_depth <= 0:
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
if isinstance(obj, list):
|
|||
|
|
# 如果是列表,检查长度
|
|||
|
|
return len(obj)
|
|||
|
|
elif isinstance(obj, dict):
|
|||
|
|
# 如果是字典,递归检查所有值
|
|||
|
|
max_count = 0
|
|||
|
|
for value in obj.values():
|
|||
|
|
count = count_data_rows(value, max_depth - 1)
|
|||
|
|
max_count = max(max_count, count)
|
|||
|
|
return max_count
|
|||
|
|
else:
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
# 计算结果中的最大数据行数
|
|||
|
|
max_rows_in_result = count_data_rows(result)
|
|||
|
|
|
|||
|
|
# 验证:单次返回的数据行数不应超过100行
|
|||
|
|
assert max_rows_in_result <= 100, (
|
|||
|
|
f"工具 {tool.name} 返回了 {max_rows_in_result} 行数据,"
|
|||
|
|
f"超过了100行的限制。原始数据有 {len(df)} 行。"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证:结果应该是聚合数据,而不是原始数据
|
|||
|
|
# 检查结果的大小是否明显小于原始数据
|
|||
|
|
# 聚合结果的行数应该远小于原始数据行数
|
|||
|
|
if max_rows_in_result > 0:
|
|||
|
|
compression_ratio = max_rows_in_result / len(df)
|
|||
|
|
# 聚合结果应该至少压缩到原始数据的60%以下
|
|||
|
|
# (对于200+行的数据,聚合结果应该显著更小)
|
|||
|
|
# 注意:时间序列工具可能返回最多100个数据点,所以对于200行数据,压缩比是50%
|
|||
|
|
assert compression_ratio <= 0.6, (
|
|||
|
|
f"工具 {tool.name} 的输出压缩比 {compression_ratio:.2%} 太高,"
|
|||
|
|
f"可能返回了过多的原始数据而不是聚合结果"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证:结果应该包含聚合信息而不是原始行数据
|
|||
|
|
# 检查结果中是否包含典型的聚合字段
|
|||
|
|
aggregation_indicators = [
|
|||
|
|
'count', 'sum', 'mean', 'median', 'std', 'min', 'max',
|
|||
|
|
'distribution', 'groups', 'correlation', 'statistics',
|
|||
|
|
'time_series', 'aggregation', 'value_counts'
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
has_aggregation = any(
|
|||
|
|
indicator in str(result).lower()
|
|||
|
|
for indicator in aggregation_indicators
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 如果结果有数据,应该包含聚合指标
|
|||
|
|
if max_rows_in_result > 0:
|
|||
|
|
assert has_aggregation, (
|
|||
|
|
f"工具 {tool.name} 的结果似乎不包含聚合信息,"
|
|||
|
|
f"可能返回了原始数据而不是聚合结果"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Feature: true-ai-agent, Property 9: 工具选择适配性
|
|||
|
|
@given(data_profile=data_profile_strategy())
|
|||
|
|
@settings(max_examples=20)
|
|||
|
|
def test_tool_selection_adaptability(data_profile):
|
|||
|
|
"""
|
|||
|
|
属性 9:对于任何数据画像,工具管理器选择的工具集应该与数据特征匹配:
|
|||
|
|
包含时间字段时启用时间序列工具,包含分类字段时启用分布分析工具,
|
|||
|
|
包含数值字段时启用统计工具,不包含地理字段时不启用地理工具。
|
|||
|
|
|
|||
|
|
验证需求:工具动态性验收.1, 工具动态性验收.2, FR-4.2
|
|||
|
|
"""
|
|||
|
|
from src.tools.tool_manager import ToolManager
|
|||
|
|
|
|||
|
|
# 创建工具管理器并注册所有工具
|
|||
|
|
registry = ToolRegistry()
|
|||
|
|
for tool_class in ALL_TOOLS:
|
|||
|
|
registry.register(tool_class())
|
|||
|
|
|
|||
|
|
manager = ToolManager(registry)
|
|||
|
|
|
|||
|
|
# 选择工具
|
|||
|
|
selected_tools = manager.select_tools(data_profile)
|
|||
|
|
selected_tool_names = [tool.name for tool in selected_tools]
|
|||
|
|
|
|||
|
|
# 验证:如果包含时间字段,应该启用时间序列工具
|
|||
|
|
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
|
|||
|
|
time_series_tools = ['get_time_series', 'calculate_trend', 'create_line_chart']
|
|||
|
|
|
|||
|
|
if has_datetime:
|
|||
|
|
# 至少应该有一个时间序列工具被选中
|
|||
|
|
has_time_tool = any(tool_name in selected_tool_names for tool_name in time_series_tools)
|
|||
|
|
assert has_time_tool, (
|
|||
|
|
f"数据包含时间字段,但没有选择时间序列工具。"
|
|||
|
|
f"选中的工具:{selected_tool_names}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证:如果包含分类字段,应该启用分布分析工具
|
|||
|
|
has_categorical = any(col.dtype == 'categorical' for col in data_profile.columns)
|
|||
|
|
categorical_tools = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
|
|||
|
|
'create_bar_chart', 'create_pie_chart']
|
|||
|
|
|
|||
|
|
if has_categorical:
|
|||
|
|
# 至少应该有一个分类工具被选中
|
|||
|
|
has_cat_tool = any(tool_name in selected_tool_names for tool_name in categorical_tools)
|
|||
|
|
assert has_cat_tool, (
|
|||
|
|
f"数据包含分类字段,但没有选择分类分析工具。"
|
|||
|
|
f"选中的工具:{selected_tool_names}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证:如果包含数值字段,应该启用统计工具
|
|||
|
|
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
|
|||
|
|
numeric_tools = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
|
|||
|
|
|
|||
|
|
if has_numeric:
|
|||
|
|
# 至少应该有一个数值工具被选中
|
|||
|
|
has_num_tool = any(tool_name in selected_tool_names for tool_name in numeric_tools)
|
|||
|
|
assert has_num_tool, (
|
|||
|
|
f"数据包含数值字段,但没有选择统计分析工具。"
|
|||
|
|
f"选中的工具:{selected_tool_names}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证:如果不包含地理字段,不应该启用地理工具
|
|||
|
|
geo_keywords = ['lat', 'lon', 'latitude', 'longitude', 'location', 'address', 'city', 'country']
|
|||
|
|
has_geo = any(
|
|||
|
|
any(keyword in col.name.lower() for keyword in geo_keywords)
|
|||
|
|
for col in data_profile.columns
|
|||
|
|
)
|
|||
|
|
geo_tools = ['create_map_visualization']
|
|||
|
|
|
|||
|
|
if not has_geo:
|
|||
|
|
# 不应该有地理工具被选中
|
|||
|
|
has_geo_tool = any(tool_name in selected_tool_names for tool_name in geo_tools)
|
|||
|
|
assert not has_geo_tool, (
|
|||
|
|
f"数据不包含地理字段,但选择了地理工具。"
|
|||
|
|
f"选中的工具:{selected_tool_names}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Feature: true-ai-agent, Property 11: 工具适用性判断
|
|||
|
|
@given(
|
|||
|
|
tool_class=st.sampled_from(ALL_TOOLS),
|
|||
|
|
data_profile=data_profile_strategy()
|
|||
|
|
)
|
|||
|
|
@settings(max_examples=20)
|
|||
|
|
def test_tool_applicability_judgment(tool_class, data_profile):
|
|||
|
|
"""
|
|||
|
|
属性 11:对于任何工具和数据画像,工具的 is_applicable 方法应该正确判断
|
|||
|
|
该工具是否适用于当前数据(例如时间序列工具只适用于包含时间字段的数据)。
|
|||
|
|
|
|||
|
|
验证需求:FR-4.3
|
|||
|
|
"""
|
|||
|
|
# 创建工具实例
|
|||
|
|
tool = tool_class()
|
|||
|
|
|
|||
|
|
# 调用 is_applicable 方法
|
|||
|
|
is_applicable = tool.is_applicable(data_profile)
|
|||
|
|
|
|||
|
|
# 验证:返回值应该是布尔值
|
|||
|
|
assert isinstance(is_applicable, bool), (
|
|||
|
|
f"工具 {tool.name} 的 is_applicable 方法返回了非布尔值:{type(is_applicable)}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证:适用性判断应该与数据特征一致
|
|||
|
|
# 根据工具类型检查适用性逻辑
|
|||
|
|
|
|||
|
|
if tool.name in ['get_time_series', 'calculate_trend']:
|
|||
|
|
# 时间序列工具应该只适用于包含时间字段的数据
|
|||
|
|
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
|
|||
|
|
|
|||
|
|
# calculate_trend 还需要数值列
|
|||
|
|
if tool.name == 'calculate_trend':
|
|||
|
|
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
|
|||
|
|
if has_datetime and has_numeric:
|
|||
|
|
# 如果有时间字段和数值字段,工具应该适用
|
|||
|
|
assert is_applicable, (
|
|||
|
|
f"工具 {tool.name} 应该适用于包含时间字段和数值字段的数据,"
|
|||
|
|
f"但 is_applicable 返回 False"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# get_time_series 只需要时间字段
|
|||
|
|
if has_datetime:
|
|||
|
|
# 如果有时间字段,工具应该适用
|
|||
|
|
assert is_applicable, (
|
|||
|
|
f"工具 {tool.name} 应该适用于包含时间字段的数据,"
|
|||
|
|
f"但 is_applicable 返回 False"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
elif tool.name in ['calculate_statistics', 'detect_outliers']:
|
|||
|
|
# 统计工具应该只适用于包含数值字段的数据
|
|||
|
|
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
|
|||
|
|
if has_numeric:
|
|||
|
|
# 如果有数值字段,工具应该适用
|
|||
|
|
assert is_applicable, (
|
|||
|
|
f"工具 {tool.name} 应该适用于包含数值字段的数据,"
|
|||
|
|
f"但 is_applicable 返回 False"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
elif tool.name == 'get_correlation':
|
|||
|
|
# 相关性工具需要至少两个数值字段
|
|||
|
|
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
|
|||
|
|
has_enough_numeric = len(numeric_cols) >= 2
|
|||
|
|
if has_enough_numeric:
|
|||
|
|
# 如果有足够的数值字段,工具应该适用
|
|||
|
|
assert is_applicable, (
|
|||
|
|
f"工具 {tool.name} 应该适用于包含至少两个数值字段的数据,"
|
|||
|
|
f"但 is_applicable 返回 False"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# 如果数值字段不足,工具不应该适用
|
|||
|
|
assert not is_applicable, (
|
|||
|
|
f"工具 {tool.name} 不应该适用于数值字段少于2个的数据,"
|
|||
|
|
f"但 is_applicable 返回 True"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
elif tool.name == 'create_heatmap':
|
|||
|
|
# 热力图工具需要至少两个数值字段
|
|||
|
|
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
|
|||
|
|
has_enough_numeric = len(numeric_cols) >= 2
|
|||
|
|
if has_enough_numeric:
|
|||
|
|
# 如果有足够的数值字段,工具应该适用
|
|||
|
|
assert is_applicable, (
|
|||
|
|
f"工具 {tool.name} 应该适用于包含至少两个数值字段的数据,"
|
|||
|
|
f"但 is_applicable 返回 False"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# 如果数值字段不足,工具不应该适用
|
|||
|
|
assert not is_applicable, (
|
|||
|
|
f"工具 {tool.name} 不应该适用于数值字段少于2个的数据,"
|
|||
|
|
f"但 is_applicable 返回 True"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Feature: true-ai-agent, Property 12: 工具需求识别
|
|||
|
|
@given(data_profile=data_profile_strategy())
|
|||
|
|
@settings(max_examples=20)
|
|||
|
|
def test_tool_requirement_identification(data_profile):
|
|||
|
|
"""
|
|||
|
|
属性 12:对于任何分析任务和可用工具集,如果任务需要的工具不在可用工具集中,
|
|||
|
|
工具管理器应该能够识别缺失的工具并记录需求。
|
|||
|
|
|
|||
|
|
验证需求:工具动态性验收.3, FR-4.2
|
|||
|
|
"""
|
|||
|
|
from src.tools.tool_manager import ToolManager
|
|||
|
|
|
|||
|
|
# 创建一个空的工具注册表(模拟缺失工具的情况)
|
|||
|
|
empty_registry = ToolRegistry()
|
|||
|
|
manager = ToolManager(empty_registry)
|
|||
|
|
|
|||
|
|
# 清空缺失工具列表
|
|||
|
|
manager.clear_missing_tools()
|
|||
|
|
|
|||
|
|
# 尝试选择工具
|
|||
|
|
selected_tools = manager.select_tools(data_profile)
|
|||
|
|
|
|||
|
|
# 获取缺失的工具列表
|
|||
|
|
missing_tools = manager.get_missing_tools()
|
|||
|
|
|
|||
|
|
# 验证:如果数据有特定特征,应该识别出相应的缺失工具
|
|||
|
|
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
|
|||
|
|
has_categorical = any(col.dtype == 'categorical' for col in data_profile.columns)
|
|||
|
|
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
|
|||
|
|
|
|||
|
|
# 如果有时间字段,应该识别出缺失的时间序列工具
|
|||
|
|
if has_datetime:
|
|||
|
|
time_tools = ['get_time_series', 'calculate_trend', 'create_line_chart']
|
|||
|
|
has_missing_time_tool = any(tool in missing_tools for tool in time_tools)
|
|||
|
|
assert has_missing_time_tool, (
|
|||
|
|
f"数据包含时间字段,但没有识别出缺失的时间序列工具。"
|
|||
|
|
f"缺失工具列表:{missing_tools}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 如果有分类字段,应该识别出缺失的分类工具
|
|||
|
|
if has_categorical:
|
|||
|
|
cat_tools = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
|
|||
|
|
'create_bar_chart', 'create_pie_chart']
|
|||
|
|
has_missing_cat_tool = any(tool in missing_tools for tool in cat_tools)
|
|||
|
|
assert has_missing_cat_tool, (
|
|||
|
|
f"数据包含分类字段,但没有识别出缺失的分类分析工具。"
|
|||
|
|
f"缺失工具列表:{missing_tools}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 如果有数值字段,应该识别出缺失的统计工具
|
|||
|
|
if has_numeric:
|
|||
|
|
num_tools = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
|
|||
|
|
has_missing_num_tool = any(tool in missing_tools for tool in num_tools)
|
|||
|
|
assert has_missing_num_tool, (
|
|||
|
|
f"数据包含数值字段,但没有识别出缺失的统计分析工具。"
|
|||
|
|
f"缺失工具列表:{missing_tools}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 额外测试:验证所有工具都正确实现了接口
|
|||
|
|
def test_all_tools_implement_interface():
|
|||
|
|
"""验证所有工具类都正确实现了 AnalysisTool 接口。"""
|
|||
|
|
for tool_class in ALL_TOOLS:
|
|||
|
|
tool = tool_class()
|
|||
|
|
|
|||
|
|
# 检查工具是 AnalysisTool 的实例
|
|||
|
|
assert isinstance(tool, AnalysisTool)
|
|||
|
|
|
|||
|
|
# 检查所有必需的方法都存在
|
|||
|
|
assert hasattr(tool, 'name')
|
|||
|
|
assert hasattr(tool, 'description')
|
|||
|
|
assert hasattr(tool, 'parameters')
|
|||
|
|
assert hasattr(tool, 'execute')
|
|||
|
|
assert hasattr(tool, 'is_applicable')
|
|||
|
|
|
|||
|
|
# 检查方法是可调用的
|
|||
|
|
assert callable(tool.execute)
|
|||
|
|
assert callable(tool.is_applicable)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 额外测试:验证工具注册表功能
|
|||
|
|
def test_tool_registry_with_all_tools():
|
|||
|
|
"""测试 ToolRegistry 与所有工具的正确工作。"""
|
|||
|
|
registry = ToolRegistry()
|
|||
|
|
|
|||
|
|
# 注册所有工具
|
|||
|
|
for tool_class in ALL_TOOLS:
|
|||
|
|
tool = tool_class()
|
|||
|
|
registry.register(tool)
|
|||
|
|
|
|||
|
|
# 验证所有工具都已注册
|
|||
|
|
registered_tools = registry.list_tools()
|
|||
|
|
assert len(registered_tools) == len(ALL_TOOLS)
|
|||
|
|
|
|||
|
|
# 验证我们可以检索每个工具
|
|||
|
|
for tool_class in ALL_TOOLS:
|
|||
|
|
tool = tool_class()
|
|||
|
|
retrieved_tool = registry.get_tool(tool.name)
|
|||
|
|
assert retrieved_tool.name == tool.name
|
|||
|
|
assert isinstance(retrieved_tool, AnalysisTool)
|