- Removed multiple failed document entries from `documents.json`. - Added a new document entry with updated metadata and changed the index name to `regulations_dense_1024_v2`. - Updated architecture documentation to reflect changes in the Milvus collection name. - Adjusted requirements by removing the sqlalchemy dependency. - Modified test cases to align with new document structure and naming conventions. - Introduced a new test file for Milvus vector index runtime recovery and error handling. - Updated assertions in various test files to ensure compatibility with the new schema.
184 lines
7.5 KiB
Python
184 lines
7.5 KiB
Python
"""Implement infrastructure support for openai compatible answer generator."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from typing import Generator
|
|
|
|
from app.config.settings import settings
|
|
from app.domain.conversation import AnswerGenerator, AnswerResult, AnswerSource
|
|
from app.domain.retrieval import RetrievedChunk
|
|
from app.services.llm.llm_factory import get_llm_client
|
|
# Keep adapter behavior explicit so integration details remain easy to audit.
|
|
|
|
|
|
PROMPT_TEMPLATES = {
|
|
"default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。",
|
|
"compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。",
|
|
}
|
|
|
|
|
|
class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
|
"""Represent the Open A I Compatible Answer Generator type."""
|
|
|
|
@staticmethod
|
|
def _estimate_tokens(text: str) -> int:
|
|
"""Estimate token count for mixed Chinese/English text.
|
|
|
|
Chinese chars are ~1.5 chars/token; ASCII is ~4 chars/token.
|
|
"""
|
|
chinese = sum(1 for c in text if "一" <= c <= "鿿")
|
|
other = len(text) - chinese
|
|
return int(chinese / 1.5 + other / 4) + 1
|
|
|
|
def _build_messages(
|
|
self,
|
|
*,
|
|
query: str,
|
|
retrieved_chunks: list[RetrievedChunk],
|
|
history: list[dict[str, str]] | None,
|
|
prompt_template: str | None,
|
|
) -> tuple[list[dict[str, str]], int]:
|
|
"""Handle build messages for this module for the Open A I Compatible Answer Generator instance."""
|
|
system_prompt = PROMPT_TEMPLATES.get(prompt_template or "compliance_qa", PROMPT_TEMPLATES["default"])
|
|
context_blocks = []
|
|
context_tokens = 0
|
|
for idx, chunk in enumerate(retrieved_chunks, start=1):
|
|
block = (
|
|
f"[{idx}] 文档: {chunk.doc_title}\n"
|
|
f"章节: {chunk.section_title or '未标注'}\n"
|
|
f"页码: {chunk.page_start}" + (f"-{chunk.page_end}" if chunk.page_end and chunk.page_end != chunk.page_start else "") + "\n"
|
|
f"内容: {chunk.text}"
|
|
)
|
|
block_tokens = self._estimate_tokens(block)
|
|
if context_tokens + block_tokens > settings.rag_max_context_tokens:
|
|
break
|
|
context_tokens += block_tokens
|
|
context_blocks.append(block)
|
|
context = "\n\n".join(context_blocks)
|
|
messages = [{"role": "system", "content": system_prompt}]
|
|
for item in history or []:
|
|
messages.append({"role": item["role"], "content": item["content"]})
|
|
messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。",
|
|
}
|
|
)
|
|
return messages, context_tokens
|
|
|
|
def _is_context_truncated(self, *, retrieved_chunks: list[RetrievedChunk], context_tokens: int) -> bool:
|
|
"""Return whether the prompt context had to omit retrieved chunks to fit the token budget."""
|
|
if not retrieved_chunks:
|
|
return False
|
|
estimated_total_tokens = sum(
|
|
self._estimate_tokens(
|
|
f"[{idx}] 文档: {chunk.doc_title}\n"
|
|
f"章节: {chunk.section_title or '未标注'}\n"
|
|
f"页码: {chunk.page_start}" + (f"-{chunk.page_end}" if chunk.page_end and chunk.page_end != chunk.page_start else "") + "\n"
|
|
f"内容: {chunk.text}"
|
|
)
|
|
for idx, chunk in enumerate(retrieved_chunks, start=1)
|
|
)
|
|
return estimated_total_tokens > context_tokens
|
|
|
|
def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]:
|
|
"""Handle sources for this module for the Open A I Compatible Answer Generator instance."""
|
|
return [
|
|
AnswerSource(
|
|
doc_id=chunk.doc_id,
|
|
doc_title=chunk.doc_title,
|
|
chunk_id=chunk.chunk_id,
|
|
chunk_type=chunk.chunk_type,
|
|
section_title=chunk.section_title,
|
|
page_start=chunk.page_start,
|
|
page_end=chunk.page_end,
|
|
section_level=chunk.section_level,
|
|
chunk_index=chunk.chunk_index,
|
|
piece_index=chunk.piece_index,
|
|
score=chunk.score,
|
|
text=chunk.text,
|
|
metadata=chunk.metadata,
|
|
)
|
|
for chunk in chunks
|
|
]
|
|
|
|
def generate(
|
|
self,
|
|
*,
|
|
query: str,
|
|
retrieved_chunks: list[RetrievedChunk],
|
|
history: list[dict[str, str]] | None = None,
|
|
provider: str | None = None,
|
|
model: str | None = None,
|
|
prompt_template: str | None = None,
|
|
) -> AnswerResult:
|
|
"""Handle generate for the Open A I Compatible Answer Generator instance."""
|
|
start = time.time()
|
|
messages, context_tokens = self._build_messages(
|
|
query=query,
|
|
retrieved_chunks=retrieved_chunks,
|
|
history=history,
|
|
prompt_template=prompt_template,
|
|
)
|
|
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
|
|
response = client.chat(messages)
|
|
latency_ms = int((time.time() - start) * 1000)
|
|
return AnswerResult(
|
|
answer=response.content if response.is_success else "",
|
|
sources=self._sources(retrieved_chunks),
|
|
model=response.model or (model or settings.llm_model),
|
|
latency_ms=latency_ms,
|
|
retrieved_count=len(retrieved_chunks),
|
|
context_tokens=context_tokens,
|
|
truncated=self._is_context_truncated(
|
|
retrieved_chunks=retrieved_chunks,
|
|
context_tokens=context_tokens,
|
|
),
|
|
error=response.error,
|
|
)
|
|
|
|
def stream_generate(
|
|
self,
|
|
*,
|
|
query: str,
|
|
retrieved_chunks: list[RetrievedChunk],
|
|
history: list[dict[str, str]] | None = None,
|
|
provider: str | None = None,
|
|
model: str | None = None,
|
|
prompt_template: str | None = None,
|
|
) -> Generator[dict, None, AnswerResult]:
|
|
"""Stream generate for the Open A I Compatible Answer Generator instance."""
|
|
start = time.time()
|
|
messages, context_tokens = self._build_messages(
|
|
query=query,
|
|
retrieved_chunks=retrieved_chunks,
|
|
history=history,
|
|
prompt_template=prompt_template,
|
|
)
|
|
sources = [source.__dict__ for source in self._sources(retrieved_chunks)]
|
|
yield {"event": "sources", "data": sources}
|
|
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
|
|
answer_parts: list[str] = []
|
|
try:
|
|
if hasattr(client, "stream_chat"):
|
|
for chunk in client.stream_chat(messages):
|
|
answer_parts.append(chunk)
|
|
yield {"event": "content", "data": chunk}
|
|
else:
|
|
response = client.chat(messages)
|
|
answer_parts.append(response.content)
|
|
yield {"event": "content", "data": response.content}
|
|
except Exception as exc:
|
|
yield {"event": "error", "data": str(exc)}
|
|
return
|
|
yield {
|
|
"event": "done",
|
|
"data": {
|
|
"latency_ms": int((time.time() - start) * 1000),
|
|
"retrieved_count": len(retrieved_chunks),
|
|
"context_tokens": context_tokens,
|
|
"model": model or settings.llm_model,
|
|
},
|
|
}
|