Files
AIRegulation-DocAnalysis/backend/app/application/agent/services.py

146 lines
5.7 KiB
Python
Raw Normal View History

"""Implement application-layer logic for services."""
from __future__ import annotations
from typing import Generator
from app.domain.conversation import AnswerGenerator, AnswerResult, ConversationStore
from app.domain.retrieval import RetrievedChunk
from app.application.knowledge import KnowledgeRetrievalService
# Keep orchestration logic centralized so use-case flow stays easy to trace.
class AgentConversationService:
"""Provide the Agent Conversation Service service."""
def __init__(
self,
*,
retrieval_service: KnowledgeRetrievalService,
answer_generator: AnswerGenerator,
conversation_store: ConversationStore,
) -> None:
"""Initialize the Agent Conversation Service instance."""
self.retrieval_service = retrieval_service
self.answer_generator = answer_generator
self.conversation_store = conversation_store
def ask(
self,
*,
query: str,
filters: str | None = None,
provider: str | None = None,
model: str | None = None,
top_k: int = 5,
prompt_template: str | None = None,
session_id: str | None = None,
) -> tuple[str | None, AnswerResult]:
"""Handle ask for the Agent Conversation Service instance."""
history = None
active_session_id = None
if session_id:
session = self.conversation_store.get_session(session_id)
if not session:
raise ValueError("会话不存在或已过期")
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
active_session_id = session.session_id
self.conversation_store.save_message(session_id, role="user", content=query)
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
result = self.answer_generator.generate(
query=query,
retrieved_chunks=retrieved,
history=history,
provider=provider,
model=model,
prompt_template=prompt_template,
)
if active_session_id:
self.conversation_store.save_message(
active_session_id,
role="assistant",
content=result.answer,
sources=[source.__dict__ for source in result.sources],
)
return active_session_id, result
def chat(
self,
*,
query: str,
session_id: str | None = None,
filters: str | None = None,
provider: str | None = None,
model: str | None = None,
top_k: int = 5,
) -> tuple[str, AnswerResult]:
"""Handle chat for the Agent Conversation Service instance."""
session = self.conversation_store.get_session(session_id) if session_id else None
if session is None:
session = self.conversation_store.create_session()
self.conversation_store.save_message(session.session_id, role="user", content=query)
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
result = self.answer_generator.generate(
query=query,
retrieved_chunks=retrieved,
history=history,
provider=provider,
model=model,
)
self.conversation_store.save_message(
session.session_id,
role="assistant",
content=result.answer,
sources=[source.__dict__ for source in result.sources],
)
return session.session_id, result
def stream_chat(
self,
*,
query: str,
session_id: str | None = None,
filters: str | None = None,
provider: str | None = None,
model: str | None = None,
top_k: int = 5,
prompt_template: str | None = None,
) -> tuple[str, Generator[dict, None, None]]:
"""Stream chat for the Agent Conversation Service instance."""
session = self.conversation_store.get_session(session_id) if session_id else None
if session is None:
session = self.conversation_store.create_session()
self.conversation_store.save_message(session.session_id, role="user", content=query)
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
def event_stream() -> Generator[dict, None, None]:
"""Handle event stream for the Agent Conversation Service instance."""
yield {"event": "status", "data": f"找到{len(retrieved)}条相关法规,正在生成回答..."}
answer_parts: list[str] = []
sources_payload: list[dict] = []
for event in self.answer_generator.stream_generate(
query=query,
retrieved_chunks=retrieved,
history=history,
provider=provider,
model=model,
prompt_template=prompt_template,
):
if event.get("event") == "sources":
sources_payload = event.get("data", [])
if event.get("event") == "content":
answer_parts.append(str(event.get("data", "")))
yield event
full_answer = "".join(answer_parts)
self.conversation_store.save_message(
session.session_id,
role="assistant",
content=full_answer,
sources=sources_payload,
)
return session.session_id, event_stream()