将flask改成fastapi

This commit is contained in:
2025-10-13 13:18:03 +08:00
commit 88db2539b0
476 changed files with 739741 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from pathlib import PurePath
from .user_service import UserService as UserService
def _split_name_counter(filename: str) -> tuple[str, int | None]:
"""
Splits a filename into main part and counter (if present in parentheses).
Args:
filename: Input filename string to be parsed
Returns:
A tuple containing:
- The main filename part (string)
- The counter from parentheses (integer) or None if no counter exists
"""
pattern = re.compile(r"^(.*?)\((\d+)\)$")
match = pattern.search(filename)
if match:
main_part = match.group(1).rstrip()
bracket_part = match.group(2)
return main_part, int(bracket_part)
return filename, None
def duplicate_name(query_func, **kwargs) -> str:
"""
Generates a unique filename by appending/incrementing a counter when duplicates exist.
Continuously checks for name availability using the provided query function,
automatically appending (1), (2), etc. until finding an available name or
reaching maximum retries.
Args:
query_func: Callable that accepts keyword arguments and returns:
- True if name exists (should be modified)
- False if name is available
**kwargs: Must contain 'name' key with original filename to check
Returns:
str: Available filename, either:
- Original name (if available)
- Modified name with counter (e.g., "file(1).txt")
Raises:
KeyError: If 'name' key not provided in kwargs
RuntimeError: If unable to generate unique name after maximum retries
Example:
>>> def name_exists(name): return name in existing_files
>>> duplicate_name(name_exists, name="document.pdf")
'document(1).pdf' # If original exists
"""
MAX_RETRIES = 1000
if "name" not in kwargs:
raise KeyError("Arguments must contain 'name' key")
original_name = kwargs["name"]
current_name = original_name
retries = 0
while retries < MAX_RETRIES:
if not query_func(**kwargs):
return current_name
path = PurePath(current_name)
stem = path.stem
suffix = path.suffix
main_part, counter = _split_name_counter(stem)
counter = counter + 1 if counter else 1
new_name = f"{main_part}({counter}){suffix}"
kwargs["name"] = new_name
current_name = new_name
retries += 1
raise RuntimeError(f"Failed to generate unique name within {MAX_RETRIES} attempts. Original: {original_name}")

View File

@@ -0,0 +1,112 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from api.db.db_models import DB, API4Conversation, APIToken, Dialog
from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format
class APITokenService(CommonService):
model = APIToken
@classmethod
@DB.connection_context()
def used(cls, token):
return cls.model.update({
"update_time": current_timestamp(),
"update_date": datetime_format(datetime.now()),
}).where(
cls.model.token == token
)
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
class API4ConversationService(CommonService):
model = API4Conversation
@classmethod
@DB.connection_context()
def get_list(cls, dialog_id, tenant_id,
page_number, items_per_page,
orderby, desc, id, user_id=None, include_dsl=True, keywords="",
from_date=None, to_date=None
):
if include_dsl:
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
else:
fields = [field for field in cls.model._meta.fields.values() if field.name != 'dsl']
sessions = cls.model.select(*fields).where(cls.model.dialog_id == dialog_id)
if id:
sessions = sessions.where(cls.model.id == id)
if user_id:
sessions = sessions.where(cls.model.user_id == user_id)
if keywords:
sessions = sessions.where(peewee.fn.LOWER(cls.model.message).contains(keywords.lower()))
if from_date:
sessions = sessions.where(cls.model.create_date >= from_date)
if to_date:
sessions = sessions.where(cls.model.create_date <= to_date)
if desc:
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
count = sessions.count()
sessions = sessions.paginate(page_number, items_per_page)
return count, list(sessions.dicts())
@classmethod
@DB.connection_context()
def append_message(cls, id, conversation):
cls.update_by_id(id, conversation)
return cls.model.update(round=cls.model.round + 1).where(cls.model.id == id).execute()
@classmethod
@DB.connection_context()
def stats(cls, tenant_id, from_date, to_date, source=None):
if len(to_date) == 10:
to_date += " 23:59:59"
return cls.model.select(
cls.model.create_date.truncate("day").alias("dt"),
peewee.fn.COUNT(
cls.model.id).alias("pv"),
peewee.fn.COUNT(
cls.model.user_id.distinct()).alias("uv"),
peewee.fn.SUM(
cls.model.tokens).alias("tokens"),
peewee.fn.SUM(
cls.model.duration).alias("duration"),
peewee.fn.AVG(
cls.model.round).alias("round"),
peewee.fn.SUM(
cls.model.thumb_up).alias("thumb_up")
).join(Dialog, on=((cls.model.dialog_id == Dialog.id) & (Dialog.tenant_id == tenant_id))).where(
cls.model.create_date >= from_date,
cls.model.create_date <= to_date,
cls.model.source == source
).group_by(cls.model.create_date.truncate("day")).dicts()
@classmethod
@DB.connection_context()
def delete_by_dialog_ids(cls, dialog_ids):
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()

View File

@@ -0,0 +1,350 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import time
from uuid import uuid4
from agent.canvas import Canvas
from api.db import CanvasCategory, TenantPermission
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
from api.utils import get_uuid
from api.utils.api_utils import get_data_openai
import tiktoken
from peewee import fn
class CanvasTemplateService(CommonService):
model = CanvasTemplate
class DataFlowTemplateService(CommonService):
"""
Alias of CanvasTemplateService
"""
model = CanvasTemplate
class UserCanvasService(CommonService):
model = UserCanvas
@classmethod
@DB.connection_context()
def get_list(cls, tenant_id,
page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent):
agents = cls.model.select()
if id:
agents = agents.where(cls.model.id == id)
if title:
agents = agents.where(cls.model.title == title)
agents = agents.where(cls.model.user_id == tenant_id)
agents = agents.where(cls.model.canvas_category == canvas_category)
if desc:
agents = agents.order_by(cls.model.getter_by(orderby).desc())
else:
agents = agents.order_by(cls.model.getter_by(orderby).asc())
agents = agents.paginate(page_number, items_per_page)
return list(agents.dicts())
@classmethod
@DB.connection_context()
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
# will get all permitted agents, be cautious
fields = [
cls.model.id,
cls.model.title,
cls.model.permission,
cls.model.canvas_type,
cls.model.canvas_category
]
# find team agents and owned agents
agents = cls.model.select(*fields).where(
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
cls.model.user_id == user_id
)
)
# sort by create_time, asc
agents.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 50
res = []
while True:
ag_batch = agents.offset(offset).limit(limit)
_temp = list(ag_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_by_canvas_id(cls, pid):
try:
fields = [
cls.model.id,
cls.model.avatar,
cls.model.title,
cls.model.dsl,
cls.model.description,
cls.model.permission,
cls.model.update_time,
cls.model.user_id,
cls.model.create_time,
cls.model.create_date,
cls.model.update_date,
cls.model.canvas_category,
User.nickname,
User.avatar.alias('tenant_avatar'),
]
agents = cls.model.select(*fields) \
.join(User, on=(cls.model.user_id == User.id)) \
.where(cls.model.id == pid)
# obj = cls.model.query(id=pid)[0]
return True, agents.dicts()[0]
except Exception as e:
logging.exception(e)
return False, None
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page,
orderby, desc, keywords, canvas_category=None
):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.title,
cls.model.dsl,
cls.model.description,
cls.model.permission,
cls.model.user_id.alias("tenant_id"),
User.nickname,
User.avatar.alias('tenant_avatar'),
cls.model.update_time,
cls.model.canvas_category,
]
if keywords:
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
cls.model.user_id.in_(joined_tenant_ids),
fn.LOWER(cls.model.title).contains(keywords.lower())
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
#(fn.LOWER(cls.model.title).contains(keywords.lower()))
)
else:
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
cls.model.user_id.in_(joined_tenant_ids)
#(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
)
if canvas_category:
agents = agents.where(cls.model.canvas_category == canvas_category)
if desc:
agents = agents.order_by(cls.model.getter_by(orderby).desc())
else:
agents = agents.order_by(cls.model.getter_by(orderby).asc())
count = agents.count()
if page_number and items_per_page:
agents = agents.paginate(page_number, items_per_page)
return list(agents.dicts()), count
@classmethod
@DB.connection_context()
def accessible(cls, canvas_id, tenant_id):
from api.db.services.user_service import UserTenantService
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
if not e:
return False
tids = [t.tenant_id for t in UserTenantService.query(user_id=tenant_id)]
if c["user_id"] != canvas_id and c["user_id"] not in tids:
return False
return True
def completion(tenant_id, agent_id, session_id=None, **kwargs):
query = kwargs.get("query", "") or kwargs.get("question", "")
files = kwargs.get("files", [])
inputs = kwargs.get("inputs", {})
user_id = kwargs.get("user_id", "")
if session_id:
e, conv = API4ConversationService.get_by_id(session_id)
assert e, "Session not found!"
if not conv.message:
conv.message = []
if not isinstance(conv.dsl, str):
conv.dsl = json.dumps(conv.dsl, ensure_ascii=False)
canvas = Canvas(conv.dsl, tenant_id, agent_id)
else:
e, cvs = UserCanvasService.get_by_id(agent_id)
assert e, "Agent not found."
assert cvs.user_id == tenant_id, "You do not own the agent."
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
session_id=get_uuid()
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
canvas.reset()
conv = {
"id": session_id,
"dialog_id": cvs.id,
"user_id": user_id,
"message": [],
"source": "agent",
"dsl": cvs.dsl,
"reference": []
}
API4ConversationService.save(**conv)
conv = API4Conversation(**conv)
message_id = str(uuid4())
conv.message.append({
"role": "user",
"content": query,
"id": message_id
})
txt = ""
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
ans["session_id"] = session_id
if ans["event"] == "message":
txt += ans["data"]["content"]
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
conv.reference = canvas.get_reference()
conv.errors = canvas.error
conv.dsl = str(canvas)
conv = conv.to_dict()
API4ConversationService.append_message(conv["id"], conv)
def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
tiktokenenc = tiktoken.get_encoding("cl100k_base")
prompt_tokens = len(tiktokenenc.encode(str(question)))
user_id = kwargs.get("user_id", "")
if stream:
completion_tokens = 0
try:
for ans in completion(
tenant_id=tenant_id,
agent_id=agent_id,
session_id=session_id,
query=question,
user_id=user_id,
**kwargs
):
if isinstance(ans, str):
try:
ans = json.loads(ans[5:]) # remove "data:"
except Exception as e:
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
continue
if ans.get("event") not in ["message", "message_end"]:
continue
content_piece = ""
if ans["event"] == "message":
content_piece = ans["data"]["content"]
completion_tokens += len(tiktokenenc.encode(content_piece))
openai_data = get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
content=content_piece,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
stream=True
)
if ans.get("data", {}).get("reference", None):
openai_data["choices"][0]["delta"]["reference"] = ans["data"]["reference"]
yield "data: " + json.dumps(openai_data, ensure_ascii=False) + "\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logging.exception(e)
yield "data: " + json.dumps(
get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
content=f"**ERROR**: {str(e)}",
finish_reason="stop",
prompt_tokens=prompt_tokens,
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
stream=True
),
ensure_ascii=False
) + "\n\n"
yield "data: [DONE]\n\n"
else:
try:
all_content = ""
reference = {}
for ans in completion(
tenant_id=tenant_id,
agent_id=agent_id,
session_id=session_id,
query=question,
user_id=user_id,
**kwargs
):
if isinstance(ans, str):
ans = json.loads(ans[5:])
if ans.get("event") not in ["message", "message_end"]:
continue
if ans["event"] == "message":
all_content += ans["data"]["content"]
if ans.get("data", {}).get("reference", None):
reference.update(ans["data"]["reference"])
completion_tokens = len(tiktokenenc.encode(all_content))
openai_data = get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
content=all_content,
finish_reason="stop",
param=None
)
if reference:
openai_data["choices"][0]["message"]["reference"] = reference
yield openai_data
except Exception as e:
logging.exception(e)
yield get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
prompt_tokens=prompt_tokens,
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
content=f"**ERROR**: {str(e)}",
finish_reason="stop",
param=None
)

View File

