Files
vibe_data_ana/tests/test_requirement_understanding.py

329 lines
10 KiB
Python
Raw Normal View History

"""Unit tests for requirement understanding engine."""
import pytest
import tempfile
import os
from src.engines.requirement_understanding import (
understand_requirement,
parse_template,
check_data_requirement_match,
_fallback_requirement_understanding
)
from src.models.data_profile import DataProfile, ColumnInfo
from src.models.requirement_spec import RequirementSpec, AnalysisObjective
@pytest.fixture
def sample_data_profile():
"""Create a sample data profile for testing."""
return DataProfile(
file_path='test.csv',
row_count=1000,
column_count=5,
columns=[
ColumnInfo(
name='created_at',
dtype='datetime',
missing_rate=0.0,
unique_count=1000,
sample_values=['2024-01-01', '2024-01-02'],
statistics={}
),
ColumnInfo(
name='status',
dtype='categorical',
missing_rate=0.1,
unique_count=5,
sample_values=['open', 'closed', 'pending'],
statistics={}
),
ColumnInfo(
name='type',
dtype='categorical',
missing_rate=0.0,
unique_count=10,
sample_values=['bug', 'feature'],
statistics={}
),
ColumnInfo(
name='priority',
dtype='numeric',
missing_rate=0.0,
unique_count=5,
sample_values=[1, 2, 3, 4, 5],
statistics={'mean': 3.0, 'std': 1.2}
),
ColumnInfo(
name='description',
dtype='text',
missing_rate=0.05,
unique_count=950,
sample_values=['Issue 1', 'Issue 2'],
statistics={}
)
],
inferred_type='ticket',
key_fields={'time': 'created_at', 'status': 'status', 'type': 'type'},
quality_score=85.0,
summary='Ticket data with 1000 rows and 5 columns'
)
def test_understand_health_requirement(sample_data_profile):
"""Test understanding "健康度" requirement."""
user_input = "我想了解工单的健康度"
# Use fallback to avoid API dependency
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
# Verify basic structure
assert isinstance(requirement, RequirementSpec)
assert requirement.user_input == user_input
assert len(requirement.objectives) > 0
# Verify health-related objective exists
health_objectives = [obj for obj in requirement.objectives if '健康' in obj.name]
assert len(health_objectives) > 0
# Verify objective has metrics
health_obj = health_objectives[0]
assert len(health_obj.metrics) > 0
assert health_obj.priority >= 1 and health_obj.priority <= 5
def test_understand_trend_requirement(sample_data_profile):
"""Test understanding trend analysis requirement."""
user_input = "分析趋势"
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
# Verify trend objective exists
trend_objectives = [obj for obj in requirement.objectives if '趋势' in obj.name]
assert len(trend_objectives) > 0
# Verify metrics
trend_obj = trend_objectives[0]
assert len(trend_obj.metrics) > 0
def test_understand_distribution_requirement(sample_data_profile):
"""Test understanding distribution analysis requirement."""
user_input = "查看分布情况"
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
# Verify distribution objective exists
dist_objectives = [obj for obj in requirement.objectives if '分布' in obj.name]
assert len(dist_objectives) > 0
def test_understand_generic_requirement(sample_data_profile):
"""Test understanding generic requirement without specific keywords."""
user_input = "帮我分析一下"
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
# Should still generate at least one objective
assert len(requirement.objectives) > 0
# Should have default objective
assert any('综合' in obj.name or 'analysis' in obj.name.lower() for obj in requirement.objectives)
def test_parse_template_with_sections():
"""Test parsing template with sections."""
template_content = """# 分析报告
## 数据概览
这是数据概览部分
## 趋势分析
指标: 增长率, 变化趋势
图表: 时间序列图
## 分布分析
指标: 类别分布
图表: 柱状图, 饼图
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
f.write(template_content)
template_path = f.name
try:
template_req = parse_template(template_path)
# Verify sections
assert len(template_req['sections']) >= 3
assert '分析报告' in template_req['sections']
assert '数据概览' in template_req['sections']
# Verify metrics
assert len(template_req['required_metrics']) >= 2
# Verify charts
assert len(template_req['required_charts']) >= 2
finally:
os.unlink(template_path)
def test_parse_nonexistent_template():
"""Test parsing non-existent template."""
template_req = parse_template('nonexistent.md')
# Should return empty structure
assert template_req['sections'] == []
assert template_req['required_metrics'] == []
assert template_req['required_charts'] == []
def test_check_data_satisfies_requirement(sample_data_profile):
"""Test checking when data satisfies requirement."""
# Create requirement that data can satisfy
requirement = RequirementSpec(
user_input="分析状态分布",
objectives=[
AnalysisObjective(
name="状态分析",
description="分析状态字段的分布",
metrics=["状态分布"],
priority=5
)
]
)
match_result = check_data_requirement_match(requirement, sample_data_profile)
# Should be satisfied
assert match_result['can_proceed'] is True
assert len(match_result['satisfied_objectives']) > 0
def test_check_data_missing_fields(sample_data_profile):
"""Test checking when data is missing required fields."""
# Create requirement that needs fields not in data
requirement = RequirementSpec(
user_input="分析地理分布",
objectives=[
AnalysisObjective(
name="地理分析",
description="分析地理位置分布",
metrics=["地理分布", "区域统计"],
priority=5
)
]
)
match_result = check_data_requirement_match(requirement, sample_data_profile)
# Verify structure
assert isinstance(match_result, dict)
assert 'missing_fields' in match_result
assert 'unsatisfied_objectives' in match_result
def test_check_time_based_requirement(sample_data_profile):
"""Test checking time-based requirement."""
requirement = RequirementSpec(
user_input="分析时间趋势",
objectives=[
AnalysisObjective(
name="时间分析",
description="分析随时间的变化",
metrics=["时间序列", "趋势"],
priority=5
)
]
)
match_result = check_data_requirement_match(requirement, sample_data_profile)
# Should be satisfied since we have datetime column
assert match_result['can_proceed'] is True
def test_check_status_based_requirement(sample_data_profile):
"""Test checking status-based requirement."""
requirement = RequirementSpec(
user_input="分析状态",
objectives=[
AnalysisObjective(
name="状态分析",
description="分析状态字段",
metrics=["状态分布", "状态变化"],
priority=5
)
]
)
match_result = check_data_requirement_match(requirement, sample_data_profile)
# Should be satisfied since we have status column
assert match_result['can_proceed'] is True
assert len(match_result['satisfied_objectives']) > 0
def test_requirement_with_template(sample_data_profile):
"""Test requirement understanding with template."""
template_content = """# 工单分析报告
## 状态分析
指标: 状态分布, 完成率
## 类型分析
指标: 类型分布
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
f.write(template_content)
template_path = f.name
try:
requirement = _fallback_requirement_understanding(
"按模板分析",
sample_data_profile,
template_path
)
# Verify template is included
assert requirement.template_path == template_path
assert requirement.template_requirements is not None
# Verify template requirements structure
assert 'sections' in requirement.template_requirements
assert 'required_metrics' in requirement.template_requirements
finally:
os.unlink(template_path)
def test_multiple_objectives_priority():
"""Test that multiple objectives have proper priorities."""
data_profile = DataProfile(
file_path='test.csv',
row_count=100,
column_count=3,
columns=[
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5),
ColumnInfo(name='col3', dtype='datetime', missing_rate=0.0, unique_count=100)
],
inferred_type='unknown',
quality_score=90.0
)
requirement = _fallback_requirement_understanding(
"完整分析,包括健康度和趋势",
data_profile,
None
)
# Should have multiple objectives
assert len(requirement.objectives) >= 2
# All priorities should be valid
for obj in requirement.objectives:
assert 1 <= obj.priority <= 5