256 lines
8.7 KiB
Python
256 lines
8.7 KiB
Python
"""环境变量加载器的单元测试。"""
|
||
|
||
import os
|
||
import pytest
|
||
from pathlib import Path
|
||
from unittest.mock import patch
|
||
|
||
from src.env_loader import (
|
||
load_env_file,
|
||
load_env_with_fallback,
|
||
get_env,
|
||
get_env_bool,
|
||
get_env_int,
|
||
get_env_float,
|
||
validate_required_env_vars
|
||
)
|
||
|
||
|
||
class TestLoadEnvFile:
|
||
"""测试加载 .env 文件。"""
|
||
|
||
def test_load_env_file_success(self, tmp_path):
|
||
"""测试成功加载 .env 文件。"""
|
||
env_file = tmp_path / ".env"
|
||
env_file.write_text("""
|
||
# This is a comment
|
||
KEY1=value1
|
||
KEY2="value2"
|
||
KEY3='value3'
|
||
KEY4=value with spaces
|
||
|
||
# Another comment
|
||
KEY5=123
|
||
""", encoding='utf-8')
|
||
|
||
# 清空环境变量
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
result = load_env_file(str(env_file))
|
||
|
||
assert result is True
|
||
assert os.getenv("KEY1") == "value1"
|
||
assert os.getenv("KEY2") == "value2"
|
||
assert os.getenv("KEY3") == "value3"
|
||
assert os.getenv("KEY4") == "value with spaces"
|
||
assert os.getenv("KEY5") == "123"
|
||
|
||
def test_load_env_file_not_found(self):
|
||
"""测试加载不存在的 .env 文件。"""
|
||
result = load_env_file("nonexistent.env")
|
||
assert result is False
|
||
|
||
def test_load_env_file_skip_existing(self, tmp_path):
|
||
"""测试跳过已存在的环境变量。"""
|
||
env_file = tmp_path / ".env"
|
||
env_file.write_text("KEY1=from_file\nKEY2=from_file")
|
||
|
||
# 设置一个已存在的环境变量
|
||
with patch.dict(os.environ, {"KEY1": "from_env"}, clear=True):
|
||
load_env_file(str(env_file))
|
||
|
||
# KEY1 应该保持原值(环境变量优先)
|
||
assert os.getenv("KEY1") == "from_env"
|
||
# KEY2 应该从文件加载
|
||
assert os.getenv("KEY2") == "from_file"
|
||
|
||
def test_load_env_file_skip_invalid_lines(self, tmp_path):
|
||
"""测试跳过无效行。"""
|
||
env_file = tmp_path / ".env"
|
||
env_file.write_text("""
|
||
VALID_KEY=valid_value
|
||
invalid line without equals
|
||
ANOTHER_VALID=another_value
|
||
""")
|
||
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
result = load_env_file(str(env_file))
|
||
|
||
assert result is True
|
||
assert os.getenv("VALID_KEY") == "valid_value"
|
||
assert os.getenv("ANOTHER_VALID") == "another_value"
|
||
|
||
def test_load_env_file_empty_lines(self, tmp_path):
|
||
"""测试处理空行。"""
|
||
env_file = tmp_path / ".env"
|
||
env_file.write_text("""
|
||
KEY1=value1
|
||
|
||
KEY2=value2
|
||
|
||
|
||
KEY3=value3
|
||
""")
|
||
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
result = load_env_file(str(env_file))
|
||
|
||
assert result is True
|
||
assert os.getenv("KEY1") == "value1"
|
||
assert os.getenv("KEY2") == "value2"
|
||
assert os.getenv("KEY3") == "value3"
|
||
|
||
|
||
class TestLoadEnvWithFallback:
|
||
"""测试按优先级加载多个 .env 文件。"""
|
||
|
||
def test_load_multiple_files(self, tmp_path):
|
||
"""测试加载多个文件。"""
|
||
env_file1 = tmp_path / ".env.local"
|
||
env_file1.write_text("KEY1=local\nKEY2=local")
|
||
|
||
env_file2 = tmp_path / ".env"
|
||
env_file2.write_text("KEY1=default\nKEY3=default")
|
||
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
# 切换到临时目录
|
||
original_dir = os.getcwd()
|
||
os.chdir(tmp_path)
|
||
|
||
try:
|
||
result = load_env_with_fallback([".env.local", ".env"])
|
||
|
||
assert result is True
|
||
# KEY1 应该来自 .env.local(优先级更高)
|
||
assert os.getenv("KEY1") == "local"
|
||
# KEY2 应该来自 .env.local
|
||
assert os.getenv("KEY2") == "local"
|
||
# KEY3 应该来自 .env
|
||
assert os.getenv("KEY3") == "default"
|
||
finally:
|
||
os.chdir(original_dir)
|
||
|
||
def test_load_no_files_found(self):
|
||
"""测试没有找到任何文件。"""
|
||
result = load_env_with_fallback(["nonexistent1.env", "nonexistent2.env"])
|
||
assert result is False
|
||
|
||
|
||
class TestGetEnv:
|
||
"""测试获取环境变量。"""
|
||
|
||
def test_get_env_exists(self):
|
||
"""测试获取存在的环境变量。"""
|
||
with patch.dict(os.environ, {"TEST_KEY": "test_value"}):
|
||
assert get_env("TEST_KEY") == "test_value"
|
||
|
||
def test_get_env_not_exists(self):
|
||
"""测试获取不存在的环境变量。"""
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
assert get_env("NONEXISTENT_KEY") is None
|
||
|
||
def test_get_env_with_default(self):
|
||
"""测试使用默认值。"""
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
assert get_env("NONEXISTENT_KEY", "default") == "default"
|
||
|
||
|
||
class TestGetEnvBool:
|
||
"""测试获取布尔类型环境变量。"""
|
||
|
||
def test_get_env_bool_true_values(self):
|
||
"""测试 True 值。"""
|
||
true_values = ["true", "True", "TRUE", "yes", "Yes", "YES", "1", "on", "On", "ON"]
|
||
|
||
for value in true_values:
|
||
with patch.dict(os.environ, {"TEST_BOOL": value}):
|
||
assert get_env_bool("TEST_BOOL") is True
|
||
|
||
def test_get_env_bool_false_values(self):
|
||
"""测试 False 值。"""
|
||
false_values = ["false", "False", "FALSE", "no", "No", "NO", "0", "off", "Off", "OFF"]
|
||
|
||
for value in false_values:
|
||
with patch.dict(os.environ, {"TEST_BOOL": value}):
|
||
assert get_env_bool("TEST_BOOL") is False
|
||
|
||
def test_get_env_bool_default(self):
|
||
"""测试默认值。"""
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
assert get_env_bool("NONEXISTENT_BOOL") is False
|
||
assert get_env_bool("NONEXISTENT_BOOL", True) is True
|
||
|
||
|
||
class TestGetEnvInt:
|
||
"""测试获取整数类型环境变量。"""
|
||
|
||
def test_get_env_int_valid(self):
|
||
"""测试有效的整数。"""
|
||
with patch.dict(os.environ, {"TEST_INT": "123"}):
|
||
assert get_env_int("TEST_INT") == 123
|
||
|
||
def test_get_env_int_negative(self):
|
||
"""测试负整数。"""
|
||
with patch.dict(os.environ, {"TEST_INT": "-456"}):
|
||
assert get_env_int("TEST_INT") == -456
|
||
|
||
def test_get_env_int_invalid(self):
|
||
"""测试无效的整数。"""
|
||
with patch.dict(os.environ, {"TEST_INT": "not_a_number"}):
|
||
assert get_env_int("TEST_INT") == 0
|
||
assert get_env_int("TEST_INT", 999) == 999
|
||
|
||
def test_get_env_int_default(self):
|
||
"""测试默认值。"""
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
assert get_env_int("NONEXISTENT_INT") == 0
|
||
assert get_env_int("NONEXISTENT_INT", 42) == 42
|
||
|
||
|
||
class TestGetEnvFloat:
|
||
"""测试获取浮点数类型环境变量。"""
|
||
|
||
def test_get_env_float_valid(self):
|
||
"""测试有效的浮点数。"""
|
||
with patch.dict(os.environ, {"TEST_FLOAT": "3.14"}):
|
||
assert get_env_float("TEST_FLOAT") == 3.14
|
||
|
||
def test_get_env_float_negative(self):
|
||
"""测试负浮点数。"""
|
||
with patch.dict(os.environ, {"TEST_FLOAT": "-2.5"}):
|
||
assert get_env_float("TEST_FLOAT") == -2.5
|
||
|
||
def test_get_env_float_invalid(self):
|
||
"""测试无效的浮点数。"""
|
||
with patch.dict(os.environ, {"TEST_FLOAT": "not_a_number"}):
|
||
assert get_env_float("TEST_FLOAT") == 0.0
|
||
assert get_env_float("TEST_FLOAT", 9.99) == 9.99
|
||
|
||
def test_get_env_float_default(self):
|
||
"""测试默认值。"""
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
assert get_env_float("NONEXISTENT_FLOAT") == 0.0
|
||
assert get_env_float("NONEXISTENT_FLOAT", 1.5) == 1.5
|
||
|
||
|
||
class TestValidateRequiredEnvVars:
|
||
"""测试验证必需的环境变量。"""
|
||
|
||
def test_validate_all_present(self):
|
||
"""测试所有必需的环境变量都存在。"""
|
||
with patch.dict(os.environ, {"KEY1": "value1", "KEY2": "value2", "KEY3": "value3"}):
|
||
assert validate_required_env_vars(["KEY1", "KEY2", "KEY3"]) is True
|
||
|
||
def test_validate_some_missing(self):
|
||
"""测试部分环境变量缺失。"""
|
||
with patch.dict(os.environ, {"KEY1": "value1"}, clear=True):
|
||
assert validate_required_env_vars(["KEY1", "KEY2", "KEY3"]) is False
|
||
|
||
def test_validate_all_missing(self):
|
||
"""测试所有环境变量都缺失。"""
|
||
with patch.dict(os.environ, {}, clear=True):
|
||
assert validate_required_env_vars(["KEY1", "KEY2"]) is False
|
||
|
||
def test_validate_empty_list(self):
|
||
"""测试空列表。"""
|
||
assert validate_required_env_vars([]) is True
|