将flask改成fastapi
This commit is contained in:
48
agent/tools/__init__.py
Normal file
48
agent/tools/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import importlib
|
||||
import inspect
|
||||
from types import ModuleType
|
||||
from typing import Dict, Type
|
||||
|
||||
_package_path = os.path.dirname(__file__)
|
||||
__all_classes: Dict[str, Type] = {}
|
||||
|
||||
def _import_submodules() -> None:
|
||||
for filename in os.listdir(_package_path): # noqa: F821
|
||||
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
|
||||
continue
|
||||
module_name = filename[:-3]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(f".{module_name}", package=__name__)
|
||||
_extract_classes_from_module(module) # noqa: F821
|
||||
except ImportError as e:
|
||||
print(f"Warning: Failed to import module {module_name}: {str(e)}")
|
||||
|
||||
def _extract_classes_from_module(module: ModuleType) -> None:
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if (inspect.isclass(obj) and
|
||||
obj.__module__ == module.__name__ and not name.startswith("_")):
|
||||
__all_classes[name] = obj
|
||||
globals()[name] = obj
|
||||
|
||||
_import_submodules()
|
||||
|
||||
__all__ = list(__all_classes.keys()) + ["__all_classes"]
|
||||
|
||||
del _package_path, _import_submodules, _extract_classes_from_module
|
||||
56
agent/tools/akshare.py
Normal file
56
agent/tools/akshare.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import pandas as pd
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class AkShareParam(ComponentParamBase):
|
||||
"""
|
||||
Define the AkShare component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
|
||||
|
||||
class AkShare(ComponentBase, ABC):
|
||||
component_name = "AkShare"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
import akshare as ak
|
||||
ans = self.get_input()
|
||||
ans = ",".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return AkShare.be_output("")
|
||||
|
||||
try:
|
||||
ak_res = []
|
||||
stock_news_em_df = ak.stock_news_em(symbol=ans)
|
||||
stock_news_em_df = stock_news_em_df.head(self._param.top_n)
|
||||
ak_res = [{"content": '<a href="' + i["新闻链接"] + '">' + i["新闻标题"] + '</a>\n 新闻内容: ' + i[
|
||||
"新闻内容"] + " \n发布时间:" + i["发布时间"] + " \n文章来源: " + i["文章来源"]} for index, i in stock_news_em_df.iterrows()]
|
||||
except Exception as e:
|
||||
return AkShare.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not ak_res:
|
||||
return AkShare.be_output("")
|
||||
|
||||
return pd.DataFrame(ak_res)
|
||||
102
agent/tools/arxiv.py
Normal file
102
agent/tools/arxiv.py
Normal file
@@ -0,0 +1,102 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
import arxiv
|
||||
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class ArXivParam(ToolParamBase):
|
||||
"""
|
||||
Define the ArXiv component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "arxiv_search",
|
||||
"description": """arXiv is a free distribution service and an open-access archive for nearly 2.4 million scholarly articles in the fields of physics, mathematics, computer science, quantitative biology, quantitative finance, statistics, electrical engineering and systems science, and economics. Materials on this site are not peer-reviewed by arXiv.""",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search keywords to execute with arXiv. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.top_n = 12
|
||||
self.sort_by = 'submittedDate'
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.sort_by, "ArXiv Search Sort_by",
|
||||
['submittedDate', 'lastUpdatedDate', 'relevance'])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ArXiv(ToolBase, ABC):
|
||||
component_name = "ArXiv"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
sort_choices = {"relevance": arxiv.SortCriterion.Relevance,
|
||||
"lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate,
|
||||
'submittedDate': arxiv.SortCriterion.SubmittedDate}
|
||||
arxiv_client = arxiv.Client()
|
||||
search = arxiv.Search(
|
||||
query=kwargs["query"],
|
||||
max_results=self._param.top_n,
|
||||
sort_by=sort_choices[self._param.sort_by]
|
||||
)
|
||||
self._retrieve_chunks(list(arxiv_client.results(search)),
|
||||
get_title=lambda r: r.title,
|
||||
get_url=lambda r: r.pdf_url,
|
||||
get_content=lambda r: r.summary)
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"ArXiv error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"ArXiv error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
173
agent/tools/base.py
Normal file
173
agent/tools/base.py
Normal file
@@ -0,0 +1,173 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import TypedDict, List, Any
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from api.utils import hash_str2int
|
||||
from rag.llm.chat_model import ToolCallSession
|
||||
from rag.prompts.generator import kb_prompt
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||
from timeit import default_timer as timer
|
||||
|
||||
|
||||
class ToolParameter(TypedDict):
|
||||
type: str
|
||||
description: str
|
||||
displayDescription: str
|
||||
enum: List[str]
|
||||
required: bool
|
||||
|
||||
|
||||
class ToolMeta(TypedDict):
|
||||
name: str
|
||||
displayName: str
|
||||
description: str
|
||||
displayDescription: str
|
||||
parameters: dict[str, ToolParameter]
|
||||
|
||||
|
||||
class LLMToolPluginCallSession(ToolCallSession):
|
||||
def __init__(self, tools_map: dict[str, object], callback: partial):
|
||||
self.tools_map = tools_map
|
||||
self.callback = callback
|
||||
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||
st = timer()
|
||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
||||
else:
|
||||
resp = self.tools_map[name].invoke(**arguments)
|
||||
|
||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||
return resp
|
||||
|
||||
def get_tool_obj(self, name):
|
||||
return self.tools_map[name]
|
||||
|
||||
|
||||
class ToolParamBase(ComponentParamBase):
|
||||
def __init__(self):
|
||||
#self.meta:ToolMeta = None
|
||||
super().__init__()
|
||||
self._init_inputs()
|
||||
self._init_attr_by_meta()
|
||||
|
||||
def _init_inputs(self):
|
||||
self.inputs = {}
|
||||
for k,p in self.meta["parameters"].items():
|
||||
self.inputs[k] = deepcopy(p)
|
||||
|
||||
def _init_attr_by_meta(self):
|
||||
for k,p in self.meta["parameters"].items():
|
||||
if not hasattr(self, k):
|
||||
setattr(self, k, p.get("default"))
|
||||
|
||||
def get_meta(self):
|
||||
params = {}
|
||||
for k, p in self.meta["parameters"].items():
|
||||
params[k] = {
|
||||
"type": p["type"],
|
||||
"description": p["description"]
|
||||
}
|
||||
if "enum" in p:
|
||||
params[k]["enum"] = p["enum"]
|
||||
|
||||
desc = self.meta["description"]
|
||||
if hasattr(self, "description"):
|
||||
desc = self.description
|
||||
|
||||
function_name = self.meta["name"]
|
||||
if hasattr(self, "function_name"):
|
||||
function_name = self.function_name
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"description": desc,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": params,
|
||||
"required": [k for k, p in self.meta["parameters"].items() if p["required"]]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ToolBase(ComponentBase):
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
from agent.canvas import Canvas # Local import to avoid cyclic dependency
|
||||
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
|
||||
self._canvas = canvas
|
||||
self._id = id
|
||||
self._param = param
|
||||
self._param.check()
|
||||
|
||||
def get_meta(self) -> dict[str, Any]:
|
||||
return self._param.get_meta()
|
||||
|
||||
def invoke(self, **kwargs):
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
res = self._invoke(**kwargs)
|
||||
except Exception as e:
|
||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||
logging.exception(e)
|
||||
res = str(e)
|
||||
self._param.debug_inputs = []
|
||||
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return res
|
||||
|
||||
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
||||
chunks = []
|
||||
aggs = []
|
||||
for r in res_list:
|
||||
content = get_content(r)
|
||||
if not content:
|
||||
continue
|
||||
content = re.sub(r"!?\[[a-z]+\]\(data:image/png;base64,[ 0-9A-Za-z/_=+-]+\)", "", content)
|
||||
content = content[:10000]
|
||||
if not content:
|
||||
continue
|
||||
id = str(hash_str2int(content))
|
||||
title = get_title(r)
|
||||
url = get_url(r)
|
||||
score = get_score(r) if get_score else 1
|
||||
chunks.append({
|
||||
"chunk_id": id,
|
||||
"content": content,
|
||||
"doc_id": id,
|
||||
"docnm_kwd": title,
|
||||
"similarity": score,
|
||||
"url": url
|
||||
})
|
||||
aggs.append({
|
||||
"doc_name": title,
|
||||
"doc_id": id,
|
||||
"count": 1,
|
||||
"url": url
|
||||
})
|
||||
self._canvas.add_reference(chunks, aggs)
|
||||
self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True)))
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return self._canvas.get_component_name(self._id) + " is running..."
|
||||
201
agent/tools/code_exec.py
Normal file
201
agent/tools/code_exec.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from strenum import StrEnum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from api import settings
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class Language(StrEnum):
|
||||
PYTHON = "python"
|
||||
NODEJS = "nodejs"
|
||||
|
||||
|
||||
class CodeExecutionRequest(BaseModel):
|
||||
code_b64: str = Field(..., description="Base64 encoded code string")
|
||||
language: str = Field(default=Language.PYTHON.value, description="Programming language")
|
||||
arguments: Optional[dict] = Field(default={}, description="Arguments")
|
||||
|
||||
@field_validator("code_b64")
|
||||
@classmethod
|
||||
def validate_base64(cls, v: str) -> str:
|
||||
try:
|
||||
base64.b64decode(v, validate=True)
|
||||
return v
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid base64 encoding: {str(e)}")
|
||||
|
||||
@field_validator("language", mode="before")
|
||||
@classmethod
|
||||
def normalize_language(cls, v) -> str:
|
||||
if isinstance(v, str):
|
||||
low = v.lower()
|
||||
if low in ("python", "python3"):
|
||||
return "python"
|
||||
elif low in ("javascript", "nodejs"):
|
||||
return "nodejs"
|
||||
raise ValueError(f"Unsupported language: {v}")
|
||||
|
||||
|
||||
class CodeExecParam(ToolParamBase):
|
||||
"""
|
||||
Define the code sandbox component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "execute_code",
|
||||
"description": """
|
||||
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string.
|
||||
Here's a code example for Python(`main` function MUST be included):
|
||||
def main() -> dict:
|
||||
\"\"\"
|
||||
Generate Fibonacci numbers within 100.
|
||||
\"\"\"
|
||||
def fibonacci_recursive(n):
|
||||
if n <= 1:
|
||||
return n
|
||||
else:
|
||||
return fibonacci_recursive(n-1) + fibonacci_recursive(n-2)
|
||||
return {
|
||||
"result": fibonacci_recursive(100),
|
||||
}
|
||||
|
||||
Here's a code example for Javascript(`main` function MUST be included and exported):
|
||||
const axios = require('axios');
|
||||
async function main(args) {
|
||||
try {
|
||||
const response = await axios.get('https://github.com/infiniflow/ragflow');
|
||||
console.log('Body:', response.data);
|
||||
} catch (error) {
|
||||
console.error('Error:', error.message);
|
||||
}
|
||||
}
|
||||
module.exports = { main };
|
||||
""",
|
||||
"parameters": {
|
||||
"lang": {
|
||||
"type": "string",
|
||||
"description": "The programming language of this piece of code.",
|
||||
"enum": ["python", "javascript"],
|
||||
"required": True,
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "A piece of code in right format. There MUST be main function.",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.lang = Language.PYTHON.value
|
||||
self.script = "def main(arg1: str, arg2: str) -> dict: return {\"result\": arg1 + arg2}"
|
||||
self.arguments = {}
|
||||
self.outputs = {"result": {"value": "", "type": "string"}}
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.lang, "Support languages", ["python", "python3", "nodejs", "javascript"])
|
||||
self.check_empty(self.script, "Script")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
res = {}
|
||||
for k, v in self.arguments.items():
|
||||
res[k] = {
|
||||
"type": "line",
|
||||
"name": k
|
||||
}
|
||||
return res
|
||||
|
||||
|
||||
class CodeExec(ToolBase, ABC):
|
||||
component_name = "CodeExec"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
lang = kwargs.get("lang", self._param.lang)
|
||||
script = kwargs.get("script", self._param.script)
|
||||
arguments = {}
|
||||
for k, v in self._param.arguments.items():
|
||||
if kwargs.get(k):
|
||||
arguments[k] = kwargs[k]
|
||||
continue
|
||||
arguments[k] = self._canvas.get_variable_value(v) if v else None
|
||||
|
||||
self._execute_code(
|
||||
language=lang,
|
||||
code=script,
|
||||
arguments=arguments
|
||||
)
|
||||
|
||||
def _execute_code(self, language: str, code: str, arguments: dict):
|
||||
import requests
|
||||
|
||||
try:
|
||||
code_b64 = self._encode_code(code)
|
||||
code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump()
|
||||
except Exception as e:
|
||||
self.set_output("_ERROR", "construct code request error: " + str(e))
|
||||
|
||||
try:
|
||||
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
|
||||
if resp.status_code != 200:
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
if body:
|
||||
stderr = body.get("stderr")
|
||||
if stderr:
|
||||
self.set_output("_ERROR", stderr)
|
||||
return
|
||||
try:
|
||||
rt = eval(body.get("stdout", ""))
|
||||
except Exception:
|
||||
rt = body.get("stdout", "")
|
||||
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}")
|
||||
if isinstance(rt, tuple):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[i]
|
||||
elif isinstance(rt, dict):
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if k not in rt or k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt[k]
|
||||
else:
|
||||
for i, (k, o) in enumerate(self._param.outputs.items()):
|
||||
if k.find("_") == 0:
|
||||
continue
|
||||
o["value"] = rt
|
||||
else:
|
||||
self.set_output("_ERROR", "There is no response from sandbox")
|
||||
|
||||
except Exception as e:
|
||||
self.set_output("_ERROR", "Exception executing code: " + str(e))
|
||||
|
||||
return self.output()
|
||||
|
||||
def _encode_code(self, code: str) -> str:
|
||||
return base64.b64encode(code.encode("utf-8")).decode("utf-8")
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Running a short script to process data."
|
||||
68
agent/tools/crawler.py
Normal file
68
agent/tools/crawler.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
from crawl4ai import AsyncWebCrawler
|
||||
from agent.tools.base import ToolParamBase, ToolBase
|
||||
|
||||
|
||||
|
||||
class CrawlerParam(ToolParamBase):
|
||||
"""
|
||||
Define the Crawler component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proxy = None
|
||||
self.extract_type = "markdown"
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content'])
|
||||
|
||||
|
||||
class Crawler(ToolBase, ABC):
|
||||
component_name = "Crawler"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
from api.utils.web_utils import is_valid_url
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not is_valid_url(ans):
|
||||
return Crawler.be_output("URL not valid")
|
||||
try:
|
||||
result = asyncio.run(self.get_web(ans))
|
||||
|
||||
return Crawler.be_output(result)
|
||||
|
||||
except Exception as e:
|
||||
return Crawler.be_output(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
async def get_web(self, url):
|
||||
proxy = self._param.proxy if self._param.proxy else None
|
||||
async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler:
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
bypass_cache=True
|
||||
)
|
||||
|
||||
if self._param.extract_type == 'html':
|
||||
return result.cleaned_html
|
||||
elif self._param.extract_type == 'markdown':
|
||||
return result.markdown
|
||||
elif self._param.extract_type == 'content':
|
||||
return result.extracted_content
|
||||
return result.markdown
|
||||
61
agent/tools/deepl.py
Normal file
61
agent/tools/deepl.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
import deepl
|
||||
|
||||
|
||||
class DeepLParam(ComponentParamBase):
|
||||
"""
|
||||
Define the DeepL component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.auth_key = "xxx"
|
||||
self.parameters = []
|
||||
self.source_lang = 'ZH'
|
||||
self.target_lang = 'EN-GB'
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.source_lang, "Source language",
|
||||
['AR', 'BG', 'CS', 'DA', 'DE', 'EL', 'EN', 'ES', 'ET', 'FI', 'FR', 'HU', 'ID', 'IT',
|
||||
'JA', 'KO', 'LT', 'LV', 'NB', 'NL', 'PL', 'PT', 'RO', 'RU', 'SK', 'SL', 'SV', 'TR',
|
||||
'UK', 'ZH'])
|
||||
self.check_valid_value(self.target_lang, "Target language",
|
||||
['AR', 'BG', 'CS', 'DA', 'DE', 'EL', 'EN-GB', 'EN-US', 'ES', 'ET', 'FI', 'FR', 'HU',
|
||||
'ID', 'IT', 'JA', 'KO', 'LT', 'LV', 'NB', 'NL', 'PL', 'PT-BR', 'PT-PT', 'RO', 'RU',
|
||||
'SK', 'SL', 'SV', 'TR', 'UK', 'ZH'])
|
||||
|
||||
|
||||
class DeepL(ComponentBase, ABC):
|
||||
component_name = "DeepL"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return DeepL.be_output("")
|
||||
|
||||
try:
|
||||
translator = deepl.Translator(self._param.auth_key)
|
||||
result = translator.translate_text(ans, source_lang=self._param.source_lang,
|
||||
target_lang=self._param.target_lang)
|
||||
|
||||
return DeepL.be_output(result.text)
|
||||
except Exception as e:
|
||||
DeepL.be_output("**Error**:" + str(e))
|
||||
120
agent/tools/duckduckgo.py
Normal file
120
agent/tools/duckduckgo.py
Normal file
@@ -0,0 +1,120 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
from duckduckgo_search import DDGS
|
||||
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class DuckDuckGoParam(ToolParamBase):
|
||||
"""
|
||||
Define the DuckDuckGo component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "duckduckgo_search",
|
||||
"description": "DuckDuckGo is a search engine focused on privacy. It offers search capabilities for web pages, images, and provides translation services. DuckDuckGo also features a private AI chat interface, providing users with an AI assistant that prioritizes data protection.",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search keywords to execute with DuckDuckGo. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "default:general. The category of the search. `news` is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. `general` is for broader, more general-purpose searches that may include a wide range of sources.",
|
||||
"enum": ["general", "news"],
|
||||
"default": "general",
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
self.channel = "text"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.channel, "Web Search or News", ["text", "news"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
},
|
||||
"channel": {
|
||||
"name": "Channel",
|
||||
"type": "options",
|
||||
"value": "general",
|
||||
"options": ["general", "news"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DuckDuckGo(ToolBase, ABC):
|
||||
component_name = "DuckDuckGo"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
if kwargs.get("topic", "general") == "general":
|
||||
with DDGS() as ddgs:
|
||||
# {'title': '', 'href': '', 'body': ''}
|
||||
duck_res = ddgs.text(kwargs["query"], max_results=self._param.top_n)
|
||||
self._retrieve_chunks(duck_res,
|
||||
get_title=lambda r: r["title"],
|
||||
get_url=lambda r: r.get("href", r.get("url")),
|
||||
get_content=lambda r: r["body"])
|
||||
self.set_output("json", duck_res)
|
||||
return self.output("formalized_content")
|
||||
else:
|
||||
with DDGS() as ddgs:
|
||||
# {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''}
|
||||
duck_res = ddgs.news(kwargs["query"], max_results=self._param.top_n)
|
||||
self._retrieve_chunks(duck_res,
|
||||
get_title=lambda r: r["title"],
|
||||
get_url=lambda r: r.get("href", r.get("url")),
|
||||
get_content=lambda r: r["body"])
|
||||
self.set_output("json", duck_res)
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"DuckDuckGo error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"DuckDuckGo error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
215
agent/tools/email.py
Normal file
215
agent/tools/email.py
Normal file
@@ -0,0 +1,215 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
import json
|
||||
import smtplib
|
||||
import logging
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.header import Header
|
||||
from email.utils import formataddr
|
||||
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class EmailParam(ToolParamBase):
|
||||
"""
|
||||
Define the Email component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "email",
|
||||
"description": "The email is a method of electronic communication for sending and receiving information through the Internet. This tool helps users to send emails to one person or to multiple recipients with support for CC, BCC, file attachments, and markdown-to-HTML conversion.",
|
||||
"parameters": {
|
||||
"to_email": {
|
||||
"type": "string",
|
||||
"description": "The target email address.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
},
|
||||
"cc_email": {
|
||||
"type": "string",
|
||||
"description": "The other email addresses needs to be send to. Comma splited.",
|
||||
"default": "",
|
||||
"required": False
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content of the email.",
|
||||
"default": "",
|
||||
"required": False
|
||||
},
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "The subject/title of the email.",
|
||||
"default": "",
|
||||
"required": False
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
# Fixed configuration parameters
|
||||
self.smtp_server = "" # SMTP server address
|
||||
self.smtp_port = 465 # SMTP port
|
||||
self.email = "" # Sender email
|
||||
self.password = "" # Email authorization code
|
||||
self.sender_name = "" # Sender name
|
||||
|
||||
def check(self):
|
||||
# Check required parameters
|
||||
self.check_empty(self.smtp_server, "SMTP Server")
|
||||
self.check_empty(self.email, "Email")
|
||||
self.check_empty(self.password, "Password")
|
||||
self.check_empty(self.sender_name, "Sender Name")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"to_email": {
|
||||
"name": "To ",
|
||||
"type": "line"
|
||||
},
|
||||
"subject": {
|
||||
"name": "Subject",
|
||||
"type": "line",
|
||||
"optional": True
|
||||
},
|
||||
"cc_email": {
|
||||
"name": "CC To",
|
||||
"type": "line",
|
||||
"optional": True
|
||||
},
|
||||
}
|
||||
|
||||
class Email(ToolBase, ABC):
|
||||
component_name = "Email"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("to_email"):
|
||||
self.set_output("success", False)
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
# Parse JSON string passed from upstream
|
||||
email_data = kwargs
|
||||
|
||||
# Validate required fields
|
||||
if "to_email" not in email_data:
|
||||
return Email.be_output("Missing required field: to_email")
|
||||
|
||||
# Create email object
|
||||
msg = MIMEMultipart('alternative')
|
||||
|
||||
# Properly handle sender name encoding
|
||||
msg['From'] = formataddr((str(Header(self._param.sender_name,'utf-8')), self._param.email))
|
||||
msg['To'] = email_data["to_email"]
|
||||
if email_data.get("cc_email"):
|
||||
msg['Cc'] = email_data["cc_email"]
|
||||
msg['Subject'] = Header(email_data.get("subject", "No Subject"), 'utf-8').encode()
|
||||
|
||||
# Use content from email_data or default content
|
||||
email_content = email_data.get("content", "No content provided")
|
||||
# msg.attach(MIMEText(email_content, 'plain', 'utf-8'))
|
||||
msg.attach(MIMEText(email_content, 'html', 'utf-8'))
|
||||
|
||||
# Connect to SMTP server and send
|
||||
logging.info(f"Connecting to SMTP server {self._param.smtp_server}:{self._param.smtp_port}")
|
||||
|
||||
context = smtplib.ssl.create_default_context()
|
||||
with smtplib.SMTP(self._param.smtp_server, self._param.smtp_port) as server:
|
||||
server.ehlo()
|
||||
server.starttls(context=context)
|
||||
server.ehlo()
|
||||
# Login
|
||||
logging.info(f"Attempting to login with email: {self._param.email}")
|
||||
server.login(self._param.email, self._param.password)
|
||||
|
||||
# Get all recipient list
|
||||
recipients = [email_data["to_email"]]
|
||||
if email_data.get("cc_email"):
|
||||
recipients.extend(email_data["cc_email"].split(','))
|
||||
|
||||
# Send email
|
||||
logging.info(f"Sending email to recipients: {recipients}")
|
||||
try:
|
||||
server.send_message(msg, self._param.email, recipients)
|
||||
success = True
|
||||
except Exception as e:
|
||||
logging.error(f"Error during send_message: {str(e)}")
|
||||
# Try alternative method
|
||||
server.sendmail(self._param.email, recipients, msg.as_string())
|
||||
success = True
|
||||
|
||||
try:
|
||||
server.quit()
|
||||
except Exception as e:
|
||||
# Ignore errors when closing connection
|
||||
logging.warning(f"Non-fatal error during connection close: {str(e)}")
|
||||
|
||||
self.set_output("success", success)
|
||||
return success
|
||||
|
||||
except json.JSONDecodeError:
|
||||
error_msg = "Invalid JSON format in input"
|
||||
logging.error(error_msg)
|
||||
self.set_output("_ERROR", error_msg)
|
||||
self.set_output("success", False)
|
||||
return False
|
||||
|
||||
except smtplib.SMTPAuthenticationError:
|
||||
error_msg = "SMTP Authentication failed. Please check your email and authorization code."
|
||||
logging.error(error_msg)
|
||||
self.set_output("_ERROR", error_msg)
|
||||
self.set_output("success", False)
|
||||
return False
|
||||
|
||||
except smtplib.SMTPConnectError:
|
||||
error_msg = f"Failed to connect to SMTP server {self._param.smtp_server}:{self._param.smtp_port}"
|
||||
logging.error(error_msg)
|
||||
last_e = error_msg
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
except smtplib.SMTPException as e:
|
||||
error_msg = f"SMTP error occurred: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
last_e = error_msg
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
self.set_output("_ERROR", error_msg)
|
||||
self.set_output("success", False)
|
||||
return False
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return False
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
inputs = self.get_input()
|
||||
return """
|
||||
To: {}
|
||||
Subject: {}
|
||||
Your email is on its way—sit tight!
|
||||
""".format(inputs.get("to_email", "-_-!"), inputs.get("subject", "-_-!"))
|
||||
212
agent/tools/exesql.py
Normal file
212
agent/tools/exesql.py
Normal file
@@ -0,0 +1,212 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
import pandas as pd
|
||||
import pymysql
|
||||
import psycopg2
|
||||
import pyodbc
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class ExeSQLParam(ToolParamBase):
|
||||
"""
|
||||
Define the ExeSQL component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "execute_sql",
|
||||
"description": "This is a tool that can execute SQL.",
|
||||
"parameters": {
|
||||
"sql": {
|
||||
"type": "string",
|
||||
"description": "The SQL needs to be executed.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.db_type = "mysql"
|
||||
self.database = ""
|
||||
self.username = ""
|
||||
self.host = ""
|
||||
self.port = 3306
|
||||
self.password = ""
|
||||
self.max_records = 1024
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2'])
|
||||
self.check_empty(self.database, "Database name")
|
||||
self.check_empty(self.username, "database username")
|
||||
self.check_empty(self.host, "IP Address")
|
||||
self.check_positive_integer(self.port, "IP Port")
|
||||
self.check_empty(self.password, "Database password")
|
||||
self.check_positive_integer(self.max_records, "Maximum number of records")
|
||||
if self.database == "rag_flow":
|
||||
if self.host == "ragflow-mysql":
|
||||
raise ValueError("For the security reason, it dose not support database named rag_flow.")
|
||||
if self.password == "infini_rag_flow":
|
||||
raise ValueError("For the security reason, it dose not support database named rag_flow.")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"sql": {
|
||||
"name": "SQL",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ExeSQL(ToolBase, ABC):
|
||||
component_name = "ExeSQL"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
|
||||
def convert_decimals(obj):
|
||||
from decimal import Decimal
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj) # 或 str(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_decimals(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_decimals(item) for item in obj]
|
||||
return obj
|
||||
|
||||
sql = kwargs.get("sql")
|
||||
if not sql:
|
||||
raise Exception("SQL for `ExeSQL` MUST not be empty.")
|
||||
|
||||
vars = self.get_input_elements_from_text(sql)
|
||||
args = {}
|
||||
for k, o in vars.items():
|
||||
args[k] = o["value"]
|
||||
if not isinstance(args[k], str):
|
||||
try:
|
||||
args[k] = json.dumps(args[k], ensure_ascii=False)
|
||||
except Exception:
|
||||
args[k] = str(args[k])
|
||||
self.set_input_value(k, args[k])
|
||||
sql = self.string_format(sql, args)
|
||||
|
||||
sqls = sql.split(";")
|
||||
if self._param.db_type in ["mysql", "mariadb"]:
|
||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
||||
port=self._param.port, password=self._param.password)
|
||||
elif self._param.db_type == 'postgres':
|
||||
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
|
||||
port=self._param.port, password=self._param.password)
|
||||
elif self._param.db_type == 'mssql':
|
||||
conn_str = (
|
||||
r'DRIVER={ODBC Driver 17 for SQL Server};'
|
||||
r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
|
||||
r'DATABASE=' + self._param.database + ';'
|
||||
r'UID=' + self._param.username + ';'
|
||||
r'PWD=' + self._param.password
|
||||
)
|
||||
db = pyodbc.connect(conn_str)
|
||||
elif self._param.db_type == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
f"DATABASE={self._param.database};"
|
||||
f"HOSTNAME={self._param.host};"
|
||||
f"PORT={self._param.port};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={self._param.username};"
|
||||
f"PWD={self._param.password};"
|
||||
)
|
||||
try:
|
||||
conn = ibm_db.connect(conn_str, "", "")
|
||||
except Exception as e:
|
||||
raise Exception("Database Connection Failed! \n" + str(e))
|
||||
|
||||
sql_res = []
|
||||
formalized_content = []
|
||||
for single_sql in sqls:
|
||||
single_sql = single_sql.replace("```", "").strip()
|
||||
if not single_sql:
|
||||
continue
|
||||
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
|
||||
|
||||
stmt = ibm_db.exec_immediate(conn, single_sql)
|
||||
rows = []
|
||||
row = ibm_db.fetch_assoc(stmt)
|
||||
while row and len(rows) < self._param.max_records:
|
||||
rows.append(row)
|
||||
row = ibm_db.fetch_assoc(stmt)
|
||||
|
||||
if not rows:
|
||||
sql_res.append({"content": "No record in the database!"})
|
||||
continue
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
for col in df.columns:
|
||||
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
||||
df[col] = df[col].dt.strftime("%Y-%m-%d")
|
||||
|
||||
df = df.where(pd.notnull(df), None)
|
||||
|
||||
sql_res.append(convert_decimals(df.to_dict(orient="records")))
|
||||
formalized_content.append(df.to_markdown(index=False, floatfmt=".6f"))
|
||||
|
||||
ibm_db.close(conn)
|
||||
|
||||
self.set_output("json", sql_res)
|
||||
self.set_output("formalized_content", "\n\n".join(formalized_content))
|
||||
return self.output("formalized_content")
|
||||
try:
|
||||
cursor = db.cursor()
|
||||
except Exception as e:
|
||||
raise Exception("Database Connection Failed! \n" + str(e))
|
||||
|
||||
sql_res = []
|
||||
formalized_content = []
|
||||
for single_sql in sqls:
|
||||
single_sql = single_sql.replace('```','')
|
||||
if not single_sql:
|
||||
continue
|
||||
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
|
||||
cursor.execute(single_sql)
|
||||
if cursor.rowcount == 0:
|
||||
sql_res.append({"content": "No record in the database!"})
|
||||
break
|
||||
if self._param.db_type == 'mssql':
|
||||
single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records),
|
||||
columns=[desc[0] for desc in cursor.description])
|
||||
else:
|
||||
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)])
|
||||
single_res.columns = [i[0] for i in cursor.description]
|
||||
|
||||
for col in single_res.columns:
|
||||
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
|
||||
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
|
||||
|
||||
single_res = single_res.where(pd.notnull(single_res), None)
|
||||
|
||||
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
|
||||
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
||||
|
||||
self.set_output("json", sql_res)
|
||||
self.set_output("formalized_content", "\n\n".join(formalized_content))
|
||||
return self.output("formalized_content")
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Query sent—waiting for the data."
|
||||
91
agent/tools/github.py
Normal file
91
agent/tools/github.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
import requests
|
||||
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class GitHubParam(ToolParamBase):
|
||||
"""
|
||||
Define the GitHub component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "github_search",
|
||||
"description": """GitHub repository search is a feature that enables users to find specific repositories on the GitHub platform. This search functionality allows users to locate projects, codebases, and other content hosted on GitHub based on various criteria.""",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search keywords to execute with GitHub. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class GitHub(ToolBase, ABC):
|
||||
component_name = "GitHub"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
url = 'https://api.github.com/search/repositories?q=' + kwargs["query"] + '&sort=stars&order=desc&per_page=' + str(
|
||||
self._param.top_n)
|
||||
headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'}
|
||||
response = requests.get(url=url, headers=headers).json()
|
||||
self._retrieve_chunks(response['items'],
|
||||
get_title=lambda r: r["name"],
|
||||
get_url=lambda r: r["html_url"],
|
||||
get_content=lambda r: str(r["description"]) + '\n stars:' + str(r['watchers']))
|
||||
self.set_output("json", response['items'])
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"GitHub error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"GitHub error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Scanning GitHub repos related to `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
159
agent/tools/google.py
Normal file
159
agent/tools/google.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
from serpapi import GoogleSearch
|
||||
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class GoogleParam(ToolParamBase):
|
||||
"""
|
||||
Define the Google component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "google_search",
|
||||
"description": """Search the world's information, including webpages, images, videos and more. Google has many special features to help you find exactly what you're looking ...""",
|
||||
"parameters": {
|
||||
"q": {
|
||||
"type": "string",
|
||||
"description": "The search keywords to execute with Google. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
},
|
||||
"start": {
|
||||
"type": "integer",
|
||||
"description": "Parameter defines the result offset. It skips the given number of results. It's used for pagination. (e.g., 0 (default) is the first page of results, 10 is the 2nd page of results, 20 is the 3rd page of results, etc.). Google Local Results only accepts multiples of 20(e.g. 20 for the second page results, 40 for the third page results, etc.) as the `start` value.",
|
||||
"default": "0",
|
||||
"required": False,
|
||||
},
|
||||
"num": {
|
||||
"type": "integer",
|
||||
"description": "Parameter defines the maximum number of results to return. (e.g., 10 (default) returns 10 results, 40 returns 40 results, and 100 returns 100 results). The use of num may introduce latency, and/or prevent the inclusion of specialized result types. It is better to omit this parameter unless it is strictly necessary to increase the number of results per page. Results are not guaranteed to have the number of results specified in num.",
|
||||
"default": "6",
|
||||
"required": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.start = 0
|
||||
self.num = 6
|
||||
self.api_key = ""
|
||||
self.country = "cn"
|
||||
self.language = "en"
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.api_key, "SerpApi API key")
|
||||
self.check_valid_value(self.country, "Google Country",
|
||||
['af', 'al', 'dz', 'as', 'ad', 'ao', 'ai', 'aq', 'ag', 'ar', 'am', 'aw', 'au', 'at',
|
||||
'az', 'bs', 'bh', 'bd', 'bb', 'by', 'be', 'bz', 'bj', 'bm', 'bt', 'bo', 'ba', 'bw',
|
||||
'bv', 'br', 'io', 'bn', 'bg', 'bf', 'bi', 'kh', 'cm', 'ca', 'cv', 'ky', 'cf', 'td',
|
||||
'cl', 'cn', 'cx', 'cc', 'co', 'km', 'cg', 'cd', 'ck', 'cr', 'ci', 'hr', 'cu', 'cy',
|
||||
'cz', 'dk', 'dj', 'dm', 'do', 'ec', 'eg', 'sv', 'gq', 'er', 'ee', 'et', 'fk', 'fo',
|
||||
'fj', 'fi', 'fr', 'gf', 'pf', 'tf', 'ga', 'gm', 'ge', 'de', 'gh', 'gi', 'gr', 'gl',
|
||||
'gd', 'gp', 'gu', 'gt', 'gn', 'gw', 'gy', 'ht', 'hm', 'va', 'hn', 'hk', 'hu', 'is',
|
||||
'in', 'id', 'ir', 'iq', 'ie', 'il', 'it', 'jm', 'jp', 'jo', 'kz', 'ke', 'ki', 'kp',
|
||||
'kr', 'kw', 'kg', 'la', 'lv', 'lb', 'ls', 'lr', 'ly', 'li', 'lt', 'lu', 'mo', 'mk',
|
||||
'mg', 'mw', 'my', 'mv', 'ml', 'mt', 'mh', 'mq', 'mr', 'mu', 'yt', 'mx', 'fm', 'md',
|
||||
'mc', 'mn', 'ms', 'ma', 'mz', 'mm', 'na', 'nr', 'np', 'nl', 'an', 'nc', 'nz', 'ni',
|
||||
'ne', 'ng', 'nu', 'nf', 'mp', 'no', 'om', 'pk', 'pw', 'ps', 'pa', 'pg', 'py', 'pe',
|
||||
'ph', 'pn', 'pl', 'pt', 'pr', 'qa', 're', 'ro', 'ru', 'rw', 'sh', 'kn', 'lc', 'pm',
|
||||
'vc', 'ws', 'sm', 'st', 'sa', 'sn', 'rs', 'sc', 'sl', 'sg', 'sk', 'si', 'sb', 'so',
|
||||
'za', 'gs', 'es', 'lk', 'sd', 'sr', 'sj', 'sz', 'se', 'ch', 'sy', 'tw', 'tj', 'tz',
|
||||
'th', 'tl', 'tg', 'tk', 'to', 'tt', 'tn', 'tr', 'tm', 'tc', 'tv', 'ug', 'ua', 'ae',
|
||||
'uk', 'gb', 'us', 'um', 'uy', 'uz', 'vu', 've', 'vn', 'vg', 'vi', 'wf', 'eh', 'ye',
|
||||
'zm', 'zw'])
|
||||
self.check_valid_value(self.language, "Google languages",
|
||||
['af', 'ak', 'sq', 'ws', 'am', 'ar', 'hy', 'az', 'eu', 'be', 'bem', 'bn', 'bh',
|
||||
'xx-bork', 'bs', 'br', 'bg', 'bt', 'km', 'ca', 'chr', 'ny', 'zh-cn', 'zh-tw', 'co',
|
||||
'hr', 'cs', 'da', 'nl', 'xx-elmer', 'en', 'eo', 'et', 'ee', 'fo', 'tl', 'fi', 'fr',
|
||||
'fy', 'gaa', 'gl', 'ka', 'de', 'el', 'kl', 'gn', 'gu', 'xx-hacker', 'ht', 'ha', 'haw',
|
||||
'iw', 'hi', 'hu', 'is', 'ig', 'id', 'ia', 'ga', 'it', 'ja', 'jw', 'kn', 'kk', 'rw',
|
||||
'rn', 'xx-klingon', 'kg', 'ko', 'kri', 'ku', 'ckb', 'ky', 'lo', 'la', 'lv', 'ln', 'lt',
|
||||
'loz', 'lg', 'ach', 'mk', 'mg', 'ms', 'ml', 'mt', 'mv', 'mi', 'mr', 'mfe', 'mo', 'mn',
|
||||
'sr-me', 'my', 'ne', 'pcm', 'nso', 'no', 'nn', 'oc', 'or', 'om', 'ps', 'fa',
|
||||
'xx-pirate', 'pl', 'pt', 'pt-br', 'pt-pt', 'pa', 'qu', 'ro', 'rm', 'nyn', 'ru', 'gd',
|
||||
'sr', 'sh', 'st', 'tn', 'crs', 'sn', 'sd', 'si', 'sk', 'sl', 'so', 'es', 'es-419', 'su',
|
||||
'sw', 'sv', 'tg', 'ta', 'tt', 'te', 'th', 'ti', 'to', 'lua', 'tum', 'tr', 'tk', 'tw',
|
||||
'ug', 'uk', 'ur', 'uz', 'vu', 'vi', 'cy', 'wo', 'xh', 'yi', 'yo', 'zu']
|
||||
)
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"q": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
},
|
||||
"start": {
|
||||
"name": "From",
|
||||
"type": "integer",
|
||||
"value": 0
|
||||
},
|
||||
"num": {
|
||||
"name": "Limit",
|
||||
"type": "integer",
|
||||
"value": 12
|
||||
}
|
||||
}
|
||||
|
||||
class Google(ToolBase, ABC):
|
||||
component_name = "Google"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("q"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
params = {
|
||||
"api_key": self._param.api_key,
|
||||
"engine": "google",
|
||||
"q": kwargs["q"],
|
||||
"google_domain": "google.com",
|
||||
"gl": self._param.country,
|
||||
"hl": self._param.language
|
||||
}
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
search = GoogleSearch(params).get_dict()
|
||||
self._retrieve_chunks(search["organic_results"],
|
||||
get_title=lambda r: r["title"],
|
||||
get_url=lambda r: r["link"],
|
||||
get_content=lambda r: r.get("about_this_result", {}).get("source", {}).get("description", r["snippet"])
|
||||
)
|
||||
self.set_output("json", search["organic_results"])
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"Google error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"Google error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
96
agent/tools/googlescholar.py
Normal file
96
agent/tools/googlescholar.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
from scholarly import scholarly
|
||||
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class GoogleScholarParam(ToolParamBase):
|
||||
"""
|
||||
Define the GoogleScholar component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "google_scholar_search",
|
||||
"description": """Google Scholar provides a simple way to broadly search for scholarly literature. From one place, you can search across many disciplines and sources: articles, theses, books, abstracts and court opinions, from academic publishers, professional societies, online repositories, universities and other web sites. Google Scholar helps you find relevant work across the world of scholarly research.""",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search keyword to execute with Google Scholar. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.top_n = 12
|
||||
self.sort_by = 'relevance'
|
||||
self.year_low = None
|
||||
self.year_high = None
|
||||
self.patents = True
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.sort_by, "GoogleScholar Sort_by", ['date', 'relevance'])
|
||||
self.check_boolean(self.patents, "Whether or not to include patents, defaults to True")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class GoogleScholar(ToolBase, ABC):
|
||||
component_name = "GoogleScholar"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
scholar_client = scholarly.search_pubs(kwargs["query"], patents=self._param.patents, year_low=self._param.year_low,
|
||||
year_high=self._param.year_high, sort_by=self._param.sort_by)
|
||||
self._retrieve_chunks(scholar_client,
|
||||
get_title=lambda r: r['bib']['title'],
|
||||
get_url=lambda r: r["pub_url"],
|
||||
get_content=lambda r: "\n author: " + ",".join(r['bib']['author']) + '\n Abstract: ' + r['bib'].get('abstract', 'no abstract')
|
||||
)
|
||||
self.set_output("json", list(scholar_client))
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"GoogleScholar error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"GoogleScholar error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
130
agent/tools/jin10.py
Normal file
130
agent/tools/jin10.py
Normal file
@@ -0,0 +1,130 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
from abc import ABC
|
||||
import pandas as pd
|
||||
import requests
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class Jin10Param(ComponentParamBase):
|
||||
"""
|
||||
Define the Jin10 component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.type = "flash"
|
||||
self.secret_key = "xxx"
|
||||
self.flash_type = '1'
|
||||
self.calendar_type = 'cj'
|
||||
self.calendar_datatype = 'data'
|
||||
self.symbols_type = 'GOODS'
|
||||
self.symbols_datatype = 'symbols'
|
||||
self.contain = ""
|
||||
self.filter = ""
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.type, "Type", ['flash', 'calendar', 'symbols', 'news'])
|
||||
self.check_valid_value(self.flash_type, "Flash Type", ['1', '2', '3', '4', '5'])
|
||||
self.check_valid_value(self.calendar_type, "Calendar Type", ['cj', 'qh', 'hk', 'us'])
|
||||
self.check_valid_value(self.calendar_datatype, "Calendar DataType", ['data', 'event', 'holiday'])
|
||||
self.check_valid_value(self.symbols_type, "Symbols Type", ['GOODS', 'FOREX', 'FUTURE', 'CRYPTO'])
|
||||
self.check_valid_value(self.symbols_datatype, 'Symbols DataType', ['symbols', 'quotes'])
|
||||
|
||||
|
||||
class Jin10(ComponentBase, ABC):
|
||||
component_name = "Jin10"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return Jin10.be_output("")
|
||||
|
||||
jin10_res = []
|
||||
headers = {'secret-key': self._param.secret_key}
|
||||
try:
|
||||
if self._param.type == "flash":
|
||||
params = {
|
||||
'category': self._param.flash_type,
|
||||
'contain': self._param.contain,
|
||||
'filter': self._param.filter
|
||||
}
|
||||
response = requests.get(
|
||||
url='https://open-data-api.jin10.com/data-api/flash?category=' + self._param.flash_type,
|
||||
headers=headers, data=json.dumps(params))
|
||||
response = response.json()
|
||||
for i in response['data']:
|
||||
jin10_res.append({"content": i['data']['content']})
|
||||
if self._param.type == "calendar":
|
||||
params = {
|
||||
'category': self._param.calendar_type
|
||||
}
|
||||
response = requests.get(
|
||||
url='https://open-data-api.jin10.com/data-api/calendar/' + self._param.calendar_datatype + '?category=' + self._param.calendar_type,
|
||||
headers=headers, data=json.dumps(params))
|
||||
|
||||
response = response.json()
|
||||
jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()})
|
||||
if self._param.type == "symbols":
|
||||
params = {
|
||||
'type': self._param.symbols_type
|
||||
}
|
||||
if self._param.symbols_datatype == "quotes":
|
||||
params['codes'] = 'BTCUSD'
|
||||
response = requests.get(
|
||||
url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type,
|
||||
headers=headers, data=json.dumps(params))
|
||||
response = response.json()
|
||||
if self._param.symbols_datatype == "symbols":
|
||||
for i in response['data']:
|
||||
i['Commodity Code'] = i['c']
|
||||
i['Stock Exchange'] = i['e']
|
||||
i['Commodity Name'] = i['n']
|
||||
i['Commodity Type'] = i['t']
|
||||
del i['c'], i['e'], i['n'], i['t']
|
||||
if self._param.symbols_datatype == "quotes":
|
||||
for i in response['data']:
|
||||
i['Selling Price'] = i['a']
|
||||
i['Buying Price'] = i['b']
|
||||
i['Commodity Code'] = i['c']
|
||||
i['Stock Exchange'] = i['e']
|
||||
i['Highest Price'] = i['h']
|
||||
i['Yesterday’s Closing Price'] = i['hc']
|
||||
i['Lowest Price'] = i['l']
|
||||
i['Opening Price'] = i['o']
|
||||
i['Latest Price'] = i['p']
|
||||
i['Market Quote Time'] = i['t']
|
||||
del i['a'], i['b'], i['c'], i['e'], i['h'], i['hc'], i['l'], i['o'], i['p'], i['t']
|
||||
jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()})
|
||||
if self._param.type == "news":
|
||||
params = {
|
||||
'contain': self._param.contain,
|
||||
'filter': self._param.filter
|
||||
}
|
||||
response = requests.get(
|
||||
url='https://open-data-api.jin10.com/data-api/news',
|
||||
headers=headers, data=json.dumps(params))
|
||||
response = response.json()
|
||||
jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()})
|
||||
except Exception as e:
|
||||
return Jin10.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not jin10_res:
|
||||
return Jin10.be_output("")
|
||||
|
||||
return pd.DataFrame(jin10_res)
|
||||
108
agent/tools/pubmed.py
Normal file
108
agent/tools/pubmed.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
from Bio import Entrez
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class PubMedParam(ToolParamBase):
|
||||
"""
|
||||
Define the PubMed component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "pubmed_search",
|
||||
"description": """
|
||||
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
|
||||
In addition to MEDLINE, PubMed provides access to:
|
||||
- older references from the print version of Index Medicus, back to 1951 and earlier
|
||||
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
|
||||
- very recent entries to records for an article before it is indexed with Medical Subject Headings (MeSH) and added to MEDLINE
|
||||
- a collection of books available full-text and other subsets of NLM records[4]
|
||||
- PMC citations
|
||||
- NCBI Bookshelf
|
||||
""",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search keywords to execute with PubMed. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.top_n = 12
|
||||
self.email = "A.N.Other@example.com"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class PubMed(ToolBase, ABC):
|
||||
component_name = "PubMed"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
Entrez.email = self._param.email
|
||||
pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=kwargs["query"]))['IdList']
|
||||
pubmedcnt = ET.fromstring(re.sub(r'<(/?)b>|<(/?)i>', '', Entrez.efetch(db='pubmed', id=",".join(pubmedids),
|
||||
retmode="xml").read().decode("utf-8")))
|
||||
self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"),
|
||||
get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text,
|
||||
get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text,
|
||||
get_content=lambda child: child.find("MedlineCitation") \
|
||||
.find("Article") \
|
||||
.find("Abstract") \
|
||||
.find("AbstractText").text \
|
||||
if child.find("MedlineCitation")\
|
||||
.find("Article").find("Abstract") \
|
||||
else "No abstract available")
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"PubMed error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"PubMed error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Looking for scholarly papers on `{}`,” prioritising reputable sources.".format(self.get_input().get("query", "-_-!"))
|
||||
111
agent/tools/qweather.py
Normal file
111
agent/tools/qweather.py
Normal file
@@ -0,0 +1,111 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
import pandas as pd
|
||||
import requests
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class QWeatherParam(ComponentParamBase):
|
||||
"""
|
||||
Define the QWeather component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.web_apikey = "xxx"
|
||||
self.lang = "zh"
|
||||
self.type = "weather"
|
||||
self.user_type = 'free'
|
||||
self.error_code = {
|
||||
"204": "The request was successful, but the region you are querying does not have the data you need at this time.",
|
||||
"400": "Request error, may contain incorrect request parameters or missing mandatory request parameters.",
|
||||
"401": "Authentication fails, possibly using the wrong KEY, wrong digital signature, wrong type of KEY (e.g. using the SDK's KEY to access the Web API).",
|
||||
"402": "Exceeded the number of accesses or the balance is not enough to support continued access to the service, you can recharge, upgrade the accesses or wait for the accesses to be reset.",
|
||||
"403": "No access, may be the binding PackageName, BundleID, domain IP address is inconsistent, or the data that requires additional payment.",
|
||||
"404": "The queried data or region does not exist.",
|
||||
"429": "Exceeded the limited QPM (number of accesses per minute), please refer to the QPM description",
|
||||
"500": "No response or timeout, interface service abnormality please contact us"
|
||||
}
|
||||
# Weather
|
||||
self.time_period = 'now'
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.web_apikey, "BaiduFanyi APPID")
|
||||
self.check_valid_value(self.type, "Type", ["weather", "indices", "airquality"])
|
||||
self.check_valid_value(self.user_type, "Free subscription or paid subscription", ["free", "paid"])
|
||||
self.check_valid_value(self.lang, "Use language",
|
||||
['zh', 'zh-hant', 'en', 'de', 'es', 'fr', 'it', 'ja', 'ko', 'ru', 'hi', 'th', 'ar', 'pt',
|
||||
'bn', 'ms', 'nl', 'el', 'la', 'sv', 'id', 'pl', 'tr', 'cs', 'et', 'vi', 'fil', 'fi',
|
||||
'he', 'is', 'nb'])
|
||||
self.check_valid_value(self.time_period, "Time period", ['now', '3d', '7d', '10d', '15d', '30d'])
|
||||
|
||||
|
||||
class QWeather(ComponentBase, ABC):
|
||||
component_name = "QWeather"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = "".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return QWeather.be_output("")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey).json()
|
||||
if response["code"] == "200":
|
||||
location_id = response["location"][0]["id"]
|
||||
else:
|
||||
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
|
||||
|
||||
base_url = "https://api.qweather.com/v7/" if self._param.user_type == 'paid' else "https://devapi.qweather.com/v7/"
|
||||
|
||||
if self._param.type == "weather":
|
||||
url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
|
||||
response = requests.get(url=url).json()
|
||||
if response["code"] == "200":
|
||||
if self._param.time_period == "now":
|
||||
return QWeather.be_output(str(response["now"]))
|
||||
else:
|
||||
qweather_res = [{"content": str(i) + "\n"} for i in response["daily"]]
|
||||
if not qweather_res:
|
||||
return QWeather.be_output("")
|
||||
|
||||
df = pd.DataFrame(qweather_res)
|
||||
return df
|
||||
else:
|
||||
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
|
||||
|
||||
elif self._param.type == "indices":
|
||||
url = base_url + "indices/1d?type=0&location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
|
||||
response = requests.get(url=url).json()
|
||||
if response["code"] == "200":
|
||||
indices_res = response["daily"][0]["date"] + "\n" + "\n".join(
|
||||
[i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]])
|
||||
return QWeather.be_output(indices_res)
|
||||
|
||||
else:
|
||||
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
|
||||
|
||||
elif self._param.type == "airquality":
|
||||
url = base_url + "air/now?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
|
||||
response = requests.get(url=url).json()
|
||||
if response["code"] == "200":
|
||||
return QWeather.be_output(str(response["now"]))
|
||||
else:
|
||||
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
|
||||
except Exception as e:
|
||||
return QWeather.be_output("**Error**" + str(e))
|
||||
181
agent/tools/retrieval.py
Normal file
181
agent/tools/retrieval.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt
|
||||
|
||||
|
||||
class RetrievalParam(ToolParamBase):
|
||||
"""
|
||||
Define the Retrieval component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "search_my_dateset",
|
||||
"description": "This tool can be utilized for relevant content searching in the datasets.",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The keywords to search the dataset. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||||
"default": "",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.function_name = "search_my_dateset"
|
||||
self.description = "This tool can be utilized for relevant content searching in the datasets."
|
||||
self.similarity_threshold = 0.2
|
||||
self.keywords_similarity_weight = 0.5
|
||||
self.top_n = 8
|
||||
self.top_k = 1024
|
||||
self.kb_ids = []
|
||||
self.kb_vars = []
|
||||
self.rerank_id = ""
|
||||
self.empty_response = ""
|
||||
self.use_kg = False
|
||||
self.cross_languages = []
|
||||
|
||||
def check(self):
|
||||
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
|
||||
self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keyword similarity weight")
|
||||
self.check_positive_number(self.top_n, "[Retrieval] Top N")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class Retrieval(ToolBase, ABC):
|
||||
component_name = "Retrieval"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
|
||||
kb_ids: list[str] = []
|
||||
for id in self._param.kb_ids:
|
||||
if id.find("@") < 0:
|
||||
kb_ids.append(id)
|
||||
continue
|
||||
kb_nm = self._canvas.get_variable_value(id)
|
||||
# if kb_nm is a list
|
||||
kb_nm_list = kb_nm if isinstance(kb_nm, list) else [kb_nm]
|
||||
for nm_or_id in kb_nm_list:
|
||||
e, kb = KnowledgebaseService.get_by_name(nm_or_id,
|
||||
self._canvas._tenant_id)
|
||||
if not e:
|
||||
e, kb = KnowledgebaseService.get_by_id(nm_or_id)
|
||||
if not e:
|
||||
raise Exception(f"Dataset({nm_or_id}) does not exist.")
|
||||
kb_ids.append(kb.id)
|
||||
|
||||
filtered_kb_ids: list[str] = list(set([kb_id for kb_id in kb_ids if kb_id]))
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
|
||||
if not kbs:
|
||||
raise Exception("No dataset is selected.")
|
||||
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
|
||||
|
||||
embd_mdl = None
|
||||
if embd_nms:
|
||||
embd_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, embd_nms[0])
|
||||
|
||||
rerank_mdl = None
|
||||
if self._param.rerank_id:
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||
|
||||
vars = self.get_input_elements_from_text(kwargs["query"])
|
||||
vars = {k:o["value"] for k,o in vars.items()}
|
||||
query = self.string_format(kwargs["query"], vars)
|
||||
if self._param.cross_languages:
|
||||
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||
|
||||
if kbs:
|
||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||
kbinfos = settings.retrievaler.retrieval(
|
||||
query,
|
||||
embd_mdl,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
filtered_kb_ids,
|
||||
1,
|
||||
self._param.top_n,
|
||||
self._param.similarity_threshold,
|
||||
1 - self._param.keywords_similarity_weight,
|
||||
aggs=False,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(query, kbs),
|
||||
)
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
else:
|
||||
kbinfos = {"chunks": [], "doc_aggs": []}
|
||||
|
||||
if self._param.use_kg and kbs:
|
||||
ck = settings.kg_retrievaler.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ck["content"] = ck["content_with_weight"]
|
||||
del ck["content_with_weight"]
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
for ck in kbinfos["chunks"]:
|
||||
if "vector" in ck:
|
||||
del ck["vector"]
|
||||
if "content_ltks" in ck:
|
||||
del ck["content_ltks"]
|
||||
|
||||
if not kbinfos["chunks"]:
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
# Format the chunks for JSON output (similar to how other tools do it)
|
||||
json_output = kbinfos["chunks"].copy()
|
||||
|
||||
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
|
||||
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
|
||||
|
||||
# Set both formalized content and JSON output
|
||||
self.set_output("formalized_content", form_cnt)
|
||||
self.set_output("json", json_output)
|
||||
|
||||
return form_cnt
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
151
agent/tools/searxng.py
Normal file
151
agent/tools/searxng.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
import requests
|
||||
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class SearXNGParam(ToolParamBase):
|
||||
"""
|
||||
Define the SearXNG component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta: ToolMeta = {
|
||||
"name": "searxng_search",
|
||||
"description": "SearXNG is a privacy-focused metasearch engine that aggregates results from multiple search engines without tracking users. It provides comprehensive web search capabilities.",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search keywords to execute with SearXNG. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
},
|
||||
"searxng_url": {
|
||||
"type": "string",
|
||||
"description": "The base URL of your SearXNG instance (e.g., http://localhost:4000). This is required to connect to your SearXNG server.",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
self.searxng_url = ""
|
||||
|
||||
def check(self):
|
||||
# Keep validation lenient so opening try-run panel won't fail without URL.
|
||||
# Coerce top_n to int if it comes as string from UI.
|
||||
try:
|
||||
if isinstance(self.top_n, str):
|
||||
self.top_n = int(self.top_n.strip())
|
||||
except Exception:
|
||||
pass
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
},
|
||||
"searxng_url": {
|
||||
"name": "SearXNG URL",
|
||||
"type": "line",
|
||||
"placeholder": "http://localhost:4000"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class SearXNG(ToolBase, ABC):
|
||||
component_name = "SearXNG"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
# Gracefully handle try-run without inputs
|
||||
query = kwargs.get("query")
|
||||
if not query or not isinstance(query, str) or not query.strip():
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
searxng_url = (getattr(self._param, "searxng_url", "") or kwargs.get("searxng_url") or "").strip()
|
||||
# In try-run, if no URL configured, just return empty instead of raising
|
||||
if not searxng_url:
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
search_params = {
|
||||
'q': query,
|
||||
'format': 'json',
|
||||
'categories': 'general',
|
||||
'language': 'auto',
|
||||
'safesearch': 1,
|
||||
'pageno': 1
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
f"{searxng_url}/search",
|
||||
params=search_params,
|
||||
timeout=10
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
if not data or not isinstance(data, dict):
|
||||
raise ValueError("Invalid response from SearXNG")
|
||||
|
||||
results = data.get("results", [])
|
||||
if not isinstance(results, list):
|
||||
raise ValueError("Invalid results format from SearXNG")
|
||||
|
||||
results = results[:self._param.top_n]
|
||||
|
||||
self._retrieve_chunks(results,
|
||||
get_title=lambda r: r.get("title", ""),
|
||||
get_url=lambda r: r.get("url", ""),
|
||||
get_content=lambda r: r.get("content", ""))
|
||||
|
||||
self.set_output("json", results)
|
||||
return self.output("formalized_content")
|
||||
|
||||
except requests.RequestException as e:
|
||||
last_e = f"Network error: {e}"
|
||||
logging.exception(f"SearXNG network error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
except Exception as e:
|
||||
last_e = str(e)
|
||||
logging.exception(f"SearXNG error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", last_e)
|
||||
return f"SearXNG error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Searching with SearXNG for relevant results...
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
227
agent/tools/tavily.py
Normal file
227
agent/tools/tavily.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
from tavily import TavilyClient
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class TavilySearchParam(ToolParamBase):
|
||||
"""
|
||||
Define the Retrieval component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "tavily_search",
|
||||
"description": """
|
||||
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
|
||||
When searching:
|
||||
- Start with specific query which should focus on just a single aspect.
|
||||
- Number of keywords in query should be less than 5.
|
||||
- Broaden search terms if needed
|
||||
- Cross-reference information from multiple sources
|
||||
""",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search keywords to execute with Tavily. The keywords should be the most important words/terms(includes synonyms) from the original request.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
},
|
||||
"topic": {
|
||||
"type": "string",
|
||||
"description": "default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.",
|
||||
"enum": ["general", "news"],
|
||||
"default": "general",
|
||||
"required": False,
|
||||
},
|
||||
"include_domains": {
|
||||
"type": "array",
|
||||
"description": "default:[]. A list of domains only from which the search results can be included.",
|
||||
"default": [],
|
||||
"items": {
|
||||
"type": "string",
|
||||
"description": "Domain name that must be included, e.g. www.yahoo.com"
|
||||
},
|
||||
"required": False
|
||||
},
|
||||
"exclude_domains": {
|
||||
"type": "array",
|
||||
"description": "default:[]. A list of domains from which the search results can not be included",
|
||||
"default": [],
|
||||
"items": {
|
||||
"type": "string",
|
||||
"description": "Domain name that must be excluded, e.g. www.yahoo.com"
|
||||
},
|
||||
"required": False
|
||||
},
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.api_key = ""
|
||||
self.search_depth = "basic" # basic/advanced
|
||||
self.max_results = 6
|
||||
self.days = 14
|
||||
self.include_answer = False
|
||||
self.include_raw_content = False
|
||||
self.include_images = False
|
||||
self.include_image_descriptions = False
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.topic, "Tavily topic: should be in 'general/news'", ["general", "news"])
|
||||
self.check_valid_value(self.search_depth, "Tavily search depth should be in 'basic/advanced'", ["basic", "advanced"])
|
||||
self.check_positive_integer(self.max_results, "Tavily max result number should be within [1, 20]")
|
||||
self.check_positive_integer(self.days, "Tavily days should be greater than 1")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class TavilySearch(ToolBase, ABC):
|
||||
component_name = "TavilySearch"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||||
last_e = None
|
||||
for fld in ["search_depth", "topic", "max_results", "days", "include_answer", "include_raw_content", "include_images", "include_image_descriptions", "include_domains", "exclude_domains"]:
|
||||
if fld not in kwargs:
|
||||
kwargs[fld] = getattr(self._param, fld)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
kwargs["include_images"] = False
|
||||
kwargs["include_raw_content"] = False
|
||||
res = self.tavily_client.search(**kwargs)
|
||||
self._retrieve_chunks(res["results"],
|
||||
get_title=lambda r: r["title"],
|
||||
get_url=lambda r: r["url"],
|
||||
get_content=lambda r: r["raw_content"] if r["raw_content"] else r["content"],
|
||||
get_score=lambda r: r["score"])
|
||||
self.set_output("json", res["results"])
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"Tavily error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"Tavily error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
|
||||
|
||||
class TavilyExtractParam(ToolParamBase):
|
||||
"""
|
||||
Define the Retrieval component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "tavily_extract",
|
||||
"description": "Extract web page content from one or more specified URLs using Tavily Extract.",
|
||||
"parameters": {
|
||||
"urls": {
|
||||
"type": "array",
|
||||
"description": "The URLs to extract content from.",
|
||||
"default": "",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"description": "The URL to extract content from, e.g. www.yahoo.com"
|
||||
},
|
||||
"required": True
|
||||
},
|
||||
"extract_depth": {
|
||||
"type": "string",
|
||||
"description": "The depth of the extraction process. advanced extraction retrieves more data, including tables and embedded content, with higher success but may increase latency.basic extraction costs 1 credit per 5 successful URL extractions, while advanced extraction costs 2 credits per 5 successful URL extractions.",
|
||||
"enum": ["basic", "advanced"],
|
||||
"default": "basic",
|
||||
"required": False,
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "The format of the extracted web page content. markdown returns content in markdown format. text returns plain text and may increase latency.",
|
||||
"enum": ["markdown", "text"],
|
||||
"default": "markdown",
|
||||
"required": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.api_key = ""
|
||||
self.extract_depth = "basic" # basic/advanced
|
||||
self.urls = []
|
||||
self.format = "markdown"
|
||||
self.include_images = False
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.extract_depth, "Tavily extract depth should be in 'basic/advanced'", ["basic", "advanced"])
|
||||
self.check_valid_value(self.format, "Tavily extract format should be in 'markdown/text'", ["markdown", "text"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"urls": {
|
||||
"name": "URLs",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class TavilyExtract(ToolBase, ABC):
|
||||
component_name = "TavilyExtract"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
self.tavily_client = TavilyClient(api_key=self._param.api_key)
|
||||
last_e = None
|
||||
for fld in ["urls", "extract_depth", "format"]:
|
||||
if fld not in kwargs:
|
||||
kwargs[fld] = getattr(self._param, fld)
|
||||
if kwargs.get("urls") and isinstance(kwargs["urls"], str):
|
||||
kwargs["urls"] = kwargs["urls"].split(",")
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
kwargs["include_images"] = False
|
||||
res = self.tavily_client.extract(**kwargs)
|
||||
self.set_output("json", res["results"])
|
||||
return self.output("json")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"Tavily error: {e}")
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"Tavily error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))
|
||||
72
agent/tools/tushare.py
Normal file
72
agent/tools/tushare.py
Normal file
@@ -0,0 +1,72 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
from abc import ABC
|
||||
import pandas as pd
|
||||
import time
|
||||
import requests
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class TuShareParam(ComponentParamBase):
|
||||
"""
|
||||
Define the TuShare component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.token = "xxx"
|
||||
self.src = "eastmoney"
|
||||
self.start_date = "2024-01-01 09:00:00"
|
||||
self.end_date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
self.keyword = ""
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.src, "Quick News Source",
|
||||
["sina", "wallstreetcn", "10jqka", "eastmoney", "yuncaijing", "fenghuang", "jinrongjie"])
|
||||
|
||||
|
||||
class TuShare(ComponentBase, ABC):
|
||||
component_name = "TuShare"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
ans = self.get_input()
|
||||
ans = ",".join(ans["content"]) if "content" in ans else ""
|
||||
if not ans:
|
||||
return TuShare.be_output("")
|
||||
|
||||
try:
|
||||
tus_res = []
|
||||
params = {
|
||||
"api_name": "news",
|
||||
"token": self._param.token,
|
||||
"params": {"src": self._param.src, "start_date": self._param.start_date,
|
||||
"end_date": self._param.end_date}
|
||||
}
|
||||
response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8'))
|
||||
response = response.json()
|
||||
if response['code'] != 0:
|
||||
return TuShare.be_output(response['msg'])
|
||||
df = pd.DataFrame(response['data']['items'])
|
||||
df.columns = response['data']['fields']
|
||||
tus_res.append({"content": (df[df['content'].str.contains(self._param.keyword, case=False)]).to_markdown()})
|
||||
except Exception as e:
|
||||
return TuShare.be_output("**ERROR**: " + str(e))
|
||||
|
||||
if not tus_res:
|
||||
return TuShare.be_output("")
|
||||
|
||||
return pd.DataFrame(tus_res)
|
||||
114
agent/tools/wencai.py
Normal file
114
agent/tools/wencai.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
import pandas as pd
|
||||
import pywencai
|
||||
|
||||
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class WenCaiParam(ToolParamBase):
|
||||
"""
|
||||
Define the WenCai component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "iwencai",
|
||||
"description": """
|
||||
iwencai search: search platform is committed to providing hundreds of millions of investors with the most timely, accurate and comprehensive information, covering news, announcements, research reports, blogs, forums, Weibo, characters, etc.
|
||||
robo-advisor intelligent stock selection platform: through AI technology, is committed to providing investors with intelligent stock selection, quantitative investment, main force tracking, value investment, technical analysis and other types of stock selection technologies.
|
||||
fund selection platform: through AI technology, is committed to providing excellent fund, value investment, quantitative analysis and other fund selection technologies for foundation citizens.
|
||||
""",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The question/conditions to select stocks.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
self.query_type = "stock"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.query_type, "Query type",
|
||||
['stock', 'zhishu', 'fund', 'hkstock', 'usstock', 'threeboard', 'conbond', 'insurance',
|
||||
'futures', 'lccp',
|
||||
'foreign_exchange'])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class WenCai(ToolBase, ABC):
|
||||
component_name = "WenCai"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("report", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
wencai_res = []
|
||||
res = pywencai.get(query=kwargs["query"], query_type=self._param.query_type, perpage=self._param.top_n)
|
||||
if isinstance(res, pd.DataFrame):
|
||||
wencai_res.append(res.to_markdown())
|
||||
elif isinstance(res, dict):
|
||||
for item in res.items():
|
||||
if isinstance(item[1], list):
|
||||
wencai_res.append(item[0] + "\n" + pd.DataFrame(item[1]).to_markdown())
|
||||
elif isinstance(item[1], str):
|
||||
wencai_res.append(item[0] + "\n" + item[1])
|
||||
elif isinstance(item[1], dict):
|
||||
if "meta" in item[1].keys():
|
||||
continue
|
||||
wencai_res.append(pd.DataFrame.from_dict(item[1], orient='index').to_markdown())
|
||||
elif isinstance(item[1], pd.DataFrame):
|
||||
if "image_url" in item[1].columns:
|
||||
continue
|
||||
wencai_res.append(item[1].to_markdown())
|
||||
else:
|
||||
wencai_res.append(item[0] + "\n" + str(item[1]))
|
||||
self.set_output("report", "\n\n".join(wencai_res))
|
||||
return self.output("report")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"WenCai error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"WenCai error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("query", "-_-!"))
|
||||
104
agent/tools/wikipedia.py
Normal file
104
agent/tools/wikipedia.py
Normal file
@@ -0,0 +1,104 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
import wikipedia
|
||||
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class WikipediaParam(ToolParamBase):
|
||||
"""
|
||||
Define the Wikipedia component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "wikipedia_search",
|
||||
"description": """A wide range of how-to and information pages are made available in wikipedia. Since 2001, it has grown rapidly to become the world's largest reference website. From Wikipedia, the free encyclopedia.""",
|
||||
"parameters": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search keyword to execute with wikipedia. The keyword MUST be a specific subject that can match the title.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.top_n = 10
|
||||
self.language = "en"
|
||||
|
||||
def check(self):
|
||||
self.check_positive_integer(self.top_n, "Top N")
|
||||
self.check_valid_value(self.language, "Wikipedia languages",
|
||||
['af', 'pl', 'ar', 'ast', 'az', 'bg', 'nan', 'bn', 'be', 'ca', 'cs', 'cy', 'da', 'de',
|
||||
'et', 'el', 'en', 'es', 'eo', 'eu', 'fa', 'fr', 'gl', 'ko', 'hy', 'hi', 'hr', 'id',
|
||||
'it', 'he', 'ka', 'lld', 'la', 'lv', 'lt', 'hu', 'mk', 'arz', 'ms', 'min', 'my', 'nl',
|
||||
'ja', 'nb', 'nn', 'ce', 'uz', 'pt', 'kk', 'ro', 'ru', 'ceb', 'sk', 'sl', 'sr', 'sh',
|
||||
'fi', 'sv', 'ta', 'tt', 'th', 'tg', 'azb', 'tr', 'uk', 'ur', 'vi', 'war', 'zh', 'yue'])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class Wikipedia(ToolBase, ABC):
|
||||
component_name = "Wikipedia"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
try:
|
||||
wikipedia.set_lang(self._param.language)
|
||||
wiki_engine = wikipedia
|
||||
pages = []
|
||||
for p in wiki_engine.search(kwargs["query"], results=self._param.top_n):
|
||||
try:
|
||||
pages.append(wikipedia.page(p))
|
||||
except Exception:
|
||||
pass
|
||||
self._retrieve_chunks(pages,
|
||||
get_title=lambda r: r.title,
|
||||
get_url=lambda r: r.url,
|
||||
get_content=lambda r: r.summary)
|
||||
return self.output("formalized_content")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"Wikipedia error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"Wikipedia error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return """
|
||||
Keywords: {}
|
||||
Looking for the most relevant articles.
|
||||
""".format(self.get_input().get("query", "-_-!"))
|
||||
114
agent/tools/yahoofinance.py
Normal file
114
agent/tools/yahoofinance.py
Normal file
@@ -0,0 +1,114 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class YahooFinanceParam(ToolParamBase):
|
||||
"""
|
||||
Define the YahooFinance component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.meta:ToolMeta = {
|
||||
"name": "yahoo_finance",
|
||||
"description": "The Yahoo Finance is a service that provides access to real-time and historical stock market data. It enables users to fetch various types of stock information, such as price quotes, historical prices, company profiles, and financial news. The API offers structured data, allowing developers to integrate market data into their applications and analysis tools.",
|
||||
"parameters": {
|
||||
"stock_code": {
|
||||
"type": "string",
|
||||
"description": "The stock code or company name.",
|
||||
"default": "{sys.query}",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
super().__init__()
|
||||
self.info = True
|
||||
self.history = False
|
||||
self.count = False
|
||||
self.financials = False
|
||||
self.income_stmt = False
|
||||
self.balance_sheet = False
|
||||
self.cash_flow_statement = False
|
||||
self.news = True
|
||||
|
||||
def check(self):
|
||||
self.check_boolean(self.info, "get all stock info")
|
||||
self.check_boolean(self.history, "get historical market data")
|
||||
self.check_boolean(self.count, "show share count")
|
||||
self.check_boolean(self.financials, "show financials")
|
||||
self.check_boolean(self.income_stmt, "income statement")
|
||||
self.check_boolean(self.balance_sheet, "balance sheet")
|
||||
self.check_boolean(self.cash_flow_statement, "cash flow statement")
|
||||
self.check_boolean(self.news, "show news")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"stock_code": {
|
||||
"name": "Stock code/Company name",
|
||||
"type": "line"
|
||||
}
|
||||
}
|
||||
|
||||
class YahooFinance(ToolBase, ABC):
|
||||
component_name = "YahooFinance"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||
def _invoke(self, **kwargs):
|
||||
if not kwargs.get("stock_code"):
|
||||
self.set_output("report", "")
|
||||
return ""
|
||||
|
||||
last_e = ""
|
||||
for _ in range(self._param.max_retries+1):
|
||||
yohoo_res = []
|
||||
try:
|
||||
msft = yf.Ticker(kwargs["stock_code"])
|
||||
if self._param.info:
|
||||
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
||||
if self._param.history:
|
||||
yohoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
|
||||
if self._param.financials:
|
||||
yohoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
|
||||
if self._param.balance_sheet:
|
||||
yohoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
|
||||
yohoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
|
||||
if self._param.cash_flow_statement:
|
||||
yohoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
|
||||
yohoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
|
||||
if self._param.news:
|
||||
yohoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
|
||||
self.set_output("report", "\n\n".join(yohoo_res))
|
||||
return self.output("report")
|
||||
except Exception as e:
|
||||
last_e = e
|
||||
logging.exception(f"YahooFinance error: {e}")
|
||||
time.sleep(self._param.delay_after_error)
|
||||
|
||||
if last_e:
|
||||
self.set_output("_ERROR", str(last_e))
|
||||
return f"YahooFinance error: {last_e}"
|
||||
|
||||
assert False, self.output()
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Pulling live financial data for `{}`.".format(self.get_input().get("stock_code", "-_-!"))
|
||||
Reference in New Issue
Block a user