将flask改成fastapi
This commit is contained in:
99
api/db/services/__init__.py
Normal file
99
api/db/services/__init__.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
from pathlib import PurePath
|
||||
|
||||
from .user_service import UserService as UserService
|
||||
|
||||
|
||||
def _split_name_counter(filename: str) -> tuple[str, int | None]:
|
||||
"""
|
||||
Splits a filename into main part and counter (if present in parentheses).
|
||||
|
||||
Args:
|
||||
filename: Input filename string to be parsed
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The main filename part (string)
|
||||
- The counter from parentheses (integer) or None if no counter exists
|
||||
"""
|
||||
pattern = re.compile(r"^(.*?)\((\d+)\)$")
|
||||
|
||||
match = pattern.search(filename)
|
||||
if match:
|
||||
main_part = match.group(1).rstrip()
|
||||
bracket_part = match.group(2)
|
||||
return main_part, int(bracket_part)
|
||||
|
||||
return filename, None
|
||||
|
||||
|
||||
def duplicate_name(query_func, **kwargs) -> str:
|
||||
"""
|
||||
Generates a unique filename by appending/incrementing a counter when duplicates exist.
|
||||
|
||||
Continuously checks for name availability using the provided query function,
|
||||
automatically appending (1), (2), etc. until finding an available name or
|
||||
reaching maximum retries.
|
||||
|
||||
Args:
|
||||
query_func: Callable that accepts keyword arguments and returns:
|
||||
- True if name exists (should be modified)
|
||||
- False if name is available
|
||||
**kwargs: Must contain 'name' key with original filename to check
|
||||
|
||||
Returns:
|
||||
str: Available filename, either:
|
||||
- Original name (if available)
|
||||
- Modified name with counter (e.g., "file(1).txt")
|
||||
|
||||
Raises:
|
||||
KeyError: If 'name' key not provided in kwargs
|
||||
RuntimeError: If unable to generate unique name after maximum retries
|
||||
|
||||
Example:
|
||||
>>> def name_exists(name): return name in existing_files
|
||||
>>> duplicate_name(name_exists, name="document.pdf")
|
||||
'document(1).pdf' # If original exists
|
||||
"""
|
||||
MAX_RETRIES = 1000
|
||||
|
||||
if "name" not in kwargs:
|
||||
raise KeyError("Arguments must contain 'name' key")
|
||||
|
||||
original_name = kwargs["name"]
|
||||
current_name = original_name
|
||||
retries = 0
|
||||
|
||||
while retries < MAX_RETRIES:
|
||||
if not query_func(**kwargs):
|
||||
return current_name
|
||||
|
||||
path = PurePath(current_name)
|
||||
stem = path.stem
|
||||
suffix = path.suffix
|
||||
|
||||
main_part, counter = _split_name_counter(stem)
|
||||
counter = counter + 1 if counter else 1
|
||||
|
||||
new_name = f"{main_part}({counter}){suffix}"
|
||||
|
||||
kwargs["name"] = new_name
|
||||
current_name = new_name
|
||||
retries += 1
|
||||
|
||||
raise RuntimeError(f"Failed to generate unique name within {MAX_RETRIES} attempts. Original: {original_name}")
|
||||
112
api/db/services/api_service.py
Normal file
112
api/db/services/api_service.py
Normal file
@@ -0,0 +1,112 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
import peewee
|
||||
|
||||
from api.db.db_models import DB, API4Conversation, APIToken, Dialog
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
|
||||
class APITokenService(CommonService):
|
||||
model = APIToken
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def used(cls, token):
|
||||
return cls.model.update({
|
||||
"update_time": current_timestamp(),
|
||||
"update_date": datetime_format(datetime.now()),
|
||||
}).where(
|
||||
cls.model.token == token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
|
||||
class API4ConversationService(CommonService):
|
||||
model = API4Conversation
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, dialog_id, tenant_id,
|
||||
page_number, items_per_page,
|
||||
orderby, desc, id, user_id=None, include_dsl=True, keywords="",
|
||||
from_date=None, to_date=None
|
||||
):
|
||||
if include_dsl:
|
||||
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
|
||||
else:
|
||||
fields = [field for field in cls.model._meta.fields.values() if field.name != 'dsl']
|
||||
sessions = cls.model.select(*fields).where(cls.model.dialog_id == dialog_id)
|
||||
if id:
|
||||
sessions = sessions.where(cls.model.id == id)
|
||||
if user_id:
|
||||
sessions = sessions.where(cls.model.user_id == user_id)
|
||||
if keywords:
|
||||
sessions = sessions.where(peewee.fn.LOWER(cls.model.message).contains(keywords.lower()))
|
||||
if from_date:
|
||||
sessions = sessions.where(cls.model.create_date >= from_date)
|
||||
if to_date:
|
||||
sessions = sessions.where(cls.model.create_date <= to_date)
|
||||
if desc:
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
||||
count = sessions.count()
|
||||
sessions = sessions.paginate(page_number, items_per_page)
|
||||
|
||||
return count, list(sessions.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def append_message(cls, id, conversation):
|
||||
cls.update_by_id(id, conversation)
|
||||
return cls.model.update(round=cls.model.round + 1).where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def stats(cls, tenant_id, from_date, to_date, source=None):
|
||||
if len(to_date) == 10:
|
||||
to_date += " 23:59:59"
|
||||
return cls.model.select(
|
||||
cls.model.create_date.truncate("day").alias("dt"),
|
||||
peewee.fn.COUNT(
|
||||
cls.model.id).alias("pv"),
|
||||
peewee.fn.COUNT(
|
||||
cls.model.user_id.distinct()).alias("uv"),
|
||||
peewee.fn.SUM(
|
||||
cls.model.tokens).alias("tokens"),
|
||||
peewee.fn.SUM(
|
||||
cls.model.duration).alias("duration"),
|
||||
peewee.fn.AVG(
|
||||
cls.model.round).alias("round"),
|
||||
peewee.fn.SUM(
|
||||
cls.model.thumb_up).alias("thumb_up")
|
||||
).join(Dialog, on=((cls.model.dialog_id == Dialog.id) & (Dialog.tenant_id == tenant_id))).where(
|
||||
cls.model.create_date >= from_date,
|
||||
cls.model.create_date <= to_date,
|
||||
cls.model.source == source
|
||||
).group_by(cls.model.create_date.truncate("day")).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_dialog_ids(cls, dialog_ids):
|
||||
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()
|
||||
350
api/db/services/canvas_service.py
Normal file
350
api/db/services/canvas_service.py
Normal file
@@ -0,0 +1,350 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from uuid import uuid4
|
||||
from agent.canvas import Canvas
|
||||
from api.db import CanvasCategory, TenantPermission
|
||||
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_openai
|
||||
import tiktoken
|
||||
from peewee import fn
|
||||
|
||||
|
||||
class CanvasTemplateService(CommonService):
|
||||
model = CanvasTemplate
|
||||
|
||||
class DataFlowTemplateService(CommonService):
|
||||
"""
|
||||
Alias of CanvasTemplateService
|
||||
"""
|
||||
model = CanvasTemplate
|
||||
|
||||
|
||||
class UserCanvasService(CommonService):
|
||||
model = UserCanvas
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, tenant_id,
|
||||
page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent):
|
||||
agents = cls.model.select()
|
||||
if id:
|
||||
agents = agents.where(cls.model.id == id)
|
||||
if title:
|
||||
agents = agents.where(cls.model.title == title)
|
||||
agents = agents.where(cls.model.user_id == tenant_id)
|
||||
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||
if desc:
|
||||
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
agents = agents.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
agents = agents.paginate(page_number, items_per_page)
|
||||
|
||||
return list(agents.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted agents, be cautious
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_type,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
# find team agents and owned agents
|
||||
agents = cls.model.select(*fields).where(
|
||||
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time, asc
|
||||
agents.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
ag_batch = agents.offset(offset).limit(limit)
|
||||
_temp = list(ag_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_canvas_id(cls, pid):
|
||||
try:
|
||||
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.title,
|
||||
cls.model.dsl,
|
||||
cls.model.description,
|
||||
cls.model.permission,
|
||||
cls.model.update_time,
|
||||
cls.model.user_id,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date,
|
||||
cls.model.canvas_category,
|
||||
User.nickname,
|
||||
User.avatar.alias('tenant_avatar'),
|
||||
]
|
||||
agents = cls.model.select(*fields) \
|
||||
.join(User, on=(cls.model.user_id == User.id)) \
|
||||
.where(cls.model.id == pid)
|
||||
# obj = cls.model.query(id=pid)[0]
|
||||
return True, agents.dicts()[0]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
page_number, items_per_page,
|
||||
orderby, desc, keywords, canvas_category=None
|
||||
):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.title,
|
||||
cls.model.dsl,
|
||||
cls.model.description,
|
||||
cls.model.permission,
|
||||
cls.model.user_id.alias("tenant_id"),
|
||||
User.nickname,
|
||||
User.avatar.alias('tenant_avatar'),
|
||||
cls.model.update_time,
|
||||
cls.model.canvas_category,
|
||||
]
|
||||
if keywords:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
cls.model.user_id.in_(joined_tenant_ids),
|
||||
fn.LOWER(cls.model.title).contains(keywords.lower())
|
||||
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
|
||||
#(fn.LOWER(cls.model.title).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||
cls.model.user_id.in_(joined_tenant_ids)
|
||||
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
|
||||
)
|
||||
if canvas_category:
|
||||
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||
if desc:
|
||||
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
agents = agents.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
count = agents.count()
|
||||
if page_number and items_per_page:
|
||||
agents = agents.paginate(page_number, items_per_page)
|
||||
return list(agents.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible(cls, canvas_id, tenant_id):
|
||||
from api.db.services.user_service import UserTenantService
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return False
|
||||
|
||||
tids = [t.tenant_id for t in UserTenantService.query(user_id=tenant_id)]
|
||||
if c["user_id"] != canvas_id and c["user_id"] not in tids:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
query = kwargs.get("query", "") or kwargs.get("question", "")
|
||||
files = kwargs.get("files", [])
|
||||
inputs = kwargs.get("inputs", {})
|
||||
user_id = kwargs.get("user_id", "")
|
||||
|
||||
if session_id:
|
||||
e, conv = API4ConversationService.get_by_id(session_id)
|
||||
assert e, "Session not found!"
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
if not isinstance(conv.dsl, str):
|
||||
conv.dsl = json.dumps(conv.dsl, ensure_ascii=False)
|
||||
canvas = Canvas(conv.dsl, tenant_id, agent_id)
|
||||
else:
|
||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
assert e, "Agent not found."
|
||||
assert cvs.user_id == tenant_id, "You do not own the agent."
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
session_id=get_uuid()
|
||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||
canvas.reset()
|
||||
conv = {
|
||||
"id": session_id,
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": user_id,
|
||||
"message": [],
|
||||
"source": "agent",
|
||||
"dsl": cvs.dsl,
|
||||
"reference": []
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
conv = API4Conversation(**conv)
|
||||
|
||||
message_id = str(uuid4())
|
||||
conv.message.append({
|
||||
"role": "user",
|
||||
"content": query,
|
||||
"id": message_id
|
||||
})
|
||||
txt = ""
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
ans["session_id"] = session_id
|
||||
if ans["event"] == "message":
|
||||
txt += ans["data"]["content"]
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
|
||||
conv.reference = canvas.get_reference()
|
||||
conv.errors = canvas.error
|
||||
conv.dsl = str(canvas)
|
||||
conv = conv.to_dict()
|
||||
API4ConversationService.append_message(conv["id"], conv)
|
||||
|
||||
|
||||
def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
prompt_tokens = len(tiktokenenc.encode(str(question)))
|
||||
user_id = kwargs.get("user_id", "")
|
||||
|
||||
if stream:
|
||||
completion_tokens = 0
|
||||
try:
|
||||
for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
query=question,
|
||||
user_id=user_id,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(ans, str):
|
||||
try:
|
||||
ans = json.loads(ans[5:]) # remove "data:"
|
||||
except Exception as e:
|
||||
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
|
||||
continue
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
continue
|
||||
|
||||
content_piece = ""
|
||||
if ans["event"] == "message":
|
||||
content_piece = ans["data"]["content"]
|
||||
|
||||
completion_tokens += len(tiktokenenc.encode(content_piece))
|
||||
|
||||
openai_data = get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
content=content_piece,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
openai_data["choices"][0]["delta"]["reference"] = ans["data"]["reference"]
|
||||
|
||||
yield "data: " + json.dumps(openai_data, ensure_ascii=False) + "\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data: " + json.dumps(
|
||||
get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
content=f"**ERROR**: {str(e)}",
|
||||
finish_reason="stop",
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
|
||||
stream=True
|
||||
),
|
||||
ensure_ascii=False
|
||||
) + "\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
else:
|
||||
try:
|
||||
all_content = ""
|
||||
reference = {}
|
||||
for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
query=question,
|
||||
user_id=user_id,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(ans, str):
|
||||
ans = json.loads(ans[5:])
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
continue
|
||||
|
||||
if ans["event"] == "message":
|
||||
all_content += ans["data"]["content"]
|
||||
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
completion_tokens = len(tiktokenenc.encode(all_content))
|
||||
|
||||
openai_data = get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
content=all_content,
|
||||
finish_reason="stop",
|
||||
param=None
|
||||
)
|
||||
|
||||
if reference:
|
||||
openai_data["choices"][0]["message"]["reference"] = reference
|
||||
|
||||
yield openai_data
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
|
||||
content=f"**ERROR**: {str(e)}",
|
||||
finish_reason="stop",
|
||||
param=None
|
||||
)
|
||||
345
api/db/services/common_service.py
Normal file
345
api/db/services/common_service.py
Normal file
@@ -0,0 +1,345 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
import peewee
|
||||
from peewee import InterfaceError, OperationalError
|
||||
|
||||
from api.db.db_models import DB
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
|
||||
def retry_db_operation(func):
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||
reraise=True,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
class CommonService:
|
||||
"""Base service class that provides common database operations.
|
||||
|
||||
This class serves as a foundation for all service classes in the application,
|
||||
implementing standard CRUD operations and common database query patterns.
|
||||
It uses the Peewee ORM for database interactions and provides a consistent
|
||||
interface for database operations across all derived service classes.
|
||||
|
||||
Attributes:
|
||||
model: The Peewee model class that this service operates on. Must be set by subclasses.
|
||||
"""
|
||||
|
||||
model = None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||
"""Execute a database query with optional column selection and ordering.
|
||||
|
||||
This method provides a flexible way to query the database with various filters
|
||||
and sorting options. It supports column selection, sort order control, and
|
||||
additional filter conditions.
|
||||
|
||||
Args:
|
||||
cols (list, optional): List of column names to select. If None, selects all columns.
|
||||
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
|
||||
order_by (str, optional): Column name to sort results by.
|
||||
**kwargs: Additional filter conditions passed as keyword arguments.
|
||||
|
||||
Returns:
|
||||
peewee.ModelSelect: A query result containing matching records.
|
||||
"""
|
||||
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all(cls, cols=None, reverse=None, order_by=None):
|
||||
"""Retrieve all records from the database with optional column selection and ordering.
|
||||
|
||||
This method fetches all records from the model's table with support for
|
||||
column selection and result ordering. If no order_by is specified and reverse
|
||||
is True, it defaults to ordering by create_time.
|
||||
|
||||
Args:
|
||||
cols (list, optional): List of column names to select. If None, selects all columns.
|
||||
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
|
||||
order_by (str, optional): Column name to sort results by. Defaults to 'create_time' if reverse is specified.
|
||||
|
||||
Returns:
|
||||
peewee.ModelSelect: A query containing all matching records.
|
||||
"""
|
||||
if cols:
|
||||
query_records = cls.model.select(*cols)
|
||||
else:
|
||||
query_records = cls.model.select()
|
||||
if reverse is not None:
|
||||
if not order_by or not hasattr(cls, order_by):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
|
||||
return query_records
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get(cls, **kwargs):
|
||||
"""Get a single record matching the given criteria.
|
||||
|
||||
This method retrieves a single record from the database that matches
|
||||
the specified filter conditions.
|
||||
|
||||
Args:
|
||||
**kwargs: Filter conditions as keyword arguments.
|
||||
|
||||
Returns:
|
||||
Model instance: Single matching record.
|
||||
|
||||
Raises:
|
||||
peewee.DoesNotExist: If no matching record is found.
|
||||
"""
|
||||
return cls.model.get(**kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_or_none(cls, **kwargs):
|
||||
"""Get a single record or None if not found.
|
||||
|
||||
This method attempts to retrieve a single record matching the given criteria,
|
||||
returning None if no match is found instead of raising an exception.
|
||||
|
||||
Args:
|
||||
**kwargs: Filter conditions as keyword arguments.
|
||||
|
||||
Returns:
|
||||
Model instance or None: Matching record if found, None otherwise.
|
||||
"""
|
||||
try:
|
||||
return cls.model.get(**kwargs)
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
"""Save a new record to database.
|
||||
|
||||
This method creates a new record in the database with the provided field values,
|
||||
forcing an insert operation rather than an update.
|
||||
|
||||
Args:
|
||||
**kwargs: Record field values as keyword arguments.
|
||||
|
||||
Returns:
|
||||
Model instance: The created record object.
|
||||
"""
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, **kwargs):
|
||||
"""Insert a new record with automatic ID and timestamps.
|
||||
|
||||
This method creates a new record with automatically generated ID and timestamp fields.
|
||||
It handles the creation of create_time, create_date, update_time, and update_date fields.
|
||||
|
||||
Args:
|
||||
**kwargs: Record field values as keyword arguments.
|
||||
|
||||
Returns:
|
||||
Model instance: The newly created record object.
|
||||
"""
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
kwargs["create_time"] = current_timestamp()
|
||||
kwargs["create_date"] = datetime_format(datetime.now())
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
kwargs["update_date"] = datetime_format(datetime.now())
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert_many(cls, data_list, batch_size=100):
|
||||
"""Insert multiple records in batches.
|
||||
|
||||
This method efficiently inserts multiple records into the database using batch processing.
|
||||
It automatically sets creation timestamps for all records.
|
||||
|
||||
Args:
|
||||
data_list (list): List of dictionaries containing record data to insert.
|
||||
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
|
||||
"""
|
||||
with DB.atomic():
|
||||
for d in data_list:
|
||||
d["create_time"] = current_timestamp()
|
||||
d["create_date"] = datetime_format(datetime.now())
|
||||
for i in range(0, len(data_list), batch_size):
|
||||
cls.model.insert_many(data_list[i : i + batch_size]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_many_by_id(cls, data_list):
|
||||
"""Update multiple records by their IDs.
|
||||
|
||||
This method updates multiple records in the database, identified by their IDs.
|
||||
It automatically updates the update_time and update_date fields for each record.
|
||||
|
||||
Args:
|
||||
data_list (list): List of dictionaries containing record data to update.
|
||||
Each dictionary must include an 'id' field.
|
||||
"""
|
||||
with DB.atomic():
|
||||
for data in data_list:
|
||||
data["update_time"] = current_timestamp()
|
||||
data["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@retry_db_operation
|
||||
def update_by_id(cls, pid, data):
|
||||
# Update a single record by ID
|
||||
# Args:
|
||||
# pid: Record ID
|
||||
# data: Updated field values
|
||||
# Returns:
|
||||
# Number of records updated
|
||||
data["update_time"] = current_timestamp()
|
||||
data["update_date"] = datetime_format(datetime.now())
|
||||
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_id(cls, pid):
|
||||
# Get a record by ID
|
||||
# Args:
|
||||
# pid: Record ID
|
||||
# Returns:
|
||||
# Tuple of (success, record)
|
||||
try:
|
||||
obj = cls.model.get_or_none(cls.model.id == pid)
|
||||
if obj:
|
||||
return True, obj
|
||||
except Exception:
|
||||
pass
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_ids(cls, pids, cols=None):
|
||||
# Get multiple records by their IDs
|
||||
# Args:
|
||||
# pids: List of record IDs
|
||||
# cols: List of columns to select
|
||||
# Returns:
|
||||
# Query of matching records
|
||||
if cols:
|
||||
objs = cls.model.select(*cols)
|
||||
else:
|
||||
objs = cls.model.select()
|
||||
return objs.where(cls.model.id.in_(pids))
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_id(cls, pid):
|
||||
# Delete a record by ID
|
||||
# Args:
|
||||
# pid: Record ID
|
||||
# Returns:
|
||||
# Number of records deleted
|
||||
return cls.model.delete().where(cls.model.id == pid).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_ids(cls, pids):
|
||||
# Delete multiple records by their IDs
|
||||
# Args:
|
||||
# pids: List of record IDs
|
||||
# Returns:
|
||||
# Number of records deleted
|
||||
with DB.atomic():
|
||||
res = cls.model.delete().where(cls.model.id.in_(pids)).execute()
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_delete(cls, filters):
|
||||
# Delete records matching given filters
|
||||
# Args:
|
||||
# filters: List of filter conditions
|
||||
# Returns:
|
||||
# Number of records deleted
|
||||
with DB.atomic():
|
||||
num = cls.model.delete().where(*filters).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_update(cls, filters, update_data):
|
||||
# Update records matching given filters
|
||||
# Args:
|
||||
# filters: List of filter conditions
|
||||
# update_data: Updated field values
|
||||
# Returns:
|
||||
# Number of records updated
|
||||
with DB.atomic():
|
||||
return cls.model.update(update_data).where(*filters).execute()
|
||||
|
||||
@staticmethod
|
||||
def cut_list(tar_list, n):
|
||||
# Split a list into chunks of size n
|
||||
# Args:
|
||||
# tar_list: List to split
|
||||
# n: Chunk size
|
||||
# Returns:
|
||||
# List of tuples containing chunks
|
||||
length = len(tar_list)
|
||||
arr = range(length)
|
||||
result = [tuple(tar_list[x : (x + n)]) for x in arr[::n]]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
|
||||
# Get records matching IN clause filters with optional column selection
|
||||
# Args:
|
||||
# in_key: Field name for IN clause
|
||||
# in_filters_list: List of values for IN clause
|
||||
# filters: Additional filter conditions
|
||||
# cols: List of columns to select
|
||||
# Returns:
|
||||
# List of matching records
|
||||
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
||||
if not filters:
|
||||
filters = []
|
||||
res_list = []
|
||||
if cols:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
else:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
return res_list
|
||||
242
api/db/services/conversation_service.py
Normal file
242
api/db/services/conversation_service.py
Normal file
@@ -0,0 +1,242 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
from uuid import uuid4
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import Conversation, DB
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.utils import get_uuid
|
||||
import json
|
||||
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
||||
|
||||
class ConversationService(CommonService):
|
||||
model = Conversation
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, dialog_id, page_number, items_per_page, orderby, desc, id, name, user_id=None):
|
||||
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
|
||||
if id:
|
||||
sessions = sessions.where(cls.model.id == id)
|
||||
if name:
|
||||
sessions = sessions.where(cls.model.name == name)
|
||||
if user_id:
|
||||
sessions = sessions.where(cls.model.user_id == user_id)
|
||||
if desc:
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
sessions = sessions.paginate(page_number, items_per_page)
|
||||
|
||||
return list(sessions.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
|
||||
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
|
||||
sessions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
s_batch = sessions.offset(offset).limit(limit)
|
||||
_temp = list(s_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def structure_answer(conv, ans, message_id, session_id):
|
||||
reference = ans["reference"]
|
||||
if not isinstance(reference, dict):
|
||||
reference = {}
|
||||
ans["reference"] = {}
|
||||
|
||||
chunk_list = chunks_format(reference)
|
||||
|
||||
reference["chunks"] = chunk_list
|
||||
ans["id"] = message_id
|
||||
ans["session_id"] = session_id
|
||||
|
||||
if not conv:
|
||||
return ans
|
||||
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
||||
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id})
|
||||
else:
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
|
||||
if conv.reference:
|
||||
conv.reference[-1] = reference
|
||||
return ans
|
||||
|
||||
|
||||
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||
assert name, "`name` can not be empty."
|
||||
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
assert dia, "You do not own the chat."
|
||||
|
||||
if not session_id:
|
||||
session_id = get_uuid()
|
||||
conv = {
|
||||
"id": session_id,
|
||||
"dialog_id": chat_id,
|
||||
"name": name,
|
||||
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue"), "created_at": time.time()}],
|
||||
"user_id": kwargs.get("user_id", "")
|
||||
}
|
||||
ConversationService.save(**conv)
|
||||
if stream:
|
||||
yield "data:" + json.dumps({"code": 0, "message": "",
|
||||
"data": {
|
||||
"answer": conv["message"][0]["content"],
|
||||
"reference": {},
|
||||
"audio_binary": None,
|
||||
"id": None,
|
||||
"session_id": session_id
|
||||
}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
return
|
||||
|
||||
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
|
||||
if not conv:
|
||||
raise LookupError("Session does not exist")
|
||||
|
||||
conv = conv[0]
|
||||
msg = []
|
||||
question = {
|
||||
"content": question,
|
||||
"role": "user",
|
||||
"id": str(uuid4())
|
||||
}
|
||||
conv.message.append(question)
|
||||
for m in conv.message:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
message_id = msg[-1].get("id")
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
|
||||
kb_ids = kwargs.get("kb_ids",[])
|
||||
dia.kb_ids = list(set(dia.kb_ids + kb_ids))
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
if stream:
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **kwargs):
|
||||
ans = structure_answer(conv, ans, message_id, session_id)
|
||||
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, False, **kwargs):
|
||||
answer = structure_answer(conv, ans, message_id, session_id)
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
break
|
||||
yield answer
|
||||
|
||||
|
||||
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
||||
e, dia = DialogService.get_by_id(dialog_id)
|
||||
assert e, "Dialog not found"
|
||||
if not session_id:
|
||||
session_id = get_uuid()
|
||||
conv = {
|
||||
"id": session_id,
|
||||
"dialog_id": dialog_id,
|
||||
"user_id": kwargs.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"], "created_at": time.time()}]
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "",
|
||||
"data": {
|
||||
"answer": conv["message"][0]["content"],
|
||||
"reference": {},
|
||||
"audio_binary": None,
|
||||
"id": None,
|
||||
"session_id": session_id
|
||||
}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
return
|
||||
else:
|
||||
session_id = session_id
|
||||
e, conv = API4ConversationService.get_by_id(session_id)
|
||||
assert e, "Session not found!"
|
||||
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
messages = conv.message
|
||||
question = {
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"id": str(uuid4())
|
||||
}
|
||||
messages.append(question)
|
||||
|
||||
msg = []
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
if stream:
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **kwargs):
|
||||
ans = structure_answer(conv, ans, message_id, session_id)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, False, **kwargs):
|
||||
answer = structure_answer(conv, ans, message_id, session_id)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
yield answer
|
||||
868
api/db/services/dialog_service.py
Normal file
868
api/db/services/dialog_service.py
Normal file
@@ -0,0 +1,868 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import binascii
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
from langfuse import Langfuse
|
||||
from peewee import fn
|
||||
from agentic_reasoning import DeepResearcher
|
||||
from api import settings
|
||||
from api.db import LLMType, ParserType, StatusEnum
|
||||
from api.db.db_models import DB, Dialog
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
from rag.utils import num_tokens_from_string, rmSpace
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
|
||||
class DialogService(CommonService):
|
||||
model = Dialog
|
||||
|
||||
@classmethod
|
||||
def save(cls, **kwargs):
|
||||
"""Save a new record to database.
|
||||
|
||||
This method creates a new record in the database with the provided field values,
|
||||
forcing an insert operation rather than an update.
|
||||
|
||||
Args:
|
||||
**kwargs: Record field values as keyword arguments.
|
||||
|
||||
Returns:
|
||||
Model instance: The created record object.
|
||||
"""
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@classmethod
|
||||
def update_many_by_id(cls, data_list):
|
||||
"""Update multiple records by their IDs.
|
||||
|
||||
This method updates multiple records in the database, identified by their IDs.
|
||||
It automatically updates the update_time and update_date fields for each record.
|
||||
|
||||
Args:
|
||||
data_list (list): List of dictionaries containing record data to update.
|
||||
Each dictionary must include an 'id' field.
|
||||
"""
|
||||
with DB.atomic():
|
||||
for data in data_list:
|
||||
data["update_time"] = current_timestamp()
|
||||
data["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
|
||||
chats = cls.model.select()
|
||||
if id:
|
||||
chats = chats.where(cls.model.id == id)
|
||||
if name:
|
||||
chats = chats.where(cls.model.name == name)
|
||||
chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
|
||||
if desc:
|
||||
chats = chats.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
chats = chats.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
chats = chats.paginate(page_number, items_per_page)
|
||||
|
||||
return list(chats.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None):
|
||||
from api.db.db_models import User
|
||||
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.name,
|
||||
cls.model.description,
|
||||
cls.model.language,
|
||||
cls.model.llm_id,
|
||||
cls.model.llm_setting,
|
||||
cls.model.prompt_type,
|
||||
cls.model.prompt_config,
|
||||
cls.model.similarity_threshold,
|
||||
cls.model.vector_similarity_weight,
|
||||
cls.model.top_n,
|
||||
cls.model.top_k,
|
||||
cls.model.do_refer,
|
||||
cls.model.rerank_id,
|
||||
cls.model.kb_ids,
|
||||
cls.model.icon,
|
||||
cls.model.status,
|
||||
User.nickname,
|
||||
User.avatar.alias("tenant_avatar"),
|
||||
cls.model.update_time,
|
||||
cls.model.create_time,
|
||||
]
|
||||
if keywords:
|
||||
dialogs = (
|
||||
cls.model.select(*fields)
|
||||
.join(User, on=(cls.model.tenant_id == User.id))
|
||||
.where(
|
||||
(cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower())),
|
||||
)
|
||||
)
|
||||
else:
|
||||
dialogs = (
|
||||
cls.model.select(*fields)
|
||||
.join(User, on=(cls.model.tenant_id == User.id))
|
||||
.where(
|
||||
(cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value),
|
||||
)
|
||||
)
|
||||
if parser_id:
|
||||
dialogs = dialogs.where(cls.model.parser_id == parser_id)
|
||||
if desc:
|
||||
dialogs = dialogs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
dialogs = dialogs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
count = dialogs.count()
|
||||
|
||||
if page_number and items_per_page:
|
||||
dialogs = dialogs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(dialogs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_dialogs_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
dialogs.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
d_batch = dialogs.offset(offset).limit(limit)
|
||||
_temp = list(d_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def chat_solo(dialog, messages, stream=True):
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
tts_mdl = None
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if stream:
|
||||
last_ans = ""
|
||||
delta_ans = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
delta_ans = ""
|
||||
if delta_ans:
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||
|
||||
|
||||
def get_models(dialog):
|
||||
embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None
|
||||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embedding_list) > 1:
|
||||
raise Exception("**ERROR**: Knowledge bases use different embedding models.")
|
||||
|
||||
if embedding_list:
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
if not embd_mdl:
|
||||
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
|
||||
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
if dialog.rerank_id:
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||
|
||||
if dialog.prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
|
||||
|
||||
|
||||
BAD_CITATION_PATTERNS = [
|
||||
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
|
||||
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
|
||||
re.compile(r"【\s*ID\s*[: ]*\s*(\d+)\s*】"), # 【ID: 12】
|
||||
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
|
||||
]
|
||||
|
||||
|
||||
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||
max_index = len(kbinfos["chunks"])
|
||||
|
||||
def safe_add(i):
|
||||
if 0 <= i < max_index:
|
||||
idx.add(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0):
|
||||
nonlocal answer
|
||||
|
||||
def replacement(match):
|
||||
try:
|
||||
i = int(match.group(group_index))
|
||||
if safe_add(i):
|
||||
return f"[{repl(i)}]"
|
||||
except Exception:
|
||||
pass
|
||||
return match.group(0)
|
||||
|
||||
answer = re.sub(pattern, replacement, answer, flags=flags)
|
||||
|
||||
for pattern in BAD_CITATION_PATTERNS:
|
||||
find_and_replace(pattern)
|
||||
|
||||
return answer, idx
|
||||
|
||||
|
||||
def convert_conditions(metadata_condition):
|
||||
if metadata_condition is None:
|
||||
metadata_condition = {}
|
||||
op_mapping = {
|
||||
"is": "=",
|
||||
"not is": "≠"
|
||||
}
|
||||
return [
|
||||
{
|
||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||
"key": cond["name"],
|
||||
"value": cond["value"]
|
||||
}
|
||||
for cond in metadata_condition.get("conditions", [])
|
||||
]
|
||||
|
||||
|
||||
def meta_filter(metas: dict, filters: list[dict]):
|
||||
doc_ids = set([])
|
||||
|
||||
def filter_out(v2docs, operator, value):
|
||||
ids = []
|
||||
for input, docids in v2docs.items():
|
||||
try:
|
||||
input = float(input)
|
||||
value = float(value)
|
||||
except Exception:
|
||||
input = str(input)
|
||||
value = str(value)
|
||||
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
(operator == "≠", input != value),
|
||||
(operator == ">", input > value),
|
||||
(operator == "<", input < value),
|
||||
(operator == "≥", input >= value),
|
||||
(operator == "≤", input <= value),
|
||||
]:
|
||||
try:
|
||||
if all(conds):
|
||||
ids.extend(docids)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
return ids
|
||||
|
||||
for k, v2docs in metas.items():
|
||||
for f in filters:
|
||||
if k != f["key"]:
|
||||
continue
|
||||
ids = filter_out(v2docs, f["op"], f["value"])
|
||||
if not doc_ids:
|
||||
doc_ids = set(ids)
|
||||
else:
|
||||
doc_ids = doc_ids & set(ids)
|
||||
if not doc_ids:
|
||||
return []
|
||||
return list(doc_ids)
|
||||
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||
for ans in chat_solo(dialog, messages, stream):
|
||||
yield ans
|
||||
return
|
||||
|
||||
chat_start_ts = timer()
|
||||
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
max_tokens = llm_model_config.get("max_tokens", 8192)
|
||||
|
||||
check_llm_ts = timer()
|
||||
|
||||
langfuse_tracer = None
|
||||
trace_context = {}
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
|
||||
if langfuse_keys:
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||
if langfuse.auth_check():
|
||||
langfuse_tracer = langfuse
|
||||
trace_id = langfuse_tracer.create_trace_id()
|
||||
trace_context = {"trace_id": trace_id}
|
||||
|
||||
check_langfuse_tracer_ts = timer()
|
||||
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
|
||||
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
|
||||
if toolcall_session and tools:
|
||||
chat_mdl.bind_tools(toolcall_session, tools)
|
||||
bind_models_ts = timer()
|
||||
|
||||
retriever = settings.retrievaler
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||
if "doc_ids" in messages[-1]:
|
||||
attachments = messages[-1]["doc_ids"]
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
yield ans
|
||||
return
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":
|
||||
continue
|
||||
if p["key"] not in kwargs and not p["optional"]:
|
||||
raise KeyError("Miss parameter: " + p["key"])
|
||||
if p["key"] not in kwargs:
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||
|
||||
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||||
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||
else:
|
||||
questions = questions[-1:]
|
||||
|
||||
if prompt_config.get("cross_languages"):
|
||||
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||
|
||||
if dialog.meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||
if dialog.meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
elif dialog.meta_data_filter.get("method") == "manual":
|
||||
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
|
||||
refine_question_ts = timer()
|
||||
|
||||
thought = ""
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
knowledges = []
|
||||
|
||||
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
knowledges = []
|
||||
if prompt_config.get("reasoning", False):
|
||||
reasoner = DeepResearcher(
|
||||
chat_mdl,
|
||||
prompt_config,
|
||||
partial(
|
||||
retriever.retrieval,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
kb_ids=dialog.kb_ids,
|
||||
page=1,
|
||||
page_size=dialog.top_n,
|
||||
similarity_threshold=0.2,
|
||||
vector_similarity_weight=0.3,
|
||||
doc_ids=attachments,
|
||||
),
|
||||
)
|
||||
|
||||
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
||||
if isinstance(think, str):
|
||||
thought = think
|
||||
knowledges = [t for t in think.split("\n") if t]
|
||||
elif stream:
|
||||
yield think
|
||||
else:
|
||||
if embd_mdl:
|
||||
kbinfos = retriever.retrieval(
|
||||
" ".join(questions),
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
dialog.kb_ids,
|
||||
1,
|
||||
dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k,
|
||||
aggs=False,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs),
|
||||
)
|
||||
if prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if prompt_config.get("use_kg"):
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
|
||||
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
retrieval_ts = timer()
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||
"audio_binary": tts(tts_mdl, empty_res)}
|
||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
gen_conf = dialog.llm_setting
|
||||
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
||||
prompt4citation = ""
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
prompt4citation = citation_prompt()
|
||||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
|
||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
|
||||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||||
prompt = msg[0]["content"]
|
||||
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
|
||||
|
||||
refs = []
|
||||
ans = answer.split("</think>")
|
||||
think = ""
|
||||
if len(ans) == 2:
|
||||
think = ans[0] + "</think>"
|
||||
answer = ans[1]
|
||||
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
idx = set([])
|
||||
if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer):
|
||||
answer, idx = retriever.insert_citations(
|
||||
answer,
|
||||
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
||||
[ck["vector"] for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight,
|
||||
)
|
||||
else:
|
||||
for match in re.finditer(r"\[ID:([0-9]+)\]", answer):
|
||||
i = int(match.group(1))
|
||||
if i < len(kbinfos["chunks"]):
|
||||
idx.add(i)
|
||||
|
||||
answer, idx = repair_bad_citation_formats(answer, kbinfos, idx)
|
||||
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs:
|
||||
recall_docs = kbinfos["doc_aggs"]
|
||||
kbinfos["doc_aggs"] = recall_docs
|
||||
|
||||
refs = deepcopy(kbinfos)
|
||||
for c in refs["chunks"]:
|
||||
if c.get("vector"):
|
||||
del c["vector"]
|
||||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
||||
finish_chat_ts = timer()
|
||||
|
||||
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
|
||||
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
|
||||
check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
|
||||
bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000
|
||||
refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000
|
||||
retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000
|
||||
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
||||
|
||||
tk_num = num_tokens_from_string(think + answer)
|
||||
prompt += "\n\n### Query:\n%s" % " ".join(questions)
|
||||
prompt = (
|
||||
f"{prompt}\n\n"
|
||||
"## Time elapsed:\n"
|
||||
f" - Total: {total_time_cost:.1f}ms\n"
|
||||
f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
|
||||
f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
|
||||
f" - Bind models: {bind_embedding_time_cost:.1f}ms\n"
|
||||
f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n"
|
||||
f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
|
||||
f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
|
||||
"## Token usage:\n"
|
||||
f" - Generated tokens(approximately): {tk_num}\n"
|
||||
f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s"
|
||||
)
|
||||
|
||||
# Add a condition check to call the end method only if langfuse_tracer exists
|
||||
if langfuse_tracer and "langfuse_generation" in locals():
|
||||
langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL)
|
||||
langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
|
||||
langfuse_generation.update(output=langfuse_output)
|
||||
langfuse_generation.end()
|
||||
|
||||
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
||||
|
||||
if langfuse_tracer:
|
||||
langfuse_generation = langfuse_tracer.start_generation(
|
||||
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
|
||||
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
||||
)
|
||||
|
||||
if stream:
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||
if thought:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought + answer)
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
res = decorate_answer(answer)
|
||||
res["audio_binary"] = tts(tts_mdl, answer)
|
||||
yield res
|
||||
|
||||
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
{}
|
||||
|
||||
Question are as follows:
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||
tried_times = 0
|
||||
|
||||
def get_table():
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
||||
if sql[: len("select ")] != "select ":
|
||||
return None, None
|
||||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
||||
if sql[: len("select *")] != "select *":
|
||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||
else:
|
||||
flds = []
|
||||
for k in field_map.keys():
|
||||
if k in forbidden_select_fields4resume:
|
||||
continue
|
||||
if len(flds) > 11:
|
||||
break
|
||||
flds.append(k)
|
||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||
|
||||
if kb_ids:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
if "where" not in sql.lower():
|
||||
sql += f" WHERE {kb_filter}"
|
||||
else:
|
||||
sql += f" AND {kb_filter}"
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl is None:
|
||||
return None
|
||||
if tbl.get("error") and tried_times <= 2:
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
{}
|
||||
|
||||
Question are as follows:
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
|
||||
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Error issued by database as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
|
||||
tbl, sql = get_table()
|
||||
logging.debug("TRY it again: {}".format(sql))
|
||||
|
||||
logging.debug("GET table: {}".format(tbl))
|
||||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
||||
return None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
||||
|
||||
# compose Markdown table
|
||||
columns = (
|
||||
"|" + "|".join(
|
||||
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
||||
"|Source|" if docid_idx and docid_idx else "|")
|
||||
)
|
||||
|
||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
|
||||
rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
||||
if quota:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
else:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||||
|
||||
if not docid_idx or not doc_name_idx:
|
||||
logging.warning("SQL missing field: " + sql)
|
||||
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||
|
||||
docid_idx = list(docid_idx)[0]
|
||||
doc_name_idx = list(doc_name_idx)[0]
|
||||
doc_aggs = {}
|
||||
for r in tbl["rows"]:
|
||||
if r[docid_idx] not in doc_aggs:
|
||||
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
||||
doc_aggs[r[docid_idx]]["count"] += 1
|
||||
return {
|
||||
"answer": "\n".join([columns, line, rows]),
|
||||
"reference": {
|
||||
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
|
||||
},
|
||||
"prompt": sys_prompt,
|
||||
}
|
||||
|
||||
|
||||
def tts(tts_mdl, text):
|
||||
if not tts_mdl or not text:
|
||||
return
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
|
||||
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
doc_ids = search_config.get("doc_ids", [])
|
||||
rerank_mdl = None
|
||||
kb_ids = search_config.get("kb_ids", kb_ids)
|
||||
chat_llm_name = search_config.get("chat_id", chat_llm_name)
|
||||
rerank_id = search_config.get("rerank_id", "")
|
||||
meta_data_filter = search_config.get("meta_data_filter")
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
|
||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||
if rerank_id:
|
||||
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
|
||||
max_tokens = chat_mdl.max_length
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
|
||||
kbinfos = retriever.retrieval(
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
kb_ids=kb_ids,
|
||||
page=1,
|
||||
page_size=12,
|
||||
similarity_threshold=search_config.get("similarity_threshold", 0.1),
|
||||
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
|
||||
top=search_config.get("top_k", 1024),
|
||||
doc_ids=doc_ids,
|
||||
aggs=False,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(question, kbs)
|
||||
)
|
||||
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges))
|
||||
|
||||
msg = [{"role": "user", "content": question}]
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal knowledges, kbinfos, sys_prompt
|
||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
|
||||
embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs:
|
||||
recall_docs = kbinfos["doc_aggs"]
|
||||
kbinfos["doc_aggs"] = recall_docs
|
||||
refs = deepcopy(kbinfos)
|
||||
for c in refs["chunks"]:
|
||||
if c.get("vector"):
|
||||
del c["vector"]
|
||||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||
refs["chunks"] = chunks_format(refs)
|
||||
return {"answer": answer, "reference": refs}
|
||||
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
|
||||
|
||||
def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
doc_ids = search_config.get("doc_ids", [])
|
||||
rerank_id = search_config.get("rerank_id", "")
|
||||
rerank_mdl = None
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
if not kbs:
|
||||
return {"error": "No KB selected"}
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
if rerank_id:
|
||||
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
|
||||
ranks = settings.retrievaler.retrieval(
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
kb_ids=kb_ids,
|
||||
page=1,
|
||||
page_size=12,
|
||||
similarity_threshold=search_config.get("similarity_threshold", 0.2),
|
||||
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
|
||||
top=search_config.get("top_k", 1024),
|
||||
doc_ids=doc_ids,
|
||||
aggs=False,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(question, kbs),
|
||||
)
|
||||
mindmap = MindMapExtractor(chat_mdl)
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||
return mind_map.output
|
||||
975
api/db/services/document_service.py
Normal file
975
api/db/services/document_service.py
Normal file
@@ -0,0 +1,975 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
||||
import trio
|
||||
import xxhash
|
||||
from peewee import fn, Case, JOIN
|
||||
|
||||
from api import settings
|
||||
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
||||
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole, CanvasCategory
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
|
||||
User
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils import current_timestamp, get_format_time, get_uuid
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
|
||||
@classmethod
|
||||
def get_cls_model_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.thumbnail,
|
||||
cls.model.kb_id,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
cls.model.parser_config,
|
||||
cls.model.source_type,
|
||||
cls.model.type,
|
||||
cls.model.created_by,
|
||||
cls.model.name,
|
||||
cls.model.location,
|
||||
cls.model.size,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.process_begin_at,
|
||||
cls.model.process_duration,
|
||||
cls.model.meta_fields,
|
||||
cls.model.suffix,
|
||||
cls.model.run,
|
||||
cls.model.status,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_time,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, id, name):
|
||||
fields = cls.get_cls_model_fields()
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
|
||||
.join(File, on = (File.id == File2Document.file_id))\
|
||||
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
if id:
|
||||
docs = docs.where(
|
||||
cls.model.id == id)
|
||||
if name:
|
||||
docs = docs.where(
|
||||
cls.model.name == name
|
||||
)
|
||||
if keywords:
|
||||
docs = docs.where(
|
||||
fn.LOWER(cls.model.name).contains(keywords.lower())
|
||||
)
|
||||
if desc:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
count = docs.count()
|
||||
docs = docs.paginate(page_number, items_per_page)
|
||||
return list(docs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def check_doc_health(cls, tenant_id: str, filename):
|
||||
import os
|
||||
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
|
||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
|
||||
raise RuntimeError("Exceed the maximum file number of a free user!")
|
||||
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
raise RuntimeError("Exceed the maximum length of file name!")
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, run_status, types, suffix):
|
||||
fields = cls.get_cls_model_fields()
|
||||
if keywords:
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
|
||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.join(File, on=(File.id == File2Document.file_id))\
|
||||
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(cls.model.kb_id == kb_id)
|
||||
|
||||
if run_status:
|
||||
docs = docs.where(cls.model.run.in_(run_status))
|
||||
if types:
|
||||
docs = docs.where(cls.model.type.in_(types))
|
||||
if suffix:
|
||||
docs = docs.where(cls.model.suffix.in_(suffix))
|
||||
|
||||
count = docs.count()
|
||||
if desc:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
|
||||
if page_number and items_per_page:
|
||||
docs = docs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(docs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix):
|
||||
"""
|
||||
returns:
|
||||
{
|
||||
"suffix": {
|
||||
"ppt": 1,
|
||||
"doxc": 2
|
||||
},
|
||||
"run_status": {
|
||||
"1": 2,
|
||||
"2": 2
|
||||
}
|
||||
}, total
|
||||
where "1" => RUNNING, "2" => CANCEL
|
||||
"""
|
||||
fields = cls.get_cls_model_fields()
|
||||
if keywords:
|
||||
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
|
||||
|
||||
|
||||
if run_status:
|
||||
query = query.where(cls.model.run.in_(run_status))
|
||||
if types:
|
||||
query = query.where(cls.model.type.in_(types))
|
||||
if suffix:
|
||||
query = query.where(cls.model.suffix.in_(suffix))
|
||||
|
||||
rows = query.select(cls.model.run, cls.model.suffix)
|
||||
total = rows.count()
|
||||
|
||||
suffix_counter = {}
|
||||
run_status_counter = {}
|
||||
|
||||
for row in rows:
|
||||
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
|
||||
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
|
||||
|
||||
return {
|
||||
"suffix": suffix_counter,
|
||||
"run_status": run_status_counter
|
||||
}, total
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def count_by_kb_id(cls, kb_id, keywords, run_status, types):
|
||||
if keywords:
|
||||
docs = cls.model.select().where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||
|
||||
if run_status:
|
||||
docs = docs.where(cls.model.run.in_(run_status))
|
||||
if types:
|
||||
docs = docs.where(cls.model.type.in_(types))
|
||||
|
||||
count = docs.count()
|
||||
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]):
|
||||
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(
|
||||
cls.model.kb_id == kb_id
|
||||
)
|
||||
|
||||
if keywords:
|
||||
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
if run_status:
|
||||
query = query.where(cls.model.run.in_(run_status))
|
||||
if types:
|
||||
query = query.where(cls.model.type.in_(types))
|
||||
|
||||
return int(query.scalar()) or 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
|
||||
fields = [cls.model.id]
|
||||
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_docs_by_creator_id(cls, creator_id):
|
||||
fields = [
|
||||
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
|
||||
]
|
||||
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.created_by == creator_id
|
||||
)
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, doc):
|
||||
if not cls.save(**doc):
|
||||
raise RuntimeError("Database error (Document)!")
|
||||
if not KnowledgebaseService.atomic_increase_doc_num_by_id(doc["kb_id"]):
|
||||
raise RuntimeError("Database error (Knowledgebase)!")
|
||||
return Document(**doc)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def remove_document(cls, doc, tenant_id):
|
||||
from api.db.services.task_service import TaskService
|
||||
cls.clear_chunk_num(doc.id)
|
||||
try:
|
||||
TaskService.filter_delete([Task.doc_id == doc.id])
|
||||
page = 0
|
||||
page_size = 1000
|
||||
all_chunk_ids = []
|
||||
while True:
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
page += 1
|
||||
for cid in all_chunk_ids:
|
||||
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
|
||||
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
|
||||
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
graph_source = settings.docStoreConn.getFields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
|
||||
{"remove": {"source_id": doc.id}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
||||
{"removed_kwd": "Y"},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
except Exception:
|
||||
pass
|
||||
return cls.delete_by_id(doc.id)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_newly_uploaded(cls):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.kb_id,
|
||||
cls.model.parser_id,
|
||||
cls.model.parser_config,
|
||||
cls.model.name,
|
||||
cls.model.type,
|
||||
cls.model.location,
|
||||
cls.model.size,
|
||||
Knowledgebase.tenant_id,
|
||||
Tenant.embd_id,
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= current_timestamp() - 1000 * 600,
|
||||
cls.model.run == TaskStatus.RUNNING.value) \
|
||||
.order_by(cls.model.update_time.asc())
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_unfinished_docs(cls):
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
|
||||
cls.model.run, cls.model.parser_id]
|
||||
docs = cls.model.select(*fields) \
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress < 1,
|
||||
cls.model.progress > 0)
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
|
||||
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
||||
chunk_num=cls.model.chunk_num + chunk_num,
|
||||
process_duration=cls.model.process_duration + duration).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
if num == 0:
|
||||
logging.warning("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num +
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num +
|
||||
chunk_num).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
|
||||
num = cls.model.update(token_num=cls.model.token_num - token_num,
|
||||
chunk_num=cls.model.chunk_num - chunk_num,
|
||||
process_duration=cls.model.process_duration + duration).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
if num == 0:
|
||||
raise LookupError(
|
||||
"Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
chunk_num
|
||||
).where(
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def clear_chunk_num(cls, doc_id):
|
||||
doc = cls.model.get_by_id(doc_id)
|
||||
assert doc, "Can't fine document in database."
|
||||
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
doc.token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
doc.chunk_num,
|
||||
doc_num=Knowledgebase.doc_num - 1
|
||||
).where(
|
||||
Knowledgebase.id == doc.kb_id).execute()
|
||||
return num
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def clear_chunk_num_when_rerun(cls, doc_id):
|
||||
doc = cls.model.get_by_id(doc_id)
|
||||
assert doc, "Can't fine document in database."
|
||||
|
||||
num = (
|
||||
Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num - doc.token_num,
|
||||
chunk_num=Knowledgebase.chunk_num - doc.chunk_num,
|
||||
)
|
||||
.where(Knowledgebase.id == doc.kb_id)
|
||||
.execute()
|
||||
)
|
||||
return num
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_knowledgebase_id(cls, doc_id):
|
||||
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return docs[0]["kb_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id_by_name(cls, name):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.tenant_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible(cls, doc_id, user_id):
|
||||
docs = cls.model.select(
|
||||
cls.model.id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)
|
||||
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible4deletion(cls, doc_id, user_id):
|
||||
docs = cls.model.select(cls.model.id
|
||||
).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)
|
||||
).join(
|
||||
UserTenant, on=(
|
||||
(UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
|
||||
).where(
|
||||
cls.model.id == doc_id,
|
||||
UserTenant.status == StatusEnum.VALID.value,
|
||||
((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))
|
||||
).paginate(0, 1)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_embd_id(cls, doc_id):
|
||||
docs = cls.model.select(
|
||||
Knowledgebase.embd_id).join(
|
||||
Knowledgebase, on=(
|
||||
Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return docs[0]["embd_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_chunking_config(cls, doc_id):
|
||||
configs = (
|
||||
cls.model.select(
|
||||
cls.model.id,
|
||||
cls.model.kb_id,
|
||||
cls.model.parser_id,
|
||||
cls.model.parser_config,
|
||||
Knowledgebase.language,
|
||||
Knowledgebase.embd_id,
|
||||
Tenant.id.alias("tenant_id"),
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
Tenant.llm_id,
|
||||
)
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == doc_id)
|
||||
)
|
||||
configs = configs.dicts()
|
||||
if not configs:
|
||||
return None
|
||||
return configs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_id_by_doc_name(cls, doc_name):
|
||||
fields = [cls.model.id]
|
||||
doc_id = cls.model.select(*fields) \
|
||||
.where(cls.model.name == doc_name)
|
||||
doc_id = doc_id.dicts()
|
||||
if not doc_id:
|
||||
return
|
||||
return doc_id[0]["id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_ids_by_doc_names(cls, doc_names):
|
||||
if not doc_names:
|
||||
return []
|
||||
|
||||
query = cls.model.select(cls.model.id).where(cls.model.name.in_(doc_names))
|
||||
return list(query.scalars().iterator())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_thumbnails(cls, docids):
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
|
||||
return list(cls.model.select(
|
||||
*fields).where(cls.model.id.in_(docids)).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_parser_config(cls, id, config):
|
||||
if not config:
|
||||
return
|
||||
e, d = cls.get_by_id(id)
|
||||
if not e:
|
||||
raise LookupError(f"Document({id}) not found.")
|
||||
|
||||
def dfs_update(old, new):
|
||||
for k, v in new.items():
|
||||
if k not in old:
|
||||
old[k] = v
|
||||
continue
|
||||
if isinstance(v, dict):
|
||||
assert isinstance(old[k], dict)
|
||||
dfs_update(old[k], v)
|
||||
else:
|
||||
old[k] = v
|
||||
|
||||
dfs_update(d.parser_config, config)
|
||||
if not config.get("raptor") and d.parser_config.get("raptor"):
|
||||
del d.parser_config["raptor"]
|
||||
cls.update_by_id(id, {"parser_config": d.parser_config})
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_count(cls, tenant_id):
|
||||
docs = cls.model.select(cls.model.id).join(Knowledgebase,
|
||||
on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
Knowledgebase.tenant_id == tenant_id)
|
||||
return len(docs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def begin2parse(cls, docid):
|
||||
cls.update_by_id(
|
||||
docid, {"progress": random.random() * 1 / 100.,
|
||||
"progress_msg": "Task is queued...",
|
||||
"process_begin_at": get_format_time()
|
||||
})
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_meta_fields(cls, doc_id, meta_fields):
|
||||
return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_meta_by_kbs(cls, kb_ids):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.meta_fields,
|
||||
]
|
||||
meta = {}
|
||||
for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
|
||||
doc_id = r.id
|
||||
for k,v in r.meta_fields.items():
|
||||
if k not in meta:
|
||||
meta[k] = {}
|
||||
v = str(v)
|
||||
if v not in meta[k]:
|
||||
meta[k][v] = []
|
||||
meta[k][v].append(doc_id)
|
||||
return meta
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress(cls):
|
||||
docs = cls.get_unfinished_docs()
|
||||
|
||||
cls._sync_progress(docs)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress_immediately(cls, docs:list[dict]):
|
||||
if not docs:
|
||||
return
|
||||
|
||||
cls._sync_progress(docs)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def _sync_progress(cls, docs:list[dict]):
|
||||
for d in docs:
|
||||
try:
|
||||
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
|
||||
if not tsks:
|
||||
continue
|
||||
msg = []
|
||||
prg = 0
|
||||
finished = True
|
||||
bad = 0
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run # TaskStatus.RUNNING.value
|
||||
priority = 0
|
||||
for t in tsks:
|
||||
if 0 <= t.progress < 1:
|
||||
finished = False
|
||||
if t.progress == -1:
|
||||
bad += 1
|
||||
prg += t.progress if t.progress >= 0 else 0
|
||||
if t.progress_msg.strip():
|
||||
msg.append(t.progress_msg)
|
||||
priority = max(priority, t.priority)
|
||||
prg /= len(tsks)
|
||||
if finished and bad:
|
||||
prg = -1
|
||||
status = TaskStatus.FAIL.value
|
||||
elif finished:
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
msg = "\n".join(sorted(msg))
|
||||
info = {
|
||||
"process_duration": datetime.timestamp(
|
||||
datetime.now()) -
|
||||
d["process_begin_at"].timestamp(),
|
||||
"run": status}
|
||||
if prg != 0:
|
||||
info["progress"] = prg
|
||||
if msg:
|
||||
info["progress_msg"] = msg
|
||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
|
||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
else:
|
||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
cls.update_by_id(d["id"], info)
|
||||
except Exception as e:
|
||||
if str(e).find("'0'") < 0:
|
||||
logging.exception("fetch task exception")
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_doc_count(cls, kb_id):
|
||||
return cls.model.select().where(cls.model.kb_id == kb_id).count()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_kb_doc_count(cls):
|
||||
result = {}
|
||||
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
|
||||
for row in rows:
|
||||
result[row.kb_id] = row.count
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def do_cancel(cls, doc_id):
|
||||
try:
|
||||
_, doc = DocumentService.get_by_id(doc_id)
|
||||
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
|
||||
# cancelled: run == "2" but progress can vary
|
||||
cancelled = (
|
||||
cls.model.select(fn.COUNT(1))
|
||||
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
|
||||
.scalar()
|
||||
)
|
||||
|
||||
row = (
|
||||
cls.model.select(
|
||||
# finished: progress == 1
|
||||
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"),
|
||||
|
||||
# failed: progress == -1
|
||||
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"),
|
||||
|
||||
# processing: 0 <= progress < 1
|
||||
fn.COALESCE(
|
||||
fn.SUM(
|
||||
Case(
|
||||
None,
|
||||
[
|
||||
(((cls.model.progress == 0) | ((cls.model.progress > 0) & (cls.model.progress < 1))), 1),
|
||||
],
|
||||
0,
|
||||
)
|
||||
),
|
||||
0,
|
||||
).alias("processing"),
|
||||
)
|
||||
.where(
|
||||
(cls.model.kb_id == kb_id)
|
||||
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))
|
||||
)
|
||||
.dicts()
|
||||
.get()
|
||||
)
|
||||
|
||||
return {
|
||||
"processing": int(row["processing"]),
|
||||
"finished": int(row["finished"]),
|
||||
"failed": int(row["failed"]),
|
||||
"cancelled": int(cancelled),
|
||||
}
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||
"""
|
||||
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||
Optionally, specify a list of doc_ids to determine which documents participate in the task.
|
||||
"""
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
hasher.update(str(chunking_config[field]).encode("utf-8"))
|
||||
|
||||
def new_task():
|
||||
nonlocal doc
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
|
||||
"begin_at": datetime.now(),
|
||||
}
|
||||
|
||||
task = new_task()
|
||||
for field in ["doc_id", "from_page", "to_page"]:
|
||||
hasher.update(str(task.get(field, "")).encode("utf-8"))
|
||||
hasher.update(ty.encode("utf-8"))
|
||||
task["digest"] = hasher.hexdigest()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
|
||||
if ty in ["graphrag", "raptor", "mindmap"]:
|
||||
task["doc_ids"] = doc_ids
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
return task["id"]
|
||||
|
||||
|
||||
def get_queue_length(priority):
|
||||
group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
|
||||
if not group_info:
|
||||
return 0
|
||||
return int(group_info.get("lag", 0) or 0)
|
||||
|
||||
|
||||
async def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.conversation_service import ConversationService
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.app import audio, email, naive, picture, presentation
|
||||
|
||||
e, conv = ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
e, conv = API4ConversationService.get_by_id(conversation_id)
|
||||
assert e, "Conversation not found!"
|
||||
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not dia.kb_ids:
|
||||
raise LookupError("No knowledge base associated with this conversation. "
|
||||
"Please add a knowledge base before uploading documents")
|
||||
kb_id = dia.kb_ids[0]
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
||||
|
||||
err, files = await FileService.upload_document(kb, file_objs, user_id)
|
||||
assert not err, "\n".join(err)
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
FACTORY = {
|
||||
ParserType.PRESENTATION.value: presentation,
|
||||
ParserType.PICTURE.value: picture,
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email
|
||||
}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
threads = []
|
||||
doc_nm = {}
|
||||
for d, blob in files:
|
||||
doc_nm[d["id"]] = d["name"]
|
||||
for d, blob in files:
|
||||
kwargs = {
|
||||
"callback": dummy,
|
||||
"parser_config": parser_config,
|
||||
"from_page": 0,
|
||||
"to_page": 100000,
|
||||
"tenant_id": kb.tenant_id,
|
||||
"lang": kb.language
|
||||
}
|
||||
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
|
||||
|
||||
for (docinfo, _), th in zip(files, threads):
|
||||
docs = []
|
||||
doc = {
|
||||
"doc_id": docinfo["id"],
|
||||
"kb_id": [kb.id]
|
||||
}
|
||||
for ck in th.result():
|
||||
d = deepcopy(doc)
|
||||
d.update(ck)
|
||||
d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
|
||||
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
if not d.get("image"):
|
||||
docs.append(d)
|
||||
continue
|
||||
|
||||
output_buffer = BytesIO()
|
||||
if isinstance(d["image"], bytes):
|
||||
output_buffer = BytesIO(d["image"])
|
||||
else:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
|
||||
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(kb.id, d["id"])
|
||||
d.pop("image", None)
|
||||
docs.append(d)
|
||||
|
||||
parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
|
||||
docids = [d["id"] for d, _ in files]
|
||||
chunk_counts = {id: 0 for id in docids}
|
||||
token_counts = {id: 0 for id in docids}
|
||||
es_bulk_size = 64
|
||||
|
||||
def embedding(doc_id, cnts, batch_size=16):
|
||||
nonlocal embd_mdl, chunk_counts, token_counts
|
||||
vects = []
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
|
||||
vects.extend(vts.tolist())
|
||||
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
|
||||
token_counts[doc_id] += c
|
||||
return vects
|
||||
|
||||
idxnm = search.index_name(kb.tenant_id)
|
||||
try_create_idx = True
|
||||
|
||||
_, tenant = TenantService.get_by_id(kb.tenant_id)
|
||||
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
|
||||
for doc_id in docids:
|
||||
cks = [c for c in docs if c["doc_id"] == doc_id]
|
||||
|
||||
if parser_ids[doc_id] != ParserType.PICTURE.value:
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
mindmap = MindMapExtractor(llm_bdl)
|
||||
try:
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
cks.append({
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc_id,
|
||||
"kb_id": [kb.id],
|
||||
"docnm_kwd": doc_nm[doc_id],
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
|
||||
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
|
||||
"content_with_weight": mind_map,
|
||||
"knowledge_graph_kwd": "mind_map"
|
||||
})
|
||||
except Exception as e:
|
||||
logging.exception("Mind map generation error")
|
||||
|
||||
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
assert len(cks) == len(vects)
|
||||
for i, d in enumerate(cks):
|
||||
v = vects[i]
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
||||
return [d["id"] for d, _ in files]
|
||||
|
||||
96
api/db/services/file2document_service.py
Normal file
96
api/db/services/file2document_service.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from api.db import FileSource
|
||||
from api.db.db_models import DB
|
||||
from api.db.db_models import File, File2Document
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
|
||||
class File2DocumentService(CommonService):
|
||||
model = File2Document
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_file_id(cls, file_id):
|
||||
objs = cls.model.select().where(cls.model.file_id == file_id)
|
||||
return objs
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_document_id(cls, document_id):
|
||||
objs = cls.model.select().where(cls.model.document_id == document_id)
|
||||
return objs
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_document_ids(cls, document_ids):
|
||||
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
|
||||
return list(objs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, obj):
|
||||
if not cls.save(**obj):
|
||||
raise RuntimeError("Database error (File)!")
|
||||
return File2Document(**obj)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_file_id(cls, file_id):
|
||||
return cls.model.delete().where(cls.model.file_id == file_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
|
||||
if not document_ids:
|
||||
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
|
||||
elif not file_ids:
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_id(cls, doc_id):
|
||||
return cls.model.delete().where(cls.model.document_id == doc_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_by_file_id(cls, file_id, obj):
|
||||
obj["update_time"] = current_timestamp()
|
||||
obj["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(obj).where(cls.model.id == file_id).execute()
|
||||
return File2Document(**obj)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_storage_address(cls, doc_id=None, file_id=None):
|
||||
if doc_id:
|
||||
f2d = cls.get_by_document_id(doc_id)
|
||||
else:
|
||||
f2d = cls.get_by_file_id(file_id)
|
||||
if f2d:
|
||||
file = File.get_by_id(f2d[0].file_id)
|
||||
if not file.source_type or file.source_type == FileSource.LOCAL:
|
||||
return file.parent_id, file.location
|
||||
doc_id = f2d[0].document_id
|
||||
|
||||
assert doc_id, "please specify doc_id"
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
return doc.kb_id, doc.location
|
||||
547
api/db/services/file_service.py
Normal file
547
api/db/services/file_service.py
Normal file
@@ -0,0 +1,547 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
from flask_login import current_user
|
||||
from peewee import fn
|
||||
|
||||
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType
|
||||
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
||||
from rag.llm.cv_model import GptV4
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class FileService(CommonService):
|
||||
# Service class for managing file operations and storage
|
||||
model = File
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords):
|
||||
# Get files by parent folder ID with pagination and filtering
|
||||
# Args:
|
||||
# tenant_id: ID of the tenant
|
||||
# pf_id: Parent folder ID
|
||||
# page_number: Page number for pagination
|
||||
# items_per_page: Number of items per page
|
||||
# orderby: Field to order by
|
||||
# desc: Boolean indicating descending order
|
||||
# keywords: Search keywords
|
||||
# Returns:
|
||||
# Tuple of (file_list, total_count)
|
||||
if keywords:
|
||||
files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), (fn.LOWER(cls.model.name).contains(keywords.lower())), ~(cls.model.id == pf_id))
|
||||
else:
|
||||
files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), ~(cls.model.id == pf_id))
|
||||
count = files.count()
|
||||
if desc:
|
||||
files = files.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
files = files.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
files = files.paginate(page_number, items_per_page)
|
||||
|
||||
res_files = list(files.dicts())
|
||||
for file in res_files:
|
||||
if file["type"] == FileType.FOLDER.value:
|
||||
file["size"] = cls.get_folder_size(file["id"])
|
||||
file["kbs_info"] = []
|
||||
children = list(
|
||||
cls.model.select()
|
||||
.where(
|
||||
(cls.model.tenant_id == tenant_id),
|
||||
(cls.model.parent_id == file["id"]),
|
||||
~(cls.model.id == file["id"]),
|
||||
)
|
||||
.dicts()
|
||||
)
|
||||
file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children)
|
||||
continue
|
||||
kbs_info = cls.get_kb_id_by_file_id(file["id"])
|
||||
file["kbs_info"] = kbs_info
|
||||
|
||||
return res_files, count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_id_by_file_id(cls, file_id):
|
||||
# Get knowledge base IDs associated with a file
|
||||
# Args:
|
||||
# file_id: File ID
|
||||
# Returns:
|
||||
# List of dictionaries containing knowledge base IDs and names
|
||||
kbs = (
|
||||
cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
|
||||
.join(File2Document, on=(File2Document.file_id == file_id))
|
||||
.join(Document, on=(File2Document.document_id == Document.id))
|
||||
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
||||
.where(cls.model.id == file_id)
|
||||
)
|
||||
if not kbs:
|
||||
return []
|
||||
kbs_info_list = []
|
||||
for kb in list(kbs.dicts()):
|
||||
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]})
|
||||
return kbs_info_list
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_pf_id_name(cls, id, name):
|
||||
# Get file by parent folder ID and name
|
||||
# Args:
|
||||
# id: Parent folder ID
|
||||
# name: File name
|
||||
# Returns:
|
||||
# File object or None if not found
|
||||
file = cls.model.select().where((cls.model.parent_id == id) & (cls.model.name == name))
|
||||
if file.count():
|
||||
e, file = cls.get_by_id(file[0].id)
|
||||
if not e:
|
||||
raise RuntimeError("Database error (File retrieval)!")
|
||||
return file
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_id_list_by_id(cls, id, name, count, res):
|
||||
# Recursively get list of file IDs by traversing folder structure
|
||||
# Args:
|
||||
# id: Starting folder ID
|
||||
# name: List of folder names to traverse
|
||||
# count: Current depth in traversal
|
||||
# res: List to store results
|
||||
# Returns:
|
||||
# List of file IDs
|
||||
if count < len(name):
|
||||
file = cls.get_by_pf_id_name(id, name[count])
|
||||
if file:
|
||||
res.append(file.id)
|
||||
return cls.get_id_list_by_id(file.id, name, count + 1, res)
|
||||
else:
|
||||
return res
|
||||
else:
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_innermost_file_ids(cls, folder_id, result_ids):
|
||||
# Get IDs of all files in the deepest level of folders
|
||||
# Args:
|
||||
# folder_id: Starting folder ID
|
||||
# result_ids: List to store results
|
||||
# Returns:
|
||||
# List of file IDs
|
||||
subfolders = cls.model.select().where(cls.model.parent_id == folder_id)
|
||||
if subfolders.exists():
|
||||
for subfolder in subfolders:
|
||||
cls.get_all_innermost_file_ids(subfolder.id, result_ids)
|
||||
else:
|
||||
result_ids.append(folder_id)
|
||||
return result_ids
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_file_ids_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
files.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
file_batch = files.offset(offset).limit(limit)
|
||||
_temp = list(file_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create_folder(cls, file, parent_id, name, count):
|
||||
# Recursively create folder structure
|
||||
# Args:
|
||||
# file: Current file object
|
||||
# parent_id: Parent folder ID
|
||||
# name: List of folder names to create
|
||||
# count: Current depth in creation
|
||||
# Returns:
|
||||
# Created file object
|
||||
if count > len(name) - 2:
|
||||
return file
|
||||
else:
|
||||
file = cls.insert(
|
||||
{"id": get_uuid(), "parent_id": parent_id, "tenant_id": current_user.id, "created_by": current_user.id, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value}
|
||||
)
|
||||
return cls.create_folder(file, file.id, name, count + 1)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def is_parent_folder_exist(cls, parent_id):
|
||||
# Check if parent folder exists
|
||||
# Args:
|
||||
# parent_id: Parent folder ID
|
||||
# Returns:
|
||||
# Boolean indicating if folder exists
|
||||
parent_files = cls.model.select().where(cls.model.id == parent_id)
|
||||
if parent_files.count():
|
||||
return True
|
||||
cls.delete_folder_by_pf_id(parent_id)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_root_folder(cls, tenant_id):
|
||||
# Get or create root folder for tenant
|
||||
# Args:
|
||||
# tenant_id: Tenant ID
|
||||
# Returns:
|
||||
# Root folder dictionary
|
||||
for file in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
|
||||
return file.to_dict()
|
||||
|
||||
file_id = get_uuid()
|
||||
file = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
cls.save(**file)
|
||||
return file
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_folder(cls, tenant_id):
|
||||
# Get knowledge base folder for tenant
|
||||
# Args:
|
||||
# tenant_id: Tenant ID
|
||||
# Returns:
|
||||
# Knowledge base folder dictionary
|
||||
root_folder = cls.get_root_folder(tenant_id)
|
||||
root_id = root_folder["id"]
|
||||
kb_folder = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == root_id), (cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)).first()
|
||||
if not kb_folder:
|
||||
kb_folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id)
|
||||
return kb_folder
|
||||
return kb_folder.to_dict()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""):
|
||||
# Create a new file from knowledge base
|
||||
# Args:
|
||||
# tenant_id: Tenant ID
|
||||
# name: File name
|
||||
# parent_id: Parent folder ID
|
||||
# ty: File type
|
||||
# size: File size
|
||||
# location: File location
|
||||
# Returns:
|
||||
# Created file dictionary
|
||||
for file in cls.query(tenant_id=tenant_id, parent_id=parent_id, name=name):
|
||||
return file.to_dict()
|
||||
file = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": parent_id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"name": name,
|
||||
"type": ty,
|
||||
"size": size,
|
||||
"location": location,
|
||||
"source_type": FileSource.KNOWLEDGEBASE,
|
||||
}
|
||||
cls.save(**file)
|
||||
return file
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def init_knowledgebase_docs(cls, root_id, tenant_id):
|
||||
# Initialize knowledge base documents
|
||||
# Args:
|
||||
# root_id: Root folder ID
|
||||
# tenant_id: Tenant ID
|
||||
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME) & (cls.model.parent_id == root_id)):
|
||||
return
|
||||
folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id)
|
||||
|
||||
for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id == tenant_id):
|
||||
kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"])
|
||||
for doc in DocumentService.query(kb_id=kb.id):
|
||||
FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_parent_folder(cls, file_id):
|
||||
# Get parent folder of a file
|
||||
# Args:
|
||||
# file_id: File ID
|
||||
# Returns:
|
||||
# Parent folder object
|
||||
file = cls.model.select().where(cls.model.id == file_id)
|
||||
if file.count():
|
||||
e, file = cls.get_by_id(file[0].parent_id)
|
||||
if not e:
|
||||
raise RuntimeError("Database error (File retrieval)!")
|
||||
else:
|
||||
raise RuntimeError("Database error (File doesn't exist)!")
|
||||
return file
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_parent_folders(cls, start_id):
|
||||
# Get all parent folders in path
|
||||
# Args:
|
||||
# start_id: Starting file ID
|
||||
# Returns:
|
||||
# List of parent folder objects
|
||||
parent_folders = []
|
||||
current_id = start_id
|
||||
while current_id:
|
||||
e, file = cls.get_by_id(current_id)
|
||||
if file.parent_id != file.id and e:
|
||||
parent_folders.append(file)
|
||||
current_id = file.parent_id
|
||||
else:
|
||||
parent_folders.append(file)
|
||||
break
|
||||
return parent_folders
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, file):
|
||||
# Insert a new file record
|
||||
# Args:
|
||||
# file: File data dictionary
|
||||
# Returns:
|
||||
# Created file object
|
||||
if not cls.save(**file):
|
||||
raise RuntimeError("Database error (File)!")
|
||||
return File(**file)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete(cls, file):
|
||||
#
|
||||
return cls.delete_by_id(file.id)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_pf_id(cls, folder_id):
|
||||
return cls.model.delete().where(cls.model.parent_id == folder_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_folder_by_pf_id(cls, user_id, folder_id):
|
||||
try:
|
||||
files = cls.model.select().where((cls.model.tenant_id == user_id) & (cls.model.parent_id == folder_id))
|
||||
for file in files:
|
||||
cls.delete_folder_by_pf_id(user_id, file.id)
|
||||
return (cls.model.delete().where((cls.model.tenant_id == user_id) & (cls.model.id == folder_id)).execute(),)
|
||||
except Exception:
|
||||
logging.exception("delete_folder_by_pf_id")
|
||||
raise RuntimeError("Database error (File retrieval)!")
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_file_count(cls, tenant_id):
|
||||
files = cls.model.select(cls.model.id).where(cls.model.tenant_id == tenant_id)
|
||||
return len(files)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_folder_size(cls, folder_id):
|
||||
size = 0
|
||||
|
||||
def dfs(parent_id):
|
||||
nonlocal size
|
||||
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id, cls.model.id != parent_id):
|
||||
size += f.size
|
||||
if f.type == FileType.FOLDER.value:
|
||||
dfs(f.id)
|
||||
|
||||
dfs(folder_id)
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
|
||||
for _ in File2DocumentService.get_by_document_id(doc["id"]):
|
||||
return
|
||||
file = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": kb_folder_id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"name": doc["name"],
|
||||
"type": doc["type"],
|
||||
"size": doc["size"],
|
||||
"location": doc["location"],
|
||||
"source_type": FileSource.KNOWLEDGEBASE,
|
||||
}
|
||||
cls.save(**file)
|
||||
File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]})
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def move_file(cls, file_ids, folder_id):
|
||||
try:
|
||||
cls.filter_update((cls.model.id << file_ids,), {"parent_id": folder_id})
|
||||
except Exception:
|
||||
logging.exception("move_file")
|
||||
raise RuntimeError("Database error (File move)!")
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
async def upload_document(self, kb, file_objs, user_id):
|
||||
root_folder = self.get_root_folder(user_id)
|
||||
pf_id = root_folder["id"]
|
||||
self.init_knowledgebase_docs(pf_id, user_id)
|
||||
kb_root_folder = self.get_kb_folder(user_id)
|
||||
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
err, files = [], []
|
||||
for file in file_objs:
|
||||
try:
|
||||
DocumentService.check_doc_health(kb.tenant_id, file.filename)
|
||||
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
|
||||
filetype = filename_type(filename)
|
||||
if filetype == FileType.OTHER.value:
|
||||
raise RuntimeError("This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while STORAGE_IMPL.obj_exist(kb.id, location):
|
||||
location += "_"
|
||||
|
||||
blob = await file.read()
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
STORAGE_IMPL.put(kb.id, location, blob)
|
||||
|
||||
doc_id = get_uuid()
|
||||
|
||||
img = thumbnail_img(filename, blob)
|
||||
thumbnail_location = ""
|
||||
if img is not None:
|
||||
thumbnail_location = f"thumbnail_{doc_id}.png"
|
||||
STORAGE_IMPL.put(kb.id, thumbnail_location, img)
|
||||
|
||||
doc = {
|
||||
"id": doc_id,
|
||||
"kb_id": kb.id,
|
||||
"parser_id": self.get_parser(filetype, filename, kb.parser_id),
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": user_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail_location,
|
||||
}
|
||||
DocumentService.insert(doc)
|
||||
|
||||
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
|
||||
files.append((doc, blob))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
err.append(file.filename + ": " + str(e))
|
||||
|
||||
return err, files
|
||||
|
||||
@staticmethod
|
||||
async def parse_docs(file_objs, user_id):
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
threads = []
|
||||
for file in file_objs:
|
||||
# Check if file has async read method (UploadFile)
|
||||
if hasattr(file, 'read') and hasattr(file.read, '__call__'):
|
||||
try:
|
||||
# Try to get the coroutine to check if it's async
|
||||
read_result = file.read()
|
||||
if hasattr(read_result, '__await__'):
|
||||
# It's an async method, await it
|
||||
blob = await read_result
|
||||
else:
|
||||
# It's a sync method
|
||||
blob = read_result
|
||||
except Exception:
|
||||
# Fallback to sync read
|
||||
blob = file.read()
|
||||
else:
|
||||
blob = file.read()
|
||||
|
||||
threads.append(exe.submit(FileService.parse, file.filename, blob, False))
|
||||
|
||||
res = []
|
||||
for th in threads:
|
||||
res.append(th.result())
|
||||
|
||||
return "\n\n".join(res)
|
||||
|
||||
@staticmethod
|
||||
def parse(filename, blob, img_base64=True, tenant_id=None):
|
||||
from rag.app import audio, email, naive, picture, presentation
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
|
||||
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
||||
kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": current_user.id if current_user else tenant_id}
|
||||
file_type = filename_type(filename)
|
||||
if img_base64 and file_type == FileType.VISUAL.value:
|
||||
return GptV4.image2base64(blob)
|
||||
cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs)
|
||||
return "\n".join([ck["content_with_weight"] for ck in cks])
|
||||
|
||||
@staticmethod
|
||||
def get_parser(doc_type, filename, default):
|
||||
if doc_type == FileType.VISUAL:
|
||||
return ParserType.PICTURE.value
|
||||
if doc_type == FileType.AURAL:
|
||||
return ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
return ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(msg|eml)$", filename):
|
||||
return ParserType.EMAIL.value
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def get_blob(user_id, location):
|
||||
bname = f"{user_id}-downloads"
|
||||
return STORAGE_IMPL.get(bname, location)
|
||||
|
||||
@staticmethod
|
||||
def put_blob(user_id, location, blob):
|
||||
bname = f"{user_id}-downloads"
|
||||
return STORAGE_IMPL.put(bname, location, blob)
|
||||
496
api/db/services/knowledgebase_service.py
Normal file
496
api/db/services/knowledgebase_service.py
Normal file
@@ -0,0 +1,496 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import fn, JOIN
|
||||
|
||||
from api.db import StatusEnum, TenantPermission
|
||||
from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
|
||||
class KnowledgebaseService(CommonService):
|
||||
"""Service class for managing knowledge base operations.
|
||||
|
||||
This class extends CommonService to provide specialized functionality for knowledge base
|
||||
management, including document parsing status tracking, access control, and configuration
|
||||
management. It handles operations such as listing, creating, updating, and deleting
|
||||
knowledge bases, as well as managing their associated documents and permissions.
|
||||
|
||||
The class implements a comprehensive set of methods for:
|
||||
- Document parsing status verification
|
||||
- Knowledge base access control
|
||||
- Parser configuration management
|
||||
- Tenant-based knowledge base organization
|
||||
|
||||
Attributes:
|
||||
model: The Knowledgebase model class for database operations.
|
||||
"""
|
||||
model = Knowledgebase
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible4deletion(cls, kb_id, user_id):
|
||||
"""Check if a knowledge base can be deleted by a specific user.
|
||||
|
||||
This method verifies whether a user has permission to delete a knowledge base
|
||||
by checking if they are the creator of that knowledge base.
|
||||
|
||||
Args:
|
||||
kb_id (str): The unique identifier of the knowledge base to check.
|
||||
user_id (str): The unique identifier of the user attempting the deletion.
|
||||
|
||||
Returns:
|
||||
bool: True if the user has permission to delete the knowledge base,
|
||||
False if the user doesn't have permission or the knowledge base doesn't exist.
|
||||
|
||||
Example:
|
||||
>>> KnowledgebaseService.accessible4deletion("kb123", "user456")
|
||||
True
|
||||
|
||||
Note:
|
||||
- This method only checks creator permissions
|
||||
- A return value of False can mean either:
|
||||
1. The knowledge base doesn't exist
|
||||
2. The user is not the creator of the knowledge base
|
||||
"""
|
||||
# Check if a knowledge base can be deleted by a user
|
||||
docs = cls.model.select(
|
||||
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def is_parsed_done(cls, kb_id):
|
||||
# Check if all documents in the knowledge base have completed parsing
|
||||
#
|
||||
# Args:
|
||||
# kb_id: Knowledge base ID
|
||||
#
|
||||
# Returns:
|
||||
# If all documents are parsed successfully, returns (True, None)
|
||||
# If any document is not fully parsed, returns (False, error_message)
|
||||
from api.db import TaskStatus
|
||||
from api.db.services.document_service import DocumentService
|
||||
|
||||
# Get knowledge base information
|
||||
kbs = cls.query(id=kb_id)
|
||||
if not kbs:
|
||||
return False, "Knowledge base not found"
|
||||
kb = kbs[0]
|
||||
|
||||
# Get all documents in the knowledge base
|
||||
docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "", [], [])
|
||||
|
||||
# Check parsing status of each document
|
||||
for doc in docs:
|
||||
# If document is being parsed, don't allow chat creation
|
||||
if doc['run'] == TaskStatus.RUNNING.value or doc['run'] == TaskStatus.CANCEL.value or doc['run'] == TaskStatus.FAIL.value:
|
||||
return False, f"Document '{doc['name']}' in dataset '{kb.name}' is still being parsed. Please wait until all documents are parsed before starting a chat."
|
||||
# If document is not yet parsed and has no chunks, don't allow chat creation
|
||||
if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0:
|
||||
return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat."
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def list_documents_by_ids(cls, kb_ids):
|
||||
# Get document IDs associated with given knowledge base IDs
|
||||
# Args:
|
||||
# kb_ids: List of knowledge base IDs
|
||||
# Returns:
|
||||
# List of document IDs
|
||||
doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(
|
||||
cls.model.id.in_(kb_ids)
|
||||
)
|
||||
doc_ids = list(doc_ids.dicts())
|
||||
doc_ids = [doc["document_id"] for doc in doc_ids]
|
||||
return doc_ids
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
page_number, items_per_page,
|
||||
orderby, desc, keywords,
|
||||
parser_id=None
|
||||
):
|
||||
# Get knowledge bases by tenant IDs with pagination and filtering
|
||||
# Args:
|
||||
# joined_tenant_ids: List of tenant IDs
|
||||
# user_id: Current user ID
|
||||
# page_number: Page number for pagination
|
||||
# items_per_page: Number of items per page
|
||||
# orderby: Field to order by
|
||||
# desc: Boolean indicating descending order
|
||||
# keywords: Search keywords
|
||||
# parser_id: Optional parser ID filter
|
||||
# Returns:
|
||||
# Tuple of (knowledge_base_list, total_count)
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.name,
|
||||
cls.model.language,
|
||||
cls.model.description,
|
||||
cls.model.tenant_id,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id,
|
||||
cls.model.embd_id,
|
||||
User.nickname,
|
||||
User.avatar.alias('tenant_avatar'),
|
||||
cls.model.update_time
|
||||
]
|
||||
if keywords:
|
||||
kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
|
||||
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (
|
||||
cls.model.tenant_id == user_id))
|
||||
& (cls.model.status == StatusEnum.VALID.value),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
|
||||
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (
|
||||
cls.model.tenant_id == user_id))
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if parser_id:
|
||||
kbs = kbs.where(cls.model.parser_id == parser_id)
|
||||
if desc:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
count = kbs.count()
|
||||
|
||||
if page_number and items_per_page:
|
||||
kbs = kbs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(kbs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted kb, be cautious.
|
||||
fields = [
|
||||
cls.model.name,
|
||||
cls.model.language,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.status,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date
|
||||
]
|
||||
# find team kb and owned kb
|
||||
kbs = cls.model.select(*fields).where(
|
||||
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
|
||||
cls.model.tenant_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time asc
|
||||
kbs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later.
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
kb_batch = kbs.offset(offset).limit(limit)
|
||||
_temp = list(kb_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_ids(cls, tenant_id):
|
||||
# Get all knowledge base IDs for a tenant
|
||||
# Args:
|
||||
# tenant_id: Tenant ID
|
||||
# Returns:
|
||||
# List of knowledge base IDs
|
||||
fields = [
|
||||
cls.model.id,
|
||||
]
|
||||
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
kb_ids = [kb.id for kb in kbs]
|
||||
return kb_ids
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_detail(cls, kb_id):
|
||||
# Get detailed information about a knowledge base
|
||||
# Args:
|
||||
# kb_id: Knowledge base ID
|
||||
# Returns:
|
||||
# Dictionary containing knowledge base details
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.embd_id,
|
||||
cls.model.avatar,
|
||||
cls.model.name,
|
||||
cls.model.language,
|
||||
cls.model.description,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
UserCanvas.title.alias("pipeline_name"),
|
||||
UserCanvas.avatar.alias("pipeline_avatar"),
|
||||
cls.model.parser_config,
|
||||
cls.model.pagerank,
|
||||
cls.model.graphrag_task_id,
|
||||
cls.model.graphrag_task_finish_at,
|
||||
cls.model.raptor_task_id,
|
||||
cls.model.raptor_task_finish_at,
|
||||
cls.model.mindmap_task_id,
|
||||
cls.model.mindmap_task_finish_at,
|
||||
cls.model.create_time,
|
||||
cls.model.update_time
|
||||
]
|
||||
kbs = cls.model.select(*fields)\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.id == kb_id),
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
).dicts()
|
||||
if not kbs:
|
||||
return
|
||||
return kbs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_parser_config(cls, id, config):
|
||||
# Update parser configuration for a knowledge base
|
||||
# Args:
|
||||
# id: Knowledge base ID
|
||||
# config: New parser configuration
|
||||
e, m = cls.get_by_id(id)
|
||||
if not e:
|
||||
raise LookupError(f"knowledgebase({id}) not found.")
|
||||
|
||||
def dfs_update(old, new):
|
||||
# Deep update of nested configuration
|
||||
for k, v in new.items():
|
||||
if k not in old:
|
||||
old[k] = v
|
||||
continue
|
||||
if isinstance(v, dict):
|
||||
assert isinstance(old[k], dict)
|
||||
dfs_update(old[k], v)
|
||||
elif isinstance(v, list):
|
||||
assert isinstance(old[k], list)
|
||||
old[k] = list(set(old[k] + v))
|
||||
else:
|
||||
old[k] = v
|
||||
|
||||
dfs_update(m.parser_config, config)
|
||||
cls.update_by_id(id, {"parser_config": m.parser_config})
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_field_map(cls, id):
|
||||
e, m = cls.get_by_id(id)
|
||||
if not e:
|
||||
raise LookupError(f"knowledgebase({id}) not found.")
|
||||
|
||||
m.parser_config.pop("field_map", None)
|
||||
cls.update_by_id(id, {"parser_config": m.parser_config})
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_field_map(cls, ids):
|
||||
# Get field mappings for knowledge bases
|
||||
# Args:
|
||||
# ids: List of knowledge base IDs
|
||||
# Returns:
|
||||
# Dictionary of field mappings
|
||||
conf = {}
|
||||
for k in cls.get_by_ids(ids):
|
||||
if k.parser_config and "field_map" in k.parser_config:
|
||||
conf.update(k.parser_config["field_map"])
|
||||
return conf
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_name(cls, kb_name, tenant_id):
|
||||
# Get knowledge base by name and tenant ID
|
||||
# Args:
|
||||
# kb_name: Knowledge base name
|
||||
# tenant_id: Tenant ID
|
||||
# Returns:
|
||||
# Tuple of (exists, knowledge_base)
|
||||
kb = cls.model.select().where(
|
||||
(cls.model.name == kb_name)
|
||||
& (cls.model.tenant_id == tenant_id)
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if kb:
|
||||
return True, kb[0]
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_ids(cls):
|
||||
# Get all knowledge base IDs
|
||||
# Returns:
|
||||
# List of all knowledge base IDs
|
||||
return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls, joined_tenant_ids, user_id,
|
||||
page_number, items_per_page, orderby, desc, id, name):
|
||||
# Get list of knowledge bases with filtering and pagination
|
||||
# Args:
|
||||
# joined_tenant_ids: List of tenant IDs
|
||||
# user_id: Current user ID
|
||||
# page_number: Page number for pagination
|
||||
# items_per_page: Number of items per page
|
||||
# orderby: Field to order by
|
||||
# desc: Boolean indicating descending order
|
||||
# id: Optional ID filter
|
||||
# name: Optional name filter
|
||||
# Returns:
|
||||
# List of knowledge bases
|
||||
kbs = cls.model.select()
|
||||
if id:
|
||||
kbs = kbs.where(cls.model.id == id)
|
||||
if name:
|
||||
kbs = kbs.where(cls.model.name == name)
|
||||
kbs = kbs.where(
|
||||
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (
|
||||
cls.model.tenant_id == user_id))
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if desc:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
kbs = kbs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(kbs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible(cls, kb_id, user_id):
|
||||
# Check if a knowledge base is accessible by a user
|
||||
# Args:
|
||||
# kb_id: Knowledge base ID
|
||||
# user_id: User ID
|
||||
# Returns:
|
||||
# Boolean indicating accessibility
|
||||
docs = cls.model.select(
|
||||
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_by_id(cls, kb_id, user_id):
|
||||
# Get knowledge base by ID and user ID
|
||||
# Args:
|
||||
# kb_id: Knowledge base ID
|
||||
# user_id: User ID
|
||||
# Returns:
|
||||
# List containing knowledge base information
|
||||
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
|
||||
kbs = kbs.dicts()
|
||||
return list(kbs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_by_name(cls, kb_name, user_id):
|
||||
# Get knowledge base by name and user ID
|
||||
# Args:
|
||||
# kb_name: Knowledge base name
|
||||
# user_id: User ID
|
||||
# Returns:
|
||||
# List containing knowledge base information
|
||||
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||
).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
|
||||
kbs = kbs.dicts()
|
||||
return list(kbs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def atomic_increase_doc_num_by_id(cls, kb_id):
|
||||
data = {}
|
||||
data["update_time"] = current_timestamp()
|
||||
data["update_date"] = datetime_format(datetime.now())
|
||||
data["doc_num"] = cls.model.doc_num + 1
|
||||
num = cls.model.update(data).where(cls.model.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_document_number_in_init(cls, kb_id, doc_num):
|
||||
"""
|
||||
Only use this function when init system
|
||||
"""
|
||||
ok, kb = cls.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return
|
||||
kb.doc_num = doc_num
|
||||
|
||||
dirty_fields = kb.dirty_fields
|
||||
if cls.model._meta.combined.get("update_time") in dirty_fields:
|
||||
dirty_fields.remove(cls.model._meta.combined["update_time"])
|
||||
|
||||
if cls.model._meta.combined.get("update_date") in dirty_fields:
|
||||
dirty_fields.remove(cls.model._meta.combined["update_date"])
|
||||
|
||||
try:
|
||||
kb.save(only=dirty_fields)
|
||||
except ValueError as e:
|
||||
if str(e) == "no data to save!":
|
||||
pass # that's OK
|
||||
else:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
|
||||
kb_row = cls.model.get_by_id(kb_id)
|
||||
if not kb_row:
|
||||
raise RuntimeError(f"kb_id {kb_id} does not exist")
|
||||
update_dict = {
|
||||
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
|
||||
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
|
||||
'token_num': kb_row.token_num - doc_num_info['token_num'],
|
||||
'update_time': current_timestamp(),
|
||||
'update_date': datetime_format(datetime.now())
|
||||
}
|
||||
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
|
||||
76
api/db/services/langfuse_service.py
Normal file
76
api/db/services/langfuse_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import peewee
|
||||
|
||||
from api.db.db_models import DB, TenantLangfuse
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
|
||||
class TenantLangfuseService(CommonService):
|
||||
"""
|
||||
All methods that modify the status should be enclosed within a DB.atomic() context to ensure atomicity
|
||||
and maintain data integrity in case of errors during execution.
|
||||
"""
|
||||
|
||||
model = TenantLangfuse
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_by_tenant(cls, tenant_id):
|
||||
fields = [cls.model.tenant_id, cls.model.host, cls.model.secret_key, cls.model.public_key]
|
||||
try:
|
||||
keys = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id).first()
|
||||
return keys
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_by_tenant_with_info(cls, tenant_id):
|
||||
fields = [cls.model.tenant_id, cls.model.host, cls.model.secret_key, cls.model.public_key]
|
||||
try:
|
||||
keys = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id).dicts().first()
|
||||
return keys
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_ty_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@classmethod
|
||||
def update_by_tenant(cls, tenant_id, langfuse_keys):
|
||||
langfuse_keys["update_time"] = current_timestamp()
|
||||
langfuse_keys["update_date"] = datetime_format(datetime.now())
|
||||
return cls.model.update(**langfuse_keys).where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@classmethod
|
||||
def save(cls, **kwargs):
|
||||
kwargs["create_time"] = current_timestamp()
|
||||
kwargs["create_date"] = datetime_format(datetime.now())
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
kwargs["update_date"] = datetime_format(datetime.now())
|
||||
obj = cls.model.create(**kwargs)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def delete_model(cls, langfuse_model):
|
||||
langfuse_model.delete_instance()
|
||||
279
api/db/services/llm_service.py
Normal file
279
api/db/services/llm_service.py
Normal file
@@ -0,0 +1,279 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||
|
||||
|
||||
class LLMService(CommonService):
|
||||
model = LLM
|
||||
|
||||
|
||||
def get_init_tenant_llm(user_id):
|
||||
from api import settings
|
||||
tenant_llm = []
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG,
|
||||
settings.EMBEDDING_CFG,
|
||||
settings.ASR_CFG,
|
||||
settings.IMAGE2TEXT_CFG,
|
||||
settings.RERANK_CFG,
|
||||
]:
|
||||
factory_name = factory_config["factory"]
|
||||
if factory_name not in seen:
|
||||
seen.add(factory_name)
|
||||
factory_configs.append(factory_config)
|
||||
|
||||
for factory_config in factory_configs:
|
||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
|
||||
if settings.LIGHTEN != 1:
|
||||
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": fid,
|
||||
"llm_name": mdlnm,
|
||||
"model_type": "embedding",
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
||||
}
|
||||
)
|
||||
|
||||
unique = {}
|
||||
for item in tenant_llm:
|
||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||
if key not in unique:
|
||||
unique[key] = item
|
||||
return list(unique.values())
|
||||
|
||||
|
||||
class LLMBundle(LLM4Tenant):
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
|
||||
|
||||
def bind_tools(self, toolcall_session, tools):
|
||||
if not self.is_tools:
|
||||
logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
|
||||
return
|
||||
self.mdl.bind_tools(toolcall_session, tools)
|
||||
|
||||
def encode(self, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
|
||||
embeddings, used_tokens = self.mdl.encode(texts)
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
||||
logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(usage_details={"total_tokens": used_tokens})
|
||||
generation.end()
|
||||
|
||||
return embeddings, used_tokens
|
||||
|
||||
def encode_queries(self, query: str):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
|
||||
|
||||
emd, used_tokens = self.mdl.encode_queries(query)
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
|
||||
logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(usage_details={"total_tokens": used_tokens})
|
||||
generation.end()
|
||||
|
||||
return emd, used_tokens
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
|
||||
|
||||
sim, used_tokens = self.mdl.similarity(query, texts)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(usage_details={"total_tokens": used_tokens})
|
||||
generation.end()
|
||||
|
||||
return sim, used_tokens
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
|
||||
|
||||
txt, used_tokens = self.mdl.describe(image)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||
generation.end()
|
||||
|
||||
return txt
|
||||
|
||||
def describe_with_prompt(self, image, prompt):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
|
||||
|
||||
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||
generation.end()
|
||||
|
||||
return txt
|
||||
|
||||
def transcription(self, audio):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
|
||||
|
||||
txt, used_tokens = self.mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||
generation.end()
|
||||
|
||||
return txt
|
||||
|
||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
|
||||
|
||||
for chunk in self.mdl.tts(text):
|
||||
if isinstance(chunk, int):
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
return
|
||||
yield chunk
|
||||
|
||||
if self.langfuse:
|
||||
generation.end()
|
||||
|
||||
def _remove_reasoning_content(self, txt: str) -> str:
|
||||
first_think_start = txt.find("<think>")
|
||||
if first_think_start == -1:
|
||||
return txt
|
||||
|
||||
last_think_end = txt.rfind("</think>")
|
||||
if last_think_end == -1:
|
||||
return txt
|
||||
|
||||
if last_think_end < first_think_start:
|
||||
return txt
|
||||
|
||||
return txt[last_think_end + len("</think>") :]
|
||||
|
||||
@staticmethod
|
||||
def _clean_param(chat_partial, **kwargs):
|
||||
func = chat_partial.func
|
||||
sig = inspect.signature(func)
|
||||
keyword_args = []
|
||||
support_var_args = False
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD or param.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
support_var_args = True
|
||||
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
|
||||
keyword_args.append(param.name)
|
||||
|
||||
use_kwargs = kwargs
|
||||
if not support_var_args:
|
||||
use_kwargs = {k: v for k, v in kwargs.items() if k in keyword_args}
|
||||
return use_kwargs
|
||||
|
||||
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf)
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
|
||||
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
txt, used_tokens = chat_partial(**use_kwargs)
|
||||
txt = self._remove_reasoning_content(txt)
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||
generation.end()
|
||||
|
||||
return txt
|
||||
|
||||
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
ans = ""
|
||||
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
|
||||
total_tokens = 0
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
if self.langfuse:
|
||||
generation.update(output={"output": ans})
|
||||
generation.end()
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans.rstrip("</think>")
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
ans += txt
|
||||
yield ans
|
||||
|
||||
if total_tokens > 0:
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
||||
91
api/db/services/mcp_server_service.py
Normal file
91
api/db/services/mcp_server_service.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from peewee import fn
|
||||
|
||||
from api.db.db_models import DB, MCPServer
|
||||
from api.db.services.common_service import CommonService
|
||||
|
||||
|
||||
class MCPServerService(CommonService):
|
||||
"""Service class for managing MCP server related database operations.
|
||||
|
||||
This class extends CommonService to provide specialized functionality for MCP server management,
|
||||
including MCP server creation, updates, and deletions.
|
||||
|
||||
Attributes:
|
||||
model: The MCPServer model class for database operations.
|
||||
"""
|
||||
|
||||
model = MCPServer
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, keywords):
|
||||
"""Retrieve all MCP servers associated with a tenant.
|
||||
|
||||
This method fetches all MCP servers for a given tenant, ordered by creation time.
|
||||
It only includes fields for list display.
|
||||
|
||||
Args:
|
||||
tenant_id (str): The unique identifier of the tenant.
|
||||
id_list (list[str]): Get servers by ID list. Will ignore this condition if None.
|
||||
|
||||
Returns:
|
||||
list[dict]: List of MCP server dictionaries containing MCP server details.
|
||||
Returns None if no MCP servers are found.
|
||||
"""
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.name,
|
||||
cls.model.server_type,
|
||||
cls.model.url,
|
||||
cls.model.description,
|
||||
cls.model.variables,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
query = cls.model.select(*fields).order_by(cls.model.create_time.desc()).where(cls.model.tenant_id == tenant_id)
|
||||
|
||||
if id_list:
|
||||
query = query.where(cls.model.id.in_(id_list))
|
||||
if keywords:
|
||||
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
if desc:
|
||||
query = query.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
query = query.order_by(cls.model.getter_by(orderby).asc())
|
||||
if page_number and items_per_page:
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
servers = list(query.dicts())
|
||||
if not servers:
|
||||
return None
|
||||
return servers
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_name_and_tenant(cls, name: str, tenant_id: str):
|
||||
try:
|
||||
mcp_server = cls.model.query(name=name, tenant_id=tenant_id)
|
||||
return bool(mcp_server), mcp_server
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id: str):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
263
api/db/services/pipeline_operation_log_service.py
Normal file
263
api/db/services/pipeline_operation_log_service.py
Normal file
@@ -0,0 +1,263 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from api.db import VALID_PIPELINE_TASK_TYPES, PipelineTaskType
|
||||
from api.db.db_models import DB, Document, PipelineOperationLog
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
|
||||
|
||||
class PipelineOperationLogService(CommonService):
|
||||
model = PipelineOperationLog
|
||||
|
||||
@classmethod
|
||||
def get_file_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.document_id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.pipeline_id,
|
||||
cls.model.pipeline_title,
|
||||
cls.model.parser_id,
|
||||
cls.model.document_name,
|
||||
cls.model.document_suffix,
|
||||
cls.model.document_type,
|
||||
cls.model.source_from,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.process_begin_at,
|
||||
cls.model.process_duration,
|
||||
cls.model.dsl,
|
||||
cls.model.task_type,
|
||||
cls.model.operation_status,
|
||||
cls.model.avatar,
|
||||
cls.model.status,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_time,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_dataset_logs_fields(cls):
|
||||
return [
|
||||
cls.model.id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.progress,
|
||||
cls.model.progress_msg,
|
||||
cls.model.process_begin_at,
|
||||
cls.model.process_duration,
|
||||
cls.model.task_type,
|
||||
cls.model.operation_status,
|
||||
cls.model.avatar,
|
||||
cls.model.status,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date,
|
||||
cls.model.update_time,
|
||||
cls.model.update_date,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def save(cls, **kwargs):
|
||||
"""
|
||||
wrap this function in a transaction
|
||||
"""
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl: str = "{}"):
|
||||
referred_document_id = document_id
|
||||
|
||||
if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
|
||||
referred_document_id = fake_document_ids[0]
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
DocumentService.update_progress_immediately([document.to_dict()])
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
if document.progress not in [1, -1]:
|
||||
return
|
||||
operation_status = document.run
|
||||
|
||||
if pipeline_id:
|
||||
ok, user_pipeline = UserCanvasService.get_by_id(pipeline_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Pipeline {pipeline_id} not found")
|
||||
tenant_id = user_pipeline.user_id
|
||||
title = user_pipeline.title
|
||||
avatar = user_pipeline.avatar
|
||||
else:
|
||||
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
|
||||
|
||||
tenant_id = kb_info.tenant_id
|
||||
title = document.parser_id
|
||||
avatar = document.thumbnail
|
||||
|
||||
if task_type not in VALID_PIPELINE_TASK_TYPES:
|
||||
raise ValueError(f"Invalid task type: {task_type}")
|
||||
|
||||
if task_type in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||
finish_at = document.process_begin_at + timedelta(seconds=document.process_duration)
|
||||
if task_type == PipelineTaskType.GRAPH_RAG:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"graphrag_task_finish_at": finish_at},
|
||||
)
|
||||
elif task_type == PipelineTaskType.RAPTOR:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"raptor_task_finish_at": finish_at},
|
||||
)
|
||||
elif task_type == PipelineTaskType.MINDMAP:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"mindmap_task_finish_at": finish_at},
|
||||
)
|
||||
|
||||
log = dict(
|
||||
id=get_uuid(),
|
||||
document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id
|
||||
tenant_id=tenant_id,
|
||||
kb_id=document.kb_id,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_title=title,
|
||||
parser_id=document.parser_id,
|
||||
document_name=document.name,
|
||||
document_suffix=document.suffix,
|
||||
document_type=document.type,
|
||||
source_from="", # TODO: add in the future
|
||||
progress=document.progress,
|
||||
progress_msg=document.progress_msg,
|
||||
process_begin_at=document.process_begin_at,
|
||||
process_duration=document.process_duration,
|
||||
dsl=json.loads(dsl),
|
||||
task_type=task_type,
|
||||
operation_status=operation_status,
|
||||
avatar=avatar,
|
||||
)
|
||||
log["create_time"] = current_timestamp()
|
||||
log["create_date"] = datetime_format(datetime.now())
|
||||
log["update_time"] = current_timestamp()
|
||||
log["update_date"] = datetime_format(datetime.now())
|
||||
|
||||
with DB.atomic():
|
||||
obj = cls.save(**log)
|
||||
|
||||
limit = int(os.getenv("PIPELINE_OPERATION_LOG_LIMIT", 1000))
|
||||
total = cls.model.select().where(cls.model.kb_id == document.kb_id).count()
|
||||
|
||||
if total > limit:
|
||||
keep_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == document.kb_id).order_by(cls.model.create_time.desc()).limit(limit)]
|
||||
|
||||
deleted = cls.model.delete().where(cls.model.kb_id == document.kb_id, cls.model.id.not_in(keep_ids)).execute()
|
||||
logging.info(f"[PipelineOperationLogService] Cleaned {deleted} old logs, kept latest {limit} for {document.kb_id}")
|
||||
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from=None, create_date_to=None):
|
||||
fields = cls.get_file_logs_fields()
|
||||
if keywords:
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
|
||||
else:
|
||||
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||
|
||||
logs = logs.where(cls.model.document_id != GRAPH_RAPTOR_FAKE_DOC_ID)
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
if types:
|
||||
logs = logs.where(cls.model.document_type.in_(types))
|
||||
if suffix:
|
||||
logs = logs.where(cls.model.document_suffix.in_(suffix))
|
||||
if create_date_from:
|
||||
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||
if create_date_to:
|
||||
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
if page_number and items_per_page:
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_documents_info(cls, id):
|
||||
fields = [Document.id, Document.name, Document.progress, Document.kb_id]
|
||||
return (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.document_id == Document.id))
|
||||
.where(
|
||||
cls.model.id == id
|
||||
)
|
||||
.dicts()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None):
|
||||
fields = cls.get_dataset_logs_fields()
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
if create_date_from:
|
||||
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||
if create_date_to:
|
||||
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
logs = logs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
if page_number and items_per_page:
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
117
api/db/services/search_service.py
Normal file
117
api/db/services/search_service.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import DB, Search, User
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
|
||||
class SearchService(CommonService):
|
||||
model = Search
|
||||
|
||||
@classmethod
|
||||
def save(cls, **kwargs):
|
||||
kwargs["create_time"] = current_timestamp()
|
||||
kwargs["create_date"] = datetime_format(datetime.now())
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
kwargs["update_date"] = datetime_format(datetime.now())
|
||||
obj = cls.model.create(**kwargs)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def accessible4deletion(cls, search_id, user_id) -> bool:
|
||||
search = (
|
||||
cls.model.select(cls.model.id)
|
||||
.where(
|
||||
cls.model.id == search_id,
|
||||
cls.model.created_by == user_id,
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return search is not None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_detail(cls, search_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.tenant_id,
|
||||
cls.model.name,
|
||||
cls.model.description,
|
||||
cls.model.created_by,
|
||||
cls.model.search_config,
|
||||
cls.model.update_time,
|
||||
User.nickname,
|
||||
User.avatar.alias("tenant_avatar"),
|
||||
]
|
||||
search = (
|
||||
cls.model.select(*fields)
|
||||
.join(User, on=((User.id == cls.model.tenant_id) & (User.status == StatusEnum.VALID.value)))
|
||||
.where((cls.model.id == search_id) & (cls.model.status == StatusEnum.VALID.value))
|
||||
.first()
|
||||
.to_dict()
|
||||
)
|
||||
if not search:
|
||||
return {}
|
||||
return search
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.tenant_id,
|
||||
cls.model.name,
|
||||
cls.model.description,
|
||||
cls.model.created_by,
|
||||
cls.model.status,
|
||||
cls.model.update_time,
|
||||
cls.model.create_time,
|
||||
User.nickname,
|
||||
User.avatar.alias("tenant_avatar"),
|
||||
]
|
||||
query = (
|
||||
cls.model.select(*fields)
|
||||
.join(User, on=(cls.model.tenant_id == User.id))
|
||||
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value))
|
||||
)
|
||||
|
||||
if keywords:
|
||||
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
if desc:
|
||||
query = query.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
query = query.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
count = query.count()
|
||||
|
||||
if page_number and items_per_page:
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
522
api/db/services/task_service.py
Normal file
522
api/db/services/task_service.py
Normal file
@@ -0,0 +1,522 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import xxhash
|
||||
from datetime import datetime
|
||||
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from deepdoc.parser import PdfParser
|
||||
from peewee import JOIN
|
||||
from api.db.db_models import DB, File2Document, File
|
||||
from api.db import StatusEnum, FileType, TaskStatus
|
||||
from api.db.db_models import Task, Document, Knowledgebase, Tenant
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.utils import current_timestamp, get_uuid
|
||||
from deepdoc.parser.excel_parser import RAGFlowExcelParser
|
||||
from rag.settings import get_svr_queue_name
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
|
||||
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||
GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x"
|
||||
|
||||
def trim_header_by_lines(text: str, max_length) -> str:
|
||||
# Trim header text to maximum length while preserving line breaks
|
||||
# Args:
|
||||
# text: Input text to trim
|
||||
# max_length: Maximum allowed length
|
||||
# Returns:
|
||||
# Trimmed text
|
||||
len_text = len(text)
|
||||
if len_text <= max_length:
|
||||
return text
|
||||
for i in range(len_text):
|
||||
if text[i] == '\n' and len_text - i <= max_length:
|
||||
return text[i + 1:]
|
||||
return text
|
||||
|
||||
|
||||
class TaskService(CommonService):
|
||||
"""Service class for managing document processing tasks.
|
||||
|
||||
This class extends CommonService to provide specialized functionality for document
|
||||
processing task management, including task creation, progress tracking, and chunk
|
||||
management. It handles various document types (PDF, Excel, etc.) and manages their
|
||||
processing lifecycle.
|
||||
|
||||
The class implements a robust task queue system with retry mechanisms and progress
|
||||
tracking, supporting both synchronous and asynchronous task execution.
|
||||
|
||||
Attributes:
|
||||
model: The Task model class for database operations.
|
||||
"""
|
||||
model = Task
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_task(cls, task_id, doc_ids=[]):
|
||||
"""Retrieve detailed task information by task ID.
|
||||
|
||||
This method fetches comprehensive task details including associated document,
|
||||
knowledge base, and tenant information. It also handles task retry logic and
|
||||
progress updates.
|
||||
|
||||
Args:
|
||||
task_id (str): The unique identifier of the task to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Task details dictionary containing all task information and related metadata.
|
||||
Returns None if task is not found or has exceeded retry limit.
|
||||
"""
|
||||
doc_id = cls.model.doc_id
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
|
||||
doc_id = doc_ids[0]
|
||||
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.doc_id,
|
||||
cls.model.from_page,
|
||||
cls.model.to_page,
|
||||
cls.model.retry_count,
|
||||
Document.kb_id,
|
||||
Document.parser_id,
|
||||
Document.parser_config,
|
||||
Document.name,
|
||||
Document.type,
|
||||
Document.location,
|
||||
Document.size,
|
||||
Knowledgebase.tenant_id,
|
||||
Knowledgebase.language,
|
||||
Knowledgebase.embd_id,
|
||||
Knowledgebase.pagerank,
|
||||
Knowledgebase.parser_config.alias("kb_parser_config"),
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
Tenant.llm_id,
|
||||
cls.model.update_time,
|
||||
]
|
||||
docs = (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(doc_id == Document.id))
|
||||
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == task_id)
|
||||
)
|
||||
docs = list(docs.dicts())
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
msg = f"\n{datetime.now().strftime('%H:%M:%S')} Task has been received."
|
||||
prog = random.random() / 10.0
|
||||
if docs[0]["retry_count"] >= 3:
|
||||
msg = "\nERROR: Task is abandoned after 3 times attempts."
|
||||
prog = -1
|
||||
|
||||
cls.model.update(
|
||||
progress_msg=cls.model.progress_msg + msg,
|
||||
progress=prog,
|
||||
retry_count=docs[0]["retry_count"] + 1,
|
||||
).where(cls.model.id == docs[0]["id"]).execute()
|
||||
|
||||
if docs[0]["retry_count"] >= 3:
|
||||
return None
|
||||
|
||||
return docs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tasks(cls, doc_id: str):
|
||||
"""Retrieve all tasks associated with a document.
|
||||
|
||||
This method fetches all processing tasks for a given document, ordered by page
|
||||
number and creation time. It includes task progress and chunk information.
|
||||
|
||||
Args:
|
||||
doc_id (str): The unique identifier of the document.
|
||||
|
||||
Returns:
|
||||
list[dict]: List of task dictionaries containing task details.
|
||||
Returns None if no tasks are found.
|
||||
"""
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.from_page,
|
||||
cls.model.progress,
|
||||
cls.model.digest,
|
||||
cls.model.chunk_ids,
|
||||
]
|
||||
tasks = (
|
||||
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
|
||||
.where(cls.model.doc_id == doc_id)
|
||||
)
|
||||
tasks = list(tasks.dicts())
|
||||
if not tasks:
|
||||
return None
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
||||
"""Update the chunk IDs associated with a task.
|
||||
|
||||
This method updates the chunk_ids field of a task, which stores the IDs of
|
||||
processed document chunks in a space-separated string format.
|
||||
|
||||
Args:
|
||||
id (str): The unique identifier of the task.
|
||||
chunk_ids (str): Space-separated string of chunk identifiers.
|
||||
"""
|
||||
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_ongoing_doc_name(cls):
|
||||
"""Get names of documents that are currently being processed.
|
||||
|
||||
This method retrieves information about documents that are in the processing state,
|
||||
including their locations and associated IDs. It uses database locking to ensure
|
||||
thread safety when accessing the task information.
|
||||
|
||||
Returns:
|
||||
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
|
||||
for documents currently being processed. Returns empty list if
|
||||
no documents are being processed.
|
||||
"""
|
||||
with DB.lock("get_task", -1):
|
||||
docs = (
|
||||
cls.model.select(
|
||||
*[Document.id, Document.kb_id, Document.location, File.parent_id]
|
||||
)
|
||||
.join(Document, on=(cls.model.doc_id == Document.id))
|
||||
.join(
|
||||
File2Document,
|
||||
on=(File2Document.document_id == Document.id),
|
||||
join_type=JOIN.LEFT_OUTER,
|
||||
)
|
||||
.join(
|
||||
File,
|
||||
on=(File2Document.file_id == File.id),
|
||||
join_type=JOIN.LEFT_OUTER,
|
||||
)
|
||||
.where(
|
||||
Document.status == StatusEnum.VALID.value,
|
||||
Document.run == TaskStatus.RUNNING.value,
|
||||
~(Document.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress < 1,
|
||||
cls.model.create_time >= current_timestamp() - 1000 * 600,
|
||||
)
|
||||
)
|
||||
docs = list(docs.dicts())
|
||||
if not docs:
|
||||
return []
|
||||
|
||||
return list(
|
||||
set(
|
||||
[
|
||||
(
|
||||
d["parent_id"] if d["parent_id"] else d["kb_id"],
|
||||
d["location"],
|
||||
)
|
||||
for d in docs
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def do_cancel(cls, id):
|
||||
"""Check if a task should be cancelled based on its document status.
|
||||
|
||||
This method determines whether a task should be cancelled by checking the
|
||||
associated document's run status and progress. A task should be cancelled
|
||||
if its document is marked for cancellation or has negative progress.
|
||||
|
||||
Args:
|
||||
id (str): The unique identifier of the task to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the task should be cancelled, False otherwise.
|
||||
"""
|
||||
task = cls.model.get_by_id(id)
|
||||
_, doc = DocumentService.get_by_id(task.doc_id)
|
||||
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress(cls, id, info):
|
||||
"""Update the progress information for a task.
|
||||
|
||||
This method updates both the progress message and completion percentage of a task.
|
||||
It handles platform-specific behavior (macOS vs others) and uses database locking
|
||||
when necessary to ensure thread safety.
|
||||
|
||||
Update Rules:
|
||||
- progress_msg: Always appends the new message to the existing one, and trims the result to max 3000 lines.
|
||||
- progress: Only updates if the current progress is not -1 AND
|
||||
(the new progress is -1 OR greater than the existing progress),
|
||||
to avoid overwriting valid progress with invalid or regressive values.
|
||||
|
||||
Args:
|
||||
id (str): The unique identifier of the task to update.
|
||||
info (dict): Dictionary containing progress information with keys:
|
||||
- progress_msg (str, optional): Progress message to append
|
||||
- progress (float, optional): Progress percentage (0.0 to 1.0)
|
||||
"""
|
||||
task = cls.model.get_by_id(id)
|
||||
if not task:
|
||||
logging.warning("Update_progress error: task not found")
|
||||
return
|
||||
|
||||
if os.environ.get("MACOS"):
|
||||
if info["progress_msg"]:
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
prog = info["progress"]
|
||||
cls.model.update(progress=prog).where(
|
||||
(cls.model.id == id) &
|
||||
(
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
else:
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
prog = info["progress"]
|
||||
cls.model.update(progress=prog).where(
|
||||
(cls.model.id == id) &
|
||||
(
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
|
||||
process_duration = (datetime.now() - task.begin_at).total_seconds()
|
||||
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_doc_ids(cls, doc_ids):
|
||||
"""Delete task associated with a document."""
|
||||
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
|
||||
|
||||
|
||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
"""Create and queue document processing tasks.
|
||||
|
||||
This function creates processing tasks for a document based on its type and configuration.
|
||||
It handles different document types (PDF, Excel, etc.) differently and manages task
|
||||
chunking and configuration. It also implements task reuse optimization by checking
|
||||
for previously completed tasks.
|
||||
|
||||
Args:
|
||||
doc (dict): Document dictionary containing metadata and configuration.
|
||||
bucket (str): Storage bucket name where the document is stored.
|
||||
name (str): File name of the document.
|
||||
priority (int, optional): Priority level for task queueing (default is 0).
|
||||
|
||||
Note:
|
||||
- For PDF documents, tasks are created per page range based on configuration
|
||||
- For Excel documents, tasks are created per row range
|
||||
- Task digests are calculated for optimization and reuse
|
||||
- Previous task chunks may be reused if available
|
||||
"""
|
||||
def new_task():
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"progress": 0.0,
|
||||
"from_page": 0,
|
||||
"to_page": 100000000,
|
||||
"begin_at": datetime.now(),
|
||||
}
|
||||
|
||||
parse_task_array = []
|
||||
|
||||
if doc["type"] == FileType.PDF.value:
|
||||
file_bin = STORAGE_IMPL.get(bucket, name)
|
||||
do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
|
||||
pages = PdfParser.total_page_number(doc["name"], file_bin)
|
||||
if pages is None:
|
||||
pages = 0
|
||||
page_size = doc["parser_config"].get("task_page_size") or 12
|
||||
if doc["parser_id"] == "paper":
|
||||
page_size = doc["parser_config"].get("task_page_size") or 22
|
||||
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC" or doc["parser_config"].get("toc", True):
|
||||
page_size = 10 ** 9
|
||||
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
|
||||
for s, e in page_ranges:
|
||||
s -= 1
|
||||
s = max(0, s)
|
||||
e = min(e - 1, pages)
|
||||
for p in range(s, e, page_size):
|
||||
task = new_task()
|
||||
task["from_page"] = p
|
||||
task["to_page"] = min(p + page_size, e)
|
||||
parse_task_array.append(task)
|
||||
|
||||
elif doc["parser_id"] == "table":
|
||||
file_bin = STORAGE_IMPL.get(bucket, name)
|
||||
rn = RAGFlowExcelParser.row_number(doc["name"], file_bin)
|
||||
for i in range(0, rn, 3000):
|
||||
task = new_task()
|
||||
task["from_page"] = i
|
||||
task["to_page"] = min(i + 3000, rn)
|
||||
parse_task_array.append(task)
|
||||
else:
|
||||
parse_task_array.append(new_task())
|
||||
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
for task in parse_task_array:
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
if field == "parser_config":
|
||||
for k in ["raptor", "graphrag"]:
|
||||
if k in chunking_config[field]:
|
||||
del chunking_config[field][k]
|
||||
hasher.update(str(chunking_config[field]).encode("utf-8"))
|
||||
for field in ["doc_id", "from_page", "to_page"]:
|
||||
hasher.update(str(task.get(field, "")).encode("utf-8"))
|
||||
task_digest = hasher.hexdigest()
|
||||
task["digest"] = task_digest
|
||||
task["progress"] = 0.0
|
||||
task["priority"] = priority
|
||||
|
||||
prev_tasks = TaskService.get_tasks(doc["id"])
|
||||
ck_num = 0
|
||||
if prev_tasks:
|
||||
for task in parse_task_array:
|
||||
ck_num += reuse_prev_task_chunks(task, prev_tasks, chunking_config)
|
||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
||||
pre_chunk_ids = []
|
||||
for pre_task in prev_tasks:
|
||||
if pre_task["chunk_ids"]:
|
||||
pre_chunk_ids.extend(pre_task["chunk_ids"].split())
|
||||
if pre_chunk_ids:
|
||||
settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
|
||||
chunking_config["kb_id"])
|
||||
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
|
||||
|
||||
bulk_insert_into_db(Task, parse_task_array, True)
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
|
||||
unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0]
|
||||
for unfinished_task in unfinished_task_array:
|
||||
assert REDIS_CONN.queue_product(
|
||||
get_svr_queue_name(priority), message=unfinished_task
|
||||
), "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
|
||||
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
||||
"""Attempt to reuse chunks from previous tasks for optimization.
|
||||
|
||||
This function checks if chunks from previously completed tasks can be reused for
|
||||
the current task, which can significantly improve processing efficiency. It matches
|
||||
tasks based on page ranges and configuration digests.
|
||||
|
||||
Args:
|
||||
task (dict): Current task dictionary to potentially reuse chunks for.
|
||||
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
|
||||
chunking_config (dict): Configuration dictionary for chunk processing.
|
||||
|
||||
Returns:
|
||||
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
|
||||
|
||||
Note:
|
||||
Chunks can only be reused if:
|
||||
- A previous task exists with matching page range and configuration digest
|
||||
- The previous task was completed successfully (progress = 1.0)
|
||||
- The previous task has valid chunk IDs
|
||||
"""
|
||||
idx = 0
|
||||
while idx < len(prev_tasks):
|
||||
prev_task = prev_tasks[idx]
|
||||
if prev_task.get("from_page", 0) == task.get("from_page", 0) \
|
||||
and prev_task.get("digest", 0) == task.get("digest", ""):
|
||||
break
|
||||
idx += 1
|
||||
|
||||
if idx >= len(prev_tasks):
|
||||
return 0
|
||||
prev_task = prev_tasks[idx]
|
||||
if prev_task["progress"] < 1.0 or not prev_task["chunk_ids"]:
|
||||
return 0
|
||||
task["chunk_ids"] = prev_task["chunk_ids"]
|
||||
task["progress"] = 1.0
|
||||
if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
|
||||
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
|
||||
else:
|
||||
task["progress_msg"] = ""
|
||||
task["progress_msg"] = " ".join(
|
||||
[datetime.now().strftime("%H:%M:%S"), task["progress_msg"], "Reused previous task's chunks."])
|
||||
prev_task["chunk_ids"] = ""
|
||||
|
||||
return len(task["chunk_ids"].split())
|
||||
|
||||
|
||||
def cancel_all_task_of(doc_id):
|
||||
for t in TaskService.query(doc_id=doc_id):
|
||||
try:
|
||||
REDIS_CONN.set(f"{t.id}-cancel", "x")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
|
||||
def has_canceled(task_id):
|
||||
try:
|
||||
if REDIS_CONN.get(f"{task_id}-cancel"):
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return False
|
||||
|
||||
|
||||
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
|
||||
|
||||
task = dict(
|
||||
id=task_id,
|
||||
doc_id=doc_id,
|
||||
from_page=0,
|
||||
to_page=100000000,
|
||||
task_type="dataflow" if not rerun else "dataflow_rerun",
|
||||
priority=priority,
|
||||
begin_at=datetime.now(),
|
||||
)
|
||||
if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]:
|
||||
TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute()
|
||||
DocumentService.begin2parse(doc_id)
|
||||
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
||||
|
||||
task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
|
||||
task["tenant_id"] = tenant_id
|
||||
task["dataflow_id"] = flow_id
|
||||
task["file"] = file
|
||||
|
||||
if not REDIS_CONN.queue_product(
|
||||
get_svr_queue_name(priority), message=task
|
||||
):
|
||||
return False, "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
return True, ""
|
||||
257
api/db/services/tenant_llm_service.py
Normal file
257
api/db/services/tenant_llm_service.py
Normal file
@@ -0,0 +1,257 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from langfuse import Langfuse
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
model = LLMFactories
|
||||
|
||||
|
||||
class TenantLLMService(CommonService):
|
||||
model = TenantLLM
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_name):
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
if not fid:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||
else:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
|
||||
if (not objs) and fid:
|
||||
if fid == "LocalAI":
|
||||
mdlnm += "___LocalAI"
|
||||
elif fid == "HuggingFace":
|
||||
mdlnm += "___HuggingFace"
|
||||
elif fid == "OpenAI-API-Compatible":
|
||||
mdlnm += "___OpenAI-API"
|
||||
elif fid == "VLLM":
|
||||
mdlnm += "___VLLM"
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
if not objs:
|
||||
return
|
||||
return objs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
arr = model_name.split("@")
|
||||
if len(arr) < 2:
|
||||
return model_name, None
|
||||
if len(arr) > 2:
|
||||
return "@".join(arr[0:-1]), arr[-1]
|
||||
|
||||
# model name must be xxx@yyy
|
||||
try:
|
||||
model_factories = settings.FACTORY_LLM_INFOS
|
||||
model_providers = set([f["name"] for f in model_factories])
|
||||
if arr[-1] not in model_providers:
|
||||
return model_name, None
|
||||
return arr[0], arr[-1]
|
||||
except Exception as e:
|
||||
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
||||
return model_name, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||
from api.db.services.llm_service import LLMService
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.TTS:
|
||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if not model_config: # for some cases seems fid mismatch
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if not llm and fid: # for some cases seems fid mismatch
|
||||
llm = LLMService.query(llm_name=mdlnm)
|
||||
if llm:
|
||||
model_config["is_tools"] = llm[0].is_tools
|
||||
if not model_config:
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||
if not model_config:
|
||||
if mdlnm == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||
return model_config
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
kwargs.update({"provider": model_config["llm_factory"]})
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
model_config["api_key"],
|
||||
model_config["llm_name"],
|
||||
base_url=model_config["api_base"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
logging.error(f"Tenant not found: {tenant_id}")
|
||||
return 0
|
||||
|
||||
llm_map = {
|
||||
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
||||
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
||||
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||
}
|
||||
|
||||
mdlnm = llm_map.get(llm_type)
|
||||
if mdlnm is None:
|
||||
logging.error(f"LLM type error: {llm_type}")
|
||||
return 0
|
||||
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
|
||||
try:
|
||||
num = (
|
||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.execute()
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||
return 0
|
||||
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
from api.db.services.llm_service import LLMService
|
||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
llm_factories = settings.FACTORY_LLM_INFOS
|
||||
for llm_factory in llm_factories:
|
||||
for llm in llm_factory["llm"]:
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].split(",")[-1]
|
||||
|
||||
for llm in LLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
|
||||
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
||||
if llm:
|
||||
return llm.model_type
|
||||
for llm in TenantLLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
|
||||
|
||||
class LLM4Tenant:
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
self.is_tools = model_config.get("is_tools", False)
|
||||
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
||||
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||
self.langfuse = None
|
||||
if langfuse_keys:
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||
if langfuse.auth_check():
|
||||
self.langfuse = langfuse
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
63
api/db/services/user_canvas_version.py
Normal file
63
api/db/services/user_canvas_version.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from api.db.db_models import UserCanvasVersion, DB
|
||||
from api.db.services.common_service import CommonService
|
||||
from peewee import DoesNotExist
|
||||
|
||||
class UserCanvasVersionService(CommonService):
|
||||
model = UserCanvasVersion
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def list_by_canvas_id(cls, user_canvas_id):
|
||||
try:
|
||||
user_canvas_version = cls.model.select(
|
||||
*[cls.model.id,
|
||||
cls.model.create_time,
|
||||
cls.model.title,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date,
|
||||
cls.model.user_canvas_id,
|
||||
cls.model.update_time]
|
||||
).where(cls.model.user_canvas_id == user_canvas_id)
|
||||
return user_canvas_version
|
||||
except DoesNotExist:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_canvas_version_by_canvas_ids(cls, canvas_ids):
|
||||
fields = [cls.model.id]
|
||||
versions = cls.model.select(*fields).where(cls.model.user_canvas_id.in_(canvas_ids))
|
||||
versions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
version_batch = versions.offset(offset).limit(limit)
|
||||
_temp = list(version_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_all_versions(cls, user_canvas_id):
|
||||
try:
|
||||
user_canvas_version = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by(cls.model.create_time.desc())
|
||||
if user_canvas_version.count() > 20:
|
||||
delete_ids = []
|
||||
for i in range(20, user_canvas_version.count()):
|
||||
delete_ids.append(user_canvas_version[i].id)
|
||||
|
||||
cls.delete_by_ids(delete_ids)
|
||||
return True
|
||||
except DoesNotExist:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
318
api/db/services/user_service.py
Normal file
318
api/db/services/user_service.py
Normal file
@@ -0,0 +1,318 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import User, Tenant
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import get_uuid, current_timestamp, datetime_format
|
||||
from api.db import StatusEnum
|
||||
from rag.settings import MINIO
|
||||
|
||||
|
||||
class UserService(CommonService):
|
||||
"""Service class for managing user-related database operations.
|
||||
|
||||
This class extends CommonService to provide specialized functionality for user management,
|
||||
including authentication, user creation, updates, and deletions.
|
||||
|
||||
Attributes:
|
||||
model: The User model class for database operations.
|
||||
"""
|
||||
model = User
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||
if 'access_token' in kwargs:
|
||||
access_token = kwargs['access_token']
|
||||
|
||||
# Reject empty, None, or whitespace-only access tokens
|
||||
if not access_token or not str(access_token).strip():
|
||||
logging.warning("UserService.query: Rejecting empty access_token query")
|
||||
return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result
|
||||
|
||||
# Reject tokens that are too short (should be UUID, 32+ chars)
|
||||
if len(str(access_token).strip()) < 32:
|
||||
logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars")
|
||||
return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result
|
||||
|
||||
# Reject tokens that start with "INVALID_" (from logout)
|
||||
if str(access_token).startswith("INVALID_"):
|
||||
logging.warning("UserService.query: Rejecting invalidated access_token")
|
||||
return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result
|
||||
|
||||
# Call parent query method for valid requests
|
||||
return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_by_id(cls, user_id):
|
||||
"""Retrieve a user by their ID.
|
||||
|
||||
Args:
|
||||
user_id: The unique identifier of the user.
|
||||
|
||||
Returns:
|
||||
User object if found, None otherwise.
|
||||
"""
|
||||
try:
|
||||
user = cls.model.select().where(cls.model.id == user_id).get()
|
||||
return user
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_user(cls, email, password):
|
||||
"""Authenticate a user with email and password.
|
||||
|
||||
Args:
|
||||
email: User's email address.
|
||||
password: User's password in plain text.
|
||||
|
||||
Returns:
|
||||
User object if authentication successful, None otherwise.
|
||||
"""
|
||||
user = cls.model.select().where((cls.model.email == email),
|
||||
(cls.model.status == StatusEnum.VALID.value)).first()
|
||||
if user and check_password_hash(str(user.password), password):
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_user_by_email(cls, email):
|
||||
users = cls.model.select().where((cls.model.email == email))
|
||||
return list(users)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
if "password" in kwargs:
|
||||
kwargs["password"] = generate_password_hash(
|
||||
str(kwargs["password"]))
|
||||
|
||||
kwargs["create_time"] = current_timestamp()
|
||||
kwargs["create_date"] = datetime_format(datetime.now())
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
kwargs["update_date"] = datetime_format(datetime.now())
|
||||
obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_user(cls, user_ids, update_user_dict):
|
||||
with DB.atomic():
|
||||
cls.model.update({"status": 0}).where(
|
||||
cls.model.id.in_(user_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_user(cls, user_id, user_dict):
|
||||
with DB.atomic():
|
||||
if user_dict:
|
||||
user_dict["update_time"] = current_timestamp()
|
||||
user_dict["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(user_dict).where(
|
||||
cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_user_password(cls, user_id, new_password):
|
||||
with DB.atomic():
|
||||
update_dict = {
|
||||
"password": generate_password_hash(str(new_password)),
|
||||
"update_time": current_timestamp(),
|
||||
"update_date": datetime_format(datetime.now())
|
||||
}
|
||||
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def is_admin(cls, user_id):
|
||||
return cls.model.select().where(
|
||||
cls.model.id == user_id,
|
||||
cls.model.is_superuser == 1).count() > 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_users(cls):
|
||||
users = cls.model.select()
|
||||
return list(users)
|
||||
|
||||
|
||||
class TenantService(CommonService):
|
||||
"""Service class for managing tenant-related database operations.
|
||||
|
||||
This class extends CommonService to provide functionality for tenant management,
|
||||
including tenant information retrieval and credit management.
|
||||
|
||||
Attributes:
|
||||
model: The Tenant model class for database operations.
|
||||
"""
|
||||
model = Tenant
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_info_by(cls, user_id):
|
||||
fields = [
|
||||
cls.model.id.alias("tenant_id"),
|
||||
cls.model.name,
|
||||
cls.model.llm_id,
|
||||
cls.model.embd_id,
|
||||
cls.model.rerank_id,
|
||||
cls.model.asr_id,
|
||||
cls.model.img2txt_id,
|
||||
cls.model.tts_id,
|
||||
cls.model.parser_ids,
|
||||
UserTenant.role]
|
||||
return list(cls.model.select(*fields)
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.OWNER)))
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_joined_tenants_by_user_id(cls, user_id):
|
||||
fields = [
|
||||
cls.model.id.alias("tenant_id"),
|
||||
cls.model.name,
|
||||
cls.model.llm_id,
|
||||
cls.model.embd_id,
|
||||
cls.model.asr_id,
|
||||
cls.model.img2txt_id,
|
||||
UserTenant.role]
|
||||
return list(cls.model.select(*fields)
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL)))
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrease(cls, user_id, num):
|
||||
num = cls.model.update(credit=cls.model.credit - num).where(
|
||||
cls.model.id == user_id).execute()
|
||||
if num == 0:
|
||||
raise LookupError("Tenant not found which is supposed to be there")
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def user_gateway(cls, tenant_id):
|
||||
hashobj = hashlib.sha256(tenant_id.encode("utf-8"))
|
||||
return int(hashobj.hexdigest(), 16)%len(MINIO)
|
||||
|
||||
|
||||
class UserTenantService(CommonService):
|
||||
"""Service class for managing user-tenant relationship operations.
|
||||
|
||||
This class extends CommonService to handle the many-to-many relationship
|
||||
between users and tenants, managing user roles and tenant memberships.
|
||||
|
||||
Attributes:
|
||||
model: The UserTenant model class for database operations.
|
||||
"""
|
||||
model = UserTenant
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_by_id(cls, user_tenant_id):
|
||||
try:
|
||||
user_tenant = cls.model.select().where((cls.model.id == user_tenant_id) & (cls.model.status == StatusEnum.VALID.value)).get()
|
||||
return user_tenant
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_id(cls, tenant_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.user_id,
|
||||
cls.model.status,
|
||||
cls.model.role,
|
||||
User.nickname,
|
||||
User.email,
|
||||
User.avatar,
|
||||
User.is_authenticated,
|
||||
User.is_active,
|
||||
User.is_anonymous,
|
||||
User.status,
|
||||
User.update_date,
|
||||
User.is_superuser]
|
||||
return list(cls.model.select(*fields)
|
||||
.join(User, on=((cls.model.user_id == User.id) & (cls.model.status == StatusEnum.VALID.value) & (cls.model.role != UserTenantRole.OWNER)))
|
||||
.where(cls.model.tenant_id == tenant_id)
|
||||
.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenants_by_user_id(cls, user_id):
|
||||
fields = [
|
||||
cls.model.tenant_id,
|
||||
cls.model.role,
|
||||
User.nickname,
|
||||
User.email,
|
||||
User.avatar,
|
||||
User.update_date
|
||||
]
|
||||
return list(cls.model.select(*fields)
|
||||
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_user_tenant_relation_by_user_id(cls, user_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.user_id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.role
|
||||
]
|
||||
return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_num_members(cls, user_id: str):
|
||||
cnt_members = cls.model.select(peewee.fn.COUNT(cls.model.id)).where(cls.model.tenant_id == user_id).scalar()
|
||||
return cnt_members
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_by_tenant_and_user_id(cls, tenant_id, user_id):
|
||||
try:
|
||||
user_tenant = cls.model.select().where(
|
||||
(cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) &
|
||||
(cls.model.user_id == user_id)
|
||||
).first()
|
||||
return user_tenant
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
Reference in New Issue
Block a user