@@ -0,0 +1,345 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import peewee
from peewee import InterfaceError, OperationalError
from api.db.db_models import DB
from api.utils import current_timestamp, datetime_format, get_uuid
def retry_db_operation(func):
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type((InterfaceError, OperationalError)),
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
reraise=True,
)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
class CommonService:
"""Base service class that provides common database operations.
This class serves as a foundation for all service classes in the application,
implementing standard CRUD operations and common database query patterns.
It uses the Peewee ORM for database interactions and provides a consistent
interface for database operations across all derived service classes.
Attributes:
model: The Peewee model class that this service operates on. Must be set by subclasses.
"""
model = None
@classmethod
@DB.connection_context()
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
"""Execute a database query with optional column selection and ordering.
This method provides a flexible way to query the database with various filters
and sorting options. It supports column selection, sort order control, and
additional filter conditions.
Args:
cols (list, optional): List of column names to select. If None, selects all columns.
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
order_by (str, optional): Column name to sort results by.
**kwargs: Additional filter conditions passed as keyword arguments.
Returns:
peewee.ModelSelect: A query result containing matching records.
"""
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@classmethod
@DB.connection_context()
def get_all(cls, cols=None, reverse=None, order_by=None):
"""Retrieve all records from the database with optional column selection and ordering.
This method fetches all records from the model's table with support for
column selection and result ordering. If no order_by is specified and reverse
is True, it defaults to ordering by create_time.
Args:
cols (list, optional): List of column names to select. If None, selects all columns.
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
order_by (str, optional): Column name to sort results by. Defaults to 'create_time' if reverse is specified.
Returns:
peewee.ModelSelect: A query containing all matching records.
"""
if cols:
query_records = cls.model.select(*cols)
else:
query_records = cls.model.select()
if reverse is not None:
if not order_by or not hasattr(cls, order_by):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
elif reverse is False:
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
return query_records
@classmethod
@DB.connection_context()
def get(cls, **kwargs):
"""Get a single record matching the given criteria.
This method retrieves a single record from the database that matches
the specified filter conditions.
Args:
**kwargs: Filter conditions as keyword arguments.
Returns:
Model instance: Single matching record.
Raises:
peewee.DoesNotExist: If no matching record is found.
"""
return cls.model.get(**kwargs)
@classmethod
@DB.connection_context()
def get_or_none(cls, **kwargs):
"""Get a single record or None if not found.
This method attempts to retrieve a single record matching the given criteria,
returning None if no match is found instead of raising an exception.
Args:
**kwargs: Filter conditions as keyword arguments.
Returns:
Model instance or None: Matching record if found, None otherwise.
"""
try:
return cls.model.get(**kwargs)
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
"""Save a new record to database.
This method creates a new record in the database with the provided field values,
forcing an insert operation rather than an update.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The created record object.
"""
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
@DB.connection_context()
def insert(cls, **kwargs):
"""Insert a new record with automatic ID and timestamps.
This method creates a new record with automatically generated ID and timestamp fields.
It handles the creation of create_time, create_date, update_time, and update_date fields.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The newly created record object.
"""
if "id" not in kwargs:
kwargs["id"] = get_uuid()
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
@DB.connection_context()
def insert_many(cls, data_list, batch_size=100):
"""Insert multiple records in batches.
This method efficiently inserts multiple records into the database using batch processing.
It automatically sets creation timestamps for all records.
Args:
data_list (list): List of dictionaries containing record data to insert.
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
"""
with DB.atomic():
for d in data_list:
d["create_time"] = current_timestamp()
d["create_date"] = datetime_format(datetime.now())
for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i : i + batch_size]).execute()
@classmethod
@DB.connection_context()
def update_many_by_id(cls, data_list):
"""Update multiple records by their IDs.
This method updates multiple records in the database, identified by their IDs.
It automatically updates the update_time and update_date fields for each record.
Args:
data_list (list): List of dictionaries containing record data to update.
Each dictionary must include an 'id' field.
"""
with DB.atomic():
for data in data_list:
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod
@DB.connection_context()
@retry_db_operation
def update_by_id(cls, pid, data):
# Update a single record by ID
# Args:
# pid: Record ID
# data: Updated field values
# Returns:
# Number of records updated
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num
@classmethod
@DB.connection_context()
def get_by_id(cls, pid):
# Get a record by ID
# Args:
# pid: Record ID
# Returns:
# Tuple of (success, record)
try:
obj = cls.model.get_or_none(cls.model.id == pid)
if obj:
return True, obj
except Exception:
pass
return False, None
@classmethod
@DB.connection_context()
def get_by_ids(cls, pids, cols=None):
# Get multiple records by their IDs
# Args:
# pids: List of record IDs
# cols: List of columns to select
# Returns:
# Query of matching records
if cols:
objs = cls.model.select(*cols)
else:
objs = cls.model.select()
return objs.where(cls.model.id.in_(pids))
@classmethod
@DB.connection_context()
def delete_by_id(cls, pid):
# Delete a record by ID
# Args:
# pid: Record ID
# Returns:
# Number of records deleted
return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod
@DB.connection_context()
def delete_by_ids(cls, pids):
# Delete multiple records by their IDs
# Args:
# pids: List of record IDs
# Returns:
# Number of records deleted
with DB.atomic():
res = cls.model.delete().where(cls.model.id.in_(pids)).execute()
return res
@classmethod
@DB.connection_context()
def filter_delete(cls, filters):
# Delete records matching given filters
# Args:
# filters: List of filter conditions
# Returns:
# Number of records deleted
with DB.atomic():
num = cls.model.delete().where(*filters).execute()
return num
@classmethod
@DB.connection_context()
def filter_update(cls, filters, update_data):
# Update records matching given filters
# Args:
# filters: List of filter conditions
# update_data: Updated field values
# Returns:
# Number of records updated
with DB.atomic():
return cls.model.update(update_data).where(*filters).execute()
@staticmethod
def cut_list(tar_list, n):
# Split a list into chunks of size n
# Args:
# tar_list: List to split
# n: Chunk size
# Returns:
# List of tuples containing chunks
length = len(tar_list)
arr = range(length)
result = [tuple(tar_list[x : (x + n)]) for x in arr[::n]]
return result
@classmethod
@DB.connection_context()
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
# Get records matching IN clause filters with optional column selection
# Args:
# in_key: Field name for IN clause
# in_filters_list: List of values for IN clause
# filters: Additional filter conditions
# cols: List of columns to select
# Returns:
# List of matching records
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters:
filters = []
res_list = []
if cols:
for i in in_filters_tuple_list:
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
else:
for i in in_filters_tuple_list:
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
return res_list

View File

@@ -0,0 +1,242 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from uuid import uuid4
from api.db import StatusEnum
from api.db.db_models import Conversation, DB
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, chat
from api.utils import get_uuid
import json
from rag.prompts.generator import chunks_format
class ConversationService(CommonService):
model = Conversation
@classmethod
@DB.connection_context()
def get_list(cls, dialog_id, page_number, items_per_page, orderby, desc, id, name, user_id=None):
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
if id:
sessions = sessions.where(cls.model.id == id)
if name:
sessions = sessions.where(cls.model.name == name)
if user_id:
sessions = sessions.where(cls.model.user_id == user_id)
if desc:
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
sessions = sessions.paginate(page_number, items_per_page)
return list(sessions.dicts())
@classmethod
@DB.connection_context()
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
sessions.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
s_batch = sessions.offset(offset).limit(limit)
_temp = list(s_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
def structure_answer(conv, ans, message_id, session_id):
reference = ans["reference"]
if not isinstance(reference, dict):
reference = {}
ans["reference"] = {}
chunk_list = chunks_format(reference)
reference["chunks"] = chunk_list
ans["id"] = message_id
ans["session_id"] = session_id
if not conv:
return ans
if not conv.message:
conv.message = []
if not conv.message or conv.message[-1].get("role", "") != "assistant":
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id})
else:
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
if conv.reference:
conv.reference[-1] = reference
return ans
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
assert dia, "You do not own the chat."
if not session_id:
session_id = get_uuid()
conv = {
"id": session_id,
"dialog_id": chat_id,
"name": name,
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue"), "created_at": time.time()}],
"user_id": kwargs.get("user_id", "")
}
ConversationService.save(**conv)
if stream:
yield "data:" + json.dumps({"code": 0, "message": "",
"data": {
"answer": conv["message"][0]["content"],
"reference": {},
"audio_binary": None,
"id": None,
"session_id": session_id
}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
return
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
if not conv:
raise LookupError("Session does not exist")
conv = conv[0]
msg = []
question = {
"content": question,
"role": "user",
"id": str(uuid4())
}
conv.message.append(question)
for m in conv.message:
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
message_id = msg[-1].get("id")
e, dia = DialogService.get_by_id(conv.dialog_id)
kb_ids = kwargs.get("kb_ids",[])
dia.kb_ids = list(set(dia.kb_ids + kb_ids))
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
if stream:
try:
for ans in chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
else:
answer = None
for ans in chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
ConversationService.update_by_id(conv.id, conv.to_dict())
break
yield answer
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
e, dia = DialogService.get_by_id(dialog_id)
assert e, "Dialog not found"
if not session_id:
session_id = get_uuid()
conv = {
"id": session_id,
"dialog_id": dialog_id,
"user_id": kwargs.get("user_id", ""),
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"], "created_at": time.time()}]
}
API4ConversationService.save(**conv)
yield "data:" + json.dumps({"code": 0, "message": "",
"data": {
"answer": conv["message"][0]["content"],
"reference": {},
"audio_binary": None,
"id": None,
"session_id": session_id
}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
return
else:
session_id = session_id
e, conv = API4ConversationService.get_by_id(session_id)
assert e, "Session not found!"
if not conv.message:
conv.message = []
messages = conv.message
question = {
"role": "user",
"content": question,
"id": str(uuid4())
}
messages.append(question)
msg = []
for m in messages:
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
if not conv.reference:
conv.reference = []
conv.reference.append({"chunks": [], "doc_aggs": []})
if stream:
try:
for ans in chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
else:
answer = None
for ans in chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
API4ConversationService.append_message(conv.id, conv.to_dict())
break
yield answer

View File

@@ -0,0 +1,868 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import binascii
import logging
import re
import time
from copy import deepcopy
from datetime import datetime
from functools import partial
from timeit import default_timer as timer
import trio
from langfuse import Langfuse
from peewee import fn
from agentic_reasoning import DeepResearcher
from api import settings
from api.db import LLMType, ParserType, StatusEnum
from api.db.db_models import DB, Dialog
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils import current_timestamp, datetime_format
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily
class DialogService(CommonService):
model = Dialog
@classmethod
def save(cls, **kwargs):
"""Save a new record to database.
This method creates a new record in the database with the provided field values,
forcing an insert operation rather than an update.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The created record object.
"""
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
def update_many_by_id(cls, data_list):
"""Update multiple records by their IDs.
This method updates multiple records in the database, identified by their IDs.
It automatically updates the update_time and update_date fields for each record.
Args:
data_list (list): List of dictionaries containing record data to update.
Each dictionary must include an 'id' field.
"""
with DB.atomic():
for data in data_list:
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod
@DB.connection_context()
def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
chats = cls.model.select()
if id:
chats = chats.where(cls.model.id == id)
if name:
chats = chats.where(cls.model.name == name)
chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
if desc:
chats = chats.order_by(cls.model.getter_by(orderby).desc())
else:
chats = chats.order_by(cls.model.getter_by(orderby).asc())
chats = chats.paginate(page_number, items_per_page)
return list(chats.dicts())
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None):
from api.db.db_models import User
fields = [
cls.model.id,
cls.model.tenant_id,
cls.model.name,
cls.model.description,
cls.model.language,
cls.model.llm_id,
cls.model.llm_setting,
cls.model.prompt_type,
cls.model.prompt_config,
cls.model.similarity_threshold,
cls.model.vector_similarity_weight,
cls.model.top_n,
cls.model.top_k,
cls.model.do_refer,
cls.model.rerank_id,
cls.model.kb_ids,
cls.model.icon,
cls.model.status,
User.nickname,
User.avatar.alias("tenant_avatar"),
cls.model.update_time,
cls.model.create_time,
]
if keywords:
dialogs = (
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(
(cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value),
(fn.LOWER(cls.model.name).contains(keywords.lower())),
)
)
else:
dialogs = (
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(
(cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value),
)
)
if parser_id:
dialogs = dialogs.where(cls.model.parser_id == parser_id)
if desc:
dialogs = dialogs.order_by(cls.model.getter_by(orderby).desc())
else:
dialogs = dialogs.order_by(cls.model.getter_by(orderby).asc())
count = dialogs.count()
if page_number and items_per_page:
dialogs = dialogs.paginate(page_number, items_per_page)
return list(dialogs.dicts()), count
@classmethod
@DB.connection_context()
def get_all_dialogs_by_tenant_id(cls, tenant_id):
fields = [cls.model.id]
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
dialogs.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
d_batch = dialogs.offset(offset).limit(limit)
_temp = list(d_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
def chat_solo(dialog, messages, stream=True):
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
prompt_config = dialog.prompt_config
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
if stream:
last_ans = ""
delta_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
delta_ans = ""
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else:
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
def get_models(dialog):
embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
if len(embedding_list) > 1:
raise Exception("**ERROR**: Knowledge bases use different embedding models.")
if embedding_list:
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
if dialog.prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
BAD_CITATION_PATTERNS = [
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
re.compile(r"\s*ID\s*[: ]*\s*(\d+)\s*】"), # 【ID: 12】
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
]
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
max_index = len(kbinfos["chunks"])
def safe_add(i):
if 0 <= i < max_index:
idx.add(i)
return True
return False
def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0):
nonlocal answer
def replacement(match):
try:
i = int(match.group(group_index))
if safe_add(i):
return f"[{repl(i)}]"
except Exception:
pass
return match.group(0)
answer = re.sub(pattern, replacement, answer, flags=flags)
for pattern in BAD_CITATION_PATTERNS:
find_and_replace(pattern)
return answer, idx
def convert_conditions(metadata_condition):
if metadata_condition is None:
metadata_condition = {}
op_mapping = {
"is": "=",
"not is": ""
}
return [
{
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
"key": cond["name"],
"value": cond["value"]
}
for cond in metadata_condition.get("conditions", [])
]
def meta_filter(metas: dict, filters: list[dict]):
doc_ids = set([])
def filter_out(v2docs, operator, value):
ids = []
for input, docids in v2docs.items():
try:
input = float(input)
value = float(value)
except Exception:
input = str(input)
value = str(value)
for conds in [
(operator == "contains", str(value).lower() in str(input).lower()),
(operator == "not contains", str(value).lower() not in str(input).lower()),
(operator == "start with", str(input).lower().startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower())),
(operator == "empty", not input),
(operator == "not empty", input),
(operator == "=", input == value),
(operator == "", input != value),
(operator == ">", input > value),
(operator == "<", input < value),
(operator == "", input >= value),
(operator == "", input <= value),
]:
try:
if all(conds):
ids.extend(docids)
break
except Exception:
pass
return ids
for k, v2docs in metas.items():
for f in filters:
if k != f["key"]:
continue
ids = filter_out(v2docs, f["op"], f["value"])
if not doc_ids:
doc_ids = set(ids)
else:
doc_ids = doc_ids & set(ids)
if not doc_ids:
return []
return list(doc_ids)
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
for ans in chat_solo(dialog, messages, stream):
yield ans
return
chat_start_ts = timer()
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
max_tokens = llm_model_config.get("max_tokens", 8192)
check_llm_ts = timer()
langfuse_tracer = None
trace_context = {}
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
if langfuse.auth_check():
langfuse_tracer = langfuse
trace_id = langfuse_tracer.create_trace_id()
trace_context = {"trace_id": trace_id}
check_langfuse_tracer_ts = timer()
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
if toolcall_session and tools:
chat_mdl.bind_tools(toolcall_session, tools)
bind_models_ts = timer()
retriever = settings.retrievaler
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
if "doc_ids" in messages[-1]:
attachments = messages[-1]["doc_ids"]
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
# try to use sql if field mapping is good to go
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans:
yield ans
return
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
continue
if p["key"] not in kwargs and not p["optional"]:
raise KeyError("Miss parameter: " + p["key"])
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
else:
questions = questions[-1:]
if prompt_config.get("cross_languages"):
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
if dialog.meta_data_filter:
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
if dialog.meta_data_filter.get("method") == "auto":
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
attachments.extend(meta_filter(metas, filters))
if not attachments:
attachments = None
elif dialog.meta_data_filter.get("method") == "manual":
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
if not attachments:
attachments = None
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
refine_question_ts = timer()
thought = ""
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
knowledges = []
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = []
if prompt_config.get("reasoning", False):
reasoner = DeepResearcher(
chat_mdl,
prompt_config,
partial(
retriever.retrieval,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
kb_ids=dialog.kb_ids,
page=1,
page_size=dialog.top_n,
similarity_threshold=0.2,
vector_similarity_weight=0.3,
doc_ids=attachments,
),
)
for think in reasoner.thinking(kbinfos, " ".join(questions)):
if isinstance(think, str):
thought = think
knowledges = [t for t in think.split("\n") if t]
elif stream:
yield think
else:
if embd_mdl:
kbinfos = retriever.retrieval(
" ".join(questions),
embd_mdl,
tenant_ids,
dialog.kb_ids,
1,
dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs),
)
if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions))
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
prompt4citation = ""
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt()
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
prompt = msg[0]["content"]
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
def decorate_answer(answer):
nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
refs = []
ans = answer.split("</think>")
think = ""
if len(ans) == 2:
think = ans[0] + "</think>"
answer = ans[1]
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
idx = set([])
if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer):
answer, idx = retriever.insert_citations(
answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],
[ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight,
)
else:
for match in re.finditer(r"\[ID:([0-9]+)\]", answer):
i = int(match.group(1))
if i < len(kbinfos["chunks"]):
idx.add(i)
answer, idx = repair_bad_citation_formats(answer, kbinfos, idx)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos)
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
finish_chat_ts = timer()
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000
refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000
retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
tk_num = num_tokens_from_string(think + answer)
prompt += "\n\n### Query:\n%s" % " ".join(questions)
prompt = (
f"{prompt}\n\n"
"## Time elapsed:\n"
f" - Total: {total_time_cost:.1f}ms\n"
f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
f" - Bind models: {bind_embedding_time_cost:.1f}ms\n"
f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n"
f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
"## Token usage:\n"
f" - Generated tokens(approximately): {tk_num}\n"
f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s"
)
# Add a condition check to call the end method only if langfuse_tracer exists
if langfuse_tracer and "langfuse_generation" in locals():
langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL)
langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
langfuse_generation.update(output=langfuse_output)
langfuse_generation.end()
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
if langfuse_tracer:
langfuse_generation = langfuse_tracer.start_generation(
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
)
if stream:
last_ans = ""
answer = ""
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):]
if delta_ans:
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
else:
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer)
yield res
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
tried_times = 0
def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql)
if sql[: len("select ")] != "select ":
return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[: len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:
continue
if len(flds) > 11:
break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
if kb_ids:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where" not in sql.lower():
sql += f" WHERE {kb_filter}"
else:
sql += f" AND {kb_filter}"
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl is None:
return None
if tbl.get("error") and tried_times <= 2:
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
Error issued by database as follows:
{}
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
tbl, sql = get_table()
logging.debug("TRY it again: {}".format(sql))
logging.debug("GET table: {}".format(tbl))
if tbl.get("error") or len(tbl["rows"]) == 0:
return None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
# compose Markdown table
columns = (
"|" + "|".join(
[re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
"|Source|" if docid_idx and docid_idx else "|")
)
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql)
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
docid_idx = list(docid_idx)[0]
doc_name_idx = list(doc_name_idx)[0]
doc_aggs = {}
for r in tbl["rows"]:
if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1
return {
"answer": "\n".join([columns, line, rows]),
"reference": {
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
},
"prompt": sys_prompt,
}
def tts(tts_mdl, text):
if not tts_mdl or not text:
return
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
return binascii.hexlify(bin).decode("utf-8")
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
kb_ids = search_config.get("kb_ids", kb_ids)
chat_llm_name = search_config.get("chat_id", chat_llm_name)
rerank_id = search_config.get("rerank_id", "")
meta_data_filter = search_config.get("meta_data_filter")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
if meta_data_filter:
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
filters = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
if not doc_ids:
doc_ids = None
kbinfos = retriever.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=search_config.get("similarity_threshold", 0.1),
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
top=search_config.get("top_k", 1024),
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs)
)
knowledges = kb_prompt(kbinfos, max_tokens)
sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges))
msg = [{"role": "user", "content": question}]
def decorate_answer(answer):
nonlocal knowledges, kbinfos, sys_prompt
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl, tkweight=0.7, vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos)
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
refs["chunks"] = chunks_format(refs)
return {"answer": answer, "reference": refs}
answer = ""
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
meta_data_filter = search_config.get("meta_data_filter", {})
doc_ids = search_config.get("doc_ids", [])
rerank_id = search_config.get("rerank_id", "")
rerank_mdl = None
kbs = KnowledgebaseService.get_by_ids(kb_ids)
if not kbs:
return {"error": "No KB selected"}
embedding_list = list(set([kb.embd_id for kb in kbs]))
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
if meta_data_filter:
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
filters = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
if not doc_ids:
doc_ids = None
ranks = settings.retrievaler.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=search_config.get("similarity_threshold", 0.2),
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
top=search_config.get("top_k", 1024),
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs),
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
return mind_map.output

