1. Add 登陆功能
2. 调整字体大小 3. 新增部分功能
This commit is contained in:
58
tests/test_auth_routes.py
Normal file
58
tests/test_auth_routes.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Integration tests for the auth routes.
|
||||
|
||||
Uses FastAPI TestClient with a mocked user store.
|
||||
Does not require a running PostgreSQL.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.infrastructure.auth.user_store import UserRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Return a TestClient with a mocked user store."""
|
||||
mock_store = MagicMock()
|
||||
alice = UserRecord(id="uuid-1", username="alice", hashed_pw="hashed", role="admin", is_active=True)
|
||||
mock_store.authenticate.side_effect = lambda u, p: alice if (u == "alice" and p == "correct") else None
|
||||
|
||||
with patch("app.shared.bootstrap.get_user_store", return_value=mock_store):
|
||||
# Import after patch so the mock is active when routes import bootstrap.
|
||||
from app.api.main import app
|
||||
with TestClient(app, raise_server_exceptions=False) as c:
|
||||
yield c
|
||||
|
||||
|
||||
def test_login_returns_token_for_valid_credentials(client):
|
||||
"""POST /auth/token must return an access_token for valid credentials."""
|
||||
resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "access_token" in body
|
||||
assert body["token_type"] == "bearer"
|
||||
|
||||
|
||||
def test_login_returns_401_for_wrong_password(client):
|
||||
"""POST /auth/token must return 401 for wrong password."""
|
||||
resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "wrong"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_me_returns_user_when_authenticated(client):
|
||||
"""GET /auth/me must return user identity when a valid token is provided."""
|
||||
login_resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
|
||||
assert login_resp.status_code == 200, login_resp.text
|
||||
token = login_resp.json()["access_token"]
|
||||
|
||||
me_resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||
assert me_resp.status_code == 200
|
||||
assert me_resp.json()["username"] == "alice"
|
||||
assert me_resp.json()["role"] == "admin"
|
||||
|
||||
|
||||
def test_me_returns_401_without_token(client):
|
||||
"""GET /auth/me must return 401 when no token is provided."""
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
45
tests/test_celery_tasks.py
Normal file
45
tests/test_celery_tasks.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Tests for Celery task infrastructure.
|
||||
|
||||
Verifies Celery app configuration and task registration without
|
||||
starting a real worker or connecting to Redis.
|
||||
"""
|
||||
|
||||
|
||||
def test_celery_app_uses_redis_broker():
|
||||
"""Celery broker URL must be a Redis URL built from settings."""
|
||||
from app.infrastructure.tasks.celery_app import celery_app
|
||||
assert celery_app.conf.broker_url.startswith("redis://")
|
||||
|
||||
|
||||
def test_celery_app_uses_redis_backend():
|
||||
"""Celery result backend must be Redis."""
|
||||
from app.infrastructure.tasks.celery_app import celery_app
|
||||
assert celery_app.conf.result_backend.startswith("redis://")
|
||||
|
||||
|
||||
def test_celery_app_has_json_serializer():
|
||||
"""Task serializer must be JSON for portability."""
|
||||
from app.infrastructure.tasks.celery_app import celery_app
|
||||
assert celery_app.conf.task_serializer == "json"
|
||||
|
||||
|
||||
def test_process_document_task_is_registered():
|
||||
"""process_document_task must be discoverable in the Celery task registry."""
|
||||
import app.infrastructure.tasks.document_tasks # noqa: F401 — triggers task registration
|
||||
from app.infrastructure.tasks.celery_app import celery_app
|
||||
registered = list(celery_app.tasks.keys())
|
||||
assert any("process_document_task" in name for name in registered), (
|
||||
f"process_document_task not found in {registered}"
|
||||
)
|
||||
|
||||
|
||||
def test_document_command_service_has_process_document():
|
||||
"""DocumentCommandService must expose _process_document method."""
|
||||
from app.application.documents.services import DocumentCommandService
|
||||
assert hasattr(DocumentCommandService, "_process_document")
|
||||
|
||||
|
||||
def test_document_command_service_has_store_document():
|
||||
"""DocumentCommandService must expose store_document method."""
|
||||
from app.application.documents.services import DocumentCommandService
|
||||
assert hasattr(DocumentCommandService, "store_document")
|
||||
58
tests/test_jwt_handler.py
Normal file
58
tests/test_jwt_handler.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for JWTHandler token creation and decoding.
|
||||
|
||||
These tests do not require a running server or database.
|
||||
"""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
|
||||
SECRET = "test-secret-key-minimum-32-characters-long"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
"""Return a JWTHandler configured with a test secret."""
|
||||
from app.infrastructure.auth.jwt_handler import JWTHandler
|
||||
return JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=30)
|
||||
|
||||
|
||||
def test_create_token_returns_string(handler):
|
||||
"""create_access_token must return a non-empty string."""
|
||||
token = handler.create_access_token(user_id="u1", username="alice", role="admin")
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
|
||||
|
||||
def test_decode_token_returns_correct_claims(handler):
|
||||
"""decode_token must return UserClaims matching the input."""
|
||||
token = handler.create_access_token(user_id="u1", username="alice", role="admin")
|
||||
claims = handler.decode_token(token)
|
||||
assert claims.user_id == "u1"
|
||||
assert claims.username == "alice"
|
||||
assert claims.role == "admin"
|
||||
|
||||
|
||||
def test_decode_expired_token_raises(handler):
|
||||
"""decode_token must raise ValueError on an expired token."""
|
||||
from app.infrastructure.auth.jwt_handler import JWTHandler
|
||||
short_handler = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=0)
|
||||
token = short_handler.create_access_token(user_id="u2", username="bob", role="readonly")
|
||||
time.sleep(1)
|
||||
with pytest.raises(ValueError, match="expired"):
|
||||
short_handler.decode_token(token)
|
||||
|
||||
|
||||
def test_decode_invalid_token_raises(handler):
|
||||
"""decode_token must raise ValueError for a tampered token."""
|
||||
with pytest.raises(ValueError):
|
||||
handler.decode_token("not.a.valid.jwt.token")
|
||||
|
||||
|
||||
def test_decode_wrong_secret_raises():
|
||||
"""decode_token must raise ValueError when signed with a different secret."""
|
||||
from app.infrastructure.auth.jwt_handler import JWTHandler
|
||||
creator = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=60)
|
||||
verifier = JWTHandler(secret_key="wrong-secret-key-also-minimum-32-chars", algorithm="HS256", expire_minutes=60)
|
||||
token = creator.create_access_token(user_id="u3", username="carol", role="legal")
|
||||
with pytest.raises(ValueError):
|
||||
verifier.decode_token(token)
|
||||
89
tests/test_redis_conversation_store.py
Normal file
89
tests/test_redis_conversation_store.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Tests for RedisConversationStore.
|
||||
|
||||
Uses fakeredis so no real Redis connection is required.
|
||||
All tests follow the same ConversationStore contract as InMemoryConversationStore.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import fakeredis
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_client():
|
||||
"""Return an in-process fake Redis client."""
|
||||
return fakeredis.FakeRedis()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(redis_client):
|
||||
"""Return a RedisConversationStore backed by fake Redis."""
|
||||
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
|
||||
return RedisConversationStore(redis_client=redis_client, timeout_seconds=1800)
|
||||
|
||||
|
||||
def test_create_session_returns_session_with_id(store):
|
||||
"""create_session() must return a ConversationSession with a non-empty session_id."""
|
||||
session = store.create_session()
|
||||
assert session.session_id
|
||||
assert len(session.session_id) > 0
|
||||
|
||||
|
||||
def test_get_session_returns_same_session(store):
|
||||
"""get_session() must return the previously created session."""
|
||||
session = store.create_session()
|
||||
fetched = store.get_session(session.session_id)
|
||||
assert fetched is not None
|
||||
assert fetched.session_id == session.session_id
|
||||
|
||||
|
||||
def test_get_session_returns_none_for_unknown_id(store):
|
||||
"""get_session() must return None when the session_id does not exist."""
|
||||
assert store.get_session("nonexistent-id") is None
|
||||
|
||||
|
||||
def test_save_message_appends_to_session(store):
|
||||
"""save_message() must append a message and return the updated session."""
|
||||
session = store.create_session()
|
||||
updated = store.save_message(session.session_id, role="user", content="Hello")
|
||||
assert updated is not None
|
||||
assert len(updated.messages) == 1
|
||||
assert updated.messages[0].role == "user"
|
||||
assert updated.messages[0].content == "Hello"
|
||||
|
||||
|
||||
def test_save_message_persists_across_lookups(store):
|
||||
"""Messages saved to a session must be visible in subsequent get_session calls."""
|
||||
session = store.create_session()
|
||||
store.save_message(session.session_id, role="user", content="test")
|
||||
fetched = store.get_session(session.session_id)
|
||||
assert fetched is not None
|
||||
assert len(fetched.messages) == 1
|
||||
|
||||
|
||||
def test_delete_session_removes_it(store):
|
||||
"""delete_session() must return True and remove the session."""
|
||||
session = store.create_session()
|
||||
assert store.delete_session(session.session_id) is True
|
||||
assert store.get_session(session.session_id) is None
|
||||
|
||||
|
||||
def test_delete_session_returns_false_for_unknown(store):
|
||||
"""delete_session() must return False when the session does not exist."""
|
||||
assert store.delete_session("ghost-id") is False
|
||||
|
||||
|
||||
def test_list_sessions_includes_created_session(store):
|
||||
"""list_sessions() must include all active sessions."""
|
||||
session = store.create_session()
|
||||
ids = [s["session_id"] for s in store.list_sessions()]
|
||||
assert session.session_id in ids
|
||||
|
||||
|
||||
def test_session_expires_after_ttl(redis_client):
|
||||
"""Sessions must disappear after the TTL expires."""
|
||||
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
|
||||
store = RedisConversationStore(redis_client=redis_client, timeout_seconds=1)
|
||||
session = store.create_session()
|
||||
# Simulate TTL expiry by deleting the key directly (fakeredis expire(0) is a no-op).
|
||||
redis_client.delete(f"session:{session.session_id}")
|
||||
assert store.get_session(session.session_id) is None
|
||||
39
tests/test_reranker_bootstrap.py
Normal file
39
tests/test_reranker_bootstrap.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Verify that bootstrap correctly wires the reranker when the setting is enabled."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_get_reranker_returns_none_when_disabled():
|
||||
"""get_reranker() must return None when reranker_enabled is False."""
|
||||
with patch("app.shared.bootstrap._build_binary_store"), \
|
||||
patch("app.shared.bootstrap._build_vector_index"):
|
||||
from app.shared import bootstrap
|
||||
bootstrap.get_reranker.cache_clear()
|
||||
|
||||
with patch("app.shared.bootstrap.settings") as mock_settings:
|
||||
mock_settings.reranker_enabled = False
|
||||
mock_settings.reranker_base_url = ""
|
||||
result = bootstrap.get_reranker()
|
||||
|
||||
bootstrap.get_reranker.cache_clear()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_reranker_returns_instance_when_enabled():
|
||||
"""get_reranker() must return an OpenAICompatibleReranker when enabled."""
|
||||
from app.shared import bootstrap
|
||||
bootstrap.get_reranker.cache_clear()
|
||||
|
||||
with patch("app.shared.bootstrap.settings") as mock_settings:
|
||||
mock_settings.reranker_enabled = True
|
||||
mock_settings.reranker_base_url = "http://localhost:8082"
|
||||
mock_settings.reranker_model = "BAAI/bge-reranker-v2-m3"
|
||||
mock_settings.reranker_api_key = ""
|
||||
mock_settings.reranker_top_k = 5
|
||||
result = bootstrap.get_reranker()
|
||||
|
||||
bootstrap.get_reranker.cache_clear()
|
||||
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
|
||||
assert isinstance(result, OpenAICompatibleReranker)
|
||||
Reference in New Issue
Block a user