192 lines
7.0 KiB
Python
192 lines
7.0 KiB
Python
"""
|
|
Authentication service business logic
|
|
Handles user registration, login, and user management operations
|
|
"""
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, or_
|
|
from sqlalchemy.exc import IntegrityError
|
|
from fastapi import HTTPException, status
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from shared.models import User
|
|
from ..schemas.user import UserCreate, UserLogin
|
|
from ..utils.security import hash_password, validate_password_strength, verify_password
|
|
|
|
# Configure logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AuthService:
|
|
"""Service class for authentication and user management"""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def get_user_by_email(self, email: str) -> Optional[User]:
|
|
"""Find user by email address"""
|
|
try:
|
|
stmt = select(User).where(User.email == email)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logger.error(f"Error fetching user by email {email}: {e}")
|
|
return None
|
|
|
|
async def get_user_by_username(self, username: str) -> Optional[User]:
|
|
"""Find user by username"""
|
|
try:
|
|
stmt = select(User).where(User.username == username)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logger.error(f"Error fetching user by username {username}: {e}")
|
|
return None
|
|
|
|
async def get_user_by_id(self, user_id: str) -> Optional[User]:
|
|
"""Find user by UUID"""
|
|
try:
|
|
# For MySQL, user_id is already a string, no need to convert to UUID
|
|
stmt = select(User).where(User.id == user_id)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logger.error(f"Error fetching user by ID {user_id}: {e}")
|
|
return None
|
|
|
|
async def create_user(self, user_data: UserCreate) -> User:
|
|
"""Create a new user account with validation"""
|
|
|
|
# Validate password strength
|
|
is_strong, message = validate_password_strength(user_data.password)
|
|
if not is_strong:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Password too weak: {message}"
|
|
)
|
|
|
|
# Hash password
|
|
hashed_password = hash_password(user_data.password)
|
|
|
|
# Create user instance
|
|
user = User(
|
|
username=user_data.username,
|
|
email=user_data.email,
|
|
hashed_password=hashed_password,
|
|
is_active=True
|
|
)
|
|
|
|
try:
|
|
self.db.add(user)
|
|
await self.db.commit()
|
|
await self.db.refresh(user)
|
|
|
|
logger.info(f"Successfully created user: {user.username} ({user.email})")
|
|
return user
|
|
|
|
except IntegrityError as e:
|
|
await self.db.rollback()
|
|
logger.error(f"Integrity error creating user {user_data.username}: {e}")
|
|
|
|
# Check which constraint was violated
|
|
if "users_username_key" in str(e.orig):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Username already exists"
|
|
)
|
|
elif "users_email_key" in str(e.orig):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Email already registered"
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to create user due to database constraint"
|
|
)
|
|
|
|
except Exception as e:
|
|
await self.db.rollback()
|
|
logger.error(f"Unexpected error creating user {user_data.username}: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error during user creation"
|
|
)
|
|
|
|
async def check_user_exists(self, email: str, username: str) -> tuple[Optional[User], Optional[User]]:
|
|
"""Check if user exists by email or username in a single query"""
|
|
try:
|
|
stmt = select(User).where(or_(User.email == email, User.username == username))
|
|
result = await self.db.execute(stmt)
|
|
users = result.scalars().all()
|
|
|
|
email_user = None
|
|
username_user = None
|
|
|
|
for user in users:
|
|
if user.email == email:
|
|
email_user = user
|
|
if user.username == username:
|
|
username_user = user
|
|
|
|
return email_user, username_user
|
|
except Exception as e:
|
|
logger.error(f"Error checking user existence: {e}")
|
|
return None, None
|
|
|
|
async def authenticate_user(self, login_data: UserLogin) -> Optional[User]:
|
|
"""Authenticate user credentials"""
|
|
user = await self.get_user_by_email(login_data.email)
|
|
|
|
if not user:
|
|
return None
|
|
|
|
# Verify password
|
|
if not verify_password(login_data.password, user.hashed_password):
|
|
return None
|
|
|
|
# Check if user is active
|
|
if not user.is_active:
|
|
logger.warning(f"Login attempt for deactivated user: {user.email}")
|
|
return None
|
|
|
|
logger.info(f"Successful authentication for user: {user.username}")
|
|
return user
|
|
|
|
async def update_user_status(self, user_id: str, is_active: bool) -> Optional[User]:
|
|
"""Update user active status"""
|
|
user = await self.get_user_by_id(user_id)
|
|
if not user:
|
|
return None
|
|
|
|
user.is_active = is_active
|
|
try:
|
|
await self.db.commit()
|
|
await self.db.refresh(user)
|
|
logger.info(f"Updated user {user.username} status to: {is_active}")
|
|
return user
|
|
except Exception as e:
|
|
await self.db.rollback()
|
|
logger.error(f"Error updating user status: {e}")
|
|
return None
|
|
|
|
async def get_all_users(self, skip: int = 0, limit: int = 100) -> list[User]:
|
|
"""Get list of all users (admin function)"""
|
|
try:
|
|
stmt = select(User).offset(skip).limit(limit)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalars().all()
|
|
except Exception as e:
|
|
logger.error(f"Error fetching users list: {e}")
|
|
return []
|
|
|
|
async def check_database_health(self) -> bool:
|
|
"""Check if database connection is healthy"""
|
|
try:
|
|
stmt = select(User).limit(1)
|
|
await self.db.execute(stmt)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Database health check failed: {e}")
|
|
return False
|