123
This commit is contained in:
148
backend/auth_service/app/utils/security.py
Normal file
148
backend/auth_service/app/utils/security.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Security utilities for password hashing and JWT token management
|
||||
"""
|
||||
|
||||
import bcrypt
|
||||
import hashlib
|
||||
import jwt
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from shared.config import shared_settings
|
||||
|
||||
# Auth-specific defaults
|
||||
BCRYPT_ROUNDS = 12
|
||||
REFRESH_TOKEN_TTL = 7 * 24 * 3600 # 7 days in seconds
|
||||
|
||||
# Lazy-initialised async Redis client
|
||||
_redis_client: Optional[aioredis.Redis] = None
|
||||
|
||||
|
||||
async def get_redis() -> aioredis.Redis:
|
||||
"""Return a shared async Redis connection."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = aioredis.from_url(
|
||||
shared_settings.REDIS_URL, decode_responses=True
|
||||
)
|
||||
return _redis_client
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""
|
||||
Hash a password using bcrypt
|
||||
Returns the hashed password as a string
|
||||
"""
|
||||
salt = bcrypt.gensalt(rounds=BCRYPT_ROUNDS)
|
||||
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
|
||||
return hashed.decode('utf-8')
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
Verify a plain text password against a hashed password
|
||||
Returns True if passwords match, False otherwise
|
||||
"""
|
||||
try:
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode('utf-8'),
|
||||
hashed_password.encode('utf-8')
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
Create a JWT access token
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(hours=shared_settings.JWT_EXPIRATION_HOURS)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, shared_settings.JWT_SECRET_KEY, algorithm=shared_settings.JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def decode_access_token(token: str) -> Optional[dict]:
|
||||
"""
|
||||
Decode and validate a JWT access token
|
||||
Returns the payload if valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, shared_settings.JWT_SECRET_KEY, algorithms=[shared_settings.JWT_ALGORITHM])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
|
||||
def generate_password_reset_token(email: str) -> str:
|
||||
"""
|
||||
Generate a secure token for password reset
|
||||
"""
|
||||
data = {"email": email, "type": "password_reset"}
|
||||
return create_access_token(data, timedelta(hours=1))
|
||||
|
||||
# Password strength validation
|
||||
def validate_password_strength(password: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate password meets strength requirements
|
||||
Returns (is_valid, error_message)
|
||||
"""
|
||||
if len(password) < 8:
|
||||
return False, "Password must be at least 8 characters long"
|
||||
|
||||
if not any(c.isupper() for c in password):
|
||||
return False, "Password must contain at least one uppercase letter"
|
||||
|
||||
if not any(c.islower() for c in password):
|
||||
return False, "Password must contain at least one lowercase letter"
|
||||
|
||||
if not any(c.isdigit() for c in password):
|
||||
return False, "Password must contain at least one digit"
|
||||
|
||||
if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
|
||||
return False, "Password must contain at least one special character"
|
||||
|
||||
return True, "Password is strong"
|
||||
|
||||
|
||||
# --------------- Refresh Token helpers ---------------
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
"""SHA-256 hash of a refresh token for safe Redis key storage."""
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def create_refresh_token(user_id: str) -> str:
|
||||
"""
|
||||
Generate a cryptographically random refresh token, store its hash in Redis
|
||||
with a 7-day TTL, and return the raw token string.
|
||||
"""
|
||||
token = secrets.token_urlsafe(48)
|
||||
token_hash = _hash_token(token)
|
||||
r = await get_redis()
|
||||
await r.setex(f"refresh_token:{token_hash}", REFRESH_TOKEN_TTL, user_id)
|
||||
return token
|
||||
|
||||
|
||||
async def verify_refresh_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
Verify a refresh token by looking up its hash in Redis.
|
||||
Returns the associated user_id if valid, None otherwise.
|
||||
"""
|
||||
token_hash = _hash_token(token)
|
||||
r = await get_redis()
|
||||
user_id = await r.get(f"refresh_token:{token_hash}")
|
||||
return user_id
|
||||
|
||||
|
||||
async def revoke_refresh_token(token: str) -> None:
|
||||
"""Delete a refresh token from Redis (used during rotation)."""
|
||||
token_hash = _hash_token(token)
|
||||
r = await get_redis()
|
||||
await r.delete(f"refresh_token:{token_hash}")
|
||||
Reference in New Issue
Block a user