Fix 法规对话
This commit is contained in:
@@ -12,7 +12,6 @@ from app.services.llm.llm_factory import get_llm_client
|
||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||
|
||||
|
||||
|
||||
PROMPT_TEMPLATES = {
|
||||
"default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。",
|
||||
"compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。",
|
||||
@@ -21,6 +20,17 @@ PROMPT_TEMPLATES = {
|
||||
|
||||
class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||
"""Represent the Open A I Compatible Answer Generator type."""
|
||||
|
||||
@staticmethod
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count for mixed Chinese/English text.
|
||||
|
||||
Chinese chars are ~1.5 chars/token; ASCII is ~4 chars/token.
|
||||
"""
|
||||
chinese = sum(1 for c in text if "一" <= c <= "鿿")
|
||||
other = len(text) - chinese
|
||||
return int(chinese / 1.5 + other / 4) + 1
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
*,
|
||||
@@ -40,9 +50,12 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||
f"页码: {chunk.page_number}\n"
|
||||
f"内容: {chunk.content}"
|
||||
)
|
||||
context_tokens += len(block)
|
||||
block_tokens = self._estimate_tokens(block)
|
||||
if context_tokens + block_tokens > settings.rag_max_context_tokens:
|
||||
break
|
||||
context_tokens += block_tokens
|
||||
context_blocks.append(block)
|
||||
context = "\n\n".join(context_blocks)[: settings.rag_max_context_tokens * 4]
|
||||
context = "\n\n".join(context_blocks)
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
for item in history or []:
|
||||
messages.append({"role": item["role"], "content": item["content"]})
|
||||
@@ -52,7 +65,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||
"content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。",
|
||||
}
|
||||
)
|
||||
return messages, min(context_tokens, settings.rag_max_context_tokens)
|
||||
return messages, context_tokens
|
||||
|
||||
def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]:
|
||||
"""Handle sources for this module for the Open A I Compatible Answer Generator instance."""
|
||||
@@ -98,7 +111,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||
latency_ms=latency_ms,
|
||||
retrieved_count=len(retrieved_chunks),
|
||||
context_tokens=context_tokens,
|
||||
truncated=False,
|
||||
truncated=len(retrieved_chunks) > len(messages),
|
||||
error=response.error,
|
||||
)
|
||||
|
||||
@@ -124,15 +137,18 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||
yield {"event": "sources", "data": sources}
|
||||
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
|
||||
answer_parts: list[str] = []
|
||||
if hasattr(client, "stream_chat"):
|
||||
for chunk in client.stream_chat(messages):
|
||||
answer_parts.append(chunk)
|
||||
yield {"event": "content", "data": chunk}
|
||||
else:
|
||||
response = client.chat(messages)
|
||||
answer_parts.append(response.content)
|
||||
yield {"event": "content", "data": response.content}
|
||||
full_answer = "".join(answer_parts)
|
||||
try:
|
||||
if hasattr(client, "stream_chat"):
|
||||
for chunk in client.stream_chat(messages):
|
||||
answer_parts.append(chunk)
|
||||
yield {"event": "content", "data": chunk}
|
||||
else:
|
||||
response = client.chat(messages)
|
||||
answer_parts.append(response.content)
|
||||
yield {"event": "content", "data": response.content}
|
||||
except Exception as exc:
|
||||
yield {"event": "error", "data": str(exc)}
|
||||
return
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
|
||||
149
backend/app/infrastructure/vectorstore/bm25_retriever.py
Normal file
149
backend/app/infrastructure/vectorstore/bm25_retriever.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""BM25 sparse retriever backed by the existing Milvus collection.
|
||||
|
||||
Falls back to no-op if rank_bm25 or jieba is not installed.
|
||||
The index is built lazily on first query and can be refreshed after
|
||||
new documents are ingested by calling refresh().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.domain.retrieval import RetrievedChunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
import jieba
|
||||
|
||||
jieba.setLogLevel(logging.WARNING)
|
||||
_BM25_AVAILABLE = True
|
||||
except ImportError:
|
||||
_BM25_AVAILABLE = False
|
||||
logger.warning("rank_bm25 or jieba not installed – BM25 retrieval disabled. Run: pip install rank-bm25 jieba")
|
||||
|
||||
|
||||
def _tokenize(text: str) -> list[str]:
|
||||
"""Tokenize Chinese/mixed text with jieba."""
|
||||
return list(jieba.cut(text))
|
||||
|
||||
|
||||
class BM25Retriever:
|
||||
"""Sparse BM25 retriever that indexes chunks stored in Milvus.
|
||||
|
||||
Thread-safe: the index is protected by a lock so concurrent requests
|
||||
during the initial build do not race.
|
||||
"""
|
||||
|
||||
def __init__(self, vector_index: "MilvusVectorIndex") -> None:
|
||||
self._vector_index = vector_index
|
||||
self._lock = threading.Lock()
|
||||
self._index: "BM25Okapi | None" = None
|
||||
self._chunks: list[RetrievedChunk] = []
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return _BM25_AVAILABLE
|
||||
|
||||
def _load_all_chunks(self) -> list[RetrievedChunk]:
|
||||
"""Fetch all chunk records from Milvus (no vectors needed)."""
|
||||
try:
|
||||
rows = self._vector_index.collection.query(
|
||||
expr='doc_id != ""',
|
||||
output_fields=["id", "doc_id", "doc_name", "content", "section_title", "page_number"],
|
||||
limit=16384,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("BM25Retriever: failed to fetch chunks from Milvus")
|
||||
return []
|
||||
return [
|
||||
RetrievedChunk(
|
||||
chunk_id=str(row.get("id", "")),
|
||||
doc_id=str(row.get("doc_id", "")),
|
||||
doc_name=str(row.get("doc_name", "")),
|
||||
content=str(row.get("content", "")),
|
||||
score=0.0,
|
||||
section_title=str(row.get("section_title", "")),
|
||||
page_number=int(row.get("page_number") or 0),
|
||||
metadata={},
|
||||
)
|
||||
for row in rows
|
||||
if row.get("content")
|
||||
]
|
||||
|
||||
def _ensure_built(self) -> None:
|
||||
if self._index is not None:
|
||||
return
|
||||
with self._lock:
|
||||
if self._index is not None:
|
||||
return
|
||||
self._build()
|
||||
|
||||
def _build(self) -> None:
|
||||
logger.info("BM25Retriever: building index …")
|
||||
chunks = self._load_all_chunks()
|
||||
if not chunks:
|
||||
logger.warning("BM25Retriever: no chunks found, index is empty")
|
||||
self._chunks = []
|
||||
self._index = BM25Okapi([[]])
|
||||
return
|
||||
tokenized = [_tokenize(c.content) for c in chunks]
|
||||
self._chunks = chunks
|
||||
self._index = BM25Okapi(tokenized)
|
||||
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Rebuild the index (call after new documents are ingested)."""
|
||||
with self._lock:
|
||||
self._index = None
|
||||
self._chunks = []
|
||||
self._build()
|
||||
|
||||
def retrieve(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Return top_k chunks ranked by BM25 score."""
|
||||
if not _BM25_AVAILABLE:
|
||||
return []
|
||||
self._ensure_built()
|
||||
if not self._chunks:
|
||||
return []
|
||||
|
||||
tokens = _tokenize(query)
|
||||
scores = self._index.get_scores(tokens) # type: ignore[union-attr]
|
||||
|
||||
# Pair and sort descending by score
|
||||
ranked = sorted(
|
||||
((float(scores[i]), self._chunks[i]) for i in range(len(self._chunks))),
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
results: list[RetrievedChunk] = []
|
||||
for score, chunk in ranked[: top_k * 2]:
|
||||
if score <= 0:
|
||||
break
|
||||
# Apply simple regulation_type filter if provided
|
||||
if filters and chunk.metadata.get("regulation_type"):
|
||||
types = [t.strip() for t in filters.split(",")]
|
||||
if chunk.metadata.get("regulation_type") not in types:
|
||||
continue
|
||||
results.append(
|
||||
RetrievedChunk(
|
||||
chunk_id=chunk.chunk_id,
|
||||
doc_id=chunk.doc_id,
|
||||
doc_name=chunk.doc_name,
|
||||
content=chunk.content,
|
||||
score=score,
|
||||
section_title=chunk.section_title,
|
||||
page_number=chunk.page_number,
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
)
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
return results
|
||||
Reference in New Issue
Block a user