Files
AIRegulation-DocAnalysis/backend/app/services/rag/retriever.py

83 lines
3.0 KiB
Python

"""Provide service-layer logic for retriever."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Optional
from app.shared.bootstrap import get_retrieval_service
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class RetrievedDocument:
"""Represent the Retrieved Document type."""
content: str
doc_id: str
doc_name: str
section_title: str
clause_number: str
page_number: int
score: float
metadata: dict[str, Any] = field(default_factory=dict)
class Retriever:
"""Provide the Retriever retriever."""
def __init__(self, top_k: int = 5, rerank: bool = False, min_score: float = 0.0):
"""Initialize the Retriever instance."""
self.top_k = top_k
self.rerank = rerank
self.min_score = min_score
def retrieve(self, query: str, filters: Optional[str] = None, top_k: Optional[int] = None) -> list[RetrievedDocument]:
"""Handle retrieve for the Retriever instance."""
results = get_retrieval_service().retrieve(query=query, top_k=top_k or self.top_k, filters=filters)
return [
RetrievedDocument(
content=item.content,
doc_id=item.doc_id,
doc_name=item.doc_name,
section_title=item.section_title,
clause_number=item.metadata.get("clause_number", ""),
page_number=item.page_number,
score=item.score,
metadata=item.metadata,
)
for item in results
if item.score >= self.min_score
]
def retrieve_with_scores(self, query: str, filters: Optional[str] = None) -> list[dict]:
"""Handle retrieve with scores for the Retriever instance."""
return [
{
"content": item.content,
"doc_id": item.doc_id,
"doc_name": item.doc_name,
"section_title": item.section_title,
"clause_number": item.clause_number,
"page_number": item.page_number,
"score": item.score,
}
for item in self.retrieve(query, filters)
]
def search_by_doc_name(self, query: str, doc_name: str) -> list[RetrievedDocument]:
"""Search by doc name for the Retriever instance."""
return self.retrieve(query, filters=f'doc_name == "{doc_name}"')
def search_by_regulation_type(self, query: str, regulation_type: str) -> list[RetrievedDocument]:
"""Search by regulation type for the Retriever instance."""
return self.retrieve(query, filters=f'regulation_type == "{regulation_type}"')
def close(self):
"""Release the resources held by this component."""
return None
def retrieve_regulations(query: str, top_k: int = 10, filters: Optional[str] = None) -> list[RetrievedDocument]:
"""Handle retrieve regulations."""
return Retriever(top_k=top_k).retrieve(query, filters)