View File

@@ -0,0 +1,975 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import random
import re
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from io import BytesIO
import trio
import xxhash
from peewee import fn, Case, JOIN
from api import settings
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole, CanvasCategory
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
User
from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils import current_timestamp, get_format_time, get_uuid
from rag.nlp import rag_tokenizer, search
from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr
class DocumentService(CommonService):
model = Document
@classmethod
def get_cls_model_fields(cls):
return [
cls.model.id,
cls.model.thumbnail,
cls.model.kb_id,
cls.model.parser_id,
cls.model.pipeline_id,
cls.model.parser_config,
cls.model.source_type,
cls.model.type,
cls.model.created_by,
cls.model.name,
cls.model.location,
cls.model.size,
cls.model.token_num,
cls.model.chunk_num,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.meta_fields,
cls.model.suffix,
cls.model.run,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id, name):
fields = cls.get_cls_model_fields()
docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
.join(File, on = (File.id == File2Document.file_id))\
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
.where(cls.model.kb_id == kb_id)
if id:
docs = docs.where(
cls.model.id == id)
if name:
docs = docs.where(
cls.model.name == name
)
if keywords:
docs = docs.where(
fn.LOWER(cls.model.name).contains(keywords.lower())
)
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
count = docs.count()
docs = docs.paginate(page_number, items_per_page)
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
def check_doc_health(cls, tenant_id: str, filename):
import os
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!")
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
raise RuntimeError("Exceed the maximum length of file name!")
return True
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, run_status, types, suffix):
fields = cls.get_cls_model_fields()
if keywords:
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(File, on=(File.id == File2Document.file_id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else:
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(File, on=(File.id == File2Document.file_id))\
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
.where(cls.model.kb_id == kb_id)
if run_status:
docs = docs.where(cls.model.run.in_(run_status))
if types:
docs = docs.where(cls.model.type.in_(types))
if suffix:
docs = docs.where(cls.model.suffix.in_(suffix))
count = docs.count()
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
docs = docs.paginate(page_number, items_per_page)
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix):
"""
returns:
{
"suffix": {
"ppt": 1,
"doxc": 2
},
"run_status": {
"1": 2,
"2": 2
}
}, total
where "1" => RUNNING, "2" => CANCEL
"""
fields = cls.get_cls_model_fields()
if keywords:
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else:
query = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
if run_status:
query = query.where(cls.model.run.in_(run_status))
if types:
query = query.where(cls.model.type.in_(types))
if suffix:
query = query.where(cls.model.suffix.in_(suffix))
rows = query.select(cls.model.run, cls.model.suffix)
total = rows.count()
suffix_counter = {}
run_status_counter = {}
for row in rows:
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
return {
"suffix": suffix_counter,
"run_status": run_status_counter
}, total
@classmethod
@DB.connection_context()
def count_by_kb_id(cls, kb_id, keywords, run_status, types):
if keywords:
docs = cls.model.select().where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
if run_status:
docs = docs.where(cls.model.run.in_(run_status))
if types:
docs = docs.where(cls.model.type.in_(types))
count = docs.count()
return count
@classmethod
@DB.connection_context()
def get_total_size_by_kb_id(cls, kb_id, keywords="", run_status=[], types=[]):
query = cls.model.select(fn.COALESCE(fn.SUM(cls.model.size), 0)).where(
cls.model.kb_id == kb_id
)
if keywords:
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if run_status:
query = query.where(cls.model.run.in_(run_status))
if types:
query = query.where(cls.model.type.in_(types))
return int(query.scalar()) or 0
@classmethod
@DB.connection_context()
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
fields = [cls.model.id]
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
res = []
while True:
doc_batch = docs.offset(offset).limit(limit)
_temp = list(doc_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_all_docs_by_creator_id(cls, creator_id):
fields = [
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
cls.model.created_by == creator_id
)
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
res = []
while True:
doc_batch = docs.offset(offset).limit(limit)
_temp = list(doc_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def insert(cls, doc):
if not cls.save(**doc):
raise RuntimeError("Database error (Document)!")
if not KnowledgebaseService.atomic_increase_doc_num_by_id(doc["kb_id"]):
raise RuntimeError("Database error (Knowledgebase)!")
return Document(**doc)
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
from api.db.services.task_service import TaskService
cls.clear_chunk_num(doc.id)
try:
TaskService.filter_delete([Task.doc_id == doc.id])
page = 0
page_size = 1000
all_chunk_ids = []
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
page += 1
for cid in all_chunk_ids:
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
STORAGE_IMPL.rm(doc.kb_id, cid)
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = settings.docStoreConn.getFields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
except Exception:
pass
return cls.delete_by_id(doc.id)
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls):
fields = [
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
cls.model.name,
cls.model.type,
cls.model.location,
cls.model.size,
Knowledgebase.tenant_id,
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= current_timestamp() - 1000 * 600,
cls.model.run == TaskStatus.RUNNING.value) \
.order_by(cls.model.update_time.asc())
return list(docs.dicts())
@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run, cls.model.parser_id]
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.progress > 0)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duration=cls.model.process_duration + duration).where(
cls.model.id == doc_id).execute()
if num == 0:
logging.warning("Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration):
num = cls.model.update(token_num=cls.model.token_num - token_num,
chunk_num=cls.model.chunk_num - chunk_num,
process_duration=cls.model.process_duration + duration).where(
cls.model.id == doc_id).execute()
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def clear_chunk_num(cls, doc_id):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1
).where(
Knowledgebase.id == doc.kb_id).execute()
return num
@classmethod
@DB.connection_context()
def clear_chunk_num_when_rerun(cls, doc_id):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = (
Knowledgebase.update(
token_num=Knowledgebase.token_num - doc.token_num,
chunk_num=Knowledgebase.chunk_num - doc.chunk_num,
)
.where(Knowledgebase.id == doc.kb_id)
.execute()
)
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def get_knowledgebase_id(cls, doc_id):
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
docs = docs.dicts()
if not docs:
return
return docs[0]["kb_id"]
@classmethod
@DB.connection_context()
def get_tenant_id_by_name(cls, name):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(cls.model.id
).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(
UserTenant, on=(
(UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
).where(
cls.model.id == doc_id,
UserTenant.status == StatusEnum.VALID.value,
((UserTenant.role == UserTenantRole.NORMAL) | (UserTenant.role == UserTenantRole.OWNER))
).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return docs[0]["embd_id"]
@classmethod
@DB.connection_context()
def get_chunking_config(cls, doc_id):
configs = (
cls.model.select(
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
Knowledgebase.language,
Knowledgebase.embd_id,
Tenant.id.alias("tenant_id"),
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,
)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
)
configs = configs.dicts()
if not configs:
return None
return configs[0]
@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return
return doc_id[0]["id"]
@classmethod
@DB.connection_context()
def get_doc_ids_by_doc_names(cls, doc_names):
if not doc_names:
return []
query = cls.model.select(cls.model.id).where(cls.model.name.in_(doc_names))
return list(query.scalars().iterator())
@classmethod
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
if not config:
return
e, d = cls.get_by_id(id)
if not e:
raise LookupError(f"Document({id}) not found.")
def dfs_update(old, new):
for k, v in new.items():
if k not in old:
old[k] = v
continue
if isinstance(v, dict):
assert isinstance(old[k], dict)
dfs_update(old[k], v)
else:
old[k] = v
dfs_update(d.parser_config, config)
if not config.get("raptor") and d.parser_config.get("raptor"):
del d.parser_config["raptor"]
cls.update_by_id(id, {"parser_config": d.parser_config})
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase,
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
return len(docs)
@classmethod
@DB.connection_context()
def begin2parse(cls, docid):
cls.update_by_id(
docid, {"progress": random.random() * 1 / 100.,
"progress_msg": "Task is queued...",
"process_begin_at": get_format_time()
})
@classmethod
@DB.connection_context()
def update_meta_fields(cls, doc_id, meta_fields):
return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
@classmethod
@DB.connection_context()
def get_meta_by_kbs(cls, kb_ids):
fields = [
cls.model.id,
cls.model.meta_fields,
]
meta = {}
for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
doc_id = r.id
for k,v in r.meta_fields.items():
if k not in meta:
meta[k] = {}
v = str(v)
if v not in meta[k]:
meta[k][v] = []
meta[k][v].append(doc_id)
return meta
@classmethod
@DB.connection_context()
def update_progress(cls):
docs = cls.get_unfinished_docs()
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def update_progress_immediately(cls, docs:list[dict]):
if not docs:
return
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def _sync_progress(cls, docs:list[dict]):
for d in docs:
try:
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
if not tsks:
continue
msg = []
prg = 0
finished = True
bad = 0
e, doc = DocumentService.get_by_id(d["id"])
status = doc.run # TaskStatus.RUNNING.value
priority = 0
for t in tsks:
if 0 <= t.progress < 1:
finished = False
if t.progress == -1:
bad += 1
prg += t.progress if t.progress >= 0 else 0
if t.progress_msg.strip():
msg.append(t.progress_msg)
priority = max(priority, t.priority)
prg /= len(tsks)
if finished and bad:
prg = -1
status = TaskStatus.FAIL.value
elif finished:
prg = 1
status = TaskStatus.DONE.value
msg = "\n".join(sorted(msg))
info = {
"process_duration": datetime.timestamp(
datetime.now()) -
d["process_begin_at"].timestamp(),
"run": status}
if prg != 0:
info["progress"] = prg
if msg:
info["progress_msg"] = msg
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
cls.update_by_id(d["id"], info)
except Exception as e:
if str(e).find("'0'") < 0:
logging.exception("fetch task exception")
@classmethod
@DB.connection_context()
def get_kb_doc_count(cls, kb_id):
return cls.model.select().where(cls.model.kb_id == kb_id).count()
@classmethod
@DB.connection_context()
def get_all_kb_doc_count(cls):
result = {}
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
for row in rows:
result[row.kb_id] = row.count
return result
@classmethod
@DB.connection_context()
def do_cancel(cls, doc_id):
try:
_, doc = DocumentService.get_by_id(doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except Exception:
pass
return False
@classmethod
@DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
# cancelled: run == "2" but progress can vary
cancelled = (
cls.model.select(fn.COUNT(1))
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))
.scalar()
)
row = (
cls.model.select(
# finished: progress == 1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == 1, 1)], 0)), 0).alias("finished"),
# failed: progress == -1
fn.COALESCE(fn.SUM(Case(None, [(cls.model.progress == -1, 1)], 0)), 0).alias("failed"),
# processing: 0 <= progress < 1
fn.COALESCE(
fn.SUM(
Case(
None,
[
(((cls.model.progress == 0) | ((cls.model.progress > 0) & (cls.model.progress < 1))), 1),
],
0,
)
),
0,
).alias("processing"),
)
.where(
(cls.model.kb_id == kb_id)
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL))
)
.dicts()
.get()
)
return {
"processing": int(row["processing"]),
"finished": int(row["finished"]),
"failed": int(row["failed"]),
"cancelled": int(cancelled),
}
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
"""
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
Optionally, specify a list of doc_ids to determine which documents participate in the task.
"""
chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
hasher.update(str(chunking_config[field]).encode("utf-8"))
def new_task():
nonlocal doc
return {
"id": get_uuid(),
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"begin_at": datetime.now(),
}
task = new_task()
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
if ty in ["graphrag", "raptor", "mindmap"]:
task["doc_ids"] = doc_ids
DocumentService.begin2parse(doc["id"])
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
return task["id"]
def get_queue_length(priority):
group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
if not group_info:
return 0
return int(group_info.get("lag", 0) or 0)
async def doc_upload_and_parse(conversation_id, file_objs, user_id):
from api.db.services.api_service import API4ConversationService
from api.db.services.conversation_service import ConversationService
from api.db.services.dialog_service import DialogService
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from rag.app import audio, email, naive, picture, presentation
e, conv = ConversationService.get_by_id(conversation_id)
if not e:
e, conv = API4ConversationService.get_by_id(conversation_id)
assert e, "Conversation not found!"
e, dia = DialogService.get_by_id(conv.dialog_id)
if not dia.kb_ids:
raise LookupError("No knowledge base associated with this conversation. "
"Please add a knowledge base before uploading documents")
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
err, files = await FileService.upload_document(kb, file_objs, user_id)
assert not err, "\n".join(err)
def dummy(prog=None, msg=""):
pass
FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
doc_nm = {}
for d, blob in files:
doc_nm[d["id"]] = d["name"]
for d, blob in files:
kwargs = {
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": kb.tenant_id,
"lang": kb.language
}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
for (docinfo, _), th in zip(files, threads):
docs = []
doc = {
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
docs.append(d)
continue
output_buffer = BytesIO()
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["id"])
d.pop("image", None)
docs.append(d)
parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
docids = [d["id"] for d, _ in files]
chunk_counts = {id: 0 for id in docids}
token_counts = {id: 0 for id in docids}
es_bulk_size = 64
def embedding(doc_id, cnts, batch_size=16):
nonlocal embd_mdl, chunk_counts, token_counts
vects = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vects.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
token_counts[doc_id] += c
return vects
idxnm = search.index_name(kb.tenant_id)
try_create_idx = True
_, tenant = TenantService.get_by_id(kb.tenant_id)
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
for doc_id in docids:
cks = [c for c in docs if c["doc_id"] == doc_id]
if parser_ids[doc_id] != ParserType.PICTURE.value:
from graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map"
})
except Exception as e:
logging.exception("Mind map generation error")
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vects)
for i, d in enumerate(cks):
v = vects[i]
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d, _ in files]

