扫码登录,获取cookies

This commit is contained in:
2026-03-09 16:10:29 +08:00
parent 754e720ba7
commit 8229208165
7775 changed files with 1150053 additions and 208 deletions

View File

@@ -15,7 +15,7 @@ import logging
from shared.models import get_db, User
from auth_service.app.models.database import create_tables
from auth_service.app.schemas.user import (
UserCreate, UserLogin, UserResponse, Token, TokenData, RefreshTokenRequest,
UserCreate, UserLogin, UserResponse, Token, TokenData, RefreshTokenRequest, AuthResponse,
)
from auth_service.app.services.auth_service import AuthService
from auth_service.app.utils.security import (
@@ -92,7 +92,9 @@ async def get_current_user(
@app.on_event("startup")
async def startup_event():
"""Initialize database tables on startup"""
await create_tables()
# 表已通过 create_sqlite_db.py 创建,无需重复创建
# await create_tables()
pass
@app.get("/")
async def root():
@@ -106,10 +108,10 @@ async def root():
async def health_check():
return {"status": "healthy"}
@app.post("/auth/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@app.post("/auth/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED)
async def register_user(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
"""
Register a new user account
Register a new user account and return tokens
"""
auth_service = AuthService(db)
@@ -131,17 +133,28 @@ async def register_user(user_data: UserCreate, db: AsyncSession = Depends(get_db
# Create new user
try:
user = await auth_service.create_user(user_data)
return UserResponse.from_orm(user)
# Create tokens for auto-login
access_token = create_access_token(data={"sub": str(user.id), "username": user.username})
refresh_token = await create_refresh_token(str(user.id))
return AuthResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=3600,
user=UserResponse.from_orm(user)
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create user: {str(e)}"
)
@app.post("/auth/login", response_model=Token)
@app.post("/auth/login", response_model=AuthResponse)
async def login_user(login_data: UserLogin, db: AsyncSession = Depends(get_db)):
"""
Authenticate user and return JWT token
Authenticate user and return JWT token with user info
"""
auth_service = AuthService(db)
@@ -173,11 +186,12 @@ async def login_user(login_data: UserLogin, db: AsyncSession = Depends(get_db)):
# Create refresh token (stored in Redis)
refresh_token = await create_refresh_token(str(user.id))
return Token(
return AuthResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=3600 # 1 hour
expires_in=3600,
user=UserResponse.from_orm(user)
)
@app.post("/auth/refresh", response_model=Token)

View File

@@ -10,6 +10,10 @@ __all__ = ["Base", "get_db", "engine", "AsyncSessionLocal", "User"]
async def create_tables():
"""Create all tables in the database."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
"""Create all tables in the database if they don't exist."""
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
except Exception as e:
# 表已存在或其他错误,忽略
print(f"Warning: Could not create tables: {e}")

View File

@@ -45,6 +45,15 @@ class Token(BaseModel):
expires_in: int = Field(..., description="Access token expiration time in seconds")
class AuthResponse(BaseModel):
"""Schema for authentication response with user info"""
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
user: UserResponse
class RefreshTokenRequest(BaseModel):
"""Schema for token refresh request"""
refresh_token: str = Field(..., description="The refresh token to exchange")

View File

@@ -7,9 +7,7 @@ import hashlib
import jwt
import secrets
from datetime import datetime, timedelta
from typing import Optional
import redis.asyncio as aioredis
from typing import Optional, Dict
from shared.config import shared_settings
@@ -17,17 +15,28 @@ from shared.config import shared_settings
BCRYPT_ROUNDS = 12
REFRESH_TOKEN_TTL = 7 * 24 * 3600 # 7 days in seconds
# Lazy-initialised async Redis client
_redis_client: Optional[aioredis.Redis] = None
# Lazy-initialised async Redis client (可选)
_redis_client: Optional[object] = None
# 内存存储(本地开发用,不使用 Redis 时)
_memory_store: Dict[str, tuple[str, datetime]] = {}
async def get_redis() -> aioredis.Redis:
"""Return a shared async Redis connection."""
async def get_redis():
"""Return a shared async Redis connection if enabled."""
if not shared_settings.USE_REDIS:
return None
global _redis_client
if _redis_client is None:
_redis_client = aioredis.from_url(
shared_settings.REDIS_URL, decode_responses=True
)
try:
import redis.asyncio as aioredis
_redis_client = aioredis.from_url(
shared_settings.REDIS_URL, decode_responses=True
)
except Exception as e:
print(f"[警告] Redis 连接失败: {e},将使用内存存储")
return None
return _redis_client
def hash_password(password: str) -> str:
@@ -118,31 +127,68 @@ def _hash_token(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _clean_expired_tokens():
"""清理过期的内存 token"""
now = datetime.utcnow()
expired_keys = [k for k, (_, exp) in _memory_store.items() if exp < now]
for k in expired_keys:
del _memory_store[k]
async def create_refresh_token(user_id: str) -> str:
"""
Generate a cryptographically random refresh token, store its hash in Redis
with a 7-day TTL, and return the raw token string.
(or memory if Redis disabled) with a 7-day TTL, and return the raw token string.
"""
token = secrets.token_urlsafe(48)
token_hash = _hash_token(token)
r = await get_redis()
await r.setex(f"refresh_token:{token_hash}", REFRESH_TOKEN_TTL, user_id)
if r:
# 使用 Redis
await r.setex(f"refresh_token:{token_hash}", REFRESH_TOKEN_TTL, user_id)
else:
# 使用内存存储
_clean_expired_tokens()
expire_at = datetime.utcnow() + timedelta(seconds=REFRESH_TOKEN_TTL)
_memory_store[f"refresh_token:{token_hash}"] = (user_id, expire_at)
return token
async def verify_refresh_token(token: str) -> Optional[str]:
"""
Verify a refresh token by looking up its hash in Redis.
Verify a refresh token by looking up its hash in Redis or memory.
Returns the associated user_id if valid, None otherwise.
"""
token_hash = _hash_token(token)
key = f"refresh_token:{token_hash}"
r = await get_redis()
user_id = await r.get(f"refresh_token:{token_hash}")
return user_id
if r:
# 使用 Redis
user_id = await r.get(key)
return user_id
else:
# 使用内存存储
_clean_expired_tokens()
if key in _memory_store:
user_id, expire_at = _memory_store[key]
if expire_at > datetime.utcnow():
return user_id
return None
async def revoke_refresh_token(token: str) -> None:
"""Delete a refresh token from Redis (used during rotation)."""
"""Delete a refresh token from Redis or memory (used during rotation)."""
token_hash = _hash_token(token)
key = f"refresh_token:{token_hash}"
r = await get_redis()
await r.delete(f"refresh_token:{token_hash}")
if r:
# 使用 Redis
await r.delete(key)
else:
# 使用内存存储
if key in _memory_store:
del _memory_store[key]