Files
AIRegulation-DocAnalysis/tests/test_parser.py

236 lines
7.8 KiB
Python
Raw Permalink Normal View History

"""API contract checks for the migrated backend architecture."""
2026-04-28 11:29:33 +08:00
from __future__ import annotations
2026-04-28 11:29:33 +08:00
from dataclasses import dataclass
2026-04-28 11:29:33 +08:00
from fastapi.testclient import TestClient
2026-04-28 11:29:33 +08:00
from app.api.main import app
from app.application.documents import DocumentProcessResult
from app.domain.conversation.models import AnswerResult, AnswerSource, ConversationSession
from app.domain.documents import Document, DocumentStatus
from app.domain.retrieval import RetrievedChunk
2026-04-28 11:29:33 +08:00
@dataclass
class FakeMessage:
role: str
content: str
2026-04-28 11:29:33 +08:00
class FakeDocumentCommandService:
def upload_and_process(self, **kwargs) -> DocumentProcessResult:
return DocumentProcessResult(
doc_id="doc-api-1",
doc_name=kwargs.get("doc_name") or "test.pdf",
status="indexed",
message="处理成功",
num_chunks=2,
summary="",
summary_latency_ms=0,
)
2026-04-28 11:29:33 +08:00
class FakeDocumentQueryService:
def get(self, doc_id: str) -> Document | None:
if doc_id != "doc-api-1":
return None
return Document(
doc_id=doc_id,
doc_name="测试法规",
file_name="test.pdf",
object_name="doc-api-1/test.pdf",
content_type="application/pdf",
size_bytes=12,
status=DocumentStatus.INDEXED,
chunk_count=2,
)
2026-04-28 11:29:33 +08:00
def list_documents(self, limit: int | None = None) -> list[Document]:
documents = [
Document(
doc_id="doc-api-1",
doc_name="测试法规",
file_name="test.pdf",
object_name="doc-api-1/test.pdf",
content_type="application/pdf",
size_bytes=12,
status=DocumentStatus.INDEXED,
chunk_count=2,
)
]
return documents[:limit] if limit is not None else documents
def download(self, doc_id: str) -> tuple[Document, bytes]:
document = self.get(doc_id)
if document is None:
raise FileNotFoundError(doc_id)
return document, b"pdf-bytes"
class FakeRetrievalService:
def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
return [
RetrievedChunk(
chunk_id="chunk-1",
doc_id="doc-api-1",
doc_name="测试法规",
content=f"关于 {query} 的法规内容",
score=0.92,
section_title="第一章",
page_number=1,
metadata={"filters": filters or ""},
)
]
class FakeConversationStore:
def __init__(self) -> None:
self.session = ConversationSession(
session_id="sess-1",
created_at=1,
updated_at=1,
messages=[FakeMessage(role="user", content="历史问题"), FakeMessage(role="assistant", content="历史回答")],
)
2026-04-28 11:29:33 +08:00
def get_session(self, session_id: str) -> ConversationSession | None:
if session_id == "sess-1":
return self.session
return None
def delete_session(self, session_id: str) -> bool:
return session_id == "sess-1"
def list_sessions(self) -> list[dict]:
return [{"session_id": "sess-1", "message_count": len(self.session.messages), "created_at": 1, "updated_at": 1}]
class FakeAgentConversationService:
def ask(self, **kwargs):
result = AnswerResult(
answer="这是基于法规上下文的回答",
sources=[
AnswerSource(
doc_id="doc-api-1",
doc_title="测试法规",
chunk_id="chunk-1",
section_title="第一章",
page_start=1,
score=0.92,
text="法规原文",
metadata={"section_title": "第一章"},
)
],
model=kwargs.get("model") or "qwen3.5-flash",
latency_ms=11,
retrieved_count=1,
context_tokens=128,
truncated=False,
error=None,
)
return None, result
def chat(self, **kwargs):
result = AnswerResult(
answer="会话回答",
sources=[],
model=kwargs.get("model") or "qwen3.5-flash",
latency_ms=12,
retrieved_count=1,
context_tokens=64,
truncated=False,
error=None,
)
return "sess-1", result
def stream_chat(self, **kwargs):
return "sess-1", iter(
[
{"event": "status", "data": "正在处理"},
{"event": "content", "data": "流式回答"},
{"event": "done", "data": {"retrieved_count": 1}},
]
)
2026-04-28 11:29:33 +08:00
def test_documents_upload_contract_preserved(monkeypatch):
from app.api.routes import documents
2026-04-28 11:29:33 +08:00
monkeypatch.setattr(documents, "get_document_command_service", lambda: FakeDocumentCommandService())
2026-04-28 11:29:33 +08:00
client = TestClient(app)
response = client.post(
"/api/v1/documents/upload",
files={"file": ("test.pdf", b"dummy-pdf", "application/pdf")},
data={"doc_name": "测试法规", "regulation_type": "车辆安全", "version": "2026"},
)
2026-04-28 11:29:33 +08:00
assert response.status_code == 200
payload = response.json()
assert payload["doc_id"] == "doc-api-1"
assert payload["doc_name"] == "测试法规"
assert payload["status"] == "indexed"
assert payload["num_chunks"] == 2
2026-04-28 11:29:33 +08:00
def test_documents_query_contract_preserved(monkeypatch):
from app.api.routes import documents
2026-04-28 11:29:33 +08:00
monkeypatch.setattr(documents, "get_document_query_service", lambda: FakeDocumentQueryService())
2026-04-28 11:29:33 +08:00
client = TestClient(app)
2026-04-28 11:29:33 +08:00
status_response = client.get("/api/v1/documents/status/doc-api-1")
assert status_response.status_code == 200
assert status_response.json()["status"] == "indexed"
2026-04-28 11:29:33 +08:00
list_response = client.get("/api/v1/documents/list")
assert list_response.status_code == 200
assert list_response.json()["total"] == 1
2026-04-28 11:29:33 +08:00
download_response = client.get("/api/v1/documents/download/doc-api-1")
assert download_response.status_code == 200
assert download_response.content == b"pdf-bytes"
2026-04-28 11:29:33 +08:00
def test_knowledge_retrieval_contract_preserved(monkeypatch):
from app.api.routes import knowledge
2026-04-28 11:29:33 +08:00
monkeypatch.setattr(knowledge, "get_retrieval_service", lambda: FakeRetrievalService())
2026-04-28 11:29:33 +08:00
client = TestClient(app)
response = client.post(
"/api/v1/knowledge/retrieval",
json={"query": "机动车安全", "top_k": 3, "filters": 'doc_id == "doc-api-1"'},
)
2026-04-28 11:29:33 +08:00
assert response.status_code == 200
payload = response.json()
assert payload["query"] == "机动车安全"
assert payload["total"] == 1
assert payload["results"][0]["metadata"]["doc_id"] == "doc-api-1"
assert payload["results"][0]["metadata"]["section_title"] == "第一章"
2026-04-28 11:29:33 +08:00
def test_agent_ask_and_stream_contract_preserved(monkeypatch):
from app.api.routes import agent
2026-04-28 11:29:33 +08:00
store = FakeConversationStore()
monkeypatch.setattr(agent, "get_agent_conversation_service", lambda: FakeAgentConversationService())
2026-04-28 11:29:33 +08:00
client = TestClient(app)
2026-04-28 11:29:33 +08:00
ask_response = client.post("/api/v1/agent/ask", json={"query": "这个法规要求什么?"})
assert ask_response.status_code == 200
ask_payload = ask_response.json()
assert ask_payload["answer"] == "这是基于法规上下文的回答"
assert ask_payload["retrieved_count"] == 1
assert ask_payload["sources"][0]["doc_id"] == "doc-api-1"
2026-04-28 11:29:33 +08:00
stream_response = client.get("/api/v1/agent/chat/stream", params={"query": "继续说明"})
assert stream_response.status_code == 200
assert stream_response.headers["content-type"].startswith("text/event-stream")
assert "event: session" in stream_response.text
assert "event: content" in stream_response.text