修复 画布,mcp服务,搜索,文档的接口
This commit is contained in:
@@ -165,17 +165,23 @@ def setup_routes(app: FastAPI):
|
|||||||
from api.apps.tenant_app import router as tenant_router
|
from api.apps.tenant_app import router as tenant_router
|
||||||
from api.apps.dialog_app import router as dialog_router
|
from api.apps.dialog_app import router as dialog_router
|
||||||
from api.apps.system_app import router as system_router
|
from api.apps.system_app import router as system_router
|
||||||
|
from api.apps.search_app import router as search_router
|
||||||
|
from api.apps.conversation_app import router as conversation_router
|
||||||
|
from api.apps.file_app import router as file_router
|
||||||
|
|
||||||
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
|
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
|
||||||
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"])
|
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"])
|
||||||
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
|
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
|
||||||
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
|
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
|
||||||
app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"])
|
app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"])
|
||||||
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
|
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp_server", tags=["MCP"])
|
||||||
app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"])
|
app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"])
|
||||||
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
|
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
|
||||||
app.include_router(dialog_router, prefix=f"/{API_VERSION}/dialog", tags=["Dialog"])
|
app.include_router(dialog_router, prefix=f"/{API_VERSION}/dialog", tags=["Dialog"])
|
||||||
app.include_router(system_router, prefix=f"/{API_VERSION}/system", tags=["System"])
|
app.include_router(system_router, prefix=f"/{API_VERSION}/system", tags=["System"])
|
||||||
|
app.include_router(search_router, prefix=f"/{API_VERSION}/search", tags=["Search"])
|
||||||
|
app.include_router(conversation_router, prefix=f"/{API_VERSION}/conversation", tags=["Conversation"])
|
||||||
|
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,10 @@ import json
|
|||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from flask import Response, request
|
from typing import Optional
|
||||||
from flask_login import current_user, login_required
|
from fastapi import APIRouter, Depends, Query, Header, HTTPException, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
@@ -28,15 +30,35 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
|
||||||
|
from api.utils import get_uuid
|
||||||
from rag.prompts.template import load_prompt
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.generator import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
|
|
||||||
|
from api.apps.models.auth_dependencies import get_current_user
|
||||||
|
from api.apps.models.conversation_models import (
|
||||||
|
SetConversationRequest,
|
||||||
|
DeleteConversationsRequest,
|
||||||
|
CompletionRequest,
|
||||||
|
TTSRequest,
|
||||||
|
DeleteMessageRequest,
|
||||||
|
ThumbupRequest,
|
||||||
|
AskRequest,
|
||||||
|
MindmapRequest,
|
||||||
|
RelatedQuestionsRequest,
|
||||||
|
)
|
||||||
|
|
||||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
# 创建路由器
|
||||||
@login_required
|
router = APIRouter()
|
||||||
def set_conversation():
|
|
||||||
req = request.json
|
|
||||||
|
@router.post('/set')
|
||||||
|
async def set_conversation(
|
||||||
|
request: SetConversationRequest,
|
||||||
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""设置对话"""
|
||||||
|
req = request.model_dump(exclude_unset=True)
|
||||||
conv_id = req.get("conversation_id")
|
conv_id = req.get("conversation_id")
|
||||||
is_new = req.get("is_new")
|
is_new = req.get("is_new")
|
||||||
name = req.get("name", "New conversation")
|
name = req.get("name", "New conversation")
|
||||||
@@ -45,9 +67,9 @@ def set_conversation():
|
|||||||
if len(name) > 255:
|
if len(name) > 255:
|
||||||
name = name[0:255]
|
name = name[0:255]
|
||||||
|
|
||||||
del req["is_new"]
|
|
||||||
if not is_new:
|
if not is_new:
|
||||||
del req["conversation_id"]
|
if not conv_id:
|
||||||
|
return get_data_error_result(message="conversation_id is required when is_new is False!")
|
||||||
try:
|
try:
|
||||||
if not ConversationService.update_by_id(conv_id, req):
|
if not ConversationService.update_by_id(conv_id, req):
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
@@ -64,7 +86,7 @@ def set_conversation():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Dialog not found")
|
return get_data_error_result(message="Dialog not found")
|
||||||
conv = {
|
conv = {
|
||||||
"id": conv_id,
|
"id": conv_id or get_uuid(),
|
||||||
"dialog_id": req["dialog_id"],
|
"dialog_id": req["dialog_id"],
|
||||||
"name": name,
|
"name": name,
|
||||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
|
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
|
||||||
@@ -77,12 +99,14 @@ def set_conversation():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/get", methods=["GET"]) # noqa: F821
|
@router.get('/get')
|
||||||
@login_required
|
async def get(
|
||||||
def get():
|
conversation_id: str = Query(..., description="对话ID"),
|
||||||
conv_id = request.args["conversation_id"]
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取对话"""
|
||||||
try:
|
try:
|
||||||
e, conv = ConversationService.get_by_id(conv_id)
|
e, conv = ConversationService.get_by_id(conversation_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
@@ -107,15 +131,27 @@ def get():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
|
@router.get('/getsse/{dialog_id}')
|
||||||
def getsse(dialog_id):
|
async def getsse(
|
||||||
token = request.headers.get("Authorization").split()
|
dialog_id: str,
|
||||||
if len(token) != 2:
|
authorization: Optional[str] = Header(None, alias="Authorization")
|
||||||
|
):
|
||||||
|
"""通过 SSE 获取对话(使用 API token 认证)"""
|
||||||
|
if not authorization:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authorization header is required"
|
||||||
|
)
|
||||||
|
|
||||||
|
token_parts = authorization.split()
|
||||||
|
if len(token_parts) != 2:
|
||||||
return get_data_error_result(message='Authorization is not valid!"')
|
return get_data_error_result(message='Authorization is not valid!"')
|
||||||
token = token[1]
|
token = token_parts[1]
|
||||||
|
|
||||||
objs = APIToken.query(beta=token)
|
objs = APIToken.query(beta=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_data_error_result(message='Authentication error: API key is invalid!"')
|
return get_data_error_result(message='Authentication error: API key is invalid!"')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, conv = DialogService.get_by_id(dialog_id)
|
e, conv = DialogService.get_by_id(dialog_id)
|
||||||
if not e:
|
if not e:
|
||||||
@@ -128,12 +164,14 @@ def getsse(dialog_id):
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
@router.post('/rm')
|
||||||
@login_required
|
async def rm(
|
||||||
def rm():
|
request: DeleteConversationsRequest,
|
||||||
conv_ids = request.json["conversation_ids"]
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""删除对话"""
|
||||||
try:
|
try:
|
||||||
for cid in conv_ids:
|
for cid in request.conversation_ids:
|
||||||
exist, conv = ConversationService.get_by_id(cid)
|
exist, conv = ConversationService.get_by_id(cid)
|
||||||
if not exist:
|
if not exist:
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
@@ -149,10 +187,12 @@ def rm():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
@router.get('/list')
|
||||||
@login_required
|
async def list_conversation(
|
||||||
def list_conversation():
|
dialog_id: str = Query(..., description="对话ID"),
|
||||||
dialog_id = request.args["dialog_id"]
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""列出对话"""
|
||||||
try:
|
try:
|
||||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||||
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||||
@@ -164,11 +204,13 @@ def list_conversation():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/completion", methods=["POST"]) # noqa: F821
|
@router.post('/completion')
|
||||||
@login_required
|
async def completion(
|
||||||
@validate_request("conversation_id", "messages")
|
request: CompletionRequest,
|
||||||
def completion():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
|
"""完成请求(聊天完成)"""
|
||||||
|
req = request.model_dump(exclude_unset=True)
|
||||||
msg = []
|
msg = []
|
||||||
for m in req["messages"]:
|
for m in req["messages"]:
|
||||||
if m["role"] == "system":
|
if m["role"] == "system":
|
||||||
@@ -176,6 +218,10 @@ def completion():
|
|||||||
if m["role"] == "assistant" and not msg:
|
if m["role"] == "assistant" and not msg:
|
||||||
continue
|
continue
|
||||||
msg.append(m)
|
msg.append(m)
|
||||||
|
|
||||||
|
if not msg:
|
||||||
|
return get_data_error_result(message="No valid messages found!")
|
||||||
|
|
||||||
message_id = msg[-1].get("id")
|
message_id = msg[-1].get("id")
|
||||||
chat_model_id = req.get("llm_id", "")
|
chat_model_id = req.get("llm_id", "")
|
||||||
req.pop("llm_id", None)
|
req.pop("llm_id", None)
|
||||||
@@ -217,6 +263,7 @@ def completion():
|
|||||||
dia.llm_setting = chat_model_config
|
dia.llm_setting = chat_model_config
|
||||||
|
|
||||||
is_embedded = bool(chat_model_id)
|
is_embedded = bool(chat_model_id)
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
nonlocal dia, msg, req, conv
|
nonlocal dia, msg, req, conv
|
||||||
try:
|
try:
|
||||||
@@ -230,14 +277,18 @@ def completion():
|
|||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
if req.get("stream", True):
|
stream_enabled = request.stream if request.stream is not None else True
|
||||||
resp = Response(stream(), mimetype="text/event-stream")
|
if stream_enabled:
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
return StreamingResponse(
|
||||||
resp.headers.add_header("Connection", "keep-alive")
|
stream(),
|
||||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
media_type="text/event-stream",
|
||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
headers={
|
||||||
return resp
|
"Cache-control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Content-Type": "text/event-stream; charset=utf-8"
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, **req):
|
for ans in chat(dia, msg, **req):
|
||||||
@@ -250,11 +301,13 @@ def completion():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
@router.post('/tts')
|
||||||
@login_required
|
async def tts(
|
||||||
def tts():
|
request: TTSRequest,
|
||||||
req = request.json
|
current_user = Depends(get_current_user)
|
||||||
text = req["text"]
|
):
|
||||||
|
"""文本转语音"""
|
||||||
|
text = request.text
|
||||||
|
|
||||||
tenants = TenantService.get_info_by(current_user.id)
|
tenants = TenantService.get_info_by(current_user.id)
|
||||||
if not tenants:
|
if not tenants:
|
||||||
@@ -274,28 +327,32 @@ def tts():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
||||||
|
|
||||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
return StreamingResponse(
|
||||||
resp.headers.add_header("Cache-Control", "no-cache")
|
stream_audio(),
|
||||||
resp.headers.add_header("Connection", "keep-alive")
|
media_type="audio/mpeg",
|
||||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
return resp
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
|
@router.post('/delete_msg')
|
||||||
@login_required
|
async def delete_msg(
|
||||||
@validate_request("conversation_id", "message_id")
|
request: DeleteMessageRequest,
|
||||||
def delete_msg():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
"""删除消息"""
|
||||||
|
e, conv = ConversationService.get_by_id(request.conversation_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
|
|
||||||
conv = conv.to_dict()
|
conv = conv.to_dict()
|
||||||
for i, msg in enumerate(conv["message"]):
|
for i, msg in enumerate(conv["message"]):
|
||||||
if req["message_id"] != msg.get("id", ""):
|
if request.message_id != msg.get("id", ""):
|
||||||
continue
|
continue
|
||||||
assert conv["message"][i + 1]["id"] == req["message_id"]
|
assert conv["message"][i + 1]["id"] == request.message_id
|
||||||
conv["message"].pop(i)
|
conv["message"].pop(i)
|
||||||
conv["message"].pop(i)
|
conv["message"].pop(i)
|
||||||
conv["reference"].pop(max(0, i // 2 - 1))
|
conv["reference"].pop(max(0, i // 2 - 1))
|
||||||
@@ -305,19 +362,21 @@ def delete_msg():
|
|||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
|
@router.post('/thumbup')
|
||||||
@login_required
|
async def thumbup(
|
||||||
@validate_request("conversation_id", "message_id")
|
request: ThumbupRequest,
|
||||||
def thumbup():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
"""点赞/点踩"""
|
||||||
|
e, conv = ConversationService.get_by_id(request.conversation_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Conversation not found!")
|
return get_data_error_result(message="Conversation not found!")
|
||||||
up_down = req.get("thumbup")
|
|
||||||
feedback = req.get("feedback", "")
|
up_down = request.thumbup
|
||||||
|
feedback = request.feedback or ""
|
||||||
conv = conv.to_dict()
|
conv = conv.to_dict()
|
||||||
for i, msg in enumerate(conv["message"]):
|
for i, msg in enumerate(conv["message"]):
|
||||||
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
|
if request.message_id == msg.get("id", "") and msg.get("role", "") == "assistant":
|
||||||
if up_down:
|
if up_down:
|
||||||
msg["thumbup"] = True
|
msg["thumbup"] = True
|
||||||
if "feedback" in msg:
|
if "feedback" in msg:
|
||||||
@@ -332,14 +391,15 @@ def thumbup():
|
|||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/ask", methods=["POST"]) # noqa: F821
|
@router.post('/ask')
|
||||||
@login_required
|
async def ask_about(
|
||||||
@validate_request("question", "kb_ids")
|
request: AskRequest,
|
||||||
def ask_about():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
|
"""提问"""
|
||||||
uid = current_user.id
|
uid = current_user.id
|
||||||
|
|
||||||
search_id = req.get("search_id", "")
|
search_id = request.search_id or ""
|
||||||
search_app = None
|
search_app = None
|
||||||
search_config = {}
|
search_config = {}
|
||||||
if search_id:
|
if search_id:
|
||||||
@@ -348,53 +408,58 @@ def ask_about():
|
|||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
nonlocal req, uid
|
nonlocal request, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
for ans in ask(request.question, request.kb_ids, uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
resp = Response(stream(), mimetype="text/event-stream")
|
return StreamingResponse(
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
stream(),
|
||||||
resp.headers.add_header("Connection", "keep-alive")
|
media_type="text/event-stream",
|
||||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
headers={
|
||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
"Cache-control": "no-cache",
|
||||||
return resp
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Content-Type": "text/event-stream; charset=utf-8"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
|
@router.post('/mindmap')
|
||||||
@login_required
|
async def mindmap(
|
||||||
@validate_request("question", "kb_ids")
|
request: MindmapRequest,
|
||||||
def mindmap():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
search_id = req.get("search_id", "")
|
"""思维导图"""
|
||||||
|
search_id = request.search_id or ""
|
||||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||||
kb_ids = search_config.get("kb_ids", [])
|
kb_ids = search_config.get("kb_ids", [])
|
||||||
kb_ids.extend(req["kb_ids"])
|
kb_ids.extend(request.kb_ids)
|
||||||
kb_ids = list(set(kb_ids))
|
kb_ids = list(set(kb_ids))
|
||||||
|
|
||||||
mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
mind_map = gen_mindmap(request.question, kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
||||||
if "error" in mind_map:
|
if "error" in mind_map:
|
||||||
return server_error_response(Exception(mind_map["error"]))
|
return server_error_response(Exception(mind_map["error"]))
|
||||||
return get_json_result(data=mind_map)
|
return get_json_result(data=mind_map)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
|
@router.post('/related_questions')
|
||||||
@login_required
|
async def related_questions(
|
||||||
@validate_request("question")
|
request: RelatedQuestionsRequest,
|
||||||
def related_questions():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
|
"""相关问题"""
|
||||||
search_id = req.get("search_id", "")
|
search_id = request.search_id or ""
|
||||||
search_config = {}
|
search_config = {}
|
||||||
if search_id:
|
if search_id:
|
||||||
if search_app := SearchService.get_detail(search_id):
|
if search_app := SearchService.get_detail(search_id):
|
||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
question = req["question"]
|
question = request.question
|
||||||
|
|
||||||
chat_id = search_config.get("chat_id", "")
|
chat_id = search_config.get("chat_id", "")
|
||||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)
|
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)
|
||||||
|
|||||||
@@ -72,18 +72,17 @@ router = APIRouter()
|
|||||||
@router.post("/upload")
|
@router.post("/upload")
|
||||||
async def upload(
|
async def upload(
|
||||||
kb_id: str = Form(...),
|
kb_id: str = Form(...),
|
||||||
files: List[UploadFile] = File(...),
|
file: UploadFile = File(...),
|
||||||
current_user = Depends(get_current_user)
|
current_user = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""上传文档"""
|
"""上传文档"""
|
||||||
if not files:
|
if not file:
|
||||||
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
for file_obj in files:
|
if not file.filename or file.filename == "":
|
||||||
if not file_obj.filename or file_obj.filename == "":
|
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
|
if len(file.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||||
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
|
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
@@ -91,7 +90,7 @@ async def upload(
|
|||||||
if not check_kb_team_permission(kb, current_user.id):
|
if not check_kb_team_permission(kb, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
err, uploaded_files = FileService.upload_document(kb, files, current_user.id)
|
err, uploaded_files = FileService.upload_document(kb, [file], current_user.id)
|
||||||
if err:
|
if err:
|
||||||
return get_json_result(data=uploaded_files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
return get_json_result(data=uploaded_files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|||||||
@@ -17,15 +17,23 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
import flask
|
from fastapi import APIRouter, Depends, Query, UploadFile, File, Form
|
||||||
from flask import request
|
from fastapi.responses import Response
|
||||||
from flask_login import login_required, current_user
|
|
||||||
|
from api.apps.models.auth_dependencies import get_current_user
|
||||||
|
from api.apps.models.file_models import (
|
||||||
|
CreateFileRequest,
|
||||||
|
DeleteFilesRequest,
|
||||||
|
RenameFileRequest,
|
||||||
|
MoveFilesRequest,
|
||||||
|
)
|
||||||
|
|
||||||
from api.common.check_team_permission import check_file_team_permission
|
from api.common.check_team_permission import check_file_team_permission
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import FileType, FileSource
|
from api.db import FileType, FileSource
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
@@ -36,35 +44,41 @@ from api.utils.file_utils import filename_type
|
|||||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
|
# 创建路由器
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
|
||||||
@login_required
|
@router.post('/upload')
|
||||||
# @validate_request("parent_id")
|
async def upload(
|
||||||
def upload():
|
files: List[UploadFile] = File(...),
|
||||||
pf_id = request.form.get("parent_id")
|
parent_id: Optional[str] = Form(None),
|
||||||
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""上传文件"""
|
||||||
|
pf_id = parent_id
|
||||||
|
|
||||||
if not pf_id:
|
if not pf_id:
|
||||||
root_folder = FileService.get_root_folder(current_user.id)
|
root_folder = FileService.get_root_folder(current_user.id)
|
||||||
pf_id = root_folder["id"]
|
pf_id = root_folder["id"]
|
||||||
|
|
||||||
if 'file' not in request.files:
|
if not files:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
file_objs = request.files.getlist('file')
|
|
||||||
|
|
||||||
for file_obj in file_objs:
|
for file_obj in files:
|
||||||
if file_obj.filename == '':
|
if not file_obj.filename or file_obj.filename == '':
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
file_res = []
|
file_res = []
|
||||||
try:
|
try:
|
||||||
e, pf_folder = FileService.get_by_id(pf_id)
|
e, pf_folder = FileService.get_by_id(pf_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result( message="Can't find this folder!")
|
return get_data_error_result(message="Can't find this folder!")
|
||||||
for file_obj in file_objs:
|
for file_obj in files:
|
||||||
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
|
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
|
||||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
return get_data_error_result(message="Exceed the maximum file number of a free user!")
|
||||||
|
|
||||||
# split file name path
|
# split file name path
|
||||||
if not file_obj.filename:
|
if not file_obj.filename:
|
||||||
@@ -97,7 +111,7 @@ def upload():
|
|||||||
location = file_obj_names[file_len - 1]
|
location = file_obj_names[file_len - 1]
|
||||||
while STORAGE_IMPL.obj_exist(last_folder.id, location):
|
while STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||||
location += "_"
|
location += "_"
|
||||||
blob = file_obj.read()
|
blob = await file_obj.read()
|
||||||
filename = duplicate_name(
|
filename = duplicate_name(
|
||||||
FileService.query,
|
FileService.query,
|
||||||
name=file_obj_names[file_len - 1],
|
name=file_obj_names[file_len - 1],
|
||||||
@@ -120,13 +134,16 @@ def upload():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
@router.post('/create')
|
||||||
@login_required
|
async def create(
|
||||||
@validate_request("name")
|
request: CreateFileRequest,
|
||||||
def create():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
pf_id = request.json.get("parent_id")
|
"""创建文件/文件夹"""
|
||||||
input_file_type = request.json.get("type")
|
req = request.model_dump(exclude_unset=True)
|
||||||
|
pf_id = req.get("parent_id")
|
||||||
|
input_file_type = req.get("type")
|
||||||
|
|
||||||
if not pf_id:
|
if not pf_id:
|
||||||
root_folder = FileService.get_root_folder(current_user.id)
|
root_folder = FileService.get_root_folder(current_user.id)
|
||||||
pf_id = root_folder["id"]
|
pf_id = root_folder["id"]
|
||||||
@@ -160,17 +177,22 @@ def create():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
@router.get('/list')
|
||||||
@login_required
|
async def list_files(
|
||||||
def list_files():
|
parent_id: Optional[str] = Query(None, description="父文件夹ID"),
|
||||||
pf_id = request.args.get("parent_id")
|
keywords: Optional[str] = Query("", description="搜索关键词"),
|
||||||
|
page: Optional[int] = Query(1, description="页码"),
|
||||||
|
page_size: Optional[int] = Query(15, description="每页数量"),
|
||||||
|
orderby: Optional[str] = Query("create_time", description="排序字段"),
|
||||||
|
desc: Optional[bool] = Query(True, description="是否降序"),
|
||||||
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""列出文件"""
|
||||||
|
pf_id = parent_id
|
||||||
|
|
||||||
keywords = request.args.get("keywords", "")
|
page_number = int(page) if page else 1
|
||||||
|
items_per_page = int(page_size) if page_size else 15
|
||||||
|
|
||||||
page_number = int(request.args.get("page", 1))
|
|
||||||
items_per_page = int(request.args.get("page_size", 15))
|
|
||||||
orderby = request.args.get("orderby", "create_time")
|
|
||||||
desc = request.args.get("desc", True)
|
|
||||||
if not pf_id:
|
if not pf_id:
|
||||||
root_folder = FileService.get_root_folder(current_user.id)
|
root_folder = FileService.get_root_folder(current_user.id)
|
||||||
pf_id = root_folder["id"]
|
pf_id = root_folder["id"]
|
||||||
@@ -192,9 +214,11 @@ def list_files():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/root_folder', methods=['GET']) # noqa: F821
|
@router.get('/root_folder')
|
||||||
@login_required
|
async def get_root_folder(
|
||||||
def get_root_folder():
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取根文件夹"""
|
||||||
try:
|
try:
|
||||||
root_folder = FileService.get_root_folder(current_user.id)
|
root_folder = FileService.get_root_folder(current_user.id)
|
||||||
return get_json_result(data={"root_folder": root_folder})
|
return get_json_result(data={"root_folder": root_folder})
|
||||||
@@ -202,10 +226,12 @@ def get_root_folder():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/parent_folder', methods=['GET']) # noqa: F821
|
@router.get('/parent_folder')
|
||||||
@login_required
|
async def get_parent_folder(
|
||||||
def get_parent_folder():
|
file_id: str = Query(..., description="文件ID"),
|
||||||
file_id = request.args.get("file_id")
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取父文件夹"""
|
||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e:
|
if not e:
|
||||||
@@ -217,10 +243,12 @@ def get_parent_folder():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/all_parent_folder', methods=['GET']) # noqa: F821
|
@router.get('/all_parent_folder')
|
||||||
@login_required
|
async def get_all_parent_folders(
|
||||||
def get_all_parent_folders():
|
file_id: str = Query(..., description="文件ID"),
|
||||||
file_id = request.args.get("file_id")
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取所有父文件夹"""
|
||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e:
|
if not e:
|
||||||
@@ -235,12 +263,13 @@ def get_all_parent_folders():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
@router.post("/rm")
|
||||||
@login_required
|
async def rm(
|
||||||
@validate_request("file_ids")
|
request: DeleteFilesRequest,
|
||||||
def rm():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
file_ids = req["file_ids"]
|
"""删除文件"""
|
||||||
|
file_ids = request.file_ids
|
||||||
|
|
||||||
def _delete_single_file(file):
|
def _delete_single_file(file):
|
||||||
try:
|
try:
|
||||||
@@ -296,11 +325,13 @@ def rm():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/rename', methods=['POST']) # noqa: F821
|
@router.post('/rename')
|
||||||
@login_required
|
async def rename(
|
||||||
@validate_request("file_id", "name")
|
request: RenameFileRequest,
|
||||||
def rename():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
|
"""重命名文件"""
|
||||||
|
req = request.model_dump()
|
||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(req["file_id"])
|
e, file = FileService.get_by_id(req["file_id"])
|
||||||
if not e:
|
if not e:
|
||||||
@@ -314,8 +345,8 @@ def rename():
|
|||||||
data=False,
|
data=False,
|
||||||
message="The extension of file can't be changed",
|
message="The extension of file can't be changed",
|
||||||
code=settings.RetCode.ARGUMENT_ERROR)
|
code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||||
if file.name == req["name"]:
|
if existing_file.name == req["name"]:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Duplicated file name in the same folder.")
|
message="Duplicated file name in the same folder.")
|
||||||
|
|
||||||
@@ -336,9 +367,12 @@ def rename():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
|
@router.get('/get/{file_id}')
|
||||||
@login_required
|
async def get(
|
||||||
def get(file_id):
|
file_id: str,
|
||||||
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取文件内容"""
|
||||||
try:
|
try:
|
||||||
e, file = FileService.get_by_id(file_id)
|
e, file = FileService.get_by_id(file_id)
|
||||||
if not e:
|
if not e:
|
||||||
@@ -351,25 +385,28 @@ def get(file_id):
|
|||||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||||
blob = STORAGE_IMPL.get(b, n)
|
blob = STORAGE_IMPL.get(b, n)
|
||||||
|
|
||||||
response = flask.make_response(blob)
|
|
||||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||||
ext = ext.group(1) if ext else None
|
ext = ext.group(1) if ext else None
|
||||||
|
|
||||||
|
content_type = "application/octet-stream"
|
||||||
if ext:
|
if ext:
|
||||||
if file.type == FileType.VISUAL.value:
|
if file.type == FileType.VISUAL.value:
|
||||||
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||||
else:
|
else:
|
||||||
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||||
response.headers.set("Content-Type", content_type)
|
|
||||||
return response
|
return Response(content=blob, media_type=content_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/mv", methods=["POST"]) # noqa: F821
|
@router.post("/mv")
|
||||||
@login_required
|
async def move(
|
||||||
@validate_request("src_file_ids", "dest_file_id")
|
request: MoveFilesRequest,
|
||||||
def move():
|
current_user = Depends(get_current_user)
|
||||||
req = request.json
|
):
|
||||||
|
"""移动文件"""
|
||||||
|
req = request.model_dump()
|
||||||
try:
|
try:
|
||||||
file_ids = req["src_file_ids"]
|
file_ids = req["src_file_ids"]
|
||||||
dest_parent_id = req["dest_file_id"]
|
dest_parent_id = req["dest_file_id"]
|
||||||
|
|||||||
@@ -169,14 +169,17 @@ async def update(
|
|||||||
):
|
):
|
||||||
"""更新知识库"""
|
"""更新知识库"""
|
||||||
req = request.model_dump(exclude_unset=True)
|
req = request.model_dump(exclude_unset=True)
|
||||||
if not isinstance(req["name"], str):
|
|
||||||
return get_data_error_result(message="Dataset name must be string.")
|
# 验证 name 字段(如果提供)
|
||||||
if req["name"].strip() == "":
|
if "name" in req:
|
||||||
return get_data_error_result(message="Dataset name can't be empty.")
|
if not isinstance(req["name"], str):
|
||||||
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
|
return get_data_error_result(message="Dataset name must be string.")
|
||||||
return get_data_error_result(
|
if req["name"].strip() == "":
|
||||||
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
|
return get_data_error_result(message="Dataset name can't be empty.")
|
||||||
req["name"] = req["name"].strip()
|
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||||
|
return get_data_error_result(
|
||||||
|
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
|
||||||
|
req["name"] = req["name"].strip()
|
||||||
|
|
||||||
# 验证不允许的参数
|
# 验证不允许的参数
|
||||||
not_allowed = ["id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date"]
|
not_allowed = ["id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date"]
|
||||||
@@ -202,7 +205,8 @@ async def update(
|
|||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Can't find this knowledgebase!")
|
message="Can't find this knowledgebase!")
|
||||||
|
|
||||||
if req["name"].lower() != kb.name.lower() \
|
# 检查名称重复(仅在提供新名称时)
|
||||||
|
if "name" in req and req["name"].lower() != kb.name.lower() \
|
||||||
and len(
|
and len(
|
||||||
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
|
|||||||
84
api/apps/models/conversation_models.py
Normal file
84
api/apps/models/conversation_models.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SetConversationRequest(BaseModel):
|
||||||
|
"""设置对话请求"""
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
is_new: bool
|
||||||
|
name: Optional[str] = Field(default="New conversation", max_length=255)
|
||||||
|
dialog_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteConversationsRequest(BaseModel):
|
||||||
|
"""删除对话请求"""
|
||||||
|
conversation_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequest(BaseModel):
|
||||||
|
"""完成请求(聊天完成)"""
|
||||||
|
conversation_id: str
|
||||||
|
messages: List[Dict[str, Any]]
|
||||||
|
llm_id: Optional[str] = None
|
||||||
|
stream: Optional[bool] = True
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TTSRequest(BaseModel):
|
||||||
|
"""文本转语音请求"""
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteMessageRequest(BaseModel):
|
||||||
|
"""删除消息请求"""
|
||||||
|
conversation_id: str
|
||||||
|
message_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ThumbupRequest(BaseModel):
|
||||||
|
"""点赞/点踩请求"""
|
||||||
|
conversation_id: str
|
||||||
|
message_id: str
|
||||||
|
thumbup: Optional[bool] = None
|
||||||
|
feedback: Optional[str] = ""
|
||||||
|
|
||||||
|
|
||||||
|
class AskRequest(BaseModel):
|
||||||
|
"""提问请求"""
|
||||||
|
question: str
|
||||||
|
kb_ids: List[str]
|
||||||
|
search_id: Optional[str] = ""
|
||||||
|
|
||||||
|
|
||||||
|
class MindmapRequest(BaseModel):
|
||||||
|
"""思维导图请求"""
|
||||||
|
question: str
|
||||||
|
kb_ids: List[str]
|
||||||
|
search_id: Optional[str] = ""
|
||||||
|
|
||||||
|
|
||||||
|
class RelatedQuestionsRequest(BaseModel):
|
||||||
|
"""相关问题请求"""
|
||||||
|
question: str
|
||||||
|
search_id: Optional[str] = ""
|
||||||
|
|
||||||
43
api/apps/models/file_models.py
Normal file
43
api/apps/models/file_models.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class CreateFileRequest(BaseModel):
|
||||||
|
"""创建文件/文件夹请求"""
|
||||||
|
name: str
|
||||||
|
parent_id: Optional[str] = None
|
||||||
|
type: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteFilesRequest(BaseModel):
|
||||||
|
"""删除文件请求"""
|
||||||
|
file_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class RenameFileRequest(BaseModel):
|
||||||
|
"""重命名文件请求"""
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class MoveFilesRequest(BaseModel):
|
||||||
|
"""移动文件请求"""
|
||||||
|
src_file_ids: List[str]
|
||||||
|
dest_file_id: str
|
||||||
|
|
||||||
@@ -60,9 +60,16 @@ class CreateKnowledgeBaseRequest(BaseModel):
|
|||||||
class UpdateKnowledgeBaseRequest(BaseModel):
|
class UpdateKnowledgeBaseRequest(BaseModel):
|
||||||
"""更新知识库请求"""
|
"""更新知识库请求"""
|
||||||
kb_id: str
|
kb_id: str
|
||||||
name: str
|
name: Optional[str] = None
|
||||||
description: str
|
avatar: Optional[str] = None
|
||||||
parser_id: str
|
language: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
permission: Optional[str] = None
|
||||||
|
doc_num: Optional[int] = None
|
||||||
|
token_num: Optional[int] = None
|
||||||
|
chunk_num: Optional[int] = None
|
||||||
|
parser_id: Optional[str] = None
|
||||||
|
embd_id: Optional[str] = None
|
||||||
pagerank: Optional[int] = None
|
pagerank: Optional[int] = None
|
||||||
# 其他可选字段,但排除 id, tenant_id, created_by, create_time, update_time, create_date, update_date
|
# 其他可选字段,但排除 id, tenant_id, created_by, create_time, update_time, create_date, update_date
|
||||||
|
|
||||||
|
|||||||
53
api/apps/models/search_models.py
Normal file
53
api/apps/models/search_models.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class CreateSearchRequest(BaseModel):
|
||||||
|
"""创建搜索应用请求"""
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = ""
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSearchRequest(BaseModel):
|
||||||
|
"""更新搜索应用请求"""
|
||||||
|
search_id: str
|
||||||
|
name: str
|
||||||
|
search_config: Dict[str, Any]
|
||||||
|
tenant_id: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteSearchRequest(BaseModel):
|
||||||
|
"""删除搜索应用请求"""
|
||||||
|
search_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ListSearchAppsQuery(BaseModel):
|
||||||
|
"""列出搜索应用查询参数"""
|
||||||
|
keywords: Optional[str] = ""
|
||||||
|
page: Optional[int] = 0
|
||||||
|
page_size: Optional[int] = 0
|
||||||
|
orderby: Optional[str] = "create_time"
|
||||||
|
desc: Optional[str] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class ListSearchAppsBody(BaseModel):
|
||||||
|
"""列出搜索应用请求体"""
|
||||||
|
owner_ids: Optional[List[str]] = []
|
||||||
|
|
||||||
@@ -14,8 +14,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from flask import request
|
from typing import Optional
|
||||||
from flask_login import current_user, login_required
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
from api.apps.models.auth_dependencies import get_current_user
|
||||||
|
from api.apps.models.search_models import (
|
||||||
|
CreateSearchRequest,
|
||||||
|
UpdateSearchRequest,
|
||||||
|
DeleteSearchRequest,
|
||||||
|
ListSearchAppsQuery,
|
||||||
|
ListSearchAppsBody,
|
||||||
|
)
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.constants import DATASET_NAME_LIMIT
|
from api.constants import DATASET_NAME_LIMIT
|
||||||
@@ -25,14 +35,23 @@ from api.db.services import duplicate_name
|
|||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
|
from api.utils.api_utils import (
|
||||||
|
get_data_error_result,
|
||||||
|
get_json_result,
|
||||||
|
server_error_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建路由器
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/create", methods=["post"]) # noqa: F821
|
@router.post('/create')
|
||||||
@login_required
|
async def create(
|
||||||
@validate_request("name")
|
request: CreateSearchRequest,
|
||||||
def create():
|
current_user = Depends(get_current_user)
|
||||||
req = request.get_json()
|
):
|
||||||
|
"""创建搜索应用"""
|
||||||
|
req = request.model_dump(exclude_unset=True)
|
||||||
search_name = req["name"]
|
search_name = req["name"]
|
||||||
description = req.get("description", "")
|
description = req.get("description", "")
|
||||||
if not isinstance(search_name, str):
|
if not isinstance(search_name, str):
|
||||||
@@ -62,12 +81,13 @@ def create():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/update", methods=["post"]) # noqa: F821
|
@router.post('/update')
|
||||||
@login_required
|
async def update(
|
||||||
@validate_request("search_id", "name", "search_config", "tenant_id")
|
request: UpdateSearchRequest,
|
||||||
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
current_user = Depends(get_current_user)
|
||||||
def update():
|
):
|
||||||
req = request.get_json()
|
"""更新搜索应用"""
|
||||||
|
req = request.model_dump(exclude_unset=True)
|
||||||
if not isinstance(req["name"], str):
|
if not isinstance(req["name"], str):
|
||||||
return get_data_error_result(message="Search name must be string.")
|
return get_data_error_result(message="Search name must be string.")
|
||||||
if req["name"].strip() == "":
|
if req["name"].strip() == "":
|
||||||
@@ -84,6 +104,12 @@ def update():
|
|||||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
|
# 验证不允许的参数
|
||||||
|
not_allowed = ["id", "created_by", "create_time", "update_time", "create_date", "update_date"]
|
||||||
|
for key in not_allowed:
|
||||||
|
if key in req:
|
||||||
|
del req[key]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0]
|
search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0]
|
||||||
if not search_app:
|
if not search_app:
|
||||||
@@ -119,10 +145,12 @@ def update():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/detail", methods=["GET"]) # noqa: F821
|
@router.get('/detail')
|
||||||
@login_required
|
async def detail(
|
||||||
def detail():
|
search_id: str = Query(..., description="搜索应用ID"),
|
||||||
search_id = request.args["search_id"]
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取搜索应用详情"""
|
||||||
try:
|
try:
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
@@ -139,20 +167,23 @@ def detail():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
@router.post('/list')
|
||||||
@login_required
|
async def list_search_app(
|
||||||
def list_search_app():
|
query: ListSearchAppsQuery = Depends(),
|
||||||
keywords = request.args.get("keywords", "")
|
body: Optional[ListSearchAppsBody] = None,
|
||||||
page_number = int(request.args.get("page", 0))
|
current_user = Depends(get_current_user)
|
||||||
items_per_page = int(request.args.get("page_size", 0))
|
):
|
||||||
orderby = request.args.get("orderby", "create_time")
|
"""列出搜索应用"""
|
||||||
if request.args.get("desc", "true").lower() == "false":
|
if body is None:
|
||||||
desc = False
|
body = ListSearchAppsBody()
|
||||||
else:
|
|
||||||
desc = True
|
|
||||||
|
|
||||||
req = request.get_json()
|
keywords = query.keywords or ""
|
||||||
owner_ids = req.get("owner_ids", [])
|
page_number = int(query.page or 0)
|
||||||
|
items_per_page = int(query.page_size or 0)
|
||||||
|
orderby = query.orderby or "create_time"
|
||||||
|
desc = query.desc.lower() == "true" if query.desc else True
|
||||||
|
|
||||||
|
owner_ids = body.owner_ids or [] if body else []
|
||||||
try:
|
try:
|
||||||
if not owner_ids:
|
if not owner_ids:
|
||||||
# tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
# tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||||
@@ -171,12 +202,13 @@ def list_search_app():
|
|||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/rm", methods=["post"]) # noqa: F821
|
@router.post('/rm')
|
||||||
@login_required
|
async def rm(
|
||||||
@validate_request("search_id")
|
request: DeleteSearchRequest,
|
||||||
def rm():
|
current_user = Depends(get_current_user)
|
||||||
req = request.get_json()
|
):
|
||||||
search_id = req["search_id"]
|
"""删除搜索应用"""
|
||||||
|
search_id = request.search_id
|
||||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,15 @@
|
|||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
|
import string
|
||||||
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, Query, status
|
||||||
from api.apps.models.auth_dependencies import get_current_user
|
from api.apps.models.auth_dependencies import get_current_user
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
@@ -60,6 +63,19 @@ from api.utils.api_utils import (
|
|||||||
validate_request,
|
validate_request,
|
||||||
)
|
)
|
||||||
from api.utils.crypt import decrypt
|
from api.utils.crypt import decrypt
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
from api.apps import smtp_mail_server
|
||||||
|
from api.utils.web_utils import (
|
||||||
|
send_email_html,
|
||||||
|
OTP_LENGTH,
|
||||||
|
OTP_TTL_SECONDS,
|
||||||
|
ATTEMPT_LIMIT,
|
||||||
|
ATTEMPT_LOCK_SECONDS,
|
||||||
|
RESEND_COOLDOWN_SECONDS,
|
||||||
|
otp_keys,
|
||||||
|
hash_code,
|
||||||
|
captcha_key,
|
||||||
|
)
|
||||||
|
|
||||||
# 创建路由器
|
# 创建路由器
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -77,9 +93,8 @@ class RegisterRequest(BaseModel):
|
|||||||
password: str
|
password: str
|
||||||
|
|
||||||
class UserSettingRequest(BaseModel):
|
class UserSettingRequest(BaseModel):
|
||||||
nickname: Optional[str] = None
|
language: Optional[str] = None
|
||||||
password: Optional[str] = None
|
|
||||||
new_password: Optional[str] = None
|
|
||||||
|
|
||||||
class TenantInfoRequest(BaseModel):
|
class TenantInfoRequest(BaseModel):
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
@@ -88,6 +103,16 @@ class TenantInfoRequest(BaseModel):
|
|||||||
img2txt_id: str
|
img2txt_id: str
|
||||||
llm_id: str
|
llm_id: str
|
||||||
|
|
||||||
|
class ForgetOtpRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
captcha: str
|
||||||
|
|
||||||
|
class ForgetPasswordRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
otp: str
|
||||||
|
new_password: str
|
||||||
|
confirm_new_password: str
|
||||||
|
|
||||||
# 依赖项:获取当前用户 - 从 auth_dependencies 导入
|
# 依赖项:获取当前用户 - 从 auth_dependencies 导入
|
||||||
|
|
||||||
@router.post("/login")
|
@router.post("/login")
|
||||||
@@ -481,3 +506,357 @@ async def set_tenant_info(request: TenantInfoRequest, current_user = Depends(get
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=str(e)
|
detail=str(e)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@router.get("/github_callback")
|
||||||
|
async def github_callback(code: Optional[str] = Query(None)):
|
||||||
|
"""
|
||||||
|
**Deprecated**, Use `/oauth/callback/<channel>` instead.
|
||||||
|
|
||||||
|
GitHub OAuth callback endpoint.
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
return RedirectResponse(url="/?error=missing_code")
|
||||||
|
|
||||||
|
res = requests.post(
|
||||||
|
settings.GITHUB_OAUTH.get("url"),
|
||||||
|
data={
|
||||||
|
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
||||||
|
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
|
||||||
|
"code": code,
|
||||||
|
},
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
res = res.json()
|
||||||
|
if "error" in res:
|
||||||
|
return RedirectResponse(url=f"/?error={res.get('error_description', res.get('error'))}")
|
||||||
|
|
||||||
|
if "user:email" not in res.get("scope", "").split(","):
|
||||||
|
return RedirectResponse(url="/?error=user:email not in scope")
|
||||||
|
|
||||||
|
access_token = res["access_token"]
|
||||||
|
user_info = user_info_from_github(access_token)
|
||||||
|
email_address = user_info["email"]
|
||||||
|
users = UserService.query(email=email_address)
|
||||||
|
user_id = get_uuid()
|
||||||
|
|
||||||
|
if not users:
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
avatar = download_img(user_info["avatar_url"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
avatar = ""
|
||||||
|
|
||||||
|
users = user_register(
|
||||||
|
user_id,
|
||||||
|
{
|
||||||
|
"access_token": access_token,
|
||||||
|
"email": email_address,
|
||||||
|
"avatar": avatar,
|
||||||
|
"nickname": user_info["login"],
|
||||||
|
"login_channel": "github",
|
||||||
|
"last_login_time": get_format_time(),
|
||||||
|
"is_superuser": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not users:
|
||||||
|
raise Exception(f"Fail to register {email_address}.")
|
||||||
|
if len(users) > 1:
|
||||||
|
raise Exception(f"Same email: {email_address} exists!")
|
||||||
|
|
||||||
|
user = users[0]
|
||||||
|
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||||
|
except Exception as e:
|
||||||
|
rollback_user_registration(user_id)
|
||||||
|
logging.exception(e)
|
||||||
|
return RedirectResponse(url=f"/?error={str(e)}")
|
||||||
|
|
||||||
|
# User has already registered, try to log in
|
||||||
|
user = users[0]
|
||||||
|
user.access_token = get_uuid()
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return RedirectResponse(url="/?error=user_inactive")
|
||||||
|
user.save()
|
||||||
|
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||||
|
|
||||||
|
@router.get("/feishu_callback")
|
||||||
|
async def feishu_callback(code: Optional[str] = Query(None)):
|
||||||
|
"""
|
||||||
|
Feishu OAuth callback endpoint.
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
return RedirectResponse(url="/?error=missing_code")
|
||||||
|
|
||||||
|
app_access_token_res = requests.post(
|
||||||
|
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"app_id": settings.FEISHU_OAUTH.get("app_id"),
|
||||||
|
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||||
|
)
|
||||||
|
app_access_token_res = app_access_token_res.json()
|
||||||
|
if app_access_token_res.get("code") != 0:
|
||||||
|
return RedirectResponse(url=f"/?error={app_access_token_res}")
|
||||||
|
|
||||||
|
res = requests.post(
|
||||||
|
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
|
||||||
|
"code": code,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {app_access_token_res['app_access_token']}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
res = res.json()
|
||||||
|
if res.get("code") != 0:
|
||||||
|
return RedirectResponse(url=f"/?error={res.get('message', 'unknown_error')}")
|
||||||
|
|
||||||
|
if "contact:user.email:readonly" not in res.get("data", {}).get("scope", "").split():
|
||||||
|
return RedirectResponse(url="/?error=contact:user.email:readonly not in scope")
|
||||||
|
|
||||||
|
access_token = res["data"]["access_token"]
|
||||||
|
user_info = user_info_from_feishu(access_token)
|
||||||
|
email_address = user_info["email"]
|
||||||
|
users = UserService.query(email=email_address)
|
||||||
|
user_id = get_uuid()
|
||||||
|
|
||||||
|
if not users:
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
avatar = download_img(user_info["avatar_url"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
avatar = ""
|
||||||
|
|
||||||
|
users = user_register(
|
||||||
|
user_id,
|
||||||
|
{
|
||||||
|
"access_token": access_token,
|
||||||
|
"email": email_address,
|
||||||
|
"avatar": avatar,
|
||||||
|
"nickname": user_info["en_name"],
|
||||||
|
"login_channel": "feishu",
|
||||||
|
"last_login_time": get_format_time(),
|
||||||
|
"is_superuser": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not users:
|
||||||
|
raise Exception(f"Fail to register {email_address}.")
|
||||||
|
if len(users) > 1:
|
||||||
|
raise Exception(f"Same email: {email_address} exists!")
|
||||||
|
|
||||||
|
user = users[0]
|
||||||
|
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||||
|
except Exception as e:
|
||||||
|
rollback_user_registration(user_id)
|
||||||
|
logging.exception(e)
|
||||||
|
return RedirectResponse(url=f"/?error={str(e)}")
|
||||||
|
|
||||||
|
# User has already registered, try to log in
|
||||||
|
user = users[0]
|
||||||
|
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||||
|
return RedirectResponse(url="/?error=user_inactive")
|
||||||
|
user.access_token = get_uuid()
|
||||||
|
user.save()
|
||||||
|
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||||
|
|
||||||
|
def user_info_from_feishu(access_token):
|
||||||
|
"""从飞书获取用户信息"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
}
|
||||||
|
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||||
|
user_info = res.json()["data"]
|
||||||
|
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
|
||||||
|
return user_info
|
||||||
|
|
||||||
|
def user_info_from_github(access_token):
|
||||||
|
"""从GitHub获取用户信息"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
|
||||||
|
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||||
|
user_info = res.json()
|
||||||
|
email_info = requests.get(
|
||||||
|
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||||
|
headers=headers,
|
||||||
|
).json()
|
||||||
|
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||||
|
return user_info
|
||||||
|
|
||||||
|
@router.get("/forget/captcha")
|
||||||
|
async def forget_get_captcha(email: str = Query(...)):
|
||||||
|
"""
|
||||||
|
GET /forget/captcha?email=<email>
|
||||||
|
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = 60 seconds.
|
||||||
|
- Returns the captcha as a JPEG image.
|
||||||
|
"""
|
||||||
|
if not email:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email is required")
|
||||||
|
|
||||||
|
users = UserService.query(email=email)
|
||||||
|
if not users:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
|
||||||
|
|
||||||
|
# Generate captcha text
|
||||||
|
allowed = string.ascii_uppercase + string.digits
|
||||||
|
captcha_text = "".join(secrets.choice(allowed) for _ in range(OTP_LENGTH))
|
||||||
|
REDIS_CONN.set(captcha_key(email), captcha_text, 60) # Valid for 60 seconds
|
||||||
|
|
||||||
|
from captcha.image import ImageCaptcha
|
||||||
|
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
|
||||||
|
img_bytes = image.generate(captcha_text).read()
|
||||||
|
|
||||||
|
return Response(content=img_bytes, media_type="image/JPEG")
|
||||||
|
|
||||||
|
@router.post("/forget/otp")
|
||||||
|
async def forget_send_otp(request: ForgetOtpRequest):
|
||||||
|
"""
|
||||||
|
POST /forget/otp
|
||||||
|
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||||||
|
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||||||
|
"""
|
||||||
|
email = request.email or ""
|
||||||
|
captcha = (request.captcha or "").strip()
|
||||||
|
|
||||||
|
if not email or not captcha:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email and captcha required")
|
||||||
|
|
||||||
|
users = UserService.query(email=email)
|
||||||
|
if not users:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
|
||||||
|
|
||||||
|
stored_captcha = REDIS_CONN.get(captcha_key(email))
|
||||||
|
if not stored_captcha:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="invalid or expired captcha")
|
||||||
|
if (stored_captcha or "").strip().lower() != captcha.lower():
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="invalid or expired captcha")
|
||||||
|
|
||||||
|
# Delete captcha to prevent reuse
|
||||||
|
REDIS_CONN.delete(captcha_key(email))
|
||||||
|
|
||||||
|
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||||||
|
now = int(time.time())
|
||||||
|
last_ts = REDIS_CONN.get(k_last)
|
||||||
|
if last_ts:
|
||||||
|
try:
|
||||||
|
elapsed = now - int(last_ts)
|
||||||
|
except Exception:
|
||||||
|
elapsed = RESEND_COOLDOWN_SECONDS
|
||||||
|
remaining = RESEND_COOLDOWN_SECONDS - elapsed
|
||||||
|
if remaining > 0:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message=f"you still have to wait {remaining} seconds")
|
||||||
|
|
||||||
|
# Generate OTP (uppercase letters only) and store hashed
|
||||||
|
otp = "".join(secrets.choice(string.ascii_uppercase) for _ in range(OTP_LENGTH))
|
||||||
|
salt = os.urandom(16)
|
||||||
|
code_hash = hash_code(otp, salt)
|
||||||
|
REDIS_CONN.set(k_code, f"{code_hash}:{salt.hex()}", OTP_TTL_SECONDS)
|
||||||
|
REDIS_CONN.set(k_attempts, 0, OTP_TTL_SECONDS)
|
||||||
|
REDIS_CONN.set(k_last, now, OTP_TTL_SECONDS)
|
||||||
|
REDIS_CONN.delete(k_lock)
|
||||||
|
|
||||||
|
ttl_min = OTP_TTL_SECONDS // 60
|
||||||
|
|
||||||
|
if not smtp_mail_server:
|
||||||
|
logging.warning("SMTP mail server not initialized; skip sending email.")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
send_email_html(
|
||||||
|
subject="Your Password Reset Code",
|
||||||
|
to_email=email,
|
||||||
|
template_key="reset_code",
|
||||||
|
code=otp,
|
||||||
|
ttl_min=ttl_min,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="failed to send email")
|
||||||
|
|
||||||
|
return get_json_result(data=True, code=settings.RetCode.SUCCESS, message="verification passed, email sent")
|
||||||
|
|
||||||
|
@router.post("/forget")
|
||||||
|
async def forget(request: ForgetPasswordRequest):
|
||||||
|
"""
|
||||||
|
POST: Verify email + OTP and reset password, then log the user in.
|
||||||
|
Request JSON: { email, otp, new_password, confirm_new_password }
|
||||||
|
"""
|
||||||
|
email = request.email or ""
|
||||||
|
otp = (request.otp or "").strip()
|
||||||
|
new_pwd = request.new_password
|
||||||
|
new_pwd2 = request.confirm_new_password
|
||||||
|
|
||||||
|
if not all([email, otp, new_pwd, new_pwd2]):
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required")
|
||||||
|
|
||||||
|
# For reset, passwords are provided as-is (no decrypt needed)
|
||||||
|
if new_pwd != new_pwd2:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="passwords do not match")
|
||||||
|
|
||||||
|
users = UserService.query(email=email)
|
||||||
|
if not users:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
|
||||||
|
|
||||||
|
user = users[0]
|
||||||
|
# Verify OTP from Redis
|
||||||
|
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||||||
|
if REDIS_CONN.get(k_lock):
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="too many attempts, try later")
|
||||||
|
|
||||||
|
stored = REDIS_CONN.get(k_code)
|
||||||
|
if not stored:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="expired otp")
|
||||||
|
|
||||||
|
try:
|
||||||
|
stored_hash, salt_hex = str(stored).split(":", 1)
|
||||||
|
salt = bytes.fromhex(salt_hex)
|
||||||
|
except Exception:
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
|
||||||
|
|
||||||
|
# Case-insensitive verification: OTP generated uppercase
|
||||||
|
calc = hash_code(otp.upper(), salt)
|
||||||
|
if calc != stored_hash:
|
||||||
|
# bump attempts
|
||||||
|
try:
|
||||||
|
attempts = int(REDIS_CONN.get(k_attempts) or 0) + 1
|
||||||
|
except Exception:
|
||||||
|
attempts = 1
|
||||||
|
REDIS_CONN.set(k_attempts, attempts, OTP_TTL_SECONDS)
|
||||||
|
if attempts >= ATTEMPT_LIMIT:
|
||||||
|
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="expired otp")
|
||||||
|
|
||||||
|
# Success: consume OTP and reset password
|
||||||
|
REDIS_CONN.delete(k_code)
|
||||||
|
REDIS_CONN.delete(k_attempts)
|
||||||
|
REDIS_CONN.delete(k_last)
|
||||||
|
REDIS_CONN.delete(k_lock)
|
||||||
|
|
||||||
|
try:
|
||||||
|
UserService.update_user_password(user.id, new_pwd)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="failed to reset password")
|
||||||
|
|
||||||
|
# Auto login (reuse login flow)
|
||||||
|
user.access_token = get_uuid()
|
||||||
|
user.update_time = (current_timestamp(),)
|
||||||
|
user.update_date = (datetime_format(datetime.now()),)
|
||||||
|
user.save()
|
||||||
|
msg = "Password reset successful. Logged in."
|
||||||
|
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ minio:
|
|||||||
user: 'rag_flow'
|
user: 'rag_flow'
|
||||||
password: 'infini_rag_flow'
|
password: 'infini_rag_flow'
|
||||||
host: 'localhost:9000'
|
host: 'localhost:9000'
|
||||||
|
es:
|
||||||
|
hosts: 'http://localhost:1200'
|
||||||
|
username: 'elastic'
|
||||||
|
password: 'infini_rag_flow'
|
||||||
os:
|
os:
|
||||||
hosts: 'http://localhost:1201'
|
hosts: 'http://localhost:1201'
|
||||||
username: 'admin'
|
username: 'admin'
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# - `elasticsearch` (default)
|
# - `elasticsearch` (default)
|
||||||
# - `infinity` (https://github.com/infiniflow/infinity)
|
# - `infinity` (https://github.com/infiniflow/infinity)
|
||||||
# - `opensearch` (https://github.com/opensearch-project/OpenSearch)
|
# - `opensearch` (https://github.com/opensearch-project/OpenSearch)
|
||||||
DOC_ENGINE=opensearch
|
DOC_ENGINE=elasticsearch
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# docker env var for specifying vector db type at startup
|
# docker env var for specifying vector db type at startup
|
||||||
@@ -98,7 +98,7 @@ ADMIN_SVR_HTTP_PORT=9381
|
|||||||
|
|
||||||
# The RAGFlow Docker image to download.
|
# The RAGFlow Docker image to download.
|
||||||
# Defaults to the v0.21.1-slim edition, which is the RAGFlow Docker image without embedding models.
|
# Defaults to the v0.21.1-slim edition, which is the RAGFlow Docker image without embedding models.
|
||||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi
|
RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1-fastapi-web
|
||||||
#
|
#
|
||||||
# To download the RAGFlow Docker image with embedding models, uncomment the following line instead:
|
# To download the RAGFlow Docker image with embedding models, uncomment the following line instead:
|
||||||
# RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1
|
# RAGFLOW_IMAGE=infiniflow/ragflow:v0.21.1
|
||||||
|
|||||||
@@ -1,36 +1,34 @@
|
|||||||
services:
|
services:
|
||||||
|
|
||||||
opensearch01:
|
es01:
|
||||||
container_name: ragflow-opensearch-01
|
container_name: ragflow-es-01
|
||||||
profiles:
|
profiles:
|
||||||
- opensearch
|
- elasticsearch
|
||||||
image: hub.icert.top/opensearchproject/opensearch:2.19.1
|
image: elasticsearch:${STACK_VERSION}
|
||||||
volumes:
|
volumes:
|
||||||
- osdata01:/usr/share/opensearch/data
|
- esdata01:/usr/share/elasticsearch/data
|
||||||
ports:
|
ports:
|
||||||
- ${OS_PORT}:9201
|
- ${ES_PORT}:9200
|
||||||
env_file: .env
|
env_file: .env
|
||||||
environment:
|
environment:
|
||||||
- node.name=opensearch01
|
- node.name=es01
|
||||||
- OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD}
|
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
|
||||||
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_PASSWORD}
|
|
||||||
- bootstrap.memory_lock=false
|
- bootstrap.memory_lock=false
|
||||||
- discovery.type=single-node
|
- discovery.type=single-node
|
||||||
- plugins.security.disabled=false
|
- xpack.security.enabled=true
|
||||||
- plugins.security.ssl.http.enabled=false
|
- xpack.security.http.ssl.enabled=false
|
||||||
- plugins.security.ssl.transport.enabled=true
|
- xpack.security.transport.ssl.enabled=false
|
||||||
- cluster.routing.allocation.disk.watermark.low=5gb
|
- cluster.routing.allocation.disk.watermark.low=5gb
|
||||||
- cluster.routing.allocation.disk.watermark.high=3gb
|
- cluster.routing.allocation.disk.watermark.high=3gb
|
||||||
- cluster.routing.allocation.disk.watermark.flood_stage=2gb
|
- cluster.routing.allocation.disk.watermark.flood_stage=2gb
|
||||||
- TZ=${TIMEZONE}
|
- TZ=${TIMEZONE}
|
||||||
- http.port=9201
|
|
||||||
mem_limit: ${MEM_LIMIT}
|
mem_limit: ${MEM_LIMIT}
|
||||||
ulimits:
|
ulimits:
|
||||||
memlock:
|
memlock:
|
||||||
soft: -1
|
soft: -1
|
||||||
hard: -1
|
hard: -1
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "curl http://localhost:9201"]
|
test: ["CMD-SHELL", "curl http://localhost:9200"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 120
|
retries: 120
|
||||||
|
|||||||
@@ -64,7 +64,8 @@ class Dealer:
|
|||||||
if key in req and req[key] is not None:
|
if key in req and req[key] is not None:
|
||||||
condition[field] = req[key]
|
condition[field] = req[key]
|
||||||
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
||||||
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
|
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd",
|
||||||
|
"removed_kwd"]:
|
||||||
if key in req and req[key] is not None:
|
if key in req and req[key] is not None:
|
||||||
condition[key] = req[key]
|
condition[key] = req[key]
|
||||||
return condition
|
return condition
|
||||||
@@ -135,7 +136,8 @@ class Dealer:
|
|||||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||||
matchDense.extra_options["similarity"] = 0.17
|
matchDense.extra_options["similarity"] = 0.17
|
||||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||||
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
orderBy, offset, limit, idx_names, kb_ids,
|
||||||
|
rank_feature=rank_feature)
|
||||||
total = self.dataStore.getTotal(res)
|
total = self.dataStore.getTotal(res)
|
||||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||||
|
|
||||||
@@ -212,8 +214,9 @@ class Dealer:
|
|||||||
ans_v, _ = embd_mdl.encode(pieces_)
|
ans_v, _ = embd_mdl.encode(pieces_)
|
||||||
for i in range(len(chunk_v)):
|
for i in range(len(chunk_v)):
|
||||||
if len(ans_v[0]) != len(chunk_v[i]):
|
if len(ans_v[0]) != len(chunk_v[i]):
|
||||||
chunk_v[i] = [0.0]*len(ans_v[0])
|
chunk_v[i] = [0.0] * len(ans_v[0])
|
||||||
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
logging.warning(
|
||||||
|
"The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
|
||||||
|
|
||||||
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
||||||
len(ans_v[0]), len(chunk_v[0]))
|
len(ans_v[0]), len(chunk_v[0]))
|
||||||
@@ -267,7 +270,7 @@ class Dealer:
|
|||||||
if not query_rfea:
|
if not query_rfea:
|
||||||
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
||||||
|
|
||||||
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
|
q_denor = np.sqrt(np.sum([s * s for t, s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||||
for i in search_res.ids:
|
for i in search_res.ids:
|
||||||
nor, denor = 0, 0
|
nor, denor = 0, 0
|
||||||
if not search_res.field[i].get(TAG_FLD):
|
if not search_res.field[i].get(TAG_FLD):
|
||||||
@@ -280,8 +283,8 @@ class Dealer:
|
|||||||
if denor == 0:
|
if denor == 0:
|
||||||
rank_fea.append(0)
|
rank_fea.append(0)
|
||||||
else:
|
else:
|
||||||
rank_fea.append(nor/np.sqrt(denor)/q_denor)
|
rank_fea.append(nor / np.sqrt(denor) / q_denor)
|
||||||
return np.array(rank_fea)*10. + pageranks
|
return np.array(rank_fea) * 10. + pageranks
|
||||||
|
|
||||||
def rerank(self, sres, query, tkweight=0.3,
|
def rerank(self, sres, query, tkweight=0.3,
|
||||||
vtweight=0.7, cfield="content_ltks",
|
vtweight=0.7, cfield="content_ltks",
|
||||||
@@ -343,7 +346,7 @@ class Dealer:
|
|||||||
## For rank feature(tag_fea) scores.
|
## For rank feature(tag_fea) scores.
|
||||||
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
||||||
|
|
||||||
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
|
return tkweight * (np.array(tksim) + rank_fea) + vtweight * vtsim, tksim, vtsim
|
||||||
|
|
||||||
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
||||||
return self.qryr.hybrid_similarity(ans_embd,
|
return self.qryr.hybrid_similarity(ans_embd,
|
||||||
@@ -360,13 +363,13 @@ class Dealer:
|
|||||||
return ranks
|
return ranks
|
||||||
|
|
||||||
# Ensure RERANK_LIMIT is multiple of page_size
|
# Ensure RERANK_LIMIT is multiple of page_size
|
||||||
RERANK_LIMIT = math.ceil(64/page_size) * page_size if page_size>1 else 1
|
RERANK_LIMIT = math.ceil(64 / page_size) * page_size if page_size > 1 else 1
|
||||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
|
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size * page / RERANK_LIMIT),
|
||||||
|
"size": RERANK_LIMIT,
|
||||||
"question": question, "vector": True, "topk": top,
|
"question": question, "vector": True, "topk": top,
|
||||||
"similarity": similarity_threshold,
|
"similarity": similarity_threshold,
|
||||||
"available_int": 1}
|
"available_int": 1}
|
||||||
|
|
||||||
|
|
||||||
if isinstance(tenant_ids, str):
|
if isinstance(tenant_ids, str):
|
||||||
tenant_ids = tenant_ids.split(",")
|
tenant_ids = tenant_ids.split(",")
|
||||||
|
|
||||||
@@ -392,15 +395,15 @@ class Dealer:
|
|||||||
tsim = sim
|
tsim = sim
|
||||||
vsim = sim
|
vsim = sim
|
||||||
# Already paginated in search function
|
# Already paginated in search function
|
||||||
begin = ((page % (RERANK_LIMIT//page_size)) - 1) * page_size
|
begin = ((page % (RERANK_LIMIT // page_size)) - 1) * page_size
|
||||||
sim = sim[begin : begin + page_size]
|
sim = sim[begin: begin + page_size]
|
||||||
sim_np = np.array(sim)
|
sim_np = np.array(sim)
|
||||||
idx = np.argsort(sim_np * -1)
|
idx = np.argsort(sim_np * -1)
|
||||||
dim = len(sres.query_vector)
|
dim = len(sres.query_vector)
|
||||||
vector_column = f"q_{dim}_vec"
|
vector_column = f"q_{dim}_vec"
|
||||||
zero_vector = [0.0] * dim
|
zero_vector = [0.0] * dim
|
||||||
filtered_count = (sim_np >= similarity_threshold).sum()
|
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||||
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
||||||
for i in idx:
|
for i in idx:
|
||||||
if sim[i] < similarity_threshold:
|
if sim[i] < similarity_threshold:
|
||||||
break
|
break
|
||||||
@@ -447,8 +450,8 @@ class Dealer:
|
|||||||
ranks["doc_aggs"] = [{"doc_name": k,
|
ranks["doc_aggs"] = [{"doc_name": k,
|
||||||
"doc_id": v["doc_id"],
|
"doc_id": v["doc_id"],
|
||||||
"count": v["count"]} for k,
|
"count": v["count"]} for k,
|
||||||
v in sorted(ranks["doc_aggs"].items(),
|
v in sorted(ranks["doc_aggs"].items(),
|
||||||
key=lambda x: x[1]["count"] * -1)]
|
key=lambda x: x[1]["count"] * -1)]
|
||||||
ranks["chunks"] = ranks["chunks"][:page_size]
|
ranks["chunks"] = ranks["chunks"][:page_size]
|
||||||
|
|
||||||
return ranks
|
return ranks
|
||||||
@@ -505,13 +508,14 @@ class Dealer:
|
|||||||
|
|
||||||
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
|
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
|
||||||
idx_nm = index_name(tenant_id)
|
idx_nm = index_name(tenant_id)
|
||||||
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
|
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []),
|
||||||
|
keywords_topn)
|
||||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
|
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
|
||||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||||
if not aggs:
|
if not aggs:
|
||||||
return False
|
return False
|
||||||
cnt = np.sum([c for _, c in aggs])
|
cnt = np.sum([c for _, c in aggs])
|
||||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||||
key=lambda x: x[1] * -1)[:topn_tags]
|
key=lambda x: x[1] * -1)[:topn_tags]
|
||||||
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
|
doc[TAG_FLD] = {a.replace(".", "_"): c for a, c in tag_fea if c > 0}
|
||||||
return True
|
return True
|
||||||
@@ -527,11 +531,11 @@ class Dealer:
|
|||||||
if not aggs:
|
if not aggs:
|
||||||
return {}
|
return {}
|
||||||
cnt = np.sum([c for _, c in aggs])
|
cnt = np.sum([c for _, c in aggs])
|
||||||
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
tag_fea = sorted([(a, round(0.1 * (c + 1) / (cnt + S) / max(1e-6, all_tags.get(a, 0.0001)))) for a, c in aggs],
|
||||||
key=lambda x: x[1] * -1)[:topn_tags]
|
key=lambda x: x[1] * -1)[:topn_tags]
|
||||||
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
|
return {a.replace(".", "_"): max(1, c) for a, c in tag_fea}
|
||||||
|
|
||||||
def retrieval_by_toc(self, query:str, chunks:list[dict], tenant_ids:list[str], chat_mdl, topn: int=6):
|
def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6):
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return []
|
return []
|
||||||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||||
@@ -541,9 +545,10 @@ class Dealer:
|
|||||||
ranks[ck["doc_id"]] = 0
|
ranks[ck["doc_id"]] = 0
|
||||||
ranks[ck["doc_id"]] += ck["similarity"]
|
ranks[ck["doc_id"]] += ck["similarity"]
|
||||||
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
|
doc_id2kb_id[ck["doc_id"]] = ck["kb_id"]
|
||||||
doc_id = sorted(ranks.items(), key=lambda x: x[1]*-1.)[0][0]
|
doc_id = sorted(ranks.items(), key=lambda x: x[1] * -1.)[0][0]
|
||||||
kb_ids = [doc_id2kb_id[doc_id]]
|
kb_ids = [doc_id2kb_id[doc_id]]
|
||||||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
|
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [],
|
||||||
|
OrderByExpr(), 0, 128, idx_nms,
|
||||||
kb_ids)
|
kb_ids)
|
||||||
toc = []
|
toc = []
|
||||||
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
|
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
|
||||||
@@ -555,7 +560,7 @@ class Dealer:
|
|||||||
if not toc:
|
if not toc:
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)
|
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2)
|
||||||
if not ids:
|
if not ids:
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
@@ -589,4 +594,4 @@ class Dealer:
|
|||||||
break
|
break
|
||||||
chunks.append(d)
|
chunks.append(d)
|
||||||
|
|
||||||
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]
|
return sorted(chunks, key=lambda x: x["similarity"] * -1)[:topn]
|
||||||
|
|||||||
35
rag/res/deepdoc/.gitattributes
vendored
Normal file
35
rag/res/deepdoc/.gitattributes
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
BIN
rag/res/deepdoc/layout.laws.onnx
Normal file
BIN
rag/res/deepdoc/layout.laws.onnx
Normal file
Binary file not shown.
BIN
rag/res/deepdoc/layout.manual.onnx
Normal file
BIN
rag/res/deepdoc/layout.manual.onnx
Normal file
Binary file not shown.
BIN
rag/res/deepdoc/layout.onnx
Normal file
BIN
rag/res/deepdoc/layout.onnx
Normal file
Binary file not shown.
BIN
rag/res/deepdoc/layout.paper.onnx
Normal file
BIN
rag/res/deepdoc/layout.paper.onnx
Normal file
Binary file not shown.
BIN
rag/res/deepdoc/tsr.onnx
Normal file
BIN
rag/res/deepdoc/tsr.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user