1. Add 登陆功能

2. 调整字体大小
3. 新增部分功能
This commit is contained in:
2026-06-05 18:00:31 +08:00
parent 06e0967128
commit 9fea9c6a53
58 changed files with 5028 additions and 322 deletions

58
tests/test_auth_routes.py Normal file
View File

@@ -0,0 +1,58 @@
"""Integration tests for the auth routes.
Uses FastAPI TestClient with a mocked user store.
Does not require a running PostgreSQL.
"""
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
from app.infrastructure.auth.user_store import UserRecord
@pytest.fixture
def client():
"""Return a TestClient with a mocked user store."""
mock_store = MagicMock()
alice = UserRecord(id="uuid-1", username="alice", hashed_pw="hashed", role="admin", is_active=True)
mock_store.authenticate.side_effect = lambda u, p: alice if (u == "alice" and p == "correct") else None
with patch("app.shared.bootstrap.get_user_store", return_value=mock_store):
# Import after patch so the mock is active when routes import bootstrap.
from app.api.main import app
with TestClient(app, raise_server_exceptions=False) as c:
yield c
def test_login_returns_token_for_valid_credentials(client):
"""POST /auth/token must return an access_token for valid credentials."""
resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
assert resp.status_code == 200
body = resp.json()
assert "access_token" in body
assert body["token_type"] == "bearer"
def test_login_returns_401_for_wrong_password(client):
"""POST /auth/token must return 401 for wrong password."""
resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "wrong"})
assert resp.status_code == 401
def test_me_returns_user_when_authenticated(client):
"""GET /auth/me must return user identity when a valid token is provided."""
login_resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
assert login_resp.status_code == 200, login_resp.text
token = login_resp.json()["access_token"]
me_resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
assert me_resp.status_code == 200
assert me_resp.json()["username"] == "alice"
assert me_resp.json()["role"] == "admin"
def test_me_returns_401_without_token(client):
"""GET /auth/me must return 401 when no token is provided."""
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401

View File

@@ -0,0 +1,45 @@
"""Tests for Celery task infrastructure.
Verifies Celery app configuration and task registration without
starting a real worker or connecting to Redis.
"""
def test_celery_app_uses_redis_broker():
"""Celery broker URL must be a Redis URL built from settings."""
from app.infrastructure.tasks.celery_app import celery_app
assert celery_app.conf.broker_url.startswith("redis://")
def test_celery_app_uses_redis_backend():
"""Celery result backend must be Redis."""
from app.infrastructure.tasks.celery_app import celery_app
assert celery_app.conf.result_backend.startswith("redis://")
def test_celery_app_has_json_serializer():
"""Task serializer must be JSON for portability."""
from app.infrastructure.tasks.celery_app import celery_app
assert celery_app.conf.task_serializer == "json"
def test_process_document_task_is_registered():
"""process_document_task must be discoverable in the Celery task registry."""
import app.infrastructure.tasks.document_tasks # noqa: F401 — triggers task registration
from app.infrastructure.tasks.celery_app import celery_app
registered = list(celery_app.tasks.keys())
assert any("process_document_task" in name for name in registered), (
f"process_document_task not found in {registered}"
)
def test_document_command_service_has_process_document():
"""DocumentCommandService must expose _process_document method."""
from app.application.documents.services import DocumentCommandService
assert hasattr(DocumentCommandService, "_process_document")
def test_document_command_service_has_store_document():
"""DocumentCommandService must expose store_document method."""
from app.application.documents.services import DocumentCommandService
assert hasattr(DocumentCommandService, "store_document")

58
tests/test_jwt_handler.py Normal file
View File

@@ -0,0 +1,58 @@
"""Tests for JWTHandler token creation and decoding.
These tests do not require a running server or database.
"""
import time
import pytest
SECRET = "test-secret-key-minimum-32-characters-long"
@pytest.fixture
def handler():
"""Return a JWTHandler configured with a test secret."""
from app.infrastructure.auth.jwt_handler import JWTHandler
return JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=30)
def test_create_token_returns_string(handler):
"""create_access_token must return a non-empty string."""
token = handler.create_access_token(user_id="u1", username="alice", role="admin")
assert isinstance(token, str)
assert len(token) > 20
def test_decode_token_returns_correct_claims(handler):
"""decode_token must return UserClaims matching the input."""
token = handler.create_access_token(user_id="u1", username="alice", role="admin")
claims = handler.decode_token(token)
assert claims.user_id == "u1"
assert claims.username == "alice"
assert claims.role == "admin"
def test_decode_expired_token_raises(handler):
"""decode_token must raise ValueError on an expired token."""
from app.infrastructure.auth.jwt_handler import JWTHandler
short_handler = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=0)
token = short_handler.create_access_token(user_id="u2", username="bob", role="readonly")
time.sleep(1)
with pytest.raises(ValueError, match="expired"):
short_handler.decode_token(token)
def test_decode_invalid_token_raises(handler):
"""decode_token must raise ValueError for a tampered token."""
with pytest.raises(ValueError):
handler.decode_token("not.a.valid.jwt.token")
def test_decode_wrong_secret_raises():
"""decode_token must raise ValueError when signed with a different secret."""
from app.infrastructure.auth.jwt_handler import JWTHandler
creator = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=60)
verifier = JWTHandler(secret_key="wrong-secret-key-also-minimum-32-chars", algorithm="HS256", expire_minutes=60)
token = creator.create_access_token(user_id="u3", username="carol", role="legal")
with pytest.raises(ValueError):
verifier.decode_token(token)