View File

@@ -0,0 +1,96 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
from api.db import FileSource
from api.db.db_models import DB
from api.db.db_models import File, File2Document
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.utils import current_timestamp, datetime_format
class File2DocumentService(CommonService):
model = File2Document
@classmethod
@DB.connection_context()
def get_by_file_id(cls, file_id):
objs = cls.model.select().where(cls.model.file_id == file_id)
return objs
@classmethod
@DB.connection_context()
def get_by_document_id(cls, document_id):
objs = cls.model.select().where(cls.model.document_id == document_id)
return objs
@classmethod
@DB.connection_context()
def get_by_document_ids(cls, document_ids):
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
return list(objs.dicts())
@classmethod
@DB.connection_context()
def insert(cls, obj):
if not cls.save(**obj):
raise RuntimeError("Database error (File)!")
return File2Document(**obj)
@classmethod
@DB.connection_context()
def delete_by_file_id(cls, file_id):
return cls.model.delete().where(cls.model.file_id == file_id).execute()
@classmethod
@DB.connection_context()
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
if not document_ids:
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
elif not file_ids:
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
@classmethod
@DB.connection_context()
def delete_by_document_id(cls, doc_id):
return cls.model.delete().where(cls.model.document_id == doc_id).execute()
@classmethod
@DB.connection_context()
def update_by_file_id(cls, file_id, obj):
obj["update_time"] = current_timestamp()
obj["update_date"] = datetime_format(datetime.now())
cls.model.update(obj).where(cls.model.id == file_id).execute()
return File2Document(**obj)
@classmethod
@DB.connection_context()
def get_storage_address(cls, doc_id=None, file_id=None):
if doc_id:
f2d = cls.get_by_document_id(doc_id)
else:
f2d = cls.get_by_file_id(file_id)
if f2d:
file = File.get_by_id(f2d[0].file_id)
if not file.source_type or file.source_type == FileSource.LOCAL:
return file.parent_id, file.location
doc_id = f2d[0].document_id
assert doc_id, "please specify doc_id"
e, doc = DocumentService.get_by_id(doc_id)
return doc.kb_id, doc.location

View File

