扫码登录,获取cookies
This commit is contained in:
Binary file not shown.
@@ -15,7 +15,7 @@ import logging
|
||||
from shared.models import get_db, User
|
||||
from auth_service.app.models.database import create_tables
|
||||
from auth_service.app.schemas.user import (
|
||||
UserCreate, UserLogin, UserResponse, Token, TokenData, RefreshTokenRequest,
|
||||
UserCreate, UserLogin, UserResponse, Token, TokenData, RefreshTokenRequest, AuthResponse,
|
||||
)
|
||||
from auth_service.app.services.auth_service import AuthService
|
||||
from auth_service.app.utils.security import (
|
||||
@@ -92,7 +92,9 @@ async def get_current_user(
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize database tables on startup"""
|
||||
await create_tables()
|
||||
# 表已通过 create_sqlite_db.py 创建,无需重复创建
|
||||
# await create_tables()
|
||||
pass
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
@@ -106,10 +108,10 @@ async def root():
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
@app.post("/auth/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
@app.post("/auth/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register_user(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
Register a new user account
|
||||
Register a new user account and return tokens
|
||||
"""
|
||||
auth_service = AuthService(db)
|
||||
|
||||
@@ -131,17 +133,28 @@ async def register_user(user_data: UserCreate, db: AsyncSession = Depends(get_db
|
||||
# Create new user
|
||||
try:
|
||||
user = await auth_service.create_user(user_data)
|
||||
return UserResponse.from_orm(user)
|
||||
|
||||
# Create tokens for auto-login
|
||||
access_token = create_access_token(data={"sub": str(user.id), "username": user.username})
|
||||
refresh_token = await create_refresh_token(str(user.id))
|
||||
|
||||
return AuthResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=3600,
|
||||
user=UserResponse.from_orm(user)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create user: {str(e)}"
|
||||
)
|
||||
|
||||
@app.post("/auth/login", response_model=Token)
|
||||
@app.post("/auth/login", response_model=AuthResponse)
|
||||
async def login_user(login_data: UserLogin, db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
Authenticate user and return JWT token
|
||||
Authenticate user and return JWT token with user info
|
||||
"""
|
||||
auth_service = AuthService(db)
|
||||
|
||||
@@ -173,11 +186,12 @@ async def login_user(login_data: UserLogin, db: AsyncSession = Depends(get_db)):
|
||||
# Create refresh token (stored in Redis)
|
||||
refresh_token = await create_refresh_token(str(user.id))
|
||||
|
||||
return Token(
|
||||
return AuthResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=3600 # 1 hour
|
||||
expires_in=3600,
|
||||
user=UserResponse.from_orm(user)
|
||||
)
|
||||
|
||||
@app.post("/auth/refresh", response_model=Token)
|
||||
|
||||
Binary file not shown.
@@ -10,6 +10,10 @@ __all__ = ["Base", "get_db", "engine", "AsyncSessionLocal", "User"]
|
||||
|
||||
|
||||
async def create_tables():
|
||||
"""Create all tables in the database."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
"""Create all tables in the database if they don't exist."""
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
except Exception as e:
|
||||
# 表已存在或其他错误,忽略
|
||||
print(f"Warning: Could not create tables: {e}")
|
||||
|
||||
Binary file not shown.
@@ -45,6 +45,15 @@ class Token(BaseModel):
|
||||
expires_in: int = Field(..., description="Access token expiration time in seconds")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""Schema for authentication response with user info"""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
user: UserResponse
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Schema for token refresh request"""
|
||||
refresh_token: str = Field(..., description="The refresh token to exchange")
|
||||
|
||||
Binary file not shown.
@@ -7,9 +7,7 @@ import hashlib
|
||||
import jwt
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from typing import Optional, Dict
|
||||
|
||||
from shared.config import shared_settings
|
||||
|
||||
@@ -17,17 +15,28 @@ from shared.config import shared_settings
|
||||
BCRYPT_ROUNDS = 12
|
||||
REFRESH_TOKEN_TTL = 7 * 24 * 3600 # 7 days in seconds
|
||||
|
||||
# Lazy-initialised async Redis client
|
||||
_redis_client: Optional[aioredis.Redis] = None
|
||||
# Lazy-initialised async Redis client (可选)
|
||||
_redis_client: Optional[object] = None
|
||||
|
||||
# 内存存储(本地开发用,不使用 Redis 时)
|
||||
_memory_store: Dict[str, tuple[str, datetime]] = {}
|
||||
|
||||
|
||||
async def get_redis() -> aioredis.Redis:
|
||||
"""Return a shared async Redis connection."""
|
||||
async def get_redis():
|
||||
"""Return a shared async Redis connection if enabled."""
|
||||
if not shared_settings.USE_REDIS:
|
||||
return None
|
||||
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = aioredis.from_url(
|
||||
shared_settings.REDIS_URL, decode_responses=True
|
||||
)
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
_redis_client = aioredis.from_url(
|
||||
shared_settings.REDIS_URL, decode_responses=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[警告] Redis 连接失败: {e},将使用内存存储")
|
||||
return None
|
||||
return _redis_client
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
@@ -118,31 +127,68 @@ def _hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _clean_expired_tokens():
|
||||
"""清理过期的内存 token"""
|
||||
now = datetime.utcnow()
|
||||
expired_keys = [k for k, (_, exp) in _memory_store.items() if exp < now]
|
||||
for k in expired_keys:
|
||||
del _memory_store[k]
|
||||
|
||||
|
||||
async def create_refresh_token(user_id: str) -> str:
|
||||
"""
|
||||
Generate a cryptographically random refresh token, store its hash in Redis
|
||||
with a 7-day TTL, and return the raw token string.
|
||||
(or memory if Redis disabled) with a 7-day TTL, and return the raw token string.
|
||||
"""
|
||||
token = secrets.token_urlsafe(48)
|
||||
token_hash = _hash_token(token)
|
||||
|
||||
r = await get_redis()
|
||||
await r.setex(f"refresh_token:{token_hash}", REFRESH_TOKEN_TTL, user_id)
|
||||
if r:
|
||||
# 使用 Redis
|
||||
await r.setex(f"refresh_token:{token_hash}", REFRESH_TOKEN_TTL, user_id)
|
||||
else:
|
||||
# 使用内存存储
|
||||
_clean_expired_tokens()
|
||||
expire_at = datetime.utcnow() + timedelta(seconds=REFRESH_TOKEN_TTL)
|
||||
_memory_store[f"refresh_token:{token_hash}"] = (user_id, expire_at)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
async def verify_refresh_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
Verify a refresh token by looking up its hash in Redis.
|
||||
Verify a refresh token by looking up its hash in Redis or memory.
|
||||
Returns the associated user_id if valid, None otherwise.
|
||||
"""
|
||||
token_hash = _hash_token(token)
|
||||
key = f"refresh_token:{token_hash}"
|
||||
|
||||
r = await get_redis()
|
||||
user_id = await r.get(f"refresh_token:{token_hash}")
|
||||
return user_id
|
||||
if r:
|
||||
# 使用 Redis
|
||||
user_id = await r.get(key)
|
||||
return user_id
|
||||
else:
|
||||
# 使用内存存储
|
||||
_clean_expired_tokens()
|
||||
if key in _memory_store:
|
||||
user_id, expire_at = _memory_store[key]
|
||||
if expire_at > datetime.utcnow():
|
||||
return user_id
|
||||
return None
|
||||
|
||||
|
||||
async def revoke_refresh_token(token: str) -> None:
|
||||
"""Delete a refresh token from Redis (used during rotation)."""
|
||||
"""Delete a refresh token from Redis or memory (used during rotation)."""
|
||||
token_hash = _hash_token(token)
|
||||
key = f"refresh_token:{token_hash}"
|
||||
|
||||
r = await get_redis()
|
||||
await r.delete(f"refresh_token:{token_hash}")
|
||||
if r:
|
||||
# 使用 Redis
|
||||
await r.delete(key)
|
||||
else:
|
||||
# 使用内存存储
|
||||
if key in _memory_store:
|
||||
del _memory_store[key]
|
||||
|
||||
Reference in New Issue
Block a user