Files
vibe_data_ana/tests/test_tools_properties.py

621 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""工具系统的基于属性的测试。"""
import pytest
import pandas as pd
import numpy as np
from hypothesis import given, strategies as st, settings, assume
from typing import Dict, Any
from src.tools.base import AnalysisTool, ToolRegistry
from src.tools.query_tools import (
GetColumnDistributionTool,
GetValueCountsTool,
GetTimeSeriesTool,
GetCorrelationTool
)
from src.tools.stats_tools import (
CalculateStatisticsTool,
PerformGroupbyTool,
DetectOutliersTool,
CalculateTrendTool
)
from src.models import DataProfile, ColumnInfo
# Hypothesis 策略用于生成测试数据
@st.composite
def column_info_strategy(draw):
"""生成随机的 ColumnInfo 实例。"""
dtype = draw(st.sampled_from(['numeric', 'categorical', 'datetime', 'text']))
return ColumnInfo(
name=draw(st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=('Lu', 'Ll')))),
dtype=dtype,
missing_rate=draw(st.floats(min_value=0.0, max_value=1.0)),
unique_count=draw(st.integers(min_value=1, max_value=1000)),
sample_values=draw(st.lists(st.integers(), min_size=1, max_size=5)),
statistics={'mean': draw(st.floats(allow_nan=False, allow_infinity=False))} if dtype == 'numeric' else {}
)
@st.composite
def data_profile_strategy(draw):
"""生成随机的 DataProfile 实例。"""
columns = draw(st.lists(column_info_strategy(), min_size=1, max_size=10))
return DataProfile(
file_path=draw(st.text(min_size=1, max_size=50)),
row_count=draw(st.integers(min_value=1, max_value=10000)),
column_count=len(columns),
columns=columns,
inferred_type=draw(st.sampled_from(['ticket', 'sales', 'user', 'unknown'])),
key_fields={},
quality_score=draw(st.floats(min_value=0.0, max_value=100.0)),
summary=draw(st.text(max_size=100))
)
@st.composite
def dataframe_strategy(draw, min_rows=10, max_rows=100, min_cols=2, max_cols=10):
"""生成随机的 DataFrame 实例。"""
n_rows = draw(st.integers(min_value=min_rows, max_value=max_rows))
n_cols = draw(st.integers(min_value=min_cols, max_value=max_cols))
data = {}
for i in range(n_cols):
col_type = draw(st.sampled_from(['int', 'float', 'str']))
col_name = f'col_{i}'
if col_type == 'int':
data[col_name] = draw(st.lists(
st.integers(min_value=-1000, max_value=1000),
min_size=n_rows,
max_size=n_rows
))
elif col_type == 'float':
data[col_name] = draw(st.lists(
st.floats(min_value=-1000.0, max_value=1000.0, allow_nan=False, allow_infinity=False),
min_size=n_rows,
max_size=n_rows
))
else: # str
data[col_name] = draw(st.lists(
st.text(min_size=1, max_size=10, alphabet=st.characters(whitelist_categories=('Lu', 'Ll'))),
min_size=n_rows,
max_size=n_rows
))
return pd.DataFrame(data)
# 获取所有工具类用于测试
ALL_TOOLS = [
GetColumnDistributionTool,
GetValueCountsTool,
GetTimeSeriesTool,
GetCorrelationTool,
CalculateStatisticsTool,
PerformGroupbyTool,
DetectOutliersTool,
CalculateTrendTool
]
# Feature: true-ai-agent, Property 10: 工具接口一致性
@given(tool_class=st.sampled_from(ALL_TOOLS))
@settings(max_examples=20)
def test_tool_interface_consistency(tool_class):
"""
属性 10对于任何工具它应该实现标准接口name, description, parameters,
execute, is_applicable并且 execute 方法应该接受 DataFrame 和参数,
返回字典格式的聚合结果。
验证需求FR-4.1
"""
# 创建工具实例
tool = tool_class()
# 验证:工具应该是 AnalysisTool 的子类
assert isinstance(tool, AnalysisTool), f"{tool_class.__name__} 不是 AnalysisTool 的子类"
# 验证:工具应该有 name 属性,且返回字符串
assert hasattr(tool, 'name'), f"{tool_class.__name__} 缺少 name 属性"
assert isinstance(tool.name, str), f"{tool_class.__name__}.name 不是字符串"
assert len(tool.name) > 0, f"{tool_class.__name__}.name 是空字符串"
# 验证:工具应该有 description 属性,且返回字符串
assert hasattr(tool, 'description'), f"{tool_class.__name__} 缺少 description 属性"
assert isinstance(tool.description, str), f"{tool_class.__name__}.description 不是字符串"
assert len(tool.description) > 0, f"{tool_class.__name__}.description 是空字符串"
# 验证:工具应该有 parameters 属性,且返回字典
assert hasattr(tool, 'parameters'), f"{tool_class.__name__} 缺少 parameters 属性"
assert isinstance(tool.parameters, dict), f"{tool_class.__name__}.parameters 不是字典"
# 验证parameters 应该符合 JSON Schema 格式
params = tool.parameters
assert 'type' in params, f"{tool_class.__name__}.parameters 缺少 'type' 字段"
assert params['type'] == 'object', f"{tool_class.__name__}.parameters.type 不是 'object'"
# 验证:工具应该有 execute 方法
assert hasattr(tool, 'execute'), f"{tool_class.__name__} 缺少 execute 方法"
assert callable(tool.execute), f"{tool_class.__name__}.execute 不可调用"
# 验证:工具应该有 is_applicable 方法
assert hasattr(tool, 'is_applicable'), f"{tool_class.__name__} 缺少 is_applicable 方法"
assert callable(tool.is_applicable), f"{tool_class.__name__}.is_applicable 不可调用"
# 验证execute 方法应该接受 DataFrame 和关键字参数
# 创建一个简单的测试 DataFrame
test_df = pd.DataFrame({
'col_0': [1, 2, 3, 4, 5],
'col_1': ['a', 'b', 'c', 'd', 'e']
})
# 尝试调用 execute可能会失败但不应该因为签名问题
try:
# 使用空参数调用(可能会因为缺少必需参数而失败,这是预期的)
result = tool.execute(test_df)
except (KeyError, ValueError, TypeError) as e:
# 这些异常是可以接受的(参数验证失败)
pass
# 验证execute 方法应该返回字典
# 我们需要提供有效的参数来测试返回类型
# 根据工具类型提供适当的参数
if tool.name == 'get_column_distribution':
result = tool.execute(test_df, column='col_0')
elif tool.name == 'get_value_counts':
result = tool.execute(test_df, column='col_0')
elif tool.name == 'calculate_statistics':
result = tool.execute(test_df, column='col_0')
elif tool.name == 'perform_groupby':
result = tool.execute(test_df, group_by='col_1')
elif tool.name == 'detect_outliers':
result = tool.execute(test_df, column='col_0')
elif tool.name == 'get_correlation':
test_df_numeric = pd.DataFrame({
'col_0': [1, 2, 3, 4, 5],
'col_1': [2, 4, 6, 8, 10]
})
result = tool.execute(test_df_numeric)
elif tool.name == 'get_time_series':
test_df_time = pd.DataFrame({
'time': pd.date_range('2020-01-01', periods=5),
'value': [1, 2, 3, 4, 5]
})
result = tool.execute(test_df_time, time_column='time')
elif tool.name == 'calculate_trend':
test_df_trend = pd.DataFrame({
'time': pd.date_range('2020-01-01', periods=5),
'value': [1, 2, 3, 4, 5]
})
result = tool.execute(test_df_trend, time_column='time', value_column='value')
else:
# 未知工具,跳过返回类型验证
return
# 验证:返回值应该是字典
assert isinstance(result, dict), f"{tool_class.__name__}.execute 返回值不是字典,而是 {type(result)}"
# Feature: true-ai-agent, Property 19: 工具输出过滤
@given(
tool_class=st.sampled_from(ALL_TOOLS),
df=dataframe_strategy(min_rows=200, max_rows=500)
)
@settings(max_examples=20, deadline=None)
def test_tool_output_filtering(tool_class, df):
"""
属性 19对于任何工具的执行结果返回的数据应该是聚合后的如统计值、
分组计数、图表数据单次返回的数据行数不应超过100行并且不应包含
完整的原始数据表。
验证需求约束条件5.3
"""
# 创建工具实例
tool = tool_class()
# 确保 DataFrame 有足够的行数来测试过滤
assume(len(df) >= 200)
# 根据工具类型准备适当的参数和数据
result = None
try:
if tool.name == 'get_column_distribution':
# 使用第一列
col_name = df.columns[0]
result = tool.execute(df, column=col_name, top_n=10)
elif tool.name == 'get_value_counts':
col_name = df.columns[0]
result = tool.execute(df, column=col_name)
elif tool.name == 'calculate_statistics':
# 找到数值列
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
result = tool.execute(df, column=numeric_cols[0])
elif tool.name == 'perform_groupby':
# 使用第一列作为分组列
result = tool.execute(df, group_by=df.columns[0])
elif tool.name == 'detect_outliers':
# 找到数值列
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
result = tool.execute(df, column=numeric_cols[0])
elif tool.name == 'get_correlation':
# 需要至少两个数值列
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) >= 2:
result = tool.execute(df)
elif tool.name == 'get_time_series':
# 创建带时间列的 DataFrame
df_with_time = df.copy()
df_with_time['time_col'] = pd.date_range('2020-01-01', periods=len(df))
result = tool.execute(df_with_time, time_column='time_col')
elif tool.name == 'calculate_trend':
# 创建带时间列和数值列的 DataFrame
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 0:
df_with_time = df.copy()
df_with_time['time_col'] = pd.date_range('2020-01-01', periods=len(df))
result = tool.execute(df_with_time, time_column='time_col', value_column=numeric_cols[0])
except (KeyError, ValueError, TypeError) as e:
# 工具可能因为数据不适用而失败,这是可以接受的
# 跳过此测试用例
assume(False)
# 如果没有结果(工具不适用),跳过验证
if result is None:
assume(False)
# 如果结果包含错误,跳过验证(工具正确地拒绝了不适用的数据)
if 'error' in result:
assume(False)
# 验证:结果应该是字典
assert isinstance(result, dict), f"工具 {tool.name} 返回值不是字典"
# 验证:结果不应包含完整的原始数据
# 检查结果中的所有值
def count_data_rows(obj, max_depth=5):
"""递归计数结果中的数据行数"""
if max_depth <= 0:
return 0
if isinstance(obj, list):
# 如果是列表,检查长度
return len(obj)
elif isinstance(obj, dict):
# 如果是字典,递归检查所有值
max_count = 0
for value in obj.values():
count = count_data_rows(value, max_depth - 1)
max_count = max(max_count, count)
return max_count
else:
return 0
# 计算结果中的最大数据行数
max_rows_in_result = count_data_rows(result)
# 验证单次返回的数据行数不应超过100行
assert max_rows_in_result <= 100, (
f"工具 {tool.name} 返回了 {max_rows_in_result} 行数据,"
f"超过了100行的限制。原始数据有 {len(df)} 行。"
)
# 验证:结果应该是聚合数据,而不是原始数据
# 检查结果的大小是否明显小于原始数据
# 聚合结果的行数应该远小于原始数据行数
if max_rows_in_result > 0:
compression_ratio = max_rows_in_result / len(df)
# 聚合结果应该至少压缩到原始数据的60%以下
# 对于200+行的数据,聚合结果应该显著更小)
# 注意时间序列工具可能返回最多100个数据点所以对于200行数据压缩比是50%
assert compression_ratio <= 0.6, (
f"工具 {tool.name} 的输出压缩比 {compression_ratio:.2%} 太高,"
f"可能返回了过多的原始数据而不是聚合结果"
)
# 验证:结果应该包含聚合信息而不是原始行数据
# 检查结果中是否包含典型的聚合字段
aggregation_indicators = [
'count', 'sum', 'mean', 'median', 'std', 'min', 'max',
'distribution', 'groups', 'correlation', 'statistics',
'time_series', 'aggregation', 'value_counts'
]
has_aggregation = any(
indicator in str(result).lower()
for indicator in aggregation_indicators
)
# 如果结果有数据,应该包含聚合指标
if max_rows_in_result > 0:
assert has_aggregation, (
f"工具 {tool.name} 的结果似乎不包含聚合信息,"
f"可能返回了原始数据而不是聚合结果"
)
# Feature: true-ai-agent, Property 9: 工具选择适配性
@given(data_profile=data_profile_strategy())
@settings(max_examples=20)
def test_tool_selection_adaptability(data_profile):
"""
属性 9对于任何数据画像工具管理器选择的工具集应该与数据特征匹配
包含时间字段时启用时间序列工具,包含分类字段时启用分布分析工具,
包含数值字段时启用统计工具,不包含地理字段时不启用地理工具。
验证需求:工具动态性验收.1, 工具动态性验收.2, FR-4.2
"""
from src.tools.tool_manager import ToolManager
# 创建工具管理器并注册所有工具
registry = ToolRegistry()
for tool_class in ALL_TOOLS:
registry.register(tool_class())
manager = ToolManager(registry)
# 选择工具
selected_tools = manager.select_tools(data_profile)
selected_tool_names = [tool.name for tool in selected_tools]
# 验证:如果包含时间字段,应该启用时间序列工具
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
time_series_tools = ['get_time_series', 'calculate_trend', 'create_line_chart']
if has_datetime:
# 至少应该有一个时间序列工具被选中
has_time_tool = any(tool_name in selected_tool_names for tool_name in time_series_tools)
assert has_time_tool, (
f"数据包含时间字段,但没有选择时间序列工具。"
f"选中的工具:{selected_tool_names}"
)
# 验证:如果包含分类字段,应该启用分布分析工具
has_categorical = any(col.dtype == 'categorical' for col in data_profile.columns)
categorical_tools = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
'create_bar_chart', 'create_pie_chart']
if has_categorical:
# 至少应该有一个分类工具被选中
has_cat_tool = any(tool_name in selected_tool_names for tool_name in categorical_tools)
assert has_cat_tool, (
f"数据包含分类字段,但没有选择分类分析工具。"
f"选中的工具:{selected_tool_names}"
)
# 验证:如果包含数值字段,应该启用统计工具
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
numeric_tools = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
if has_numeric:
# 至少应该有一个数值工具被选中
has_num_tool = any(tool_name in selected_tool_names for tool_name in numeric_tools)
assert has_num_tool, (
f"数据包含数值字段,但没有选择统计分析工具。"
f"选中的工具:{selected_tool_names}"
)
# 验证:如果不包含地理字段,不应该启用地理工具
geo_keywords = ['lat', 'lon', 'latitude', 'longitude', 'location', 'address', 'city', 'country']
has_geo = any(
any(keyword in col.name.lower() for keyword in geo_keywords)
for col in data_profile.columns
)
geo_tools = ['create_map_visualization']
if not has_geo:
# 不应该有地理工具被选中
has_geo_tool = any(tool_name in selected_tool_names for tool_name in geo_tools)
assert not has_geo_tool, (
f"数据不包含地理字段,但选择了地理工具。"
f"选中的工具:{selected_tool_names}"
)
# Feature: true-ai-agent, Property 11: 工具适用性判断
@given(
tool_class=st.sampled_from(ALL_TOOLS),
data_profile=data_profile_strategy()
)
@settings(max_examples=20)
def test_tool_applicability_judgment(tool_class, data_profile):
"""
属性 11对于任何工具和数据画像工具的 is_applicable 方法应该正确判断
该工具是否适用于当前数据(例如时间序列工具只适用于包含时间字段的数据)。
验证需求FR-4.3
"""
# 创建工具实例
tool = tool_class()
# 调用 is_applicable 方法
is_applicable = tool.is_applicable(data_profile)
# 验证:返回值应该是布尔值
assert isinstance(is_applicable, bool), (
f"工具 {tool.name} 的 is_applicable 方法返回了非布尔值:{type(is_applicable)}"
)
# 验证:适用性判断应该与数据特征一致
# 根据工具类型检查适用性逻辑
if tool.name in ['get_time_series', 'calculate_trend']:
# 时间序列工具应该只适用于包含时间字段的数据
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
# calculate_trend 还需要数值列
if tool.name == 'calculate_trend':
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
if has_datetime and has_numeric:
# 如果有时间字段和数值字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含时间字段和数值字段的数据,"
f"但 is_applicable 返回 False"
)
else:
# get_time_series 只需要时间字段
if has_datetime:
# 如果有时间字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含时间字段的数据,"
f"但 is_applicable 返回 False"
)
elif tool.name in ['calculate_statistics', 'detect_outliers']:
# 统计工具应该只适用于包含数值字段的数据
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
if has_numeric:
# 如果有数值字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含数值字段的数据,"
f"但 is_applicable 返回 False"
)
elif tool.name == 'get_correlation':
# 相关性工具需要至少两个数值字段
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
has_enough_numeric = len(numeric_cols) >= 2
if has_enough_numeric:
# 如果有足够的数值字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含至少两个数值字段的数据,"
f"但 is_applicable 返回 False"
)
else:
# 如果数值字段不足,工具不应该适用
assert not is_applicable, (
f"工具 {tool.name} 不应该适用于数值字段少于2个的数据"
f"但 is_applicable 返回 True"
)
elif tool.name == 'create_heatmap':
# 热力图工具需要至少两个数值字段
numeric_cols = [col for col in data_profile.columns if col.dtype == 'numeric']
has_enough_numeric = len(numeric_cols) >= 2
if has_enough_numeric:
# 如果有足够的数值字段,工具应该适用
assert is_applicable, (
f"工具 {tool.name} 应该适用于包含至少两个数值字段的数据,"
f"但 is_applicable 返回 False"
)
else:
# 如果数值字段不足,工具不应该适用
assert not is_applicable, (
f"工具 {tool.name} 不应该适用于数值字段少于2个的数据"
f"但 is_applicable 返回 True"
)
# Feature: true-ai-agent, Property 12: 工具需求识别
@given(data_profile=data_profile_strategy())
@settings(max_examples=20)
def test_tool_requirement_identification(data_profile):
"""
属性 12对于任何分析任务和可用工具集如果任务需要的工具不在可用工具集中
工具管理器应该能够识别缺失的工具并记录需求。
验证需求:工具动态性验收.3, FR-4.2
"""
from src.tools.tool_manager import ToolManager
# 创建一个空的工具注册表(模拟缺失工具的情况)
empty_registry = ToolRegistry()
manager = ToolManager(empty_registry)
# 清空缺失工具列表
manager.clear_missing_tools()
# 尝试选择工具
selected_tools = manager.select_tools(data_profile)
# 获取缺失的工具列表
missing_tools = manager.get_missing_tools()
# 验证:如果数据有特定特征,应该识别出相应的缺失工具
has_datetime = any(col.dtype == 'datetime' for col in data_profile.columns)
has_categorical = any(col.dtype == 'categorical' for col in data_profile.columns)
has_numeric = any(col.dtype == 'numeric' for col in data_profile.columns)
# 如果有时间字段,应该识别出缺失的时间序列工具
if has_datetime:
time_tools = ['get_time_series', 'calculate_trend', 'create_line_chart']
has_missing_time_tool = any(tool in missing_tools for tool in time_tools)
assert has_missing_time_tool, (
f"数据包含时间字段,但没有识别出缺失的时间序列工具。"
f"缺失工具列表:{missing_tools}"
)
# 如果有分类字段,应该识别出缺失的分类工具
if has_categorical:
cat_tools = ['get_column_distribution', 'get_value_counts', 'perform_groupby',
'create_bar_chart', 'create_pie_chart']
has_missing_cat_tool = any(tool in missing_tools for tool in cat_tools)
assert has_missing_cat_tool, (
f"数据包含分类字段,但没有识别出缺失的分类分析工具。"
f"缺失工具列表:{missing_tools}"
)
# 如果有数值字段,应该识别出缺失的统计工具
if has_numeric:
num_tools = ['calculate_statistics', 'detect_outliers', 'get_correlation', 'create_heatmap']
has_missing_num_tool = any(tool in missing_tools for tool in num_tools)
assert has_missing_num_tool, (
f"数据包含数值字段,但没有识别出缺失的统计分析工具。"
f"缺失工具列表:{missing_tools}"
)
# 额外测试:验证所有工具都正确实现了接口
def test_all_tools_implement_interface():
"""验证所有工具类都正确实现了 AnalysisTool 接口。"""
for tool_class in ALL_TOOLS:
tool = tool_class()
# 检查工具是 AnalysisTool 的实例
assert isinstance(tool, AnalysisTool)
# 检查所有必需的方法都存在
assert hasattr(tool, 'name')
assert hasattr(tool, 'description')
assert hasattr(tool, 'parameters')
assert hasattr(tool, 'execute')
assert hasattr(tool, 'is_applicable')
# 检查方法是可调用的
assert callable(tool.execute)
assert callable(tool.is_applicable)
# 额外测试:验证工具注册表功能
def test_tool_registry_with_all_tools():
"""测试 ToolRegistry 与所有工具的正确工作。"""
registry = ToolRegistry()
# 注册所有工具
for tool_class in ALL_TOOLS:
tool = tool_class()
registry.register(tool)
# 验证所有工具都已注册
registered_tools = registry.list_tools()
assert len(registered_tools) == len(ALL_TOOLS)
# 验证我们可以检索每个工具
for tool_class in ALL_TOOLS:
tool = tool_class()
retrieved_tool = registry.get_tool(tool.name)
assert retrieved_tool.name == tool.name
assert isinstance(retrieved_tool, AnalysisTool)