@@ -0,0 +1,547 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import traceback
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from flask_login import current_user
from peewee import fn
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase
from api.db.services import duplicate_name
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.utils import get_uuid
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
from rag.llm.cv_model import GptV4
from rag.utils.storage_factory import STORAGE_IMPL
class FileService(CommonService):
# Service class for managing file operations and storage
model = File
@classmethod
@DB.connection_context()
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords):
# Get files by parent folder ID with pagination and filtering
# Args:
# tenant_id: ID of the tenant
# pf_id: Parent folder ID
# page_number: Page number for pagination
# items_per_page: Number of items per page
# orderby: Field to order by
# desc: Boolean indicating descending order
# keywords: Search keywords
# Returns:
# Tuple of (file_list, total_count)
if keywords:
files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), (fn.LOWER(cls.model.name).contains(keywords.lower())), ~(cls.model.id == pf_id))
else:
files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), ~(cls.model.id == pf_id))
count = files.count()
if desc:
files = files.order_by(cls.model.getter_by(orderby).desc())
else:
files = files.order_by(cls.model.getter_by(orderby).asc())
files = files.paginate(page_number, items_per_page)
res_files = list(files.dicts())
for file in res_files:
if file["type"] == FileType.FOLDER.value:
file["size"] = cls.get_folder_size(file["id"])
file["kbs_info"] = []
children = list(
cls.model.select()
.where(
(cls.model.tenant_id == tenant_id),
(cls.model.parent_id == file["id"]),
~(cls.model.id == file["id"]),
)
.dicts()
)
file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children)
continue
kbs_info = cls.get_kb_id_by_file_id(file["id"])
file["kbs_info"] = kbs_info
return res_files, count
@classmethod
@DB.connection_context()
def get_kb_id_by_file_id(cls, file_id):
# Get knowledge base IDs associated with a file
# Args:
# file_id: File ID
# Returns:
# List of dictionaries containing knowledge base IDs and names
kbs = (
cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
.join(File2Document, on=(File2Document.file_id == file_id))
.join(Document, on=(File2Document.document_id == Document.id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
.where(cls.model.id == file_id)
)
if not kbs:
return []
kbs_info_list = []
for kb in list(kbs.dicts()):
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]})
return kbs_info_list
@classmethod
@DB.connection_context()
def get_by_pf_id_name(cls, id, name):
# Get file by parent folder ID and name
# Args:
# id: Parent folder ID
# name: File name
# Returns:
# File object or None if not found
file = cls.model.select().where((cls.model.parent_id == id) & (cls.model.name == name))
if file.count():
e, file = cls.get_by_id(file[0].id)
if not e:
raise RuntimeError("Database error (File retrieval)!")
return file
return None
@classmethod
@DB.connection_context()
def get_id_list_by_id(cls, id, name, count, res):
# Recursively get list of file IDs by traversing folder structure
# Args:
# id: Starting folder ID
# name: List of folder names to traverse
# count: Current depth in traversal
# res: List to store results
# Returns:
# List of file IDs
if count < len(name):
file = cls.get_by_pf_id_name(id, name[count])
if file:
res.append(file.id)
return cls.get_id_list_by_id(file.id, name, count + 1, res)
else:
return res
else:
return res
@classmethod
@DB.connection_context()
def get_all_innermost_file_ids(cls, folder_id, result_ids):
# Get IDs of all files in the deepest level of folders
# Args:
# folder_id: Starting folder ID
# result_ids: List to store results
# Returns:
# List of file IDs
subfolders = cls.model.select().where(cls.model.parent_id == folder_id)
if subfolders.exists():
for subfolder in subfolders:
cls.get_all_innermost_file_ids(subfolder.id, result_ids)
else:
result_ids.append(folder_id)
return result_ids
@classmethod
@DB.connection_context()
def get_all_file_ids_by_tenant_id(cls, tenant_id):
fields = [cls.model.id]
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
files.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
file_batch = files.offset(offset).limit(limit)
_temp = list(file_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def create_folder(cls, file, parent_id, name, count):
# Recursively create folder structure
# Args:
# file: Current file object
# parent_id: Parent folder ID
# name: List of folder names to create
# count: Current depth in creation
# Returns:
# Created file object
if count > len(name) - 2:
return file
else:
file = cls.insert(
{"id": get_uuid(), "parent_id": parent_id, "tenant_id": current_user.id, "created_by": current_user.id, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value}
)
return cls.create_folder(file, file.id, name, count + 1)
@classmethod
@DB.connection_context()
def is_parent_folder_exist(cls, parent_id):
# Check if parent folder exists
# Args:
# parent_id: Parent folder ID
# Returns:
# Boolean indicating if folder exists
parent_files = cls.model.select().where(cls.model.id == parent_id)
if parent_files.count():
return True
cls.delete_folder_by_pf_id(parent_id)
return False
@classmethod
@DB.connection_context()
def get_root_folder(cls, tenant_id):
# Get or create root folder for tenant
# Args:
# tenant_id: Tenant ID
# Returns:
# Root folder dictionary
for file in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
return file.to_dict()
file_id = get_uuid()
file = {
"id": file_id,
"parent_id": file_id,
"tenant_id": tenant_id,
"created_by": tenant_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
cls.save(**file)
return file
@classmethod
@DB.connection_context()
def get_kb_folder(cls, tenant_id):
# Get knowledge base folder for tenant
# Args:
# tenant_id: Tenant ID
# Returns:
# Knowledge base folder dictionary
root_folder = cls.get_root_folder(tenant_id)
root_id = root_folder["id"]
kb_folder = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == root_id), (cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)).first()
if not kb_folder:
kb_folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id)
return kb_folder
return kb_folder.to_dict()
@classmethod
@DB.connection_context()
def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""):
# Create a new file from knowledge base
# Args:
# tenant_id: Tenant ID
# name: File name
# parent_id: Parent folder ID
# ty: File type
# size: File size
# location: File location
# Returns:
# Created file dictionary
for file in cls.query(tenant_id=tenant_id, parent_id=parent_id, name=name):
return file.to_dict()
file = {
"id": get_uuid(),
"parent_id": parent_id,
"tenant_id": tenant_id,
"created_by": tenant_id,
"name": name,
"type": ty,
"size": size,
"location": location,
"source_type": FileSource.KNOWLEDGEBASE,
}
cls.save(**file)
return file
@classmethod
@DB.connection_context()
def init_knowledgebase_docs(cls, root_id, tenant_id):
# Initialize knowledge base documents
# Args:
# root_id: Root folder ID
# tenant_id: Tenant ID
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME) & (cls.model.parent_id == root_id)):
return
folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id)
for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id == tenant_id):
kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"])
for doc in DocumentService.query(kb_id=kb.id):
FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id)
@classmethod
@DB.connection_context()
def get_parent_folder(cls, file_id):
# Get parent folder of a file
# Args:
# file_id: File ID
# Returns:
# Parent folder object
file = cls.model.select().where(cls.model.id == file_id)
if file.count():
e, file = cls.get_by_id(file[0].parent_id)
if not e:
raise RuntimeError("Database error (File retrieval)!")
else:
raise RuntimeError("Database error (File doesn't exist)!")
return file
@classmethod
@DB.connection_context()
def get_all_parent_folders(cls, start_id):
# Get all parent folders in path
# Args:
# start_id: Starting file ID
# Returns:
# List of parent folder objects
parent_folders = []
current_id = start_id
while current_id:
e, file = cls.get_by_id(current_id)
if file.parent_id != file.id and e:
parent_folders.append(file)
current_id = file.parent_id
else:
parent_folders.append(file)
break
return parent_folders
@classmethod
@DB.connection_context()
def insert(cls, file):
# Insert a new file record
# Args:
# file: File data dictionary
# Returns:
# Created file object
if not cls.save(**file):
raise RuntimeError("Database error (File)!")
return File(**file)
@classmethod
@DB.connection_context()
def delete(cls, file):
#
return cls.delete_by_id(file.id)
@classmethod
@DB.connection_context()
def delete_by_pf_id(cls, folder_id):
return cls.model.delete().where(cls.model.parent_id == folder_id).execute()
@classmethod
@DB.connection_context()
def delete_folder_by_pf_id(cls, user_id, folder_id):
try:
files = cls.model.select().where((cls.model.tenant_id == user_id) & (cls.model.parent_id == folder_id))
for file in files:
cls.delete_folder_by_pf_id(user_id, file.id)
return (cls.model.delete().where((cls.model.tenant_id == user_id) & (cls.model.id == folder_id)).execute(),)
except Exception:
logging.exception("delete_folder_by_pf_id")
raise RuntimeError("Database error (File retrieval)!")
@classmethod
@DB.connection_context()
def get_file_count(cls, tenant_id):
files = cls.model.select(cls.model.id).where(cls.model.tenant_id == tenant_id)
return len(files)
@classmethod
@DB.connection_context()
def get_folder_size(cls, folder_id):
size = 0
def dfs(parent_id):
nonlocal size
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id, cls.model.id != parent_id):
size += f.size
if f.type == FileType.FOLDER.value:
dfs(f.id)
dfs(folder_id)
return size
@classmethod
@DB.connection_context()
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
for _ in File2DocumentService.get_by_document_id(doc["id"]):
return
file = {
"id": get_uuid(),
"parent_id": kb_folder_id,
"tenant_id": tenant_id,
"created_by": tenant_id,
"name": doc["name"],
"type": doc["type"],
"size": doc["size"],
"location": doc["location"],
"source_type": FileSource.KNOWLEDGEBASE,
}
cls.save(**file)
File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]})
@classmethod
@DB.connection_context()
def move_file(cls, file_ids, folder_id):
try:
cls.filter_update((cls.model.id << file_ids,), {"parent_id": folder_id})
except Exception:
logging.exception("move_file")
raise RuntimeError("Database error (File move)!")
@classmethod
@DB.connection_context()
async def upload_document(self, kb, file_objs, user_id):
root_folder = self.get_root_folder(user_id)
pf_id = root_folder["id"]
self.init_knowledgebase_docs(pf_id, user_id)
kb_root_folder = self.get_kb_folder(user_id)
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
err, files = [], []
for file in file_objs:
try:
DocumentService.check_doc_health(kb.tenant_id, file.filename)
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
filetype = filename_type(filename)
if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!")
location = filename
while STORAGE_IMPL.obj_exist(kb.id, location):
location += "_"
blob = await file.read()
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
STORAGE_IMPL.put(kb.id, location, blob)
doc_id = get_uuid()
img = thumbnail_img(filename, blob)
thumbnail_location = ""
if img is not None:
thumbnail_location = f"thumbnail_{doc_id}.png"
STORAGE_IMPL.put(kb.id, thumbnail_location, img)
doc = {
"id": doc_id,
"kb_id": kb.id,
"parser_id": self.get_parser(filetype, filename, kb.parser_id),
"pipeline_id": kb.pipeline_id,
"parser_config": kb.parser_config,
"created_by": user_id,
"type": filetype,
"name": filename,
"suffix": Path(filename).suffix.lstrip("."),
"location": location,
"size": len(blob),
"thumbnail": thumbnail_location,
}
DocumentService.insert(doc)
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
files.append((doc, blob))
except Exception as e:
traceback.print_exc()
err.append(file.filename + ": " + str(e))
return err, files
@staticmethod
async def parse_docs(file_objs, user_id):
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for file in file_objs:
# Check if file has async read method (UploadFile)
if hasattr(file, 'read') and hasattr(file.read, '__call__'):
try:
# Try to get the coroutine to check if it's async
read_result = file.read()
if hasattr(read_result, '__await__'):
# It's an async method, await it
blob = await read_result
else:
# It's a sync method
blob = read_result
except Exception:
# Fallback to sync read
blob = file.read()
else:
blob = file.read()
threads.append(exe.submit(FileService.parse, file.filename, blob, False))
res = []
for th in threads:
res.append(th.result())
return "\n\n".join(res)
@staticmethod
def parse(filename, blob, img_base64=True, tenant_id=None):
from rag.app import audio, email, naive, picture, presentation
def dummy(prog=None, msg=""):
pass
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": current_user.id if current_user else tenant_id}
file_type = filename_type(filename)
if img_base64 and file_type == FileType.VISUAL.value:
return GptV4.image2base64(blob)
cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs)
return "\n".join([ck["content_with_weight"] for ck in cks])
@staticmethod
def get_parser(doc_type, filename, default):
if doc_type == FileType.VISUAL:
return ParserType.PICTURE.value
if doc_type == FileType.AURAL:
return ParserType.AUDIO.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
return ParserType.PRESENTATION.value
if re.search(r"\.(msg|eml)$", filename):
return ParserType.EMAIL.value
return default
@staticmethod
def get_blob(user_id, location):
bname = f"{user_id}-downloads"
return STORAGE_IMPL.get(bname, location)
@staticmethod
def put_blob(user_id, location, blob):
bname = f"{user_id}-downloads"
return STORAGE_IMPL.put(bname, location, blob)

View File

@@ -0,0 +1,496 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
from peewee import fn, JOIN
from api.db import StatusEnum, TenantPermission
from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format
class KnowledgebaseService(CommonService):
"""Service class for managing knowledge base operations.
This class extends CommonService to provide specialized functionality for knowledge base
management, including document parsing status tracking, access control, and configuration
management. It handles operations such as listing, creating, updating, and deleting
knowledge bases, as well as managing their associated documents and permissions.
The class implements a comprehensive set of methods for:
- Document parsing status verification
- Knowledge base access control
- Parser configuration management
- Tenant-based knowledge base organization
Attributes:
model: The Knowledgebase model class for database operations.
"""
model = Knowledgebase
@classmethod
@DB.connection_context()
def accessible4deletion(cls, kb_id, user_id):
"""Check if a knowledge base can be deleted by a specific user.
This method verifies whether a user has permission to delete a knowledge base
by checking if they are the creator of that knowledge base.
Args:
kb_id (str): The unique identifier of the knowledge base to check.
user_id (str): The unique identifier of the user attempting the deletion.
Returns:
bool: True if the user has permission to delete the knowledge base,
False if the user doesn't have permission or the knowledge base doesn't exist.
Example:
>>> KnowledgebaseService.accessible4deletion("kb123", "user456")
True
Note:
- This method only checks creator permissions
- A return value of False can mean either:
1. The knowledge base doesn't exist
2. The user is not the creator of the knowledge base
"""
# Check if a knowledge base can be deleted by a user
docs = cls.model.select(
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def is_parsed_done(cls, kb_id):
# Check if all documents in the knowledge base have completed parsing
#
# Args:
# kb_id: Knowledge base ID
#
# Returns:
# If all documents are parsed successfully, returns (True, None)
# If any document is not fully parsed, returns (False, error_message)
from api.db import TaskStatus
from api.db.services.document_service import DocumentService
# Get knowledge base information
kbs = cls.query(id=kb_id)
if not kbs:
return False, "Knowledge base not found"
kb = kbs[0]
# Get all documents in the knowledge base
docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "", [], [])
# Check parsing status of each document
for doc in docs:
# If document is being parsed, don't allow chat creation
if doc['run'] == TaskStatus.RUNNING.value or doc['run'] == TaskStatus.CANCEL.value or doc['run'] == TaskStatus.FAIL.value:
return False, f"Document '{doc['name']}' in dataset '{kb.name}' is still being parsed. Please wait until all documents are parsed before starting a chat."
# If document is not yet parsed and has no chunks, don't allow chat creation
if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0:
return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat."
return True, None
@classmethod
@DB.connection_context()
def list_documents_by_ids(cls, kb_ids):
# Get document IDs associated with given knowledge base IDs
# Args:
# kb_ids: List of knowledge base IDs
# Returns:
# List of document IDs
doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(
cls.model.id.in_(kb_ids)
)
doc_ids = list(doc_ids.dicts())
doc_ids = [doc["document_id"] for doc in doc_ids]
return doc_ids
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page,
orderby, desc, keywords,
parser_id=None
):
# Get knowledge bases by tenant IDs with pagination and filtering
# Args:
# joined_tenant_ids: List of tenant IDs
# user_id: Current user ID
# page_number: Page number for pagination
# items_per_page: Number of items per page
# orderby: Field to order by
# desc: Boolean indicating descending order
# keywords: Search keywords
# parser_id: Optional parser ID filter
# Returns:
# Tuple of (knowledge_base_list, total_count)
fields = [
cls.model.id,
cls.model.avatar,
cls.model.name,
cls.model.language,
cls.model.description,
cls.model.tenant_id,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id,
cls.model.embd_id,
User.nickname,
User.avatar.alias('tenant_avatar'),
cls.model.update_time
]
if keywords:
kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else:
kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if parser_id:
kbs = kbs.where(cls.model.parser_id == parser_id)
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
count = kbs.count()
if page_number and items_per_page:
kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts()), count
@classmethod
@DB.connection_context()
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
# will get all permitted kb, be cautious.
fields = [
cls.model.name,
cls.model.language,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.status,
cls.model.create_date,
cls.model.update_date
]
# find team kb and owned kb
kbs = cls.model.select(*fields).where(
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id
)
)
# sort by create_time asc
kbs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later.
offset, limit = 0, 50
res = []
while True:
kb_batch = kbs.offset(offset).limit(limit)
_temp = list(kb_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_kb_ids(cls, tenant_id):
# Get all knowledge base IDs for a tenant
# Args:
# tenant_id: Tenant ID
# Returns:
# List of knowledge base IDs
fields = [
cls.model.id,
]
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
kb_ids = [kb.id for kb in kbs]
return kb_ids
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
# Get detailed information about a knowledge base
# Args:
# kb_id: Knowledge base ID
# Returns:
# Dictionary containing knowledge base details
fields = [
cls.model.id,
cls.model.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.language,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id,
cls.model.pipeline_id,
UserCanvas.title.alias("pipeline_name"),
UserCanvas.avatar.alias("pipeline_avatar"),
cls.model.parser_config,
cls.model.pagerank,
cls.model.graphrag_task_id,
cls.model.graphrag_task_finish_at,
cls.model.raptor_task_id,
cls.model.raptor_task_finish_at,
cls.model.mindmap_task_id,
cls.model.mindmap_task_finish_at,
cls.model.create_time,
cls.model.update_time
]
kbs = cls.model.select(*fields)\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
).dicts()
if not kbs:
return
return kbs[0]
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
# Update parser configuration for a knowledge base
# Args:
# id: Knowledge base ID
# config: New parser configuration
e, m = cls.get_by_id(id)
if not e:
raise LookupError(f"knowledgebase({id}) not found.")
def dfs_update(old, new):
# Deep update of nested configuration
for k, v in new.items():
if k not in old:
old[k] = v
continue
if isinstance(v, dict):
assert isinstance(old[k], dict)
dfs_update(old[k], v)
elif isinstance(v, list):
assert isinstance(old[k], list)
old[k] = list(set(old[k] + v))
else:
old[k] = v
dfs_update(m.parser_config, config)
cls.update_by_id(id, {"parser_config": m.parser_config})
@classmethod
@DB.connection_context()
def delete_field_map(cls, id):
e, m = cls.get_by_id(id)
if not e:
raise LookupError(f"knowledgebase({id}) not found.")
m.parser_config.pop("field_map", None)
cls.update_by_id(id, {"parser_config": m.parser_config})
@classmethod
@DB.connection_context()
def get_field_map(cls, ids):
# Get field mappings for knowledge bases
# Args:
# ids: List of knowledge base IDs
# Returns:
# Dictionary of field mappings
conf = {}
for k in cls.get_by_ids(ids):
if k.parser_config and "field_map" in k.parser_config:
conf.update(k.parser_config["field_map"])
return conf
@classmethod
@DB.connection_context()
def get_by_name(cls, kb_name, tenant_id):
# Get knowledge base by name and tenant ID
# Args:
# kb_name: Knowledge base name
# tenant_id: Tenant ID
# Returns:
# Tuple of (exists, knowledge_base)
kb = cls.model.select().where(
(cls.model.name == kb_name)
& (cls.model.tenant_id == tenant_id)
& (cls.model.status == StatusEnum.VALID.value)
)
if kb:
return True, kb[0]
return False, None
@classmethod
@DB.connection_context()
def get_all_ids(cls):
# Get all knowledge base IDs
# Returns:
# List of all knowledge base IDs
return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
@classmethod
@DB.connection_context()
def get_list(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc, id, name):
# Get list of knowledge bases with filtering and pagination
# Args:
# joined_tenant_ids: List of tenant IDs
# user_id: Current user ID
# page_number: Page number for pagination
# items_per_page: Number of items per page
# orderby: Field to order by
# desc: Boolean indicating descending order
# id: Optional ID filter
# name: Optional name filter
# Returns:
# List of knowledge bases
kbs = cls.model.select()
if id:
kbs = kbs.where(cls.model.id == id)
if name:
kbs = kbs.where(cls.model.name == name)
kbs = kbs.where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts())
@classmethod
@DB.connection_context()
def accessible(cls, kb_id, user_id):
# Check if a knowledge base is accessible by a user
# Args:
# kb_id: Knowledge base ID
# user_id: User ID
# Returns:
# Boolean indicating accessibility
docs = cls.model.select(
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def get_kb_by_id(cls, kb_id, user_id):
# Get knowledge base by ID and user ID
# Args:
# kb_id: Knowledge base ID
# user_id: User ID
# Returns:
# List containing knowledge base information
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
kbs = kbs.dicts()
return list(kbs)
@classmethod
@DB.connection_context()
def get_kb_by_name(cls, kb_name, user_id):
# Get knowledge base by name and user ID
# Args:
# kb_name: Knowledge base name
# user_id: User ID
# Returns:
# List containing knowledge base information
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
kbs = kbs.dicts()
return list(kbs)
@classmethod
@DB.connection_context()
def atomic_increase_doc_num_by_id(cls, kb_id):
data = {}
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
data["doc_num"] = cls.model.doc_num + 1
num = cls.model.update(data).where(cls.model.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def update_document_number_in_init(cls, kb_id, doc_num):
"""
Only use this function when init system
"""
ok, kb = cls.get_by_id(kb_id)
if not ok:
return
kb.doc_num = doc_num
dirty_fields = kb.dirty_fields
if cls.model._meta.combined.get("update_time") in dirty_fields:
dirty_fields.remove(cls.model._meta.combined["update_time"])
if cls.model._meta.combined.get("update_date") in dirty_fields:
dirty_fields.remove(cls.model._meta.combined["update_date"])
try:
kb.save(only=dirty_fields)
except ValueError as e:
if str(e) == "no data to save!":
pass # that's OK
else:
raise e
@classmethod
@DB.connection_context()
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
kb_row = cls.model.get_by_id(kb_id)
if not kb_row:
raise RuntimeError(f"kb_id {kb_id} does not exist")
update_dict = {
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
'token_num': kb_row.token_num - doc_num_info['token_num'],
'update_time': current_timestamp(),
'update_date': datetime_format(datetime.now())
}
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()

View File

@@ -0,0 +1,76 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from api.db.db_models import DB, TenantLangfuse
from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format
class TenantLangfuseService(CommonService):
"""
All methods that modify the status should be enclosed within a DB.atomic() context to ensure atomicity
and maintain data integrity in case of errors during execution.
"""
model = TenantLangfuse
@classmethod
@DB.connection_context()
def filter_by_tenant(cls, tenant_id):
fields = [cls.model.tenant_id, cls.model.host, cls.model.secret_key, cls.model.public_key]
try:
keys = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id).first()
return keys
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def filter_by_tenant_with_info(cls, tenant_id):
fields = [cls.model.tenant_id, cls.model.host, cls.model.secret_key, cls.model.public_key]
try:
keys = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id).dicts().first()
return keys
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def delete_ty_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
@classmethod
def update_by_tenant(cls, tenant_id, langfuse_keys):
langfuse_keys["update_time"] = current_timestamp()
langfuse_keys["update_date"] = datetime_format(datetime.now())
return cls.model.update(**langfuse_keys).where(cls.model.tenant_id == tenant_id).execute()
@classmethod
def save(cls, **kwargs):
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
obj = cls.model.create(**kwargs)
return obj
@classmethod
def delete_model(cls, langfuse_model):
langfuse_model.delete_instance()

