修复 画布,mcp服务,搜索,文档的接口

This commit is contained in:
2025-11-07 09:34:35 +08:00
parent c5f8fe06e7
commit 54532747d2
21 changed files with 1023 additions and 272 deletions

View File

@@ -165,17 +165,23 @@ def setup_routes(app: FastAPI):
from api.apps.tenant_app import router as tenant_router from api.apps.tenant_app import router as tenant_router
from api.apps.dialog_app import router as dialog_router from api.apps.dialog_app import router as dialog_router
from api.apps.system_app import router as system_router from api.apps.system_app import router as system_router
from api.apps.search_app import router as search_router
from api.apps.conversation_app import router as conversation_router
from api.apps.file_app import router as file_router
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"]) app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"]) app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"])
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"]) app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"]) app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"]) app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"])
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"]) app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp_server", tags=["MCP"])
app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"]) app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"])
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"]) app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
app.include_router(dialog_router, prefix=f"/{API_VERSION}/dialog", tags=["Dialog"]) app.include_router(dialog_router, prefix=f"/{API_VERSION}/dialog", tags=["Dialog"])
app.include_router(system_router, prefix=f"/{API_VERSION}/system", tags=["System"]) app.include_router(system_router, prefix=f"/{API_VERSION}/system", tags=["System"])
app.include_router(search_router, prefix=f"/{API_VERSION}/search", tags=["Search"])
app.include_router(conversation_router, prefix=f"/{API_VERSION}/conversation", tags=["Conversation"])
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])

View File

