172 lines
5.5 KiB
Python
172 lines
5.5 KiB
Python
"""
|
||
Tests for the shared module: crypto, response format, and ORM models.
|
||
Validates tasks 1.1 – 1.5 (excluding optional PBT task 1.4).
|
||
"""
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from sqlalchemy import select
|
||
|
||
from shared.crypto import derive_key, encrypt_cookie, decrypt_cookie
|
||
from shared.response import success_response, error_response
|
||
from shared.models import User, Account, Task, SigninLog
|
||
|
||
from tests.conftest import TestSessionLocal
|
||
|
||
|
||
# ===================== Crypto tests =====================
|
||
|
||
|
||
class TestCrypto:
|
||
"""Verify AES-256-GCM encrypt/decrypt round-trip and error handling."""
|
||
|
||
def setup_method(self):
|
||
self.key = derive_key("test-encryption-key")
|
||
|
||
def test_encrypt_decrypt_roundtrip(self):
|
||
original = "SUB=abc123; SUBP=xyz789;"
|
||
ct, iv = encrypt_cookie(original, self.key)
|
||
assert decrypt_cookie(ct, iv, self.key) == original
|
||
|
||
def test_different_plaintexts_produce_different_ciphertexts(self):
|
||
ct1, _ = encrypt_cookie("cookie_a", self.key)
|
||
ct2, _ = encrypt_cookie("cookie_b", self.key)
|
||
assert ct1 != ct2
|
||
|
||
def test_wrong_key_raises(self):
|
||
ct, iv = encrypt_cookie("secret", self.key)
|
||
wrong_key = derive_key("wrong-key")
|
||
with pytest.raises(Exception):
|
||
decrypt_cookie(ct, iv, wrong_key)
|
||
|
||
def test_empty_string_roundtrip(self):
|
||
ct, iv = encrypt_cookie("", self.key)
|
||
assert decrypt_cookie(ct, iv, self.key) == ""
|
||
|
||
def test_unicode_roundtrip(self):
|
||
original = "微博Cookie=值; 中文=测试"
|
||
ct, iv = encrypt_cookie(original, self.key)
|
||
assert decrypt_cookie(ct, iv, self.key) == original
|
||
|
||
|
||
# ===================== Response format tests =====================
|
||
|
||
|
||
class TestResponseFormat:
|
||
"""Verify unified response helpers."""
|
||
|
||
def test_success_response_structure(self):
|
||
resp = success_response({"id": 1}, "ok")
|
||
assert resp["success"] is True
|
||
assert resp["data"] == {"id": 1}
|
||
assert resp["message"] == "ok"
|
||
|
||
def test_success_response_defaults(self):
|
||
resp = success_response()
|
||
assert resp["success"] is True
|
||
assert resp["data"] is None
|
||
assert "Operation successful" in resp["message"]
|
||
|
||
def test_error_response_structure(self):
|
||
resp = error_response("bad", "VALIDATION_ERROR", [{"field": "email"}], 400)
|
||
assert resp.status_code == 400
|
||
import json
|
||
body = json.loads(resp.body)
|
||
assert body["success"] is False
|
||
assert body["data"] is None
|
||
assert body["error"]["code"] == "VALIDATION_ERROR"
|
||
assert len(body["error"]["details"]) == 1
|
||
|
||
|
||
# ===================== ORM model smoke tests =====================
|
||
|
||
|
||
class TestORMModels:
|
||
"""Verify ORM models can be created and queried with SQLite."""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_user(self, db_session):
|
||
user = User(
|
||
username="testuser",
|
||
email="test@example.com",
|
||
hashed_password="hashed",
|
||
)
|
||
db_session.add(user)
|
||
await db_session.commit()
|
||
|
||
result = await db_session.execute(select(User).where(User.username == "testuser"))
|
||
fetched = result.scalar_one()
|
||
assert fetched.email == "test@example.com"
|
||
assert fetched.is_active is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_account_linked_to_user(self, db_session):
|
||
user = User(username="u1", email="u1@x.com", hashed_password="h")
|
||
db_session.add(user)
|
||
await db_session.commit()
|
||
|
||
acct = Account(
|
||
user_id=user.id,
|
||
weibo_user_id="12345",
|
||
remark="test",
|
||
encrypted_cookies="enc",
|
||
iv="iv123",
|
||
)
|
||
db_session.add(acct)
|
||
await db_session.commit()
|
||
|
||
result = await db_session.execute(select(Account).where(Account.user_id == user.id))
|
||
fetched = result.scalar_one()
|
||
assert fetched.weibo_user_id == "12345"
|
||
assert fetched.status == "pending"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_task_linked_to_account(self, db_session):
|
||
user = User(username="u2", email="u2@x.com", hashed_password="h")
|
||
db_session.add(user)
|
||
await db_session.commit()
|
||
|
||
acct = Account(
|
||
user_id=user.id, weibo_user_id="99", remark="r",
|
||
encrypted_cookies="e", iv="i",
|
||
)
|
||
db_session.add(acct)
|
||
await db_session.commit()
|
||
|
||
task = Task(account_id=acct.id, cron_expression="0 8 * * *")
|
||
db_session.add(task)
|
||
await db_session.commit()
|
||
|
||
result = await db_session.execute(select(Task).where(Task.account_id == acct.id))
|
||
fetched = result.scalar_one()
|
||
assert fetched.cron_expression == "0 8 * * *"
|
||
assert fetched.is_enabled is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_signin_log(self, db_session):
|
||
user = User(username="u3", email="u3@x.com", hashed_password="h")
|
||
db_session.add(user)
|
||
await db_session.commit()
|
||
|
||
acct = Account(
|
||
user_id=user.id, weibo_user_id="77", remark="r",
|
||
encrypted_cookies="e", iv="i",
|
||
)
|
||
db_session.add(acct)
|
||
await db_session.commit()
|
||
|
||
log = SigninLog(
|
||
account_id=acct.id,
|
||
topic_title="超话A",
|
||
status="success",
|
||
)
|
||
db_session.add(log)
|
||
await db_session.commit()
|
||
|
||
result = await db_session.execute(
|
||
select(SigninLog).where(SigninLog.account_id == acct.id)
|
||
)
|
||
fetched = result.scalar_one()
|
||
assert fetched.status == "success"
|
||
assert fetched.topic_title == "超话A"
|