98 lines
3.1 KiB
Python
98 lines
3.1 KiB
Python
"""Provide service-layer logic for qa agent."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Generator, List, Optional
|
|
|
|
from app.config.settings import settings
|
|
from app.shared.bootstrap import get_agent_conversation_service
|
|
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
|
|
|
|
|
|
|
@dataclass
|
|
class AgentResponse:
|
|
"""Represent the Agent Response type."""
|
|
answer: str
|
|
sources: List[Dict] = field(default_factory=list)
|
|
model: str = ""
|
|
latency_ms: int = 0
|
|
retrieved_count: int = 0
|
|
context_tokens: int = 0
|
|
truncated: bool = False
|
|
error: Optional[str] = None
|
|
|
|
@property
|
|
def is_success(self) -> bool:
|
|
"""Return whether success for the Agent Response instance."""
|
|
return self.error is None
|
|
|
|
|
|
@dataclass
|
|
class AgentConfig:
|
|
"""Define configuration for agent config."""
|
|
llm_provider: str = settings.llm_provider
|
|
llm_model: str = settings.llm_model
|
|
top_k: int = settings.rag_top_k
|
|
min_score: float = 0.0
|
|
max_context_tokens: int = settings.rag_max_context_tokens
|
|
temperature: float = settings.llm_temperature
|
|
prompt_template: str = "compliance_qa"
|
|
include_metadata: bool = True
|
|
|
|
|
|
class QAAgent:
|
|
"""Represent the Q A Agent type."""
|
|
def __init__(self, config: Optional[AgentConfig] = None):
|
|
"""Initialize the Q A Agent instance."""
|
|
self.config = config or AgentConfig()
|
|
|
|
def ask(
|
|
self,
|
|
query: str,
|
|
filters: Optional[str] = None,
|
|
prompt_template: Optional[str] = None,
|
|
) -> AgentResponse:
|
|
"""Handle ask for the Q A Agent instance."""
|
|
_, result = get_agent_conversation_service().ask(
|
|
query=query,
|
|
filters=filters,
|
|
provider=self.config.llm_provider,
|
|
model=self.config.llm_model,
|
|
top_k=self.config.top_k,
|
|
prompt_template=prompt_template or self.config.prompt_template,
|
|
)
|
|
return AgentResponse(
|
|
answer=result.answer,
|
|
sources=[source.__dict__ for source in result.sources],
|
|
model=result.model,
|
|
latency_ms=result.latency_ms,
|
|
retrieved_count=result.retrieved_count,
|
|
context_tokens=result.context_tokens,
|
|
truncated=result.truncated,
|
|
error=result.error,
|
|
)
|
|
|
|
def ask_stream(self, query: str, filters: Optional[str] = None) -> Generator[dict, None, None]:
|
|
"""Handle ask stream for the Q A Agent instance."""
|
|
_, stream = get_agent_conversation_service().stream_chat(
|
|
query=query,
|
|
filters=filters,
|
|
provider=self.config.llm_provider,
|
|
model=self.config.llm_model,
|
|
top_k=self.config.top_k,
|
|
prompt_template=self.config.prompt_template,
|
|
)
|
|
for event in stream:
|
|
yield event
|
|
|
|
def close(self):
|
|
"""Release the resources held by this component."""
|
|
return None
|
|
|
|
|
|
def ask_compliance_question(query: str, top_k: int = 5) -> AgentResponse:
|
|
"""Handle ask compliance question."""
|
|
return QAAgent(AgentConfig(top_k=top_k)).ask(query)
|