2026-05-22 09:50:30 +08:00
|
|
|
"""Implement application-layer logic for agent services."""
|
2026-05-18 16:32:42 +08:00
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2026-05-22 09:50:30 +08:00
|
|
|
from dataclasses import dataclass
|
2026-05-18 16:32:42 +08:00
|
|
|
from typing import Generator
|
|
|
|
|
|
|
|
|
|
from app.domain.conversation import AnswerGenerator, AnswerResult, ConversationStore
|
|
|
|
|
from app.domain.retrieval import RetrievedChunk
|
|
|
|
|
|
|
|
|
|
from app.application.knowledge import KnowledgeRetrievalService
|
|
|
|
|
# Keep orchestration logic centralized so use-case flow stays easy to trace.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentConversationService:
|
|
|
|
|
"""Provide the Agent Conversation Service service."""
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
retrieval_service: KnowledgeRetrievalService,
|
|
|
|
|
answer_generator: AnswerGenerator,
|
|
|
|
|
conversation_store: ConversationStore,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Initialize the Agent Conversation Service instance."""
|
|
|
|
|
self.retrieval_service = retrieval_service
|
|
|
|
|
self.answer_generator = answer_generator
|
|
|
|
|
self.conversation_store = conversation_store
|
|
|
|
|
|
|
|
|
|
def ask(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
query: str,
|
|
|
|
|
filters: str | None = None,
|
|
|
|
|
provider: str | None = None,
|
|
|
|
|
model: str | None = None,
|
|
|
|
|
top_k: int = 5,
|
|
|
|
|
prompt_template: str | None = None,
|
|
|
|
|
session_id: str | None = None,
|
|
|
|
|
) -> tuple[str | None, AnswerResult]:
|
|
|
|
|
"""Handle ask for the Agent Conversation Service instance."""
|
|
|
|
|
history = None
|
|
|
|
|
active_session_id = None
|
|
|
|
|
if session_id:
|
|
|
|
|
session = self.conversation_store.get_session(session_id)
|
|
|
|
|
if not session:
|
|
|
|
|
raise ValueError("会话不存在或已过期")
|
|
|
|
|
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
|
|
|
|
|
active_session_id = session.session_id
|
|
|
|
|
self.conversation_store.save_message(session_id, role="user", content=query)
|
|
|
|
|
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
|
|
|
|
|
result = self.answer_generator.generate(
|
|
|
|
|
query=query,
|
|
|
|
|
retrieved_chunks=retrieved,
|
|
|
|
|
history=history,
|
|
|
|
|
provider=provider,
|
|
|
|
|
model=model,
|
|
|
|
|
prompt_template=prompt_template,
|
|
|
|
|
)
|
|
|
|
|
if active_session_id:
|
|
|
|
|
self.conversation_store.save_message(
|
|
|
|
|
active_session_id,
|
|
|
|
|
role="assistant",
|
|
|
|
|
content=result.answer,
|
|
|
|
|
sources=[source.__dict__ for source in result.sources],
|
|
|
|
|
)
|
|
|
|
|
return active_session_id, result
|
|
|
|
|
|
|
|
|
|
def chat(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
query: str,
|
|
|
|
|
session_id: str | None = None,
|
|
|
|
|
filters: str | None = None,
|
|
|
|
|
provider: str | None = None,
|
|
|
|
|
model: str | None = None,
|
|
|
|
|
top_k: int = 5,
|
|
|
|
|
) -> tuple[str, AnswerResult]:
|
|
|
|
|
"""Handle chat for the Agent Conversation Service instance."""
|
|
|
|
|
session = self.conversation_store.get_session(session_id) if session_id else None
|
|
|
|
|
if session is None:
|
|
|
|
|
session = self.conversation_store.create_session()
|
|
|
|
|
self.conversation_store.save_message(session.session_id, role="user", content=query)
|
|
|
|
|
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
|
|
|
|
|
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
|
|
|
|
|
result = self.answer_generator.generate(
|
|
|
|
|
query=query,
|
|
|
|
|
retrieved_chunks=retrieved,
|
|
|
|
|
history=history,
|
|
|
|
|
provider=provider,
|
|
|
|
|
model=model,
|
|
|
|
|
)
|
|
|
|
|
self.conversation_store.save_message(
|
|
|
|
|
session.session_id,
|
|
|
|
|
role="assistant",
|
|
|
|
|
content=result.answer,
|
|
|
|
|
sources=[source.__dict__ for source in result.sources],
|
|
|
|
|
)
|
|
|
|
|
return session.session_id, result
|
|
|
|
|
|
|
|
|
|
def stream_chat(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
query: str,
|
|
|
|
|
session_id: str | None = None,
|
|
|
|
|
filters: str | None = None,
|
|
|
|
|
provider: str | None = None,
|
|
|
|
|
model: str | None = None,
|
|
|
|
|
top_k: int = 5,
|
|
|
|
|
prompt_template: str | None = None,
|
|
|
|
|
) -> tuple[str, Generator[dict, None, None]]:
|
|
|
|
|
"""Stream chat for the Agent Conversation Service instance."""
|
|
|
|
|
session = self.conversation_store.get_session(session_id) if session_id else None
|
|
|
|
|
if session is None:
|
|
|
|
|
session = self.conversation_store.create_session()
|
|
|
|
|
self.conversation_store.save_message(session.session_id, role="user", content=query)
|
|
|
|
|
history = [{"role": msg.role, "content": msg.content} for msg in session.messages[-10:]]
|
|
|
|
|
retrieved = self.retrieval_service.retrieve(query=query, top_k=top_k, filters=filters)
|
|
|
|
|
|
|
|
|
|
def event_stream() -> Generator[dict, None, None]:
|
|
|
|
|
"""Handle event stream for the Agent Conversation Service instance."""
|
|
|
|
|
yield {"event": "status", "data": f"找到{len(retrieved)}条相关法规,正在生成回答..."}
|
|
|
|
|
answer_parts: list[str] = []
|
|
|
|
|
sources_payload: list[dict] = []
|
|
|
|
|
for event in self.answer_generator.stream_generate(
|
|
|
|
|
query=query,
|
|
|
|
|
retrieved_chunks=retrieved,
|
|
|
|
|
history=history,
|
|
|
|
|
provider=provider,
|
|
|
|
|
model=model,
|
|
|
|
|
prompt_template=prompt_template,
|
|
|
|
|
):
|
|
|
|
|
if event.get("event") == "sources":
|
|
|
|
|
sources_payload = event.get("data", [])
|
|
|
|
|
if event.get("event") == "content":
|
|
|
|
|
answer_parts.append(str(event.get("data", "")))
|
|
|
|
|
yield event
|
|
|
|
|
full_answer = "".join(answer_parts)
|
|
|
|
|
self.conversation_store.save_message(
|
|
|
|
|
session.session_id,
|
|
|
|
|
role="assistant",
|
|
|
|
|
content=full_answer,
|
|
|
|
|
sources=sources_payload,
|
|
|
|
|
)
|
2026-05-22 09:50:30 +08:00
|
|
|
|
2026-05-18 16:32:42 +08:00
|
|
|
return session.session_id, event_stream()
|
2026-05-22 09:50:30 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class AgentSessionFeedbackResult:
|
|
|
|
|
"""Represent the result of storing session feedback."""
|
|
|
|
|
|
|
|
|
|
session_id: str
|
|
|
|
|
message_index: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentSessionService:
|
|
|
|
|
"""Provide application-layer access to session management workflows."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, *, conversation_store: ConversationStore) -> None:
|
|
|
|
|
"""Initialize the Agent Session Service instance."""
|
|
|
|
|
self.conversation_store = conversation_store
|
|
|
|
|
|
|
|
|
|
def get_session(self, session_id: str):
|
|
|
|
|
"""Return a session by id or raise when it does not exist."""
|
|
|
|
|
session = self.conversation_store.get_session(session_id)
|
|
|
|
|
if not session:
|
|
|
|
|
raise ValueError("会话不存在或已过期")
|
|
|
|
|
return session
|
|
|
|
|
|
|
|
|
|
def get_history(self, *, session_id: str, max_turns: int = 5) -> list[dict[str, str]]:
|
|
|
|
|
"""Return the recent conversation history for a session."""
|
|
|
|
|
session = self.get_session(session_id)
|
|
|
|
|
return [{"role": msg.role, "content": msg.content} for msg in session.messages[-(max_turns * 2):]]
|
|
|
|
|
|
|
|
|
|
def delete_session(self, session_id: str) -> None:
|
|
|
|
|
"""Delete a session or raise when it does not exist."""
|
|
|
|
|
if not self.conversation_store.delete_session(session_id):
|
|
|
|
|
raise ValueError("会话不存在")
|
|
|
|
|
|
|
|
|
|
def list_sessions(self) -> list[dict]:
|
|
|
|
|
"""Return the list of visible sessions."""
|
|
|
|
|
return self.conversation_store.list_sessions()
|
|
|
|
|
|
|
|
|
|
def submit_feedback(self, *, session_id: str, message_index: int) -> AgentSessionFeedbackResult:
|
|
|
|
|
"""Validate feedback targets and return a normalized feedback result."""
|
|
|
|
|
session = self.get_session(session_id)
|
|
|
|
|
if message_index < 0 or message_index >= len(session.messages):
|
|
|
|
|
raise ValueError("消息索引不存在")
|
|
|
|
|
# Preserve the existing API behavior until a persistent feedback store is introduced.
|
|
|
|
|
return AgentSessionFeedbackResult(session_id=session_id, message_index=message_index)
|