Fix SSE route dependency and align architecture docs
This commit is contained in:
@@ -1,137 +1,127 @@
|
||||
# tests/test_milvus.py
|
||||
"""Milvus集成测试"""
|
||||
"""新架构下的检索与 Milvus dense-only 约定测试。"""
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
import sys
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
|
||||
sys.path.insert(0, os.path.join(PROJECT_ROOT, "backend"))
|
||||
|
||||
from app.services.storage.milvus_client import MilvusClient, SearchResult
|
||||
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
|
||||
from app.config.settings import settings
|
||||
from app.application.agent.services import AgentConversationService
|
||||
from app.application.knowledge.services import KnowledgeRetrievalService
|
||||
from app.domain.conversation.models import AnswerResult, AnswerSource, ConversationSession
|
||||
from app.domain.retrieval import RetrievalQuery, RetrievedChunk
|
||||
|
||||
|
||||
class TestMilvusConnection:
|
||||
"""Milvus连接测试"""
|
||||
class FakeRetriever:
|
||||
def __init__(self) -> None:
|
||||
self.queries: list[RetrievalQuery] = []
|
||||
|
||||
def test_connection(self):
|
||||
"""测试Milvus连接"""
|
||||
client = MilvusClient()
|
||||
|
||||
result = client.connect()
|
||||
assert result == True
|
||||
|
||||
client.disconnect()
|
||||
|
||||
def test_create_collection(self):
|
||||
"""测试创建Collection"""
|
||||
client = MilvusClient()
|
||||
client.connect()
|
||||
|
||||
result = client.create_collection(recreate=True)
|
||||
assert result == True
|
||||
|
||||
# 检查Collection是否存在
|
||||
stats = client.get_collection_stats()
|
||||
assert stats["name"] == settings.milvus_collection
|
||||
|
||||
client.disconnect()
|
||||
|
||||
|
||||
class TestMilvusOperations:
|
||||
"""Milvus操作测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""创建测试客户端"""
|
||||
client = MilvusClient()
|
||||
client.connect()
|
||||
client.create_collection(recreate=True)
|
||||
client.load_collection()
|
||||
yield client
|
||||
client.disconnect()
|
||||
|
||||
def test_insert_and_search(self, client):
|
||||
"""测试插入和检索"""
|
||||
from app.services.embedding.text_chunker import TextChunk, ChunkMetadata
|
||||
|
||||
# 创建测试数据
|
||||
chunks = [
|
||||
TextChunk(
|
||||
content="第一条 为保障机动车安全技术性能,预防和减少机动车交通事故,保护人身安全,制定本标准。",
|
||||
metadata=ChunkMetadata(
|
||||
doc_id="test_doc",
|
||||
doc_name="测试文档",
|
||||
chunk_id="test_chunk_1",
|
||||
clause_number="第一条",
|
||||
regulation_type="车辆安全"
|
||||
)
|
||||
),
|
||||
TextChunk(
|
||||
content="第二条 本标准适用于在我国道路上行驶的所有机动车。",
|
||||
metadata=ChunkMetadata(
|
||||
doc_id="test_doc",
|
||||
doc_name="测试文档",
|
||||
chunk_id="test_chunk_2",
|
||||
clause_number="第二条",
|
||||
regulation_type="车辆安全"
|
||||
)
|
||||
def retrieve(self, query: RetrievalQuery) -> list[RetrievedChunk]:
|
||||
self.queries.append(query)
|
||||
return [
|
||||
RetrievedChunk(
|
||||
chunk_id="chunk-1",
|
||||
doc_id="doc-1",
|
||||
doc_name="测试法规",
|
||||
content="法规正文",
|
||||
score=0.91,
|
||||
section_title="第一章",
|
||||
page_number=1,
|
||||
metadata={"section_title": "第一章"},
|
||||
)
|
||||
]
|
||||
|
||||
# 生成嵌入
|
||||
embedder = BGEM3Embedder()
|
||||
embeddings = embedder.embed([c.content for c in chunks])
|
||||
def search(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
return self.retrieve(RetrievalQuery(query=query, top_k=top_k, filters=filters))
|
||||
|
||||
# 插入数据
|
||||
inserted_ids = client.insert_chunks(chunks, embeddings)
|
||||
assert len(inserted_ids) == 2
|
||||
|
||||
# 执行检索
|
||||
query = "机动车安全标准"
|
||||
query_embedding = embedder.embed_single(query)
|
||||
|
||||
results = client.hybrid_search(
|
||||
query_dense=query_embedding['dense'].tolist(),
|
||||
query_sparse=query_embedding['sparse'],
|
||||
top_k=2
|
||||
class FakeAnswerGenerator:
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
retrieved_chunks: list[RetrievedChunk],
|
||||
history: list[dict[str, str]] | None = None,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
prompt_template: str | None = None,
|
||||
) -> AnswerResult:
|
||||
return AnswerResult(
|
||||
answer=f"回答: {query}",
|
||||
sources=[
|
||||
AnswerSource(
|
||||
doc_id=item.doc_id,
|
||||
doc_name=item.doc_name,
|
||||
chunk_id=item.chunk_id,
|
||||
section_title=item.section_title,
|
||||
page_number=item.page_number,
|
||||
score=item.score,
|
||||
content=item.content,
|
||||
metadata=item.metadata,
|
||||
)
|
||||
for item in retrieved_chunks
|
||||
],
|
||||
model=model or "deepseek-v4-flash",
|
||||
latency_ms=12,
|
||||
retrieved_count=len(retrieved_chunks),
|
||||
context_tokens=128,
|
||||
)
|
||||
|
||||
assert len(results) > 0
|
||||
assert "机动车" in results[0].content or "安全" in results[0].content
|
||||
def stream_generate(self, **kwargs):
|
||||
sources = [source.__dict__ for source in self.generate(**kwargs).sources]
|
||||
yield {"event": "sources", "data": sources}
|
||||
yield {"event": "content", "data": "流式回答"}
|
||||
yield {"event": "done", "data": {"retrieved_count": 1}}
|
||||
|
||||
|
||||
class TestEmbedding:
|
||||
"""嵌入模型测试"""
|
||||
class FakeConversationStore:
|
||||
def __init__(self) -> None:
|
||||
self.sessions: dict[str, ConversationSession] = {}
|
||||
|
||||
def test_embed_single_text(self):
|
||||
"""测试单文本嵌入"""
|
||||
embedder = BGEM3Embedder()
|
||||
def create_session(self, metadata: dict | None = None) -> ConversationSession:
|
||||
session = ConversationSession(session_id="sess-1", created_at=1, updated_at=1, metadata=metadata or {})
|
||||
self.sessions[session.session_id] = session
|
||||
return session
|
||||
|
||||
result = embedder.embed_single("这是一条测试文本")
|
||||
def get_session(self, session_id: str) -> ConversationSession | None:
|
||||
return self.sessions.get(session_id)
|
||||
|
||||
assert 'dense' in result
|
||||
assert 'sparse' in result
|
||||
assert len(result['dense']) == 1024 # BGE-M3默认维度
|
||||
def save_message(self, session_id: str, *, role: str, content: str, sources: list[dict] | None = None):
|
||||
session = self.sessions.get(session_id)
|
||||
if session is None:
|
||||
return None
|
||||
session.messages.append(type("Msg", (), {"role": role, "content": content})())
|
||||
return session
|
||||
|
||||
def test_embed_batch(self):
|
||||
"""测试批量嵌入"""
|
||||
embedder = BGEM3Embedder()
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
return self.sessions.pop(session_id, None) is not None
|
||||
|
||||
texts = [
|
||||
"第一条 本标准规定了机动车安全要求",
|
||||
"第二条 机动车应符合以下技术条件",
|
||||
"第三条 生产企业应建立质量管理体系"
|
||||
]
|
||||
|
||||
result = embedder.embed(texts)
|
||||
|
||||
assert len(result.dense_embeddings) == 3
|
||||
assert result.dense_embeddings.shape[1] == 1024
|
||||
def list_sessions(self) -> list[dict]:
|
||||
return [{"session_id": key, "message_count": len(value.messages), "created_at": value.created_at, "updated_at": value.updated_at} for key, value in self.sessions.items()]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
def test_knowledge_retrieval_service_builds_retrieval_query():
|
||||
retriever = FakeRetriever()
|
||||
service = KnowledgeRetrievalService(retriever=retriever)
|
||||
|
||||
results = service.retrieve(query="机动车安全", top_k=3, filters='doc_name == "测试法规"')
|
||||
|
||||
assert len(results) == 1
|
||||
assert retriever.queries[0].query == "机动车安全"
|
||||
assert retriever.queries[0].top_k == 3
|
||||
assert retriever.queries[0].filters == 'doc_name == "测试法规"'
|
||||
|
||||
|
||||
def test_agent_conversation_service_reuses_shared_retrieval_service():
|
||||
retriever = FakeRetriever()
|
||||
retrieval_service = KnowledgeRetrievalService(retriever=retriever)
|
||||
conversation_store = FakeConversationStore()
|
||||
service = AgentConversationService(
|
||||
retrieval_service=retrieval_service,
|
||||
answer_generator=FakeAnswerGenerator(),
|
||||
conversation_store=conversation_store,
|
||||
)
|
||||
|
||||
session_id, result = service.chat(query="问一个问题", top_k=2, model="qwen3.5-flash")
|
||||
|
||||
assert session_id == "sess-1"
|
||||
assert result.answer == "回答: 问一个问题"
|
||||
assert result.retrieved_count == 1
|
||||
assert retriever.queries[0].top_k == 2
|
||||
assert len(conversation_store.sessions["sess-1"].messages) == 2
|
||||
|
||||
Reference in New Issue
Block a user