Files
weibo_signin/backend/auth_service/app/utils/security.py
2026-03-09 14:05:00 +08:00

149 lines
4.5 KiB
Python

"""
Security utilities for password hashing and JWT token management
"""
import bcrypt
import hashlib
import jwt
import secrets
from datetime import datetime, timedelta
from typing import Optional
import redis.asyncio as aioredis
from shared.config import shared_settings
# Auth-specific defaults
BCRYPT_ROUNDS = 12
REFRESH_TOKEN_TTL = 7 * 24 * 3600 # 7 days in seconds
# Lazy-initialised async Redis client
_redis_client: Optional[aioredis.Redis] = None
async def get_redis() -> aioredis.Redis:
"""Return a shared async Redis connection."""
global _redis_client
if _redis_client is None:
_redis_client = aioredis.from_url(
shared_settings.REDIS_URL, decode_responses=True
)
return _redis_client
def hash_password(password: str) -> str:
"""
Hash a password using bcrypt
Returns the hashed password as a string
"""
salt = bcrypt.gensalt(rounds=BCRYPT_ROUNDS)
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
return hashed.decode('utf-8')
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify a plain text password against a hashed password
Returns True if passwords match, False otherwise
"""
try:
return bcrypt.checkpw(
plain_password.encode('utf-8'),
hashed_password.encode('utf-8')
)
except Exception:
return False
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(hours=shared_settings.JWT_EXPIRATION_HOURS)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, shared_settings.JWT_SECRET_KEY, algorithm=shared_settings.JWT_ALGORITHM)
return encoded_jwt
def decode_access_token(token: str) -> Optional[dict]:
"""
Decode and validate a JWT access token
Returns the payload if valid, None otherwise
"""
try:
payload = jwt.decode(token, shared_settings.JWT_SECRET_KEY, algorithms=[shared_settings.JWT_ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def generate_password_reset_token(email: str) -> str:
"""
Generate a secure token for password reset
"""
data = {"email": email, "type": "password_reset"}
return create_access_token(data, timedelta(hours=1))
# Password strength validation
def validate_password_strength(password: str) -> tuple[bool, str]:
"""
Validate password meets strength requirements
Returns (is_valid, error_message)
"""
if len(password) < 8:
return False, "Password must be at least 8 characters long"
if not any(c.isupper() for c in password):
return False, "Password must contain at least one uppercase letter"
if not any(c.islower() for c in password):
return False, "Password must contain at least one lowercase letter"
if not any(c.isdigit() for c in password):
return False, "Password must contain at least one digit"
if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
return False, "Password must contain at least one special character"
return True, "Password is strong"
# --------------- Refresh Token helpers ---------------
def _hash_token(token: str) -> str:
"""SHA-256 hash of a refresh token for safe Redis key storage."""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
async def create_refresh_token(user_id: str) -> str:
"""
Generate a cryptographically random refresh token, store its hash in Redis
with a 7-day TTL, and return the raw token string.
"""
token = secrets.token_urlsafe(48)
token_hash = _hash_token(token)
r = await get_redis()
await r.setex(f"refresh_token:{token_hash}", REFRESH_TOKEN_TTL, user_id)
return token
async def verify_refresh_token(token: str) -> Optional[str]:
"""
Verify a refresh token by looking up its hash in Redis.
Returns the associated user_id if valid, None otherwise.
"""
token_hash = _hash_token(token)
r = await get_redis()
user_id = await r.get(f"refresh_token:{token_hash}")
return user_id
async def revoke_refresh_token(token: str) -> None:
"""Delete a refresh token from Redis (used during rotation)."""
token_hash = _hash_token(token)
r = await get_redis()
await r.delete(f"refresh_token:{token_hash}")