@@ -17,8 +17,10 @@ import json
import re import re
import logging import logging
from copy import deepcopy from copy import deepcopy
from flask import Response, request from typing import Optional
from flask_login import current_user, login_required from fastapi import APIRouter, Depends, Query, Header, HTTPException, status
from fastapi.responses import StreamingResponse
from api import settings from api import settings
from api.db import LLMType from api.db import LLMType
from api.db.db_models import APIToken from api.db.db_models import APIToken
@@ -28,15 +30,35 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
from api.utils import get_uuid
from rag.prompts.template import load_prompt from rag.prompts.template import load_prompt
from rag.prompts.generator import chunks_format from rag.prompts.generator import chunks_format
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.conversation_models import (
SetConversationRequest,
DeleteConversationsRequest,
CompletionRequest,
TTSRequest,
DeleteMessageRequest,
ThumbupRequest,
AskRequest,
MindmapRequest,
RelatedQuestionsRequest,
)
@manager.route("/set", methods=["POST"]) # noqa: F821 # 创建路由器
@login_required router = APIRouter()
def set_conversation():
req = request.json
@router.post('/set')
async def set_conversation(
request: SetConversationRequest,
current_user = Depends(get_current_user)
):
"""设置对话"""
req = request.model_dump(exclude_unset=True)
conv_id = req.get("conversation_id") conv_id = req.get("conversation_id")
is_new = req.get("is_new") is_new = req.get("is_new")
name = req.get("name", "New conversation") name = req.get("name", "New conversation")
@@ -45,9 +67,9 @@ def set_conversation():
if len(name) > 255: if len(name) > 255:
name = name[0:255] name = name[0:255]
del req["is_new"]
if not is_new: if not is_new:
del req["conversation_id"] if not conv_id:
return get_data_error_result(message="conversation_id is required when is_new is False!")
try: try:
if not ConversationService.update_by_id(conv_id, req): if not ConversationService.update_by_id(conv_id, req):
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
@@ -64,7 +86,7 @@ def set_conversation():
if not e: if not e:
return get_data_error_result(message="Dialog not found") return get_data_error_result(message="Dialog not found")
conv = { conv = {
"id": conv_id, "id": conv_id or get_uuid(),
"dialog_id": req["dialog_id"], "dialog_id": req["dialog_id"],
"name": name, "name": name,
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}], "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
@@ -77,12 +99,14 @@ def set_conversation():
return server_error_response(e) return server_error_response(e)
@manager.route("/get", methods=["GET"]) # noqa: F821 @router.get('/get')
@login_required async def get(
def get(): conversation_id: str = Query(..., description="对话ID"),
conv_id = request.args["conversation_id"] current_user = Depends(get_current_user)
):
"""获取对话"""
try: try:
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conversation_id)
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
@@ -107,15 +131,27 @@ def get():
return server_error_response(e) return server_error_response(e)
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821 @router.get('/getsse/{dialog_id}')
def getsse(dialog_id): async def getsse(
token = request.headers.get("Authorization").split() dialog_id: str,
if len(token) != 2: authorization: Optional[str] = Header(None, alias="Authorization")
):
"""通过 SSE 获取对话(使用 API token 认证)"""
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header is required"
)
token_parts = authorization.split()
if len(token_parts) != 2:
return get_data_error_result(message='Authorization is not valid!"') return get_data_error_result(message='Authorization is not valid!"')
token = token[1] token = token_parts[1]
objs = APIToken.query(beta=token) objs = APIToken.query(beta=token)
if not objs: if not objs:
return get_data_error_result(message='Authentication error: API key is invalid!"') return get_data_error_result(message='Authentication error: API key is invalid!"')
try: try:
e, conv = DialogService.get_by_id(dialog_id) e, conv = DialogService.get_by_id(dialog_id)
if not e: if not e:
@@ -128,12 +164,14 @@ def getsse(dialog_id):
return server_error_response(e) return server_error_response(e)
@manager.route("/rm", methods=["POST"]) # noqa: F821 @router.post('/rm')
@login_required async def rm(
def rm(): request: DeleteConversationsRequest,
conv_ids = request.json["conversation_ids"] current_user = Depends(get_current_user)
):
"""删除对话"""
try: try:
for cid in conv_ids: for cid in request.conversation_ids:
exist, conv = ConversationService.get_by_id(cid) exist, conv = ConversationService.get_by_id(cid)
if not exist: if not exist:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
@@ -149,10 +187,12 @@ def rm():
return server_error_response(e) return server_error_response(e)
@manager.route("/list", methods=["GET"]) # noqa: F821 @router.get('/list')
@login_required async def list_conversation(
def list_conversation(): dialog_id: str = Query(..., description="对话ID"),
dialog_id = request.args["dialog_id"] current_user = Depends(get_current_user)
):
"""列出对话"""
try: try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id): if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
@@ -164,11 +204,13 @@ def list_conversation():
return server_error_response(e) return server_error_response(e)
@manager.route("/completion", methods=["POST"]) # noqa: F821 @router.post('/completion')
@login_required async def completion(
@validate_request("conversation_id", "messages") request: CompletionRequest,
def completion(): current_user = Depends(get_current_user)
req = request.json ):
"""完成请求(聊天完成)"""
req = request.model_dump(exclude_unset=True)
msg = [] msg = []
for m in req["messages"]: for m in req["messages"]:
if m["role"] == "system": if m["role"] == "system":
@@ -176,6 +218,10 @@ def completion():
if m["role"] == "assistant" and not msg: if m["role"] == "assistant" and not msg:
continue continue
msg.append(m) msg.append(m)
if not msg:
return get_data_error_result(message="No valid messages found!")
message_id = msg[-1].get("id") message_id = msg[-1].get("id")
chat_model_id = req.get("llm_id", "") chat_model_id = req.get("llm_id", "")
req.pop("llm_id", None) req.pop("llm_id", None)
@@ -217,6 +263,7 @@ def completion():
dia.llm_setting = chat_model_config dia.llm_setting = chat_model_config
is_embedded = bool(chat_model_id) is_embedded = bool(chat_model_id)
def stream(): def stream():
nonlocal dia, msg, req, conv nonlocal dia, msg, req, conv
try: try:
@@ -230,14 +277,18 @@ def completion():
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True): stream_enabled = request.stream if request.stream is not None else True
resp = Response(stream(), mimetype="text/event-stream") if stream_enabled:
resp.headers.add_header("Cache-control", "no-cache") return StreamingResponse(
resp.headers.add_header("Connection", "keep-alive") stream(),
resp.headers.add_header("X-Accel-Buffering", "no") media_type="text/event-stream",
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") headers={
return resp "Cache-control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8"
}
)
else: else:
answer = None answer = None
for ans in chat(dia, msg, **req): for ans in chat(dia, msg, **req):
@@ -250,11 +301,13 @@ def completion():
return server_error_response(e) return server_error_response(e)
@manager.route("/tts", methods=["POST"]) # noqa: F821 @router.post('/tts')
@login_required async def tts(
def tts(): request: TTSRequest,
req = request.json current_user = Depends(get_current_user)
text = req["text"] ):
"""文本转语音"""
text = request.text
tenants = TenantService.get_info_by(current_user.id) tenants = TenantService.get_info_by(current_user.id)
if not tenants: if not tenants:
@@ -274,28 +327,32 @@ def tts():
except Exception as e: except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8") yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
resp = Response(stream_audio(), mimetype="audio/mpeg") return StreamingResponse(
resp.headers.add_header("Cache-Control", "no-cache") stream_audio(),
resp.headers.add_header("Connection", "keep-alive") media_type="audio/mpeg",
resp.headers.add_header("X-Accel-Buffering", "no") headers={
"Cache-Control": "no-cache",
return resp "Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821 @router.post('/delete_msg')
@login_required async def delete_msg(
@validate_request("conversation_id", "message_id") request: DeleteMessageRequest,
def delete_msg(): current_user = Depends(get_current_user)
req = request.json ):
e, conv = ConversationService.get_by_id(req["conversation_id"]) """删除消息"""
e, conv = ConversationService.get_by_id(request.conversation_id)
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
conv = conv.to_dict() conv = conv.to_dict()
for i, msg in enumerate(conv["message"]): for i, msg in enumerate(conv["message"]):
if req["message_id"] != msg.get("id", ""): if request.message_id != msg.get("id", ""):
continue continue
assert conv["message"][i + 1]["id"] == req["message_id"] assert conv["message"][i + 1]["id"] == request.message_id
conv["message"].pop(i) conv["message"].pop(i)
conv["message"].pop(i) conv["message"].pop(i)
conv["reference"].pop(max(0, i // 2 - 1)) conv["reference"].pop(max(0, i // 2 - 1))
@@ -305,19 +362,21 @@ def delete_msg():
return get_json_result(data=conv) return get_json_result(data=conv)
@manager.route("/thumbup", methods=["POST"]) # noqa: F821 @router.post('/thumbup')
@login_required async def thumbup(
@validate_request("conversation_id", "message_id") request: ThumbupRequest,
def thumbup(): current_user = Depends(get_current_user)
req = request.json ):
e, conv = ConversationService.get_by_id(req["conversation_id"]) """点赞/点踩"""
e, conv = ConversationService.get_by_id(request.conversation_id)
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
up_down = req.get("thumbup")
feedback = req.get("feedback", "") up_down = request.thumbup
feedback = request.feedback or ""
conv = conv.to_dict() conv = conv.to_dict()
for i, msg in enumerate(conv["message"]): for i, msg in enumerate(conv["message"]):
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant": if request.message_id == msg.get("id", "") and msg.get("role", "") == "assistant":
if up_down: if up_down:
msg["thumbup"] = True msg["thumbup"] = True
if "feedback" in msg: if "feedback" in msg:
@@ -332,14 +391,15 @@ def thumbup():
return get_json_result(data=conv) return get_json_result(data=conv)
@manager.route("/ask", methods=["POST"]) # noqa: F821 @router.post('/ask')
@login_required async def ask_about(
@validate_request("question", "kb_ids") request: AskRequest,
def ask_about(): current_user = Depends(get_current_user)
req = request.json ):
"""提问"""
uid = current_user.id uid = current_user.id
search_id = req.get("search_id", "") search_id = request.search_id or ""
search_app = None search_app = None
search_config = {} search_config = {}
if search_id: if search_id:
@@ -348,53 +408,58 @@ def ask_about():
search_config = search_app.get("search_config", {}) search_config = search_app.get("search_config", {})
def stream(): def stream():
nonlocal req, uid nonlocal request, uid
try: try:
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): for ans in ask(request.question, request.kb_ids, uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream") return StreamingResponse(
resp.headers.add_header("Cache-control", "no-cache") stream(),
resp.headers.add_header("Connection", "keep-alive") media_type="text/event-stream",
resp.headers.add_header("X-Accel-Buffering", "no") headers={
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") "Cache-control": "no-cache",
return resp "Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Type": "text/event-stream; charset=utf-8"
}
)
@manager.route("/mindmap", methods=["POST"]) # noqa: F821 @router.post('/mindmap')
@login_required async def mindmap(
@validate_request("question", "kb_ids") request: MindmapRequest,
def mindmap(): current_user = Depends(get_current_user)
req = request.json ):
search_id = req.get("search_id", "") """思维导图"""
search_id = request.search_id or ""
search_app = SearchService.get_detail(search_id) if search_id else {} search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {} search_config = search_app.get("search_config", {}) if search_app else {}
kb_ids = search_config.get("kb_ids", []) kb_ids = search_config.get("kb_ids", [])
kb_ids.extend(req["kb_ids"]) kb_ids.extend(request.kb_ids)
kb_ids = list(set(kb_ids)) kb_ids = list(set(kb_ids))
mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config) mind_map = gen_mindmap(request.question, kb_ids, search_app.get("tenant_id", current_user.id), search_config)
if "error" in mind_map: if "error" in mind_map:
return server_error_response(Exception(mind_map["error"])) return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map) return get_json_result(data=mind_map)
@manager.route("/related_questions", methods=["POST"]) # noqa: F821 @router.post('/related_questions')
@login_required async def related_questions(
@validate_request("question") request: RelatedQuestionsRequest,
def related_questions(): current_user = Depends(get_current_user)
req = request.json ):
"""相关问题"""
search_id = req.get("search_id", "") search_id = request.search_id or ""
search_config = {} search_config = {}
if search_id: if search_id:
if search_app := SearchService.get_detail(search_id): if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {}) search_config = search_app.get("search_config", {})
question = req["question"] question = request.question
chat_id = search_config.get("chat_id", "") chat_id = search_config.get("chat_id", "")
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)

