329 lines
10 KiB
Python
329 lines
10 KiB
Python
"""Unit tests for requirement understanding engine."""
|
|
|
|
import pytest
|
|
import tempfile
|
|
import os
|
|
|
|
from src.engines.requirement_understanding import (
|
|
understand_requirement,
|
|
parse_template,
|
|
check_data_requirement_match,
|
|
_fallback_requirement_understanding
|
|
)
|
|
from src.models.data_profile import DataProfile, ColumnInfo
|
|
from src.models.requirement_spec import RequirementSpec, AnalysisObjective
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_data_profile():
|
|
"""Create a sample data profile for testing."""
|
|
return DataProfile(
|
|
file_path='test.csv',
|
|
row_count=1000,
|
|
column_count=5,
|
|
columns=[
|
|
ColumnInfo(
|
|
name='created_at',
|
|
dtype='datetime',
|
|
missing_rate=0.0,
|
|
unique_count=1000,
|
|
sample_values=['2024-01-01', '2024-01-02'],
|
|
statistics={}
|
|
),
|
|
ColumnInfo(
|
|
name='status',
|
|
dtype='categorical',
|
|
missing_rate=0.1,
|
|
unique_count=5,
|
|
sample_values=['open', 'closed', 'pending'],
|
|
statistics={}
|
|
),
|
|
ColumnInfo(
|
|
name='type',
|
|
dtype='categorical',
|
|
missing_rate=0.0,
|
|
unique_count=10,
|
|
sample_values=['bug', 'feature'],
|
|
statistics={}
|
|
),
|
|
ColumnInfo(
|
|
name='priority',
|
|
dtype='numeric',
|
|
missing_rate=0.0,
|
|
unique_count=5,
|
|
sample_values=[1, 2, 3, 4, 5],
|
|
statistics={'mean': 3.0, 'std': 1.2}
|
|
),
|
|
ColumnInfo(
|
|
name='description',
|
|
dtype='text',
|
|
missing_rate=0.05,
|
|
unique_count=950,
|
|
sample_values=['Issue 1', 'Issue 2'],
|
|
statistics={}
|
|
)
|
|
],
|
|
inferred_type='ticket',
|
|
key_fields={'time': 'created_at', 'status': 'status', 'type': 'type'},
|
|
quality_score=85.0,
|
|
summary='Ticket data with 1000 rows and 5 columns'
|
|
)
|
|
|
|
|
|
def test_understand_health_requirement(sample_data_profile):
|
|
"""Test understanding "健康度" requirement."""
|
|
user_input = "我想了解工单的健康度"
|
|
|
|
# Use fallback to avoid API dependency
|
|
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
|
|
|
|
# Verify basic structure
|
|
assert isinstance(requirement, RequirementSpec)
|
|
assert requirement.user_input == user_input
|
|
assert len(requirement.objectives) > 0
|
|
|
|
# Verify health-related objective exists
|
|
health_objectives = [obj for obj in requirement.objectives if '健康' in obj.name]
|
|
assert len(health_objectives) > 0
|
|
|
|
# Verify objective has metrics
|
|
health_obj = health_objectives[0]
|
|
assert len(health_obj.metrics) > 0
|
|
assert health_obj.priority >= 1 and health_obj.priority <= 5
|
|
|
|
|
|
def test_understand_trend_requirement(sample_data_profile):
|
|
"""Test understanding trend analysis requirement."""
|
|
user_input = "分析趋势"
|
|
|
|
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
|
|
|
|
# Verify trend objective exists
|
|
trend_objectives = [obj for obj in requirement.objectives if '趋势' in obj.name]
|
|
assert len(trend_objectives) > 0
|
|
|
|
# Verify metrics
|
|
trend_obj = trend_objectives[0]
|
|
assert len(trend_obj.metrics) > 0
|
|
|
|
|
|
def test_understand_distribution_requirement(sample_data_profile):
|
|
"""Test understanding distribution analysis requirement."""
|
|
user_input = "查看分布情况"
|
|
|
|
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
|
|
|
|
# Verify distribution objective exists
|
|
dist_objectives = [obj for obj in requirement.objectives if '分布' in obj.name]
|
|
assert len(dist_objectives) > 0
|
|
|
|
|
|
def test_understand_generic_requirement(sample_data_profile):
|
|
"""Test understanding generic requirement without specific keywords."""
|
|
user_input = "帮我分析一下"
|
|
|
|
requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None)
|
|
|
|
# Should still generate at least one objective
|
|
assert len(requirement.objectives) > 0
|
|
|
|
# Should have default objective
|
|
assert any('综合' in obj.name or 'analysis' in obj.name.lower() for obj in requirement.objectives)
|
|
|
|
|
|
def test_parse_template_with_sections():
|
|
"""Test parsing template with sections."""
|
|
template_content = """# 分析报告
|
|
|
|
## 数据概览
|
|
这是数据概览部分
|
|
|
|
## 趋势分析
|
|
指标: 增长率, 变化趋势
|
|
图表: 时间序列图
|
|
|
|
## 分布分析
|
|
指标: 类别分布
|
|
图表: 柱状图, 饼图
|
|
"""
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
|
|
f.write(template_content)
|
|
template_path = f.name
|
|
|
|
try:
|
|
template_req = parse_template(template_path)
|
|
|
|
# Verify sections
|
|
assert len(template_req['sections']) >= 3
|
|
assert '分析报告' in template_req['sections']
|
|
assert '数据概览' in template_req['sections']
|
|
|
|
# Verify metrics
|
|
assert len(template_req['required_metrics']) >= 2
|
|
|
|
# Verify charts
|
|
assert len(template_req['required_charts']) >= 2
|
|
|
|
finally:
|
|
os.unlink(template_path)
|
|
|
|
|
|
def test_parse_nonexistent_template():
|
|
"""Test parsing non-existent template."""
|
|
template_req = parse_template('nonexistent.md')
|
|
|
|
# Should return empty structure
|
|
assert template_req['sections'] == []
|
|
assert template_req['required_metrics'] == []
|
|
assert template_req['required_charts'] == []
|
|
|
|
|
|
def test_check_data_satisfies_requirement(sample_data_profile):
|
|
"""Test checking when data satisfies requirement."""
|
|
# Create requirement that data can satisfy
|
|
requirement = RequirementSpec(
|
|
user_input="分析状态分布",
|
|
objectives=[
|
|
AnalysisObjective(
|
|
name="状态分析",
|
|
description="分析状态字段的分布",
|
|
metrics=["状态分布"],
|
|
priority=5
|
|
)
|
|
]
|
|
)
|
|
|
|
match_result = check_data_requirement_match(requirement, sample_data_profile)
|
|
|
|
# Should be satisfied
|
|
assert match_result['can_proceed'] is True
|
|
assert len(match_result['satisfied_objectives']) > 0
|
|
|
|
|
|
def test_check_data_missing_fields(sample_data_profile):
|
|
"""Test checking when data is missing required fields."""
|
|
# Create requirement that needs fields not in data
|
|
requirement = RequirementSpec(
|
|
user_input="分析地理分布",
|
|
objectives=[
|
|
AnalysisObjective(
|
|
name="地理分析",
|
|
description="分析地理位置分布",
|
|
metrics=["地理分布", "区域统计"],
|
|
priority=5
|
|
)
|
|
]
|
|
)
|
|
|
|
match_result = check_data_requirement_match(requirement, sample_data_profile)
|
|
|
|
# Verify structure
|
|
assert isinstance(match_result, dict)
|
|
assert 'missing_fields' in match_result
|
|
assert 'unsatisfied_objectives' in match_result
|
|
|
|
|
|
def test_check_time_based_requirement(sample_data_profile):
|
|
"""Test checking time-based requirement."""
|
|
requirement = RequirementSpec(
|
|
user_input="分析时间趋势",
|
|
objectives=[
|
|
AnalysisObjective(
|
|
name="时间分析",
|
|
description="分析随时间的变化",
|
|
metrics=["时间序列", "趋势"],
|
|
priority=5
|
|
)
|
|
]
|
|
)
|
|
|
|
match_result = check_data_requirement_match(requirement, sample_data_profile)
|
|
|
|
# Should be satisfied since we have datetime column
|
|
assert match_result['can_proceed'] is True
|
|
|
|
|
|
def test_check_status_based_requirement(sample_data_profile):
|
|
"""Test checking status-based requirement."""
|
|
requirement = RequirementSpec(
|
|
user_input="分析状态",
|
|
objectives=[
|
|
AnalysisObjective(
|
|
name="状态分析",
|
|
description="分析状态字段",
|
|
metrics=["状态分布", "状态变化"],
|
|
priority=5
|
|
)
|
|
]
|
|
)
|
|
|
|
match_result = check_data_requirement_match(requirement, sample_data_profile)
|
|
|
|
# Should be satisfied since we have status column
|
|
assert match_result['can_proceed'] is True
|
|
assert len(match_result['satisfied_objectives']) > 0
|
|
|
|
|
|
def test_requirement_with_template(sample_data_profile):
|
|
"""Test requirement understanding with template."""
|
|
template_content = """# 工单分析报告
|
|
|
|
## 状态分析
|
|
指标: 状态分布, 完成率
|
|
|
|
## 类型分析
|
|
指标: 类型分布
|
|
"""
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f:
|
|
f.write(template_content)
|
|
template_path = f.name
|
|
|
|
try:
|
|
requirement = _fallback_requirement_understanding(
|
|
"按模板分析",
|
|
sample_data_profile,
|
|
template_path
|
|
)
|
|
|
|
# Verify template is included
|
|
assert requirement.template_path == template_path
|
|
assert requirement.template_requirements is not None
|
|
|
|
# Verify template requirements structure
|
|
assert 'sections' in requirement.template_requirements
|
|
assert 'required_metrics' in requirement.template_requirements
|
|
|
|
finally:
|
|
os.unlink(template_path)
|
|
|
|
|
|
def test_multiple_objectives_priority():
|
|
"""Test that multiple objectives have proper priorities."""
|
|
data_profile = DataProfile(
|
|
file_path='test.csv',
|
|
row_count=100,
|
|
column_count=3,
|
|
columns=[
|
|
ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100),
|
|
ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5),
|
|
ColumnInfo(name='col3', dtype='datetime', missing_rate=0.0, unique_count=100)
|
|
],
|
|
inferred_type='unknown',
|
|
quality_score=90.0
|
|
)
|
|
|
|
requirement = _fallback_requirement_understanding(
|
|
"完整分析,包括健康度和趋势",
|
|
data_profile,
|
|
None
|
|
)
|
|
|
|
# Should have multiple objectives
|
|
assert len(requirement.objectives) >= 2
|
|
|
|
# All priorities should be valid
|
|
for obj in requirement.objectives:
|
|
assert 1 <= obj.priority <= 5
|