修复 画布,mcp服务,搜索,文档的接口
This commit is contained in:
@@ -165,17 +165,23 @@ def setup_routes(app: FastAPI):
|
||||
from api.apps.tenant_app import router as tenant_router
|
||||
from api.apps.dialog_app import router as dialog_router
|
||||
from api.apps.system_app import router as system_router
|
||||
from api.apps.search_app import router as search_router
|
||||
from api.apps.conversation_app import router as conversation_router
|
||||
from api.apps.file_app import router as file_router
|
||||
|
||||
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
|
||||
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"])
|
||||
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
|
||||
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
|
||||
app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"])
|
||||
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
|
||||
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp_server", tags=["MCP"])
|
||||
app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"])
|
||||
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
|
||||
app.include_router(dialog_router, prefix=f"/{API_VERSION}/dialog", tags=["Dialog"])
|
||||
app.include_router(system_router, prefix=f"/{API_VERSION}/system", tags=["System"])
|
||||
app.include_router(search_router, prefix=f"/{API_VERSION}/search", tags=["Search"])
|
||||
app.include_router(conversation_router, prefix=f"/{API_VERSION}/conversation", tags=["Conversation"])
|
||||
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -17,8 +17,10 @@ import json
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query, Header, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import APIToken
|
||||
@@ -28,15 +30,35 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
|
||||
from api.utils import get_uuid
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.conversation_models import (
|
||||
SetConversationRequest,
|
||||
DeleteConversationsRequest,
|
||||
CompletionRequest,
|
||||
TTSRequest,
|
||||
DeleteMessageRequest,
|
||||
ThumbupRequest,
|
||||
AskRequest,
|
||||
MindmapRequest,
|
||||
RelatedQuestionsRequest,
|
||||
)
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_conversation():
|
||||
req = request.json
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/set')
|
||||
async def set_conversation(
|
||||
request: SetConversationRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""设置对话"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
conv_id = req.get("conversation_id")
|
||||
is_new = req.get("is_new")
|
||||
name = req.get("name", "New conversation")
|
||||
@@ -45,9 +67,9 @@ def set_conversation():
|
||||
if len(name) > 255:
|
||||
name = name[0:255]
|
||||
|
||||
del req["is_new"]
|
||||
if not is_new:
|
||||
del req["conversation_id"]
|
||||
if not conv_id:
|
||||
return get_data_error_result(message="conversation_id is required when is_new is False!")
|
||||
try:
|
||||
if not ConversationService.update_by_id(conv_id, req):
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@@ -64,7 +86,7 @@ def set_conversation():
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": conv_id,
|
||||
"id": conv_id or get_uuid(),
|
||||
"dialog_id": req["dialog_id"],
|
||||
"name": name,
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
|
||||
@@ -77,12 +99,14 @@ def set_conversation():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
conv_id = request.args["conversation_id"]
|
||||
@router.get('/get')
|
||||
async def get(
|
||||
conversation_id: str = Query(..., description="对话ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取对话"""
|
||||
try:
|
||||
e, conv = ConversationService.get_by_id(conv_id)
|
||||
e, conv = ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
@@ -107,15 +131,27 @@ def get():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
|
||||
def getsse(dialog_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@router.get('/getsse/{dialog_id}')
|
||||
async def getsse(
|
||||
dialog_id: str,
|
||||
authorization: Optional[str] = Header(None, alias="Authorization")
|
||||
):
|
||||
"""通过 SSE 获取对话(使用 API token 认证)"""
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header is required"
|
||||
)
|
||||
|
||||
token_parts = authorization.split()
|
||||
if len(token_parts) != 2:
|
||||
return get_data_error_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
token = token_parts[1]
|
||||
|
||||
objs = APIToken.query(beta=token)
|
||||
if not objs:
|
||||
return get_data_error_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
try:
|
||||
e, conv = DialogService.get_by_id(dialog_id)
|
||||
if not e:
|
||||
@@ -128,12 +164,14 @@ def getsse(dialog_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def rm():
|
||||
conv_ids = request.json["conversation_ids"]
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
request: DeleteConversationsRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除对话"""
|
||||
try:
|
||||
for cid in conv_ids:
|
||||
for cid in request.conversation_ids:
|
||||
exist, conv = ConversationService.get_by_id(cid)
|
||||
if not exist:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@@ -149,10 +187,12 @@ def rm():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_conversation():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
@router.get('/list')
|
||||
async def list_conversation(
|
||||
dialog_id: str = Query(..., description="对话ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出对话"""
|
||||
try:
|
||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
@@ -164,11 +204,13 @@ def list_conversation():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/completion", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
req = request.json
|
||||
@router.post('/completion')
|
||||
async def completion(
|
||||
request: CompletionRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""完成请求(聊天完成)"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
@@ -176,6 +218,10 @@ def completion():
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
|
||||
if not msg:
|
||||
return get_data_error_result(message="No valid messages found!")
|
||||
|
||||
message_id = msg[-1].get("id")
|
||||
chat_model_id = req.get("llm_id", "")
|
||||
req.pop("llm_id", None)
|
||||
@@ -217,6 +263,7 @@ def completion():
|
||||
dia.llm_setting = chat_model_config
|
||||
|
||||
is_embedded = bool(chat_model_id)
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
@@ -230,14 +277,18 @@ def completion():
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
stream_enabled = request.stream if request.stream is not None else True
|
||||
if stream_enabled:
|
||||
return StreamingResponse(
|
||||
stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Type": "text/event-stream; charset=utf-8"
|
||||
}
|
||||
)
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
@@ -250,11 +301,13 @@ def completion():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def tts():
|
||||
req = request.json
|
||||
text = req["text"]
|
||||
@router.post('/tts')
|
||||
async def tts(
|
||||
request: TTSRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""文本转语音"""
|
||||
text = request.text
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
if not tenants:
|
||||
@@ -274,28 +327,32 @@ def tts():
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
||||
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
|
||||
return resp
|
||||
return StreamingResponse(
|
||||
stream_audio(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def delete_msg():
|
||||
req = request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
@router.post('/delete_msg')
|
||||
async def delete_msg(
|
||||
request: DeleteMessageRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除消息"""
|
||||
e, conv = ConversationService.get_by_id(request.conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
|
||||
conv = conv.to_dict()
|
||||
for i, msg in enumerate(conv["message"]):
|
||||
if req["message_id"] != msg.get("id", ""):
|
||||
if request.message_id != msg.get("id", ""):
|
||||
continue
|
||||
assert conv["message"][i + 1]["id"] == req["message_id"]
|
||||
assert conv["message"][i + 1]["id"] == request.message_id
|
||||
conv["message"].pop(i)
|
||||
conv["message"].pop(i)
|
||||
conv["reference"].pop(max(0, i // 2 - 1))
|
||||
@@ -305,19 +362,21 @@ def delete_msg():
|
||||
return get_json_result(data=conv)
|
||||
|
||||
|
||||
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def thumbup():
|
||||
req = request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
@router.post('/thumbup')
|
||||
async def thumbup(
|
||||
request: ThumbupRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""点赞/点踩"""
|
||||
e, conv = ConversationService.get_by_id(request.conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
up_down = req.get("thumbup")
|
||||
feedback = req.get("feedback", "")
|
||||
|
||||
up_down = request.thumbup
|
||||
feedback = request.feedback or ""
|
||||
conv = conv.to_dict()
|
||||
for i, msg in enumerate(conv["message"]):
|
||||
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
|
||||
if request.message_id == msg.get("id", "") and msg.get("role", "") == "assistant":
|
||||
if up_down:
|
||||
msg["thumbup"] = True
|
||||
if "feedback" in msg:
|
||||
@@ -332,14 +391,15 @@ def thumbup():
|
||||
return get_json_result(data=conv)
|
||||
|
||||
|
||||
@manager.route("/ask", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about():
|
||||
req = request.json
|
||||
@router.post('/ask')
|
||||
async def ask_about(
|
||||
request: AskRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""提问"""
|
||||
uid = current_user.id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_id = request.search_id or ""
|
||||
search_app = None
|
||||
search_config = {}
|
||||
if search_id:
|
||||
@@ -348,53 +408,58 @@ def ask_about():
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
def stream():
|
||||
nonlocal req, uid
|
||||
nonlocal request, uid
|
||||
try:
|
||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||
for ans in ask(request.question, request.kb_ids, uid, search_config=search_config):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
return StreamingResponse(
|
||||
stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Type": "text/event-stream; charset=utf-8"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
req = request.json
|
||||
search_id = req.get("search_id", "")
|
||||
@router.post('/mindmap')
|
||||
async def mindmap(
|
||||
request: MindmapRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""思维导图"""
|
||||
search_id = request.search_id or ""
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||
kb_ids = search_config.get("kb_ids", [])
|
||||
kb_ids.extend(req["kb_ids"])
|
||||
kb_ids.extend(request.kb_ids)
|
||||
kb_ids = list(set(kb_ids))
|
||||
|
||||
mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
||||
mind_map = gen_mindmap(request.question, kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
|
||||
|
||||
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
def related_questions():
|
||||
req = request.json
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@router.post('/related_questions')
|
||||
async def related_questions(
|
||||
request: RelatedQuestionsRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""相关问题"""
|
||||
search_id = request.search_id or ""
|
||||
search_config = {}
|
||||
if search_id:
|
||||
if search_app := SearchService.get_detail(search_id):
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
question = req["question"]
|
||||
question = request.question
|
||||
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)
|
||||
|
||||
@@ -72,17 +72,16 @@ router = APIRouter()
|
||||
@router.post("/upload")
|
||||
async def upload(
|
||||
kb_id: str = Form(...),
|
||||
files: List[UploadFile] = File(...),
|
||||
file: UploadFile = File(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""上传文档"""
|
||||
if not files:
|
||||
if not file:
|
||||
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
for file_obj in files:
|
||||
if not file_obj.filename or file_obj.filename == "":
|
||||
if not file.filename or file.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
if len(file.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
@@ -91,7 +90,7 @@ async def upload(
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
err, uploaded_files = FileService.upload_document(kb, files, current_user.id)
|
||||
err, uploaded_files = FileService.upload_document(kb, [file], current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=uploaded_files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@@ -17,15 +17,23 @@ import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Optional, List
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile, File, Form
|
||||
from fastapi.responses import Response
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.file_models import (
|
||||
CreateFileRequest,
|
||||
DeleteFilesRequest,
|
||||
RenameFileRequest,
|
||||
MoveFilesRequest,
|
||||
)
|
||||
|
||||
from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result
|
||||
from api.utils import get_uuid
|
||||
from api.db import FileType, FileSource
|
||||
from api.db.services import duplicate_name
|
||||
@@ -36,35 +44,41 @@ from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
# @validate_request("parent_id")
|
||||
def upload():
|
||||
pf_id = request.form.get("parent_id")
|
||||
|
||||
@router.post('/upload')
|
||||
async def upload(
|
||||
files: List[UploadFile] = File(...),
|
||||
parent_id: Optional[str] = Form(None),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""上传文件"""
|
||||
pf_id = parent_id
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
if not files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
for file_obj in files:
|
||||
if not file_obj.filename or file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_res = []
|
||||
try:
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_data_error_result( message="Can't find this folder!")
|
||||
for file_obj in file_objs:
|
||||
return get_data_error_result(message="Can't find this folder!")
|
||||
for file_obj in files:
|
||||
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
|
||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||
return get_data_error_result(message="Exceed the maximum file number of a free user!")
|
||||
|
||||
# split file name path
|
||||
if not file_obj.filename:
|
||||
@@ -97,7 +111,7 @@ def upload():
|
||||
location = file_obj_names[file_len - 1]
|
||||
while STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
location += "_"
|
||||
blob = file_obj.read()
|
||||
blob = await file_obj.read()
|
||||
filename = duplicate_name(
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
@@ -120,13 +134,16 @@ def upload():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
@router.post('/create')
|
||||
async def create(
|
||||
request: CreateFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""创建文件/文件夹"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
pf_id = req.get("parent_id")
|
||||
input_file_type = req.get("type")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
@@ -160,17 +177,22 @@ def create():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_files():
|
||||
pf_id = request.args.get("parent_id")
|
||||
@router.get('/list')
|
||||
async def list_files(
|
||||
parent_id: Optional[str] = Query(None, description="父文件夹ID"),
|
||||
keywords: Optional[str] = Query("", description="搜索关键词"),
|
||||
page: Optional[int] = Query(1, description="页码"),
|
||||
page_size: Optional[int] = Query(15, description="每页数量"),
|
||||
orderby: Optional[str] = Query("create_time", description="排序字段"),
|
||||
desc: Optional[bool] = Query(True, description="是否降序"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出文件"""
|
||||
pf_id = parent_id
|
||||
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(page) if page else 1
|
||||
items_per_page = int(page_size) if page_size else 15
|
||||
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 15))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
@@ -192,9 +214,11 @@ def list_files():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/root_folder', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get_root_folder():
|
||||
@router.get('/root_folder')
|
||||
async def get_root_folder(
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取根文件夹"""
|
||||
try:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
return get_json_result(data={"root_folder": root_folder})
|
||||
@@ -202,10 +226,12 @@ def get_root_folder():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/parent_folder', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get_parent_folder():
|
||||
file_id = request.args.get("file_id")
|
||||
@router.get('/parent_folder')
|
||||
async def get_parent_folder(
|
||||
file_id: str = Query(..., description="文件ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取父文件夹"""
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@@ -217,10 +243,12 @@ def get_parent_folder():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/all_parent_folder', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get_all_parent_folders():
|
||||
file_id = request.args.get("file_id")
|
||||
@router.get('/all_parent_folder')
|
||||
async def get_all_parent_folders(
|
||||
file_id: str = Query(..., description="文件ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取所有父文件夹"""
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@@ -235,12 +263,13 @@ def get_all_parent_folders():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
file_ids = req["file_ids"]
|
||||
@router.post("/rm")
|
||||
async def rm(
|
||||
request: DeleteFilesRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除文件"""
|
||||
file_ids = request.file_ids
|
||||
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
@@ -296,11 +325,13 @@ def rm():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rename', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
@router.post('/rename')
|
||||
async def rename(
|
||||
request: RenameFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""重命名文件"""
|
||||
req = request.model_dump()
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
@@ -314,8 +345,8 @@ def rename():
|
||||
data=False,
|
||||
message="The extension of file can't be changed",
|
||||
code=settings.RetCode.ARGUMENT_ERROR)
|
||||
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if file.name == req["name"]:
|
||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if existing_file.name == req["name"]:
|
||||
return get_data_error_result(
|
||||
message="Duplicated file name in the same folder.")
|
||||
|
||||
@@ -336,9 +367,12 @@ def rename():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get(file_id):
|
||||
@router.get('/get/{file_id}')
|
||||
async def get(
|
||||
file_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取文件内容"""
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@@ -351,25 +385,28 @@ def get(file_id):
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
|
||||
content_type = "application/octet-stream"
|
||||
if ext:
|
||||
if file.type == FileType.VISUAL.value:
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
else:
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
response.headers.set("Content-Type", content_type)
|
||||
return response
|
||||
|
||||
return Response(content=blob, media_type=content_type)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/mv", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("src_file_ids", "dest_file_id")
|
||||
def move():
|
||||
req = request.json
|
||||
@router.post("/mv")
|
||||
async def move(
|
||||
request: MoveFilesRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""移动文件"""
|
||||
req = request.model_dump()
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
dest_parent_id = req["dest_file_id"]
|
||||
|
||||
@@ -169,6 +169,9 @@ async def update(
|
||||
):
|
||||
"""更新知识库"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
|
||||
# 验证 name 字段(如果提供)
|
||||
if "name" in req:
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@@ -202,7 +205,8 @@ async def update(
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
|
||||
if req["name"].lower() != kb.name.lower() \
|
||||
# 检查名称重复(仅在提供新名称时)
|
||||
if "name" in req and req["name"].lower() != kb.name.lower() \
|
||||
and len(
|
||||
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
||||
return get_data_error_result(
|
||||
|
||||
84
api/apps/models/conversation_models.py
Normal file
84
api/apps/models/conversation_models.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SetConversationRequest(BaseModel):
|
||||
"""设置对话请求"""
|
||||
conversation_id: Optional[str] = None
|
||||
is_new: bool
|
||||
name: Optional[str] = Field(default="New conversation", max_length=255)
|
||||
dialog_id: str
|
||||
|
||||
|
||||
class DeleteConversationsRequest(BaseModel):
|
||||
"""删除对话请求"""
|
||||
conversation_ids: List[str]
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""完成请求(聊天完成)"""
|
||||
conversation_id: str
|
||||
messages: List[Dict[str, Any]]
|
||||
llm_id: Optional[str] = None
|
||||
stream: Optional[bool] = True
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
"""文本转语音请求"""
|
||||
text: str
|
||||
|
||||
|
||||
class DeleteMessageRequest(BaseModel):
|
||||
"""删除消息请求"""
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
|
||||
|
||||
class ThumbupRequest(BaseModel):
|
||||
"""点赞/点踩请求"""
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
thumbup: Optional[bool] = None
|
||||
feedback: Optional[str] = ""
|
||||
|
||||
|
||||
class AskRequest(BaseModel):
|
||||
"""提问请求"""
|
||||
question: str
|
||||
kb_ids: List[str]
|
||||
search_id: Optional[str] = ""
|
||||
|
||||
|
||||
class MindmapRequest(BaseModel):
|
||||
"""思维导图请求"""
|
||||
question: str
|
||||
kb_ids: List[str]
|
||||
search_id: Optional[str] = ""
|
||||
|
||||
|
||||
class RelatedQuestionsRequest(BaseModel):
|
||||
"""相关问题请求"""
|
||||
question: str
|
||||
search_id: Optional[str] = ""
|
||||
|
||||
43
api/apps/models/file_models.py
Normal file
43
api/apps/models/file_models.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateFileRequest(BaseModel):
|
||||
"""创建文件/文件夹请求"""
|
||||
name: str
|
||||
parent_id: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
|
||||
|
||||
class DeleteFilesRequest(BaseModel):
|
||||
"""删除文件请求"""
|
||||
file_ids: List[str]
|
||||
|
||||
|
||||
class RenameFileRequest(BaseModel):
|
||||
"""重命名文件请求"""
|
||||
file_id: str
|
||||
name: str
|
||||
|
||||
|
||||
class MoveFilesRequest(BaseModel):
|
||||
"""移动文件请求"""
|
||||
src_file_ids: List[str]
|
||||
dest_file_id: str
|
||||
|
||||
@@ -60,9 +60,16 @@ class CreateKnowledgeBaseRequest(BaseModel):
|
||||
class UpdateKnowledgeBaseRequest(BaseModel):
|
||||
"""更新知识库请求"""
|
||||
kb_id: str
|
||||
name: str
|
||||
description: str
|
||||
parser_id: str
|
||||
name: Optional[str] = None
|
||||
avatar: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
permission: Optional[str] = None
|
||||
doc_num: Optional[int] = None
|
||||
token_num: Optional[int] = None
|
||||
chunk_num: Optional[int] = None
|
||||
parser_id: Optional[str] = None
|
||||
embd_id: Optional[str] = None
|
||||
pagerank: Optional[int] = None
|
||||
# 其他可选字段,但排除 id, tenant_id, created_by, create_time, update_time, create_date, update_date
|
||||
|
||||
|
||||
53
api/apps/models/search_models.py
Normal file
53
api/apps/models/search_models.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateSearchRequest(BaseModel):
|
||||
"""创建搜索应用请求"""
|
||||
name: str
|
||||
description: Optional[str] = ""
|
||||
|
||||
|
||||
class UpdateSearchRequest(BaseModel):
|
||||
"""更新搜索应用请求"""
|
||||
search_id: str
|
||||
name: str
|
||||
search_config: Dict[str, Any]
|
||||
tenant_id: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class DeleteSearchRequest(BaseModel):
|
||||
"""删除搜索应用请求"""
|
||||
search_id: str
|
||||
|
||||
|
||||
class ListSearchAppsQuery(BaseModel):
|
||||
"""列出搜索应用查询参数"""
|
||||
keywords: Optional[str] = ""
|
||||
page: Optional[int] = 0
|
||||
page_size: Optional[int] = 0
|
||||
orderby: Optional[str] = "create_time"
|
||||
desc: Optional[str] = "true"
|
||||
|
||||
|
||||
class ListSearchAppsBody(BaseModel):
|
||||
"""列出搜索应用请求体"""
|
||||
owner_ids: Optional[List[str]] = []
|
||||
|
||||
@@ -14,8 +14,18 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.search_models import (
|
||||
CreateSearchRequest,
|
||||
UpdateSearchRequest,
|
||||
DeleteSearchRequest,
|
||||
ListSearchAppsQuery,
|
||||
ListSearchAppsBody,
|
||||
)
|
||||
|
||||
from api import settings
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
@@ -25,14 +35,23 @@ from api.db.services import duplicate_name
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
)
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@manager.route("/create", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.get_json()
|
||||
@router.post('/create')
|
||||
async def create(
|
||||
request: CreateSearchRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""创建搜索应用"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
search_name = req["name"]
|
||||
description = req.get("description", "")
|
||||
if not isinstance(search_name, str):
|
||||
@@ -62,12 +81,13 @@ def create():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/update", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("search_id", "name", "search_config", "tenant_id")
|
||||
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.get_json()
|
||||
@router.post('/update')
|
||||
async def update(
|
||||
request: UpdateSearchRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""更新搜索应用"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Search name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@@ -84,6 +104,12 @@ def update():
|
||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
# 验证不允许的参数
|
||||
not_allowed = ["id", "created_by", "create_time", "update_time", "create_date", "update_date"]
|
||||
for key in not_allowed:
|
||||
if key in req:
|
||||
del req[key]
|
||||
|
||||
try:
|
||||
search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0]
|
||||
if not search_app:
|
||||
@@ -119,10 +145,12 @@ def update():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/detail", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def detail():
|
||||
search_id = request.args["search_id"]
|
||||
@router.get('/detail')
|
||||
async def detail(
|
||||
search_id: str = Query(..., description="搜索应用ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取搜索应用详情"""
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
@@ -139,20 +167,23 @@ def detail():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_search_app():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
@router.post('/list')
|
||||
async def list_search_app(
|
||||
query: ListSearchAppsQuery = Depends(),
|
||||
body: Optional[ListSearchAppsBody] = None,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出搜索应用"""
|
||||
if body is None:
|
||||
body = ListSearchAppsBody()
|
||||
|
||||
req = request.get_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
keywords = query.keywords or ""
|
||||
page_number = int(query.page or 0)
|
||||
items_per_page = int(query.page_size or 0)
|
||||
orderby = query.orderby or "create_time"
|
||||
desc = query.desc.lower() == "true" if query.desc else True
|
||||
|
||||
owner_ids = body.owner_ids or [] if body else []
|
||||
try:
|
||||
if not owner_ids:
|
||||
# tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
@@ -171,12 +202,13 @@ def list_search_app():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/rm", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("search_id")
|
||||
def rm():
|
||||
req = request.get_json()
|
||||
search_id = req["search_id"]
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
request: DeleteSearchRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除搜索应用"""
|
||||
search_id = request.search_id
|
||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
|
||||
@@ -15,12 +15,15 @@
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, Query, status
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, EmailStr
|
||||
@@ -60,6 +63,19 @@ from api.utils.api_utils import (
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import smtp_mail_server
|
||||
from api.utils.web_utils import (
|
||||
send_email_html,
|
||||
OTP_LENGTH,
|
||||
OTP_TTL_SECONDS,
|
||||
ATTEMPT_LIMIT,
|
||||
ATTEMPT_LOCK_SECONDS,
|
||||
RESEND_COOLDOWN_SECONDS,
|
||||
otp_keys,
|
||||
hash_code,
|
||||
captcha_key,
|
||||
)
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
@@ -77,9 +93,8 @@ class RegisterRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
class UserSettingRequest(BaseModel):
|
||||
nickname: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
new_password: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
class TenantInfoRequest(BaseModel):
|
||||
tenant_id: str
|
||||
@@ -88,6 +103,16 @@ class TenantInfoRequest(BaseModel):
|
||||
img2txt_id: str
|
||||
llm_id: str
|
||||
|
||||
class ForgetOtpRequest(BaseModel):
|
||||
email: str
|
||||
captcha: str
|
||||
|
||||
class ForgetPasswordRequest(BaseModel):
|
||||
email: str
|
||||
otp: str
|
||||
new_password: str
|
||||
confirm_new_password: str
|
||||
|
||||
# 依赖项:获取当前用户 - 从 auth_dependencies 导入
|
||||
|
||||
@router.post("/login")
|
||||
@@ -481,3 +506,357 @@ async def set_tenant_info(request: TenantInfoRequest, current_user = Depends(get
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
@router.get("/github_callback")
|
||||
async def github_callback(code: Optional[str] = Query(None)):
|
||||
"""
|
||||
**Deprecated**, Use `/oauth/callback/<channel>` instead.
|
||||
|
||||
GitHub OAuth callback endpoint.
|
||||
"""
|
||||
import requests
|
||||
|
||||
if not code:
|
||||
return RedirectResponse(url="/?error=missing_code")
|
||||
|
||||
res = requests.post(
|
||||
settings.GITHUB_OAUTH.get("url"),
|
||||
data={
|
||||
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
|
||||
"code": code,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
res = res.json()
|
||||
if "error" in res:
|
||||
return RedirectResponse(url=f"/?error={res.get('error_description', res.get('error'))}")
|
||||
|
||||
if "user:email" not in res.get("scope", "").split(","):
|
||||
return RedirectResponse(url="/?error=user:email not in scope")
|
||||
|
||||
access_token = res["access_token"]
|
||||
user_info = user_info_from_github(access_token)
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
|
||||
if not users:
|
||||
try:
|
||||
try:
|
||||
avatar = download_img(user_info["avatar_url"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
avatar = ""
|
||||
|
||||
users = user_register(
|
||||
user_id,
|
||||
{
|
||||
"access_token": access_token,
|
||||
"email": email_address,
|
||||
"avatar": avatar,
|
||||
"nickname": user_info["login"],
|
||||
"login_channel": "github",
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
},
|
||||
)
|
||||
|
||||
if not users:
|
||||
raise Exception(f"Fail to register {email_address}.")
|
||||
if len(users) > 1:
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
|
||||
user = users[0]
|
||||
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
logging.exception(e)
|
||||
return RedirectResponse(url=f"/?error={str(e)}")
|
||||
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return RedirectResponse(url="/?error=user_inactive")
|
||||
user.save()
|
||||
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||
|
||||
@router.get("/feishu_callback")
|
||||
async def feishu_callback(code: Optional[str] = Query(None)):
|
||||
"""
|
||||
Feishu OAuth callback endpoint.
|
||||
"""
|
||||
import requests
|
||||
|
||||
if not code:
|
||||
return RedirectResponse(url="/?error=missing_code")
|
||||
|
||||
app_access_token_res = requests.post(
|
||||
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"app_id": settings.FEISHU_OAUTH.get("app_id"),
|
||||
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
|
||||
}
|
||||
),
|
||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
)
|
||||
app_access_token_res = app_access_token_res.json()
|
||||
if app_access_token_res.get("code") != 0:
|
||||
return RedirectResponse(url=f"/?error={app_access_token_res}")
|
||||
|
||||
res = requests.post(
|
||||
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
|
||||
"code": code,
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {app_access_token_res['app_access_token']}",
|
||||
},
|
||||
)
|
||||
res = res.json()
|
||||
if res.get("code") != 0:
|
||||
return RedirectResponse(url=f"/?error={res.get('message', 'unknown_error')}")
|
||||
|
||||
if "contact:user.email:readonly" not in res.get("data", {}).get("scope", "").split():
|
||||
return RedirectResponse(url="/?error=contact:user.email:readonly not in scope")
|
||||
|
||||
access_token = res["data"]["access_token"]
|
||||
user_info = user_info_from_feishu(access_token)
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
|
||||
if not users:
|
||||
try:
|
||||
try:
|
||||
avatar = download_img(user_info["avatar_url"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
avatar = ""
|
||||
|
||||
users = user_register(
|
||||
user_id,
|
||||
{
|
||||
"access_token": access_token,
|
||||
"email": email_address,
|
||||
"avatar": avatar,
|
||||
"nickname": user_info["en_name"],
|
||||
"login_channel": "feishu",
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
},
|
||||
)
|
||||
|
||||
if not users:
|
||||
raise Exception(f"Fail to register {email_address}.")
|
||||
if len(users) > 1:
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
|
||||
user = users[0]
|
||||
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
logging.exception(e)
|
||||
return RedirectResponse(url=f"/?error={str(e)}")
|
||||
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return RedirectResponse(url="/?error=user_inactive")
|
||||
user.access_token = get_uuid()
|
||||
user.save()
|
||||
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||
|
||||
def user_info_from_feishu(access_token):
|
||||
"""从飞书获取用户信息"""
|
||||
import requests
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||
user_info = res.json()["data"]
|
||||
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
|
||||
return user_info
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
"""从GitHub获取用户信息"""
|
||||
import requests
|
||||
|
||||
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(
|
||||
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||
headers=headers,
|
||||
).json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return user_info
|
||||
|
||||
@router.get("/forget/captcha")
|
||||
async def forget_get_captcha(email: str = Query(...)):
|
||||
"""
|
||||
GET /forget/captcha?email=<email>
|
||||
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = 60 seconds.
|
||||
- Returns the captcha as a JPEG image.
|
||||
"""
|
||||
if not email:
|
||||
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email is required")
|
||||
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
|
||||
|
||||
# Generate captcha text
|
||||
allowed = string.ascii_uppercase + string.digits
|
||||
captcha_text = "".join(secrets.choice(allowed) for _ in range(OTP_LENGTH))
|
||||
REDIS_CONN.set(captcha_key(email), captcha_text, 60) # Valid for 60 seconds
|
||||
|
||||
from captcha.image import ImageCaptcha
|
||||
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
|
||||
img_bytes = image.generate(captcha_text).read()
|
||||
|
||||
return Response(content=img_bytes, media_type="image/JPEG")
|
||||
|
||||
@router.post("/forget/otp")
|
||||
async def forget_send_otp(request: ForgetOtpRequest):
|
||||
"""
|
||||
POST /forget/otp
|
||||
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||||
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||||
"""
|
||||
email = request.email or ""
|
||||
captcha = (request.captcha or "").strip()
|
||||
|
||||
if not email or not captcha:
|
||||
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email and captcha required")
|
||||
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
|
||||
|
||||
stored_captcha = REDIS_CONN.get(captcha_key(email))
|
||||
if not stored_captcha:
|
||||
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="invalid or expired captcha")
|
||||
if (stored_captcha or "").strip().lower() != captcha.lower():
|
||||
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="invalid or expired captcha")
|
||||
|
||||
# Delete captcha to prevent reuse
|
||||
REDIS_CONN.delete(captcha_key(email))
|
||||
|
||||
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||||
now = int(time.time())
|
||||
last_ts = REDIS_CONN.get(k_last)
|
||||
if last_ts:
|
||||
try:
|
||||
elapsed = now - int(last_ts)
|
||||
except Exception:
|
||||
elapsed = RESEND_COOLDOWN_SECONDS
|
||||
remaining = RESEND_COOLDOWN_SECONDS - elapsed
|
||||
if remaining > 0:
|
||||
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message=f"you still have to wait {remaining} seconds")
|
||||
|
||||
# Generate OTP (uppercase letters only) and store hashed
|
||||
otp = "".join(secrets.choice(string.ascii_uppercase) for _ in range(OTP_LENGTH))
|
||||
salt = os.urandom(16)
|
||||
code_hash = hash_code(otp, salt)
|
||||
REDIS_CONN.set(k_code, f"{code_hash}:{salt.hex()}", OTP_TTL_SECONDS)
|
||||
REDIS_CONN.set(k_attempts, 0, OTP_TTL_SECONDS)
|
||||
REDIS_CONN.set(k_last, now, OTP_TTL_SECONDS)
|
||||
REDIS_CONN.delete(k_lock)
|
||||
|
||||
ttl_min = OTP_TTL_SECONDS // 60
|
||||
|
||||
if not smtp_mail_server:
|
||||
logging.warning("SMTP mail server not initialized; skip sending email.")
|
||||
else:
|
||||
try:
|
||||
send_email_html(
|
||||
subject="Your Password Reset Code",
|
||||
to_email=email,
|
||||
template_key="reset_code",
|
||||
code=otp,
|
||||
ttl_min=ttl_min,
|
||||
)
|
||||
except Exception:
|
||||
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="failed to send email")
|
||||
|
||||
return get_json_result(data=True, code=settings.RetCode.SUCCESS, message="verification passed, email sent")
|
||||
|
||||
@router.post("/forget")
|
||||
async def forget(request: ForgetPasswordRequest):
|
||||
"""
|
||||
POST: Verify email + OTP and reset password, then log the user in.
|
||||
Request JSON: { email, otp, new_password, confirm_new_password }
|
||||
"""
|
||||
email = request.email or ""
|
||||
otp = (request.otp or "").strip()
|
||||
new_pwd = request.new_password
|
||||
new_pwd2 = request.confirm_new_password
|
||||
|
||||
if not all([email, otp, new_pwd, new_pwd2]):
|
||||
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required")
|
||||
|
||||
# For reset, passwords are provided as-is (no decrypt needed)
|
||||
if new_pwd != new_pwd2:
|
||||
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="passwords do not match")
|
||||
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
|
||||
|
||||
user = users[0]
|
||||
# Verify OTP from Redis
|
||||
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||||
if REDIS_CONN.get(k_lock):
|
||||
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="too many attempts, try later")
|
||||
|
||||
stored = REDIS_CONN.get(k_code)
|
||||
if not stored:
|
||||
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="expired otp")
|
||||
|
||||
try:
|
||||
stored_hash, salt_hex = str(stored).split(":", 1)
|
||||
salt = bytes.fromhex(salt_hex)
|
||||
except Exception:
|
||||
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
|
||||
|
||||
# Case-insensitive verification: OTP generated uppercase
|
||||
calc = hash_code(otp.upper(), salt)
|
||||
if calc != stored_hash:
|
||||
# bump attempts
|
||||
try:
|
||||
attempts = int(REDIS_CONN.get(k_attempts) or 0) + 1
|
||||
except Exception:
|
||||
attempts = 1
|
||||
REDIS_CONN.set(k_attempts, attempts, OTP_TTL_SECONDS)
|
||||
if attempts >= ATTEMPT_LIMIT:
|
||||
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
|
||||
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="expired otp")
|
||||
|
||||
# Success: consume OTP and reset password
|
||||
REDIS_CONN.delete(k_code)
|
||||
REDIS_CONN.delete(k_attempts)
|
||||
REDIS_CONN.delete(k_last)
|
||||
REDIS_CONN.delete(k_lock)
|
||||
|
||||
try:
|
||||
UserService.update_user_password(user.id, new_pwd)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="failed to reset password")
|
||||
|
||||
# Auto login (reuse login flow)
|
||||
user.access_token = get_uuid()
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.save()
|
||||
msg = "Password reset successful. Logged in."
|
||||
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||
|
||||
@@ -8,6 +8,10 @@ minio:
|
||||
user: 'rag_flow'
|
||||
password: 'infini_rag_flow'
|
||||
host: 'localhost:9000'
|
||||
es:
|
||||
hosts: 'http://localhost:1200'
|
||||
username: 'elastic'
|
||||
password: 'infini_rag_flow'
|
||||
os:
|
||||
hosts: 'http://localhost:1201'
|
||||
username: 'admin'
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# - `elasticsearch` (default)
|
||||
# - `infinity` (https://github.com/infiniflow/infinity)
|
||||
# - `opensearch` (https://github.com/opensearch-project/OpenSearch)
|
||||
DOC_ENGINE=opensearch
|
||||
DOC_ENGINE=elasticsearch
|
||||
|
||||
# ------------------------------
|
||||
# docker env var for specifying vector db type at startup
|
||||
@@ -98,7 +98,7 @@ ADMIN_SVR_HTTP_PORT=9381
|
||||
|
||||
# The RAGFlow Docker image to download.
|
||||
# Defaults to the v0.21.1-slim edition, which is the RAGFlow Docker image without embedding models.
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi-web
|
||||
#
|
||||
# To download the RAGFlow Docker image with embedding models, uncomment the following line instead:
|
||||
# RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1
|
||||
|
||||
@@ -1,36 +1,34 @@
|
||||
services:
|
||||
|
||||
opensearch01:
|
||||
container_name: ragflow-opensearch-01
|
||||
es01:
|
||||
container_name: ragflow-es-01
|
||||
profiles:
|
||||
- opensearch
|
||||
image: hub.icert.top/opensearchproject/opensearch:2.19.1
|
||||
- elasticsearch
|
||||
image: elasticsearch:${STACK_VERSION}
|
||||
volumes:
|
||||
- osdata01:/usr/share/opensearch/data
|
||||
- esdata01:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- ${OS_PORT}:9201
|
||||
- ${ES_PORT}:9200
|
||||
env_file: .env
|
||||
environment:
|
||||
- node.name=opensearch01
|
||||
- OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD}
|
||||
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_PASSWORD}
|
||||
- node.name=es01
|
||||
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
|
||||
- bootstrap.memory_lock=false
|
||||
- discovery.type=single-node
|
||||
- plugins.security.disabled=false
|
||||
- plugins.security.ssl.http.enabled=false
|
||||
- plugins.security.ssl.transport.enabled=true
|
||||
- xpack.security.enabled=true
|
||||
- xpack.security.http.ssl.enabled=false
|
||||
- xpack.security.transport.ssl.enabled=false
|
||||
- cluster.routing.allocation.disk.watermark.low=5gb
|
||||
- cluster.routing.allocation.disk.watermark.high=3gb
|
||||
- cluster.routing.allocation.disk.watermark.flood_stage=2gb
|
||||
- TZ=${TIMEZONE}
|
||||
- http.port=9201
|
||||
mem_limit: ${MEM_LIMIT}
|
||||
ulimits:
|
||||
memlock:
|
||||
soft: -1
|
||||
hard: -1
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "curl http://localhost:9201"]
|
||||
test: ["CMD-SHELL", "curl http://localhost:9200"]
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 120
|
||||
|
||||
@@ -64,7 +64,8 @@ class Dealer:
|
||||
if key in req and req[key] is not None:
|
||||
condition[field] = req[key]
|
||||
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
||||
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
|
||||
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
|
||||
"removed_kwd"]:
|
||||
if key in req and req[key] is not None:
|
||||
condition[key] = req[key]
|
||||
return condition
|
||||
@@ -135,7 +136,8 @@ class Dealer:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
orderBy, offset, limit, idx_names, kb_ids,
|
||||
rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||
|
||||
@@ -212,8 +214,9 @@ class Dealer:
|
||||
ans_v, _ = embd_mdl.encode(pieces_)
|
||||
for i in range(len(chunk_v)):
|
||||
if len(ans_v[0]) != len(chunk_v[i]):
|
||||
chunk_v[i] = [0.0]*len(ans_v[0])
|
||||
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
||||
chunk_v[i] = [0.0] * len(ans_v[0])
|
||||
logging.warning(
|
||||
"The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
||||
|
||||
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
||||
len(ans_v[0]), len(chunk_v[0]))
|
||||
@@ -267,7 +270,7 @@ class Dealer:
|
||||
if not query_rfea:
|
||||
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
||||
|
||||
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||
for i in search_res.ids:
|
||||
nor, denor = 0, 0
|
||||
if not search_res.field[i].get(TAG_FLD):
|
||||
@@ -280,8 +283,8 @@ class Dealer:
|
||||
if denor == 0:
|
||||
rank_fea.append(0)
|
||||
else:
|
||||
rank_fea.append(nor/np.sqrt(denor)/q_denor)
|
||||
return np.array(rank_fea)*10. + pageranks
|
||||
rank_fea.append(nor / np.sqrt(denor) / q_denor)
|
||||
return np.array(rank_fea) * 10. + pageranks
|
||||
|
||||
def rerank(self, sres, query, tkweight=0.3,
|
||||
vtweight=0.7, cfield="content_ltks",
|
||||
@@ -343,7 +346,7 @@ class Dealer:
|
||||
## For rank feature(tag_fea) scores.
|
||||
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
||||
|
||||
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
|
||||
return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim
|
||||
|
||||
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
||||
return self.qryr.hybrid_similarity(ans_embd,
|
||||
@@ -360,13 +363,13 @@ class Dealer:
|
||||
return ranks
|
||||
|
||||
# Ensure RERANK_LIMIT is multiple of page_size
|
||||
RERANK_LIMIT = math.ceil(64/page_size) * page_size if page_size>1 else 1
|
||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
|
||||
RERANK_LIMIT = math.ceil(64 / page_size) * page_size if page_size > 1 else 1
|
||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size * page / RERANK_LIMIT),
|
||||
"size": RERANK_LIMIT,
|
||||
"question": question, "vector": True, "topk": top,
|
||||
"similarity": similarity_threshold,
|
||||
"available_int": 1}
|
||||
|
||||
|
||||
if isinstance(tenant_ids, str):
|
||||
tenant_ids = tenant_ids.split(",")
|
||||
|
||||
@@ -392,8 +395,8 @@ class Dealer:
|
||||
tsim = sim
|
||||
vsim = sim
|
||||
# Already paginated in search function
|
||||
begin = ((page % (RERANK_LIMIT//page_size)) - 1) * page_size
|
||||
sim = sim[begin : begin + page_size]
|
||||
begin = ((page % (RERANK_LIMIT // page_size)) - 1) * page_size
|
||||
sim = sim[begin: begin + page_size]
|
||||
sim_np = np.array(sim)
|
||||
idx = np.argsort(sim_np * -1)
|
||||
dim = len(sres.query_vector)
|
||||
@@ -505,13 +508,14 @@ class Dealer:
|
||||
|
||||
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
|
||||
idx_nm = index_name(tenant_id)
|
||||
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
|
||||
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
|
||||
keywords_topn)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return False
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
|
||||
return True
|
||||
@@ -527,11 +531,11 @@ class Dealer:
|
||||
if not aggs:
|
||||
return {}
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||
key=lambda x: x[1] * -1)[:topn_tags]
|
||||
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
|
||||
|
||||
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6):
|
||||
def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6):
|
||||
if not chunks:
|
||||
return []
|
||||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||
@@ -541,9 +545,10 @@ class Dealer:
|
||||
ranks[ck["doc_id"]] = 0
|
||||
ranks[ck["doc_id"]] += ck["similarity"]
|
||||
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
|
||||
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0]
|
||||
doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
|
||||
kb_ids = [doc_id2kb_id[doc_id]]
|
||||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
|
||||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [],
|
||||
OrderByExpr(), 0, 128, idx_nms,
|
||||
kb_ids)
|
||||
toc = []
|
||||
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
|
||||
@@ -555,7 +560,7 @@ class Dealer:
|
||||
if not toc:
|
||||
return chunks
|
||||
|
||||
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)
|
||||
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2)
|
||||
if not ids:
|
||||
return chunks
|
||||
|
||||
@@ -589,4 +594,4 @@ class Dealer:
|
||||
break
|
||||
chunks.append(d)
|
||||
|
||||
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]
|
||||
return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn]
|
||||
|
||||
35
rag/res/deepdoc/.gitattributes
vendored
Normal file
35
rag/res/deepdoc/.gitattributes
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
BIN
rag/res/deepdoc/layout.laws.onnx
Normal file
BIN
rag/res/deepdoc/layout.laws.onnx
Normal file
Binary file not shown.
BIN
rag/res/deepdoc/layout.manual.onnx
Normal file
BIN
rag/res/deepdoc/layout.manual.onnx
Normal file
Binary file not shown.
BIN
rag/res/deepdoc/layout.onnx
Normal file
BIN
rag/res/deepdoc/layout.onnx
Normal file
Binary file not shown.
BIN
rag/res/deepdoc/layout.paper.onnx
Normal file
BIN
rag/res/deepdoc/layout.paper.onnx
Normal file
Binary file not shown.
BIN
rag/res/deepdoc/tsr.onnx
Normal file
BIN
rag/res/deepdoc/tsr.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user