83 lines
3.0 KiB
Python
83 lines
3.0 KiB
Python
"""Provide service-layer logic for retriever."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Optional
|
|
|
|
from app.shared.bootstrap import get_retrieval_service
|
|
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
|
|
|
|
|
|
|
@dataclass
|
|
class RetrievedDocument:
|
|
"""Represent the Retrieved Document type."""
|
|
content: str
|
|
doc_id: str
|
|
doc_name: str
|
|
section_title: str
|
|
clause_number: str
|
|
page_number: int
|
|
score: float
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class Retriever:
|
|
"""Provide the Retriever retriever."""
|
|
def __init__(self, top_k: int = 5, rerank: bool = False, min_score: float = 0.0):
|
|
"""Initialize the Retriever instance."""
|
|
self.top_k = top_k
|
|
self.rerank = rerank
|
|
self.min_score = min_score
|
|
|
|
def retrieve(self, query: str, filters: Optional[str] = None, top_k: Optional[int] = None) -> list[RetrievedDocument]:
|
|
"""Handle retrieve for the Retriever instance."""
|
|
results = get_retrieval_service().retrieve(query=query, top_k=top_k or self.top_k, filters=filters)
|
|
return [
|
|
RetrievedDocument(
|
|
content=item.content,
|
|
doc_id=item.doc_id,
|
|
doc_name=item.doc_name,
|
|
section_title=item.section_title,
|
|
clause_number=item.metadata.get("clause_number", ""),
|
|
page_number=item.page_number,
|
|
score=item.score,
|
|
metadata=item.metadata,
|
|
)
|
|
for item in results
|
|
if item.score >= self.min_score
|
|
]
|
|
|
|
def retrieve_with_scores(self, query: str, filters: Optional[str] = None) -> list[dict]:
|
|
"""Handle retrieve with scores for the Retriever instance."""
|
|
return [
|
|
{
|
|
"content": item.content,
|
|
"doc_id": item.doc_id,
|
|
"doc_name": item.doc_name,
|
|
"section_title": item.section_title,
|
|
"clause_number": item.clause_number,
|
|
"page_number": item.page_number,
|
|
"score": item.score,
|
|
}
|
|
for item in self.retrieve(query, filters)
|
|
]
|
|
|
|
def search_by_doc_name(self, query: str, doc_name: str) -> list[RetrievedDocument]:
|
|
"""Search by doc name for the Retriever instance."""
|
|
return self.retrieve(query, filters=f'doc_name == "{doc_name}"')
|
|
|
|
def search_by_regulation_type(self, query: str, regulation_type: str) -> list[RetrievedDocument]:
|
|
"""Search by regulation type for the Retriever instance."""
|
|
return self.retrieve(query, filters=f'regulation_type == "{regulation_type}"')
|
|
|
|
def close(self):
|
|
"""Release the resources held by this component."""
|
|
return None
|
|
|
|
|
|
def retrieve_regulations(query: str, top_k: int = 10, filters: Optional[str] = None) -> list[RetrievedDocument]:
|
|
"""Handle retrieve regulations."""
|
|
return Retriever(top_k=top_k).retrieve(query, filters)
|