"""配置管理模块的单元测试。""" import os import json import pytest from pathlib import Path from unittest.mock import patch from src.config import ( LLMConfig, PerformanceConfig, OutputConfig, Config, get_config, set_config, load_config_from_env, load_config_from_file ) class TestLLMConfig: """测试 LLM 配置。""" def test_default_config(self): """测试默认配置。""" config = LLMConfig(api_key="test_key") assert config.provider == "openai" assert config.api_key == "test_key" assert config.base_url == "https://api.openai.com/v1" assert config.model == "gpt-4" assert config.timeout == 120 assert config.max_retries == 3 assert config.temperature == 0.7 assert config.max_tokens is None def test_custom_config(self): """测试自定义配置。""" config = LLMConfig( provider="gemini", api_key="gemini_key", base_url="https://gemini.api", model="gemini-pro", timeout=60, max_retries=5, temperature=0.5, max_tokens=1000 ) assert config.provider == "gemini" assert config.api_key == "gemini_key" assert config.base_url == "https://gemini.api" assert config.model == "gemini-pro" assert config.timeout == 60 assert config.max_retries == 5 assert config.temperature == 0.5 assert config.max_tokens == 1000 def test_empty_api_key(self): """测试空 API key。""" with pytest.raises(ValueError, match="API key 不能为空"): LLMConfig(api_key="") def test_invalid_provider(self): """测试无效的 provider。""" with pytest.raises(ValueError, match="不支持的 LLM provider"): LLMConfig(api_key="test", provider="invalid") def test_invalid_timeout(self): """测试无效的 timeout。""" with pytest.raises(ValueError, match="timeout 必须大于 0"): LLMConfig(api_key="test", timeout=0) def test_invalid_max_retries(self): """测试无效的 max_retries。""" with pytest.raises(ValueError, match="max_retries 不能为负数"): LLMConfig(api_key="test", max_retries=-1) class TestPerformanceConfig: """测试性能配置。""" def test_default_config(self): """测试默认配置。""" config = PerformanceConfig() assert config.agent_max_rounds == 20 assert config.agent_timeout == 300 assert config.tool_max_query_rows == 10000 assert config.tool_execution_timeout == 60 assert config.data_max_rows == 1000000 assert config.data_sample_threshold == 1000000 assert config.max_concurrent_tasks == 1 def test_custom_config(self): """测试自定义配置。""" config = PerformanceConfig( agent_max_rounds=10, agent_timeout=600, tool_max_query_rows=5000, tool_execution_timeout=30, data_max_rows=500000, data_sample_threshold=500000, max_concurrent_tasks=2 ) assert config.agent_max_rounds == 10 assert config.agent_timeout == 600 assert config.tool_max_query_rows == 5000 assert config.tool_execution_timeout == 30 assert config.data_max_rows == 500000 assert config.data_sample_threshold == 500000 assert config.max_concurrent_tasks == 2 def test_invalid_agent_max_rounds(self): """测试无效的 agent_max_rounds。""" with pytest.raises(ValueError, match="agent_max_rounds 必须大于 0"): PerformanceConfig(agent_max_rounds=0) def test_invalid_tool_max_query_rows(self): """测试无效的 tool_max_query_rows。""" with pytest.raises(ValueError, match="tool_max_query_rows 必须大于 0"): PerformanceConfig(tool_max_query_rows=-1) class TestOutputConfig: """测试输出配置。""" def test_default_config(self): """测试默认配置。""" config = OutputConfig() assert config.output_dir == "output" assert config.log_dir == "output" assert config.chart_dir == str(Path("output") / "charts") assert config.report_filename == "analysis_report.md" assert config.log_level == "INFO" assert config.log_to_file is True assert config.log_to_console is True def test_custom_config(self): """测试自定义配置。""" config = OutputConfig( output_dir="results", log_dir="logs", chart_dir="charts", report_filename="report.md", log_level="DEBUG", log_to_file=False, log_to_console=True ) assert config.output_dir == "results" assert config.log_dir == "logs" assert config.chart_dir == "charts" assert config.report_filename == "report.md" assert config.log_level == "DEBUG" assert config.log_to_file is False assert config.log_to_console is True def test_invalid_log_level(self): """测试无效的 log_level。""" with pytest.raises(ValueError, match="不支持的 log_level"): OutputConfig(log_level="INVALID") def test_get_paths(self): """测试路径获取方法。""" config = OutputConfig( output_dir="results", log_dir="logs", chart_dir="charts" ) assert config.get_output_path() == Path("results") assert config.get_log_path() == Path("logs") assert config.get_chart_path() == Path("charts") assert config.get_report_path() == Path("results/analysis_report.md") class TestConfig: """测试系统配置。""" def test_default_config(self): """测试默认配置。""" config = Config( llm=LLMConfig(api_key="test_key") ) assert config.llm.api_key == "test_key" assert config.performance.agent_max_rounds == 20 assert config.output.output_dir == "output" assert config.code_repo_enable_reuse is True def test_from_env(self): """测试从环境变量加载配置。""" env_vars = { "LLM_PROVIDER": "openai", "OPENAI_API_KEY": "env_test_key", "OPENAI_BASE_URL": "https://test.api", "OPENAI_MODEL": "gpt-3.5-turbo", "AGENT_MAX_ROUNDS": "15", "AGENT_OUTPUT_DIR": "test_output", "TOOL_MAX_QUERY_ROWS": "5000", "CODE_REPO_ENABLE_REUSE": "false" } with patch.dict(os.environ, env_vars, clear=True): config = Config.from_env() assert config.llm.provider == "openai" assert config.llm.api_key == "env_test_key" assert config.llm.base_url == "https://test.api" assert config.llm.model == "gpt-3.5-turbo" assert config.performance.agent_max_rounds == 15 assert config.performance.tool_max_query_rows == 5000 assert config.output.output_dir == "test_output" assert config.code_repo_enable_reuse is False def test_from_env_gemini(self): """测试从环境变量加载 Gemini 配置。""" env_vars = { "LLM_PROVIDER": "gemini", "GEMINI_API_KEY": "gemini_key", "GEMINI_BASE_URL": "https://gemini.api", "GEMINI_MODEL": "gemini-pro" } with patch.dict(os.environ, env_vars, clear=True): config = Config.from_env() assert config.llm.provider == "gemini" assert config.llm.api_key == "gemini_key" assert config.llm.base_url == "https://gemini.api" assert config.llm.model == "gemini-pro" def test_from_dict(self): """测试从字典加载配置。""" config_dict = { "llm": { "provider": "openai", "api_key": "dict_test_key", "base_url": "https://dict.api", "model": "gpt-4", "timeout": 90, "max_retries": 2, "temperature": 0.5, "max_tokens": 2000 }, "performance": { "agent_max_rounds": 25, "tool_max_query_rows": 8000 }, "output": { "output_dir": "dict_output", "log_level": "DEBUG" }, "code_repo_enable_reuse": False } config = Config.from_dict(config_dict) assert config.llm.api_key == "dict_test_key" assert config.llm.base_url == "https://dict.api" assert config.llm.timeout == 90 assert config.llm.max_retries == 2 assert config.llm.temperature == 0.5 assert config.llm.max_tokens == 2000 assert config.performance.agent_max_rounds == 25 assert config.performance.tool_max_query_rows == 8000 assert config.output.output_dir == "dict_output" assert config.output.log_level == "DEBUG" assert config.code_repo_enable_reuse is False def test_from_file(self, tmp_path): """测试从文件加载配置。""" config_file = tmp_path / "test_config.json" config_dict = { "llm": { "provider": "openai", "api_key": "file_test_key", "model": "gpt-4" }, "performance": { "agent_max_rounds": 30 } } with open(config_file, 'w') as f: json.dump(config_dict, f) config = Config.from_file(str(config_file)) assert config.llm.api_key == "file_test_key" assert config.llm.model == "gpt-4" assert config.performance.agent_max_rounds == 30 def test_from_file_not_found(self): """测试加载不存在的配置文件。""" with pytest.raises(FileNotFoundError): Config.from_file("nonexistent.json") def test_to_dict(self): """测试转换为字典。""" config = Config( llm=LLMConfig( api_key="test_key", model="gpt-4" ), performance=PerformanceConfig( agent_max_rounds=15 ), output=OutputConfig( output_dir="test_output" ) ) config_dict = config.to_dict() assert config_dict["llm"]["api_key"] == "***" # API key 应该被隐藏 assert config_dict["llm"]["model"] == "gpt-4" assert config_dict["performance"]["agent_max_rounds"] == 15 assert config_dict["output"]["output_dir"] == "test_output" def test_save_to_file(self, tmp_path): """测试保存配置到文件。""" config_file = tmp_path / "saved_config.json" config = Config( llm=LLMConfig(api_key="test_key"), performance=PerformanceConfig(agent_max_rounds=15) ) config.save_to_file(str(config_file)) assert config_file.exists() with open(config_file, 'r') as f: saved_dict = json.load(f) assert saved_dict["llm"]["api_key"] == "***" assert saved_dict["performance"]["agent_max_rounds"] == 15 def test_validate_success(self): """测试配置验证成功。""" config = Config( llm=LLMConfig(api_key="test_key") ) assert config.validate() is True def test_validate_missing_api_key(self): """测试配置验证失败(缺少 API key)。""" config = Config( llm=LLMConfig(api_key="test_key") ) config.llm.api_key = "" # 手动清空 assert config.validate() is False class TestGlobalConfig: """测试全局配置管理。""" def test_get_config(self): """测试获取全局配置。""" # 重置全局配置 set_config(None) # 模拟环境变量 env_vars = { "OPENAI_API_KEY": "global_test_key" } with patch.dict(os.environ, env_vars, clear=True): config = get_config() assert config is not None assert config.llm.api_key == "global_test_key" def test_set_config(self): """测试设置全局配置。""" custom_config = Config( llm=LLMConfig(api_key="custom_key") ) set_config(custom_config) config = get_config() assert config.llm.api_key == "custom_key" def test_load_config_from_env(self): """测试从环境变量加载全局配置。""" env_vars = { "OPENAI_API_KEY": "env_global_key", "AGENT_MAX_ROUNDS": "25" } with patch.dict(os.environ, env_vars, clear=True): config = load_config_from_env() assert config.llm.api_key == "env_global_key" assert config.performance.agent_max_rounds == 25 # 验证全局配置已更新 global_config = get_config() assert global_config.llm.api_key == "env_global_key" def test_load_config_from_file(self, tmp_path): """测试从文件加载全局配置。""" config_file = tmp_path / "global_config.json" config_dict = { "llm": { "provider": "openai", "api_key": "file_global_key", "model": "gpt-4" } } with open(config_file, 'w') as f: json.dump(config_dict, f) config = load_config_from_file(str(config_file)) assert config.llm.api_key == "file_global_key" # 验证全局配置已更新 global_config = get_config() assert global_config.llm.api_key == "file_global_key"