View File

@@ -0,0 +1,279 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import logging
import re
from functools import partial
from typing import Generator
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
class LLMService(CommonService):
model = LLM
def get_init_tenant_llm(user_id):
from api import settings
tenant_llm = []
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG,
settings.EMBEDDING_CFG,
settings.ASR_CFG,
settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG,
]:
factory_name = factory_config["factory"]
if factory_name not in seen:
seen.add(factory_name)
factory_configs.append(factory_config)
for factory_config in factory_configs:
for llm in LLMService.query(fid=factory_config["factory"]):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": factory_config["api_key"],
"api_base": factory_config["base_url"],
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
if settings.LIGHTEN != 1:
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": fid,
"llm_name": mdlnm,
"model_type": "embedding",
"api_key": "",
"api_base": "",
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
}
)
unique = {}
for item in tenant_llm:
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
if key not in unique:
unique[key] = item
return list(unique.values())
class LLMBundle(LLM4Tenant):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
def bind_tools(self, toolcall_session, tools):
if not self.is_tools:
logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!")
return
self.mdl.bind_tools(toolcall_session, tools)
def encode(self, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
embeddings, used_tokens = self.mdl.encode(texts)
llm_name = getattr(self, "llm_name", None)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(usage_details={"total_tokens": used_tokens})
generation.end()
return embeddings, used_tokens
def encode_queries(self, query: str):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.llm_name, input={"query": query})
emd, used_tokens = self.mdl.encode_queries(query)
llm_name = getattr(self, "llm_name", None)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name):
logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(usage_details={"total_tokens": used_tokens})
generation.end()
return emd, used_tokens
def similarity(self, query: str, texts: list):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.llm_name, input={"query": query, "texts": texts})
sim, used_tokens = self.mdl.similarity(query, texts)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(usage_details={"total_tokens": used_tokens})
generation.end()
return sim, used_tokens
def describe(self, image, max_tokens=300):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.llm_name})
txt, used_tokens = self.mdl.describe(image)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
def describe_with_prompt(self, image, prompt):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.llm_name, "prompt": prompt})
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
def transcription(self, audio):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.llm_name})
txt, used_tokens = self.mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error("LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
def tts(self, text: str) -> Generator[bytes, None, None]:
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
for chunk in self.mdl.tts(text):
if isinstance(chunk, int):
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, chunk, self.llm_name):
logging.error("LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
if self.langfuse:
generation.end()
def _remove_reasoning_content(self, txt: str) -> str:
first_think_start = txt.find("<think>")
if first_think_start == -1:
return txt
last_think_end = txt.rfind("</think>")
if last_think_end == -1:
return txt
if last_think_end < first_think_start:
return txt
return txt[last_think_end + len("</think>") :]
@staticmethod
def _clean_param(chat_partial, **kwargs):
func = chat_partial.func
sig = inspect.signature(func)
keyword_args = []
support_var_args = False
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD or param.kind == inspect.Parameter.VAR_POSITIONAL:
support_var_args = True
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
keyword_args.append(param.name)
use_kwargs = kwargs
if not support_var_args:
use_kwargs = {k: v for k, v in kwargs.items() if k in keyword_args}
return use_kwargs
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(self.mdl.chat, system, history, gen_conf)
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
txt, used_tokens = chat_partial(**use_kwargs)
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = ""
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
if self.langfuse:
generation.update(output={"output": ans})
generation.end()
break
if txt.endswith("</think>"):
ans = ans.rstrip("</think>")
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield ans
if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))

View File

@@ -0,0 +1,91 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from peewee import fn
from api.db.db_models import DB, MCPServer
from api.db.services.common_service import CommonService
class MCPServerService(CommonService):
"""Service class for managing MCP server related database operations.
This class extends CommonService to provide specialized functionality for MCP server management,
including MCP server creation, updates, and deletions.
Attributes:
model: The MCPServer model class for database operations.
"""
model = MCPServer
@classmethod
@DB.connection_context()
def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, keywords):
"""Retrieve all MCP servers associated with a tenant.
This method fetches all MCP servers for a given tenant, ordered by creation time.
It only includes fields for list display.
Args:
tenant_id (str): The unique identifier of the tenant.
id_list (list[str]): Get servers by ID list. Will ignore this condition if None.
Returns:
list[dict]: List of MCP server dictionaries containing MCP server details.
Returns None if no MCP servers are found.
"""
fields = [
cls.model.id,
cls.model.name,
cls.model.server_type,
cls.model.url,
cls.model.description,
cls.model.variables,
cls.model.create_date,
cls.model.update_date,
]
query = cls.model.select(*fields).order_by(cls.model.create_time.desc()).where(cls.model.tenant_id == tenant_id)
if id_list:
query = query.where(cls.model.id.in_(id_list))
if keywords:
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if desc:
query = query.order_by(cls.model.getter_by(orderby).desc())
else:
query = query.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
query = query.paginate(page_number, items_per_page)
servers = list(query.dicts())
if not servers:
return None
return servers
@classmethod
@DB.connection_context()
def get_by_name_and_tenant(cls, name: str, tenant_id: str):
try:
mcp_server = cls.model.query(name=name, tenant_id=tenant_id)
return bool(mcp_server), mcp_server
except Exception:
return False, None
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id: str):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()

View File

@@ -0,0 +1,263 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
from datetime import datetime, timedelta
from peewee import fn
from api.db import VALID_PIPELINE_TASK_TYPES, PipelineTaskType
from api.db.db_models import DB, Document, PipelineOperationLog
from api.db.services.canvas_service import UserCanvasService
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID
from api.utils import current_timestamp, datetime_format, get_uuid
class PipelineOperationLogService(CommonService):
model = PipelineOperationLog
@classmethod
def get_file_logs_fields(cls):
return [
cls.model.id,
cls.model.document_id,
cls.model.tenant_id,
cls.model.kb_id,
cls.model.pipeline_id,
cls.model.pipeline_title,
cls.model.parser_id,
cls.model.document_name,
cls.model.document_suffix,
cls.model.document_type,
cls.model.source_from,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.dsl,
cls.model.task_type,
cls.model.operation_status,
cls.model.avatar,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod
def get_dataset_logs_fields(cls):
return [
cls.model.id,
cls.model.tenant_id,
cls.model.kb_id,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.task_type,
cls.model.operation_status,
cls.model.avatar,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod
def save(cls, **kwargs):
"""
wrap this function in a transaction
"""
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
@DB.connection_context()
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl: str = "{}"):
referred_document_id = document_id
if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
referred_document_id = fake_document_ids[0]
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
DocumentService.update_progress_immediately([document.to_dict()])
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
if document.progress not in [1, -1]:
return
operation_status = document.run
if pipeline_id:
ok, user_pipeline = UserCanvasService.get_by_id(pipeline_id)
if not ok:
raise RuntimeError(f"Pipeline {pipeline_id} not found")
tenant_id = user_pipeline.user_id
title = user_pipeline.title
avatar = user_pipeline.avatar
else:
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
if not ok:
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
tenant_id = kb_info.tenant_id
title = document.parser_id
avatar = document.thumbnail
if task_type not in VALID_PIPELINE_TASK_TYPES:
raise ValueError(f"Invalid task type: {task_type}")
if task_type in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
finish_at = document.process_begin_at + timedelta(seconds=document.process_duration)
if task_type == PipelineTaskType.GRAPH_RAG:
KnowledgebaseService.update_by_id(
document.kb_id,
{"graphrag_task_finish_at": finish_at},
)
elif task_type == PipelineTaskType.RAPTOR:
KnowledgebaseService.update_by_id(
document.kb_id,
{"raptor_task_finish_at": finish_at},
)
elif task_type == PipelineTaskType.MINDMAP:
KnowledgebaseService.update_by_id(
document.kb_id,
{"mindmap_task_finish_at": finish_at},
)
log = dict(
id=get_uuid(),
document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id
tenant_id=tenant_id,
kb_id=document.kb_id,
pipeline_id=pipeline_id,
pipeline_title=title,
parser_id=document.parser_id,
document_name=document.name,
document_suffix=document.suffix,
document_type=document.type,
source_from="", # TODO: add in the future
progress=document.progress,
progress_msg=document.progress_msg,
process_begin_at=document.process_begin_at,
process_duration=document.process_duration,
dsl=json.loads(dsl),
task_type=task_type,
operation_status=operation_status,
avatar=avatar,
)
log["create_time"] = current_timestamp()
log["create_date"] = datetime_format(datetime.now())
log["update_time"] = current_timestamp()
log["update_date"] = datetime_format(datetime.now())
with DB.atomic():
obj = cls.save(**log)
limit = int(os.getenv("PIPELINE_OPERATION_LOG_LIMIT", 1000))
total = cls.model.select().where(cls.model.kb_id == document.kb_id).count()
if total > limit:
keep_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == document.kb_id).order_by(cls.model.create_time.desc()).limit(limit)]
deleted = cls.model.delete().where(cls.model.kb_id == document.kb_id, cls.model.id.not_in(keep_ids)).execute()
logging.info(f"[PipelineOperationLogService] Cleaned {deleted} old logs, kept latest {limit} for {document.kb_id}")
return obj
@classmethod
@DB.connection_context()
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
@classmethod
@DB.connection_context()
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from=None, create_date_to=None):
fields = cls.get_file_logs_fields()
if keywords:
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
else:
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
logs = logs.where(cls.model.document_id != GRAPH_RAPTOR_FAKE_DOC_ID)
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
if types:
logs = logs.where(cls.model.document_type.in_(types))
if suffix:
logs = logs.where(cls.model.document_suffix.in_(suffix))
if create_date_from:
logs = logs.where(cls.model.create_date >= create_date_from)
if create_date_to:
logs = logs.where(cls.model.create_date <= create_date_to)
count = logs.count()
if desc:
logs = logs.order_by(cls.model.getter_by(orderby).desc())
else:
logs = logs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count
@classmethod
@DB.connection_context()
def get_documents_info(cls, id):
fields = [Document.id, Document.name, Document.progress, Document.kb_id]
return (
cls.model.select(*fields)
.join(Document, on=(cls.model.document_id == Document.id))
.where(
cls.model.id == id
)
.dicts()
)
@classmethod
@DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None):
fields = cls.get_dataset_logs_fields()
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
if create_date_from:
logs = logs.where(cls.model.create_date >= create_date_from)
if create_date_to:
logs = logs.where(cls.model.create_date <= create_date_to)
count = logs.count()
if desc:
logs = logs.order_by(cls.model.getter_by(orderby).desc())
else:
logs = logs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count