View File

@@ -72,18 +72,17 @@ router = APIRouter()
@router.post("/upload") @router.post("/upload")
async def upload( async def upload(
kb_id: str = Form(...), kb_id: str = Form(...),
files: List[UploadFile] = File(...), file: UploadFile = File(...),
current_user = Depends(get_current_user) current_user = Depends(get_current_user)
): ):
"""上传文档""" """上传文档"""
if not files: if not file:
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
for file_obj in files: if not file.filename or file.filename == "":
if not file_obj.filename or file_obj.filename == "": return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR) if len(file.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
@@ -91,7 +90,7 @@ async def upload(
if not check_kb_team_permission(kb, current_user.id): if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
err, uploaded_files = FileService.upload_document(kb, files, current_user.id) err, uploaded_files = FileService.upload_document(kb, [file], current_user.id)
if err: if err:
return get_json_result(data=uploaded_files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR) return get_json_result(data=uploaded_files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)

View File

@@ -17,15 +17,23 @@ import logging
import os import os
import pathlib import pathlib
import re import re
from typing import Optional, List
import flask from fastapi import APIRouter, Depends, Query, UploadFile, File, Form
from flask import request from fastapi.responses import Response
from flask_login import login_required, current_user
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.file_models import (
CreateFileRequest,
DeleteFilesRequest,
RenameFileRequest,
MoveFilesRequest,
)
from api.common.check_team_permission import check_file_team_permission from api.common.check_team_permission import check_file_team_permission
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result
from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType, FileSource from api.db import FileType, FileSource
from api.db.services import duplicate_name from api.db.services import duplicate_name
@@ -36,35 +44,41 @@ from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP from api.utils.web_utils import CONTENT_TYPE_MAP
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
# 创建路由器
router = APIRouter()
@manager.route('/upload', methods=['POST']) # noqa: F821
@login_required @router.post('/upload')
# @validate_request("parent_id") async def upload(
def upload(): files: List[UploadFile] = File(...),
pf_id = request.form.get("parent_id") parent_id: Optional[str] = Form(None),
current_user = Depends(get_current_user)
):
"""上传文件"""
pf_id = parent_id
if not pf_id: if not pf_id:
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
if 'file' not in request.files: if not files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs: for file_obj in files:
if file_obj.filename == '': if not file_obj.filename or file_obj.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
file_res = [] file_res = []
try: try:
e, pf_folder = FileService.get_by_id(pf_id) e, pf_folder = FileService.get_by_id(pf_id)
if not e: if not e:
return get_data_error_result( message="Can't find this folder!") return get_data_error_result(message="Can't find this folder!")
for file_obj in file_objs: for file_obj in files:
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER: if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
return get_data_error_result( message="Exceed the maximum file number of a free user!") return get_data_error_result(message="Exceed the maximum file number of a free user!")
# split file name path # split file name path
if not file_obj.filename: if not file_obj.filename:
@@ -97,7 +111,7 @@ def upload():
location = file_obj_names[file_len - 1] location = file_obj_names[file_len - 1]
while STORAGE_IMPL.obj_exist(last_folder.id, location): while STORAGE_IMPL.obj_exist(last_folder.id, location):
location += "_" location += "_"
blob = file_obj.read() blob = await file_obj.read()
filename = duplicate_name( filename = duplicate_name(
FileService.query, FileService.query,
name=file_obj_names[file_len - 1], name=file_obj_names[file_len - 1],
@@ -120,13 +134,16 @@ def upload():
return server_error_response(e) return server_error_response(e)
@manager.route('/create', methods=['POST']) # noqa: F821 @router.post('/create')
@login_required async def create(
@validate_request("name") request: CreateFileRequest,
def create(): current_user = Depends(get_current_user)
req = request.json ):
pf_id = request.json.get("parent_id") """创建文件/文件夹"""
input_file_type = request.json.get("type") req = request.model_dump(exclude_unset=True)
pf_id = req.get("parent_id")
input_file_type = req.get("type")
if not pf_id: if not pf_id:
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
@@ -160,17 +177,22 @@ def create():
return server_error_response(e) return server_error_response(e)
@manager.route('/list', methods=['GET']) # noqa: F821 @router.get('/list')
@login_required async def list_files(
def list_files(): parent_id: Optional[str] = Query(None, description="父文件夹ID"),
pf_id = request.args.get("parent_id") keywords: Optional[str] = Query("", description="搜索关键词"),
page: Optional[int] = Query(1, description="页码"),
page_size: Optional[int] = Query(15, description="每页数量"),
orderby: Optional[str] = Query("create_time", description="排序字段"),
desc: Optional[bool] = Query(True, description="是否降序"),
current_user = Depends(get_current_user)
):
"""列出文件"""
pf_id = parent_id
keywords = request.args.get("keywords", "") page_number = int(page) if page else 1
items_per_page = int(page_size) if page_size else 15
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 15))
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
if not pf_id: if not pf_id:
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
@@ -192,9 +214,11 @@ def list_files():
return server_error_response(e) return server_error_response(e)
@manager.route('/root_folder', methods=['GET']) # noqa: F821 @router.get('/root_folder')
@login_required async def get_root_folder(
def get_root_folder(): current_user = Depends(get_current_user)
):
"""获取根文件夹"""
try: try:
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
return get_json_result(data={"root_folder": root_folder}) return get_json_result(data={"root_folder": root_folder})
@@ -202,10 +226,12 @@ def get_root_folder():
return server_error_response(e) return server_error_response(e)
@manager.route('/parent_folder', methods=['GET']) # noqa: F821 @router.get('/parent_folder')
@login_required async def get_parent_folder(
def get_parent_folder(): file_id: str = Query(..., description="文件ID"),
file_id = request.args.get("file_id") current_user = Depends(get_current_user)
):
"""获取父文件夹"""
try: try:
e, file = FileService.get_by_id(file_id) e, file = FileService.get_by_id(file_id)
if not e: if not e:
@@ -217,10 +243,12 @@ def get_parent_folder():
return server_error_response(e) return server_error_response(e)
@manager.route('/all_parent_folder', methods=['GET']) # noqa: F821 @router.get('/all_parent_folder')
@login_required async def get_all_parent_folders(
def get_all_parent_folders(): file_id: str = Query(..., description="文件ID"),
file_id = request.args.get("file_id") current_user = Depends(get_current_user)
):
"""获取所有父文件夹"""
try: try:
e, file = FileService.get_by_id(file_id) e, file = FileService.get_by_id(file_id)
if not e: if not e:
@@ -235,12 +263,13 @@ def get_all_parent_folders():
return server_error_response(e) return server_error_response(e)
@manager.route("/rm", methods=["POST"]) # noqa: F821 @router.post("/rm")
@login_required async def rm(
@validate_request("file_ids") request: DeleteFilesRequest,
def rm(): current_user = Depends(get_current_user)
req = request.json ):
file_ids = req["file_ids"] """删除文件"""
file_ids = request.file_ids
def _delete_single_file(file): def _delete_single_file(file):
try: try:
@@ -296,11 +325,13 @@ def rm():
return server_error_response(e) return server_error_response(e)
@manager.route('/rename', methods=['POST']) # noqa: F821 @router.post('/rename')
@login_required async def rename(
@validate_request("file_id", "name") request: RenameFileRequest,
def rename(): current_user = Depends(get_current_user)
req = request.json ):
"""重命名文件"""
req = request.model_dump()
try: try:
e, file = FileService.get_by_id(req["file_id"]) e, file = FileService.get_by_id(req["file_id"])
if not e: if not e:
@@ -314,8 +345,8 @@ def rename():
data=False, data=False,
message="The extension of file can't be changed", message="The extension of file can't be changed",
code=settings.RetCode.ARGUMENT_ERROR) code=settings.RetCode.ARGUMENT_ERROR)
for file in FileService.query(name=req["name"], pf_id=file.parent_id): for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
if file.name == req["name"]: if existing_file.name == req["name"]:
return get_data_error_result( return get_data_error_result(
message="Duplicated file name in the same folder.") message="Duplicated file name in the same folder.")
@@ -336,9 +367,12 @@ def rename():
return server_error_response(e) return server_error_response(e)
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821 @router.get('/get/{file_id}')
@login_required async def get(
def get(file_id): file_id: str,
current_user = Depends(get_current_user)
):
"""获取文件内容"""
try: try:
e, file = FileService.get_by_id(file_id) e, file = FileService.get_by_id(file_id)
if not e: if not e:
@@ -351,25 +385,28 @@ def get(file_id):
b, n = File2DocumentService.get_storage_address(file_id=file_id) b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = STORAGE_IMPL.get(b, n) blob = STORAGE_IMPL.get(b, n)
response = flask.make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower()) ext = re.search(r"\.([^.]+)$", file.name.lower())
ext = ext.group(1) if ext else None ext = ext.group(1) if ext else None
content_type = "application/octet-stream"
if ext: if ext:
if file.type == FileType.VISUAL.value: if file.type == FileType.VISUAL.value:
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}") content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
else: else:
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}") content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
response.headers.set("Content-Type", content_type)
return response return Response(content=blob, media_type=content_type)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route("/mv", methods=["POST"]) # noqa: F821 @router.post("/mv")
@login_required async def move(
@validate_request("src_file_ids", "dest_file_id") request: MoveFilesRequest,
def move(): current_user = Depends(get_current_user)
req = request.json ):
"""移动文件"""
req = request.model_dump()
try: try:
file_ids = req["src_file_ids"] file_ids = req["src_file_ids"]
dest_parent_id = req["dest_file_id"] dest_parent_id = req["dest_file_id"]

