v0.21.1-fastapi

This commit is contained in:
2025-11-04 16:06:36 +08:00
parent 3e58c3d0e9
commit d57b5d76ae
218 changed files with 19617 additions and 72339 deletions

View File

@@ -203,7 +203,6 @@ class Canvas(Graph):
self.history = []
self.retrieval = []
self.memory = []
for k in self.globals.keys():
if isinstance(self.globals[k], str):
self.globals[k] = ""
@@ -292,7 +291,6 @@ class Canvas(Graph):
"thoughts": self.get_component_thoughts(self.path[i])
})
_run_batch(idx, to)
# post processing of components invocation
for i in range(idx, to):
cpn = self.get_component(self.path[i])
@@ -393,7 +391,6 @@ class Canvas(Graph):
self.path = path
yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
return
self.path = self.path[:idx]
if not self.error:
yield decorate("workflow_finished",

View File

@@ -346,7 +346,11 @@ Respond immediately with your final comprehensive answer.
return "Error occurred."
def reset(self):
def reset(self, temp=False):
"""
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
"""
for k, cpn in self.tools.items():
cpn.reset()
if hasattr(cpn, "reset") and callable(cpn.reset):
cpn.reset()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -53,12 +53,13 @@ class ExeSQLParam(ToolParamBase):
self.max_records = 1024
def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
self.check_positive_integer(self.port, "IP Port")
self.check_empty(self.password, "Database password")
if self.db_type != "trino":
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.max_records, "Maximum number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql":
@@ -123,6 +124,45 @@ class ExeSQL(ToolBase, ABC):
r'PWD=' + self._param.password
)
db = pyodbc.connect(conn_str)
elif self._param.db_type == 'trino':
try:
import trino
from trino.auth import BasicAuthentication
except Exception:
raise Exception("Missing dependency 'trino'. Please install: pip install trino")
def _parse_catalog_schema(db: str):
if not db:
return None, None
if "." in db:
c, s = db.split(".", 1)
elif "/" in db:
c, s = db.split("/", 1)
else:
c, s = db, "default"
return c, s
catalog, schema = _parse_catalog_schema(self._param.database)
if not catalog:
raise Exception("For Trino, `database` must be 'catalog.schema' or at least 'catalog'.")
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
auth = None
if http_scheme == "https" and self._param.password:
auth = BasicAuthentication(self._param.username, self._param.password)
try:
db = trino.dbapi.connect(
host=self._param.host,
port=int(self._param.port or 8080),
user=self._param.username or "ragflow",
catalog=catalog,
schema=schema or "default",
http_scheme=http_scheme,
auth=auth
)
except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e))
elif self._param.db_type == 'IBM DB2':
import ibm_db
conn_str = (

View File

@@ -85,13 +85,7 @@ class PubMed(ToolBase, ABC):
self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"),
get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text,
get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text,
get_content=lambda child: child.find("MedlineCitation") \
.find("Article") \
.find("Abstract") \
.find("AbstractText").text \
if child.find("MedlineCitation")\
.find("Article").find("Abstract") \
else "No abstract available")
get_content=lambda child: self._format_pubmed_content(child),)
return self.output("formalized_content")
except Exception as e:
last_e = e
@@ -104,5 +98,50 @@ class PubMed(ToolBase, ABC):
assert False, self.output()
def _format_pubmed_content(self, child):
"""Extract structured reference info from PubMed XML"""
def safe_find(path):
node = child
for p in path.split("/"):
if node is None:
return None
node = node.find(p)
return node.text if node is not None and node.text else None
title = safe_find("MedlineCitation/Article/ArticleTitle") or "No title"
abstract = safe_find("MedlineCitation/Article/Abstract/AbstractText") or "No abstract available"
journal = safe_find("MedlineCitation/Article/Journal/Title") or "Unknown Journal"
volume = safe_find("MedlineCitation/Article/Journal/JournalIssue/Volume") or "-"
issue = safe_find("MedlineCitation/Article/Journal/JournalIssue/Issue") or "-"
pages = safe_find("MedlineCitation/Article/Pagination/MedlinePgn") or "-"
# Authors
authors = []
for author in child.findall(".//AuthorList/Author"):
lastname = safe_find("LastName") or ""
forename = safe_find("ForeName") or ""
fullname = f"{forename} {lastname}".strip()
if fullname:
authors.append(fullname)
authors_str = ", ".join(authors) if authors else "Unknown Authors"
# DOI
doi = None
for eid in child.findall(".//ArticleId"):
if eid.attrib.get("IdType") == "doi":
doi = eid.text
break
return (
f"Title: {title}\n"
f"Authors: {authors_str}\n"
f"Journal: {journal}\n"
f"Volume: {volume}\n"
f"Issue: {issue}\n"
f"Pages: {pages}\n"
f"DOI: {doi or '-'}\n"
f"Abstract: {abstract.strip()}"
)
def thoughts(self) -> str:
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))

View File

@@ -18,12 +18,14 @@ import re
from abc import ABC
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api.db import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.dialog_service import meta_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from api.utils.api_utils import timeout
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
class RetrievalParam(ToolParamBase):
@@ -57,6 +59,8 @@ class RetrievalParam(ToolParamBase):
self.empty_response = ""
self.use_kg = False
self.cross_languages = []
self.toc_enhance = False
self.meta_data_filter={}
def check(self):
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
@@ -116,12 +120,27 @@ class Retrieval(ToolBase, ABC):
vars = self.get_input_elements_from_text(kwargs["query"])
vars = {k:o["value"] for k,o in vars.items()}
query = self.string_format(kwargs["query"], vars)
doc_ids=[]
if self._param.meta_data_filter!={}:
metas = DocumentService.get_meta_by_kbs(kb_ids)
if self._param.meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
filters = gen_meta_filter(chat_mdl, metas, query)
doc_ids.extend(meta_filter(metas, filters))
if not doc_ids:
doc_ids = None
elif self._param.meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, self._param.meta_data_filter["manual"]))
if not doc_ids:
doc_ids = None
if self._param.cross_languages:
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = settings.retrievaler.retrieval(
kbinfos = settings.retriever.retrieval(
query,
embd_mdl,
[kb.tenant_id for kb in kbs],
@@ -130,12 +149,18 @@ class Retrieval(ToolBase, ABC):
self._param.top_n,
self._param.similarity_threshold,
1 - self._param.keywords_similarity_weight,
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs),
)
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
if cks:
kbinfos["chunks"] = cks
if self._param.use_kg:
ck = settings.kg_retrievaler.retrieval(query,
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
@@ -146,7 +171,7 @@ class Retrieval(ToolBase, ABC):
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
ck = settings.kg_retrievaler.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ck["content"] = ck["content_with_weight"]
del ck["content_with_weight"]