Files
vibe_data_ana/tests/test_config.py

431 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""配置管理模块的单元测试。"""
import os
import json
import pytest
from pathlib import Path
from unittest.mock import patch
from src.config import (
LLMConfig,
PerformanceConfig,
OutputConfig,
Config,
get_config,
set_config,
load_config_from_env,
load_config_from_file
)
class TestLLMConfig:
"""测试 LLM 配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = LLMConfig(api_key="test_key")
assert config.provider == "openai"
assert config.api_key == "test_key"
assert config.base_url == "https://api.openai.com/v1"
assert config.model == "gpt-4"
assert config.timeout == 120
assert config.max_retries == 3
assert config.temperature == 0.7
assert config.max_tokens is None
def test_custom_config(self):
"""测试自定义配置。"""
config = LLMConfig(
provider="gemini",
api_key="gemini_key",
base_url="https://gemini.api",
model="gemini-pro",
timeout=60,
max_retries=5,
temperature=0.5,
max_tokens=1000
)
assert config.provider == "gemini"
assert config.api_key == "gemini_key"
assert config.base_url == "https://gemini.api"
assert config.model == "gemini-pro"
assert config.timeout == 60
assert config.max_retries == 5
assert config.temperature == 0.5
assert config.max_tokens == 1000
def test_empty_api_key(self):
"""测试空 API key。"""
with pytest.raises(ValueError, match="API key 不能为空"):
LLMConfig(api_key="")
def test_invalid_provider(self):
"""测试无效的 provider。"""
with pytest.raises(ValueError, match="不支持的 LLM provider"):
LLMConfig(api_key="test", provider="invalid")
def test_invalid_timeout(self):
"""测试无效的 timeout。"""
with pytest.raises(ValueError, match="timeout 必须大于 0"):
LLMConfig(api_key="test", timeout=0)
def test_invalid_max_retries(self):
"""测试无效的 max_retries。"""
with pytest.raises(ValueError, match="max_retries 不能为负数"):
LLMConfig(api_key="test", max_retries=-1)
class TestPerformanceConfig:
"""测试性能配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = PerformanceConfig()
assert config.agent_max_rounds == 20
assert config.agent_timeout == 300
assert config.tool_max_query_rows == 10000
assert config.tool_execution_timeout == 60
assert config.data_max_rows == 1000000
assert config.data_sample_threshold == 1000000
assert config.max_concurrent_tasks == 1
def test_custom_config(self):
"""测试自定义配置。"""
config = PerformanceConfig(
agent_max_rounds=10,
agent_timeout=600,
tool_max_query_rows=5000,
tool_execution_timeout=30,
data_max_rows=500000,
data_sample_threshold=500000,
max_concurrent_tasks=2
)
assert config.agent_max_rounds == 10
assert config.agent_timeout == 600
assert config.tool_max_query_rows == 5000
assert config.tool_execution_timeout == 30
assert config.data_max_rows == 500000
assert config.data_sample_threshold == 500000
assert config.max_concurrent_tasks == 2
def test_invalid_agent_max_rounds(self):
"""测试无效的 agent_max_rounds。"""
with pytest.raises(ValueError, match="agent_max_rounds 必须大于 0"):
PerformanceConfig(agent_max_rounds=0)
def test_invalid_tool_max_query_rows(self):
"""测试无效的 tool_max_query_rows。"""
with pytest.raises(ValueError, match="tool_max_query_rows 必须大于 0"):
PerformanceConfig(tool_max_query_rows=-1)
class TestOutputConfig:
"""测试输出配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = OutputConfig()
assert config.output_dir == "output"
assert config.log_dir == "output"
assert config.chart_dir == str(Path("output") / "charts")
assert config.report_filename == "analysis_report.md"
assert config.log_level == "INFO"
assert config.log_to_file is True
assert config.log_to_console is True
def test_custom_config(self):
"""测试自定义配置。"""
config = OutputConfig(
output_dir="results",
log_dir="logs",
chart_dir="charts",
report_filename="report.md",
log_level="DEBUG",
log_to_file=False,
log_to_console=True
)
assert config.output_dir == "results"
assert config.log_dir == "logs"
assert config.chart_dir == "charts"
assert config.report_filename == "report.md"
assert config.log_level == "DEBUG"
assert config.log_to_file is False
assert config.log_to_console is True
def test_invalid_log_level(self):
"""测试无效的 log_level。"""
with pytest.raises(ValueError, match="不支持的 log_level"):
OutputConfig(log_level="INVALID")
def test_get_paths(self):
"""测试路径获取方法。"""
config = OutputConfig(
output_dir="results",
log_dir="logs",
chart_dir="charts"
)
assert config.get_output_path() == Path("results")
assert config.get_log_path() == Path("logs")
assert config.get_chart_path() == Path("charts")
assert config.get_report_path() == Path("results/analysis_report.md")
class TestConfig:
"""测试系统配置。"""
def test_default_config(self):
"""测试默认配置。"""
config = Config(
llm=LLMConfig(api_key="test_key")
)
assert config.llm.api_key == "test_key"
assert config.performance.agent_max_rounds == 20
assert config.output.output_dir == "output"
assert config.code_repo_enable_reuse is True
def test_from_env(self):
"""测试从环境变量加载配置。"""
env_vars = {
"LLM_PROVIDER": "openai",
"OPENAI_API_KEY": "env_test_key",
"OPENAI_BASE_URL": "https://test.api",
"OPENAI_MODEL": "gpt-3.5-turbo",
"AGENT_MAX_ROUNDS": "15",
"AGENT_OUTPUT_DIR": "test_output",
"TOOL_MAX_QUERY_ROWS": "5000",
"CODE_REPO_ENABLE_REUSE": "false"
}
with patch.dict(os.environ, env_vars, clear=True):
config = Config.from_env()
assert config.llm.provider == "openai"
assert config.llm.api_key == "env_test_key"
assert config.llm.base_url == "https://test.api"
assert config.llm.model == "gpt-3.5-turbo"
assert config.performance.agent_max_rounds == 15
assert config.performance.tool_max_query_rows == 5000
assert config.output.output_dir == "test_output"
assert config.code_repo_enable_reuse is False
def test_from_env_gemini(self):
"""测试从环境变量加载 Gemini 配置。"""
env_vars = {
"LLM_PROVIDER": "gemini",
"GEMINI_API_KEY": "gemini_key",
"GEMINI_BASE_URL": "https://gemini.api",
"GEMINI_MODEL": "gemini-pro"
}
with patch.dict(os.environ, env_vars, clear=True):
config = Config.from_env()
assert config.llm.provider == "gemini"
assert config.llm.api_key == "gemini_key"
assert config.llm.base_url == "https://gemini.api"
assert config.llm.model == "gemini-pro"
def test_from_dict(self):
"""测试从字典加载配置。"""
config_dict = {
"llm": {
"provider": "openai",
"api_key": "dict_test_key",
"base_url": "https://dict.api",
"model": "gpt-4",
"timeout": 90,
"max_retries": 2,
"temperature": 0.5,
"max_tokens": 2000
},
"performance": {
"agent_max_rounds": 25,
"tool_max_query_rows": 8000
},
"output": {
"output_dir": "dict_output",
"log_level": "DEBUG"
},
"code_repo_enable_reuse": False
}
config = Config.from_dict(config_dict)
assert config.llm.api_key == "dict_test_key"
assert config.llm.base_url == "https://dict.api"
assert config.llm.timeout == 90
assert config.llm.max_retries == 2
assert config.llm.temperature == 0.5
assert config.llm.max_tokens == 2000
assert config.performance.agent_max_rounds == 25
assert config.performance.tool_max_query_rows == 8000
assert config.output.output_dir == "dict_output"
assert config.output.log_level == "DEBUG"
assert config.code_repo_enable_reuse is False
def test_from_file(self, tmp_path):
"""测试从文件加载配置。"""
config_file = tmp_path / "test_config.json"
config_dict = {
"llm": {
"provider": "openai",
"api_key": "file_test_key",
"model": "gpt-4"
},
"performance": {
"agent_max_rounds": 30
}
}
with open(config_file, 'w') as f:
json.dump(config_dict, f)
config = Config.from_file(str(config_file))
assert config.llm.api_key == "file_test_key"
assert config.llm.model == "gpt-4"
assert config.performance.agent_max_rounds == 30
def test_from_file_not_found(self):
"""测试加载不存在的配置文件。"""
with pytest.raises(FileNotFoundError):
Config.from_file("nonexistent.json")
def test_to_dict(self):
"""测试转换为字典。"""
config = Config(
llm=LLMConfig(
api_key="test_key",
model="gpt-4"
),
performance=PerformanceConfig(
agent_max_rounds=15
),
output=OutputConfig(
output_dir="test_output"
)
)
config_dict = config.to_dict()
assert config_dict["llm"]["api_key"] == "***" # API key 应该被隐藏
assert config_dict["llm"]["model"] == "gpt-4"
assert config_dict["performance"]["agent_max_rounds"] == 15
assert config_dict["output"]["output_dir"] == "test_output"
def test_save_to_file(self, tmp_path):
"""测试保存配置到文件。"""
config_file = tmp_path / "saved_config.json"
config = Config(
llm=LLMConfig(api_key="test_key"),
performance=PerformanceConfig(agent_max_rounds=15)
)
config.save_to_file(str(config_file))
assert config_file.exists()
with open(config_file, 'r') as f:
saved_dict = json.load(f)
assert saved_dict["llm"]["api_key"] == "***"
assert saved_dict["performance"]["agent_max_rounds"] == 15
def test_validate_success(self):
"""测试配置验证成功。"""
config = Config(
llm=LLMConfig(api_key="test_key")
)
assert config.validate() is True
def test_validate_missing_api_key(self):
"""测试配置验证失败(缺少 API key"""
config = Config(
llm=LLMConfig(api_key="test_key")
)
config.llm.api_key = "" # 手动清空
assert config.validate() is False
class TestGlobalConfig:
"""测试全局配置管理。"""
def test_get_config(self):
"""测试获取全局配置。"""
# 重置全局配置
set_config(None)
# 模拟环境变量
env_vars = {
"OPENAI_API_KEY": "global_test_key"
}
with patch.dict(os.environ, env_vars, clear=True):
config = get_config()
assert config is not None
assert config.llm.api_key == "global_test_key"
def test_set_config(self):
"""测试设置全局配置。"""
custom_config = Config(
llm=LLMConfig(api_key="custom_key")
)
set_config(custom_config)
config = get_config()
assert config.llm.api_key == "custom_key"
def test_load_config_from_env(self):
"""测试从环境变量加载全局配置。"""
env_vars = {
"OPENAI_API_KEY": "env_global_key",
"AGENT_MAX_ROUNDS": "25"
}
with patch.dict(os.environ, env_vars, clear=True):
config = load_config_from_env()
assert config.llm.api_key == "env_global_key"
assert config.performance.agent_max_rounds == 25
# 验证全局配置已更新
global_config = get_config()
assert global_config.llm.api_key == "env_global_key"
def test_load_config_from_file(self, tmp_path):
"""测试从文件加载全局配置。"""
config_file = tmp_path / "global_config.json"
config_dict = {
"llm": {
"provider": "openai",
"api_key": "file_global_key",
"model": "gpt-4"
}
}
with open(config_file, 'w') as f:
json.dump(config_dict, f)
config = load_config_from_file(str(config_file))
assert config.llm.api_key == "file_global_key"
# 验证全局配置已更新
global_config = get_config()
assert global_config.llm.api_key == "file_global_key"