diff --git a/api/apps/__init___fastapi.py b/api/apps/__init___fastapi.py index afc01c7..30e0c82 100644 --- a/api/apps/__init___fastapi.py +++ b/api/apps/__init___fastapi.py @@ -165,17 +165,23 @@ def setup_routes(app: FastAPI): from api.apps.tenant_app import router as tenant_router from api.apps.dialog_app import router as dialog_router from api.apps.system_app import router as system_router + from api.apps.search_app import router as search_router + from api.apps.conversation_app import router as conversation_router + from api.apps.file_app import router as file_router app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"]) app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"]) app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"]) app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"]) app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"]) - app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"]) + app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp_server", tags=["MCP"]) app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"]) app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"]) app.include_router(dialog_router, prefix=f"/{API_VERSION}/dialog", tags=["Dialog"]) app.include_router(system_router, prefix=f"/{API_VERSION}/system", tags=["System"]) + app.include_router(search_router, prefix=f"/{API_VERSION}/search", tags=["Search"]) + app.include_router(conversation_router, prefix=f"/{API_VERSION}/conversation", tags=["Conversation"]) + app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"]) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 48b9a15..2b18504 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -17,8 +17,10 @@ import json import re import logging from copy import deepcopy -from flask import Response, request -from flask_login import current_user, login_required +from typing import Optional +from fastapi import APIRouter, Depends, Query, Header, HTTPException, status +from fastapi.responses import StreamingResponse + from api import settings from api.db import LLMType from api.db.db_models import APIToken @@ -28,15 +30,35 @@ from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response +from api.utils import get_uuid from rag.prompts.template import load_prompt from rag.prompts.generator import chunks_format +from api.apps.models.auth_dependencies import get_current_user +from api.apps.models.conversation_models import ( + SetConversationRequest, + DeleteConversationsRequest, + CompletionRequest, + TTSRequest, + DeleteMessageRequest, + ThumbupRequest, + AskRequest, + MindmapRequest, + RelatedQuestionsRequest, +) -@manager.route("/set", methods=["POST"]) # noqa: F821 -@login_required -def set_conversation(): - req = request.json +# 创建路由器 +router = APIRouter() + + +@router.post('/set') +async def set_conversation( + request: SetConversationRequest, + current_user = Depends(get_current_user) +): + """设置对话""" + req = request.model_dump(exclude_unset=True) conv_id = req.get("conversation_id") is_new = req.get("is_new") name = req.get("name", "New conversation") @@ -45,9 +67,9 @@ def set_conversation(): if len(name) > 255: name = name[0:255] - del req["is_new"] if not is_new: - del req["conversation_id"] + if not conv_id: + return get_data_error_result(message="conversation_id is required when is_new is False!") try: if not ConversationService.update_by_id(conv_id, req): return get_data_error_result(message="Conversation not found!") @@ -64,7 +86,7 @@ def set_conversation(): if not e: return get_data_error_result(message="Dialog not found") conv = { - "id": conv_id, + "id": conv_id or get_uuid(), "dialog_id": req["dialog_id"], "name": name, "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}], @@ -77,12 +99,14 @@ def set_conversation(): return server_error_response(e) -@manager.route("/get", methods=["GET"]) # noqa: F821 -@login_required -def get(): - conv_id = request.args["conversation_id"] +@router.get('/get') +async def get( + conversation_id: str = Query(..., description="对话ID"), + current_user = Depends(get_current_user) +): + """获取对话""" try: - e, conv = ConversationService.get_by_id(conv_id) + e, conv = ConversationService.get_by_id(conversation_id) if not e: return get_data_error_result(message="Conversation not found!") tenants = UserTenantService.query(user_id=current_user.id) @@ -107,15 +131,27 @@ def get(): return server_error_response(e) -@manager.route("/getsse/", methods=["GET"]) # type: ignore # noqa: F821 -def getsse(dialog_id): - token = request.headers.get("Authorization").split() - if len(token) != 2: +@router.get('/getsse/{dialog_id}') +async def getsse( + dialog_id: str, + authorization: Optional[str] = Header(None, alias="Authorization") +): + """通过 SSE 获取对话(使用 API token 认证)""" + if not authorization: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authorization header is required" + ) + + token_parts = authorization.split() + if len(token_parts) != 2: return get_data_error_result(message='Authorization is not valid!"') - token = token[1] + token = token_parts[1] + objs = APIToken.query(beta=token) if not objs: return get_data_error_result(message='Authentication error: API key is invalid!"') + try: e, conv = DialogService.get_by_id(dialog_id) if not e: @@ -128,12 +164,14 @@ def getsse(dialog_id): return server_error_response(e) -@manager.route("/rm", methods=["POST"]) # noqa: F821 -@login_required -def rm(): - conv_ids = request.json["conversation_ids"] +@router.post('/rm') +async def rm( + request: DeleteConversationsRequest, + current_user = Depends(get_current_user) +): + """删除对话""" try: - for cid in conv_ids: + for cid in request.conversation_ids: exist, conv = ConversationService.get_by_id(cid) if not exist: return get_data_error_result(message="Conversation not found!") @@ -149,10 +187,12 @@ def rm(): return server_error_response(e) -@manager.route("/list", methods=["GET"]) # noqa: F821 -@login_required -def list_conversation(): - dialog_id = request.args["dialog_id"] +@router.get('/list') +async def list_conversation( + dialog_id: str = Query(..., description="对话ID"), + current_user = Depends(get_current_user) +): + """列出对话""" try: if not DialogService.query(tenant_id=current_user.id, id=dialog_id): return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) @@ -164,11 +204,13 @@ def list_conversation(): return server_error_response(e) -@manager.route("/completion", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("conversation_id", "messages") -def completion(): - req = request.json +@router.post('/completion') +async def completion( + request: CompletionRequest, + current_user = Depends(get_current_user) +): + """完成请求(聊天完成)""" + req = request.model_dump(exclude_unset=True) msg = [] for m in req["messages"]: if m["role"] == "system": @@ -176,6 +218,10 @@ def completion(): if m["role"] == "assistant" and not msg: continue msg.append(m) + + if not msg: + return get_data_error_result(message="No valid messages found!") + message_id = msg[-1].get("id") chat_model_id = req.get("llm_id", "") req.pop("llm_id", None) @@ -217,6 +263,7 @@ def completion(): dia.llm_setting = chat_model_config is_embedded = bool(chat_model_id) + def stream(): nonlocal dia, msg, req, conv try: @@ -230,14 +277,18 @@ def completion(): yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - if req.get("stream", True): - resp = Response(stream(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - + stream_enabled = request.stream if request.stream is not None else True + if stream_enabled: + return StreamingResponse( + stream(), + media_type="text/event-stream", + headers={ + "Cache-control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "Content-Type": "text/event-stream; charset=utf-8" + } + ) else: answer = None for ans in chat(dia, msg, **req): @@ -250,11 +301,13 @@ def completion(): return server_error_response(e) -@manager.route("/tts", methods=["POST"]) # noqa: F821 -@login_required -def tts(): - req = request.json - text = req["text"] +@router.post('/tts') +async def tts( + request: TTSRequest, + current_user = Depends(get_current_user) +): + """文本转语音""" + text = request.text tenants = TenantService.get_info_by(current_user.id) if not tenants: @@ -274,28 +327,32 @@ def tts(): except Exception as e: yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8") - resp = Response(stream_audio(), mimetype="audio/mpeg") - resp.headers.add_header("Cache-Control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - - return resp + return StreamingResponse( + stream_audio(), + media_type="audio/mpeg", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) -@manager.route("/delete_msg", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("conversation_id", "message_id") -def delete_msg(): - req = request.json - e, conv = ConversationService.get_by_id(req["conversation_id"]) +@router.post('/delete_msg') +async def delete_msg( + request: DeleteMessageRequest, + current_user = Depends(get_current_user) +): + """删除消息""" + e, conv = ConversationService.get_by_id(request.conversation_id) if not e: return get_data_error_result(message="Conversation not found!") conv = conv.to_dict() for i, msg in enumerate(conv["message"]): - if req["message_id"] != msg.get("id", ""): + if request.message_id != msg.get("id", ""): continue - assert conv["message"][i + 1]["id"] == req["message_id"] + assert conv["message"][i + 1]["id"] == request.message_id conv["message"].pop(i) conv["message"].pop(i) conv["reference"].pop(max(0, i // 2 - 1)) @@ -305,19 +362,21 @@ def delete_msg(): return get_json_result(data=conv) -@manager.route("/thumbup", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("conversation_id", "message_id") -def thumbup(): - req = request.json - e, conv = ConversationService.get_by_id(req["conversation_id"]) +@router.post('/thumbup') +async def thumbup( + request: ThumbupRequest, + current_user = Depends(get_current_user) +): + """点赞/点踩""" + e, conv = ConversationService.get_by_id(request.conversation_id) if not e: return get_data_error_result(message="Conversation not found!") - up_down = req.get("thumbup") - feedback = req.get("feedback", "") + + up_down = request.thumbup + feedback = request.feedback or "" conv = conv.to_dict() for i, msg in enumerate(conv["message"]): - if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant": + if request.message_id == msg.get("id", "") and msg.get("role", "") == "assistant": if up_down: msg["thumbup"] = True if "feedback" in msg: @@ -332,14 +391,15 @@ def thumbup(): return get_json_result(data=conv) -@manager.route("/ask", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("question", "kb_ids") -def ask_about(): - req = request.json +@router.post('/ask') +async def ask_about( + request: AskRequest, + current_user = Depends(get_current_user) +): + """提问""" uid = current_user.id - search_id = req.get("search_id", "") + search_id = request.search_id or "" search_app = None search_config = {} if search_id: @@ -348,53 +408,58 @@ def ask_about(): search_config = search_app.get("search_config", {}) def stream(): - nonlocal req, uid + nonlocal request, uid try: - for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): + for ans in ask(request.question, request.kb_ids, uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - resp = Response(stream(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp + return StreamingResponse( + stream(), + media_type="text/event-stream", + headers={ + "Cache-control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "Content-Type": "text/event-stream; charset=utf-8" + } + ) -@manager.route("/mindmap", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("question", "kb_ids") -def mindmap(): - req = request.json - search_id = req.get("search_id", "") +@router.post('/mindmap') +async def mindmap( + request: MindmapRequest, + current_user = Depends(get_current_user) +): + """思维导图""" + search_id = request.search_id or "" search_app = SearchService.get_detail(search_id) if search_id else {} search_config = search_app.get("search_config", {}) if search_app else {} kb_ids = search_config.get("kb_ids", []) - kb_ids.extend(req["kb_ids"]) + kb_ids.extend(request.kb_ids) kb_ids = list(set(kb_ids)) - mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config) + mind_map = gen_mindmap(request.question, kb_ids, search_app.get("tenant_id", current_user.id), search_config) if "error" in mind_map: return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) -@manager.route("/related_questions", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("question") -def related_questions(): - req = request.json - - search_id = req.get("search_id", "") +@router.post('/related_questions') +async def related_questions( + request: RelatedQuestionsRequest, + current_user = Depends(get_current_user) +): + """相关问题""" + search_id = request.search_id or "" search_config = {} if search_id: if search_app := SearchService.get_detail(search_id): search_config = search_app.get("search_config", {}) - question = req["question"] + question = request.question chat_id = search_config.get("chat_id", "") chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index b54ce76..45f89ff 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -72,18 +72,17 @@ router = APIRouter() @router.post("/upload") async def upload( kb_id: str = Form(...), - files: List[UploadFile] = File(...), + file: UploadFile = File(...), current_user = Depends(get_current_user) ): """上传文档""" - if not files: + if not file: return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR) - for file_obj in files: - if not file_obj.filename or file_obj.filename == "": - return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR) - if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: - return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR) + if not file.filename or file.filename == "": + return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR) + if len(file.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: + return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: @@ -91,7 +90,7 @@ async def upload( if not check_kb_team_permission(kb, current_user.id): return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) - err, uploaded_files = FileService.upload_document(kb, files, current_user.id) + err, uploaded_files = FileService.upload_document(kb, [file], current_user.id) if err: return get_json_result(data=uploaded_files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR) diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 7828a82..2b207d2 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -17,15 +17,23 @@ import logging import os import pathlib import re +from typing import Optional, List -import flask -from flask import request -from flask_login import login_required, current_user +from fastapi import APIRouter, Depends, Query, UploadFile, File, Form +from fastapi.responses import Response + +from api.apps.models.auth_dependencies import get_current_user +from api.apps.models.file_models import ( + CreateFileRequest, + DeleteFilesRequest, + RenameFileRequest, + MoveFilesRequest, +) from api.common.check_team_permission import check_file_team_permission from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils.api_utils import server_error_response, get_data_error_result from api.utils import get_uuid from api.db import FileType, FileSource from api.db.services import duplicate_name @@ -36,35 +44,41 @@ from api.utils.file_utils import filename_type from api.utils.web_utils import CONTENT_TYPE_MAP from rag.utils.storage_factory import STORAGE_IMPL +# 创建路由器 +router = APIRouter() -@manager.route('/upload', methods=['POST']) # noqa: F821 -@login_required -# @validate_request("parent_id") -def upload(): - pf_id = request.form.get("parent_id") + +@router.post('/upload') +async def upload( + files: List[UploadFile] = File(...), + parent_id: Optional[str] = Form(None), + current_user = Depends(get_current_user) +): + """上传文件""" + pf_id = parent_id if not pf_id: root_folder = FileService.get_root_folder(current_user.id) pf_id = root_folder["id"] - if 'file' not in request.files: + if not files: return get_json_result( data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist('file') - for file_obj in file_objs: - if file_obj.filename == '': + for file_obj in files: + if not file_obj.filename or file_obj.filename == '': return get_json_result( data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) + file_res = [] try: e, pf_folder = FileService.get_by_id(pf_id) if not e: - return get_data_error_result( message="Can't find this folder!") - for file_obj in file_objs: + return get_data_error_result(message="Can't find this folder!") + for file_obj in files: MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER: - return get_data_error_result( message="Exceed the maximum file number of a free user!") + return get_data_error_result(message="Exceed the maximum file number of a free user!") # split file name path if not file_obj.filename: @@ -97,7 +111,7 @@ def upload(): location = file_obj_names[file_len - 1] while STORAGE_IMPL.obj_exist(last_folder.id, location): location += "_" - blob = file_obj.read() + blob = await file_obj.read() filename = duplicate_name( FileService.query, name=file_obj_names[file_len - 1], @@ -120,13 +134,16 @@ def upload(): return server_error_response(e) -@manager.route('/create', methods=['POST']) # noqa: F821 -@login_required -@validate_request("name") -def create(): - req = request.json - pf_id = request.json.get("parent_id") - input_file_type = request.json.get("type") +@router.post('/create') +async def create( + request: CreateFileRequest, + current_user = Depends(get_current_user) +): + """创建文件/文件夹""" + req = request.model_dump(exclude_unset=True) + pf_id = req.get("parent_id") + input_file_type = req.get("type") + if not pf_id: root_folder = FileService.get_root_folder(current_user.id) pf_id = root_folder["id"] @@ -160,17 +177,22 @@ def create(): return server_error_response(e) -@manager.route('/list', methods=['GET']) # noqa: F821 -@login_required -def list_files(): - pf_id = request.args.get("parent_id") +@router.get('/list') +async def list_files( + parent_id: Optional[str] = Query(None, description="父文件夹ID"), + keywords: Optional[str] = Query("", description="搜索关键词"), + page: Optional[int] = Query(1, description="页码"), + page_size: Optional[int] = Query(15, description="每页数量"), + orderby: Optional[str] = Query("create_time", description="排序字段"), + desc: Optional[bool] = Query(True, description="是否降序"), + current_user = Depends(get_current_user) +): + """列出文件""" + pf_id = parent_id - keywords = request.args.get("keywords", "") - - page_number = int(request.args.get("page", 1)) - items_per_page = int(request.args.get("page_size", 15)) - orderby = request.args.get("orderby", "create_time") - desc = request.args.get("desc", True) + page_number = int(page) if page else 1 + items_per_page = int(page_size) if page_size else 15 + if not pf_id: root_folder = FileService.get_root_folder(current_user.id) pf_id = root_folder["id"] @@ -192,9 +214,11 @@ def list_files(): return server_error_response(e) -@manager.route('/root_folder', methods=['GET']) # noqa: F821 -@login_required -def get_root_folder(): +@router.get('/root_folder') +async def get_root_folder( + current_user = Depends(get_current_user) +): + """获取根文件夹""" try: root_folder = FileService.get_root_folder(current_user.id) return get_json_result(data={"root_folder": root_folder}) @@ -202,10 +226,12 @@ def get_root_folder(): return server_error_response(e) -@manager.route('/parent_folder', methods=['GET']) # noqa: F821 -@login_required -def get_parent_folder(): - file_id = request.args.get("file_id") +@router.get('/parent_folder') +async def get_parent_folder( + file_id: str = Query(..., description="文件ID"), + current_user = Depends(get_current_user) +): + """获取父文件夹""" try: e, file = FileService.get_by_id(file_id) if not e: @@ -217,10 +243,12 @@ def get_parent_folder(): return server_error_response(e) -@manager.route('/all_parent_folder', methods=['GET']) # noqa: F821 -@login_required -def get_all_parent_folders(): - file_id = request.args.get("file_id") +@router.get('/all_parent_folder') +async def get_all_parent_folders( + file_id: str = Query(..., description="文件ID"), + current_user = Depends(get_current_user) +): + """获取所有父文件夹""" try: e, file = FileService.get_by_id(file_id) if not e: @@ -235,12 +263,13 @@ def get_all_parent_folders(): return server_error_response(e) -@manager.route("/rm", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("file_ids") -def rm(): - req = request.json - file_ids = req["file_ids"] +@router.post("/rm") +async def rm( + request: DeleteFilesRequest, + current_user = Depends(get_current_user) +): + """删除文件""" + file_ids = request.file_ids def _delete_single_file(file): try: @@ -296,11 +325,13 @@ def rm(): return server_error_response(e) -@manager.route('/rename', methods=['POST']) # noqa: F821 -@login_required -@validate_request("file_id", "name") -def rename(): - req = request.json +@router.post('/rename') +async def rename( + request: RenameFileRequest, + current_user = Depends(get_current_user) +): + """重命名文件""" + req = request.model_dump() try: e, file = FileService.get_by_id(req["file_id"]) if not e: @@ -314,8 +345,8 @@ def rename(): data=False, message="The extension of file can't be changed", code=settings.RetCode.ARGUMENT_ERROR) - for file in FileService.query(name=req["name"], pf_id=file.parent_id): - if file.name == req["name"]: + for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id): + if existing_file.name == req["name"]: return get_data_error_result( message="Duplicated file name in the same folder.") @@ -336,9 +367,12 @@ def rename(): return server_error_response(e) -@manager.route('/get/', methods=['GET']) # noqa: F821 -@login_required -def get(file_id): +@router.get('/get/{file_id}') +async def get( + file_id: str, + current_user = Depends(get_current_user) +): + """获取文件内容""" try: e, file = FileService.get_by_id(file_id) if not e: @@ -351,25 +385,28 @@ def get(file_id): b, n = File2DocumentService.get_storage_address(file_id=file_id) blob = STORAGE_IMPL.get(b, n) - response = flask.make_response(blob) ext = re.search(r"\.([^.]+)$", file.name.lower()) ext = ext.group(1) if ext else None + + content_type = "application/octet-stream" if ext: if file.type == FileType.VISUAL.value: content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}") else: content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}") - response.headers.set("Content-Type", content_type) - return response + + return Response(content=blob, media_type=content_type) except Exception as e: return server_error_response(e) -@manager.route("/mv", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("src_file_ids", "dest_file_id") -def move(): - req = request.json +@router.post("/mv") +async def move( + request: MoveFilesRequest, + current_user = Depends(get_current_user) +): + """移动文件""" + req = request.model_dump() try: file_ids = req["src_file_ids"] dest_parent_id = req["dest_file_id"] diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 355a0f5..71399a4 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -169,14 +169,17 @@ async def update( ): """更新知识库""" req = request.model_dump(exclude_unset=True) - if not isinstance(req["name"], str): - return get_data_error_result(message="Dataset name must be string.") - if req["name"].strip() == "": - return get_data_error_result(message="Dataset name can't be empty.") - if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT: - return get_data_error_result( - message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}") - req["name"] = req["name"].strip() + + # 验证 name 字段(如果提供) + if "name" in req: + if not isinstance(req["name"], str): + return get_data_error_result(message="Dataset name must be string.") + if req["name"].strip() == "": + return get_data_error_result(message="Dataset name can't be empty.") + if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT: + return get_data_error_result( + message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}") + req["name"] = req["name"].strip() # 验证不允许的参数 not_allowed = ["id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date"] @@ -202,7 +205,8 @@ async def update( return get_data_error_result( message="Can't find this knowledgebase!") - if req["name"].lower() != kb.name.lower() \ + # 检查名称重复(仅在提供新名称时) + if "name" in req and req["name"].lower() != kb.name.lower() \ and len( KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1: return get_data_error_result( diff --git a/api/apps/models/conversation_models.py b/api/apps/models/conversation_models.py new file mode 100644 index 0000000..e0fb8d6 --- /dev/null +++ b/api/apps/models/conversation_models.py @@ -0,0 +1,84 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + + +class SetConversationRequest(BaseModel): + """设置对话请求""" + conversation_id: Optional[str] = None + is_new: bool + name: Optional[str] = Field(default="New conversation", max_length=255) + dialog_id: str + + +class DeleteConversationsRequest(BaseModel): + """删除对话请求""" + conversation_ids: List[str] + + +class CompletionRequest(BaseModel): + """完成请求(聊天完成)""" + conversation_id: str + messages: List[Dict[str, Any]] + llm_id: Optional[str] = None + stream: Optional[bool] = True + temperature: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + max_tokens: Optional[int] = None + + +class TTSRequest(BaseModel): + """文本转语音请求""" + text: str + + +class DeleteMessageRequest(BaseModel): + """删除消息请求""" + conversation_id: str + message_id: str + + +class ThumbupRequest(BaseModel): + """点赞/点踩请求""" + conversation_id: str + message_id: str + thumbup: Optional[bool] = None + feedback: Optional[str] = "" + + +class AskRequest(BaseModel): + """提问请求""" + question: str + kb_ids: List[str] + search_id: Optional[str] = "" + + +class MindmapRequest(BaseModel): + """思维导图请求""" + question: str + kb_ids: List[str] + search_id: Optional[str] = "" + + +class RelatedQuestionsRequest(BaseModel): + """相关问题请求""" + question: str + search_id: Optional[str] = "" + diff --git a/api/apps/models/file_models.py b/api/apps/models/file_models.py new file mode 100644 index 0000000..4c8a043 --- /dev/null +++ b/api/apps/models/file_models.py @@ -0,0 +1,43 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, List +from pydantic import BaseModel, Field + + +class CreateFileRequest(BaseModel): + """创建文件/文件夹请求""" + name: str + parent_id: Optional[str] = None + type: Optional[str] = None + + +class DeleteFilesRequest(BaseModel): + """删除文件请求""" + file_ids: List[str] + + +class RenameFileRequest(BaseModel): + """重命名文件请求""" + file_id: str + name: str + + +class MoveFilesRequest(BaseModel): + """移动文件请求""" + src_file_ids: List[str] + dest_file_id: str + diff --git a/api/apps/models/kb_models.py b/api/apps/models/kb_models.py index ca58398..b823412 100644 --- a/api/apps/models/kb_models.py +++ b/api/apps/models/kb_models.py @@ -60,9 +60,16 @@ class CreateKnowledgeBaseRequest(BaseModel): class UpdateKnowledgeBaseRequest(BaseModel): """更新知识库请求""" kb_id: str - name: str - description: str - parser_id: str + name: Optional[str] = None + avatar: Optional[str] = None + language: Optional[str] = None + description: Optional[str] = None + permission: Optional[str] = None + doc_num: Optional[int] = None + token_num: Optional[int] = None + chunk_num: Optional[int] = None + parser_id: Optional[str] = None + embd_id: Optional[str] = None pagerank: Optional[int] = None # 其他可选字段,但排除 id, tenant_id, created_by, create_time, update_time, create_date, update_date diff --git a/api/apps/models/search_models.py b/api/apps/models/search_models.py new file mode 100644 index 0000000..eb9d70b --- /dev/null +++ b/api/apps/models/search_models.py @@ -0,0 +1,53 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + + +class CreateSearchRequest(BaseModel): + """创建搜索应用请求""" + name: str + description: Optional[str] = "" + + +class UpdateSearchRequest(BaseModel): + """更新搜索应用请求""" + search_id: str + name: str + search_config: Dict[str, Any] + tenant_id: str + description: Optional[str] = None + + +class DeleteSearchRequest(BaseModel): + """删除搜索应用请求""" + search_id: str + + +class ListSearchAppsQuery(BaseModel): + """列出搜索应用查询参数""" + keywords: Optional[str] = "" + page: Optional[int] = 0 + page_size: Optional[int] = 0 + orderby: Optional[str] = "create_time" + desc: Optional[str] = "true" + + +class ListSearchAppsBody(BaseModel): + """列出搜索应用请求体""" + owner_ids: Optional[List[str]] = [] + diff --git a/api/apps/search_app.py b/api/apps/search_app.py index e0002f8..9d8c6f8 100644 --- a/api/apps/search_app.py +++ b/api/apps/search_app.py @@ -14,8 +14,18 @@ # limitations under the License. # -from flask import request -from flask_login import current_user, login_required +from typing import Optional + +from fastapi import APIRouter, Depends, Query + +from api.apps.models.auth_dependencies import get_current_user +from api.apps.models.search_models import ( + CreateSearchRequest, + UpdateSearchRequest, + DeleteSearchRequest, + ListSearchAppsQuery, + ListSearchAppsBody, +) from api import settings from api.constants import DATASET_NAME_LIMIT @@ -25,14 +35,23 @@ from api.db.services import duplicate_name from api.db.services.search_service import SearchService from api.db.services.user_service import TenantService, UserTenantService from api.utils import get_uuid -from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request +from api.utils.api_utils import ( + get_data_error_result, + get_json_result, + server_error_response, +) + +# 创建路由器 +router = APIRouter() -@manager.route("/create", methods=["post"]) # noqa: F821 -@login_required -@validate_request("name") -def create(): - req = request.get_json() +@router.post('/create') +async def create( + request: CreateSearchRequest, + current_user = Depends(get_current_user) +): + """创建搜索应用""" + req = request.model_dump(exclude_unset=True) search_name = req["name"] description = req.get("description", "") if not isinstance(search_name, str): @@ -62,12 +81,13 @@ def create(): return server_error_response(e) -@manager.route("/update", methods=["post"]) # noqa: F821 -@login_required -@validate_request("search_id", "name", "search_config", "tenant_id") -@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") -def update(): - req = request.get_json() +@router.post('/update') +async def update( + request: UpdateSearchRequest, + current_user = Depends(get_current_user) +): + """更新搜索应用""" + req = request.model_dump(exclude_unset=True) if not isinstance(req["name"], str): return get_data_error_result(message="Search name must be string.") if req["name"].strip() == "": @@ -84,6 +104,12 @@ def update(): if not SearchService.accessible4deletion(search_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) + # 验证不允许的参数 + not_allowed = ["id", "created_by", "create_time", "update_time", "create_date", "update_date"] + for key in not_allowed: + if key in req: + del req[key] + try: search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0] if not search_app: @@ -119,10 +145,12 @@ def update(): return server_error_response(e) -@manager.route("/detail", methods=["GET"]) # noqa: F821 -@login_required -def detail(): - search_id = request.args["search_id"] +@router.get('/detail') +async def detail( + search_id: str = Query(..., description="搜索应用ID"), + current_user = Depends(get_current_user) +): + """获取搜索应用详情""" try: tenants = UserTenantService.query(user_id=current_user.id) for tenant in tenants: @@ -139,20 +167,23 @@ def detail(): return server_error_response(e) -@manager.route("/list", methods=["POST"]) # noqa: F821 -@login_required -def list_search_app(): - keywords = request.args.get("keywords", "") - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "true").lower() == "false": - desc = False - else: - desc = True +@router.post('/list') +async def list_search_app( + query: ListSearchAppsQuery = Depends(), + body: Optional[ListSearchAppsBody] = None, + current_user = Depends(get_current_user) +): + """列出搜索应用""" + if body is None: + body = ListSearchAppsBody() + + keywords = query.keywords or "" + page_number = int(query.page or 0) + items_per_page = int(query.page_size or 0) + orderby = query.orderby or "create_time" + desc = query.desc.lower() == "true" if query.desc else True - req = request.get_json() - owner_ids = req.get("owner_ids", []) + owner_ids = body.owner_ids or [] if body else [] try: if not owner_ids: # tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) @@ -171,12 +202,13 @@ def list_search_app(): return server_error_response(e) -@manager.route("/rm", methods=["post"]) # noqa: F821 -@login_required -@validate_request("search_id") -def rm(): - req = request.get_json() - search_id = req["search_id"] +@router.post('/rm') +async def rm( + request: DeleteSearchRequest, + current_user = Depends(get_current_user) +): + """删除搜索应用""" + search_id = request.search_id if not SearchService.accessible4deletion(search_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) diff --git a/api/apps/user_app_fastapi.py b/api/apps/user_app_fastapi.py index 16e98a8..88b8f3f 100644 --- a/api/apps/user_app_fastapi.py +++ b/api/apps/user_app_fastapi.py @@ -15,12 +15,15 @@ # import json import logging +import os import re import secrets +import string +import time from datetime import datetime from typing import Optional, Dict, Any -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi import APIRouter, Depends, HTTPException, Request, Response, Query, status from api.apps.models.auth_dependencies import get_current_user from fastapi.responses import RedirectResponse from pydantic import BaseModel, EmailStr @@ -60,6 +63,19 @@ from api.utils.api_utils import ( validate_request, ) from api.utils.crypt import decrypt +from rag.utils.redis_conn import REDIS_CONN +from api.apps import smtp_mail_server +from api.utils.web_utils import ( + send_email_html, + OTP_LENGTH, + OTP_TTL_SECONDS, + ATTEMPT_LIMIT, + ATTEMPT_LOCK_SECONDS, + RESEND_COOLDOWN_SECONDS, + otp_keys, + hash_code, + captcha_key, +) # 创建路由器 router = APIRouter() @@ -77,9 +93,8 @@ class RegisterRequest(BaseModel): password: str class UserSettingRequest(BaseModel): - nickname: Optional[str] = None - password: Optional[str] = None - new_password: Optional[str] = None + language: Optional[str] = None + class TenantInfoRequest(BaseModel): tenant_id: str @@ -88,6 +103,16 @@ class TenantInfoRequest(BaseModel): img2txt_id: str llm_id: str +class ForgetOtpRequest(BaseModel): + email: str + captcha: str + +class ForgetPasswordRequest(BaseModel): + email: str + otp: str + new_password: str + confirm_new_password: str + # 依赖项:获取当前用户 - 从 auth_dependencies 导入 @router.post("/login") @@ -481,3 +506,357 @@ async def set_tenant_info(request: TenantInfoRequest, current_user = Depends(get status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) + +@router.get("/github_callback") +async def github_callback(code: Optional[str] = Query(None)): + """ + **Deprecated**, Use `/oauth/callback/` instead. + + GitHub OAuth callback endpoint. + """ + import requests + + if not code: + return RedirectResponse(url="/?error=missing_code") + + res = requests.post( + settings.GITHUB_OAUTH.get("url"), + data={ + "client_id": settings.GITHUB_OAUTH.get("client_id"), + "client_secret": settings.GITHUB_OAUTH.get("secret_key"), + "code": code, + }, + headers={"Accept": "application/json"}, + ) + res = res.json() + if "error" in res: + return RedirectResponse(url=f"/?error={res.get('error_description', res.get('error'))}") + + if "user:email" not in res.get("scope", "").split(","): + return RedirectResponse(url="/?error=user:email not in scope") + + access_token = res["access_token"] + user_info = user_info_from_github(access_token) + email_address = user_info["email"] + users = UserService.query(email=email_address) + user_id = get_uuid() + + if not users: + try: + try: + avatar = download_img(user_info["avatar_url"]) + except Exception as e: + logging.exception(e) + avatar = "" + + users = user_register( + user_id, + { + "access_token": access_token, + "email": email_address, + "avatar": avatar, + "nickname": user_info["login"], + "login_channel": "github", + "last_login_time": get_format_time(), + "is_superuser": False, + }, + ) + + if not users: + raise Exception(f"Fail to register {email_address}.") + if len(users) > 1: + raise Exception(f"Same email: {email_address} exists!") + + user = users[0] + return RedirectResponse(url=f"/?auth={user.get_id()}") + except Exception as e: + rollback_user_registration(user_id) + logging.exception(e) + return RedirectResponse(url=f"/?error={str(e)}") + + # User has already registered, try to log in + user = users[0] + user.access_token = get_uuid() + if user and hasattr(user, 'is_active') and user.is_active == "0": + return RedirectResponse(url="/?error=user_inactive") + user.save() + return RedirectResponse(url=f"/?auth={user.get_id()}") + +@router.get("/feishu_callback") +async def feishu_callback(code: Optional[str] = Query(None)): + """ + Feishu OAuth callback endpoint. + """ + import requests + + if not code: + return RedirectResponse(url="/?error=missing_code") + + app_access_token_res = requests.post( + settings.FEISHU_OAUTH.get("app_access_token_url"), + data=json.dumps( + { + "app_id": settings.FEISHU_OAUTH.get("app_id"), + "app_secret": settings.FEISHU_OAUTH.get("app_secret"), + } + ), + headers={"Content-Type": "application/json; charset=utf-8"}, + ) + app_access_token_res = app_access_token_res.json() + if app_access_token_res.get("code") != 0: + return RedirectResponse(url=f"/?error={app_access_token_res}") + + res = requests.post( + settings.FEISHU_OAUTH.get("user_access_token_url"), + data=json.dumps( + { + "grant_type": settings.FEISHU_OAUTH.get("grant_type"), + "code": code, + } + ), + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {app_access_token_res['app_access_token']}", + }, + ) + res = res.json() + if res.get("code") != 0: + return RedirectResponse(url=f"/?error={res.get('message', 'unknown_error')}") + + if "contact:user.email:readonly" not in res.get("data", {}).get("scope", "").split(): + return RedirectResponse(url="/?error=contact:user.email:readonly not in scope") + + access_token = res["data"]["access_token"] + user_info = user_info_from_feishu(access_token) + email_address = user_info["email"] + users = UserService.query(email=email_address) + user_id = get_uuid() + + if not users: + try: + try: + avatar = download_img(user_info["avatar_url"]) + except Exception as e: + logging.exception(e) + avatar = "" + + users = user_register( + user_id, + { + "access_token": access_token, + "email": email_address, + "avatar": avatar, + "nickname": user_info["en_name"], + "login_channel": "feishu", + "last_login_time": get_format_time(), + "is_superuser": False, + }, + ) + + if not users: + raise Exception(f"Fail to register {email_address}.") + if len(users) > 1: + raise Exception(f"Same email: {email_address} exists!") + + user = users[0] + return RedirectResponse(url=f"/?auth={user.get_id()}") + except Exception as e: + rollback_user_registration(user_id) + logging.exception(e) + return RedirectResponse(url=f"/?error={str(e)}") + + # User has already registered, try to log in + user = users[0] + if user and hasattr(user, 'is_active') and user.is_active == "0": + return RedirectResponse(url="/?error=user_inactive") + user.access_token = get_uuid() + user.save() + return RedirectResponse(url=f"/?auth={user.get_id()}") + +def user_info_from_feishu(access_token): + """从飞书获取用户信息""" + import requests + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {access_token}", + } + res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) + user_info = res.json()["data"] + user_info["email"] = None if user_info.get("email") == "" else user_info["email"] + return user_info + +def user_info_from_github(access_token): + """从GitHub获取用户信息""" + import requests + + headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} + res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) + user_info = res.json() + email_info = requests.get( + f"https://api.github.com/user/emails?access_token={access_token}", + headers=headers, + ).json() + user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] + return user_info + +@router.get("/forget/captcha") +async def forget_get_captcha(email: str = Query(...)): + """ + GET /forget/captcha?email= + - Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = 60 seconds. + - Returns the captcha as a JPEG image. + """ + if not email: + return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email is required") + + users = UserService.query(email=email) + if not users: + return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email") + + # Generate captcha text + allowed = string.ascii_uppercase + string.digits + captcha_text = "".join(secrets.choice(allowed) for _ in range(OTP_LENGTH)) + REDIS_CONN.set(captcha_key(email), captcha_text, 60) # Valid for 60 seconds + + from captcha.image import ImageCaptcha + image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70]) + img_bytes = image.generate(captcha_text).read() + + return Response(content=img_bytes, media_type="image/JPEG") + +@router.post("/forget/otp") +async def forget_send_otp(request: ForgetOtpRequest): + """ + POST /forget/otp + - Verify the image captcha stored at captcha:{email} (case-insensitive). + - On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email. + """ + email = request.email or "" + captcha = (request.captcha or "").strip() + + if not email or not captcha: + return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email and captcha required") + + users = UserService.query(email=email) + if not users: + return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email") + + stored_captcha = REDIS_CONN.get(captcha_key(email)) + if not stored_captcha: + return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="invalid or expired captcha") + if (stored_captcha or "").strip().lower() != captcha.lower(): + return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="invalid or expired captcha") + + # Delete captcha to prevent reuse + REDIS_CONN.delete(captcha_key(email)) + + k_code, k_attempts, k_last, k_lock = otp_keys(email) + now = int(time.time()) + last_ts = REDIS_CONN.get(k_last) + if last_ts: + try: + elapsed = now - int(last_ts) + except Exception: + elapsed = RESEND_COOLDOWN_SECONDS + remaining = RESEND_COOLDOWN_SECONDS - elapsed + if remaining > 0: + return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message=f"you still have to wait {remaining} seconds") + + # Generate OTP (uppercase letters only) and store hashed + otp = "".join(secrets.choice(string.ascii_uppercase) for _ in range(OTP_LENGTH)) + salt = os.urandom(16) + code_hash = hash_code(otp, salt) + REDIS_CONN.set(k_code, f"{code_hash}:{salt.hex()}", OTP_TTL_SECONDS) + REDIS_CONN.set(k_attempts, 0, OTP_TTL_SECONDS) + REDIS_CONN.set(k_last, now, OTP_TTL_SECONDS) + REDIS_CONN.delete(k_lock) + + ttl_min = OTP_TTL_SECONDS // 60 + + if not smtp_mail_server: + logging.warning("SMTP mail server not initialized; skip sending email.") + else: + try: + send_email_html( + subject="Your Password Reset Code", + to_email=email, + template_key="reset_code", + code=otp, + ttl_min=ttl_min, + ) + except Exception: + return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="failed to send email") + + return get_json_result(data=True, code=settings.RetCode.SUCCESS, message="verification passed, email sent") + +@router.post("/forget") +async def forget(request: ForgetPasswordRequest): + """ + POST: Verify email + OTP and reset password, then log the user in. + Request JSON: { email, otp, new_password, confirm_new_password } + """ + email = request.email or "" + otp = (request.otp or "").strip() + new_pwd = request.new_password + new_pwd2 = request.confirm_new_password + + if not all([email, otp, new_pwd, new_pwd2]): + return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required") + + # For reset, passwords are provided as-is (no decrypt needed) + if new_pwd != new_pwd2: + return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="passwords do not match") + + users = UserService.query(email=email) + if not users: + return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email") + + user = users[0] + # Verify OTP from Redis + k_code, k_attempts, k_last, k_lock = otp_keys(email) + if REDIS_CONN.get(k_lock): + return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="too many attempts, try later") + + stored = REDIS_CONN.get(k_code) + if not stored: + return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="expired otp") + + try: + stored_hash, salt_hex = str(stored).split(":", 1) + salt = bytes.fromhex(salt_hex) + except Exception: + return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="otp storage corrupted") + + # Case-insensitive verification: OTP generated uppercase + calc = hash_code(otp.upper(), salt) + if calc != stored_hash: + # bump attempts + try: + attempts = int(REDIS_CONN.get(k_attempts) or 0) + 1 + except Exception: + attempts = 1 + REDIS_CONN.set(k_attempts, attempts, OTP_TTL_SECONDS) + if attempts >= ATTEMPT_LIMIT: + REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS) + return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="expired otp") + + # Success: consume OTP and reset password + REDIS_CONN.delete(k_code) + REDIS_CONN.delete(k_attempts) + REDIS_CONN.delete(k_last) + REDIS_CONN.delete(k_lock) + + try: + UserService.update_user_password(user.id, new_pwd) + except Exception as e: + logging.exception(e) + return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="failed to reset password") + + # Auto login (reuse login flow) + user.access_token = get_uuid() + user.update_time = (current_timestamp(),) + user.update_date = (datetime_format(datetime.now()),) + user.save() + msg = "Password reset successful. Logged in." + return construct_response(data=user.to_json(), auth=user.get_id(), message=msg) diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 2b60500..002b619 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -8,6 +8,10 @@ minio: user: 'rag_flow' password: 'infini_rag_flow' host: 'localhost:9000' +es: + hosts: 'http://localhost:1200' + username: 'elastic' + password: 'infini_rag_flow' os: hosts: 'http://localhost:1201' username: 'admin' diff --git a/docker/.env b/docker/.env index d4f3bc2..67062ea 100644 --- a/docker/.env +++ b/docker/.env @@ -3,7 +3,7 @@ # - `elasticsearch` (default) # - `infinity` (https://github.com/infiniflow/infinity) # - `opensearch` (https://github.com/opensearch-project/OpenSearch) -DOC_ENGINE=opensearch +DOC_ENGINE=elasticsearch # ------------------------------ # docker env var for specifying vector db type at startup @@ -98,7 +98,7 @@ ADMIN_SVR_HTTP_PORT=9381 # The RAGFlow Docker image to download. # Defaults to the v0.21.1-slim edition, which is the RAGFlow Docker image without embedding models. -RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi +RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi-web # # To download the RAGFlow Docker image with embedding models, uncomment the following line instead: # RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1 diff --git a/docker/docker-compose-base.yml b/docker/docker-compose-base.yml index d7aa601..1088a00 100644 --- a/docker/docker-compose-base.yml +++ b/docker/docker-compose-base.yml @@ -1,36 +1,34 @@ services: - opensearch01: - container_name: ragflow-opensearch-01 + es01: + container_name: ragflow-es-01 profiles: - - opensearch - image: hub.icert.top/opensearchproject/opensearch:2.19.1 + - elasticsearch + image: elasticsearch:${STACK_VERSION} volumes: - - osdata01:/usr/share/opensearch/data + - esdata01:/usr/share/elasticsearch/data ports: - - ${OS_PORT}:9201 + - ${ES_PORT}:9200 env_file: .env environment: - - node.name=opensearch01 - - OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD} - - OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_PASSWORD} + - node.name=es01 + - ELASTIC_PASSWORD=${ELASTIC_PASSWORD} - bootstrap.memory_lock=false - discovery.type=single-node - - plugins.security.disabled=false - - plugins.security.ssl.http.enabled=false - - plugins.security.ssl.transport.enabled=true + - xpack.security.enabled=true + - xpack.security.http.ssl.enabled=false + - xpack.security.transport.ssl.enabled=false - cluster.routing.allocation.disk.watermark.low=5gb - cluster.routing.allocation.disk.watermark.high=3gb - cluster.routing.allocation.disk.watermark.flood_stage=2gb - TZ=${TIMEZONE} - - http.port=9201 mem_limit: ${MEM_LIMIT} ulimits: memlock: soft: -1 hard: -1 healthcheck: - test: ["CMD-SHELL", "curl http://localhost:9201"] + test: ["CMD-SHELL", "curl http://localhost:9200"] interval: 10s timeout: 10s retries: 120 diff --git a/rag/nlp/search.py b/rag/nlp/search.py index c10b803..1aa9f4e 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -64,7 +64,8 @@ class Dealer: if key in req and req[key] is not None: condition[field] = req[key] # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns. - for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]: + for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", + "removed_kwd"]: if key in req and req[key] is not None: condition[key] = req[key] return condition @@ -135,7 +136,8 @@ class Dealer: matchText, _ = self.qryr.question(qst, min_match=0.1) matchDense.extra_options["similarity"] = 0.17 res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], - orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) + orderBy, offset, limit, idx_names, kb_ids, + rank_feature=rank_feature) total = self.dataStore.getTotal(res) logging.debug("Dealer.search 2 TOTAL: {}".format(total)) @@ -212,8 +214,9 @@ class Dealer: ans_v, _ = embd_mdl.encode(pieces_) for i in range(len(chunk_v)): if len(ans_v[0]) != len(chunk_v[i]): - chunk_v[i] = [0.0]*len(ans_v[0]) - logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i]))) + chunk_v[i] = [0.0] * len(ans_v[0]) + logging.warning( + "The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i]))) assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( len(ans_v[0]), len(chunk_v[0])) @@ -267,7 +270,7 @@ class Dealer: if not query_rfea: return np.array([0 for _ in range(len(search_res.ids))]) + pageranks - q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD])) + q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD])) for i in search_res.ids: nor, denor = 0, 0 if not search_res.field[i].get(TAG_FLD): @@ -280,8 +283,8 @@ class Dealer: if denor == 0: rank_fea.append(0) else: - rank_fea.append(nor/np.sqrt(denor)/q_denor) - return np.array(rank_fea)*10. + pageranks + rank_fea.append(nor / np.sqrt(denor) / q_denor) + return np.array(rank_fea) * 10. + pageranks def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", @@ -343,7 +346,7 @@ class Dealer: ## For rank feature(tag_fea) scores. rank_fea = self._rank_feature_scores(rank_feature, sres) - return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim + return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): return self.qryr.hybrid_similarity(ans_embd, @@ -360,13 +363,13 @@ class Dealer: return ranks # Ensure RERANK_LIMIT is multiple of page_size - RERANK_LIMIT = math.ceil(64/page_size) * page_size if page_size>1 else 1 - req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT, + RERANK_LIMIT = math.ceil(64 / page_size) * page_size if page_size > 1 else 1 + req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size * page / RERANK_LIMIT), + "size": RERANK_LIMIT, "question": question, "vector": True, "topk": top, "similarity": similarity_threshold, "available_int": 1} - if isinstance(tenant_ids, str): tenant_ids = tenant_ids.split(",") @@ -392,15 +395,15 @@ class Dealer: tsim = sim vsim = sim # Already paginated in search function - begin = ((page % (RERANK_LIMIT//page_size)) - 1) * page_size - sim = sim[begin : begin + page_size] + begin = ((page % (RERANK_LIMIT // page_size)) - 1) * page_size + sim = sim[begin: begin + page_size] sim_np = np.array(sim) idx = np.argsort(sim_np * -1) dim = len(sres.query_vector) vector_column = f"q_{dim}_vec" zero_vector = [0.0] * dim filtered_count = (sim_np >= similarity_threshold).sum() - ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error + ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error for i in idx: if sim[i] < similarity_threshold: break @@ -447,8 +450,8 @@ class Dealer: ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k, - v in sorted(ranks["doc_aggs"].items(), - key=lambda x: x[1]["count"] * -1)] + v in sorted(ranks["doc_aggs"].items(), + key=lambda x: x[1]["count"] * -1)] ranks["chunks"] = ranks["chunks"][:page_size] return ranks @@ -505,13 +508,14 @@ class Dealer: def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000): idx_nm = index_name(tenant_id) - match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn) + match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), + keywords_topn) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"]) aggs = self.dataStore.getAggregation(res, "tag_kwd") if not aggs: return False cnt = np.sum([c for _, c in aggs]) - tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], + tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags] doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0} return True @@ -527,11 +531,11 @@ class Dealer: if not aggs: return {} cnt = np.sum([c for _, c in aggs]) - tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], + tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], key=lambda x: x[1] * -1)[:topn_tags] return {a.replace(".", "_"): max(1, c) for a, c in tag_fea} - def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6): + def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6): if not chunks: return [] idx_nms = [index_name(tid) for tid in tenant_ids] @@ -541,9 +545,10 @@ class Dealer: ranks[ck["doc_id"]] = 0 ranks[ck["doc_id"]] += ck["similarity"] doc_id2kb_id[ck["doc_id"]] = ck["kb_id"] - doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0] + doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0] kb_ids = [doc_id2kb_id[doc_id]] - es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms, + es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], + OrderByExpr(), 0, 128, idx_nms, kb_ids) toc = [] dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"]) @@ -555,10 +560,10 @@ class Dealer: if not toc: return chunks - ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2) + ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2) if not ids: return chunks - + vector_size = 1024 id2idx = {ck["chunk_id"]: i for i, ck in enumerate(chunks)} for cid, sim in ids: @@ -589,4 +594,4 @@ class Dealer: break chunks.append(d) - return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn] + return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn] diff --git a/rag/res/deepdoc/.gitattributes b/rag/res/deepdoc/.gitattributes new file mode 100644 index 0000000..a6344aa --- /dev/null +++ b/rag/res/deepdoc/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/rag/res/deepdoc/layout.laws.onnx b/rag/res/deepdoc/layout.laws.onnx new file mode 100644 index 0000000..6dfc9ef Binary files /dev/null and b/rag/res/deepdoc/layout.laws.onnx differ diff --git a/rag/res/deepdoc/layout.manual.onnx b/rag/res/deepdoc/layout.manual.onnx new file mode 100644 index 0000000..6dfc9ef Binary files /dev/null and b/rag/res/deepdoc/layout.manual.onnx differ diff --git a/rag/res/deepdoc/layout.onnx b/rag/res/deepdoc/layout.onnx new file mode 100644 index 0000000..6dfc9ef Binary files /dev/null and b/rag/res/deepdoc/layout.onnx differ diff --git a/rag/res/deepdoc/layout.paper.onnx b/rag/res/deepdoc/layout.paper.onnx new file mode 100644 index 0000000..6dfc9ef Binary files /dev/null and b/rag/res/deepdoc/layout.paper.onnx differ diff --git a/rag/res/deepdoc/tsr.onnx b/rag/res/deepdoc/tsr.onnx new file mode 100644 index 0000000..445aa09 Binary files /dev/null and b/rag/res/deepdoc/tsr.onnx differ