View File

@@ -0,0 +1,117 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
from peewee import fn
from api.db import StatusEnum
from api.db.db_models import DB, Search, User
from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format
class SearchService(CommonService):
model = Search
@classmethod
def save(cls, **kwargs):
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
obj = cls.model.create(**kwargs)
return obj
@classmethod
@DB.connection_context()
def accessible4deletion(cls, search_id, user_id) -> bool:
search = (
cls.model.select(cls.model.id)
.where(
cls.model.id == search_id,
cls.model.created_by == user_id,
cls.model.status == StatusEnum.VALID.value,
)
.first()
)
return search is not None
@classmethod
@DB.connection_context()
def get_detail(cls, search_id):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.tenant_id,
cls.model.name,
cls.model.description,
cls.model.created_by,
cls.model.search_config,
cls.model.update_time,
User.nickname,
User.avatar.alias("tenant_avatar"),
]
search = (
cls.model.select(*fields)
.join(User, on=((User.id == cls.model.tenant_id) & (User.status == StatusEnum.VALID.value)))
.where((cls.model.id == search_id) & (cls.model.status == StatusEnum.VALID.value))
.first()
.to_dict()
)
if not search:
return {}
return search
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.tenant_id,
cls.model.name,
cls.model.description,
cls.model.created_by,
cls.model.status,
cls.model.update_time,
cls.model.create_time,
User.nickname,
User.avatar.alias("tenant_avatar"),
]
query = (
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value))
)
if keywords:
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if desc:
query = query.order_by(cls.model.getter_by(orderby).desc())
else:
query = query.order_by(cls.model.getter_by(orderby).asc())
count = query.count()
if page_number and items_per_page:
query = query.paginate(page_number, items_per_page)
return list(query.dicts()), count
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()

View File

@@ -0,0 +1,522 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import random
import xxhash
from datetime import datetime
from api.db.db_utils import bulk_insert_into_db
from deepdoc.parser import PdfParser
from peewee import JOIN
from api.db.db_models import DB, File2Document, File
from api.db import StatusEnum, FileType, TaskStatus
from api.db.db_models import Task, Document, Knowledgebase, Tenant
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.utils import current_timestamp, get_uuid
from deepdoc.parser.excel_parser import RAGFlowExcelParser
from rag.settings import get_svr_queue_name
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
from api import settings
from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x"
GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x"
def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks
# Args:
# text: Input text to trim
# max_length: Maximum allowed length
# Returns:
# Trimmed text
len_text = len(text)
if len_text <= max_length:
return text
for i in range(len_text):
if text[i] == '\n' and len_text - i <= max_length:
return text[i + 1:]
return text
class TaskService(CommonService):
"""Service class for managing document processing tasks.
This class extends CommonService to provide specialized functionality for document
processing task management, including task creation, progress tracking, and chunk
management. It handles various document types (PDF, Excel, etc.) and manages their
processing lifecycle.
The class implements a robust task queue system with retry mechanisms and progress
tracking, supporting both synchronous and asynchronous task execution.
Attributes:
model: The Task model class for database operations.
"""
model = Task
@classmethod
@DB.connection_context()
def get_task(cls, task_id, doc_ids=[]):
"""Retrieve detailed task information by task ID.
This method fetches comprehensive task details including associated document,
knowledge base, and tenant information. It also handles task retry logic and
progress updates.
Args:
task_id (str): The unique identifier of the task to retrieve.
Returns:
dict: Task details dictionary containing all task information and related metadata.
Returns None if task is not found or has exceeded retry limit.
"""
doc_id = cls.model.doc_id
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
doc_id = doc_ids[0]
fields = [
cls.model.id,
cls.model.doc_id,
cls.model.from_page,
cls.model.to_page,
cls.model.retry_count,
Document.kb_id,
Document.parser_id,
Document.parser_config,
Document.name,
Document.type,
Document.location,
Document.size,
Knowledgebase.tenant_id,
Knowledgebase.language,
Knowledgebase.embd_id,
Knowledgebase.pagerank,
Knowledgebase.parser_config.alias("kb_parser_config"),
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,
cls.model.update_time,
]
docs = (
cls.model.select(*fields)
.join(Document, on=(doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id)
)
docs = list(docs.dicts())
if not docs:
return None
msg = f"\n{datetime.now().strftime('%H:%M:%S')} Task has been received."
prog = random.random() / 10.0
if docs[0]["retry_count"] >= 3:
msg = "\nERROR: Task is abandoned after 3 times attempts."
prog = -1
cls.model.update(
progress_msg=cls.model.progress_msg + msg,
progress=prog,
retry_count=docs[0]["retry_count"] + 1,
).where(cls.model.id == docs[0]["id"]).execute()
if docs[0]["retry_count"] >= 3:
return None
return docs[0]
@classmethod
@DB.connection_context()
def get_tasks(cls, doc_id: str):
"""Retrieve all tasks associated with a document.
This method fetches all processing tasks for a given document, ordered by page
number and creation time. It includes task progress and chunk information.
Args:
doc_id (str): The unique identifier of the document.
Returns:
list[dict]: List of task dictionaries containing task details.
Returns None if no tasks are found.
"""
fields = [
cls.model.id,
cls.model.from_page,
cls.model.progress,
cls.model.digest,
cls.model.chunk_ids,
]
tasks = (
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
.where(cls.model.doc_id == doc_id)
)
tasks = list(tasks.dicts())
if not tasks:
return None
return tasks
@classmethod
@DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str):
"""Update the chunk IDs associated with a task.
This method updates the chunk_ids field of a task, which stores the IDs of
processed document chunks in a space-separated string format.
Args:
id (str): The unique identifier of the task.
chunk_ids (str): Space-separated string of chunk identifiers.
"""
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()
@classmethod
@DB.connection_context()
def get_ongoing_doc_name(cls):
"""Get names of documents that are currently being processed.
This method retrieves information about documents that are in the processing state,
including their locations and associated IDs. It uses database locking to ensure
thread safety when accessing the task information.
Returns:
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
for documents currently being processed. Returns empty list if
no documents are being processed.
"""
with DB.lock("get_task", -1):
docs = (
cls.model.select(
*[Document.id, Document.kb_id, Document.location, File.parent_id]
)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(
File2Document,
on=(File2Document.document_id == Document.id),
join_type=JOIN.LEFT_OUTER,
)
.join(
File,
on=(File2Document.file_id == File.id),
join_type=JOIN.LEFT_OUTER,
)
.where(
Document.status == StatusEnum.VALID.value,
Document.run == TaskStatus.RUNNING.value,
~(Document.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.create_time >= current_timestamp() - 1000 * 600,
)
)
docs = list(docs.dicts())
if not docs:
return []
return list(
set(
[
(
d["parent_id"] if d["parent_id"] else d["kb_id"],
d["location"],
)
for d in docs
]
)
)
@classmethod
@DB.connection_context()
def do_cancel(cls, id):
"""Check if a task should be cancelled based on its document status.
This method determines whether a task should be cancelled by checking the
associated document's run status and progress. A task should be cancelled
if its document is marked for cancellation or has negative progress.
Args:
id (str): The unique identifier of the task to check.
Returns:
bool: True if the task should be cancelled, False otherwise.
"""
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
@classmethod
@DB.connection_context()
def update_progress(cls, id, info):
"""Update the progress information for a task.
This method updates both the progress message and completion percentage of a task.
It handles platform-specific behavior (macOS vs others) and uses database locking
when necessary to ensure thread safety.
Update Rules:
- progress_msg: Always appends the new message to the existing one, and trims the result to max 3000 lines.
- progress: Only updates if the current progress is not -1 AND
(the new progress is -1 OR greater than the existing progress),
to avoid overwriting valid progress with invalid or regressive values.
Args:
id (str): The unique identifier of the task to update.
info (dict): Dictionary containing progress information with keys:
- progress_msg (str, optional): Progress message to append
- progress (float, optional): Progress percentage (0.0 to 1.0)
"""
task = cls.model.get_by_id(id)
if not task:
logging.warning("Update_progress error: task not found")
return
if os.environ.get("MACOS"):
if info["progress_msg"]:
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
prog = info["progress"]
cls.model.update(progress=prog).where(
(cls.model.id == id) &
(
(cls.model.progress != -1) &
((prog == -1) | (prog > cls.model.progress))
)
).execute()
else:
with DB.lock("update_progress", -1):
if info["progress_msg"]:
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
prog = info["progress"]
cls.model.update(progress=prog).where(
(cls.model.id == id) &
(
(cls.model.progress != -1) &
((prog == -1) | (prog > cls.model.progress))
)
).execute()
process_duration = (datetime.now() - task.begin_at).total_seconds()
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
@classmethod
@DB.connection_context()
def delete_by_doc_ids(cls, doc_ids):
"""Delete task associated with a document."""
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
"""Create and queue document processing tasks.
This function creates processing tasks for a document based on its type and configuration.
It handles different document types (PDF, Excel, etc.) differently and manages task
chunking and configuration. It also implements task reuse optimization by checking
for previously completed tasks.
Args:
doc (dict): Document dictionary containing metadata and configuration.
bucket (str): Storage bucket name where the document is stored.
name (str): File name of the document.
priority (int, optional): Priority level for task queueing (default is 0).
Note:
- For PDF documents, tasks are created per page range based on configuration
- For Excel documents, tasks are created per row range
- Task digests are calculated for optimization and reuse
- Previous task chunks may be reused if available
"""
def new_task():
return {
"id": get_uuid(),
"doc_id": doc["id"],
"progress": 0.0,
"from_page": 0,
"to_page": 100000000,
"begin_at": datetime.now(),
}
parse_task_array = []
if doc["type"] == FileType.PDF.value:
file_bin = STORAGE_IMPL.get(bucket, name)
do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
pages = PdfParser.total_page_number(doc["name"], file_bin)
if pages is None:
pages = 0
page_size = doc["parser_config"].get("task_page_size") or 12
if doc["parser_id"] == "paper":
page_size = doc["parser_config"].get("task_page_size") or 22
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC" or doc["parser_config"].get("toc", True):
page_size = 10 ** 9
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
for s, e in page_ranges:
s -= 1
s = max(0, s)
e = min(e - 1, pages)
for p in range(s, e, page_size):
task = new_task()
task["from_page"] = p
task["to_page"] = min(p + page_size, e)
parse_task_array.append(task)
elif doc["parser_id"] == "table":
file_bin = STORAGE_IMPL.get(bucket, name)
rn = RAGFlowExcelParser.row_number(doc["name"], file_bin)
for i in range(0, rn, 3000):
task = new_task()
task["from_page"] = i
task["to_page"] = min(i + 3000, rn)
parse_task_array.append(task)
else:
parse_task_array.append(new_task())
chunking_config = DocumentService.get_chunking_config(doc["id"])
for task in parse_task_array:
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
if field == "parser_config":
for k in ["raptor", "graphrag"]:
if k in chunking_config[field]:
del chunking_config[field][k]
hasher.update(str(chunking_config[field]).encode("utf-8"))
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
task_digest = hasher.hexdigest()
task["digest"] = task_digest
task["progress"] = 0.0
task["priority"] = priority
prev_tasks = TaskService.get_tasks(doc["id"])
ck_num = 0
if prev_tasks:
for task in parse_task_array:
ck_num += reuse_prev_task_chunks(task, prev_tasks, chunking_config)
TaskService.filter_delete([Task.doc_id == doc["id"]])
pre_chunk_ids = []
for pre_task in prev_tasks:
if pre_task["chunk_ids"]:
pre_chunk_ids.extend(pre_task["chunk_ids"].split())
if pre_chunk_ids:
settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
chunking_config["kb_id"])
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
bulk_insert_into_db(Task, parse_task_array, True)
DocumentService.begin2parse(doc["id"])
unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0]
for unfinished_task in unfinished_task_array:
assert REDIS_CONN.queue_product(
get_svr_queue_name(priority), message=unfinished_task
), "Can't access Redis. Please check the Redis' status."
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
"""Attempt to reuse chunks from previous tasks for optimization.
This function checks if chunks from previously completed tasks can be reused for
the current task, which can significantly improve processing efficiency. It matches
tasks based on page ranges and configuration digests.
Args:
task (dict): Current task dictionary to potentially reuse chunks for.
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
chunking_config (dict): Configuration dictionary for chunk processing.
Returns:
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
Note:
Chunks can only be reused if:
- A previous task exists with matching page range and configuration digest
- The previous task was completed successfully (progress = 1.0)
- The previous task has valid chunk IDs
"""
idx = 0
while idx < len(prev_tasks):
prev_task = prev_tasks[idx]
if prev_task.get("from_page", 0) == task.get("from_page", 0) \
and prev_task.get("digest", 0) == task.get("digest", ""):
break
idx += 1
if idx >= len(prev_tasks):
return 0
prev_task = prev_tasks[idx]
if prev_task["progress"] < 1.0 or not prev_task["chunk_ids"]:
return 0
task["chunk_ids"] = prev_task["chunk_ids"]
task["progress"] = 1.0
if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
else:
task["progress_msg"] = ""
task["progress_msg"] = " ".join(
[datetime.now().strftime("%H:%M:%S"), task["progress_msg"], "Reused previous task's chunks."])
prev_task["chunk_ids"] = ""
return len(task["chunk_ids"].split())
def cancel_all_task_of(doc_id):
for t in TaskService.query(doc_id=doc_id):
try:
REDIS_CONN.set(f"{t.id}-cancel", "x")
except Exception as e:
logging.exception(e)
def has_canceled(task_id):
try:
if REDIS_CONN.get(f"{task_id}-cancel"):
return True
except Exception as e:
logging.exception(e)
return False
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
task = dict(
id=task_id,
doc_id=doc_id,
from_page=0,
to_page=100000000,
task_type="dataflow" if not rerun else "dataflow_rerun",
priority=priority,
begin_at=datetime.now(),
)
if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]:
TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute()
DocumentService.begin2parse(doc_id)
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
task["tenant_id"] = tenant_id
task["dataflow_id"] = flow_id
task["file"] = file
if not REDIS_CONN.queue_product(
get_svr_queue_name(priority), message=task
):
return False, "Can't access Redis. Please check the Redis' status."
return True, ""

