修复 画布,mcp服务,搜索,文档的接口

This commit is contained in:
2025-11-07 09:34:35 +08:00
parent c5f8fe06e7
commit 54532747d2
21 changed files with 1023 additions and 272 deletions

View File

@@ -165,17 +165,23 @@ def setup_routes(app: FastAPI):
from api.apps.tenant_app import router as tenant_router
from api.apps.dialog_app import router as dialog_router
from api.apps.system_app import router as system_router
from api.apps.search_app import router as search_router
from api.apps.conversation_app import router as conversation_router
from api.apps.file_app import router as file_router
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"])
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"])
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp_server", tags=["MCP"])
app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"])
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
app.include_router(dialog_router, prefix=f"/{API_VERSION}/dialog", tags=["Dialog"])
app.include_router(system_router, prefix=f"/{API_VERSION}/system", tags=["System"])
app.include_router(search_router, prefix=f"/{API_VERSION}/search", tags=["Search"])
app.include_router(conversation_router, prefix=f"/{API_VERSION}/conversation", tags=["Conversation"])
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])

View File

@@ -17,8 +17,10 @@ import json
import re
import logging
from copy import deepcopy
from flask import Response, request
from flask_login import current_user, login_required
from typing import Optional
from fastapi import APIRouter, Depends, Query, Header, HTTPException, status
from fastapi.responses import StreamingResponse
from api import settings
from api.db import LLMType
from api.db.db_models import APIToken
@@ -28,15 +30,35 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
from api.utils import get_uuid
from rag.prompts.template import load_prompt
from rag.prompts.generator import chunks_format
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.conversation_models import (
SetConversationRequest,
DeleteConversationsRequest,
CompletionRequest,
TTSRequest,
DeleteMessageRequest,
ThumbupRequest,
AskRequest,
MindmapRequest,
RelatedQuestionsRequest,
)
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
def set_conversation():
req = request.json
# 创建路由器
router = APIRouter()
@router.post('/set')
async def set_conversation(
request: SetConversationRequest,
current_user = Depends(get_current_user)
):
"""设置对话"""
req = request.model_dump(exclude_unset=True)
conv_id = req.get("conversation_id")
is_new = req.get("is_new")
name = req.get("name", "New conversation")
@@ -45,9 +67,9 @@ def set_conversation():
if len(name) > 255:
name = name[0:255]
del req["is_new"]
if not is_new:
del req["conversation_id"]
if not conv_id:
return get_data_error_result(message="conversation_id is required when is_new is False!")
try:
if not ConversationService.update_by_id(conv_id, req):
return get_data_error_result(message="Conversation not found!")
@@ -64,7 +86,7 @@ def set_conversation():
if not e:
return get_data_error_result(message="Dialog not found")
conv = {
"id": conv_id,
"id": conv_id or get_uuid(),
"dialog_id": req["dialog_id"],
"name": name,
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
@@ -77,12 +99,14 @@ def set_conversation():
return server_error_response(e)
@manager.route("/get", methods=["GET"]) # noqa: F821
@login_required
def get():
conv_id = request.args["conversation_id"]
@router.get('/get')
async def get(
conversation_id: str = Query(..., description="对话ID"),
current_user = Depends(get_current_user)
):
"""获取对话"""
try:
e, conv = ConversationService.get_by_id(conv_id)
e, conv = ConversationService.get_by_id(conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
@@ -107,15 +131,27 @@ def get():
return server_error_response(e)
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
def getsse(dialog_id):
token = request.headers.get("Authorization").split()
if len(token) != 2:
@router.get('/getsse/{dialog_id}')
async def getsse(
dialog_id: str,
authorization: Optional[str] = Header(None, alias="Authorization")
):
"""通过 SSE 获取对话(使用 API token 认证)"""
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header is required"
)
token_parts = authorization.split()
if len(token_parts) != 2:
return get_data_error_result(message='Authorization is not valid!"')
token = token[1]
token = token_parts[1]
objs = APIToken.query(beta=token)
if not objs:
return get_data_error_result(message='Authentication error: API key is invalid!"')
try:
e, conv = DialogService.get_by_id(dialog_id)
if not e:
@@ -128,12 +164,14 @@ def getsse(dialog_id):
return server_error_response(e)
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
def rm():
conv_ids = request.json["conversation_ids"]
@router.post('/rm')
async def rm(
request: DeleteConversationsRequest,
current_user = Depends(get_current_user)
):
"""删除对话"""
try:
for cid in conv_ids:
for cid in request.conversation_ids:
exist, conv = ConversationService.get_by_id(cid)
if not exist:
return get_data_error_result(message="Conversation not found!")
@@ -149,10 +187,12 @@ def rm():
return server_error_response(e)
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
def list_conversation():
dialog_id = request.args["dialog_id"]
@router.get('/list')
async def list_conversation(
dialog_id: str = Query(..., description="对话ID"),
current_user = Depends(get_current_user)
):
"""列出对话"""
try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
@@ -164,11 +204,13 @@ def list_conversation():
return server_error_response(e)
@manager.route("/completion", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "messages")
def completion():
req = request.json
@router.post('/completion')
async def completion(
request: CompletionRequest,
current_user = Depends(get_current_user)
):
"""完成请求(聊天完成)"""
req = request.model_dump(exclude_unset=True)
msg = []
for m in req["messages"]:
if m["role"] == "system":
@@ -176,6 +218,10 @@ def completion():
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg:
return get_data_error_result(message="No valid messages found!")
message_id = msg[-1].get("id")
chat_model_id = req.get("llm_id", "")
req.pop("llm_id", None)
@@ -217,6 +263,7 @@ def completion():
dia.llm_setting = chat_model_config
is_embedded = bool(chat_model_id)
def stream():
nonlocal dia, msg, req, conv
try:
@@ -230,14 +277,18 @@ def completion():
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
stream_enabled = request.stream if request.stream is not None else True
if stream_enabled:
return StreamingResponse(
stream(),
media_type="text/event-stream",
headers={
"Cache-control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8"
}
)
else:
answer = None
for ans in chat(dia, msg, **req):
@@ -250,11 +301,13 @@ def completion():
return server_error_response(e)
@manager.route("/tts", methods=["POST"]) # noqa: F821
@login_required
def tts():
req = request.json
text = req["text"]
@router.post('/tts')
async def tts(
request: TTSRequest,
current_user = Depends(get_current_user)
):
"""文本转语音"""
text = request.text
tenants = TenantService.get_info_by(current_user.id)
if not tenants:
@@ -274,28 +327,32 @@ def tts():
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
return StreamingResponse(
stream_audio(),
media_type="audio/mpeg",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def delete_msg():
req = request.json
e, conv = ConversationService.get_by_id(req["conversation_id"])
@router.post('/delete_msg')
async def delete_msg(
request: DeleteMessageRequest,
current_user = Depends(get_current_user)
):
"""删除消息"""
e, conv = ConversationService.get_by_id(request.conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
conv = conv.to_dict()
for i, msg in enumerate(conv["message"]):
if req["message_id"] != msg.get("id", ""):
if request.message_id != msg.get("id", ""):
continue
assert conv["message"][i + 1]["id"] == req["message_id"]
assert conv["message"][i + 1]["id"] == request.message_id
conv["message"].pop(i)
conv["message"].pop(i)
conv["reference"].pop(max(0, i // 2 - 1))
@@ -305,19 +362,21 @@ def delete_msg():
return get_json_result(data=conv)
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id", "message_id")
def thumbup():
req = request.json
e, conv = ConversationService.get_by_id(req["conversation_id"])
@router.post('/thumbup')
async def thumbup(
request: ThumbupRequest,
current_user = Depends(get_current_user)
):
"""点赞/点踩"""
e, conv = ConversationService.get_by_id(request.conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
up_down = req.get("thumbup")
feedback = req.get("feedback", "")
up_down = request.thumbup
feedback = request.feedback or ""
conv = conv.to_dict()
for i, msg in enumerate(conv["message"]):
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
if request.message_id == msg.get("id", "") and msg.get("role", "") == "assistant":
if up_down:
msg["thumbup"] = True
if "feedback" in msg:
@@ -332,14 +391,15 @@ def thumbup():
return get_json_result(data=conv)
@manager.route("/ask", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def ask_about():
req = request.json
@router.post('/ask')
async def ask_about(
request: AskRequest,
current_user = Depends(get_current_user)
):
"""提问"""
uid = current_user.id
search_id = req.get("search_id", "")
search_id = request.search_id or ""
search_app = None
search_config = {}
if search_id:
@@ -348,53 +408,58 @@ def ask_about():
search_config = search_app.get("search_config", {})
def stream():
nonlocal req, uid
nonlocal request, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
for ans in ask(request.question, request.kb_ids, uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
return StreamingResponse(
stream(),
media_type="text/event-stream",
headers={
"Cache-control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8"
}
)
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
def mindmap():
req = request.json
search_id = req.get("search_id", "")
@router.post('/mindmap')
async def mindmap(
request: MindmapRequest,
current_user = Depends(get_current_user)
):
"""思维导图"""
search_id = request.search_id or ""
search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {}
kb_ids = search_config.get("kb_ids", [])
kb_ids.extend(req["kb_ids"])
kb_ids.extend(request.kb_ids)
kb_ids = list(set(kb_ids))
mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
mind_map = gen_mindmap(request.question, kb_ids, search_app.get("tenant_id", current_user.id), search_config)
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question")
def related_questions():
req = request.json
search_id = req.get("search_id", "")
@router.post('/related_questions')
async def related_questions(
request: RelatedQuestionsRequest,
current_user = Depends(get_current_user)
):
"""相关问题"""
search_id = request.search_id or ""
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
question = req["question"]
question = request.question
chat_id = search_config.get("chat_id", "")
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)

View File

@@ -72,18 +72,17 @@ router = APIRouter()
@router.post("/upload")
async def upload(
kb_id: str = Form(...),
files: List[UploadFile] = File(...),
file: UploadFile = File(...),
current_user = Depends(get_current_user)
):
"""上传文档"""
if not files:
if not file:
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
for file_obj in files:
if not file_obj.filename or file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
if not file.filename or file.filename == "":
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
if len(file.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
@@ -91,7 +90,7 @@ async def upload(
if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
err, uploaded_files = FileService.upload_document(kb, files, current_user.id)
err, uploaded_files = FileService.upload_document(kb, [file], current_user.id)
if err:
return get_json_result(data=uploaded_files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)

View File

@@ -17,15 +17,23 @@ import logging
import os
import pathlib
import re
from typing import Optional, List
import flask
from flask import request
from flask_login import login_required, current_user
from fastapi import APIRouter, Depends, Query, UploadFile, File, Form
from fastapi.responses import Response
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.file_models import (
CreateFileRequest,
DeleteFilesRequest,
RenameFileRequest,
MoveFilesRequest,
)
from api.common.check_team_permission import check_file_team_permission
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils.api_utils import server_error_response, get_data_error_result
from api.utils import get_uuid
from api.db import FileType, FileSource
from api.db.services import duplicate_name
@@ -36,35 +44,41 @@ from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from rag.utils.storage_factory import STORAGE_IMPL
# 创建路由器
router = APIRouter()
@manager.route('/upload', methods=['POST']) # noqa: F821
@login_required
# @validate_request("parent_id")
def upload():
pf_id = request.form.get("parent_id")
@router.post('/upload')
async def upload(
files: List[UploadFile] = File(...),
parent_id: Optional[str] = Form(None),
current_user = Depends(get_current_user)
):
"""上传文件"""
pf_id = parent_id
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
if 'file' not in request.files:
if not files:
return get_json_result(
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
for file_obj in files:
if not file_obj.filename or file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
file_res = []
try:
e, pf_folder = FileService.get_by_id(pf_id)
if not e:
return get_data_error_result( message="Can't find this folder!")
for file_obj in file_objs:
return get_data_error_result(message="Can't find this folder!")
for file_obj in files:
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
return get_data_error_result( message="Exceed the maximum file number of a free user!")
return get_data_error_result(message="Exceed the maximum file number of a free user!")
# split file name path
if not file_obj.filename:
@@ -97,7 +111,7 @@ def upload():
location = file_obj_names[file_len - 1]
while STORAGE_IMPL.obj_exist(last_folder.id, location):
location += "_"
blob = file_obj.read()
blob = await file_obj.read()
filename = duplicate_name(
FileService.query,
name=file_obj_names[file_len - 1],
@@ -120,13 +134,16 @@ def upload():
return server_error_response(e)
@manager.route('/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.json
pf_id = request.json.get("parent_id")
input_file_type = request.json.get("type")
@router.post('/create')
async def create(
request: CreateFileRequest,
current_user = Depends(get_current_user)
):
"""创建文件/文件夹"""
req = request.model_dump(exclude_unset=True)
pf_id = req.get("parent_id")
input_file_type = req.get("type")
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
@@ -160,17 +177,22 @@ def create():
return server_error_response(e)
@manager.route('/list', methods=['GET']) # noqa: F821
@login_required
def list_files():
pf_id = request.args.get("parent_id")
@router.get('/list')
async def list_files(
parent_id: Optional[str] = Query(None, description="父文件夹ID"),
keywords: Optional[str] = Query("", description="搜索关键词"),
page: Optional[int] = Query(1, description="页码"),
page_size: Optional[int] = Query(15, description="每页数量"),
orderby: Optional[str] = Query("create_time", description="排序字段"),
desc: Optional[bool] = Query(True, description="是否降序"),
current_user = Depends(get_current_user)
):
"""列出文件"""
pf_id = parent_id
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 15))
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
page_number = int(page) if page else 1
items_per_page = int(page_size) if page_size else 15
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
@@ -192,9 +214,11 @@ def list_files():
return server_error_response(e)
@manager.route('/root_folder', methods=['GET']) # noqa: F821
@login_required
def get_root_folder():
@router.get('/root_folder')
async def get_root_folder(
current_user = Depends(get_current_user)
):
"""获取根文件夹"""
try:
root_folder = FileService.get_root_folder(current_user.id)
return get_json_result(data={"root_folder": root_folder})
@@ -202,10 +226,12 @@ def get_root_folder():
return server_error_response(e)
@manager.route('/parent_folder', methods=['GET']) # noqa: F821
@login_required
def get_parent_folder():
file_id = request.args.get("file_id")
@router.get('/parent_folder')
async def get_parent_folder(
file_id: str = Query(..., description="文件ID"),
current_user = Depends(get_current_user)
):
"""获取父文件夹"""
try:
e, file = FileService.get_by_id(file_id)
if not e:
@@ -217,10 +243,12 @@ def get_parent_folder():
return server_error_response(e)
@manager.route('/all_parent_folder', methods=['GET']) # noqa: F821
@login_required
def get_all_parent_folders():
file_id = request.args.get("file_id")
@router.get('/all_parent_folder')
async def get_all_parent_folders(
file_id: str = Query(..., description="文件ID"),
current_user = Depends(get_current_user)
):
"""获取所有父文件夹"""
try:
e, file = FileService.get_by_id(file_id)
if not e:
@@ -235,12 +263,13 @@ def get_all_parent_folders():
return server_error_response(e)
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
@validate_request("file_ids")
def rm():
req = request.json
file_ids = req["file_ids"]
@router.post("/rm")
async def rm(
request: DeleteFilesRequest,
current_user = Depends(get_current_user)
):
"""删除文件"""
file_ids = request.file_ids
def _delete_single_file(file):
try:
@@ -296,11 +325,13 @@ def rm():
return server_error_response(e)
@manager.route('/rename', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_id", "name")
def rename():
req = request.json
@router.post('/rename')
async def rename(
request: RenameFileRequest,
current_user = Depends(get_current_user)
):
"""重命名文件"""
req = request.model_dump()
try:
e, file = FileService.get_by_id(req["file_id"])
if not e:
@@ -314,8 +345,8 @@ def rename():
data=False,
message="The extension of file can't be changed",
code=settings.RetCode.ARGUMENT_ERROR)
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
if file.name == req["name"]:
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
if existing_file.name == req["name"]:
return get_data_error_result(
message="Duplicated file name in the same folder.")
@@ -336,9 +367,12 @@ def rename():
return server_error_response(e)
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
@login_required
def get(file_id):
@router.get('/get/{file_id}')
async def get(
file_id: str,
current_user = Depends(get_current_user)
):
"""获取文件内容"""
try:
e, file = FileService.get_by_id(file_id)
if not e:
@@ -351,25 +385,28 @@ def get(file_id):
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = STORAGE_IMPL.get(b, n)
response = flask.make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
ext = ext.group(1) if ext else None
content_type = "application/octet-stream"
if ext:
if file.type == FileType.VISUAL.value:
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
else:
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
response.headers.set("Content-Type", content_type)
return response
return Response(content=blob, media_type=content_type)
except Exception as e:
return server_error_response(e)
@manager.route("/mv", methods=["POST"]) # noqa: F821
@login_required
@validate_request("src_file_ids", "dest_file_id")
def move():
req = request.json
@router.post("/mv")
async def move(
request: MoveFilesRequest,
current_user = Depends(get_current_user)
):
"""移动文件"""
req = request.model_dump()
try:
file_ids = req["src_file_ids"]
dest_parent_id = req["dest_file_id"]

View File

@@ -169,14 +169,17 @@ async def update(
):
"""更新知识库"""
req = request.model_dump(exclude_unset=True)
if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.")
if req["name"].strip() == "":
return get_data_error_result(message="Dataset name can't be empty.")
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip()
# 验证 name 字段(如果提供)
if "name" in req:
if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.")
if req["name"].strip() == "":
return get_data_error_result(message="Dataset name can't be empty.")
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip()
# 验证不允许的参数
not_allowed = ["id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date"]
@@ -202,7 +205,8 @@ async def update(
return get_data_error_result(
message="Can't find this knowledgebase!")
if req["name"].lower() != kb.name.lower() \
# 检查名称重复(仅在提供新名称时)
if "name" in req and req["name"].lower() != kb.name.lower() \
and len(
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
return get_data_error_result(

View File

@@ -0,0 +1,84 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
class SetConversationRequest(BaseModel):
"""设置对话请求"""
conversation_id: Optional[str] = None
is_new: bool
name: Optional[str] = Field(default="New conversation", max_length=255)
dialog_id: str
class DeleteConversationsRequest(BaseModel):
"""删除对话请求"""
conversation_ids: List[str]
class CompletionRequest(BaseModel):
"""完成请求(聊天完成)"""
conversation_id: str
messages: List[Dict[str, Any]]
llm_id: Optional[str] = None
stream: Optional[bool] = True
temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
class TTSRequest(BaseModel):
"""文本转语音请求"""
text: str
class DeleteMessageRequest(BaseModel):
"""删除消息请求"""
conversation_id: str
message_id: str
class ThumbupRequest(BaseModel):
"""点赞/点踩请求"""
conversation_id: str
message_id: str
thumbup: Optional[bool] = None
feedback: Optional[str] = ""
class AskRequest(BaseModel):
"""提问请求"""
question: str
kb_ids: List[str]
search_id: Optional[str] = ""
class MindmapRequest(BaseModel):
"""思维导图请求"""
question: str
kb_ids: List[str]
search_id: Optional[str] = ""
class RelatedQuestionsRequest(BaseModel):
"""相关问题请求"""
question: str
search_id: Optional[str] = ""

View File

@@ -0,0 +1,43 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List
from pydantic import BaseModel, Field
class CreateFileRequest(BaseModel):
"""创建文件/文件夹请求"""
name: str
parent_id: Optional[str] = None
type: Optional[str] = None
class DeleteFilesRequest(BaseModel):
"""删除文件请求"""
file_ids: List[str]
class RenameFileRequest(BaseModel):
"""重命名文件请求"""
file_id: str
name: str
class MoveFilesRequest(BaseModel):
"""移动文件请求"""
src_file_ids: List[str]
dest_file_id: str

View File

@@ -60,9 +60,16 @@ class CreateKnowledgeBaseRequest(BaseModel):
class UpdateKnowledgeBaseRequest(BaseModel):
"""更新知识库请求"""
kb_id: str
name: str
description: str
parser_id: str
name: Optional[str] = None
avatar: Optional[str] = None
language: Optional[str] = None
description: Optional[str] = None
permission: Optional[str] = None
doc_num: Optional[int] = None
token_num: Optional[int] = None
chunk_num: Optional[int] = None
parser_id: Optional[str] = None
embd_id: Optional[str] = None
pagerank: Optional[int] = None
# 其他可选字段,但排除 id, tenant_id, created_by, create_time, update_time, create_date, update_date

View File

@@ -0,0 +1,53 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
class CreateSearchRequest(BaseModel):
"""创建搜索应用请求"""
name: str
description: Optional[str] = ""
class UpdateSearchRequest(BaseModel):
"""更新搜索应用请求"""
search_id: str
name: str
search_config: Dict[str, Any]
tenant_id: str
description: Optional[str] = None
class DeleteSearchRequest(BaseModel):
"""删除搜索应用请求"""
search_id: str
class ListSearchAppsQuery(BaseModel):
"""列出搜索应用查询参数"""
keywords: Optional[str] = ""
page: Optional[int] = 0
page_size: Optional[int] = 0
orderby: Optional[str] = "create_time"
desc: Optional[str] = "true"
class ListSearchAppsBody(BaseModel):
"""列出搜索应用请求体"""
owner_ids: Optional[List[str]] = []

View File

@@ -14,8 +14,18 @@
# limitations under the License.
#
from flask import request
from flask_login import current_user, login_required
from typing import Optional
from fastapi import APIRouter, Depends, Query
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.search_models import (
CreateSearchRequest,
UpdateSearchRequest,
DeleteSearchRequest,
ListSearchAppsQuery,
ListSearchAppsBody,
)
from api import settings
from api.constants import DATASET_NAME_LIMIT
@@ -25,14 +35,23 @@ from api.db.services import duplicate_name
from api.db.services.search_service import SearchService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
)
# 创建路由器
router = APIRouter()
@manager.route("/create", methods=["post"]) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.get_json()
@router.post('/create')
async def create(
request: CreateSearchRequest,
current_user = Depends(get_current_user)
):
"""创建搜索应用"""
req = request.model_dump(exclude_unset=True)
search_name = req["name"]
description = req.get("description", "")
if not isinstance(search_name, str):
@@ -62,12 +81,13 @@ def create():
return server_error_response(e)
@manager.route("/update", methods=["post"]) # noqa: F821
@login_required
@validate_request("search_id", "name", "search_config", "tenant_id")
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
def update():
req = request.get_json()
@router.post('/update')
async def update(
request: UpdateSearchRequest,
current_user = Depends(get_current_user)
):
"""更新搜索应用"""
req = request.model_dump(exclude_unset=True)
if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "":
@@ -84,6 +104,12 @@ def update():
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
# 验证不允许的参数
not_allowed = ["id", "created_by", "create_time", "update_time", "create_date", "update_date"]
for key in not_allowed:
if key in req:
del req[key]
try:
search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0]
if not search_app:
@@ -119,10 +145,12 @@ def update():
return server_error_response(e)
@manager.route("/detail", methods=["GET"]) # noqa: F821
@login_required
def detail():
search_id = request.args["search_id"]
@router.get('/detail')
async def detail(
search_id: str = Query(..., description="搜索应用ID"),
current_user = Depends(get_current_user)
):
"""获取搜索应用详情"""
try:
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
@@ -139,20 +167,23 @@ def detail():
return server_error_response(e)
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
def list_search_app():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
@router.post('/list')
async def list_search_app(
query: ListSearchAppsQuery = Depends(),
body: Optional[ListSearchAppsBody] = None,
current_user = Depends(get_current_user)
):
"""列出搜索应用"""
if body is None:
body = ListSearchAppsBody()
keywords = query.keywords or ""
page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
req = request.get_json()
owner_ids = req.get("owner_ids", [])
owner_ids = body.owner_ids or [] if body else []
try:
if not owner_ids:
# tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
@@ -171,12 +202,13 @@ def list_search_app():
return server_error_response(e)
@manager.route("/rm", methods=["post"]) # noqa: F821
@login_required
@validate_request("search_id")
def rm():
req = request.get_json()
search_id = req["search_id"]
@router.post('/rm')
async def rm(
request: DeleteSearchRequest,
current_user = Depends(get_current_user)
):
"""删除搜索应用"""
search_id = request.search_id
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)

View File

@@ -15,12 +15,15 @@
#
import json
import logging
import os
import re
import secrets
import string
import time
from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi import APIRouter, Depends, HTTPException, Request, Response, Query, status
from api.apps.models.auth_dependencies import get_current_user
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, EmailStr
@@ -60,6 +63,19 @@ from api.utils.api_utils import (
validate_request,
)
from api.utils.crypt import decrypt
from rag.utils.redis_conn import REDIS_CONN
from api.apps import smtp_mail_server
from api.utils.web_utils import (
send_email_html,
OTP_LENGTH,
OTP_TTL_SECONDS,
ATTEMPT_LIMIT,
ATTEMPT_LOCK_SECONDS,
RESEND_COOLDOWN_SECONDS,
otp_keys,
hash_code,
captcha_key,
)
# 创建路由器
router = APIRouter()
@@ -77,9 +93,8 @@ class RegisterRequest(BaseModel):
password: str
class UserSettingRequest(BaseModel):
nickname: Optional[str] = None
password: Optional[str] = None
new_password: Optional[str] = None
language: Optional[str] = None
class TenantInfoRequest(BaseModel):
tenant_id: str
@@ -88,6 +103,16 @@ class TenantInfoRequest(BaseModel):
img2txt_id: str
llm_id: str
class ForgetOtpRequest(BaseModel):
email: str
captcha: str
class ForgetPasswordRequest(BaseModel):
email: str
otp: str
new_password: str
confirm_new_password: str
# 依赖项:获取当前用户 - 从 auth_dependencies 导入
@router.post("/login")
@@ -481,3 +506,357 @@ async def set_tenant_info(request: TenantInfoRequest, current_user = Depends(get
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
@router.get("/github_callback")
async def github_callback(code: Optional[str] = Query(None)):
"""
**Deprecated**, Use `/oauth/callback/<channel>` instead.
GitHub OAuth callback endpoint.
"""
import requests
if not code:
return RedirectResponse(url="/?error=missing_code")
res = requests.post(
settings.GITHUB_OAUTH.get("url"),
data={
"client_id": settings.GITHUB_OAUTH.get("client_id"),
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
"code": code,
},
headers={"Accept": "application/json"},
)
res = res.json()
if "error" in res:
return RedirectResponse(url=f"/?error={res.get('error_description', res.get('error'))}")
if "user:email" not in res.get("scope", "").split(","):
return RedirectResponse(url="/?error=user:email not in scope")
access_token = res["access_token"]
user_info = user_info_from_github(access_token)
email_address = user_info["email"]
users = UserService.query(email=email_address)
user_id = get_uuid()
if not users:
try:
try:
avatar = download_img(user_info["avatar_url"])
except Exception as e:
logging.exception(e)
avatar = ""
users = user_register(
user_id,
{
"access_token": access_token,
"email": email_address,
"avatar": avatar,
"nickname": user_info["login"],
"login_channel": "github",
"last_login_time": get_format_time(),
"is_superuser": False,
},
)
if not users:
raise Exception(f"Fail to register {email_address}.")
if len(users) > 1:
raise Exception(f"Same email: {email_address} exists!")
user = users[0]
return RedirectResponse(url=f"/?auth={user.get_id()}")
except Exception as e:
rollback_user_registration(user_id)
logging.exception(e)
return RedirectResponse(url=f"/?error={str(e)}")
# User has already registered, try to log in
user = users[0]
user.access_token = get_uuid()
if user and hasattr(user, 'is_active') and user.is_active == "0":
return RedirectResponse(url="/?error=user_inactive")
user.save()
return RedirectResponse(url=f"/?auth={user.get_id()}")
@router.get("/feishu_callback")
async def feishu_callback(code: Optional[str] = Query(None)):
"""
Feishu OAuth callback endpoint.
"""
import requests
if not code:
return RedirectResponse(url="/?error=missing_code")
app_access_token_res = requests.post(
settings.FEISHU_OAUTH.get("app_access_token_url"),
data=json.dumps(
{
"app_id": settings.FEISHU_OAUTH.get("app_id"),
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
}
),
headers={"Content-Type": "application/json; charset=utf-8"},
)
app_access_token_res = app_access_token_res.json()
if app_access_token_res.get("code") != 0:
return RedirectResponse(url=f"/?error={app_access_token_res}")
res = requests.post(
settings.FEISHU_OAUTH.get("user_access_token_url"),
data=json.dumps(
{
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
"code": code,
}
),
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {app_access_token_res['app_access_token']}",
},
)
res = res.json()
if res.get("code") != 0:
return RedirectResponse(url=f"/?error={res.get('message', 'unknown_error')}")
if "contact:user.email:readonly" not in res.get("data", {}).get("scope", "").split():
return RedirectResponse(url="/?error=contact:user.email:readonly not in scope")
access_token = res["data"]["access_token"]
user_info = user_info_from_feishu(access_token)
email_address = user_info["email"]
users = UserService.query(email=email_address)
user_id = get_uuid()
if not users:
try:
try:
avatar = download_img(user_info["avatar_url"])
except Exception as e:
logging.exception(e)
avatar = ""
users = user_register(
user_id,
{
"access_token": access_token,
"email": email_address,
"avatar": avatar,
"nickname": user_info["en_name"],
"login_channel": "feishu",
"last_login_time": get_format_time(),
"is_superuser": False,
},
)
if not users:
raise Exception(f"Fail to register {email_address}.")
if len(users) > 1:
raise Exception(f"Same email: {email_address} exists!")
user = users[0]
return RedirectResponse(url=f"/?auth={user.get_id()}")
except Exception as e:
rollback_user_registration(user_id)
logging.exception(e)
return RedirectResponse(url=f"/?error={str(e)}")
# User has already registered, try to log in
user = users[0]
if user and hasattr(user, 'is_active') and user.is_active == "0":
return RedirectResponse(url="/?error=user_inactive")
user.access_token = get_uuid()
user.save()
return RedirectResponse(url=f"/?auth={user.get_id()}")
def user_info_from_feishu(access_token):
"""从飞书获取用户信息"""
import requests
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
user_info = res.json()["data"]
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
return user_info
def user_info_from_github(access_token):
"""从GitHub获取用户信息"""
import requests
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
user_info = res.json()
email_info = requests.get(
f"https://api.github.com/user/emails?access_token={access_token}",
headers=headers,
).json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return user_info
@router.get("/forget/captcha")
async def forget_get_captcha(email: str = Query(...)):
"""
GET /forget/captcha?email=<email>
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = 60 seconds.
- Returns the captcha as a JPEG image.
"""
if not email:
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email is required")
users = UserService.query(email=email)
if not users:
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
# Generate captcha text
allowed = string.ascii_uppercase + string.digits
captcha_text = "".join(secrets.choice(allowed) for _ in range(OTP_LENGTH))
REDIS_CONN.set(captcha_key(email), captcha_text, 60) # Valid for 60 seconds
from captcha.image import ImageCaptcha
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
img_bytes = image.generate(captcha_text).read()
return Response(content=img_bytes, media_type="image/JPEG")
@router.post("/forget/otp")
async def forget_send_otp(request: ForgetOtpRequest):
"""
POST /forget/otp
- Verify the image captcha stored at captcha:{email} (case-insensitive).
- On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
"""
email = request.email or ""
captcha = (request.captcha or "").strip()
if not email or not captcha:
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email and captcha required")
users = UserService.query(email=email)
if not users:
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
stored_captcha = REDIS_CONN.get(captcha_key(email))
if not stored_captcha:
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="invalid or expired captcha")
if (stored_captcha or "").strip().lower() != captcha.lower():
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="invalid or expired captcha")
# Delete captcha to prevent reuse
REDIS_CONN.delete(captcha_key(email))
k_code, k_attempts, k_last, k_lock = otp_keys(email)
now = int(time.time())
last_ts = REDIS_CONN.get(k_last)
if last_ts:
try:
elapsed = now - int(last_ts)
except Exception:
elapsed = RESEND_COOLDOWN_SECONDS
remaining = RESEND_COOLDOWN_SECONDS - elapsed
if remaining > 0:
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message=f"you still have to wait {remaining} seconds")
# Generate OTP (uppercase letters only) and store hashed
otp = "".join(secrets.choice(string.ascii_uppercase) for _ in range(OTP_LENGTH))
salt = os.urandom(16)
code_hash = hash_code(otp, salt)
REDIS_CONN.set(k_code, f"{code_hash}:{salt.hex()}", OTP_TTL_SECONDS)
REDIS_CONN.set(k_attempts, 0, OTP_TTL_SECONDS)
REDIS_CONN.set(k_last, now, OTP_TTL_SECONDS)
REDIS_CONN.delete(k_lock)
ttl_min = OTP_TTL_SECONDS // 60
if not smtp_mail_server:
logging.warning("SMTP mail server not initialized; skip sending email.")
else:
try:
send_email_html(
subject="Your Password Reset Code",
to_email=email,
template_key="reset_code",
code=otp,
ttl_min=ttl_min,
)
except Exception:
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="failed to send email")
return get_json_result(data=True, code=settings.RetCode.SUCCESS, message="verification passed, email sent")
@router.post("/forget")
async def forget(request: ForgetPasswordRequest):
"""
POST: Verify email + OTP and reset password, then log the user in.
Request JSON: { email, otp, new_password, confirm_new_password }
"""
email = request.email or ""
otp = (request.otp or "").strip()
new_pwd = request.new_password
new_pwd2 = request.confirm_new_password
if not all([email, otp, new_pwd, new_pwd2]):
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required")
# For reset, passwords are provided as-is (no decrypt needed)
if new_pwd != new_pwd2:
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="passwords do not match")
users = UserService.query(email=email)
if not users:
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
user = users[0]
# Verify OTP from Redis
k_code, k_attempts, k_last, k_lock = otp_keys(email)
if REDIS_CONN.get(k_lock):
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="too many attempts, try later")
stored = REDIS_CONN.get(k_code)
if not stored:
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="expired otp")
try:
stored_hash, salt_hex = str(stored).split(":", 1)
salt = bytes.fromhex(salt_hex)
except Exception:
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
# Case-insensitive verification: OTP generated uppercase
calc = hash_code(otp.upper(), salt)
if calc != stored_hash:
# bump attempts
try:
attempts = int(REDIS_CONN.get(k_attempts) or 0) + 1
except Exception:
attempts = 1
REDIS_CONN.set(k_attempts, attempts, OTP_TTL_SECONDS)
if attempts >= ATTEMPT_LIMIT:
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="expired otp")
# Success: consume OTP and reset password
REDIS_CONN.delete(k_code)
REDIS_CONN.delete(k_attempts)
REDIS_CONN.delete(k_last)
REDIS_CONN.delete(k_lock)
try:
UserService.update_user_password(user.id, new_pwd)
except Exception as e:
logging.exception(e)
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="failed to reset password")
# Auto login (reuse login flow)
user.access_token = get_uuid()
user.update_time = (current_timestamp(),)
user.update_date = (datetime_format(datetime.now()),)
user.save()
msg = "Password reset successful. Logged in."
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)

View File

@@ -8,6 +8,10 @@ minio:
user: 'rag_flow'
password: 'infini_rag_flow'
host: 'localhost:9000'
es:
hosts: 'http://localhost:1200'
username: 'elastic'
password: 'infini_rag_flow'
os:
hosts: 'http://localhost:1201'
username: 'admin'

View File

@@ -3,7 +3,7 @@
# - `elasticsearch` (default)
# - `infinity` (https://github.com/infiniflow/infinity)
# - `opensearch` (https://github.com/opensearch-project/OpenSearch)
DOC_ENGINE=opensearch
DOC_ENGINE=elasticsearch
# ------------------------------
# docker env var for specifying vector db type at startup
@@ -98,7 +98,7 @@ ADMIN_SVR_HTTP_PORT=9381
# The RAGFlow Docker image to download.
# Defaults to the v0.21.1-slim edition, which is the RAGFlow Docker image without embedding models.
RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi
RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi-web
#
# To download the RAGFlow Docker image with embedding models, uncomment the following line instead:
# RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1

View File

@@ -1,36 +1,34 @@
services:
opensearch01:
container_name: ragflow-opensearch-01
es01:
container_name: ragflow-es-01
profiles:
- opensearch
image: hub.icert.top/opensearchproject/opensearch:2.19.1
- elasticsearch
image: elasticsearch:${STACK_VERSION}
volumes:
- osdata01:/usr/share/opensearch/data
- esdata01:/usr/share/elasticsearch/data
ports:
- ${OS_PORT}:9201
- ${ES_PORT}:9200
env_file: .env
environment:
- node.name=opensearch01
- OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD}
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_PASSWORD}
- node.name=es01
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
- bootstrap.memory_lock=false
- discovery.type=single-node
- plugins.security.disabled=false
- plugins.security.ssl.http.enabled=false
- plugins.security.ssl.transport.enabled=true
- xpack.security.enabled=true
- xpack.security.http.ssl.enabled=false
- xpack.security.transport.ssl.enabled=false
- cluster.routing.allocation.disk.watermark.low=5gb
- cluster.routing.allocation.disk.watermark.high=3gb
- cluster.routing.allocation.disk.watermark.flood_stage=2gb
- TZ=${TIMEZONE}
- http.port=9201
mem_limit: ${MEM_LIMIT}
ulimits:
memlock:
soft: -1
hard: -1
healthcheck:
test: ["CMD-SHELL", "curl http://localhost:9201"]
test: ["CMD-SHELL", "curl http://localhost:9200"]
interval: 10s
timeout: 10s
retries: 120

