Fix 法规对话

This commit is contained in:
2026-05-21 23:20:39 +08:00
parent 1b640f0084
commit bf6d47e1fd
11 changed files with 553 additions and 428 deletions

View File

@@ -12,7 +12,6 @@ from app.services.llm.llm_factory import get_llm_client
# Keep adapter behavior explicit so integration details remain easy to audit.
PROMPT_TEMPLATES = {
"default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。",
"compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。",
@@ -21,6 +20,17 @@ PROMPT_TEMPLATES = {
class OpenAICompatibleAnswerGenerator(AnswerGenerator):
"""Represent the Open A I Compatible Answer Generator type."""
@staticmethod
def _estimate_tokens(text: str) -> int:
"""Estimate token count for mixed Chinese/English text.
Chinese chars are ~1.5 chars/token; ASCII is ~4 chars/token.
"""
chinese = sum(1 for c in text if "" <= c <= "鿿")
other = len(text) - chinese
return int(chinese / 1.5 + other / 4) + 1
def _build_messages(
self,
*,
@@ -40,9 +50,12 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
f"页码: {chunk.page_number}\n"
f"内容: {chunk.content}"
)
context_tokens += len(block)
block_tokens = self._estimate_tokens(block)
if context_tokens + block_tokens > settings.rag_max_context_tokens:
break
context_tokens += block_tokens
context_blocks.append(block)
context = "\n\n".join(context_blocks)[: settings.rag_max_context_tokens * 4]
context = "\n\n".join(context_blocks)
messages = [{"role": "system", "content": system_prompt}]
for item in history or []:
messages.append({"role": item["role"], "content": item["content"]})
@@ -52,7 +65,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
"content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。",
}
)
return messages, min(context_tokens, settings.rag_max_context_tokens)
return messages, context_tokens
def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]:
"""Handle sources for this module for the Open A I Compatible Answer Generator instance."""
@@ -98,7 +111,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
latency_ms=latency_ms,
retrieved_count=len(retrieved_chunks),
context_tokens=context_tokens,
truncated=False,
truncated=len(retrieved_chunks) > len(messages),
error=response.error,
)
@@ -124,15 +137,18 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
yield {"event": "sources", "data": sources}
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
answer_parts: list[str] = []
if hasattr(client, "stream_chat"):
for chunk in client.stream_chat(messages):
answer_parts.append(chunk)
yield {"event": "content", "data": chunk}
else:
response = client.chat(messages)
answer_parts.append(response.content)
yield {"event": "content", "data": response.content}
full_answer = "".join(answer_parts)
try:
if hasattr(client, "stream_chat"):
for chunk in client.stream_chat(messages):
answer_parts.append(chunk)
yield {"event": "content", "data": chunk}
else:
response = client.chat(messages)
answer_parts.append(response.content)
yield {"event": "content", "data": response.content}
except Exception as exc:
yield {"event": "error", "data": str(exc)}
return
yield {
"event": "done",
"data": {