v0.21.1-fastapi
This commit is contained in:
@@ -18,12 +18,14 @@ import re
|
||||
from abc import ABC
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from api.db import LLMType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt
|
||||
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
||||
|
||||
|
||||
class RetrievalParam(ToolParamBase):
|
||||
@@ -57,6 +59,8 @@ class RetrievalParam(ToolParamBase):
|
||||
self.empty_response = ""
|
||||
self.use_kg = False
|
||||
self.cross_languages = []
|
||||
self.toc_enhance = False
|
||||
self.meta_data_filter={}
|
||||
|
||||
def check(self):
|
||||
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
|
||||
@@ -116,12 +120,27 @@ class Retrieval(ToolBase, ABC):
|
||||
vars = self.get_input_elements_from_text(kwargs["query"])
|
||||
vars = {k:o["value"] for k,o in vars.items()}
|
||||
query = self.string_format(kwargs["query"], vars)
|
||||
|
||||
doc_ids=[]
|
||||
if self._param.meta_data_filter!={}:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if self._param.meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||
filters = gen_meta_filter(chat_mdl, metas, query)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif self._param.meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, self._param.meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
|
||||
if self._param.cross_languages:
|
||||
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||
|
||||
if kbs:
|
||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||
kbinfos = settings.retrievaler.retrieval(
|
||||
kbinfos = settings.retriever.retrieval(
|
||||
query,
|
||||
embd_mdl,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
@@ -130,12 +149,18 @@ class Retrieval(ToolBase, ABC):
|
||||
self._param.top_n,
|
||||
self._param.similarity_threshold,
|
||||
1 - self._param.keywords_similarity_weight,
|
||||
doc_ids=doc_ids,
|
||||
aggs=False,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(query, kbs),
|
||||
)
|
||||
if self._param.toc_enhance:
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(query,
|
||||
ck = settings.kg_retriever.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
@@ -146,7 +171,7 @@ class Retrieval(ToolBase, ABC):
|
||||
kbinfos = {"chunks": [], "doc_aggs": []}
|
||||
|
||||
if self._param.use_kg and kbs:
|
||||
ck = settings.kg_retrievaler.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ck["content"] = ck["content_with_weight"]
|
||||
del ck["content_with_weight"]
|
||||
|
||||
Reference in New Issue
Block a user