"""Unit tests for requirement understanding engine.""" import pytest import tempfile import os from src.engines.requirement_understanding import ( understand_requirement, parse_template, check_data_requirement_match, _fallback_requirement_understanding ) from src.models.data_profile import DataProfile, ColumnInfo from src.models.requirement_spec import RequirementSpec, AnalysisObjective @pytest.fixture def sample_data_profile(): """Create a sample data profile for testing.""" return DataProfile( file_path='test.csv', row_count=1000, column_count=5, columns=[ ColumnInfo( name='created_at', dtype='datetime', missing_rate=0.0, unique_count=1000, sample_values=['2024-01-01', '2024-01-02'], statistics={} ), ColumnInfo( name='status', dtype='categorical', missing_rate=0.1, unique_count=5, sample_values=['open', 'closed', 'pending'], statistics={} ), ColumnInfo( name='type', dtype='categorical', missing_rate=0.0, unique_count=10, sample_values=['bug', 'feature'], statistics={} ), ColumnInfo( name='priority', dtype='numeric', missing_rate=0.0, unique_count=5, sample_values=[1, 2, 3, 4, 5], statistics={'mean': 3.0, 'std': 1.2} ), ColumnInfo( name='description', dtype='text', missing_rate=0.05, unique_count=950, sample_values=['Issue 1', 'Issue 2'], statistics={} ) ], inferred_type='ticket', key_fields={'time': 'created_at', 'status': 'status', 'type': 'type'}, quality_score=85.0, summary='Ticket data with 1000 rows and 5 columns' ) def test_understand_health_requirement(sample_data_profile): """Test understanding "健康度" requirement.""" user_input = "我想了解工单的健康度" # Use fallback to avoid API dependency requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None) # Verify basic structure assert isinstance(requirement, RequirementSpec) assert requirement.user_input == user_input assert len(requirement.objectives) > 0 # Verify health-related objective exists health_objectives = [obj for obj in requirement.objectives if '健康' in obj.name] assert len(health_objectives) > 0 # Verify objective has metrics health_obj = health_objectives[0] assert len(health_obj.metrics) > 0 assert health_obj.priority >= 1 and health_obj.priority <= 5 def test_understand_trend_requirement(sample_data_profile): """Test understanding trend analysis requirement.""" user_input = "分析趋势" requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None) # Verify trend objective exists trend_objectives = [obj for obj in requirement.objectives if '趋势' in obj.name] assert len(trend_objectives) > 0 # Verify metrics trend_obj = trend_objectives[0] assert len(trend_obj.metrics) > 0 def test_understand_distribution_requirement(sample_data_profile): """Test understanding distribution analysis requirement.""" user_input = "查看分布情况" requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None) # Verify distribution objective exists dist_objectives = [obj for obj in requirement.objectives if '分布' in obj.name] assert len(dist_objectives) > 0 def test_understand_generic_requirement(sample_data_profile): """Test understanding generic requirement without specific keywords.""" user_input = "帮我分析一下" requirement = _fallback_requirement_understanding(user_input, sample_data_profile, None) # Should still generate at least one objective assert len(requirement.objectives) > 0 # Should have default objective assert any('综合' in obj.name or 'analysis' in obj.name.lower() for obj in requirement.objectives) def test_parse_template_with_sections(): """Test parsing template with sections.""" template_content = """# 分析报告 ## 数据概览 这是数据概览部分 ## 趋势分析 指标: 增长率, 变化趋势 图表: 时间序列图 ## 分布分析 指标: 类别分布 图表: 柱状图, 饼图 """ with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f: f.write(template_content) template_path = f.name try: template_req = parse_template(template_path) # Verify sections assert len(template_req['sections']) >= 3 assert '分析报告' in template_req['sections'] assert '数据概览' in template_req['sections'] # Verify metrics assert len(template_req['required_metrics']) >= 2 # Verify charts assert len(template_req['required_charts']) >= 2 finally: os.unlink(template_path) def test_parse_nonexistent_template(): """Test parsing non-existent template.""" template_req = parse_template('nonexistent.md') # Should return empty structure assert template_req['sections'] == [] assert template_req['required_metrics'] == [] assert template_req['required_charts'] == [] def test_check_data_satisfies_requirement(sample_data_profile): """Test checking when data satisfies requirement.""" # Create requirement that data can satisfy requirement = RequirementSpec( user_input="分析状态分布", objectives=[ AnalysisObjective( name="状态分析", description="分析状态字段的分布", metrics=["状态分布"], priority=5 ) ] ) match_result = check_data_requirement_match(requirement, sample_data_profile) # Should be satisfied assert match_result['can_proceed'] is True assert len(match_result['satisfied_objectives']) > 0 def test_check_data_missing_fields(sample_data_profile): """Test checking when data is missing required fields.""" # Create requirement that needs fields not in data requirement = RequirementSpec( user_input="分析地理分布", objectives=[ AnalysisObjective( name="地理分析", description="分析地理位置分布", metrics=["地理分布", "区域统计"], priority=5 ) ] ) match_result = check_data_requirement_match(requirement, sample_data_profile) # Verify structure assert isinstance(match_result, dict) assert 'missing_fields' in match_result assert 'unsatisfied_objectives' in match_result def test_check_time_based_requirement(sample_data_profile): """Test checking time-based requirement.""" requirement = RequirementSpec( user_input="分析时间趋势", objectives=[ AnalysisObjective( name="时间分析", description="分析随时间的变化", metrics=["时间序列", "趋势"], priority=5 ) ] ) match_result = check_data_requirement_match(requirement, sample_data_profile) # Should be satisfied since we have datetime column assert match_result['can_proceed'] is True def test_check_status_based_requirement(sample_data_profile): """Test checking status-based requirement.""" requirement = RequirementSpec( user_input="分析状态", objectives=[ AnalysisObjective( name="状态分析", description="分析状态字段", metrics=["状态分布", "状态变化"], priority=5 ) ] ) match_result = check_data_requirement_match(requirement, sample_data_profile) # Should be satisfied since we have status column assert match_result['can_proceed'] is True assert len(match_result['satisfied_objectives']) > 0 def test_requirement_with_template(sample_data_profile): """Test requirement understanding with template.""" template_content = """# 工单分析报告 ## 状态分析 指标: 状态分布, 完成率 ## 类型分析 指标: 类型分布 """ with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding='utf-8') as f: f.write(template_content) template_path = f.name try: requirement = _fallback_requirement_understanding( "按模板分析", sample_data_profile, template_path ) # Verify template is included assert requirement.template_path == template_path assert requirement.template_requirements is not None # Verify template requirements structure assert 'sections' in requirement.template_requirements assert 'required_metrics' in requirement.template_requirements finally: os.unlink(template_path) def test_multiple_objectives_priority(): """Test that multiple objectives have proper priorities.""" data_profile = DataProfile( file_path='test.csv', row_count=100, column_count=3, columns=[ ColumnInfo(name='col1', dtype='numeric', missing_rate=0.0, unique_count=100), ColumnInfo(name='col2', dtype='categorical', missing_rate=0.0, unique_count=5), ColumnInfo(name='col3', dtype='datetime', missing_rate=0.0, unique_count=100) ], inferred_type='unknown', quality_score=90.0 ) requirement = _fallback_requirement_understanding( "完整分析,包括健康度和趋势", data_profile, None ) # Should have multiple objectives assert len(requirement.objectives) >= 2 # All priorities should be valid for obj in requirement.objectives: assert 1 <= obj.priority <= 5