Files
TERES_fastapi_backend/api/apps/dialog_app.py
2025-11-06 17:15:46 +08:00

250 lines
9.3 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
from fastapi import APIRouter, Depends, Query
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.dialog_models import (
SetDialogRequest,
ListDialogsNextQuery,
ListDialogsNextBody,
DeleteDialogRequest,
)
from api.db.services import duplicate_name
from api.db.services.dialog_service import DialogService
from api.db import StatusEnum
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
# 创建路由器
router = APIRouter()
@router.post('/set')
async def set_dialog(
request: SetDialogRequest,
current_user = Depends(get_current_user)
):
"""设置/创建对话框"""
req = request.model_dump(exclude_unset=True)
dialog_id = req.get("dialog_id", "")
is_create = not dialog_id
name = req.get("name", "New Dialog")
if not isinstance(name, str):
return get_data_error_result(message="Dialog name must be string.")
if name.strip() == "":
return get_data_error_result(message="Dialog name can't be empty.")
if len(name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()):
name = name.strip()
name = duplicate_name(
DialogService.query,
name=name,
tenant_id=current_user.id,
status=StatusEnum.VALID.value)
description = req.get("description", "A helpful dialog")
icon = req.get("icon", "")
top_n = req.get("top_n", 6)
top_k = req.get("top_k", 1024)
rerank_id = req.get("rerank_id", "")
if not rerank_id:
req["rerank_id"] = ""
similarity_threshold = req.get("similarity_threshold", 0.1)
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
llm_setting = req.get("llm_setting", {})
meta_data_filter = req.get("meta_data_filter", {})
prompt_config = req["prompt_config"]
if not is_create:
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no knowledge base / Tavily used here.")
for p in prompt_config["parameters"]:
if p["optional"]:
continue
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
return get_data_error_result(
message="Parameter '{}' is not used".format(p["key"]))
try:
e, tenant = TenantService.get_by_id(current_user.id)
if not e:
return get_data_error_result(message="Tenant not found!")
kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids", []))
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = len(set(embd_ids))
if embd_count > 1:
return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"')
llm_id = req.get("llm_id", tenant.llm_id)
if not dialog_id:
dia = {
"id": get_uuid(),
"tenant_id": current_user.id,
"name": name,
"kb_ids": req.get("kb_ids", []),
"description": description,
"llm_id": llm_id,
"llm_setting": llm_setting,
"prompt_config": prompt_config,
"meta_data_filter": meta_data_filter,
"top_n": top_n,
"top_k": top_k,
"rerank_id": rerank_id,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"icon": icon
}
if not DialogService.save(**dia):
return get_data_error_result(message="Fail to new a dialog!")
return get_json_result(data=dia)
else:
del req["dialog_id"]
if "kb_names" in req:
del req["kb_names"]
if not DialogService.update_by_id(dialog_id, req):
return get_data_error_result(message="Dialog not found!")
e, dia = DialogService.get_by_id(dialog_id)
if not e:
return get_data_error_result(message="Fail to update a dialog!")
dia = dia.to_dict()
dia.update(req)
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
return get_json_result(data=dia)
except Exception as e:
return server_error_response(e)
@router.get('/get')
async def get(
dialog_id: str = Query(..., description="对话框ID"),
current_user = Depends(get_current_user)
):
"""获取对话框详情"""
try:
e, dia = DialogService.get_by_id(dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
dia = dia.to_dict()
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
return get_json_result(data=dia)
except Exception as e:
return server_error_response(e)
def get_kb_names(kb_ids):
ids, nms = [], []
for kid in kb_ids:
e, kb = KnowledgebaseService.get_by_id(kid)
if not e or kb.status != StatusEnum.VALID.value:
continue
ids.append(kid)
nms.append(kb.name)
return ids, nms
@router.get('/list')
async def list_dialogs(
current_user = Depends(get_current_user)
):
"""列出对话框"""
try:
diags = DialogService.query(
tenant_id=current_user.id,
status=StatusEnum.VALID.value,
reverse=True,
order_by=DialogService.model.create_time)
diags = [d.to_dict() for d in diags]
for d in diags:
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
return get_json_result(data=diags)
except Exception as e:
return server_error_response(e)
@router.post('/next')
async def list_dialogs_next(
query: ListDialogsNextQuery = Depends(),
body: Optional[ListDialogsNextBody] = None,
current_user = Depends(get_current_user)
):
"""列出对话框(分页)"""
if body is None:
body = ListDialogsNextBody()
keywords = query.keywords or ""
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
parser_id = query.parser_id
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
owner_ids = body.owner_ids or []
try:
if not owner_ids:
# tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
# tenants = [tenant["tenant_id"] for tenant in tenants]
tenants = [] # keep it here
dialogs, total = DialogService.get_by_tenant_ids(
tenants, current_user.id, page_number,
items_per_page, orderby, desc, keywords, parser_id)
else:
tenants = owner_ids
dialogs, total = DialogService.get_by_tenant_ids(
tenants, current_user.id, 0,
0, orderby, desc, keywords, parser_id)
dialogs = [dialog for dialog in dialogs if dialog["tenant_id"] in tenants]
total = len(dialogs)
if page_number and items_per_page:
dialogs = dialogs[(page_number-1)*items_per_page:page_number*items_per_page]
return get_json_result(data={"dialogs": dialogs, "total": total})
except Exception as e:
return server_error_response(e)
@router.post('/rm')
async def rm(
request: DeleteDialogRequest,
current_user = Depends(get_current_user)
):
"""删除对话框"""
dialog_list = []
tenants = UserTenantService.query(user_id=current_user.id)
try:
for id in request.dialog_ids:
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=id):
break
else:
return get_json_result(
data=False, message='Only owner of dialog authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
dialog_list.append({"id": id, "status": StatusEnum.INVALID.value})
DialogService.update_many_by_id(dialog_list)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)