View File

@@ -0,0 +1,89 @@
"""Tests for RedisConversationStore.
Uses fakeredis so no real Redis connection is required.
All tests follow the same ConversationStore contract as InMemoryConversationStore.
"""
import pytest
import fakeredis
@pytest.fixture
def redis_client():
"""Return an in-process fake Redis client."""
return fakeredis.FakeRedis()
@pytest.fixture
def store(redis_client):
"""Return a RedisConversationStore backed by fake Redis."""
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
return RedisConversationStore(redis_client=redis_client, timeout_seconds=1800)
def test_create_session_returns_session_with_id(store):
"""create_session() must return a ConversationSession with a non-empty session_id."""
session = store.create_session()
assert session.session_id
assert len(session.session_id) > 0
def test_get_session_returns_same_session(store):
"""get_session() must return the previously created session."""
session = store.create_session()
fetched = store.get_session(session.session_id)
assert fetched is not None
assert fetched.session_id == session.session_id
def test_get_session_returns_none_for_unknown_id(store):
"""get_session() must return None when the session_id does not exist."""
assert store.get_session("nonexistent-id") is None
def test_save_message_appends_to_session(store):
"""save_message() must append a message and return the updated session."""
session = store.create_session()
updated = store.save_message(session.session_id, role="user", content="Hello")
assert updated is not None
assert len(updated.messages) == 1
assert updated.messages[0].role == "user"
assert updated.messages[0].content == "Hello"
def test_save_message_persists_across_lookups(store):
"""Messages saved to a session must be visible in subsequent get_session calls."""
session = store.create_session()
store.save_message(session.session_id, role="user", content="test")
fetched = store.get_session(session.session_id)
assert fetched is not None
assert len(fetched.messages) == 1
def test_delete_session_removes_it(store):
"""delete_session() must return True and remove the session."""
session = store.create_session()
assert store.delete_session(session.session_id) is True
assert store.get_session(session.session_id) is None
def test_delete_session_returns_false_for_unknown(store):
"""delete_session() must return False when the session does not exist."""
assert store.delete_session("ghost-id") is False
def test_list_sessions_includes_created_session(store):
"""list_sessions() must include all active sessions."""
session = store.create_session()
ids = [s["session_id"] for s in store.list_sessions()]
assert session.session_id in ids
def test_session_expires_after_ttl(redis_client):
"""Sessions must disappear after the TTL expires."""
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
store = RedisConversationStore(redis_client=redis_client, timeout_seconds=1)
session = store.create_session()
# Simulate TTL expiry by deleting the key directly (fakeredis expire(0) is a no-op).
redis_client.delete(f"session:{session.session_id}")
assert store.get_session(session.session_id) is None

View File

@@ -0,0 +1,39 @@
"""Verify that bootstrap correctly wires the reranker when the setting is enabled."""
from unittest.mock import patch
import pytest
def test_get_reranker_returns_none_when_disabled():
"""get_reranker() must return None when reranker_enabled is False."""
with patch("app.shared.bootstrap._build_binary_store"), \
patch("app.shared.bootstrap._build_vector_index"):
from app.shared import bootstrap
bootstrap.get_reranker.cache_clear()
with patch("app.shared.bootstrap.settings") as mock_settings:
mock_settings.reranker_enabled = False
mock_settings.reranker_base_url = ""
result = bootstrap.get_reranker()
bootstrap.get_reranker.cache_clear()
assert result is None
def test_get_reranker_returns_instance_when_enabled():
"""get_reranker() must return an OpenAICompatibleReranker when enabled."""
from app.shared import bootstrap
bootstrap.get_reranker.cache_clear()
with patch("app.shared.bootstrap.settings") as mock_settings:
mock_settings.reranker_enabled = True
mock_settings.reranker_base_url = "http://localhost:8082"
mock_settings.reranker_model = "BAAI/bge-reranker-v2-m3"
mock_settings.reranker_api_key = ""
mock_settings.reranker_top_k = 5
result = bootstrap.get_reranker()
bootstrap.get_reranker.cache_clear()
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
assert isinstance(result, OpenAICompatibleReranker)