2026-03-09 14:05:00 +08:00
|
|
|
"""
|
|
|
|
|
Security utilities for password hashing and JWT token management
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import bcrypt
|
|
|
|
|
import hashlib
|
|
|
|
|
import jwt
|
|
|
|
|
import secrets
|
|
|
|
|
from datetime import datetime, timedelta
|
2026-03-09 16:10:29 +08:00
|
|
|
from typing import Optional, Dict
|
2026-03-09 14:05:00 +08:00
|
|
|
|
|
|
|
|
from shared.config import shared_settings
|
|
|
|
|
|
|
|
|
|
# Auth-specific defaults
|
|
|
|
|
BCRYPT_ROUNDS = 12
|
|
|
|
|
REFRESH_TOKEN_TTL = 7 * 24 * 3600 # 7 days in seconds
|
|
|
|
|
|
2026-03-09 16:10:29 +08:00
|
|
|
# Lazy-initialised async Redis client (可选)
|
|
|
|
|
_redis_client: Optional[object] = None
|
|
|
|
|
|
|
|
|
|
# 内存存储(本地开发用,不使用 Redis 时)
|
|
|
|
|
_memory_store: Dict[str, tuple[str, datetime]] = {}
|
2026-03-09 14:05:00 +08:00
|
|
|
|
|
|
|
|
|
2026-03-09 16:10:29 +08:00
|
|
|
async def get_redis():
|
|
|
|
|
"""Return a shared async Redis connection if enabled."""
|
|
|
|
|
if not shared_settings.USE_REDIS:
|
|
|
|
|
return None
|
|
|
|
|
|
2026-03-09 14:05:00 +08:00
|
|
|
global _redis_client
|
|
|
|
|
if _redis_client is None:
|
2026-03-09 16:10:29 +08:00
|
|
|
try:
|
|
|
|
|
import redis.asyncio as aioredis
|
|
|
|
|
_redis_client = aioredis.from_url(
|
|
|
|
|
shared_settings.REDIS_URL, decode_responses=True
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"[警告] Redis 连接失败: {e},将使用内存存储")
|
|
|
|
|
return None
|
2026-03-09 14:05:00 +08:00
|
|
|
return _redis_client
|
|
|
|
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Hash a password using bcrypt
|
|
|
|
|
Returns the hashed password as a string
|
|
|
|
|
"""
|
|
|
|
|
salt = bcrypt.gensalt(rounds=BCRYPT_ROUNDS)
|
|
|
|
|
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
|
|
|
|
|
return hashed.decode('utf-8')
|
|
|
|
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Verify a plain text password against a hashed password
|
|
|
|
|
Returns True if passwords match, False otherwise
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
return bcrypt.checkpw(
|
|
|
|
|
plain_password.encode('utf-8'),
|
|
|
|
|
hashed_password.encode('utf-8')
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Create a JWT access token
|
|
|
|
|
"""
|
|
|
|
|
to_encode = data.copy()
|
|
|
|
|
|
|
|
|
|
if expires_delta:
|
|
|
|
|
expire = datetime.utcnow() + expires_delta
|
|
|
|
|
else:
|
|
|
|
|
expire = datetime.utcnow() + timedelta(hours=shared_settings.JWT_EXPIRATION_HOURS)
|
|
|
|
|
|
|
|
|
|
to_encode.update({"exp": expire})
|
|
|
|
|
encoded_jwt = jwt.encode(to_encode, shared_settings.JWT_SECRET_KEY, algorithm=shared_settings.JWT_ALGORITHM)
|
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
|
|
|
|
def decode_access_token(token: str) -> Optional[dict]:
|
|
|
|
|
"""
|
|
|
|
|
Decode and validate a JWT access token
|
|
|
|
|
Returns the payload if valid, None otherwise
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
payload = jwt.decode(token, shared_settings.JWT_SECRET_KEY, algorithms=[shared_settings.JWT_ALGORITHM])
|
|
|
|
|
return payload
|
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
|
|
|
return None
|
|
|
|
|
except jwt.InvalidTokenError:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def generate_password_reset_token(email: str) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Generate a secure token for password reset
|
|
|
|
|
"""
|
|
|
|
|
data = {"email": email, "type": "password_reset"}
|
|
|
|
|
return create_access_token(data, timedelta(hours=1))
|
|
|
|
|
|
|
|
|
|
# Password strength validation
|
|
|
|
|
def validate_password_strength(password: str) -> tuple[bool, str]:
|
|
|
|
|
"""
|
|
|
|
|
Validate password meets strength requirements
|
|
|
|
|
Returns (is_valid, error_message)
|
|
|
|
|
"""
|
|
|
|
|
if len(password) < 8:
|
|
|
|
|
return False, "Password must be at least 8 characters long"
|
|
|
|
|
|
|
|
|
|
if not any(c.isupper() for c in password):
|
|
|
|
|
return False, "Password must contain at least one uppercase letter"
|
|
|
|
|
|
|
|
|
|
if not any(c.islower() for c in password):
|
|
|
|
|
return False, "Password must contain at least one lowercase letter"
|
|
|
|
|
|
|
|
|
|
if not any(c.isdigit() for c in password):
|
|
|
|
|
return False, "Password must contain at least one digit"
|
|
|
|
|
|
|
|
|
|
if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
|
|
|
|
|
return False, "Password must contain at least one special character"
|
|
|
|
|
|
|
|
|
|
return True, "Password is strong"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# --------------- Refresh Token helpers ---------------
|
|
|
|
|
|
|
|
|
|
def _hash_token(token: str) -> str:
|
|
|
|
|
"""SHA-256 hash of a refresh token for safe Redis key storage."""
|
|
|
|
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
2026-03-09 16:10:29 +08:00
|
|
|
def _clean_expired_tokens():
|
|
|
|
|
"""清理过期的内存 token"""
|
|
|
|
|
now = datetime.utcnow()
|
|
|
|
|
expired_keys = [k for k, (_, exp) in _memory_store.items() if exp < now]
|
|
|
|
|
for k in expired_keys:
|
|
|
|
|
del _memory_store[k]
|
|
|
|
|
|
|
|
|
|
|
2026-03-09 14:05:00 +08:00
|
|
|
async def create_refresh_token(user_id: str) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Generate a cryptographically random refresh token, store its hash in Redis
|
2026-03-09 16:10:29 +08:00
|
|
|
(or memory if Redis disabled) with a 7-day TTL, and return the raw token string.
|
2026-03-09 14:05:00 +08:00
|
|
|
"""
|
|
|
|
|
token = secrets.token_urlsafe(48)
|
|
|
|
|
token_hash = _hash_token(token)
|
2026-03-09 16:10:29 +08:00
|
|
|
|
2026-03-09 14:05:00 +08:00
|
|
|
r = await get_redis()
|
2026-03-09 16:10:29 +08:00
|
|
|
if r:
|
|
|
|
|
# 使用 Redis
|
|
|
|
|
await r.setex(f"refresh_token:{token_hash}", REFRESH_TOKEN_TTL, user_id)
|
|
|
|
|
else:
|
|
|
|
|
# 使用内存存储
|
|
|
|
|
_clean_expired_tokens()
|
|
|
|
|
expire_at = datetime.utcnow() + timedelta(seconds=REFRESH_TOKEN_TTL)
|
|
|
|
|
_memory_store[f"refresh_token:{token_hash}"] = (user_id, expire_at)
|
|
|
|
|
|
2026-03-09 14:05:00 +08:00
|
|
|
return token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def verify_refresh_token(token: str) -> Optional[str]:
|
|
|
|
|
"""
|
2026-03-09 16:10:29 +08:00
|
|
|
Verify a refresh token by looking up its hash in Redis or memory.
|
2026-03-09 14:05:00 +08:00
|
|
|
Returns the associated user_id if valid, None otherwise.
|
|
|
|
|
"""
|
|
|
|
|
token_hash = _hash_token(token)
|
2026-03-09 16:10:29 +08:00
|
|
|
key = f"refresh_token:{token_hash}"
|
|
|
|
|
|
2026-03-09 14:05:00 +08:00
|
|
|
r = await get_redis()
|
2026-03-09 16:10:29 +08:00
|
|
|
if r:
|
|
|
|
|
# 使用 Redis
|
|
|
|
|
user_id = await r.get(key)
|
|
|
|
|
return user_id
|
|
|
|
|
else:
|
|
|
|
|
# 使用内存存储
|
|
|
|
|
_clean_expired_tokens()
|
|
|
|
|
if key in _memory_store:
|
|
|
|
|
user_id, expire_at = _memory_store[key]
|
|
|
|
|
if expire_at > datetime.utcnow():
|
|
|
|
|
return user_id
|
|
|
|
|
return None
|
2026-03-09 14:05:00 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def revoke_refresh_token(token: str) -> None:
|
2026-03-09 16:10:29 +08:00
|
|
|
"""Delete a refresh token from Redis or memory (used during rotation)."""
|
2026-03-09 14:05:00 +08:00
|
|
|
token_hash = _hash_token(token)
|
2026-03-09 16:10:29 +08:00
|
|
|
key = f"refresh_token:{token_hash}"
|
|
|
|
|
|
2026-03-09 14:05:00 +08:00
|
|
|
r = await get_redis()
|
2026-03-09 16:10:29 +08:00
|
|
|
if r:
|
|
|
|
|
# 使用 Redis
|
|
|
|
|
await r.delete(key)
|
|
|
|
|
else:
|
|
|
|
|
# 使用内存存储
|
|
|
|
|
if key in _memory_store:
|
|
|
|
|
del _memory_store[key]
|