修复 画布,mcp服务,搜索,文档的接口
This commit is contained in:
@@ -17,8 +17,10 @@ import json
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query, Header, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import APIToken
|
||||
@@ -28,15 +30,35 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
|
||||
from api.utils import get_uuid
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.conversation_models import (
|
||||
SetConversationRequest,
|
||||
DeleteConversationsRequest,
|
||||
CompletionRequest,
|
||||
TTSRequest,
|
||||
DeleteMessageRequest,
|
||||
ThumbupRequest,
|
||||
AskRequest,
|
||||
MindmapRequest,
|
||||
RelatedQuestionsRequest,
|
||||
)
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_conversation():
|
||||
req = request.json
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/set')
|
||||
async def set_conversation(
|
||||
request: SetConversationRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""设置对话"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
conv_id = req.get("conversation_id")
|
||||
is_new = req.get("is_new")
|
||||
name = req.get("name", "New conversation")
|
||||
@@ -45,9 +67,9 @@ def set_conversation():
|
||||
if len(name) > 255:
|
||||
name = name[0:255]
|
||||
|
||||
del req["is_new"]
|
||||
if not is_new:
|
||||
del req["conversation_id"]
|
||||
if not conv_id:
|
||||
return get_data_error_result(message="conversation_id is required when is_new is False!")
|
||||
try:
|
||||
if not ConversationService.update_by_id(conv_id, req):
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@@ -64,7 +86,7 @@ def set_conversation():
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": conv_id,
|
||||
"id": conv_id or get_uuid(),
|
||||
"dialog_id": req["dialog_id"],
|
||||
"name": name,
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
|
||||
@@ -77,12 +99,14 @@ def set_conversation():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
conv_id = request.args["conversation_id"]
|
||||
@router.get('/get')
|
||||
async def get(
|
||||
conversation_id: str = Query(..., description="对话ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取对话"""
|
||||
try:
|
||||
e, conv = ConversationService.get_by_id(conv_id)
|
||||
e, conv = ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
@@ -107,15 +131,27 @@ def get():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
|
||||
def getsse(dialog_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@router.get('/getsse/{dialog_id}')
|
||||
async def getsse(
|
||||
dialog_id: str,
|
||||
authorization: Optional[str] = Header(None, alias="Authorization")
|
||||
):
|
||||
"""通过 SSE 获取对话(使用 API token 认证)"""
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header is required"
|
||||
)
|
||||
|
||||
token_parts = authorization.split()
|
||||
if len(token_parts) != 2:
|
||||
return get_data_error_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
token = token_parts[1]
|
||||
|
||||
objs = APIToken.query(beta=token)
|
||||
if not objs:
|
||||
return get_data_error_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
try:
|
||||
e, conv = DialogService.get_by_id(dialog_id)
|
||||
if not e:
|
||||
@@ -128,12 +164,14 @@ def getsse(dialog_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def rm():
|
||||
conv_ids = request.json["conversation_ids"]
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
request: DeleteConversationsRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除对话"""
|
||||
try:
|
||||
for cid in conv_ids:
|
||||
for cid in request.conversation_ids:
|
||||
exist, conv = ConversationService.get_by_id(cid)
|
||||
if not exist:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@@ -149,10 +187,12 @@ def rm():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_conversation():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
@router.get('/list')
|
||||
async def list_conversation(
|
||||
dialog_id: str = Query(..., description="对话ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出对话"""
|
||||
try:
|
||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
@@ -164,11 +204,13 @@ def list_conversation():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/completion", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
req = request.json
|
||||
@router.post('/completion')
|
||||
async def completion(
|
||||
request: CompletionRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""完成请求(聊天完成)"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
@@ -176,6 +218,10 @@ def completion():
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
|
||||
if not msg:
|
||||
return get_data_error_result(message="No valid messages found!")
|
||||
|
||||
message_id = msg[-1].get("id")
|
||||
chat_model_id = req.get("llm_id", "")
|
||||
req.pop("llm_id", None)
|
||||
@@ -217,6 +263,7 @@ def completion():
|
||||
dia.llm_setting = chat_model_config
|
||||
|
||||
is_embedded = bool(chat_model_id)
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
@@ -230,14 +277,18 @@ def completion():
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
stream_enabled = request.stream if request.stream is not None else True
|
||||
if stream_enabled:
|
||||
return StreamingResponse(
|
||||
stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Type": "text/event-stream; charset=utf-8"
|
||||
}
|
||||
)
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
@@ -250,11 +301,13 @@ def completion():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def tts():
|
||||
req = request.json
|
||||
text = req["text"]
|
||||
@router.post('/tts')
|
||||
async def tts(
|
||||
request: TTSRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""文本转语音"""
|
||||
text = request.text
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
if not tenants:
|
||||
@@ -274,28 +327,32 @@ def tts():
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
||||
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
|
||||
return resp
|
||||
return StreamingResponse(
|
||||
stream_audio(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def delete_msg():
|
||||
req = request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
@router.post('/delete_msg')
|
||||
async def delete_msg(
|
||||
request: DeleteMessageRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除消息"""
|
||||
e, conv = ConversationService.get_by_id(request.conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
|
||||
conv = conv.to_dict()
|
||||
for i, msg in enumerate(conv["message"]):
|
||||
if req["message_id"] != msg.get("id", ""):
|
||||
if request.message_id != msg.get("id", ""):
|
||||
continue
|
||||
assert conv["message"][i + 1]["id"] == req["message_id"]
|
||||
assert conv["message"][i + 1]["id"] == request.message_id
|
||||
conv["message"].pop(i)
|
||||
conv["message"].pop(i)
|
||||
conv["reference"].pop(max(0, i // 2 - 1))
|
||||
@@ -305,19 +362,21 @@ def delete_msg():
|
||||
return get_json_result(data=conv)
|
||||
|
||||
|
||||
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def thumbup():
|
||||
req = request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
@router.post('/thumbup')
|
||||
async def thumbup(
|
||||
request: ThumbupRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""点赞/点踩"""
|
||||
e, conv = ConversationService.get_by_id(request.conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
up_down = req.get("thumbup")
|
||||
feedback = req.get("feedback", "")
|
||||
|
||||
up_down = request.thumbup
|
||||
feedback = request.feedback or ""
|
||||
conv = conv.to_dict()
|
||||
for i, msg in enumerate(conv["message"]):
|
||||
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
|
||||
if request.message_id == msg.get("id", "") and msg.get("role", "") == "assistant":
|
||||
if up_down:
|
||||
msg["thumbup"] = True
|
||||
if "feedback" in msg:
|
||||
@@ -332,14 +391,15 @@ def thumbup():
|
||||
return get_json_result(data=conv)
|
||||
|
||||
|
||||
@manager.route("/ask", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about():
|
||||
req = request.json
|
||||
@router.post('/ask')
|
||||
async def ask_about(
|
||||
request: AskRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""提问"""
|
||||
uid = current_user.id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_id = request.search_id or ""
|
||||
search_app = None
|
||||
search_config = {}
|
||||
if search_id:
|
||||
@@ -348,53 +408,58 @@ def ask_about():
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
def stream():
|
||||
nonlocal req, uid
|
||||
nonlocal request, uid
|
||||
try:
|
||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||
for ans in ask(request.question, request.kb_ids, uid, search_config=search_config):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
return StreamingResponse(
|
||||
stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Type": "text/event-stream; charset=utf-8"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
req = request.json
|
||||
search_id = req.get("search_id", "")
|
||||
@router.post('/mindmap')
|
||||
async def mindmap(
|
||||
request: MindmapRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""思维导图"""
|
||||
search_id = request.search_id or ""
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||
kb_ids = search_config.get("kb_ids", [])
|
||||
kb_ids.extend(req["kb_ids"])
|
||||
kb_ids.extend(request.kb_ids)
|
||||
kb_ids = list(set(kb_ids))
|
||||
|
||||
mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
||||
mind_map = gen_mindmap(request.question, kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
|
||||
|
||||
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
def related_questions():
|
||||
req = request.json
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@router.post('/related_questions')
|
||||
async def related_questions(
|
||||
request: RelatedQuestionsRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""相关问题"""
|
||||
search_id = request.search_id or ""
|
||||
search_config = {}
|
||||
if search_id:
|
||||
if search_app := SearchService.get_detail(search_id):
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
question = req["question"]
|
||||
question = request.question
|
||||
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)
|
||||
|
||||
Reference in New Issue
Block a user