2026-05-18 16:32:42 +08:00
|
|
|
"""Implement infrastructure support for openai compatible answer generator."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
from typing import Generator
|
|
|
|
|
|
|
|
|
|
from app.config.settings import settings
|
|
|
|
|
from app.domain.conversation import AnswerGenerator, AnswerResult, AnswerSource
|
|
|
|
|
from app.domain.retrieval import RetrievedChunk
|
|
|
|
|
from app.services.llm.llm_factory import get_llm_client
|
|
|
|
|
# Keep adapter behavior explicit so integration details remain easy to audit.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROMPT_TEMPLATES = {
|
|
|
|
|
"default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。",
|
|
|
|
|
"compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
|
|
|
|
"""Represent the Open A I Compatible Answer Generator type."""
|
2026-05-21 23:20:39 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _estimate_tokens(text: str) -> int:
|
|
|
|
|
"""Estimate token count for mixed Chinese/English text.
|
|
|
|
|
|
|
|
|
|
Chinese chars are ~1.5 chars/token; ASCII is ~4 chars/token.
|
|
|
|
|
"""
|
|
|
|
|
chinese = sum(1 for c in text if "一" <= c <= "鿿")
|
|
|
|
|
other = len(text) - chinese
|
|
|
|
|
return int(chinese / 1.5 + other / 4) + 1
|
|
|
|
|
|
2026-05-18 16:32:42 +08:00
|
|
|
def _build_messages(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
query: str,
|
|
|
|
|
retrieved_chunks: list[RetrievedChunk],
|
|
|
|
|
history: list[dict[str, str]] | None,
|
|
|
|
|
prompt_template: str | None,
|
|
|
|
|
) -> tuple[list[dict[str, str]], int]:
|
|
|
|
|
"""Handle build messages for this module for the Open A I Compatible Answer Generator instance."""
|
|
|
|
|
system_prompt = PROMPT_TEMPLATES.get(prompt_template or "compliance_qa", PROMPT_TEMPLATES["default"])
|
|
|
|
|
context_blocks = []
|
|
|
|
|
context_tokens = 0
|
|
|
|
|
for idx, chunk in enumerate(retrieved_chunks, start=1):
|
|
|
|
|
block = (
|
|
|
|
|
f"[{idx}] 文档: {chunk.doc_name}\n"
|
|
|
|
|
f"章节: {chunk.section_title or '未标注'}\n"
|
|
|
|
|
f"页码: {chunk.page_number}\n"
|
|
|
|
|
f"内容: {chunk.content}"
|
|
|
|
|
)
|
2026-05-21 23:20:39 +08:00
|
|
|
block_tokens = self._estimate_tokens(block)
|
|
|
|
|
if context_tokens + block_tokens > settings.rag_max_context_tokens:
|
|
|
|
|
break
|
|
|
|
|
context_tokens += block_tokens
|
2026-05-18 16:32:42 +08:00
|
|
|
context_blocks.append(block)
|
2026-05-21 23:20:39 +08:00
|
|
|
context = "\n\n".join(context_blocks)
|
2026-05-18 16:32:42 +08:00
|
|
|
messages = [{"role": "system", "content": system_prompt}]
|
|
|
|
|
for item in history or []:
|
|
|
|
|
messages.append({"role": item["role"], "content": item["content"]})
|
|
|
|
|
messages.append(
|
|
|
|
|
{
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。",
|
|
|
|
|
}
|
|
|
|
|
)
|
2026-05-21 23:20:39 +08:00
|
|
|
return messages, context_tokens
|
2026-05-18 16:32:42 +08:00
|
|
|
|
|
|
|
|
def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]:
|
|
|
|
|
"""Handle sources for this module for the Open A I Compatible Answer Generator instance."""
|
|
|
|
|
return [
|
|
|
|
|
AnswerSource(
|
|
|
|
|
doc_id=chunk.doc_id,
|
|
|
|
|
doc_name=chunk.doc_name,
|
|
|
|
|
chunk_id=chunk.chunk_id,
|
|
|
|
|
section_title=chunk.section_title,
|
|
|
|
|
page_number=chunk.page_number,
|
|
|
|
|
score=chunk.score,
|
|
|
|
|
content=chunk.content,
|
|
|
|
|
metadata=chunk.metadata,
|
|
|
|
|
)
|
|
|
|
|
for chunk in chunks
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
query: str,
|
|
|
|
|
retrieved_chunks: list[RetrievedChunk],
|
|
|
|
|
history: list[dict[str, str]] | None = None,
|
|
|
|
|
provider: str | None = None,
|
|
|
|
|
model: str | None = None,
|
|
|
|
|
prompt_template: str | None = None,
|
|
|
|
|
) -> AnswerResult:
|
|
|
|
|
"""Handle generate for the Open A I Compatible Answer Generator instance."""
|
|
|
|
|
start = time.time()
|
|
|
|
|
messages, context_tokens = self._build_messages(
|
|
|
|
|
query=query,
|
|
|
|
|
retrieved_chunks=retrieved_chunks,
|
|
|
|
|
history=history,
|
|
|
|
|
prompt_template=prompt_template,
|
|
|
|
|
)
|
|
|
|
|
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
|
|
|
|
|
response = client.chat(messages)
|
|
|
|
|
latency_ms = int((time.time() - start) * 1000)
|
|
|
|
|
return AnswerResult(
|
|
|
|
|
answer=response.content if response.is_success else "",
|
|
|
|
|
sources=self._sources(retrieved_chunks),
|
|
|
|
|
model=response.model or (model or settings.llm_model),
|
|
|
|
|
latency_ms=latency_ms,
|
|
|
|
|
retrieved_count=len(retrieved_chunks),
|
|
|
|
|
context_tokens=context_tokens,
|
2026-05-21 23:20:39 +08:00
|
|
|
truncated=len(retrieved_chunks) > len(messages),
|
2026-05-18 16:32:42 +08:00
|
|
|
error=response.error,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def stream_generate(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
query: str,
|
|
|
|
|
retrieved_chunks: list[RetrievedChunk],
|
|
|
|
|
history: list[dict[str, str]] | None = None,
|
|
|
|
|
provider: str | None = None,
|
|
|
|
|
model: str | None = None,
|
|
|
|
|
prompt_template: str | None = None,
|
|
|
|
|
) -> Generator[dict, None, AnswerResult]:
|
|
|
|
|
"""Stream generate for the Open A I Compatible Answer Generator instance."""
|
|
|
|
|
start = time.time()
|
|
|
|
|
messages, context_tokens = self._build_messages(
|
|
|
|
|
query=query,
|
|
|
|
|
retrieved_chunks=retrieved_chunks,
|
|
|
|
|
history=history,
|
|
|
|
|
prompt_template=prompt_template,
|
|
|
|
|
)
|
|
|
|
|
sources = [source.__dict__ for source in self._sources(retrieved_chunks)]
|
|
|
|
|
yield {"event": "sources", "data": sources}
|
|
|
|
|
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
|
|
|
|
|
answer_parts: list[str] = []
|
2026-05-21 23:20:39 +08:00
|
|
|
try:
|
|
|
|
|
if hasattr(client, "stream_chat"):
|
|
|
|
|
for chunk in client.stream_chat(messages):
|
|
|
|
|
answer_parts.append(chunk)
|
|
|
|
|
yield {"event": "content", "data": chunk}
|
|
|
|
|
else:
|
|
|
|
|
response = client.chat(messages)
|
|
|
|
|
answer_parts.append(response.content)
|
|
|
|
|
yield {"event": "content", "data": response.content}
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
yield {"event": "error", "data": str(exc)}
|
|
|
|
|
return
|
2026-05-18 16:32:42 +08:00
|
|
|
yield {
|
|
|
|
|
"event": "done",
|
|
|
|
|
"data": {
|
|
|
|
|
"latency_ms": int((time.time() - start) * 1000),
|
|
|
|
|
"retrieved_count": len(retrieved_chunks),
|
|
|
|
|
"context_tokens": context_tokens,
|
|
|
|
|
"model": model or settings.llm_model,
|
|
|
|
|
},
|
|
|
|
|
}
|