修复 画布,mcp服务,搜索,文档的接口

This commit is contained in:
2025-11-07 09:34:35 +08:00
parent c5f8fe06e7
commit 54532747d2
21 changed files with 1023 additions and 272 deletions

View File

@@ -17,8 +17,10 @@ import json
import re
import logging
from copy import deepcopy
from flask import Response, request
from flask_login import current_user, login_required
from typing import Optional
from fastapi import APIRouter, Depends, Query, Header, HTTPException, status
from fastapi.responses import StreamingResponse
from api import settings
from api.db import LLMType
from api.db.db_models import APIToken
@@ -28,15 +30,35 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
from api.utils import get_uuid
from rag.prompts.template import load_prompt
from rag.prompts.generator import chunks_format
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.conversation_models import (
SetConversationRequest,
DeleteConversationsRequest,
CompletionRequest,
TTSRequest,
DeleteMessageRequest,
ThumbupRequest,
AskRequest,
MindmapRequest,
RelatedQuestionsRequest,
)
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
def set_conversation():
req = request.json
# 创建路由器
router = APIRouter()
@router.post('/set')
async def set_conversation(
request: SetConversationRequest,
current_user = Depends(get_current_user)
):
"""设置对话"""
req = request.model_dump(exclude_unset=True)
conv_id = req.get("conversation_id")
is_new = req.get("is_new")
name = req.get("name", "New conversation")
@@ -45,9 +67,9 @@ def set_conversation():
if len(name) > 255:
name = name[0:255]
del req["is_new"]
if not is_new:
del req["conversation_id"]
if not conv_id:
return get_data_error_result(message="conversation_id is required when is_new is False!")
try:
if not ConversationService.update_by_id(conv_id, req):
return get_data_error_result(message="Conversation not found!")
@@ -64,7 +86,7 @@ def set_conversation():
if not e:
return get_data_error_result(message="Dialog not found")
conv = {
"id": conv_id,
"id": conv_id or get_uuid(),
"dialog_id": req["dialog_id"],
"name": name,
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
@@ -77,12 +99,14 @@ def set_conversation():
return server_error_response(e)
@manager.route("/get", methods=["GET"]) # noqa: F821
@login_required
def get():
conv_id = request.args["conversation_id"]
@router.get('/get')
async def get(
conversation_id: str = Query(..., description="对话ID"),
current_user = Depends(get_current_user)
):
"""获取对话"""
try:
e, conv = ConversationService.get_by_id(conv_id)
e, conv = ConversationService.get_by_id(conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
@@ -107,15 +131,27 @@ def get():
return server_error_response(e)
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
def getsse(dialog_id):
token = request.headers.get("Authorization").split()
if len(token) != 2:
@router.get('/getsse/{dialog_id}')
async def getsse(
dialog_id: str,
authorization: Optional[str] = Header(None, alias="Authorization")
):
"""通过 SSE 获取对话(使用 API token 认证)"""
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header is required"
)
token_parts = authorization.split()
if len(token_parts) != 2:
return get_data_error_result(message='Authorization is not valid!"')
token = token[1]
token = token_parts[1]
objs = APIToken.query(beta=token)
if not objs:
return get_data_error_result(message='Authentication error: API key is invalid!"')
try:
e, conv = DialogService.get_by_id(dialog_id)
if not e:
@@ -128,12 +164,14 @@ def getsse(dialog_id):
return server_error_response(e)
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
def rm():
conv_ids = request.json["conversation_ids"]
@router.post('/rm')
async def rm(
request: DeleteConversationsRequest,
current_user = Depends(get_current_user)
):
"""删除对话"""
try:
for cid in conv_ids:
for cid in request.conversation_ids:
exist, conv = ConversationService.get_by_id(cid)
if not exist:
return get_data_error_result(message="Conversation not found!")
@@ -149,10 +187,12 @@ def rm():
return server_error_response(e)
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
def list_conversation():
dialog_id = request.args["dialog_id"]
@router.get('/list')
async def list_conversation(
dialog_id: str = Query(..., description="对话ID"),
current_user = Depends(get_current_user)
):
"""列出对话"""
try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
@@ -164,11 +204,13 @@ def list_conversation():
return server_error_response(e)
@manager.route("/completion", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "messages")
def completion():
req = request.json
@router.post('/completion')
async def completion(
request: CompletionRequest,
current_user = Depends(get_current_user)
):
"""完成请求(聊天完成)"""
req = request.model_dump(exclude_unset=True)
msg = []
for m in req["messages"]:
if m["role"] == "system":
@@ -176,6 +218,10 @@ def completion():
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg:
return get_data_error_result(message="No valid messages found!")
message_id = msg[-1].get("id")
chat_model_id = req.get("llm_id", "")
req.pop("llm_id", None)
@@ -217,6 +263,7 @@ def completion():
dia.llm_setting = chat_model_config
is_embedded = bool(chat_model_id)
def stream():
nonlocal dia, msg, req, conv
try:
@@ -230,14 +277,18 @@ def completion():
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
stream_enabled = request.stream if request.stream is not None else True
if stream_enabled:
return StreamingResponse(
stream(),
media_type="text/event-stream",
headers={
"Cache-control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8"
}
)
else:
answer = None
for ans in chat(dia, msg, **req):
@@ -250,11 +301,13 @@ def completion():
return server_error_response(e)
@manager.route("/tts", methods=["POST"]) # noqa: F821
@login_required
def tts():
req = request.json
text = req["text"]
@router.post('/tts')
async def tts(
request: TTSRequest,
current_user = Depends(get_current_user)
):
"""文本转语音"""
text = request.text
tenants = TenantService.get_info_by(current_user.id)
if not tenants:
@@ -274,28 +327,32 @@ def tts():
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
return StreamingResponse(
stream_audio(),
media_type="audio/mpeg",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def delete_msg():
req = request.json
e, conv = ConversationService.get_by_id(req["conversation_id"])
@router.post('/delete_msg')
async def delete_msg(
request: DeleteMessageRequest,
current_user = Depends(get_current_user)
):
"""删除消息"""
e, conv = ConversationService.get_by_id(request.conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
conv = conv.to_dict()
for i, msg in enumerate(conv["message"]):
if req["message_id"] != msg.get("id", ""):
if request.message_id != msg.get("id", ""):
continue
assert conv["message"][i + 1]["id"] == req["message_id"]
assert conv["message"][i + 1]["id"] == request.message_id
conv["message"].pop(i)
conv["message"].pop(i)
conv["reference"].pop(max(0, i // 2 - 1))
@@ -305,19 +362,21 @@ def delete_msg():
return get_json_result(data=conv)
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def thumbup():
req = request.json
e, conv = ConversationService.get_by_id(req["conversation_id"])
@router.post('/thumbup')
async def thumbup(
request: ThumbupRequest,
current_user = Depends(get_current_user)
):
"""点赞/点踩"""
e, conv = ConversationService.get_by_id(request.conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
up_down = req.get("thumbup")
feedback = req.get("feedback", "")
up_down = request.thumbup
feedback = request.feedback or ""
conv = conv.to_dict()
for i, msg in enumerate(conv["message"]):
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
if request.message_id == msg.get("id", "") and msg.get("role", "") == "assistant":
if up_down:
msg["thumbup"] = True
if "feedback" in msg:
@@ -332,14 +391,15 @@ def thumbup():
return get_json_result(data=conv)
@manager.route("/ask", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def ask_about():
req = request.json
@router.post('/ask')
async def ask_about(
request: AskRequest,
current_user = Depends(get_current_user)
):
"""提问"""
uid = current_user.id
search_id = req.get("search_id", "")
search_id = request.search_id or ""
search_app = None
search_config = {}
if search_id:
@@ -348,53 +408,58 @@ def ask_about():
search_config = search_app.get("search_config", {})
def stream():
nonlocal req, uid
nonlocal request, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
for ans in ask(request.question, request.kb_ids, uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
return StreamingResponse(
stream(),
media_type="text/event-stream",
headers={
"Cache-control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8"
}
)
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def mindmap():
req = request.json
search_id = req.get("search_id", "")
@router.post('/mindmap')
async def mindmap(
request: MindmapRequest,
current_user = Depends(get_current_user)
):
"""思维导图"""
search_id = request.search_id or ""
search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {}
kb_ids = search_config.get("kb_ids", [])
kb_ids.extend(req["kb_ids"])
kb_ids.extend(request.kb_ids)
kb_ids = list(set(kb_ids))
mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
mind_map = gen_mindmap(request.question, kb_ids, search_app.get("tenant_id", current_user.id), search_config)
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question")
def related_questions():
req = request.json
search_id = req.get("search_id", "")
@router.post('/related_questions')
async def related_questions(
request: RelatedQuestionsRequest,
current_user = Depends(get_current_user)
):
"""相关问题"""
search_id = request.search_id or ""
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
question = req["question"]
question = request.question
chat_id = search_config.get("chat_id", "")
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)