431 lines
14 KiB
Python
431 lines
14 KiB
Python
"""配置管理模块的单元测试。"""
|
||
|
||
import os
|
||
import json
|
||
import pytest
|
||
from pathlib import Path
|
||
from unittest.mock import patch
|
||
|
||
from src.config import (
|
||
LLMConfig,
|
||
PerformanceConfig,
|
||
OutputConfig,
|
||
Config,
|
||
get_config,
|
||
set_config,
|
||
load_config_from_env,
|
||
load_config_from_file
|
||
)
|
||
|
||
|
||
class TestLLMConfig:
|
||
"""测试 LLM 配置。"""
|
||
|
||
def test_default_config(self):
|
||
"""测试默认配置。"""
|
||
config = LLMConfig(api_key="test_key")
|
||
|
||
assert config.provider == "openai"
|
||
assert config.api_key == "test_key"
|
||
assert config.base_url == "https://api.openai.com/v1"
|
||
assert config.model == "gpt-4"
|
||
assert config.timeout == 120
|
||
assert config.max_retries == 3
|
||
assert config.temperature == 0.7
|
||
assert config.max_tokens is None
|
||
|
||
def test_custom_config(self):
|
||
"""测试自定义配置。"""
|
||
config = LLMConfig(
|
||
provider="gemini",
|
||
api_key="gemini_key",
|
||
base_url="https://gemini.api",
|
||
model="gemini-pro",
|
||
timeout=60,
|
||
max_retries=5,
|
||
temperature=0.5,
|
||
max_tokens=1000
|
||
)
|
||
|
||
assert config.provider == "gemini"
|
||
assert config.api_key == "gemini_key"
|
||
assert config.base_url == "https://gemini.api"
|
||
assert config.model == "gemini-pro"
|
||
assert config.timeout == 60
|
||
assert config.max_retries == 5
|
||
assert config.temperature == 0.5
|
||
assert config.max_tokens == 1000
|
||
|
||
def test_empty_api_key(self):
|
||
"""测试空 API key。"""
|
||
with pytest.raises(ValueError, match="API key 不能为空"):
|
||
LLMConfig(api_key="")
|
||
|
||
def test_invalid_provider(self):
|
||
"""测试无效的 provider。"""
|
||
with pytest.raises(ValueError, match="不支持的 LLM provider"):
|
||
LLMConfig(api_key="test", provider="invalid")
|
||
|
||
def test_invalid_timeout(self):
|
||
"""测试无效的 timeout。"""
|
||
with pytest.raises(ValueError, match="timeout 必须大于 0"):
|
||
LLMConfig(api_key="test", timeout=0)
|
||
|
||
def test_invalid_max_retries(self):
|
||
"""测试无效的 max_retries。"""
|
||
with pytest.raises(ValueError, match="max_retries 不能为负数"):
|
||
LLMConfig(api_key="test", max_retries=-1)
|
||
|
||
|
||
class TestPerformanceConfig:
|
||
"""测试性能配置。"""
|
||
|
||
def test_default_config(self):
|
||
"""测试默认配置。"""
|
||
config = PerformanceConfig()
|
||
|
||
assert config.agent_max_rounds == 20
|
||
assert config.agent_timeout == 300
|
||
assert config.tool_max_query_rows == 10000
|
||
assert config.tool_execution_timeout == 60
|
||
assert config.data_max_rows == 1000000
|
||
assert config.data_sample_threshold == 1000000
|
||
assert config.max_concurrent_tasks == 1
|
||
|
||
def test_custom_config(self):
|
||
"""测试自定义配置。"""
|
||
config = PerformanceConfig(
|
||
agent_max_rounds=10,
|
||
agent_timeout=600,
|
||
tool_max_query_rows=5000,
|
||
tool_execution_timeout=30,
|
||
data_max_rows=500000,
|
||
data_sample_threshold=500000,
|
||
max_concurrent_tasks=2
|
||
)
|
||
|
||
assert config.agent_max_rounds == 10
|
||
assert config.agent_timeout == 600
|
||
assert config.tool_max_query_rows == 5000
|
||
assert config.tool_execution_timeout == 30
|
||
assert config.data_max_rows == 500000
|
||
assert config.data_sample_threshold == 500000
|
||
assert config.max_concurrent_tasks == 2
|
||
|
||
def test_invalid_agent_max_rounds(self):
|
||
"""测试无效的 agent_max_rounds。"""
|
||
with pytest.raises(ValueError, match="agent_max_rounds 必须大于 0"):
|
||
PerformanceConfig(agent_max_rounds=0)
|
||
|
||
def test_invalid_tool_max_query_rows(self):
|
||
"""测试无效的 tool_max_query_rows。"""
|
||
with pytest.raises(ValueError, match="tool_max_query_rows 必须大于 0"):
|
||
PerformanceConfig(tool_max_query_rows=-1)
|
||
|
||
|
||
class TestOutputConfig:
|
||
"""测试输出配置。"""
|
||
|
||
def test_default_config(self):
|
||
"""测试默认配置。"""
|
||
config = OutputConfig()
|
||
|
||
assert config.output_dir == "output"
|
||
assert config.log_dir == "output"
|
||
assert config.chart_dir == str(Path("output") / "charts")
|
||
assert config.report_filename == "analysis_report.md"
|
||
assert config.log_level == "INFO"
|
||
assert config.log_to_file is True
|
||
assert config.log_to_console is True
|
||
|
||
def test_custom_config(self):
|
||
"""测试自定义配置。"""
|
||
config = OutputConfig(
|
||
output_dir="results",
|
||
log_dir="logs",
|
||
chart_dir="charts",
|
||
report_filename="report.md",
|
||
log_level="DEBUG",
|
||
log_to_file=False,
|
||
log_to_console=True
|
||
)
|
||
|
||
assert config.output_dir == "results"
|
||
assert config.log_dir == "logs"
|
||
assert config.chart_dir == "charts"
|
||
assert config.report_filename == "report.md"
|
||
assert config.log_level == "DEBUG"
|
||
assert config.log_to_file is False
|
||
assert config.log_to_console is True
|
||
|
||
def test_invalid_log_level(self):
|
||
"""测试无效的 log_level。"""
|
||
with pytest.raises(ValueError, match="不支持的 log_level"):
|
||
OutputConfig(log_level="INVALID")
|
||
|
||
def test_get_paths(self):
|
||
"""测试路径获取方法。"""
|
||
config = OutputConfig(
|
||
output_dir="results",
|
||
log_dir="logs",
|
||
chart_dir="charts"
|
||
)
|
||
|
||
assert config.get_output_path() == Path("results")
|
||
assert config.get_log_path() == Path("logs")
|
||
assert config.get_chart_path() == Path("charts")
|
||
assert config.get_report_path() == Path("results/analysis_report.md")
|
||
|
||
|
||
class TestConfig:
|
||
"""测试系统配置。"""
|
||
|
||
def test_default_config(self):
|
||
"""测试默认配置。"""
|
||
config = Config(
|
||
llm=LLMConfig(api_key="test_key")
|
||
)
|
||
|
||
assert config.llm.api_key == "test_key"
|
||
assert config.performance.agent_max_rounds == 20
|
||
assert config.output.output_dir == "output"
|
||
assert config.code_repo_enable_reuse is True
|
||
|
||
def test_from_env(self):
|
||
"""测试从环境变量加载配置。"""
|
||
env_vars = {
|
||
"LLM_PROVIDER": "openai",
|
||
"OPENAI_API_KEY": "env_test_key",
|
||
"OPENAI_BASE_URL": "https://test.api",
|
||
"OPENAI_MODEL": "gpt-3.5-turbo",
|
||
"AGENT_MAX_ROUNDS": "15",
|
||
"AGENT_OUTPUT_DIR": "test_output",
|
||
"TOOL_MAX_QUERY_ROWS": "5000",
|
||
"CODE_REPO_ENABLE_REUSE": "false"
|
||
}
|
||
|
||
with patch.dict(os.environ, env_vars, clear=True):
|
||
config = Config.from_env()
|
||
|
||
assert config.llm.provider == "openai"
|
||
assert config.llm.api_key == "env_test_key"
|
||
assert config.llm.base_url == "https://test.api"
|
||
assert config.llm.model == "gpt-3.5-turbo"
|
||
assert config.performance.agent_max_rounds == 15
|
||
assert config.performance.tool_max_query_rows == 5000
|
||
assert config.output.output_dir == "test_output"
|
||
assert config.code_repo_enable_reuse is False
|
||
|
||
def test_from_env_gemini(self):
|
||
"""测试从环境变量加载 Gemini 配置。"""
|
||
env_vars = {
|
||
"LLM_PROVIDER": "gemini",
|
||
"GEMINI_API_KEY": "gemini_key",
|
||
"GEMINI_BASE_URL": "https://gemini.api",
|
||
"GEMINI_MODEL": "gemini-pro"
|
||
}
|
||
|
||
with patch.dict(os.environ, env_vars, clear=True):
|
||
config = Config.from_env()
|
||
|
||
assert config.llm.provider == "gemini"
|
||
assert config.llm.api_key == "gemini_key"
|
||
assert config.llm.base_url == "https://gemini.api"
|
||
assert config.llm.model == "gemini-pro"
|
||
|
||
def test_from_dict(self):
|
||
"""测试从字典加载配置。"""
|
||
config_dict = {
|
||
"llm": {
|
||
"provider": "openai",
|
||
"api_key": "dict_test_key",
|
||
"base_url": "https://dict.api",
|
||
"model": "gpt-4",
|
||
"timeout": 90,
|
||
"max_retries": 2,
|
||
"temperature": 0.5,
|
||
"max_tokens": 2000
|
||
},
|
||
"performance": {
|
||
"agent_max_rounds": 25,
|
||
"tool_max_query_rows": 8000
|
||
},
|
||
"output": {
|
||
"output_dir": "dict_output",
|
||
"log_level": "DEBUG"
|
||
},
|
||
"code_repo_enable_reuse": False
|
||
}
|
||
|
||
config = Config.from_dict(config_dict)
|
||
|
||
assert config.llm.api_key == "dict_test_key"
|
||
assert config.llm.base_url == "https://dict.api"
|
||
assert config.llm.timeout == 90
|
||
assert config.llm.max_retries == 2
|
||
assert config.llm.temperature == 0.5
|
||
assert config.llm.max_tokens == 2000
|
||
assert config.performance.agent_max_rounds == 25
|
||
assert config.performance.tool_max_query_rows == 8000
|
||
assert config.output.output_dir == "dict_output"
|
||
assert config.output.log_level == "DEBUG"
|
||
assert config.code_repo_enable_reuse is False
|
||
|
||
def test_from_file(self, tmp_path):
|
||
"""测试从文件加载配置。"""
|
||
config_file = tmp_path / "test_config.json"
|
||
|
||
config_dict = {
|
||
"llm": {
|
||
"provider": "openai",
|
||
"api_key": "file_test_key",
|
||
"model": "gpt-4"
|
||
},
|
||
"performance": {
|
||
"agent_max_rounds": 30
|
||
}
|
||
}
|
||
|
||
with open(config_file, 'w') as f:
|
||
json.dump(config_dict, f)
|
||
|
||
config = Config.from_file(str(config_file))
|
||
|
||
assert config.llm.api_key == "file_test_key"
|
||
assert config.llm.model == "gpt-4"
|
||
assert config.performance.agent_max_rounds == 30
|
||
|
||
def test_from_file_not_found(self):
|
||
"""测试加载不存在的配置文件。"""
|
||
with pytest.raises(FileNotFoundError):
|
||
Config.from_file("nonexistent.json")
|
||
|
||
def test_to_dict(self):
|
||
"""测试转换为字典。"""
|
||
config = Config(
|
||
llm=LLMConfig(
|
||
api_key="test_key",
|
||
model="gpt-4"
|
||
),
|
||
performance=PerformanceConfig(
|
||
agent_max_rounds=15
|
||
),
|
||
output=OutputConfig(
|
||
output_dir="test_output"
|
||
)
|
||
)
|
||
|
||
config_dict = config.to_dict()
|
||
|
||
assert config_dict["llm"]["api_key"] == "***" # API key 应该被隐藏
|
||
assert config_dict["llm"]["model"] == "gpt-4"
|
||
assert config_dict["performance"]["agent_max_rounds"] == 15
|
||
assert config_dict["output"]["output_dir"] == "test_output"
|
||
|
||
def test_save_to_file(self, tmp_path):
|
||
"""测试保存配置到文件。"""
|
||
config_file = tmp_path / "saved_config.json"
|
||
|
||
config = Config(
|
||
llm=LLMConfig(api_key="test_key"),
|
||
performance=PerformanceConfig(agent_max_rounds=15)
|
||
)
|
||
|
||
config.save_to_file(str(config_file))
|
||
|
||
assert config_file.exists()
|
||
|
||
with open(config_file, 'r') as f:
|
||
saved_dict = json.load(f)
|
||
|
||
assert saved_dict["llm"]["api_key"] == "***"
|
||
assert saved_dict["performance"]["agent_max_rounds"] == 15
|
||
|
||
def test_validate_success(self):
|
||
"""测试配置验证成功。"""
|
||
config = Config(
|
||
llm=LLMConfig(api_key="test_key")
|
||
)
|
||
|
||
assert config.validate() is True
|
||
|
||
def test_validate_missing_api_key(self):
|
||
"""测试配置验证失败(缺少 API key)。"""
|
||
config = Config(
|
||
llm=LLMConfig(api_key="test_key")
|
||
)
|
||
config.llm.api_key = "" # 手动清空
|
||
|
||
assert config.validate() is False
|
||
|
||
|
||
class TestGlobalConfig:
|
||
"""测试全局配置管理。"""
|
||
|
||
def test_get_config(self):
|
||
"""测试获取全局配置。"""
|
||
# 重置全局配置
|
||
set_config(None)
|
||
|
||
# 模拟环境变量
|
||
env_vars = {
|
||
"OPENAI_API_KEY": "global_test_key"
|
||
}
|
||
|
||
with patch.dict(os.environ, env_vars, clear=True):
|
||
config = get_config()
|
||
|
||
assert config is not None
|
||
assert config.llm.api_key == "global_test_key"
|
||
|
||
def test_set_config(self):
|
||
"""测试设置全局配置。"""
|
||
custom_config = Config(
|
||
llm=LLMConfig(api_key="custom_key")
|
||
)
|
||
|
||
set_config(custom_config)
|
||
|
||
config = get_config()
|
||
assert config.llm.api_key == "custom_key"
|
||
|
||
def test_load_config_from_env(self):
|
||
"""测试从环境变量加载全局配置。"""
|
||
env_vars = {
|
||
"OPENAI_API_KEY": "env_global_key",
|
||
"AGENT_MAX_ROUNDS": "25"
|
||
}
|
||
|
||
with patch.dict(os.environ, env_vars, clear=True):
|
||
config = load_config_from_env()
|
||
|
||
assert config.llm.api_key == "env_global_key"
|
||
assert config.performance.agent_max_rounds == 25
|
||
|
||
# 验证全局配置已更新
|
||
global_config = get_config()
|
||
assert global_config.llm.api_key == "env_global_key"
|
||
|
||
def test_load_config_from_file(self, tmp_path):
|
||
"""测试从文件加载全局配置。"""
|
||
config_file = tmp_path / "global_config.json"
|
||
|
||
config_dict = {
|
||
"llm": {
|
||
"provider": "openai",
|
||
"api_key": "file_global_key",
|
||
"model": "gpt-4"
|
||
}
|
||
}
|
||
|
||
with open(config_file, 'w') as f:
|
||
json.dump(config_dict, f)
|
||
|
||
config = load_config_from_file(str(config_file))
|
||
|
||
assert config.llm.api_key == "file_global_key"
|
||
|
||
# 验证全局配置已更新
|
||
global_config = get_config()
|
||
assert global_config.llm.api_key == "file_global_key"
|