Files
weibo_signin/backend/auth_service/app/utils/security.py

195 lines
5.9 KiB
Python
Raw Normal View History

2026-03-09 14:05:00 +08:00
"""
Security utilities for password hashing and JWT token management
"""
import bcrypt
import hashlib
import jwt
import secrets
from datetime import datetime, timedelta
2026-03-09 16:10:29 +08:00
from typing import Optional, Dict
2026-03-09 14:05:00 +08:00
from shared.config import shared_settings
# Auth-specific defaults
BCRYPT_ROUNDS = 12
REFRESH_TOKEN_TTL = 7 * 24 * 3600 # 7 days in seconds
2026-03-09 16:10:29 +08:00
# Lazy-initialised async Redis client (可选)
_redis_client: Optional[object] = None
# 内存存储(本地开发用,不使用 Redis 时)
_memory_store: Dict[str, tuple[str, datetime]] = {}
2026-03-09 14:05:00 +08:00
2026-03-09 16:10:29 +08:00
async def get_redis():
"""Return a shared async Redis connection if enabled."""
if not shared_settings.USE_REDIS:
return None
2026-03-09 14:05:00 +08:00
global _redis_client
if _redis_client is None:
2026-03-09 16:10:29 +08:00
try:
import redis.asyncio as aioredis
_redis_client = aioredis.from_url(
shared_settings.REDIS_URL, decode_responses=True
)
except Exception as e:
print(f"[警告] Redis 连接失败: {e},将使用内存存储")
return None
2026-03-09 14:05:00 +08:00
return _redis_client
def hash_password(password: str) -> str:
"""
Hash a password using bcrypt
Returns the hashed password as a string
"""
salt = bcrypt.gensalt(rounds=BCRYPT_ROUNDS)
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
return hashed.decode('utf-8')
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify a plain text password against a hashed password
Returns True if passwords match, False otherwise
"""
try:
return bcrypt.checkpw(
plain_password.encode('utf-8'),
hashed_password.encode('utf-8')
)
except Exception:
return False
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(hours=shared_settings.JWT_EXPIRATION_HOURS)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, shared_settings.JWT_SECRET_KEY, algorithm=shared_settings.JWT_ALGORITHM)
return encoded_jwt
def decode_access_token(token: str) -> Optional[dict]:
"""
Decode and validate a JWT access token
Returns the payload if valid, None otherwise
"""
try:
payload = jwt.decode(token, shared_settings.JWT_SECRET_KEY, algorithms=[shared_settings.JWT_ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def generate_password_reset_token(email: str) -> str:
"""
Generate a secure token for password reset
"""
data = {"email": email, "type": "password_reset"}
return create_access_token(data, timedelta(hours=1))
# Password strength validation
def validate_password_strength(password: str) -> tuple[bool, str]:
"""
Validate password meets strength requirements
Returns (is_valid, error_message)
"""
if len(password) < 8:
return False, "Password must be at least 8 characters long"
if not any(c.isupper() for c in password):
return False, "Password must contain at least one uppercase letter"
if not any(c.islower() for c in password):
return False, "Password must contain at least one lowercase letter"
if not any(c.isdigit() for c in password):
return False, "Password must contain at least one digit"
if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
return False, "Password must contain at least one special character"
return True, "Password is strong"
# --------------- Refresh Token helpers ---------------
def _hash_token(token: str) -> str:
"""SHA-256 hash of a refresh token for safe Redis key storage."""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
2026-03-09 16:10:29 +08:00
def _clean_expired_tokens():
"""清理过期的内存 token"""
now = datetime.utcnow()
expired_keys = [k for k, (_, exp) in _memory_store.items() if exp < now]
for k in expired_keys:
del _memory_store[k]
2026-03-09 14:05:00 +08:00
async def create_refresh_token(user_id: str) -> str:
"""
Generate a cryptographically random refresh token, store its hash in Redis
2026-03-09 16:10:29 +08:00
(or memory if Redis disabled) with a 7-day TTL, and return the raw token string.
2026-03-09 14:05:00 +08:00
"""
token = secrets.token_urlsafe(48)
token_hash = _hash_token(token)
2026-03-09 16:10:29 +08:00
2026-03-09 14:05:00 +08:00
r = await get_redis()
2026-03-09 16:10:29 +08:00
if r:
# 使用 Redis
await r.setex(f"refresh_token:{token_hash}", REFRESH_TOKEN_TTL, user_id)
else:
# 使用内存存储
_clean_expired_tokens()
expire_at = datetime.utcnow() + timedelta(seconds=REFRESH_TOKEN_TTL)
_memory_store[f"refresh_token:{token_hash}"] = (user_id, expire_at)
2026-03-09 14:05:00 +08:00
return token
async def verify_refresh_token(token: str) -> Optional[str]:
"""
2026-03-09 16:10:29 +08:00
Verify a refresh token by looking up its hash in Redis or memory.
2026-03-09 14:05:00 +08:00
Returns the associated user_id if valid, None otherwise.
"""
token_hash = _hash_token(token)
2026-03-09 16:10:29 +08:00
key = f"refresh_token:{token_hash}"
2026-03-09 14:05:00 +08:00
r = await get_redis()
2026-03-09 16:10:29 +08:00
if r:
# 使用 Redis
user_id = await r.get(key)
return user_id
else:
# 使用内存存储
_clean_expired_tokens()
if key in _memory_store:
user_id, expire_at = _memory_store[key]
if expire_at > datetime.utcnow():
return user_id
return None
2026-03-09 14:05:00 +08:00
async def revoke_refresh_token(token: str) -> None:
2026-03-09 16:10:29 +08:00
"""Delete a refresh token from Redis or memory (used during rotation)."""
2026-03-09 14:05:00 +08:00
token_hash = _hash_token(token)
2026-03-09 16:10:29 +08:00
key = f"refresh_token:{token_hash}"
2026-03-09 14:05:00 +08:00
r = await get_redis()
2026-03-09 16:10:29 +08:00
if r:
# 使用 Redis
await r.delete(key)
else:
# 使用内存存储
if key in _memory_store:
del _memory_store[key]