Files
TERES_fastapi_backend/api/apps/conversation_app.py

485 lines
16 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import re
import logging
from copy import deepcopy
from typing import Optional
from fastapi import APIRouter, Depends, Query, Header, HTTPException, status
from fastapi.responses import StreamingResponse
from api import settings
from api.db import LLMType
from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
from api.utils import get_uuid
from rag.prompts.template import load_prompt
from rag.prompts.generator import chunks_format
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.conversation_models import (
SetConversationRequest,
DeleteConversationsRequest,
CompletionRequest,
TTSRequest,
DeleteMessageRequest,
ThumbupRequest,
AskRequest,
MindmapRequest,
RelatedQuestionsRequest,
)
# 创建路由器
router = APIRouter()
@router.post('/set')
async def set_conversation(
request: SetConversationRequest,
current_user = Depends(get_current_user)
):
"""设置对话"""
req = request.model_dump(exclude_unset=True)
conv_id = req.get("conversation_id")
is_new = req.get("is_new")
name = req.get("name", "New conversation")
req["user_id"] = current_user.id
if len(name) > 255:
name = name[0:255]
if not is_new:
if not conv_id:
return get_data_error_result(message="conversation_id is required when is_new is False!")
try:
if not ConversationService.update_by_id(conv_id, req):
return get_data_error_result(message="Conversation not found!")
e, conv = ConversationService.get_by_id(conv_id)
if not e:
return get_data_error_result(message="Fail to update a conversation!")
conv = conv.to_dict()
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
try:
e, dia = DialogService.get_by_id(req["dialog_id"])
if not e:
return get_data_error_result(message="Dialog not found")
conv = {
"id": conv_id or get_uuid(),
"dialog_id": req["dialog_id"],
"name": name,
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
"user_id": current_user.id,
"reference": [],
}
ConversationService.save(**conv)
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@router.get('/get')
async def get(
conversation_id: str = Query(..., description="对话ID"),
current_user = Depends(get_current_user)
):
"""获取对话"""
try:
e, conv = ConversationService.get_by_id(conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
avatar = None
for tenant in tenants:
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
if dialog and len(dialog) > 0:
avatar = dialog[0].icon
break
else:
return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
for ref in conv.reference:
if isinstance(ref, list):
continue
ref["chunks"] = chunks_format(ref)
conv = conv.to_dict()
conv["avatar"] = avatar
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@router.get('/getsse/{dialog_id}')
async def getsse(
dialog_id: str,
authorization: Optional[str] = Header(None, alias="Authorization")
):
"""通过 SSE 获取对话(使用 API token 认证)"""
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header is required"
)
token_parts = authorization.split()
if len(token_parts) != 2:
return get_data_error_result(message='Authorization is not valid!"')
token = token_parts[1]
objs = APIToken.query(beta=token)
if not objs:
return get_data_error_result(message='Authentication error: API key is invalid!"')
try:
e, conv = DialogService.get_by_id(dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
conv = conv.to_dict()
conv["avatar"] = conv["icon"]
del conv["icon"]
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@router.post('/rm')
async def rm(
request: DeleteConversationsRequest,
current_user = Depends(get_current_user)
):
"""删除对话"""
try:
for cid in request.conversation_ids:
exist, conv = ConversationService.get_by_id(cid)
if not exist:
return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break
else:
return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@router.get('/list')
async def list_conversation(
dialog_id: str = Query(..., description="对话ID"),
current_user = Depends(get_current_user)
):
"""列出对话"""
try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
convs = [d.to_dict() for d in convs]
return get_json_result(data=convs)
except Exception as e:
return server_error_response(e)
@router.post('/completion')
async def completion(
request: CompletionRequest,
current_user = Depends(get_current_user)
):
"""完成请求(聊天完成)"""
req = request.model_dump(exclude_unset=True)
msg = []
for m in req["messages"]:
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg:
return get_data_error_result(message="No valid messages found!")
message_id = msg[-1].get("id")
chat_model_id = req.get("llm_id", "")
req.pop("llm_id", None)
chat_model_config = {}
for model_config in [
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
"max_tokens",
]:
config = req.get(model_config)
if config:
chat_model_config[model_config] = config
try:
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
conv.message = deepcopy(req["messages"])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
del req["conversation_id"]
del req["messages"]
if not conv.reference:
conv.reference = []
conv.reference = [r for r in conv.reference if r]
conv.reference.append({"chunks": [], "doc_aggs": []})
if chat_model_id:
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
req.pop("chat_model_id", None)
req.pop("chat_model_config", None)
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
dia.llm_id = chat_model_id
dia.llm_setting = chat_model_config
is_embedded = bool(chat_model_id)
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
ans = structure_answer(conv, ans, message_id, conv.id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
logging.exception(e)
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
stream_enabled = request.stream if request.stream is not None else True
if stream_enabled:
return StreamingResponse(
stream(),
media_type="text/event-stream",
headers={
"Cache-control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8"
}
)
else:
answer = None
for ans in chat(dia, msg, **req):
answer = structure_answer(conv, ans, message_id, conv.id)
if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict())
break
return get_json_result(data=answer)
except Exception as e:
return server_error_response(e)
@router.post('/tts')
async def tts(
request: TTSRequest,
current_user = Depends(get_current_user)
):
"""文本转语音"""
text = request.text
tenants = TenantService.get_info_by(current_user.id)
if not tenants:
return get_data_error_result(message="Tenant not found!")
tts_id = tenants[0]["tts_id"]
if not tts_id:
return get_data_error_result(message="No default TTS model is set")
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
def stream_audio():
try:
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
for chunk in tts_mdl.tts(txt):
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
return StreamingResponse(
stream_audio(),
media_type="audio/mpeg",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@router.post('/delete_msg')
async def delete_msg(
request: DeleteMessageRequest,
current_user = Depends(get_current_user)
):
"""删除消息"""
e, conv = ConversationService.get_by_id(request.conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
conv = conv.to_dict()
for i, msg in enumerate(conv["message"]):
if request.message_id != msg.get("id", ""):
continue
assert conv["message"][i + 1]["id"] == request.message_id
conv["message"].pop(i)
conv["message"].pop(i)
conv["reference"].pop(max(0, i // 2 - 1))
break
ConversationService.update_by_id(conv["id"], conv)
return get_json_result(data=conv)
@router.post('/thumbup')
async def thumbup(
request: ThumbupRequest,
current_user = Depends(get_current_user)
):
"""点赞/点踩"""
e, conv = ConversationService.get_by_id(request.conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
up_down = request.thumbup
feedback = request.feedback or ""
conv = conv.to_dict()
for i, msg in enumerate(conv["message"]):
if request.message_id == msg.get("id", "") and msg.get("role", "") == "assistant":
if up_down:
msg["thumbup"] = True
if "feedback" in msg:
del msg["feedback"]
else:
msg["thumbup"] = False
if feedback:
msg["feedback"] = feedback
break
ConversationService.update_by_id(conv["id"], conv)
return get_json_result(data=conv)
@router.post('/ask')
async def ask_about(
request: AskRequest,
current_user = Depends(get_current_user)
):
"""提问"""
uid = current_user.id
search_id = request.search_id or ""
search_app = None
search_config = {}
if search_id:
search_app = SearchService.get_detail(search_id)
if search_app:
search_config = search_app.get("search_config", {})
def stream():
nonlocal request, uid
try:
for ans in ask(request.question, request.kb_ids, uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
return StreamingResponse(
stream(),
media_type="text/event-stream",
headers={
"Cache-control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8"
}
)
@router.post('/mindmap')
async def mindmap(
request: MindmapRequest,
current_user = Depends(get_current_user)
):
"""思维导图"""
search_id = request.search_id or ""
search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {}
kb_ids = search_config.get("kb_ids", [])
kb_ids.extend(request.kb_ids)
kb_ids = list(set(kb_ids))
mind_map = gen_mindmap(request.question, kb_ids, search_app.get("tenant_id", current_user.id), search_config)
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)
@router.post('/related_questions')
async def related_questions(
request: RelatedQuestionsRequest,
current_user = Depends(get_current_user)
):
"""相关问题"""
search_id = request.search_id or ""
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
question = request.question
chat_id = search_config.get("chat_id", "")
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
if "parameter" in gen_conf:
del gen_conf["parameter"]
prompt = load_prompt("related_question")
ans = chat_mdl.chat(
prompt,
[
{
"role": "user",
"content": f"""
Keywords: {question}
Related search terms:
""",
}
],
gen_conf,
)
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])