View File

@@ -0,0 +1,257 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from langfuse import Langfuse
from api import settings
from api.db import LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.user_service import TenantService
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
class LLMFactoriesService(CommonService):
model = LLMFactories
class TenantLLMService(CommonService):
model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
if not fid:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
else:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
if (not objs) and fid:
if fid == "LocalAI":
mdlnm += "___LocalAI"
elif fid == "HuggingFace":
mdlnm += "___HuggingFace"
elif fid == "OpenAI-API-Compatible":
mdlnm += "___OpenAI-API"
elif fid == "VLLM":
mdlnm += "___VLLM"
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
if not objs:
return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@staticmethod
def split_model_name_and_factory(model_name):
arr = model_name.split("@")
if len(arr) < 2:
return model_name, None
if len(arr) > 2:
return "@".join(arr[0:-1]), arr[-1]
# model name must be xxx@yyy
try:
model_factories = settings.FACTORY_LLM_INFOS
model_providers = set([f["name"] for f in model_factories])
if arr[-1] not in model_providers:
return model_name, None
return arr[0], arr[-1]
except Exception as e:
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
return model_name, None
@classmethod
@DB.connection_context()
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
from api.db.services.llm_service import LLMService
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id if not llm_name else llm_name
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id if not llm_name else llm_name
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
elif llm_type == LLMType.RERANK:
mdlnm = tenant.rerank_id if not llm_name else llm_name
elif llm_type == LLMType.TTS:
mdlnm = tenant.tts_id if not llm_name else llm_name
else:
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if not model_config: # for some cases seems fid mismatch
model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config:
model_config = model_config.to_dict()
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if not llm and fid: # for some cases seems fid mismatch
llm = LLMService.query(llm_name=mdlnm)
if llm:
model_config["is_tools"] = llm[0].is_tools
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
else:
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
raise LookupError("Model({}) not authorized".format(mdlnm))
return model_config
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
kwargs.update({"provider": model_config["llm_factory"]})
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return
return TTSModel[model_config["llm_factory"]](
model_config["api_key"],
model_config["llm_name"],
base_url=model_config["api_base"],
)
@classmethod
@DB.connection_context()
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
logging.error(f"Tenant not found: {tenant_id}")
return 0
llm_map = {
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
LLMType.SPEECH2TEXT.value: tenant.asr_id,
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
}
mdlnm = llm_map.get(llm_type)
if mdlnm is None:
logging.error(f"LLM type error: {llm_type}")
return 0
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
try:
num = (
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
.execute()
)
except Exception:
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
return 0
return num
@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
@staticmethod
def llm_id2llm_type(llm_id: str) -> str | None:
from api.db.services.llm_service import LLMService
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].split(",")[-1]
for llm in LLMService.query(llm_name=llm_id):
return llm.model_type
llm = TenantLLMService.get_or_none(llm_name=llm_id)
if llm:
return llm.model_type
for llm in TenantLLMService.query(llm_name=llm_id):
return llm.model_type
class LLM4Tenant:
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192)
self.is_tools = model_config.get("is_tools", False)
self.verbose_tool_use = kwargs.get("verbose_tool_use")
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
self.langfuse = None
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
if langfuse.auth_check():
self.langfuse = langfuse
trace_id = self.langfuse.create_trace_id()
self.trace_context = {"trace_id": trace_id}

View File

@@ -0,0 +1,63 @@
from api.db.db_models import UserCanvasVersion, DB
from api.db.services.common_service import CommonService
from peewee import DoesNotExist
class UserCanvasVersionService(CommonService):
model = UserCanvasVersion
@classmethod
@DB.connection_context()
def list_by_canvas_id(cls, user_canvas_id):
try:
user_canvas_version = cls.model.select(
*[cls.model.id,
cls.model.create_time,
cls.model.title,
cls.model.create_date,
cls.model.update_date,
cls.model.user_canvas_id,
cls.model.update_time]
).where(cls.model.user_canvas_id == user_canvas_id)
return user_canvas_version
except DoesNotExist:
return None
except Exception:
return None
@classmethod
@DB.connection_context()
def get_all_canvas_version_by_canvas_ids(cls, canvas_ids):
fields = [cls.model.id]
versions = cls.model.select(*fields).where(cls.model.user_canvas_id.in_(canvas_ids))
versions.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
version_batch = versions.offset(offset).limit(limit)
_temp = list(version_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def delete_all_versions(cls, user_canvas_id):
try:
user_canvas_version = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by(cls.model.create_time.desc())
if user_canvas_version.count() > 20:
delete_ids = []
for i in range(20, user_canvas_version.count()):
delete_ids.append(user_canvas_version[i].id)
cls.delete_by_ids(delete_ids)
return True
except DoesNotExist:
return None
except Exception:
return None

View File

@@ -0,0 +1,318 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
from datetime import datetime
import logging
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant
from api.db.services.common_service import CommonService
from api.utils import get_uuid, current_timestamp, datetime_format
from api.db import StatusEnum
from rag.settings import MINIO
class UserService(CommonService):
"""Service class for managing user-related database operations.
This class extends CommonService to provide specialized functionality for user management,
including authentication, user creation, updates, and deletions.
Attributes:
model: The User model class for database operations.
"""
model = User
@classmethod
@DB.connection_context()
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
if 'access_token' in kwargs:
access_token = kwargs['access_token']
# Reject empty, None, or whitespace-only access tokens
if not access_token or not str(access_token).strip():
logging.warning("UserService.query: Rejecting empty access_token query")
return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result
# Reject tokens that are too short (should be UUID, 32+ chars)
if len(str(access_token).strip()) < 32:
logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars")
return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result
# Reject tokens that start with "INVALID_" (from logout)
if str(access_token).startswith("INVALID_"):
logging.warning("UserService.query: Rejecting invalidated access_token")
return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result
# Call parent query method for valid requests
return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@classmethod
@DB.connection_context()
def filter_by_id(cls, user_id):
"""Retrieve a user by their ID.
Args:
user_id: The unique identifier of the user.
Returns:
User object if found, None otherwise.
"""
try:
user = cls.model.select().where(cls.model.id == user_id).get()
return user
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def query_user(cls, email, password):
"""Authenticate a user with email and password.
Args:
email: User's email address.
password: User's password in plain text.
Returns:
User object if authentication successful, None otherwise.
"""
user = cls.model.select().where((cls.model.email == email),
(cls.model.status == StatusEnum.VALID.value)).first()
if user and check_password_hash(str(user.password), password):
return user
else:
return None
@classmethod
@DB.connection_context()
def query_user_by_email(cls, email):
users = cls.model.select().where((cls.model.email == email))
return list(users)
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
if "password" in kwargs:
kwargs["password"] = generate_password_hash(
str(kwargs["password"]))
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
obj = cls.model(**kwargs).save(force_insert=True)
return obj
@classmethod
@DB.connection_context()
def delete_user(cls, user_ids, update_user_dict):
with DB.atomic():
cls.model.update({"status": 0}).where(
cls.model.id.in_(user_ids)).execute()
@classmethod
@DB.connection_context()
def update_user(cls, user_id, user_dict):
with DB.atomic():
if user_dict:
user_dict["update_time"] = current_timestamp()
user_dict["update_date"] = datetime_format(datetime.now())
cls.model.update(user_dict).where(
cls.model.id == user_id).execute()
@classmethod
@DB.connection_context()
def update_user_password(cls, user_id, new_password):
with DB.atomic():
update_dict = {
"password": generate_password_hash(str(new_password)),
"update_time": current_timestamp(),
"update_date": datetime_format(datetime.now())
}
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
@classmethod
@DB.connection_context()
def is_admin(cls, user_id):
return cls.model.select().where(
cls.model.id == user_id,
cls.model.is_superuser == 1).count() > 0
@classmethod
@DB.connection_context()
def get_all_users(cls):
users = cls.model.select()
return list(users)
class TenantService(CommonService):
"""Service class for managing tenant-related database operations.
This class extends CommonService to provide functionality for tenant management,
including tenant information retrieval and credit management.
Attributes:
model: The Tenant model class for database operations.
"""
model = Tenant
@classmethod
@DB.connection_context()
def get_info_by(cls, user_id):
fields = [
cls.model.id.alias("tenant_id"),
cls.model.name,
cls.model.llm_id,
cls.model.embd_id,
cls.model.rerank_id,
cls.model.asr_id,
cls.model.img2txt_id,
cls.model.tts_id,
cls.model.parser_ids,
UserTenant.role]
return list(cls.model.select(*fields)
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.OWNER)))
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def get_joined_tenants_by_user_id(cls, user_id):
fields = [
cls.model.id.alias("tenant_id"),
cls.model.name,
cls.model.llm_id,
cls.model.embd_id,
cls.model.asr_id,
cls.model.img2txt_id,
UserTenant.role]
return list(cls.model.select(*fields)
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL)))
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def decrease(cls, user_id, num):
num = cls.model.update(credit=cls.model.credit - num).where(
cls.model.id == user_id).execute()
if num == 0:
raise LookupError("Tenant not found which is supposed to be there")
@classmethod
@DB.connection_context()
def user_gateway(cls, tenant_id):
hashobj = hashlib.sha256(tenant_id.encode("utf-8"))
return int(hashobj.hexdigest(), 16)%len(MINIO)
class UserTenantService(CommonService):
"""Service class for managing user-tenant relationship operations.
This class extends CommonService to handle the many-to-many relationship
between users and tenants, managing user roles and tenant memberships.
Attributes:
model: The UserTenant model class for database operations.
"""
model = UserTenant
@classmethod
@DB.connection_context()
def filter_by_id(cls, user_tenant_id):
try:
user_tenant = cls.model.select().where((cls.model.id == user_tenant_id) & (cls.model.status == StatusEnum.VALID.value)).get()
return user_tenant
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
obj = cls.model(**kwargs).save(force_insert=True)
return obj
@classmethod
@DB.connection_context()
def get_by_tenant_id(cls, tenant_id):
fields = [
cls.model.id,
cls.model.user_id,
cls.model.status,
cls.model.role,
User.nickname,
User.email,
User.avatar,
User.is_authenticated,
User.is_active,
User.is_anonymous,
User.status,
User.update_date,
User.is_superuser]
return list(cls.model.select(*fields)
.join(User, on=((cls.model.user_id == User.id) & (cls.model.status == StatusEnum.VALID.value) & (cls.model.role != UserTenantRole.OWNER)))
.where(cls.model.tenant_id == tenant_id)
.dicts())
@classmethod
@DB.connection_context()
def get_tenants_by_user_id(cls, user_id):
fields = [
cls.model.tenant_id,
cls.model.role,
User.nickname,
User.email,
User.avatar,
User.update_date
]
return list(cls.model.select(*fields)
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def get_user_tenant_relation_by_user_id(cls, user_id):
fields = [
cls.model.id,
cls.model.user_id,
cls.model.tenant_id,
cls.model.role
]
return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts())
@classmethod
@DB.connection_context()
def get_num_members(cls, user_id: str):
cnt_members = cls.model.select(peewee.fn.COUNT(cls.model.id)).where(cls.model.tenant_id == user_id).scalar()
return cnt_members
@classmethod
@DB.connection_context()
def filter_by_tenant_and_user_id(cls, tenant_id, user_id):
try:
user_tenant = cls.model.select().where(
(cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) &
(cls.model.user_id == user_id)
).first()
return user_tenant
except peewee.DoesNotExist:
return None