View File

@@ -169,14 +169,17 @@ async def update(
): ):
"""更新知识库""" """更新知识库"""
req = request.model_dump(exclude_unset=True) req = request.model_dump(exclude_unset=True)
if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.") # 验证 name 字段(如果提供)
if req["name"].strip() == "": if "name" in req:
return get_data_error_result(message="Dataset name can't be empty.") if not isinstance(req["name"], str):
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT: return get_data_error_result(message="Dataset name must be string.")
return get_data_error_result( if req["name"].strip() == "":
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}") return get_data_error_result(message="Dataset name can't be empty.")
req["name"] = req["name"].strip() if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip()
# 验证不允许的参数 # 验证不允许的参数
not_allowed = ["id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date"] not_allowed = ["id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date"]
@@ -202,7 +205,8 @@ async def update(
return get_data_error_result( return get_data_error_result(
message="Can't find this knowledgebase!") message="Can't find this knowledgebase!")
if req["name"].lower() != kb.name.lower() \ # 检查名称重复(仅在提供新名称时)
if "name" in req and req["name"].lower() != kb.name.lower() \
and len( and len(
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1: KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
return get_data_error_result( return get_data_error_result(

View File

@@ -0,0 +1,84 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
class SetConversationRequest(BaseModel):
"""设置对话请求"""
conversation_id: Optional[str] = None
is_new: bool
name: Optional[str] = Field(default="New conversation", max_length=255)
dialog_id: str
class DeleteConversationsRequest(BaseModel):
"""删除对话请求"""
conversation_ids: List[str]
class CompletionRequest(BaseModel):
"""完成请求(聊天完成)"""
conversation_id: str
messages: List[Dict[str, Any]]
llm_id: Optional[str] = None
stream: Optional[bool] = True
temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
class TTSRequest(BaseModel):
"""文本转语音请求"""
text: str
class DeleteMessageRequest(BaseModel):
"""删除消息请求"""
conversation_id: str
message_id: str
class ThumbupRequest(BaseModel):
"""点赞/点踩请求"""
conversation_id: str
message_id: str
thumbup: Optional[bool] = None
feedback: Optional[str] = ""
class AskRequest(BaseModel):
"""提问请求"""
question: str
kb_ids: List[str]
search_id: Optional[str] = ""
class MindmapRequest(BaseModel):
"""思维导图请求"""
question: str
kb_ids: List[str]
search_id: Optional[str] = ""
class RelatedQuestionsRequest(BaseModel):
"""相关问题请求"""
question: str
search_id: Optional[str] = ""

View File

@@ -0,0 +1,43 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List
from pydantic import BaseModel, Field
class CreateFileRequest(BaseModel):
"""创建文件/文件夹请求"""
name: str
parent_id: Optional[str] = None
type: Optional[str] = None
class DeleteFilesRequest(BaseModel):
"""删除文件请求"""
file_ids: List[str]
class RenameFileRequest(BaseModel):
"""重命名文件请求"""
file_id: str
name: str
class MoveFilesRequest(BaseModel):
"""移动文件请求"""
src_file_ids: List[str]
dest_file_id: str

View File

@@ -60,9 +60,16 @@ class CreateKnowledgeBaseRequest(BaseModel):
class UpdateKnowledgeBaseRequest(BaseModel): class UpdateKnowledgeBaseRequest(BaseModel):
"""更新知识库请求""" """更新知识库请求"""
kb_id: str kb_id: str
name: str name: Optional[str] = None
description: str avatar: Optional[str] = None
parser_id: str language: Optional[str] = None
description: Optional[str] = None
permission: Optional[str] = None
doc_num: Optional[int] = None
token_num: Optional[int] = None
chunk_num: Optional[int] = None
parser_id: Optional[str] = None
embd_id: Optional[str] = None
pagerank: Optional[int] = None pagerank: Optional[int] = None
# 其他可选字段,但排除 id, tenant_id, created_by, create_time, update_time, create_date, update_date # 其他可选字段,但排除 id, tenant_id, created_by, create_time, update_time, create_date, update_date

View File

@@ -0,0 +1,53 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
class CreateSearchRequest(BaseModel):
"""创建搜索应用请求"""
name: str
description: Optional[str] = ""
class UpdateSearchRequest(BaseModel):
"""更新搜索应用请求"""
search_id: str
name: str
search_config: Dict[str, Any]
tenant_id: str
description: Optional[str] = None
class DeleteSearchRequest(BaseModel):
"""删除搜索应用请求"""
search_id: str
class ListSearchAppsQuery(BaseModel):
"""列出搜索应用查询参数"""
keywords: Optional[str] = ""
page: Optional[int] = 0
page_size: Optional[int] = 0
orderby: Optional[str] = "create_time"
desc: Optional[str] = "true"
class ListSearchAppsBody(BaseModel):
"""列出搜索应用请求体"""
owner_ids: Optional[List[str]] = []

View File

@@ -14,8 +14,18 @@
# limitations under the License. # limitations under the License.
# #
from flask import request from typing import Optional
from flask_login import current_user, login_required
from fastapi import APIRouter, Depends, Query
from api.apps.models.auth_dependencies import get_current_user
from api.apps.models.search_models import (
CreateSearchRequest,
UpdateSearchRequest,
DeleteSearchRequest,
ListSearchAppsQuery,
ListSearchAppsBody,
)
from api import settings from api import settings
from api.constants import DATASET_NAME_LIMIT from api.constants import DATASET_NAME_LIMIT
@@ -25,14 +35,23 @@ from api.db.services import duplicate_name
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
)
# 创建路由器
router = APIRouter()
@manager.route("/create", methods=["post"]) # noqa: F821 @router.post('/create')
@login_required async def create(
@validate_request("name") request: CreateSearchRequest,
def create(): current_user = Depends(get_current_user)
req = request.get_json() ):
"""创建搜索应用"""
req = request.model_dump(exclude_unset=True)
search_name = req["name"] search_name = req["name"]
description = req.get("description", "") description = req.get("description", "")
if not isinstance(search_name, str): if not isinstance(search_name, str):
@@ -62,12 +81,13 @@ def create():
return server_error_response(e) return server_error_response(e)
@manager.route("/update", methods=["post"]) # noqa: F821 @router.post('/update')
@login_required async def update(
@validate_request("search_id", "name", "search_config", "tenant_id") request: UpdateSearchRequest,
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") current_user = Depends(get_current_user)
def update(): ):
req = request.get_json() """更新搜索应用"""
req = request.model_dump(exclude_unset=True)
if not isinstance(req["name"], str): if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.") return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "": if req["name"].strip() == "":
@@ -84,6 +104,12 @@ def update():
if not SearchService.accessible4deletion(search_id, current_user.id): if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
# 验证不允许的参数
not_allowed = ["id", "created_by", "create_time", "update_time", "create_date", "update_date"]
for key in not_allowed:
if key in req:
del req[key]
try: try:
search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0] search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0]
if not search_app: if not search_app:
@@ -119,10 +145,12 @@ def update():
return server_error_response(e) return server_error_response(e)
@manager.route("/detail", methods=["GET"]) # noqa: F821 @router.get('/detail')
@login_required async def detail(
def detail(): search_id: str = Query(..., description="搜索应用ID"),
search_id = request.args["search_id"] current_user = Depends(get_current_user)
):
"""获取搜索应用详情"""
try: try:
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants: for tenant in tenants:
@@ -139,20 +167,23 @@ def detail():
return server_error_response(e) return server_error_response(e)
@manager.route("/list", methods=["POST"]) # noqa: F821 @router.post('/list')
@login_required async def list_search_app(
def list_search_app(): query: ListSearchAppsQuery = Depends(),
keywords = request.args.get("keywords", "") body: Optional[ListSearchAppsBody] = None,
page_number = int(request.args.get("page", 0)) current_user = Depends(get_current_user)
items_per_page = int(request.args.get("page_size", 0)) ):
orderby = request.args.get("orderby", "create_time") """列出搜索应用"""
if request.args.get("desc", "true").lower() == "false": if body is None:
desc = False body = ListSearchAppsBody()
else:
desc = True
req = request.get_json() keywords = query.keywords or ""
owner_ids = req.get("owner_ids", []) page_number = int(query.page or 0)
items_per_page = int(query.page_size or 0)
orderby = query.orderby or "create_time"
desc = query.desc.lower() == "true" if query.desc else True
owner_ids = body.owner_ids or [] if body else []
try: try:
if not owner_ids: if not owner_ids:
# tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) # tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
@@ -171,12 +202,13 @@ def list_search_app():
return server_error_response(e) return server_error_response(e)
@manager.route("/rm", methods=["post"]) # noqa: F821 @router.post('/rm')
@login_required async def rm(
@validate_request("search_id") request: DeleteSearchRequest,
def rm(): current_user = Depends(get_current_user)
req = request.get_json() ):
search_id = req["search_id"] """删除搜索应用"""
search_id = request.search_id
if not SearchService.accessible4deletion(search_id, current_user.id): if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)

View File

@@ -15,12 +15,15 @@
# #
import json import json
import logging import logging
import os
import re import re
import secrets import secrets
import string
import time
from datetime import datetime from datetime import datetime
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi import APIRouter, Depends, HTTPException, Request, Response, Query, status
from api.apps.models.auth_dependencies import get_current_user from api.apps.models.auth_dependencies import get_current_user
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
@@ -60,6 +63,19 @@ from api.utils.api_utils import (
validate_request, validate_request,
) )
from api.utils.crypt import decrypt from api.utils.crypt import decrypt
from rag.utils.redis_conn import REDIS_CONN
from api.apps import smtp_mail_server
from api.utils.web_utils import (
send_email_html,
OTP_LENGTH,
OTP_TTL_SECONDS,
ATTEMPT_LIMIT,
ATTEMPT_LOCK_SECONDS,
RESEND_COOLDOWN_SECONDS,
otp_keys,
hash_code,
captcha_key,
)
# 创建路由器 # 创建路由器
router = APIRouter() router = APIRouter()
@@ -77,9 +93,8 @@ class RegisterRequest(BaseModel):
password: str password: str
class UserSettingRequest(BaseModel): class UserSettingRequest(BaseModel):
nickname: Optional[str] = None language: Optional[str] = None
password: Optional[str] = None
new_password: Optional[str] = None
class TenantInfoRequest(BaseModel): class TenantInfoRequest(BaseModel):
tenant_id: str tenant_id: str
@@ -88,6 +103,16 @@ class TenantInfoRequest(BaseModel):
img2txt_id: str img2txt_id: str
llm_id: str llm_id: str
class ForgetOtpRequest(BaseModel):
email: str
captcha: str
class ForgetPasswordRequest(BaseModel):
email: str
otp: str
new_password: str
confirm_new_password: str
# 依赖项:获取当前用户 - 从 auth_dependencies 导入 # 依赖项:获取当前用户 - 从 auth_dependencies 导入
@router.post("/login") @router.post("/login")
@@ -481,3 +506,357 @@ async def set_tenant_info(request: TenantInfoRequest, current_user = Depends(get
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e) detail=str(e)
) )
@router.get("/github_callback")
async def github_callback(code: Optional[str] = Query(None)):
"""
**Deprecated**, Use `/oauth/callback/<channel>` instead.
GitHub OAuth callback endpoint.
"""
import requests
if not code:
return RedirectResponse(url="/?error=missing_code")
res = requests.post(
settings.GITHUB_OAUTH.get("url"),
data={
"client_id": settings.GITHUB_OAUTH.get("client_id"),
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
"code": code,
},
headers={"Accept": "application/json"},
)
res = res.json()
if "error" in res:
return RedirectResponse(url=f"/?error={res.get('error_description', res.get('error'))}")
if "user:email" not in res.get("scope", "").split(","):
return RedirectResponse(url="/?error=user:email not in scope")
access_token = res["access_token"]
user_info = user_info_from_github(access_token)
email_address = user_info["email"]
users = UserService.query(email=email_address)
user_id = get_uuid()
if not users:
try:
try:
avatar = download_img(user_info["avatar_url"])
except Exception as e:
logging.exception(e)
avatar = ""
users = user_register(
user_id,
{
"access_token": access_token,
"email": email_address,
"avatar": avatar,
"nickname": user_info["login"],
"login_channel": "github",
"last_login_time": get_format_time(),
"is_superuser": False,
},
)
if not users:
raise Exception(f"Fail to register {email_address}.")
if len(users) > 1:
raise Exception(f"Same email: {email_address} exists!")
user = users[0]
return RedirectResponse(url=f"/?auth={user.get_id()}")
except Exception as e:
rollback_user_registration(user_id)
logging.exception(e)
return RedirectResponse(url=f"/?error={str(e)}")
# User has already registered, try to log in
user = users[0]
user.access_token = get_uuid()
if user and hasattr(user, 'is_active') and user.is_active == "0":
return RedirectResponse(url="/?error=user_inactive")
user.save()
return RedirectResponse(url=f"/?auth={user.get_id()}")
@router.get("/feishu_callback")
async def feishu_callback(code: Optional[str] = Query(None)):
"""
Feishu OAuth callback endpoint.
"""
import requests
if not code:
return RedirectResponse(url="/?error=missing_code")
app_access_token_res = requests.post(
settings.FEISHU_OAUTH.get("app_access_token_url"),
data=json.dumps(
{
"app_id": settings.FEISHU_OAUTH.get("app_id"),
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
}
),
headers={"Content-Type": "application/json; charset=utf-8"},
)
app_access_token_res = app_access_token_res.json()
if app_access_token_res.get("code") != 0:
return RedirectResponse(url=f"/?error={app_access_token_res}")
res = requests.post(
settings.FEISHU_OAUTH.get("user_access_token_url"),
data=json.dumps(
{
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
"code": code,
}
),
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {app_access_token_res['app_access_token']}",
},
)
res = res.json()
if res.get("code") != 0:
return RedirectResponse(url=f"/?error={res.get('message', 'unknown_error')}")
if "contact:user.email:readonly" not in res.get("data", {}).get("scope", "").split():
return RedirectResponse(url="/?error=contact:user.email:readonly not in scope")
access_token = res["data"]["access_token"]
user_info = user_info_from_feishu(access_token)
email_address = user_info["email"]
users = UserService.query(email=email_address)
user_id = get_uuid()
if not users:
try:
try:
avatar = download_img(user_info["avatar_url"])
except Exception as e:
logging.exception(e)
avatar = ""
users = user_register(
user_id,
{
"access_token": access_token,
"email": email_address,
"avatar": avatar,
"nickname": user_info["en_name"],
"login_channel": "feishu",
"last_login_time": get_format_time(),
"is_superuser": False,
},
)
if not users:
raise Exception(f"Fail to register {email_address}.")
if len(users) > 1:
raise Exception(f"Same email: {email_address} exists!")
user = users[0]
return RedirectResponse(url=f"/?auth={user.get_id()}")
except Exception as e:
rollback_user_registration(user_id)
logging.exception(e)
return RedirectResponse(url=f"/?error={str(e)}")
# User has already registered, try to log in
user = users[0]
if user and hasattr(user, 'is_active') and user.is_active == "0":
return RedirectResponse(url="/?error=user_inactive")
user.access_token = get_uuid()
user.save()
return RedirectResponse(url=f"/?auth={user.get_id()}")
def user_info_from_feishu(access_token):
"""从飞书获取用户信息"""
import requests
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
user_info = res.json()["data"]
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
return user_info
def user_info_from_github(access_token):
"""从GitHub获取用户信息"""
import requests
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
user_info = res.json()
email_info = requests.get(
f"https://api.github.com/user/emails?access_token={access_token}",
headers=headers,
).json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return user_info
@router.get("/forget/captcha")
async def forget_get_captcha(email: str = Query(...)):
"""
GET /forget/captcha?email=<email>
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = 60 seconds.
- Returns the captcha as a JPEG image.
"""
if not email:
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email is required")
users = UserService.query(email=email)
if not users:
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
# Generate captcha text
allowed = string.ascii_uppercase + string.digits
captcha_text = "".join(secrets.choice(allowed) for _ in range(OTP_LENGTH))
REDIS_CONN.set(captcha_key(email), captcha_text, 60) # Valid for 60 seconds
from captcha.image import ImageCaptcha
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
img_bytes = image.generate(captcha_text).read()
return Response(content=img_bytes, media_type="image/JPEG")
@router.post("/forget/otp")
async def forget_send_otp(request: ForgetOtpRequest):
"""
POST /forget/otp
- Verify the image captcha stored at captcha:{email} (case-insensitive).
- On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
"""
email = request.email or ""
captcha = (request.captcha or "").strip()
if not email or not captcha:
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email and captcha required")
users = UserService.query(email=email)
if not users:
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
stored_captcha = REDIS_CONN.get(captcha_key(email))
if not stored_captcha:
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="invalid or expired captcha")
if (stored_captcha or "").strip().lower() != captcha.lower():
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="invalid or expired captcha")
# Delete captcha to prevent reuse
REDIS_CONN.delete(captcha_key(email))
k_code, k_attempts, k_last, k_lock = otp_keys(email)
now = int(time.time())
last_ts = REDIS_CONN.get(k_last)
if last_ts:
try:
elapsed = now - int(last_ts)
except Exception:
elapsed = RESEND_COOLDOWN_SECONDS
remaining = RESEND_COOLDOWN_SECONDS - elapsed
if remaining > 0:
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message=f"you still have to wait {remaining} seconds")
# Generate OTP (uppercase letters only) and store hashed
otp = "".join(secrets.choice(string.ascii_uppercase) for _ in range(OTP_LENGTH))
salt = os.urandom(16)
code_hash = hash_code(otp, salt)
REDIS_CONN.set(k_code, f"{code_hash}:{salt.hex()}", OTP_TTL_SECONDS)
REDIS_CONN.set(k_attempts, 0, OTP_TTL_SECONDS)
REDIS_CONN.set(k_last, now, OTP_TTL_SECONDS)
REDIS_CONN.delete(k_lock)
ttl_min = OTP_TTL_SECONDS // 60
if not smtp_mail_server:
logging.warning("SMTP mail server not initialized; skip sending email.")
else:
try:
send_email_html(
subject="Your Password Reset Code",
to_email=email,
template_key="reset_code",
code=otp,
ttl_min=ttl_min,
)
except Exception:
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="failed to send email")
return get_json_result(data=True, code=settings.RetCode.SUCCESS, message="verification passed, email sent")
@router.post("/forget")
async def forget(request: ForgetPasswordRequest):
"""
POST: Verify email + OTP and reset password, then log the user in.
Request JSON: { email, otp, new_password, confirm_new_password }
"""
email = request.email or ""
otp = (request.otp or "").strip()
new_pwd = request.new_password
new_pwd2 = request.confirm_new_password
if not all([email, otp, new_pwd, new_pwd2]):
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required")
# For reset, passwords are provided as-is (no decrypt needed)
if new_pwd != new_pwd2:
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="passwords do not match")
users = UserService.query(email=email)
if not users:
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
user = users[0]
# Verify OTP from Redis
k_code, k_attempts, k_last, k_lock = otp_keys(email)
if REDIS_CONN.get(k_lock):
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="too many attempts, try later")
stored = REDIS_CONN.get(k_code)
if not stored:
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="expired otp")
try:
stored_hash, salt_hex = str(stored).split(":", 1)
salt = bytes.fromhex(salt_hex)
except Exception:
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
# Case-insensitive verification: OTP generated uppercase
calc = hash_code(otp.upper(), salt)
if calc != stored_hash:
# bump attempts
try:
attempts = int(REDIS_CONN.get(k_attempts) or 0) + 1
except Exception:
attempts = 1
REDIS_CONN.set(k_attempts, attempts, OTP_TTL_SECONDS)
if attempts >= ATTEMPT_LIMIT:
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="expired otp")
# Success: consume OTP and reset password
REDIS_CONN.delete(k_code)
REDIS_CONN.delete(k_attempts)
REDIS_CONN.delete(k_last)
REDIS_CONN.delete(k_lock)
try:
UserService.update_user_password(user.id, new_pwd)
except Exception as e:
logging.exception(e)
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="failed to reset password")
# Auto login (reuse login flow)
user.access_token = get_uuid()
user.update_time = (current_timestamp(),)
user.update_date = (datetime_format(datetime.now()),)
user.save()
msg = "Password reset successful. Logged in."
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)

