"""Implement infrastructure support for openai compatible answer generator.""" from __future__ import annotations import time from typing import Generator from app.config.settings import settings from app.domain.conversation import AnswerGenerator, AnswerResult, AnswerSource from app.domain.retrieval import RetrievedChunk from app.services.llm.llm_factory import get_llm_client # Keep adapter behavior explicit so integration details remain easy to audit. PROMPT_TEMPLATES = { "default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。", "compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。", } class OpenAICompatibleAnswerGenerator(AnswerGenerator): """Represent the Open A I Compatible Answer Generator type.""" @staticmethod def _estimate_tokens(text: str) -> int: """Estimate token count for mixed Chinese/English text. Chinese chars are ~1.5 chars/token; ASCII is ~4 chars/token. """ chinese = sum(1 for c in text if "一" <= c <= "鿿") other = len(text) - chinese return int(chinese / 1.5 + other / 4) + 1 def _build_messages( self, *, query: str, retrieved_chunks: list[RetrievedChunk], history: list[dict[str, str]] | None, prompt_template: str | None, ) -> tuple[list[dict[str, str]], int]: """Handle build messages for this module for the Open A I Compatible Answer Generator instance.""" system_prompt = PROMPT_TEMPLATES.get(prompt_template or "compliance_qa", PROMPT_TEMPLATES["default"]) context_blocks = [] context_tokens = 0 for idx, chunk in enumerate(retrieved_chunks, start=1): block = ( f"[{idx}] 文档: {chunk.doc_name}\n" f"章节: {chunk.section_title or '未标注'}\n" f"页码: {chunk.page_number}\n" f"内容: {chunk.content}" ) block_tokens = self._estimate_tokens(block) if context_tokens + block_tokens > settings.rag_max_context_tokens: break context_tokens += block_tokens context_blocks.append(block) context = "\n\n".join(context_blocks) messages = [{"role": "system", "content": system_prompt}] for item in history or []: messages.append({"role": item["role"], "content": item["content"]}) messages.append( { "role": "user", "content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。", } ) return messages, context_tokens def _is_context_truncated(self, *, retrieved_chunks: list[RetrievedChunk], context_tokens: int) -> bool: """Return whether the prompt context had to omit retrieved chunks to fit the token budget.""" if not retrieved_chunks: return False estimated_total_tokens = sum( self._estimate_tokens( f"[{idx}] 文档: {chunk.doc_name}\n" f"章节: {chunk.section_title or '未标注'}\n" f"页码: {chunk.page_number}\n" f"内容: {chunk.content}" ) for idx, chunk in enumerate(retrieved_chunks, start=1) ) return estimated_total_tokens > context_tokens def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]: """Handle sources for this module for the Open A I Compatible Answer Generator instance.""" return [ AnswerSource( doc_id=chunk.doc_id, doc_name=chunk.doc_name, chunk_id=chunk.chunk_id, section_title=chunk.section_title, page_number=chunk.page_number, score=chunk.score, content=chunk.content, metadata=chunk.metadata, ) for chunk in chunks ] def generate( self, *, query: str, retrieved_chunks: list[RetrievedChunk], history: list[dict[str, str]] | None = None, provider: str | None = None, model: str | None = None, prompt_template: str | None = None, ) -> AnswerResult: """Handle generate for the Open A I Compatible Answer Generator instance.""" start = time.time() messages, context_tokens = self._build_messages( query=query, retrieved_chunks=retrieved_chunks, history=history, prompt_template=prompt_template, ) client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model) response = client.chat(messages) latency_ms = int((time.time() - start) * 1000) return AnswerResult( answer=response.content if response.is_success else "", sources=self._sources(retrieved_chunks), model=response.model or (model or settings.llm_model), latency_ms=latency_ms, retrieved_count=len(retrieved_chunks), context_tokens=context_tokens, truncated=self._is_context_truncated( retrieved_chunks=retrieved_chunks, context_tokens=context_tokens, ), error=response.error, ) def stream_generate( self, *, query: str, retrieved_chunks: list[RetrievedChunk], history: list[dict[str, str]] | None = None, provider: str | None = None, model: str | None = None, prompt_template: str | None = None, ) -> Generator[dict, None, AnswerResult]: """Stream generate for the Open A I Compatible Answer Generator instance.""" start = time.time() messages, context_tokens = self._build_messages( query=query, retrieved_chunks=retrieved_chunks, history=history, prompt_template=prompt_template, ) sources = [source.__dict__ for source in self._sources(retrieved_chunks)] yield {"event": "sources", "data": sources} client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model) answer_parts: list[str] = [] try: if hasattr(client, "stream_chat"): for chunk in client.stream_chat(messages): answer_parts.append(chunk) yield {"event": "content", "data": chunk} else: response = client.chat(messages) answer_parts.append(response.content) yield {"event": "content", "data": response.content} except Exception as exc: yield {"event": "error", "data": str(exc)} return yield { "event": "done", "data": { "latency_ms": int((time.time() - start) * 1000), "retrieved_count": len(retrieved_chunks), "context_tokens": context_tokens, "model": model or settings.llm_model, }, }