"""环境变量加载器的单元测试。""" import os import pytest from pathlib import Path from unittest.mock import patch from src.env_loader import ( load_env_file, load_env_with_fallback, get_env, get_env_bool, get_env_int, get_env_float, validate_required_env_vars ) class TestLoadEnvFile: """测试加载 .env 文件。""" def test_load_env_file_success(self, tmp_path): """测试成功加载 .env 文件。""" env_file = tmp_path / ".env" env_file.write_text(""" # This is a comment KEY1=value1 KEY2="value2" KEY3='value3' KEY4=value with spaces # Another comment KEY5=123 """, encoding='utf-8') # 清空环境变量 with patch.dict(os.environ, {}, clear=True): result = load_env_file(str(env_file)) assert result is True assert os.getenv("KEY1") == "value1" assert os.getenv("KEY2") == "value2" assert os.getenv("KEY3") == "value3" assert os.getenv("KEY4") == "value with spaces" assert os.getenv("KEY5") == "123" def test_load_env_file_not_found(self): """测试加载不存在的 .env 文件。""" result = load_env_file("nonexistent.env") assert result is False def test_load_env_file_skip_existing(self, tmp_path): """测试跳过已存在的环境变量。""" env_file = tmp_path / ".env" env_file.write_text("KEY1=from_file\nKEY2=from_file") # 设置一个已存在的环境变量 with patch.dict(os.environ, {"KEY1": "from_env"}, clear=True): load_env_file(str(env_file)) # KEY1 应该保持原值(环境变量优先) assert os.getenv("KEY1") == "from_env" # KEY2 应该从文件加载 assert os.getenv("KEY2") == "from_file" def test_load_env_file_skip_invalid_lines(self, tmp_path): """测试跳过无效行。""" env_file = tmp_path / ".env" env_file.write_text(""" VALID_KEY=valid_value invalid line without equals ANOTHER_VALID=another_value """) with patch.dict(os.environ, {}, clear=True): result = load_env_file(str(env_file)) assert result is True assert os.getenv("VALID_KEY") == "valid_value" assert os.getenv("ANOTHER_VALID") == "another_value" def test_load_env_file_empty_lines(self, tmp_path): """测试处理空行。""" env_file = tmp_path / ".env" env_file.write_text(""" KEY1=value1 KEY2=value2 KEY3=value3 """) with patch.dict(os.environ, {}, clear=True): result = load_env_file(str(env_file)) assert result is True assert os.getenv("KEY1") == "value1" assert os.getenv("KEY2") == "value2" assert os.getenv("KEY3") == "value3" class TestLoadEnvWithFallback: """测试按优先级加载多个 .env 文件。""" def test_load_multiple_files(self, tmp_path): """测试加载多个文件。""" env_file1 = tmp_path / ".env.local" env_file1.write_text("KEY1=local\nKEY2=local") env_file2 = tmp_path / ".env" env_file2.write_text("KEY1=default\nKEY3=default") with patch.dict(os.environ, {}, clear=True): # 切换到临时目录 original_dir = os.getcwd() os.chdir(tmp_path) try: result = load_env_with_fallback([".env.local", ".env"]) assert result is True # KEY1 应该来自 .env.local(优先级更高) assert os.getenv("KEY1") == "local" # KEY2 应该来自 .env.local assert os.getenv("KEY2") == "local" # KEY3 应该来自 .env assert os.getenv("KEY3") == "default" finally: os.chdir(original_dir) def test_load_no_files_found(self): """测试没有找到任何文件。""" result = load_env_with_fallback(["nonexistent1.env", "nonexistent2.env"]) assert result is False class TestGetEnv: """测试获取环境变量。""" def test_get_env_exists(self): """测试获取存在的环境变量。""" with patch.dict(os.environ, {"TEST_KEY": "test_value"}): assert get_env("TEST_KEY") == "test_value" def test_get_env_not_exists(self): """测试获取不存在的环境变量。""" with patch.dict(os.environ, {}, clear=True): assert get_env("NONEXISTENT_KEY") is None def test_get_env_with_default(self): """测试使用默认值。""" with patch.dict(os.environ, {}, clear=True): assert get_env("NONEXISTENT_KEY", "default") == "default" class TestGetEnvBool: """测试获取布尔类型环境变量。""" def test_get_env_bool_true_values(self): """测试 True 值。""" true_values = ["true", "True", "TRUE", "yes", "Yes", "YES", "1", "on", "On", "ON"] for value in true_values: with patch.dict(os.environ, {"TEST_BOOL": value}): assert get_env_bool("TEST_BOOL") is True def test_get_env_bool_false_values(self): """测试 False 值。""" false_values = ["false", "False", "FALSE", "no", "No", "NO", "0", "off", "Off", "OFF"] for value in false_values: with patch.dict(os.environ, {"TEST_BOOL": value}): assert get_env_bool("TEST_BOOL") is False def test_get_env_bool_default(self): """测试默认值。""" with patch.dict(os.environ, {}, clear=True): assert get_env_bool("NONEXISTENT_BOOL") is False assert get_env_bool("NONEXISTENT_BOOL", True) is True class TestGetEnvInt: """测试获取整数类型环境变量。""" def test_get_env_int_valid(self): """测试有效的整数。""" with patch.dict(os.environ, {"TEST_INT": "123"}): assert get_env_int("TEST_INT") == 123 def test_get_env_int_negative(self): """测试负整数。""" with patch.dict(os.environ, {"TEST_INT": "-456"}): assert get_env_int("TEST_INT") == -456 def test_get_env_int_invalid(self): """测试无效的整数。""" with patch.dict(os.environ, {"TEST_INT": "not_a_number"}): assert get_env_int("TEST_INT") == 0 assert get_env_int("TEST_INT", 999) == 999 def test_get_env_int_default(self): """测试默认值。""" with patch.dict(os.environ, {}, clear=True): assert get_env_int("NONEXISTENT_INT") == 0 assert get_env_int("NONEXISTENT_INT", 42) == 42 class TestGetEnvFloat: """测试获取浮点数类型环境变量。""" def test_get_env_float_valid(self): """测试有效的浮点数。""" with patch.dict(os.environ, {"TEST_FLOAT": "3.14"}): assert get_env_float("TEST_FLOAT") == 3.14 def test_get_env_float_negative(self): """测试负浮点数。""" with patch.dict(os.environ, {"TEST_FLOAT": "-2.5"}): assert get_env_float("TEST_FLOAT") == -2.5 def test_get_env_float_invalid(self): """测试无效的浮点数。""" with patch.dict(os.environ, {"TEST_FLOAT": "not_a_number"}): assert get_env_float("TEST_FLOAT") == 0.0 assert get_env_float("TEST_FLOAT", 9.99) == 9.99 def test_get_env_float_default(self): """测试默认值。""" with patch.dict(os.environ, {}, clear=True): assert get_env_float("NONEXISTENT_FLOAT") == 0.0 assert get_env_float("NONEXISTENT_FLOAT", 1.5) == 1.5 class TestValidateRequiredEnvVars: """测试验证必需的环境变量。""" def test_validate_all_present(self): """测试所有必需的环境变量都存在。""" with patch.dict(os.environ, {"KEY1": "value1", "KEY2": "value2", "KEY3": "value3"}): assert validate_required_env_vars(["KEY1", "KEY2", "KEY3"]) is True def test_validate_some_missing(self): """测试部分环境变量缺失。""" with patch.dict(os.environ, {"KEY1": "value1"}, clear=True): assert validate_required_env_vars(["KEY1", "KEY2", "KEY3"]) is False def test_validate_all_missing(self): """测试所有环境变量都缺失。""" with patch.dict(os.environ, {}, clear=True): assert validate_required_env_vars(["KEY1", "KEY2"]) is False def test_validate_empty_list(self): """测试空列表。""" assert validate_required_env_vars([]) is True