83 lines
2.5 KiB
Python
83 lines
2.5 KiB
Python
"""Define domain ports for retrieval."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from app.domain.documents.models import Chunk
|
|
|
|
from .models import RetrievalQuery, RetrievedChunk
|
|
# Keep domain contracts explicit so adapters can swap implementations cleanly.
|
|
|
|
|
|
class EmbeddingProvider(ABC):
|
|
"""Provide the Embedding Provider provider."""
|
|
@abstractmethod
|
|
def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
|
"""Embed texts for the Embedding Provider instance."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def embed_query(self, text: str) -> list[float]:
|
|
"""Embed query for the Embedding Provider instance."""
|
|
pass
|
|
|
|
|
|
class VectorIndex(ABC):
|
|
"""Provide the Vector Index index implementation."""
|
|
@abstractmethod
|
|
def upsert(self, chunks: list[Chunk], vectors: list[list[float]]) -> int:
|
|
"""Handle upsert for the Vector Index instance."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def delete_by_document(self, doc_id: str) -> int:
|
|
"""Delete by document for the Vector Index instance."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
|
"""Handle search for the Vector Index instance."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def count_by_document(self) -> dict[str, int]:
|
|
"""Return a mapping of doc_id -> chunk count from the vector store."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def list_document_metadata(self) -> list[dict]:
|
|
"""Return per-document metadata rows from the vector store.
|
|
|
|
Each row contains at minimum: doc_id, doc_name, chunk_count.
|
|
Optional fields: regulation_type, version.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def health(self) -> dict:
|
|
"""Handle health for the Vector Index instance."""
|
|
pass
|
|
|
|
|
|
class Reranker(ABC):
|
|
"""Re-score and re-order a candidate list using a cross-encoder model."""
|
|
|
|
@abstractmethod
|
|
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
|
"""Return top_k chunks sorted by cross-encoder score (descending)."""
|
|
pass
|
|
|
|
|
|
class Retriever(ABC):
|
|
"""Provide the Retriever retriever."""
|
|
@abstractmethod
|
|
def retrieve(self, query: RetrievalQuery) -> list[RetrievedChunk]:
|
|
"""Handle retrieve for the Retriever instance."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def search(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
|
"""Handle search for the Retriever instance."""
|
|
pass
|