View File

@@ -8,6 +8,10 @@ minio:
user: 'rag_flow' user: 'rag_flow'
password: 'infini_rag_flow' password: 'infini_rag_flow'
host: 'localhost:9000' host: 'localhost:9000'
es:
hosts: 'http://localhost:1200'
username: 'elastic'
password: 'infini_rag_flow'
os: os:
hosts: 'http://localhost:1201' hosts: 'http://localhost:1201'
username: 'admin' username: 'admin'

View File

@@ -3,7 +3,7 @@
# - `elasticsearch` (default) # - `elasticsearch` (default)
# - `infinity` (https://github.com/infiniflow/infinity) # - `infinity` (https://github.com/infiniflow/infinity)
# - `opensearch` (https://github.com/opensearch-project/OpenSearch) # - `opensearch` (https://github.com/opensearch-project/OpenSearch)
DOC_ENGINE=opensearch DOC_ENGINE=elasticsearch
# ------------------------------ # ------------------------------
# docker env var for specifying vector db type at startup # docker env var for specifying vector db type at startup
@@ -98,7 +98,7 @@ ADMIN_SVR_HTTP_PORT=9381
# The RAGFlow Docker image to download. # The RAGFlow Docker image to download.
# Defaults to the v0.21.1-slim edition, which is the RAGFlow Docker image without embedding models. # Defaults to the v0.21.1-slim edition, which is the RAGFlow Docker image without embedding models.
RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi-web
# #
# To download the RAGFlow Docker image with embedding models, uncomment the following line instead: # To download the RAGFlow Docker image with embedding models, uncomment the following line instead:
# RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1 # RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1