View File

@@ -64,7 +64,8 @@ class Dealer:
if key in req and req[key] is not None:
condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
"removed_kwd"]:
if key in req and req[key] is not None:
condition[key] = req[key]
return condition
@@ -135,7 +136,8 @@ class Dealer:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
@@ -212,8 +214,9 @@ class Dealer:
ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0]*len(ans_v[0])
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
chunk_v[i] = [0.0] * len(ans_v[0])
logging.warning(
"The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
@@ -267,7 +270,7 @@ class Dealer:
if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids:
nor, denor = 0, 0
if not search_res.field[i].get(TAG_FLD):
@@ -280,8 +283,8 @@ class Dealer:
if denor == 0:
rank_fea.append(0)
else:
rank_fea.append(nor/np.sqrt(denor)/q_denor)
return np.array(rank_fea)*10. + pageranks
rank_fea.append(nor / np.sqrt(denor) / q_denor)
return np.array(rank_fea) * 10. + pageranks
def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks",
@@ -343,7 +346,7 @@ class Dealer:
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,
@@ -360,13 +363,13 @@ class Dealer:
return ranks
# Ensure RERANK_LIMIT is multiple of page_size
RERANK_LIMIT = math.ceil(64/page_size) * page_size if page_size>1 else 1
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
RERANK_LIMIT = math.ceil(64 / page_size) * page_size if page_size > 1 else 1
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size * page / RERANK_LIMIT),
"size": RERANK_LIMIT,
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
"available_int": 1}
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
@@ -392,15 +395,15 @@ class Dealer:
tsim = sim
vsim = sim
# Already paginated in search function
begin = ((page % (RERANK_LIMIT//page_size)) - 1) * page_size
sim = sim[begin : begin + page_size]
begin = ((page % (RERANK_LIMIT // page_size)) - 1) * page_size
sim = sim[begin: begin + page_size]
sim_np = np.array(sim)
idx = np.argsort(sim_np * -1)
dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx:
if sim[i] < similarity_threshold:
break
@@ -447,8 +450,8 @@ class Dealer:
ranks["doc_aggs"] = [{"doc_name": k,
"doc_id": v["doc_id"],
"count": v["count"]} for k,
v in sorted(ranks["doc_aggs"].items(),
key=lambda x: x[1]["count"] * -1)]
v in sorted(ranks["doc_aggs"].items(),
key=lambda x: x[1]["count"] * -1)]
ranks["chunks"] = ranks["chunks"][:page_size]
return ranks
@@ -505,13 +508,14 @@ class Dealer:
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
return True
@@ -527,11 +531,11 @@ class Dealer:
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6):
def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6):
if not chunks:
return []
idx_nms = [index_name(tid) for tid in tenant_ids]
@@ -541,9 +545,10 @@ class Dealer:
ranks[ck["doc_id"]] = 0
ranks[ck["doc_id"]] += ck["similarity"]
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0]
doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
kb_ids = [doc_id2kb_id[doc_id]]
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [],
OrderByExpr(), 0, 128, idx_nms,
kb_ids)
toc = []
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
@@ -555,10 +560,10 @@ class Dealer:
if not toc:
return chunks
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2)
if not ids:
return chunks
vector_size = 1024
id2idx = {ck["chunk_id"]: i for i, ck in enumerate(chunks)}
for cid, sim in ids:
@@ -589,4 +594,4 @@ class Dealer:
break
chunks.append(d)
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]
return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn]

35
rag/res/deepdoc/.gitattributes vendored Normal file
View File

@@ -0,0 +1,35 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text

Binary file not shown.

Binary file not shown.

BIN
rag/res/deepdoc/layout.onnx Normal file

Binary file not shown.

Binary file not shown.

BIN
rag/res/deepdoc/tsr.onnx Normal file

Binary file not shown.