Files
weibo_signin/backend/auth_service/app/utils/security.py
2026-03-09 16:10:29 +08:00

195 lines
5.9 KiB
Python

"""
Security utilities for password hashing and JWT token management
"""
import bcrypt
import hashlib
import jwt
import secrets
from datetime import datetime, timedelta
from typing import Optional, Dict
from shared.config import shared_settings
# Auth-specific defaults
BCRYPT_ROUNDS = 12
REFRESH_TOKEN_TTL = 7 * 24 * 3600 # 7 days in seconds
# Lazy-initialised async Redis client (可选)
_redis_client: Optional[object] = None
# 内存存储(本地开发用,不使用 Redis 时)
_memory_store: Dict[str, tuple[str, datetime]] = {}
async def get_redis():
"""Return a shared async Redis connection if enabled."""
if not shared_settings.USE_REDIS:
return None
global _redis_client
if _redis_client is None:
try:
import redis.asyncio as aioredis
_redis_client = aioredis.from_url(
shared_settings.REDIS_URL, decode_responses=True
)
except Exception as e:
print(f"[警告] Redis 连接失败: {e},将使用内存存储")
return None
return _redis_client
def hash_password(password: str) -> str:
"""
Hash a password using bcrypt
Returns the hashed password as a string
"""
salt = bcrypt.gensalt(rounds=BCRYPT_ROUNDS)
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
return hashed.decode('utf-8')
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify a plain text password against a hashed password
Returns True if passwords match, False otherwise
"""
try:
return bcrypt.checkpw(
plain_password.encode('utf-8'),
hashed_password.encode('utf-8')
)
except Exception:
return False
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(hours=shared_settings.JWT_EXPIRATION_HOURS)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, shared_settings.JWT_SECRET_KEY, algorithm=shared_settings.JWT_ALGORITHM)
return encoded_jwt
def decode_access_token(token: str) -> Optional[dict]:
"""
Decode and validate a JWT access token
Returns the payload if valid, None otherwise
"""
try:
payload = jwt.decode(token, shared_settings.JWT_SECRET_KEY, algorithms=[shared_settings.JWT_ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def generate_password_reset_token(email: str) -> str:
"""
Generate a secure token for password reset
"""
data = {"email": email, "type": "password_reset"}
return create_access_token(data, timedelta(hours=1))
# Password strength validation
def validate_password_strength(password: str) -> tuple[bool, str]:
"""
Validate password meets strength requirements
Returns (is_valid, error_message)
"""
if len(password) < 8:
return False, "Password must be at least 8 characters long"
if not any(c.isupper() for c in password):
return False, "Password must contain at least one uppercase letter"
if not any(c.islower() for c in password):
return False, "Password must contain at least one lowercase letter"
if not any(c.isdigit() for c in password):
return False, "Password must contain at least one digit"
if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
return False, "Password must contain at least one special character"
return True, "Password is strong"
# --------------- Refresh Token helpers ---------------
def _hash_token(token: str) -> str:
"""SHA-256 hash of a refresh token for safe Redis key storage."""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _clean_expired_tokens():
"""清理过期的内存 token"""
now = datetime.utcnow()
expired_keys = [k for k, (_, exp) in _memory_store.items() if exp < now]
for k in expired_keys:
del _memory_store[k]
async def create_refresh_token(user_id: str) -> str:
"""
Generate a cryptographically random refresh token, store its hash in Redis
(or memory if Redis disabled) with a 7-day TTL, and return the raw token string.
"""
token = secrets.token_urlsafe(48)
token_hash = _hash_token(token)
r = await get_redis()
if r:
# 使用 Redis
await r.setex(f"refresh_token:{token_hash}", REFRESH_TOKEN_TTL, user_id)
else:
# 使用内存存储
_clean_expired_tokens()
expire_at = datetime.utcnow() + timedelta(seconds=REFRESH_TOKEN_TTL)
_memory_store[f"refresh_token:{token_hash}"] = (user_id, expire_at)
return token
async def verify_refresh_token(token: str) -> Optional[str]:
"""
Verify a refresh token by looking up its hash in Redis or memory.
Returns the associated user_id if valid, None otherwise.
"""
token_hash = _hash_token(token)
key = f"refresh_token:{token_hash}"
r = await get_redis()
if r:
# 使用 Redis
user_id = await r.get(key)
return user_id
else:
# 使用内存存储
_clean_expired_tokens()
if key in _memory_store:
user_id, expire_at = _memory_store[key]
if expire_at > datetime.utcnow():
return user_id
return None
async def revoke_refresh_token(token: str) -> None:
"""Delete a refresh token from Redis or memory (used during rotation)."""
token_hash = _hash_token(token)
key = f"refresh_token:{token_hash}"
r = await get_redis()
if r:
# 使用 Redis
await r.delete(key)
else:
# 使用内存存储
if key in _memory_store:
del _memory_store[key]