View File

@@ -1,36 +1,34 @@
services: services:
opensearch01: es01:
container_name: ragflow-opensearch-01 container_name: ragflow-es-01
profiles: profiles:
- opensearch - elasticsearch
image: hub.icert.top/opensearchproject/opensearch:2.19.1 image: elasticsearch:${STACK_VERSION}
volumes: volumes:
- osdata01:/usr/share/opensearch/data - esdata01:/usr/share/elasticsearch/data
ports: ports:
- ${OS_PORT}:9201 - ${ES_PORT}:9200
env_file: .env env_file: .env
environment: environment:
- node.name=opensearch01 - node.name=es01
- OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD} - ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_PASSWORD}
- bootstrap.memory_lock=false - bootstrap.memory_lock=false
- discovery.type=single-node - discovery.type=single-node
- plugins.security.disabled=false - xpack.security.enabled=true
- plugins.security.ssl.http.enabled=false - xpack.security.http.ssl.enabled=false
- plugins.security.ssl.transport.enabled=true - xpack.security.transport.ssl.enabled=false
- cluster.routing.allocation.disk.watermark.low=5gb - cluster.routing.allocation.disk.watermark.low=5gb
- cluster.routing.allocation.disk.watermark.high=3gb - cluster.routing.allocation.disk.watermark.high=3gb
- cluster.routing.allocation.disk.watermark.flood_stage=2gb - cluster.routing.allocation.disk.watermark.flood_stage=2gb
- TZ=${TIMEZONE} - TZ=${TIMEZONE}
- http.port=9201
mem_limit: ${MEM_LIMIT} mem_limit: ${MEM_LIMIT}
ulimits: ulimits:
memlock: memlock:
soft: -1 soft: -1
hard: -1 hard: -1
healthcheck: healthcheck:
test: ["CMD-SHELL", "curl http://localhost:9201"] test: ["CMD-SHELL", "curl http://localhost:9200"]
interval: 10s interval: 10s
timeout: 10s timeout: 10s
retries: 120 retries: 120

