v0.21.1-fastapi
This commit is contained in:
@@ -203,7 +203,6 @@ class Canvas(Graph):
|
||||
self.history = []
|
||||
self.retrieval = []
|
||||
self.memory = []
|
||||
|
||||
for k in self.globals.keys():
|
||||
if isinstance(self.globals[k], str):
|
||||
self.globals[k] = ""
|
||||
@@ -292,7 +291,6 @@ class Canvas(Graph):
|
||||
"thoughts": self.get_component_thoughts(self.path[i])
|
||||
})
|
||||
_run_batch(idx, to)
|
||||
|
||||
# post processing of components invocation
|
||||
for i in range(idx, to):
|
||||
cpn = self.get_component(self.path[i])
|
||||
@@ -393,7 +391,6 @@ class Canvas(Graph):
|
||||
self.path = path
|
||||
yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
|
||||
return
|
||||
|
||||
self.path = self.path[:idx]
|
||||
if not self.error:
|
||||
yield decorate("workflow_finished",
|
||||
|
||||
@@ -346,7 +346,11 @@ Respond immediately with your final comprehensive answer.
|
||||
|
||||
return "Error occurred."
|
||||
|
||||
def reset(self):
|
||||
def reset(self, temp=False):
|
||||
"""
|
||||
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
|
||||
"""
|
||||
for k, cpn in self.tools.items():
|
||||
cpn.reset()
|
||||
if hasattr(cpn, "reset") and callable(cpn.reset):
|
||||
cpn.reset()
|
||||
|
||||
|
||||
726
agent/templates/advanced_ingestion_pipeline.json
Normal file
726
agent/templates/advanced_ingestion_pipeline.json
Normal file
File diff suppressed because one or more lines are too long
493
agent/templates/chunk_summary.json
Normal file
493
agent/templates/chunk_summary.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1172
agent/templates/stock_research_report.json
Normal file
1172
agent/templates/stock_research_report.json
Normal file
File diff suppressed because one or more lines are too long
369
agent/templates/title_chunker.json
Normal file
369
agent/templates/title_chunker.json
Normal file
File diff suppressed because one or more lines are too long
@@ -53,12 +53,13 @@ class ExeSQLParam(ToolParamBase):
|
||||
self.max_records = 1024
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
|
||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino'])
|
||||
self.check_empty(self.database, "Database name")
|
||||
self.check_empty(self.username, "database username")
|
||||
self.check_empty(self.host, "IP Address")
|
||||
self.check_positive_integer(self.port, "IP Port")
|
||||
self.check_empty(self.password, "Database password")
|
||||
if self.db_type != "trino":
|
||||
self.check_empty(self.password, "Database password")
|
||||
self.check_positive_integer(self.max_records, "Maximum number of records")
|
||||
if self.database == "rag_flow":
|
||||
if self.host == "ragflow-mysql":
|
||||
@@ -123,6 +124,45 @@ class ExeSQL(ToolBase, ABC):
|
||||
r'PWD=' + self._param.password
|
||||
)
|
||||
db = pyodbc.connect(conn_str)
|
||||
elif self._param.db_type == 'trino':
|
||||
try:
|
||||
import trino
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception:
|
||||
raise Exception("Missing dependency 'trino'. Please install: pip install trino")
|
||||
|
||||
def _parse_catalog_schema(db: str):
|
||||
if not db:
|
||||
return None, None
|
||||
if "." in db:
|
||||
c, s = db.split(".", 1)
|
||||
elif "/" in db:
|
||||
c, s = db.split("/", 1)
|
||||
else:
|
||||
c, s = db, "default"
|
||||
return c, s
|
||||
|
||||
catalog, schema = _parse_catalog_schema(self._param.database)
|
||||
if not catalog:
|
||||
raise Exception("For Trino, `database` must be 'catalog.schema' or at least 'catalog'.")
|
||||
|
||||
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||
auth = None
|
||||
if http_scheme == "https" and self._param.password:
|
||||
auth = BasicAuthentication(self._param.username, self._param.password)
|
||||
|
||||
try:
|
||||
db = trino.dbapi.connect(
|
||||
host=self._param.host,
|
||||
port=int(self._param.port or 8080),
|
||||
user=self._param.username or "ragflow",
|
||||
catalog=catalog,
|
||||
schema=schema or "default",
|
||||
http_scheme=http_scheme,
|
||||
auth=auth
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception("Database Connection Failed! \n" + str(e))
|
||||
elif self._param.db_type == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
|
||||
@@ -85,13 +85,7 @@ class PubMed(ToolBase, ABC):
|
||||
self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"),
|
||||
get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text,
|
||||
get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text,
|
||||
get_content=lambda child: child.find("MedlineCitation") \
|
||||
.find("Article") \
|
||||
.find("Abstract") \
|
||||
.find("AbstractText").text \
|
||||
if child.find("MedlineCitation")\
|
||||
.find("Article").find("Abstract") \
|
||||
else "No abstract available")
|
||||
get_content=lambda child: self._format_pubmed_content(child),)
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
@@ -104,5 +98,50 @@ class PubMed(ToolBase, ABC):
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def _format_pubmed_content(self, child):
|
||||
"""Extract structured reference info from PubMed XML"""
|
||||
def safe_find(path):
|
||||
node = child
|
||||
for p in path.split("/"):
|
||||
if node is None:
|
||||
return None
|
||||
node = node.find(p)
|
||||
return node.text if node is not None and node.text else None
|
||||
|
||||
title = safe_find("MedlineCitation/Article/ArticleTitle") or "No title"
|
||||
abstract = safe_find("MedlineCitation/Article/Abstract/AbstractText") or "No abstract available"
|
||||
journal = safe_find("MedlineCitation/Article/Journal/Title") or "Unknown Journal"
|
||||
volume = safe_find("MedlineCitation/Article/Journal/JournalIssue/Volume") or "-"
|
||||
issue = safe_find("MedlineCitation/Article/Journal/JournalIssue/Issue") or "-"
|
||||
pages = safe_find("MedlineCitation/Article/Pagination/MedlinePgn") or "-"
|
||||
|
||||
# Authors
|
||||
authors = []
|
||||
for author in child.findall(".//AuthorList/Author"):
|
||||
lastname = safe_find("LastName") or ""
|
||||
forename = safe_find("ForeName") or ""
|
||||
fullname = f"{forename} {lastname}".strip()
|
||||
if fullname:
|
||||
authors.append(fullname)
|
||||
authors_str = ", ".join(authors) if authors else "Unknown Authors"
|
||||
|
||||
# DOI
|
||||
doi = None
|
||||
for eid in child.findall(".//ArticleId"):
|
||||
if eid.attrib.get("IdType") == "doi":
|
||||
doi = eid.text
|
||||
break
|
||||
|
||||
return (
|
||||
f"Title: {title}\n"
|
||||
f"Authors: {authors_str}\n"
|
||||
f"Journal: {journal}\n"
|
||||
f"Volume: {volume}\n"
|
||||
f"Issue: {issue}\n"
|
||||
f"Pages: {pages}\n"
|
||||
f"DOI: {doi or '-'}\n"
|
||||
f"Abstract: {abstract.strip()}"
|
||||
)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
@@ -18,12 +18,14 @@ import re
|
||||
from abc import ABC
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from api.db import LLMType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt
|
||||
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
||||
|
||||
|
||||
class RetrievalParam(ToolParamBase):
|
||||
@@ -57,6 +59,8 @@ class RetrievalParam(ToolParamBase):
|
||||
self.empty_response = ""
|
||||
self.use_kg = False
|
||||
self.cross_languages = []
|
||||
self.toc_enhance = False
|
||||
self.meta_data_filter={}
|
||||
|
||||
def check(self):
|
||||
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
|
||||
@@ -116,12 +120,27 @@ class Retrieval(ToolBase, ABC):
|
||||
vars = self.get_input_elements_from_text(kwargs["query"])
|
||||
vars = {k:o["value"] for k,o in vars.items()}
|
||||
query = self.string_format(kwargs["query"], vars)
|
||||
|
||||
doc_ids=[]
|
||||
if self._param.meta_data_filter!={}:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if self._param.meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||
filters = gen_meta_filter(chat_mdl, metas, query)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif self._param.meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, self._param.meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
|
||||
if self._param.cross_languages:
|
||||
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||
|
||||
if kbs:
|
||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||
kbinfos = settings.retrievaler.retrieval(
|
||||
kbinfos = settings.retriever.retrieval(
|
||||
query,
|
||||
embd_mdl,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
@@ -130,12 +149,18 @@ class Retrieval(ToolBase, ABC):
|
||||
self._param.top_n,
|
||||
self._param.similarity_threshold,
|
||||
1 - self._param.keywords_similarity_weight,
|
||||
doc_ids=doc_ids,
|
||||
aggs=False,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(query, kbs),
|
||||
)
|
||||
if self._param.toc_enhance:
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(query,
|
||||
ck = settings.kg_retriever.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
@@ -146,7 +171,7 @@ class Retrieval(ToolBase, ABC):
|
||||
kbinfos = {"chunks": [], "doc_aggs": []}
|
||||
|
||||
if self._param.use_kg and kbs:
|
||||
ck = settings.kg_retrievaler.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ck["content"] = ck["content_with_weight"]
|
||||
del ck["content_with_weight"]
|
||||
|
||||
Reference in New Issue
Block a user