Fix 法规对话

This commit is contained in:
2026-05-21 23:20:39 +08:00
parent 1b640f0084
commit bf6d47e1fd
11 changed files with 553 additions and 428 deletions

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
import json
from typing import AsyncGenerator, List, Optional
@@ -20,6 +19,7 @@ from app.api.models import (
SessionInfo,
)
from app.config.settings import settings
from app.shared.async_utils import iter_in_thread
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
# Keep route handlers close to their transport-layer wiring for easier auditing.
@@ -101,14 +101,13 @@ async def chat_stream_get(
top_k=settings.rag_top_k,
)
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
for event_data in event_stream:
async for event_data in iter_in_thread(event_stream):
event_type = event_data.get("event", "content")
data = event_data.get("data", "")
if isinstance(data, (dict, list)):
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
else:
yield f"event: {event_type}\ndata: {data}\n\n"
await asyncio.sleep(0)
except Exception as exc:
yield f"event: error\ndata: {str(exc)}\n\n"

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
import json
from typing import AsyncGenerator
@@ -11,6 +10,7 @@ from fastapi.responses import StreamingResponse
from app.config.settings import settings
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
from app.shared.async_utils import iter_in_thread
from app.shared.bootstrap import get_agent_conversation_service
@@ -29,7 +29,7 @@ _DEFAULT_QUICK_QUESTIONS = [
@router.post("/chat")
async def rag_chat(request: RagChatRequest):
"""Stream RAG Q&A using the real agent service."""
_, event_stream = get_agent_conversation_service().stream_chat(
session_id, event_stream = get_agent_conversation_service().stream_chat(
query=request.query,
session_id=request.session_id,
filters=request.filters,
@@ -38,7 +38,11 @@ async def rag_chat(request: RagChatRequest):
async def generate() -> AsyncGenerator[str, None]:
"""Translate agent SSE events to rag format."""
for event in event_stream:
yield (
"event: message\n"
f"data: {json.dumps({'type': 'session', 'session_id': session_id}, ensure_ascii=False)}\n\n"
)
async for event in iter_in_thread(event_stream):
event_type = event.get("event", "")
data = event.get("data", "")
if event_type == "sources":
@@ -69,14 +73,18 @@ async def rag_chat(request: RagChatRequest):
elif event_type == "done":
yield (
"event: message\n"
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
f"data: {json.dumps({'type': 'done', 'session_id': session_id}, ensure_ascii=False)}\n\n"
)
elif event_type == "status":
yield (
"event: message\n"
f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n"
)
await asyncio.sleep(0)
elif event_type == "error":
yield (
"event: message\n"
f"data: {json.dumps({'type': 'error', 'text': str(data)}, ensure_ascii=False)}\n\n"
)
return StreamingResponse(
generate(),

View File

@@ -7,20 +7,75 @@ from app.domain.retrieval.ports import Reranker
# Keep orchestration logic centralized so use-case flow stays easy to trace.
def _reciprocal_rank_fusion(
ranked_lists: list[list[RetrievedChunk]], k: int = 60
) -> list[RetrievedChunk]:
"""Merge multiple ranked lists with Reciprocal Rank Fusion.
Score for chunk c = sum over lists of 1 / (k + rank(c)).
A chunk appearing in multiple lists gets a higher fused score.
"""
scores: dict[str, float] = {}
chunk_map: dict[str, RetrievedChunk] = {}
for ranked in ranked_lists:
for rank, chunk in enumerate(ranked):
key = chunk.chunk_id
scores[key] = scores.get(key, 0.0) + 1.0 / (k + rank + 1)
chunk_map[key] = chunk
sorted_keys = sorted(scores, key=lambda ck: scores[ck], reverse=True)
return [
RetrievedChunk(
chunk_id=chunk_map[ck].chunk_id,
doc_id=chunk_map[ck].doc_id,
doc_name=chunk_map[ck].doc_name,
content=chunk_map[ck].content,
score=scores[ck],
section_title=chunk_map[ck].section_title,
page_number=chunk_map[ck].page_number,
metadata=chunk_map[ck].metadata,
)
for ck in sorted_keys
]
class KnowledgeRetrievalService:
"""Provide the Knowledge Retrieval Service service."""
def __init__(self, *, retriever: Retriever, reranker: Reranker | None = None, reranker_top_k: int = 5) -> None:
def __init__(
self,
*,
retriever: Retriever,
bm25_retriever=None,
reranker: Reranker | None = None,
reranker_top_k: int = 5,
) -> None:
"""Initialize the Knowledge Retrieval Service instance."""
self.retriever = retriever
self.bm25_retriever = bm25_retriever
self.reranker = reranker
self.reranker_top_k = reranker_top_k
def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Retrieve and optionally rerank chunks for a query."""
candidate_k = top_k if self.reranker is None else max(top_k * 4, 20)
"""Retrieve and optionally rerank chunks for a query.
When a BM25 retriever is available, combines dense + sparse results
via Reciprocal Rank Fusion before optional reranking.
"""
use_hybrid = self.bm25_retriever is not None and getattr(self.bm25_retriever, "available", False)
candidate_k = max(top_k * 4, 20) if (self.reranker is not None or use_hybrid) else top_k
retrieval_query = RetrievalQuery(query=query, top_k=candidate_k, filters=filters)
candidates = self.retriever.retrieve(retrieval_query)
dense_results = self.retriever.retrieve(retrieval_query)
if use_hybrid:
bm25_results = self.bm25_retriever.retrieve(query, top_k=candidate_k, filters=filters)
candidates = _reciprocal_rank_fusion([dense_results, bm25_results])
else:
candidates = dense_results
if self.reranker and candidates:
return self.reranker.rerank(query, candidates, top_k=self.reranker_top_k)
return candidates[:top_k]

View File

@@ -12,7 +12,6 @@ from app.services.llm.llm_factory import get_llm_client
# Keep adapter behavior explicit so integration details remain easy to audit.
PROMPT_TEMPLATES = {
"default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。",
"compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。",
@@ -21,6 +20,17 @@ PROMPT_TEMPLATES = {
class OpenAICompatibleAnswerGenerator(AnswerGenerator):
"""Represent the Open A I Compatible Answer Generator type."""
@staticmethod
def _estimate_tokens(text: str) -> int:
"""Estimate token count for mixed Chinese/English text.
Chinese chars are ~1.5 chars/token; ASCII is ~4 chars/token.
"""
chinese = sum(1 for c in text if "" <= c <= "鿿")
other = len(text) - chinese
return int(chinese / 1.5 + other / 4) + 1
def _build_messages(
self,
*,
@@ -40,9 +50,12 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
f"页码: {chunk.page_number}\n"
f"内容: {chunk.content}"
)
context_tokens += len(block)
block_tokens = self._estimate_tokens(block)
if context_tokens + block_tokens > settings.rag_max_context_tokens:
break
context_tokens += block_tokens
context_blocks.append(block)
context = "\n\n".join(context_blocks)[: settings.rag_max_context_tokens * 4]
context = "\n\n".join(context_blocks)
messages = [{"role": "system", "content": system_prompt}]
for item in history or []:
messages.append({"role": item["role"], "content": item["content"]})
@@ -52,7 +65,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
"content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。",
}
)
return messages, min(context_tokens, settings.rag_max_context_tokens)
return messages, context_tokens
def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]:
"""Handle sources for this module for the Open A I Compatible Answer Generator instance."""
@@ -98,7 +111,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
latency_ms=latency_ms,
retrieved_count=len(retrieved_chunks),
context_tokens=context_tokens,
truncated=False,
truncated=len(retrieved_chunks) > len(messages),
error=response.error,
)
@@ -124,15 +137,18 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
yield {"event": "sources", "data": sources}
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
answer_parts: list[str] = []
if hasattr(client, "stream_chat"):
for chunk in client.stream_chat(messages):
answer_parts.append(chunk)
yield {"event": "content", "data": chunk}
else:
response = client.chat(messages)
answer_parts.append(response.content)
yield {"event": "content", "data": response.content}
full_answer = "".join(answer_parts)
try:
if hasattr(client, "stream_chat"):
for chunk in client.stream_chat(messages):
answer_parts.append(chunk)
yield {"event": "content", "data": chunk}
else:
response = client.chat(messages)
answer_parts.append(response.content)
yield {"event": "content", "data": response.content}
except Exception as exc:
yield {"event": "error", "data": str(exc)}
return
yield {
"event": "done",
"data": {

View File

@@ -0,0 +1,149 @@
"""BM25 sparse retriever backed by the existing Milvus collection.
Falls back to no-op if rank_bm25 or jieba is not installed.
The index is built lazily on first query and can be refreshed after
new documents are ingested by calling refresh().
"""
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING
from app.domain.retrieval import RetrievedChunk
if TYPE_CHECKING:
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
logger = logging.getLogger(__name__)
try:
from rank_bm25 import BM25Okapi
import jieba
jieba.setLogLevel(logging.WARNING)
_BM25_AVAILABLE = True
except ImportError:
_BM25_AVAILABLE = False
logger.warning("rank_bm25 or jieba not installed BM25 retrieval disabled. Run: pip install rank-bm25 jieba")
def _tokenize(text: str) -> list[str]:
"""Tokenize Chinese/mixed text with jieba."""
return list(jieba.cut(text))
class BM25Retriever:
"""Sparse BM25 retriever that indexes chunks stored in Milvus.
Thread-safe: the index is protected by a lock so concurrent requests
during the initial build do not race.
"""
def __init__(self, vector_index: "MilvusVectorIndex") -> None:
self._vector_index = vector_index
self._lock = threading.Lock()
self._index: "BM25Okapi | None" = None
self._chunks: list[RetrievedChunk] = []
@property
def available(self) -> bool:
return _BM25_AVAILABLE
def _load_all_chunks(self) -> list[RetrievedChunk]:
"""Fetch all chunk records from Milvus (no vectors needed)."""
try:
rows = self._vector_index.collection.query(
expr='doc_id != ""',
output_fields=["id", "doc_id", "doc_name", "content", "section_title", "page_number"],
limit=16384,
)
except Exception:
logger.exception("BM25Retriever: failed to fetch chunks from Milvus")
return []
return [
RetrievedChunk(
chunk_id=str(row.get("id", "")),
doc_id=str(row.get("doc_id", "")),
doc_name=str(row.get("doc_name", "")),
content=str(row.get("content", "")),
score=0.0,
section_title=str(row.get("section_title", "")),
page_number=int(row.get("page_number") or 0),
metadata={},
)
for row in rows
if row.get("content")
]
def _ensure_built(self) -> None:
if self._index is not None:
return
with self._lock:
if self._index is not None:
return
self._build()
def _build(self) -> None:
logger.info("BM25Retriever: building index …")
chunks = self._load_all_chunks()
if not chunks:
logger.warning("BM25Retriever: no chunks found, index is empty")
self._chunks = []
self._index = BM25Okapi([[]])
return
tokenized = [_tokenize(c.content) for c in chunks]
self._chunks = chunks
self._index = BM25Okapi(tokenized)
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
def refresh(self) -> None:
"""Rebuild the index (call after new documents are ingested)."""
with self._lock:
self._index = None
self._chunks = []
self._build()
def retrieve(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Return top_k chunks ranked by BM25 score."""
if not _BM25_AVAILABLE:
return []
self._ensure_built()
if not self._chunks:
return []
tokens = _tokenize(query)
scores = self._index.get_scores(tokens) # type: ignore[union-attr]
# Pair and sort descending by score
ranked = sorted(
((float(scores[i]), self._chunks[i]) for i in range(len(self._chunks))),
key=lambda x: x[0],
reverse=True,
)
results: list[RetrievedChunk] = []
for score, chunk in ranked[: top_k * 2]:
if score <= 0:
break
# Apply simple regulation_type filter if provided
if filters and chunk.metadata.get("regulation_type"):
types = [t.strip() for t in filters.split(",")]
if chunk.metadata.get("regulation_type") not in types:
continue
results.append(
RetrievedChunk(
chunk_id=chunk.chunk_id,
doc_id=chunk.doc_id,
doc_name=chunk.doc_name,
content=chunk.content,
score=score,
section_title=chunk.section_title,
page_number=chunk.page_number,
metadata=chunk.metadata,
)
)
if len(results) >= top_k:
break
return results

View File

@@ -0,0 +1,38 @@
"""Async utility helpers for bridging sync generators to async FastAPI routes."""
from __future__ import annotations
import asyncio
import threading
from typing import AsyncGenerator, Generator, TypeVar
T = TypeVar("T")
_SENTINEL = object()
async def iter_in_thread(sync_gen: Generator[T, None, None]) -> AsyncGenerator[T, None]:
"""Yield items from a synchronous generator without blocking the event loop.
Runs the generator in a daemon thread, forwarding items via an asyncio.Queue.
Use this to wrap blocking LLM streaming generators inside async FastAPI routes.
"""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue(maxsize=32)
def _drain() -> None:
try:
for item in sync_gen:
future = asyncio.run_coroutine_threadsafe(queue.put(item), loop)
future.result()
finally:
asyncio.run_coroutine_threadsafe(queue.put(_SENTINEL), loop).result()
thread = threading.Thread(target=_drain, daemon=True)
thread.start()
while True:
item = await queue.get()
if item is _SENTINEL:
break
yield item # type: ignore[misc]

View File

@@ -19,6 +19,7 @@ from app.infrastructure.storage.json_document_repository import JsonDocumentRepo
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository
from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore
from app.infrastructure.vectorstore.bm25_retriever import BM25Retriever
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
@@ -87,6 +88,13 @@ def get_reranker():
return None
@lru_cache
def get_bm25_retriever() -> BM25Retriever | None:
"""Return BM25 retriever if rank_bm25 + jieba are installed, else None."""
retriever = BM25Retriever(vector_index=get_vector_index())
return retriever if retriever.available else None
@lru_cache
def get_retrieval_service() -> KnowledgeRetrievalService:
"""Return retrieval service."""
@@ -96,6 +104,7 @@ def get_retrieval_service() -> KnowledgeRetrievalService:
)
return KnowledgeRetrievalService(
retriever=retriever,
bm25_retriever=get_bm25_retriever(),
reranker=get_reranker(),
reranker_top_k=settings.reranker_top_k,
)