View File

@@ -64,7 +64,8 @@ class Dealer:
if key in req and req[key] is not None: if key in req and req[key] is not None:
condition[field] = req[key] condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns. # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]: for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
"removed_kwd"]:
if key in req and req[key] is not None: if key in req and req[key] is not None:
condition[key] = req[key] condition[key] = req[key]
return condition return condition
@@ -135,7 +136,8 @@ class Dealer:
matchText, _ = self.qryr.question(qst, min_match=0.1) matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17 matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature)
total = self.dataStore.getTotal(res) total = self.dataStore.getTotal(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total)) logging.debug("Dealer.search 2 TOTAL: {}".format(total))
@@ -212,8 +214,9 @@ class Dealer:
ans_v, _ = embd_mdl.encode(pieces_) ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)): for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]): if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0]*len(ans_v[0]) chunk_v[i] = [0.0] * len(ans_v[0])
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i]))) logging.warning(
"The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0])) len(ans_v[0]), len(chunk_v[0]))
@@ -267,7 +270,7 @@ class Dealer:
if not query_rfea: if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD])) q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids: for i in search_res.ids:
nor, denor = 0, 0 nor, denor = 0, 0
if not search_res.field[i].get(TAG_FLD): if not search_res.field[i].get(TAG_FLD):
@@ -280,8 +283,8 @@ class Dealer:
if denor == 0: if denor == 0:
rank_fea.append(0) rank_fea.append(0)
else: else:
rank_fea.append(nor/np.sqrt(denor)/q_denor) rank_fea.append(nor / np.sqrt(denor) / q_denor)
return np.array(rank_fea)*10. + pageranks return np.array(rank_fea) * 10. + pageranks
def rerank(self, sres, query, tkweight=0.3, def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks", vtweight=0.7, cfield="content_ltks",
@@ -343,7 +346,7 @@ class Dealer:
## For rank feature(tag_fea) scores. ## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres) rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd, return self.qryr.hybrid_similarity(ans_embd,
@@ -360,13 +363,13 @@ class Dealer:
return ranks return ranks
# Ensure RERANK_LIMIT is multiple of page_size # Ensure RERANK_LIMIT is multiple of page_size
RERANK_LIMIT = math.ceil(64/page_size) * page_size if page_size>1 else 1 RERANK_LIMIT = math.ceil(64 / page_size) * page_size if page_size > 1 else 1
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT, req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size * page / RERANK_LIMIT),
"size": RERANK_LIMIT,
"question": question, "vector": True, "topk": top, "question": question, "vector": True, "topk": top,
"similarity": similarity_threshold, "similarity": similarity_threshold,
"available_int": 1} "available_int": 1}
if isinstance(tenant_ids, str): if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",") tenant_ids = tenant_ids.split(",")
@@ -392,15 +395,15 @@ class Dealer:
tsim = sim tsim = sim
vsim = sim vsim = sim
# Already paginated in search function # Already paginated in search function
begin = ((page % (RERANK_LIMIT//page_size)) - 1) * page_size begin = ((page % (RERANK_LIMIT // page_size)) - 1) * page_size
sim = sim[begin : begin + page_size] sim = sim[begin: begin + page_size]
sim_np = np.array(sim) sim_np = np.array(sim)
idx = np.argsort(sim_np * -1) idx = np.argsort(sim_np * -1)
dim = len(sres.query_vector) dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec" vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim zero_vector = [0.0] * dim
filtered_count = (sim_np >= similarity_threshold).sum() filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx: for i in idx:
if sim[i] < similarity_threshold: if sim[i] < similarity_threshold:
break break
@@ -447,8 +450,8 @@ class Dealer:
ranks["doc_aggs"] = [{"doc_name": k, ranks["doc_aggs"] = [{"doc_name": k,
"doc_id": v["doc_id"], "doc_id": v["doc_id"],
"count": v["count"]} for k, "count": v["count"]} for k,
v in sorted(ranks["doc_aggs"].items(), v in sorted(ranks["doc_aggs"].items(),
key=lambda x: x[1]["count"] * -1)] key=lambda x: x[1]["count"] * -1)]
ranks["chunks"] = ranks["chunks"][:page_size] ranks["chunks"] = ranks["chunks"][:page_size]
return ranks return ranks
@@ -505,13 +508,14 @@ class Dealer:
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000): def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id) idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn) match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"]) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd") aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs: if not aggs:
return False return False
cnt = np.sum([c for _, c in aggs]) cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags] key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0} doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
return True return True
@@ -527,11 +531,11 @@ class Dealer:
if not aggs: if not aggs:
return {} return {}
cnt = np.sum([c for _, c in aggs]) cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs], tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags] key=lambda x: x[1] * -1)[:topn_tags]
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea} return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6): def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6):
if not chunks: if not chunks:
return [] return []
idx_nms = [index_name(tid) for tid in tenant_ids] idx_nms = [index_name(tid) for tid in tenant_ids]
@@ -541,9 +545,10 @@ class Dealer:
ranks[ck["doc_id"]] = 0 ranks[ck["doc_id"]] = 0
ranks[ck["doc_id"]] += ck["similarity"] ranks[ck["doc_id"]] += ck["similarity"]
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"] doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0] doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
kb_ids = [doc_id2kb_id[doc_id]] kb_ids = [doc_id2kb_id[doc_id]]
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms, es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [],
OrderByExpr(), 0, 128, idx_nms,
kb_ids) kb_ids)
toc = [] toc = []
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"]) dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
@@ -555,7 +560,7 @@ class Dealer:
if not toc: if not toc:
return chunks return chunks
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2) ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2)
if not ids: if not ids:
return chunks return chunks
@@ -589,4 +594,4 @@ class Dealer:
break break
chunks.append(d) chunks.append(d)
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn] return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn]

35
rag/res/deepdoc/.gitattributes vendored Normal file
View File

@@ -0,0 +1,35 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text

Binary file not shown.

Binary file not shown.

BIN
rag/res/deepdoc/layout.onnx Normal file

Binary file not shown.

Binary file not shown.

BIN
rag/res/deepdoc/tsr.onnx Normal file

Binary file not shown.