336 lines
10 KiB
Python
336 lines
10 KiB
Python
"""
|
||
Weibo-HotSign Authentication Service
|
||
Main FastAPI application entry point
|
||
"""
|
||
|
||
from fastapi import FastAPI, Depends, HTTPException, status, Security
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
import uvicorn
|
||
import os
|
||
import logging
|
||
|
||
from shared.models import get_db, User
|
||
from shared.config import shared_settings
|
||
from auth_service.app.models.database import create_tables
|
||
from auth_service.app.schemas.user import (
|
||
UserCreate, UserLogin, UserResponse, Token, TokenData,
|
||
RefreshTokenRequest, AuthResponse, WxLoginRequest,
|
||
)
|
||
from auth_service.app.services.auth_service import AuthService
|
||
from auth_service.app.utils.security import (
|
||
verify_password, create_access_token, decode_access_token,
|
||
create_refresh_token, verify_refresh_token, revoke_refresh_token,
|
||
)
|
||
|
||
# Configure logger
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Initialize FastAPI app
|
||
app = FastAPI(
|
||
title="Weibo-HotSign Authentication Service",
|
||
description="Handles user authentication and authorization for Weibo-HotSign system",
|
||
version="1.0.0",
|
||
docs_url="/docs",
|
||
redoc_url="/redoc"
|
||
)
|
||
|
||
# CORS middleware configuration
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["http://localhost:3000", "http://localhost:80"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Security scheme for JWT
|
||
security = HTTPBearer()
|
||
|
||
async def get_current_user(
|
||
credentials: HTTPAuthorizationCredentials = Security(security),
|
||
db: AsyncSession = Depends(get_db)
|
||
) -> UserResponse:
|
||
"""
|
||
Dependency to get current user from JWT token
|
||
"""
|
||
token = credentials.credentials
|
||
payload = decode_access_token(token)
|
||
|
||
if payload is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid or expired token",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
user_id = payload.get("sub")
|
||
if not user_id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid token payload",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
auth_service = AuthService(db)
|
||
user = await auth_service.get_user_by_id(user_id)
|
||
|
||
if user is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="User not found",
|
||
)
|
||
|
||
if not user.is_active:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="User account is deactivated",
|
||
)
|
||
|
||
return UserResponse.from_orm(user)
|
||
|
||
@app.on_event("startup")
|
||
async def startup_event():
|
||
"""Initialize database tables on startup"""
|
||
# 表已通过 create_sqlite_db.py 创建,无需重复创建
|
||
# await create_tables()
|
||
pass
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
return {
|
||
"service": "Weibo-HotSign Authentication Service",
|
||
"status": "running",
|
||
"version": "1.0.0"
|
||
}
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
return {"status": "healthy"}
|
||
|
||
@app.post("/auth/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED)
|
||
async def register_user(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||
"""
|
||
Register a new user account and return tokens
|
||
"""
|
||
auth_service = AuthService(db)
|
||
|
||
# Check if user already exists - optimized with single query
|
||
email_user, username_user = await auth_service.check_user_exists(user_data.email, user_data.username)
|
||
|
||
if email_user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail="User with this email already exists"
|
||
)
|
||
|
||
if username_user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_409_CONFLICT,
|
||
detail="Username already taken"
|
||
)
|
||
|
||
# Create new user
|
||
try:
|
||
user = await auth_service.create_user(user_data)
|
||
|
||
# Create tokens for auto-login
|
||
access_token = create_access_token(data={"sub": str(user.id), "username": user.username})
|
||
refresh_token = await create_refresh_token(str(user.id))
|
||
|
||
return AuthResponse(
|
||
access_token=access_token,
|
||
refresh_token=refresh_token,
|
||
token_type="bearer",
|
||
expires_in=3600,
|
||
user=UserResponse.from_orm(user)
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to create user: {str(e)}"
|
||
)
|
||
|
||
@app.post("/auth/login", response_model=AuthResponse)
|
||
async def login_user(login_data: UserLogin, db: AsyncSession = Depends(get_db)):
|
||
"""
|
||
Authenticate user and return JWT token with user info
|
||
"""
|
||
auth_service = AuthService(db)
|
||
|
||
# Find user by email
|
||
user = await auth_service.get_user_by_email(login_data.email)
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid email or password"
|
||
)
|
||
|
||
# Verify password
|
||
if not verify_password(login_data.password, user.hashed_password):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid email or password"
|
||
)
|
||
|
||
# Check if user is active
|
||
if not user.is_active:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="User account is deactivated"
|
||
)
|
||
|
||
# Create access token
|
||
access_token = create_access_token(data={"sub": str(user.id), "username": user.username})
|
||
|
||
# Create refresh token (stored in Redis)
|
||
refresh_token = await create_refresh_token(str(user.id))
|
||
|
||
return AuthResponse(
|
||
access_token=access_token,
|
||
refresh_token=refresh_token,
|
||
token_type="bearer",
|
||
expires_in=3600,
|
||
user=UserResponse.from_orm(user)
|
||
)
|
||
|
||
@app.post("/auth/refresh", response_model=Token)
|
||
async def refresh_token(body: RefreshTokenRequest, db: AsyncSession = Depends(get_db)):
|
||
"""
|
||
Exchange a valid refresh token for a new access + refresh token pair (Token Rotation).
|
||
The old refresh token is revoked immediately.
|
||
"""
|
||
# Verify the incoming refresh token
|
||
user_id = await verify_refresh_token(body.refresh_token)
|
||
if user_id is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid or expired refresh token",
|
||
)
|
||
|
||
# Ensure the user still exists and is active
|
||
auth_service = AuthService(db)
|
||
user = await auth_service.get_user_by_id(user_id)
|
||
if user is None or not user.is_active:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="User not found or deactivated",
|
||
)
|
||
|
||
# Revoke old token, issue new pair
|
||
await revoke_refresh_token(body.refresh_token)
|
||
new_access = create_access_token(data={"sub": str(user.id), "username": user.username})
|
||
new_refresh = await create_refresh_token(str(user.id))
|
||
|
||
return Token(
|
||
access_token=new_access,
|
||
refresh_token=new_refresh,
|
||
token_type="bearer",
|
||
expires_in=3600,
|
||
)
|
||
|
||
@app.get("/auth/me", response_model=UserResponse)
|
||
async def get_current_user_info(current_user: UserResponse = Depends(get_current_user)):
|
||
"""
|
||
Get current user information
|
||
"""
|
||
return current_user
|
||
|
||
|
||
@app.post("/auth/wx-login", response_model=AuthResponse)
|
||
async def wx_login(body: WxLoginRequest, db: AsyncSession = Depends(get_db)):
|
||
"""
|
||
微信小程序登录。
|
||
|
||
流程:
|
||
1. 用 code 调微信 code2Session 接口换取 openid
|
||
2. 查找是否已有该 openid 的用户
|
||
3. 没有则自动注册,有则直接登录
|
||
4. 返回 JWT token
|
||
"""
|
||
import httpx
|
||
|
||
appid = shared_settings.WX_APPID
|
||
secret = shared_settings.WX_SECRET
|
||
|
||
if not appid or not secret:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="微信小程序未配置 APPID 和 SECRET",
|
||
)
|
||
|
||
# Step 1: code 换 openid
|
||
async with httpx.AsyncClient(timeout=10) as client:
|
||
resp = await client.get(
|
||
"https://api.weixin.qq.com/sns/jscode2session",
|
||
params={
|
||
"appid": appid,
|
||
"secret": secret,
|
||
"js_code": body.code,
|
||
"grant_type": "authorization_code",
|
||
},
|
||
)
|
||
wx_data = resp.json()
|
||
|
||
openid = wx_data.get("openid")
|
||
if not openid:
|
||
errcode = wx_data.get("errcode", "unknown")
|
||
errmsg = wx_data.get("errmsg", "未知错误")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail=f"微信登录失败: {errmsg} (errcode={errcode})",
|
||
)
|
||
|
||
# Step 2: 查找已有用户
|
||
result = await db.execute(select(User).where(User.wx_openid == openid))
|
||
user = result.scalar_one_or_none()
|
||
|
||
if user:
|
||
# 已有用户 — 更新昵称头像(如果传了)
|
||
if body.nickname and body.nickname != user.wx_nickname:
|
||
user.wx_nickname = body.nickname
|
||
if body.avatar_url and body.avatar_url != user.wx_avatar:
|
||
user.wx_avatar = body.avatar_url
|
||
await db.commit()
|
||
await db.refresh(user)
|
||
else:
|
||
# Step 3: 自动注册
|
||
import uuid
|
||
nickname = body.nickname or f"wx_{openid[:8]}"
|
||
# 生成唯一 username(避免冲突)
|
||
username = f"wx_{openid[:12]}"
|
||
user = User(
|
||
id=str(uuid.uuid4()),
|
||
username=username,
|
||
wx_openid=openid,
|
||
wx_nickname=nickname,
|
||
wx_avatar=body.avatar_url,
|
||
is_active=True,
|
||
)
|
||
db.add(user)
|
||
await db.commit()
|
||
await db.refresh(user)
|
||
logger.info(f"微信用户自动注册: openid={openid[:16]}..., username={username}")
|
||
|
||
if not user.is_active:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="账号已被禁用",
|
||
)
|
||
|
||
# Step 4: 签发 token
|
||
access_token = create_access_token(
|
||
data={"sub": str(user.id), "username": user.username}
|
||
)
|
||
refresh_token = await create_refresh_token(str(user.id))
|
||
|
||
return AuthResponse(
|
||
access_token=access_token,
|
||
refresh_token=refresh_token,
|
||
token_type="bearer",
|
||
expires_in=3600,
|
||
user=UserResponse.from_orm(user),
|
||
)
|