v0.21.1-fastapi
This commit is contained in:
@@ -143,15 +143,12 @@ class UserCanvasService(CommonService):
|
||||
]
|
||||
if keywords:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
cls.model.user_id.in_(joined_tenant_ids),
|
||||
fn.LOWER(cls.model.title).contains(keywords.lower())
|
||||
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
|
||||
#(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
||||
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
|
||||
(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
cls.model.user_id.in_(joined_tenant_ids)
|
||||
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
|
||||
(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
|
||||
)
|
||||
if canvas_category:
|
||||
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||
|
||||
@@ -370,7 +370,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
chat_mdl.bind_tools(toolcall_session, tools)
|
||||
bind_models_ts = timer()
|
||||
|
||||
retriever = settings.retrievaler
|
||||
retriever = settings.retriever
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||
if "doc_ids" in messages[-1]:
|
||||
@@ -466,13 +466,17 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs),
|
||||
)
|
||||
if prompt_config.get("toc_enhance"):
|
||||
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
if prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if prompt_config.get("use_kg"):
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||
ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
@@ -658,7 +662,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl is None:
|
||||
@@ -752,7 +756,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
|
||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
||||
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||
@@ -848,7 +852,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
|
||||
ranks = settings.retrievaler.retrieval(
|
||||
ranks = settings.retriever.retrieval(
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
|
||||
@@ -79,7 +79,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, id, name):
|
||||
orderby, desc, keywords, id, name, suffix=None, run = None):
|
||||
fields = cls.get_cls_model_fields()
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
|
||||
.join(File, on = (File.id == File2Document.file_id))\
|
||||
@@ -96,6 +96,10 @@ class DocumentService(CommonService):
|
||||
docs = docs.where(
|
||||
fn.LOWER(cls.model.name).contains(keywords.lower())
|
||||
)
|
||||
if suffix:
|
||||
docs = docs.where(cls.model.suffix.in_(suffix))
|
||||
if run:
|
||||
docs = docs.where(cls.model.run.in_(run))
|
||||
if desc:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
@@ -667,9 +671,11 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def _sync_progress(cls, docs:list[dict]):
|
||||
from api.db.services.task_service import TaskService
|
||||
|
||||
for d in docs:
|
||||
try:
|
||||
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
|
||||
tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time)
|
||||
if not tsks:
|
||||
continue
|
||||
msg = []
|
||||
@@ -787,21 +793,23 @@ class DocumentService(CommonService):
|
||||
"cancelled": int(cancelled),
|
||||
}
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||
def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||
"""
|
||||
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||
Optionally, specify a list of doc_ids to determine which documents participate in the task.
|
||||
"""
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
assert ty in ["graphrag", "raptor", "mindmap"], "type should be graphrag, raptor or mindmap"
|
||||
|
||||
chunking_config = DocumentService.get_chunking_config(sample_doc_id["id"])
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
hasher.update(str(chunking_config[field]).encode("utf-8"))
|
||||
|
||||
def new_task():
|
||||
nonlocal doc
|
||||
nonlocal sample_doc_id
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
|
||||
"doc_id": sample_doc_id["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
@@ -816,9 +824,9 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[])
|
||||
task["digest"] = hasher.hexdigest()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
|
||||
if ty in ["graphrag", "raptor", "mindmap"]:
|
||||
task["doc_ids"] = doc_ids
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
task["doc_id"] = fake_doc_id
|
||||
task["doc_ids"] = doc_ids
|
||||
DocumentService.begin2parse(sample_doc_id["id"])
|
||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
return task["id"]
|
||||
|
||||
@@ -830,7 +838,7 @@ def get_queue_length(priority):
|
||||
return int(group_info.get("lag", 0) or 0)
|
||||
|
||||
|
||||
async def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.conversation_service import ConversationService
|
||||
from api.db.services.dialog_service import DialogService
|
||||
@@ -855,7 +863,7 @@ async def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
||||
|
||||
err, files = await FileService.upload_document(kb, file_objs, user_id)
|
||||
err, files = FileService.upload_document(kb, file_objs, user_id)
|
||||
assert not err, "\n".join(err)
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
@@ -421,7 +420,7 @@ class FileService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
async def upload_document(self, kb, file_objs, user_id):
|
||||
def upload_document(self, kb, file_objs, user_id):
|
||||
root_folder = self.get_root_folder(user_id)
|
||||
pf_id = root_folder["id"]
|
||||
self.init_knowledgebase_docs(pf_id, user_id)
|
||||
@@ -441,7 +440,14 @@ class FileService(CommonService):
|
||||
while STORAGE_IMPL.obj_exist(kb.id, location):
|
||||
location += "_"
|
||||
|
||||
blob = await file.read()
|
||||
# 支持 FastAPI UploadFile,直接使用 file 属性进行同步读取
|
||||
if hasattr(file, 'file') and hasattr(file, 'filename'):
|
||||
# FastAPI UploadFile
|
||||
file.file.seek(0)
|
||||
blob = file.file.read()
|
||||
else:
|
||||
# 普通文件对象
|
||||
blob = file.read()
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
STORAGE_IMPL.put(kb.id, location, blob)
|
||||
@@ -473,33 +479,33 @@ class FileService(CommonService):
|
||||
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
|
||||
files.append((doc, blob))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
err.append(file.filename + ": " + str(e))
|
||||
|
||||
return err, files
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def list_all_files_by_parent_id(cls, parent_id):
|
||||
try:
|
||||
files = cls.model.select().where((cls.model.parent_id == parent_id) & (cls.model.id != parent_id))
|
||||
return list(files)
|
||||
except Exception:
|
||||
logging.exception("list_by_parent_id failed")
|
||||
raise RuntimeError("Database error (list_by_parent_id)!")
|
||||
|
||||
@staticmethod
|
||||
async def parse_docs(file_objs, user_id):
|
||||
def parse_docs(file_objs, user_id):
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
threads = []
|
||||
for file in file_objs:
|
||||
# Check if file has async read method (UploadFile)
|
||||
if hasattr(file, 'read') and hasattr(file.read, '__call__'):
|
||||
try:
|
||||
# Try to get the coroutine to check if it's async
|
||||
read_result = file.read()
|
||||
if hasattr(read_result, '__await__'):
|
||||
# It's an async method, await it
|
||||
blob = await read_result
|
||||
else:
|
||||
# It's a sync method
|
||||
blob = read_result
|
||||
except Exception:
|
||||
# Fallback to sync read
|
||||
blob = file.read()
|
||||
# 支持 FastAPI UploadFile,直接使用 file 属性进行同步读取
|
||||
if hasattr(file, 'file') and hasattr(file, 'filename'):
|
||||
# FastAPI UploadFile
|
||||
file.file.seek(0)
|
||||
blob = file.file.read()
|
||||
else:
|
||||
# 普通文件对象
|
||||
blob = file.read()
|
||||
|
||||
threads.append(exe.submit(FileService.parse, file.filename, blob, False))
|
||||
|
||||
res = []
|
||||
|
||||
@@ -379,6 +379,7 @@ class KnowledgebaseService(CommonService):
|
||||
# name: Optional name filter
|
||||
# Returns:
|
||||
# List of knowledge bases
|
||||
# Total count of knowledge bases
|
||||
kbs = cls.model.select()
|
||||
if id:
|
||||
kbs = kbs.where(cls.model.id == id)
|
||||
@@ -390,14 +391,16 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.tenant_id == user_id))
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
|
||||
if desc:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
total = kbs.count()
|
||||
kbs = kbs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(kbs.dicts())
|
||||
return list(kbs.dicts()), total
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
|
||||
@@ -205,32 +205,31 @@ class LLMBundle(LLM4Tenant):
|
||||
return txt
|
||||
|
||||
return txt[last_think_end + len("</think>") :]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _clean_param(chat_partial, **kwargs):
|
||||
func = chat_partial.func
|
||||
sig = inspect.signature(func)
|
||||
keyword_args = []
|
||||
support_var_args = False
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD or param.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
support_var_args = True
|
||||
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
|
||||
keyword_args.append(param.name)
|
||||
allowed_params = set()
|
||||
|
||||
use_kwargs = kwargs
|
||||
if not support_var_args:
|
||||
use_kwargs = {k: v for k, v in kwargs.items() if k in keyword_args}
|
||||
return use_kwargs
|
||||
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
support_var_args = True
|
||||
elif param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY):
|
||||
allowed_params.add(param.name)
|
||||
if support_var_args:
|
||||
return kwargs
|
||||
else:
|
||||
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
||||
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf)
|
||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
|
||||
|
||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
||||
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
txt, used_tokens = chat_partial(**use_kwargs)
|
||||
txt = self._remove_reasoning_content(txt)
|
||||
@@ -266,7 +265,7 @@ class LLMBundle(LLM4Tenant):
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans.rstrip("</think>")
|
||||
ans = ans[: -len("</think>")]
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
@@ -33,7 +33,8 @@ class MCPServerService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, keywords):
|
||||
def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc,
|
||||
keywords):
|
||||
"""Retrieve all MCP servers associated with a tenant.
|
||||
|
||||
This method fetches all MCP servers for a given tenant, ordered by creation time.
|
||||
|
||||
@@ -94,7 +94,8 @@ class SearchService(CommonService):
|
||||
query = (
|
||||
cls.model.select(*fields)
|
||||
.join(User, on=(cls.model.tenant_id == User.id))
|
||||
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value))
|
||||
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (
|
||||
cls.model.status == StatusEnum.VALID.value))
|
||||
)
|
||||
|
||||
if keywords:
|
||||
|
||||
@@ -165,7 +165,7 @@ class TaskService(CommonService):
|
||||
]
|
||||
tasks = (
|
||||
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
|
||||
.where(cls.model.doc_id == doc_id)
|
||||
.where(cls.model.doc_id == doc_id)
|
||||
)
|
||||
tasks = list(tasks.dicts())
|
||||
if not tasks:
|
||||
@@ -205,18 +205,18 @@ class TaskService(CommonService):
|
||||
cls.model.select(
|
||||
*[Document.id, Document.kb_id, Document.location, File.parent_id]
|
||||
)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(
|
||||
File2Document,
|
||||
on=(File2Document.document_id == Document.id),
|
||||
join_type=JOIN.LEFT_OUTER,
|
||||
)
|
||||
.join(
|
||||
.join(
|
||||
File,
|
||||
on=(File2Document.file_id == File.id),
|
||||
join_type=JOIN.LEFT_OUTER,
|
||||
)
|
||||
.where(
|
||||
.where(
|
||||
Document.status == StatusEnum.VALID.value,
|
||||
Document.run == TaskStatus.RUNNING.value,
|
||||
~(Document.type == FileType.VIRTUAL.value),
|
||||
@@ -294,8 +294,8 @@ class TaskService(CommonService):
|
||||
cls.model.update(progress=prog).where(
|
||||
(cls.model.id == id) &
|
||||
(
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
else:
|
||||
@@ -343,6 +343,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
- Task digests are calculated for optimization and reuse
|
||||
- Previous task chunks may be reused if available
|
||||
"""
|
||||
|
||||
def new_task():
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
@@ -350,7 +351,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
"progress": 0.0,
|
||||
"from_page": 0,
|
||||
"to_page": 100000000,
|
||||
"begin_at": datetime.now(),
|
||||
"begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
|
||||
parse_task_array = []
|
||||
@@ -502,7 +503,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE
|
||||
to_page=100000000,
|
||||
task_type="dataflow" if not rerun else "dataflow_rerun",
|
||||
priority=priority,
|
||||
begin_at=datetime.now(),
|
||||
begin_at= datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]:
|
||||
TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute()
|
||||
@@ -515,7 +516,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE
|
||||
task["file"] = file
|
||||
|
||||
if not REDIS_CONN.queue_product(
|
||||
get_svr_queue_name(priority), message=task
|
||||
get_svr_queue_name(priority), message=task
|
||||
):
|
||||
return False, "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
|
||||
@@ -57,8 +57,10 @@ class TenantLLMService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
|
||||
cls.model.used_tokens]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
||||
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
@@ -122,7 +124,8 @@ class TenantLLMService(CommonService):
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||
if not model_config:
|
||||
if mdlnm == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name,
|
||||
"api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
@@ -137,27 +140,33 @@ class TenantLLMService(CommonService):
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
||||
base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
||||
model_name=model_config["llm_name"], lang=lang,
|
||||
base_url=model_config["api_base"])
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
@@ -194,11 +203,14 @@ class TenantLLMService(CommonService):
|
||||
try:
|
||||
num = (
|
||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name,
|
||||
cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.execute()
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||
logging.exception(
|
||||
"TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s",
|
||||
tenant_id, llm_name)
|
||||
return 0
|
||||
|
||||
return num
|
||||
@@ -206,7 +218,9 @@ class TenantLLMService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"),
|
||||
~(cls.model.llm_name == "text-embedding-3-small"),
|
||||
~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@@ -250,8 +264,9 @@ class LLM4Tenant:
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||
self.langfuse = None
|
||||
if langfuse_keys:
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key,
|
||||
host=langfuse_keys.host)
|
||||
if langfuse.auth_check():
|
||||
self.langfuse = langfuse
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
|
||||
@@ -2,22 +2,22 @@ from api.db.db_models import UserCanvasVersion, DB
|
||||
from api.db.services.common_service import CommonService
|
||||
from peewee import DoesNotExist
|
||||
|
||||
|
||||
class UserCanvasVersionService(CommonService):
|
||||
model = UserCanvasVersion
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def list_by_canvas_id(cls, user_canvas_id):
|
||||
try:
|
||||
user_canvas_version = cls.model.select(
|
||||
*[cls.model.id,
|
||||
cls.model.create_time,
|
||||
cls.model.title,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date,
|
||||
cls.model.user_canvas_id,
|
||||
cls.model.update_time]
|
||||
*[cls.model.id,
|
||||
cls.model.create_time,
|
||||
cls.model.title,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date,
|
||||
cls.model.user_canvas_id,
|
||||
cls.model.update_time]
|
||||
).where(cls.model.user_canvas_id == user_canvas_id)
|
||||
return user_canvas_version
|
||||
except DoesNotExist:
|
||||
@@ -46,18 +46,16 @@ class UserCanvasVersionService(CommonService):
|
||||
@DB.connection_context()
|
||||
def delete_all_versions(cls, user_canvas_id):
|
||||
try:
|
||||
user_canvas_version = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by(cls.model.create_time.desc())
|
||||
user_canvas_version = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by(
|
||||
cls.model.create_time.desc())
|
||||
if user_canvas_version.count() > 20:
|
||||
delete_ids = []
|
||||
for i in range(20, user_canvas_version.count()):
|
||||
delete_ids.append(user_canvas_version[i].id)
|
||||
|
||||
|
||||
cls.delete_by_ids(delete_ids)
|
||||
return True
|
||||
except DoesNotExist:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user