v0.21.1-fastapi
This commit is contained in:
@@ -52,6 +52,40 @@ def create_app() -> FastAPI:
|
||||
openapi_url="/apispec.json"
|
||||
)
|
||||
|
||||
# 自定义 OpenAPI schema 以支持 Bearer Token 认证
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
# 添加安全方案定义(HTTPBearer 会自动注册为 "HTTPBearer")
|
||||
# 如果 components 不存在,先创建它
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "securitySchemes" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["securitySchemes"] = {}
|
||||
|
||||
# 移除 CustomHTTPBearer(如果存在),只保留 HTTPBearer
|
||||
if "CustomHTTPBearer" in openapi_schema["components"]["securitySchemes"]:
|
||||
del openapi_schema["components"]["securitySchemes"]["CustomHTTPBearer"]
|
||||
|
||||
# 添加/更新 HTTPBearer 安全方案(FastAPI 默认名称)
|
||||
openapi_schema["components"]["securitySchemes"]["HTTPBearer"] = {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
"description": "输入 token,可以不带 'Bearer ' 前缀,系统会自动添加"
|
||||
}
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -124,22 +158,20 @@ def setup_routes(app: FastAPI):
|
||||
from api.apps.user_app_fastapi import router as user_router
|
||||
from api.apps.kb_app import router as kb_router
|
||||
from api.apps.document_app import router as document_router
|
||||
from api.apps.file_app import router as file_router
|
||||
from api.apps.file2document_app import router as file2document_router
|
||||
from api.apps.mcp_server_app import router as mcp_router
|
||||
from api.apps.tenant_app import router as tenant_router
|
||||
from api.apps.llm_app import router as llm_router
|
||||
from api.apps.chunk_app import router as chunk_router
|
||||
from api.apps.mcp_server_app import router as mcp_router
|
||||
from api.apps.canvas_app import router as canvas_router
|
||||
|
||||
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
|
||||
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"])
|
||||
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"])
|
||||
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
|
||||
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])
|
||||
app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"])
|
||||
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
|
||||
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
|
||||
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
|
||||
app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"])
|
||||
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
|
||||
app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"])
|
||||
|
||||
|
||||
|
||||
def get_current_user_from_token(authorization: str):
|
||||
"""从token获取当前用户"""
|
||||
|
||||
@@ -490,7 +490,7 @@ def upload():
|
||||
|
||||
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id")
|
||||
async def upload_parse():
|
||||
def upload_parse():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
@@ -507,7 +507,7 @@ async def upload_parse():
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = await doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@@ -536,7 +536,7 @@ def list_chunks():
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
@@ -884,7 +884,7 @@ def retrieval():
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||
rank_feature=label_question(question, kbs))
|
||||
|
||||
@@ -18,11 +18,12 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import flask
|
||||
import trio
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
from fastapi import APIRouter, Depends, Query, Header, UploadFile, File, Form
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from agent.component import LLM
|
||||
from api import settings
|
||||
@@ -36,7 +37,24 @@ from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.settings import RetCode
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||
from api.utils.api_utils import get_json_result, server_error_response, get_data_error_result
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.canvas_models import (
|
||||
DeleteCanvasRequest,
|
||||
SaveCanvasRequest,
|
||||
CompletionRequest,
|
||||
RerunRequest,
|
||||
ResetCanvasRequest,
|
||||
InputFormQuery,
|
||||
DebugRequest,
|
||||
TestDBConnectRequest,
|
||||
ListCanvasQuery,
|
||||
SettingRequest,
|
||||
TraceQuery,
|
||||
ListSessionsQuery,
|
||||
DownloadQuery,
|
||||
UploadQuery,
|
||||
)
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
@@ -47,18 +65,25 @@ from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def templates():
|
||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)])
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
def rm():
|
||||
for i in request.json["canvas_ids"]:
|
||||
@router.get('/templates')
|
||||
async def templates(
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取画布模板列表"""
|
||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
|
||||
|
||||
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
request: DeleteCanvasRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除画布"""
|
||||
for i in request.canvas_ids:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@@ -67,16 +92,18 @@ def rm():
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
def save():
|
||||
req = request.json
|
||||
@router.post('/set')
|
||||
async def save(
|
||||
request: SaveCanvasRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""保存画布"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
cate = req.get("canvas_category", CanvasCategory.Agent)
|
||||
if "id" not in req:
|
||||
if "id" not in req or not req.get("id"):
|
||||
req["user_id"] = current_user.id
|
||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate):
|
||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||
@@ -95,21 +122,28 @@ def save():
|
||||
return get_json_result(data=req)
|
||||
|
||||
|
||||
@manager.route('/get/<canvas_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get(canvas_id):
|
||||
@router.get('/get/{canvas_id}')
|
||||
async def get(
|
||||
canvas_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取画布详情"""
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@manager.route('/getsse/<canvas_id>', methods=['GET']) # type: ignore # noqa: F821
|
||||
def getsse(canvas_id):
|
||||
token = request.headers.get('Authorization').split()
|
||||
if len(token) != 2:
|
||||
@router.get('/getsse/{canvas_id}')
|
||||
async def getsse(
|
||||
canvas_id: str,
|
||||
authorization: str = Header(..., description="Authorization header")
|
||||
):
|
||||
"""获取画布详情(SSE,通过API token认证)"""
|
||||
token_parts = authorization.split()
|
||||
if len(token_parts) != 2:
|
||||
return get_data_error_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
token = token_parts[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
if not objs:
|
||||
return get_data_error_result(message='Authentication error: API key is invalid!"')
|
||||
@@ -126,21 +160,23 @@ def getsse(canvas_id):
|
||||
return get_json_result(data=c.to_dict())
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def run():
|
||||
req = request.json
|
||||
query = req.get("query", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
@router.post('/completion')
|
||||
async def run(
|
||||
request: CompletionRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""运行画布(完成/执行)"""
|
||||
query = request.query or ""
|
||||
files = request.files or []
|
||||
inputs = request.inputs or {}
|
||||
user_id = request.user_id or current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(request.id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
e, cvs = UserCanvasService.get_by_id(request.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@@ -149,43 +185,48 @@ def run():
|
||||
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=request.id)
|
||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=request.id, task_id=task_id, file=files[0] if files else None, priority=0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
|
||||
canvas = Canvas(cvs.dsl, current_user.id, request.id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
def sse():
|
||||
async def sse():
|
||||
nonlocal canvas, user_id
|
||||
try:
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||
UserCanvasService.update_by_id(request.id, cvs.to_dict())
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
return StreamingResponse(
|
||||
sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Type": "text/event-stream; charset=utf-8"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
@router.post('/rerun')
|
||||
async def rerun(
|
||||
request: RerunRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""重新运行流水线"""
|
||||
doc = PipelineOperationLogService.get_documents_info(request.id)
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
doc = doc[0]
|
||||
@@ -198,19 +239,22 @@ def rerun():
|
||||
doc["chunk_num"] = 0
|
||||
doc["token_num"] = 0
|
||||
DocumentService.clear_chunk_num_when_rerun(doc["id"])
|
||||
DocumentService.update_by_id(id, doc)
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
DocumentService.update_by_id(request.id, doc)
|
||||
TaskService.filter_delete([Task.doc_id == request.id])
|
||||
|
||||
dsl = req["dsl"]
|
||||
dsl["path"] = [req["component_id"]]
|
||||
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
|
||||
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||
dsl = request.dsl
|
||||
dsl["path"] = [request.component_id]
|
||||
PipelineOperationLogService.update_by_id(request.id, {"dsl": dsl})
|
||||
queue_dataflow(tenant_id=current_user.id, flow_id=request.id, task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
def cancel(task_id):
|
||||
@router.put('/cancel/{task_id}')
|
||||
async def cancel(
|
||||
task_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""取消任务"""
|
||||
try:
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
except Exception as e:
|
||||
@@ -218,36 +262,43 @@ def cancel(task_id):
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def reset():
|
||||
req = request.json
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
@router.post('/reset')
|
||||
async def reset(
|
||||
request: ResetCanvasRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""重置画布"""
|
||||
if not UserCanvasService.accessible(request.id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
e, user_canvas = UserCanvasService.get_by_id(request.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||
canvas.reset()
|
||||
req["dsl"] = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
||||
return get_json_result(data=req["dsl"])
|
||||
dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(request.id, {"dsl": dsl})
|
||||
return get_json_result(data=dsl)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
@router.post("/upload/{canvas_id}")
|
||||
async def upload(
|
||||
canvas_id: str,
|
||||
url: Optional[str] = Query(None, description="URL(可选,用于从URL下载)"),
|
||||
file: Optional[UploadFile] = File(None, description="上传的文件")
|
||||
):
|
||||
"""上传文件到画布"""
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
user_id = cvs["user_id"]
|
||||
|
||||
def structured(filename, filetype, blob, content_type):
|
||||
nonlocal user_id
|
||||
if filetype == FileType.PDF.value:
|
||||
@@ -267,7 +318,7 @@ def upload(canvas_id):
|
||||
"preview_url": None
|
||||
}
|
||||
|
||||
if request.args.get("url"):
|
||||
if url:
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
@@ -277,7 +328,6 @@ def upload(canvas_id):
|
||||
CrawlResult
|
||||
)
|
||||
try:
|
||||
url = request.args.get("url")
|
||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||
async def adownload():
|
||||
browser_config = BrowserConfig(
|
||||
@@ -301,61 +351,67 @@ def upload(canvas_id):
|
||||
if page.pdf:
|
||||
if filename.split(".")[-1].lower() != "pdf":
|
||||
filename += ".pdf"
|
||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers.get("content-type", "application/pdf")))
|
||||
|
||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers.get("content-type", "text/html"), user_id))
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
||||
file = request.files['file']
|
||||
if not file:
|
||||
return get_data_error_result(message="No file provided.")
|
||||
|
||||
try:
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||
blob = await file.read()
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), blob, file.content_type or "application/octet-stream"))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/input_form', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def input_form():
|
||||
cvs_id = request.args.get("id")
|
||||
cpn_id = request.args.get("component_id")
|
||||
@router.get('/input_form')
|
||||
async def input_form(
|
||||
id: str = Query(..., description="画布ID"),
|
||||
component_id: str = Query(..., description="组件ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取组件输入表单"""
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(cvs_id)
|
||||
e, user_canvas = UserCanvasService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=cvs_id):
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||
return get_json_result(data=canvas.get_component_input_form(cpn_id))
|
||||
return get_json_result(data=canvas.get_component_input_form(component_id))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/debug', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
def debug():
|
||||
req = request.json
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
@router.post('/debug')
|
||||
async def debug(
|
||||
request: DebugRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""调试组件"""
|
||||
if not UserCanvasService.accessible(request.id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
e, user_canvas = UserCanvasService.get_by_id(request.id)
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||
canvas.reset()
|
||||
canvas.message_id = get_uuid()
|
||||
component = canvas.get_component(req["component_id"])["obj"]
|
||||
component = canvas.get_component(request.component_id)["obj"]
|
||||
component.reset()
|
||||
|
||||
if isinstance(component, LLM):
|
||||
component.set_debug_inputs(req["params"])
|
||||
component.invoke(**{k: o["value"] for k,o in req["params"].items()})
|
||||
component.set_debug_inputs(request.params)
|
||||
component.invoke(**{k: o["value"] for k,o in request.params.items()})
|
||||
outputs = component.output()
|
||||
for k in outputs.keys():
|
||||
if isinstance(outputs[k], partial):
|
||||
@@ -368,11 +424,13 @@ def debug():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
||||
@validate_request("db_type", "database", "username", "host", "port", "password")
|
||||
@login_required
|
||||
def test_db_connect():
|
||||
req = request.json
|
||||
@router.post('/test_db_connect')
|
||||
async def test_db_connect(
|
||||
request: TestDBConnectRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""测试数据库连接"""
|
||||
req = request.model_dump()
|
||||
try:
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
@@ -409,6 +467,49 @@ def test_db_connect():
|
||||
ibm_db.fetch_assoc(stmt)
|
||||
ibm_db.close(conn)
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
elif req["db_type"] == 'trino':
|
||||
def _parse_catalog_schema(db: str):
|
||||
if not db:
|
||||
return None, None
|
||||
if "." in db:
|
||||
c, s = db.split(".", 1)
|
||||
elif "/" in db:
|
||||
c, s = db.split("/", 1)
|
||||
else:
|
||||
c, s = db, "default"
|
||||
return c, s
|
||||
try:
|
||||
import trino
|
||||
import os
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception:
|
||||
return server_error_response("Missing dependency 'trino'. Please install: pip install trino")
|
||||
|
||||
catalog, schema = _parse_catalog_schema(req["database"])
|
||||
if not catalog:
|
||||
return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
|
||||
|
||||
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||
|
||||
auth = None
|
||||
if http_scheme == "https" and req.get("password"):
|
||||
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
|
||||
conn = trino.dbapi.connect(
|
||||
host=req["host"],
|
||||
port=int(req["port"] or 8080),
|
||||
user=req["username"] or "ragflow",
|
||||
catalog=catalog,
|
||||
schema=schema or "default",
|
||||
http_scheme=http_scheme,
|
||||
auth=auth
|
||||
)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT 1")
|
||||
cur.fetchall()
|
||||
cur.close()
|
||||
conn.close()
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
else:
|
||||
return server_error_response("Unsupported database type.")
|
||||
if req["db_type"] != 'mssql':
|
||||
@@ -421,42 +522,48 @@ def test_db_connect():
|
||||
|
||||
|
||||
#api get list version dsl of canvas
|
||||
@manager.route('/getlistversion/<canvas_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
@router.get('/getlistversion/{canvas_id}')
|
||||
async def getlistversion(
|
||||
canvas_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取画布版本列表"""
|
||||
try:
|
||||
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=list)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
|
||||
#api get version dsl of canvas
|
||||
@manager.route('/getversion/<version_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def getversion( version_id):
|
||||
@router.get('/getversion/{version_id}')
|
||||
async def getversion(
|
||||
version_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取画布版本详情"""
|
||||
try:
|
||||
|
||||
e, version = UserCanvasVersionService.get_by_id(version_id)
|
||||
if version:
|
||||
return get_json_result(data=version.to_dict())
|
||||
return get_data_error_result(message="Version not found.")
|
||||
except Exception as e:
|
||||
return get_json_result(data=f"Error getting history file: {e}")
|
||||
return get_data_error_result(message=f"Error getting history file: {e}")
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_canvas():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
canvas_category = request.args.get("canvas_category")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id]
|
||||
@router.get('/list')
|
||||
async def list_canvas(
|
||||
query: ListCanvasQuery = Depends(),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出画布"""
|
||||
keywords = query.keywords or ""
|
||||
page_number = int(query.page or 0)
|
||||
items_per_page = int(query.page_size or 0)
|
||||
orderby = query.orderby or "create_time"
|
||||
canvas_category = query.canvas_category
|
||||
desc = query.desc.lower() == "false" if query.desc else True
|
||||
owner_ids = [id for id in (query.owner_ids or "").strip().split(",") if id]
|
||||
|
||||
if not owner_ids:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
tenants = [m["tenant_id"] for m in tenants]
|
||||
@@ -472,68 +579,73 @@ def list_canvas():
|
||||
return get_json_result(data={"canvas": canvas, "total": total})
|
||||
|
||||
|
||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
def setting():
|
||||
req = request.json
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
@router.post('/setting')
|
||||
async def setting(
|
||||
request: SettingRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""设置画布"""
|
||||
if not UserCanvasService.accessible(request.id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e,flow = UserCanvasService.get_by_id(req["id"])
|
||||
e, flow = UserCanvasService.get_by_id(request.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
flow = flow.to_dict()
|
||||
flow["title"] = req["title"]
|
||||
flow["title"] = request.title
|
||||
|
||||
for key in ["description", "permission", "avatar"]:
|
||||
if value := req.get(key):
|
||||
value = getattr(request, key, None)
|
||||
if value:
|
||||
flow[key] = value
|
||||
|
||||
num= UserCanvasService.update_by_id(req["id"], flow)
|
||||
num = UserCanvasService.update_by_id(request.id, flow)
|
||||
return get_json_result(data=num)
|
||||
|
||||
|
||||
@manager.route('/trace', methods=['GET']) # noqa: F821
|
||||
def trace():
|
||||
cvs_id = request.args.get("canvas_id")
|
||||
msg_id = request.args.get("message_id")
|
||||
@router.get('/trace')
|
||||
async def trace(
|
||||
canvas_id: str = Query(..., description="画布ID"),
|
||||
message_id: str = Query(..., description="消息ID")
|
||||
):
|
||||
"""追踪日志"""
|
||||
try:
|
||||
bin = REDIS_CONN.get(f"{cvs_id}-{msg_id}-logs")
|
||||
bin = REDIS_CONN.get(f"{canvas_id}-{message_id}-logs")
|
||||
if not bin:
|
||||
return get_json_result(data={})
|
||||
|
||||
return get_json_result(data=json.loads(bin.encode("utf-8")))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/<canvas_id>/sessions', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def sessions(canvas_id):
|
||||
@router.get('/{canvas_id}/sessions')
|
||||
async def sessions(
|
||||
canvas_id: str,
|
||||
query: ListSessionsQuery = Depends(),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出画布会话"""
|
||||
tenant_id = current_user.id
|
||||
if not UserCanvasService.accessible(canvas_id, tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
user_id = request.args.get("user_id")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
keywords = request.args.get("keywords")
|
||||
from_date = request.args.get("from_date")
|
||||
to_date = request.args.get("to_date")
|
||||
orderby = request.args.get("orderby", "update_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
user_id = query.user_id
|
||||
page_number = int(query.page or 1)
|
||||
items_per_page = int(query.page_size or 30)
|
||||
keywords = query.keywords
|
||||
from_date = query.from_date
|
||||
to_date = query.to_date
|
||||
orderby = query.orderby or "update_time"
|
||||
desc = query.desc.lower() in ["false", "False"] if query.desc else True
|
||||
# dsl defaults to True in all cases except for False and false
|
||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||
include_dsl = query.dsl.lower() not in ["false", "False"] if query.dsl else True
|
||||
|
||||
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
|
||||
None, user_id, include_dsl, keywords, from_date, to_date)
|
||||
try:
|
||||
@@ -542,9 +654,11 @@ def sessions(canvas_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def prompts():
|
||||
@router.get('/prompts')
|
||||
async def prompts(
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取提示词模板"""
|
||||
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||
return get_json_result(data={
|
||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||
@@ -556,9 +670,11 @@ def prompts():
|
||||
})
|
||||
|
||||
|
||||
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||
def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
@router.get('/download')
|
||||
async def download(
|
||||
id: str = Query(..., description="文件ID"),
|
||||
created_by: str = Query(..., description="创建者ID")
|
||||
):
|
||||
"""下载文件"""
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return flask.make_response(blob)
|
||||
return Response(content=blob)
|
||||
@@ -16,10 +16,19 @@
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, List
|
||||
|
||||
import xxhash
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.chunk_models import (
|
||||
ListChunksRequest,
|
||||
SetChunkRequest,
|
||||
SwitchChunksRequest,
|
||||
DeleteChunksRequest,
|
||||
CreateChunkRequest,
|
||||
RetrievalTestRequest,
|
||||
)
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType, ParserType
|
||||
@@ -29,17 +38,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.models.chunk_models import (
|
||||
ListChunkRequest,
|
||||
GetChunkRequest,
|
||||
SetChunkRequest,
|
||||
SwitchChunkRequest,
|
||||
RemoveChunkRequest,
|
||||
CreateChunkRequest,
|
||||
RetrievalTestRequest,
|
||||
KnowledgeGraphRequest
|
||||
)
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_current_user
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@@ -47,19 +46,20 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
|
||||
# 创建 FastAPI 路由器
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/list')
|
||||
async def list_chunk(
|
||||
request: ListChunkRequest,
|
||||
request: ListChunksRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出文档块"""
|
||||
doc_id = request.doc_id
|
||||
page = request.page
|
||||
size = request.size
|
||||
question = request.keywords
|
||||
page = request.page or 1
|
||||
size = request.size or 30
|
||||
question = request.keywords or ""
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
@@ -73,7 +73,7 @@ async def list_chunk(
|
||||
}
|
||||
if request.available_int is not None:
|
||||
query["available_int"] = int(request.available_int)
|
||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@@ -105,6 +105,7 @@ async def get(
|
||||
chunk_id: str = Query(..., description="块ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取文档块"""
|
||||
try:
|
||||
chunk = None
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
@@ -138,6 +139,7 @@ async def set(
|
||||
request: SetChunkRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""设置文档块"""
|
||||
d = {
|
||||
"id": request.chunk_id,
|
||||
"content_with_weight": request.content_with_weight}
|
||||
@@ -192,9 +194,10 @@ async def set(
|
||||
|
||||
@router.post('/switch')
|
||||
async def switch(
|
||||
request: SwitchChunkRequest,
|
||||
request: SwitchChunksRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""切换文档块状态"""
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(request.doc_id)
|
||||
if not e:
|
||||
@@ -212,9 +215,10 @@ async def switch(
|
||||
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
request: RemoveChunkRequest,
|
||||
request: DeleteChunksRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除文档块"""
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(request.doc_id)
|
||||
@@ -240,15 +244,16 @@ async def create(
|
||||
request: CreateChunkRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""创建文档块"""
|
||||
chunck_id = xxhash.xxh64((request.content_with_weight + request.doc_id).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(request.content_with_weight),
|
||||
"content_with_weight": request.content_with_weight}
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
d["important_kwd"] = request.important_kwd
|
||||
d["important_kwd"] = request.important_kwd or []
|
||||
if not isinstance(d["important_kwd"], list):
|
||||
return get_data_error_result(message="`important_kwd` is required to be a list")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
d["question_kwd"] = request.question_kwd
|
||||
d["question_kwd"] = request.question_kwd or []
|
||||
if not isinstance(d["question_kwd"], list):
|
||||
return get_data_error_result(message="`question_kwd` is required to be a list")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
@@ -296,8 +301,9 @@ async def retrieval_test(
|
||||
request: RetrievalTestRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
page = request.page
|
||||
size = request.size
|
||||
"""检索测试"""
|
||||
page = request.page or 1
|
||||
size = request.size or 30
|
||||
question = request.question
|
||||
kb_ids = request.kb_id
|
||||
if isinstance(kb_ids, str):
|
||||
@@ -306,10 +312,10 @@ async def retrieval_test(
|
||||
return get_json_result(data=False, message='Please specify dataset firstly.',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
|
||||
doc_ids = request.doc_ids
|
||||
use_kg = request.use_kg
|
||||
top = request.top_k
|
||||
langs = request.cross_languages
|
||||
doc_ids = request.doc_ids or []
|
||||
use_kg = request.use_kg or False
|
||||
top = request.top_k or 1024
|
||||
langs = request.cross_languages or []
|
||||
tenant_ids = []
|
||||
|
||||
if request.search_id:
|
||||
@@ -358,15 +364,16 @@ async def retrieval_test(
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(request.similarity_threshold),
|
||||
float(request.vector_similarity_weight),
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(request.similarity_threshold or 0.0),
|
||||
float(request.vector_similarity_weight or 0.3),
|
||||
top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=request.highlight,
|
||||
doc_ids, rerank_mdl=rerank_mdl,
|
||||
highlight=request.highlight or False,
|
||||
rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(question,
|
||||
ck = settings.kg_retriever.retrieval(question,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
@@ -391,13 +398,14 @@ async def knowledge_graph(
|
||||
doc_id: str = Query(..., description="文档ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取知识图谱"""
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
req = {
|
||||
"doc_ids": [doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
||||
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
|
||||
@@ -17,15 +17,30 @@ import json
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from api.utils.api_utils import security
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile, File, Form
|
||||
from fastapi.responses import Response
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.document_models import (
|
||||
CreateDocumentRequest,
|
||||
WebCrawlRequest,
|
||||
ListDocumentsQuery,
|
||||
ListDocumentsBody,
|
||||
FilterDocumentsRequest,
|
||||
GetDocumentInfosRequest,
|
||||
ChangeStatusRequest,
|
||||
DeleteDocumentRequest,
|
||||
RunDocumentRequest,
|
||||
RenameDocumentRequest,
|
||||
ChangeParserRequest,
|
||||
ChangeParserSimpleRequest,
|
||||
UploadAndParseRequest,
|
||||
ParseRequest,
|
||||
SetMetaRequest,
|
||||
)
|
||||
from api import settings
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
@@ -43,159 +58,29 @@ from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, get_project_base_directory, thumbnail
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
|
||||
from deepdoc.parser.html_parser import RAGFlowHtmlParser
|
||||
from rag.nlp import search
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from pydantic import BaseModel
|
||||
from api.db.db_models import User
|
||||
|
||||
# Security
|
||||
|
||||
# Pydantic models for request/response
|
||||
class WebCrawlRequest(BaseModel):
|
||||
kb_id: str
|
||||
name: str
|
||||
url: str
|
||||
|
||||
class CreateDocumentRequest(BaseModel):
|
||||
name: str
|
||||
kb_id: str
|
||||
|
||||
class DocumentListRequest(BaseModel):
|
||||
run_status: List[str] = []
|
||||
types: List[str] = []
|
||||
suffix: List[str] = []
|
||||
|
||||
class DocumentFilterRequest(BaseModel):
|
||||
kb_id: str
|
||||
keywords: str = ""
|
||||
run_status: List[str] = []
|
||||
types: List[str] = []
|
||||
suffix: List[str] = []
|
||||
|
||||
class DocumentInfosRequest(BaseModel):
|
||||
doc_ids: List[str]
|
||||
|
||||
class ChangeStatusRequest(BaseModel):
|
||||
doc_ids: List[str]
|
||||
status: str
|
||||
|
||||
class RemoveDocumentRequest(BaseModel):
|
||||
doc_id: List[str]
|
||||
|
||||
class RunDocumentRequest(BaseModel):
|
||||
doc_ids: List[str]
|
||||
run: int
|
||||
delete: bool = False
|
||||
|
||||
class RenameDocumentRequest(BaseModel):
|
||||
doc_id: str
|
||||
name: str
|
||||
|
||||
class ChangeParserRequest(BaseModel):
|
||||
doc_id: str
|
||||
parser_id: str
|
||||
pipeline_id: Optional[str] = None
|
||||
parser_config: Optional[dict] = None
|
||||
|
||||
class UploadAndParseRequest(BaseModel):
|
||||
conversation_id: str
|
||||
|
||||
class ParseRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
|
||||
class SetMetaRequest(BaseModel):
|
||||
doc_id: str
|
||||
meta: str
|
||||
|
||||
|
||||
# Dependency injection
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户"""
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.user_service import UserService
|
||||
from fastapi import HTTPException, status
|
||||
import logging
|
||||
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User {user[0].email} has empty access_token in database"
|
||||
)
|
||||
return user[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required"
|
||||
)
|
||||
|
||||
# Create router
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload(
|
||||
kb_id: str = Form(...),
|
||||
file: List[UploadFile] = File(...),
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if not file:
|
||||
"""上传文档"""
|
||||
if not files:
|
||||
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
# Use UploadFile directly - file is already a list from multiple file fields
|
||||
file_objs = file
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
for file_obj in files:
|
||||
if not file_obj.filename or file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
@@ -206,29 +91,30 @@ async def upload(
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
err, files = await FileService.upload_document(kb, file_objs, current_user.id)
|
||||
err, uploaded_files = FileService.upload_document(kb, files, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=uploaded_files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
if not files:
|
||||
return get_json_result(data=files, message="There seems to be an issue with your file format. Please verify it is correct and not corrupted.", code=settings.RetCode.DATA_ERROR)
|
||||
files = [f[0] for f in files] # remove the blob
|
||||
if not uploaded_files:
|
||||
return get_json_result(data=uploaded_files, message="There seems to be an issue with your file format. Please verify it is correct and not corrupted.", code=settings.RetCode.DATA_ERROR)
|
||||
files_result = [f[0] for f in uploaded_files] # remove the blob
|
||||
|
||||
return get_json_result(data=files)
|
||||
return get_json_result(data=files_result)
|
||||
|
||||
|
||||
@router.post("/web_crawl")
|
||||
async def web_crawl(
|
||||
req: WebCrawlRequest,
|
||||
request: WebCrawlRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_id = req.kb_id
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
name = req.name
|
||||
url = req.url
|
||||
"""网页爬取"""
|
||||
kb_id = request.kb_id
|
||||
name = request.name
|
||||
url = request.url
|
||||
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
@@ -259,6 +145,7 @@ async def web_crawl(
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": filetype,
|
||||
@@ -285,25 +172,26 @@ async def web_crawl(
|
||||
|
||||
@router.post("/create")
|
||||
async def create(
|
||||
req: CreateDocumentRequest,
|
||||
request: CreateDocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_id = req.kb_id
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if len(req.name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
"""创建文档"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
kb_id = req["kb_id"]
|
||||
|
||||
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if req.name.strip() == "":
|
||||
if req["name"].strip() == "":
|
||||
return get_json_result(data=False, message="File name can't be empty.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
req.name = req.name.strip()
|
||||
req["name"] = req["name"].strip()
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Can't find this knowledgebase!")
|
||||
|
||||
if DocumentService.query(name=req.name, kb_id=kb_id):
|
||||
if DocumentService.query(name=req["name"], kb_id=kb_id):
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
kb_root_folder = FileService.get_kb_folder(kb.tenant_id)
|
||||
@@ -326,8 +214,8 @@ async def create(
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": FileType.VIRTUAL,
|
||||
"name": req.name,
|
||||
"suffix": Path(req.name).suffix.lstrip("."),
|
||||
"name": req["name"],
|
||||
"suffix": Path(req["name"]).suffix.lstrip("."),
|
||||
"location": "",
|
||||
"size": 0,
|
||||
}
|
||||
@@ -342,47 +230,46 @@ async def create(
|
||||
|
||||
@router.post("/list")
|
||||
async def list_docs(
|
||||
kb_id: str = Query(...),
|
||||
keywords: str = Query(""),
|
||||
page: int = Query(0),
|
||||
page_size: int = Query(0),
|
||||
orderby: str = Query("create_time"),
|
||||
desc: str = Query("true"),
|
||||
create_time_from: int = Query(0),
|
||||
create_time_to: int = Query(0),
|
||||
req: DocumentListRequest = None,
|
||||
query: ListDocumentsQuery = Depends(),
|
||||
body: Optional[ListDocumentsBody] = None,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
"""列出文档"""
|
||||
if body is None:
|
||||
body = ListDocumentsBody()
|
||||
|
||||
kb_id = query.kb_id
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
keywords = query.keywords or ""
|
||||
page_number = int(query.page or 0)
|
||||
items_per_page = int(query.page_size or 0)
|
||||
orderby = query.orderby or "create_time"
|
||||
desc = query.desc.lower() == "true" if query.desc else True
|
||||
create_time_from = int(query.create_time_from or 0)
|
||||
create_time_to = int(query.create_time_to or 0)
|
||||
|
||||
if desc.lower() == "false":
|
||||
desc_bool = False
|
||||
else:
|
||||
desc_bool = True
|
||||
|
||||
run_status = req.run_status if req else []
|
||||
run_status = body.run_status or []
|
||||
if run_status:
|
||||
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = req.types if req else []
|
||||
types = body.types or []
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
suffix = req.suffix if req else []
|
||||
suffix = body.suffix or []
|
||||
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page, page_size, orderby, desc_bool, keywords, run_status, types, suffix)
|
||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix)
|
||||
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
@@ -403,12 +290,11 @@ async def list_docs(
|
||||
|
||||
@router.post("/filter")
|
||||
async def get_filter(
|
||||
req: DocumentFilterRequest,
|
||||
request: FilterDocumentsRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_id = req.kb_id
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
"""过滤文档"""
|
||||
kb_id = request.kb_id
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
@@ -416,15 +302,16 @@ async def get_filter(
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
keywords = req.keywords
|
||||
suffix = req.suffix
|
||||
run_status = req.run_status
|
||||
keywords = request.keywords or ""
|
||||
suffix = request.suffix or []
|
||||
run_status = request.run_status or []
|
||||
|
||||
if run_status:
|
||||
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = req.types
|
||||
types = request.types or []
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
@@ -439,10 +326,11 @@ async def get_filter(
|
||||
|
||||
@router.post("/infos")
|
||||
async def docinfos(
|
||||
req: DocumentInfosRequest,
|
||||
request: GetDocumentInfosRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
doc_ids = req.doc_ids
|
||||
"""获取文档信息"""
|
||||
doc_ids = request.doc_ids
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
@@ -452,8 +340,9 @@ async def docinfos(
|
||||
|
||||
@router.get("/thumbnails")
|
||||
async def thumbnails(
|
||||
doc_ids: List[str] = Query(...)
|
||||
doc_ids: List[str] = Query(..., description="文档ID列表"),
|
||||
):
|
||||
"""获取文档缩略图"""
|
||||
if not doc_ids:
|
||||
return get_json_result(data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
@@ -471,14 +360,12 @@ async def thumbnails(
|
||||
|
||||
@router.post("/change_status")
|
||||
async def change_status(
|
||||
req: ChangeStatusRequest,
|
||||
request: ChangeStatusRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
doc_ids = req.doc_ids
|
||||
status = str(req.status)
|
||||
|
||||
if status not in ["0", "1"]:
|
||||
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
"""修改文档状态"""
|
||||
doc_ids = request.doc_ids
|
||||
status = request.status
|
||||
|
||||
result = {}
|
||||
for doc_id in doc_ids:
|
||||
@@ -511,10 +398,11 @@ async def change_status(
|
||||
|
||||
@router.post("/rm")
|
||||
async def rm(
|
||||
req: RemoveDocumentRequest,
|
||||
request: DeleteDocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
doc_ids = req.doc_id
|
||||
"""删除文档"""
|
||||
doc_ids = request.doc_id
|
||||
if isinstance(doc_ids, str):
|
||||
doc_ids = [doc_ids]
|
||||
|
||||
@@ -570,17 +458,19 @@ async def rm(
|
||||
|
||||
@router.post("/run")
|
||||
async def run(
|
||||
req: RunDocumentRequest,
|
||||
request: RunDocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
for doc_id in req.doc_ids:
|
||||
"""运行文档解析"""
|
||||
for doc_id in request.doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
kb_table_num_map = {}
|
||||
for id in req.doc_ids:
|
||||
info = {"run": str(req.run), "progress": 0}
|
||||
if req.run == int(TaskStatus.RUNNING.value) and req.delete:
|
||||
for id in request.doc_ids:
|
||||
info = {"run": str(request.run), "progress": 0}
|
||||
if str(request.run) == TaskStatus.RUNNING.value and request.delete:
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
@@ -592,21 +482,21 @@ async def run(
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
if req.run == int(TaskStatus.CANCEL.value):
|
||||
if str(request.run) == TaskStatus.CANCEL.value:
|
||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||
cancel_all_task_of(id)
|
||||
else:
|
||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||
if all([req.delete, req.run == int(TaskStatus.RUNNING.value), str(doc.run) == TaskStatus.DONE.value]):
|
||||
if all([not request.delete, str(request.run) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.delete:
|
||||
if request.delete:
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
if req.run == int(TaskStatus.RUNNING.value):
|
||||
if str(request.run) == TaskStatus.RUNNING.value:
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
|
||||
@@ -628,37 +518,53 @@ async def run(
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/rename")
|
||||
async def rename(
|
||||
req: RenameDocumentRequest,
|
||||
request: RenameDocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not DocumentService.accessible(req.doc_id, current_user.id):
|
||||
"""重命名文档"""
|
||||
if not DocumentService.accessible(request.doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req.doc_id)
|
||||
e, doc = DocumentService.get_by_id(request.doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if pathlib.Path(req.name.lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
if pathlib.Path(request.name.lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if len(req.name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
if len(request.name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
for d in DocumentService.query(name=req.name, kb_id=doc.kb_id):
|
||||
if d.name == req.name:
|
||||
for d in DocumentService.query(name=request.name, kb_id=doc.kb_id):
|
||||
if d.name == request.name:
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
if not DocumentService.update_by_id(req.doc_id, {"name": req.name}):
|
||||
if not DocumentService.update_by_id(request.doc_id, {"name": request.name}):
|
||||
return get_data_error_result(message="Database error (Document rename)!")
|
||||
|
||||
informs = File2DocumentService.get_by_document_id(req.doc_id)
|
||||
informs = File2DocumentService.get_by_document_id(request.doc_id)
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
FileService.update_by_id(file.id, {"name": req.name})
|
||||
FileService.update_by_id(file.id, {"name": request.name})
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(request.doc_id)
|
||||
title_tks = rag_tokenizer.tokenize(request.name)
|
||||
es_body = {
|
||||
"docnm_kwd": request.name,
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": request.doc_id},
|
||||
es_body,
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@@ -667,6 +573,7 @@ async def rename(
|
||||
|
||||
@router.get("/get/{doc_id}")
|
||||
async def get(doc_id: str):
|
||||
"""获取文档文件"""
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
@@ -677,70 +584,79 @@ async def get(doc_id: str):
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
|
||||
content_type = "application/octet-stream"
|
||||
if ext:
|
||||
if doc.type == FileType.VISUAL.value:
|
||||
media_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
else:
|
||||
media_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
else:
|
||||
media_type = "application/octet-stream"
|
||||
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f"attachment; filename={doc.name}"}
|
||||
)
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
|
||||
return Response(content=content, media_type=content_type)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/change_parser")
|
||||
async def change_parser(
|
||||
req: ChangeParserRequest,
|
||||
request: ChangeParserSimpleRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not DocumentService.accessible(req.doc_id, current_user.id):
|
||||
"""修改文档解析器"""
|
||||
if not DocumentService.accessible(request.doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req.doc_id)
|
||||
e, doc = DocumentService.get_by_id(request.doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
def reset_doc():
|
||||
def reset_doc(update_data_override=None):
|
||||
nonlocal doc
|
||||
e = DocumentService.update_by_id(doc.id, {"parser_id": req.parser_id, "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
|
||||
update_data = update_data_override or {}
|
||||
if request.pipeline_id is not None:
|
||||
update_data["pipeline_id"] = request.pipeline_id
|
||||
if request.parser_id is not None:
|
||||
update_data["parser_id"] = request.parser_id
|
||||
update_data.update({
|
||||
"progress": 0,
|
||||
"progress_msg": "",
|
||||
"run": TaskStatus.UNSTART.value
|
||||
})
|
||||
e = DocumentService.update_by_id(doc.id, update_data)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if doc.token_num > 0:
|
||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(req.doc_id)
|
||||
tenant_id = DocumentService.get_tenant_id(request.doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
try:
|
||||
if req.pipeline_id:
|
||||
if doc.pipeline_id == req.pipeline_id:
|
||||
if request.pipeline_id is not None and request.pipeline_id != "":
|
||||
if doc.pipeline_id == request.pipeline_id:
|
||||
return get_json_result(data=True)
|
||||
DocumentService.update_by_id(doc.id, {"pipeline_id": req.pipeline_id})
|
||||
reset_doc()
|
||||
reset_doc({"pipeline_id": request.pipeline_id})
|
||||
return get_json_result(data=True)
|
||||
|
||||
if doc.parser_id.lower() == req.parser_id.lower():
|
||||
if req.parser_config:
|
||||
if req.parser_config == doc.parser_config:
|
||||
if request.parser_id is None:
|
||||
return get_json_result(data=False, message="缺少 parser_id 或 pipeline_id", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if doc.parser_id.lower() == request.parser_id.lower():
|
||||
if request.parser_config is not None:
|
||||
if request.parser_config == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if (doc.type == FileType.VISUAL and req.parser_id != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req.parser_id != "presentation"):
|
||||
if (doc.type == FileType.VISUAL and request.parser_id != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and request.parser_id != "presentation"):
|
||||
return get_data_error_result(message="Not supported yet!")
|
||||
if req.parser_config:
|
||||
DocumentService.update_parser_config(doc.id, req.parser_config)
|
||||
|
||||
if request.parser_config is not None:
|
||||
DocumentService.update_parser_config(doc.id, request.parser_config)
|
||||
|
||||
reset_doc()
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@@ -749,16 +665,14 @@ async def change_parser(
|
||||
|
||||
@router.get("/image/{image_id}")
|
||||
async def get_image(image_id: str):
|
||||
"""获取图片"""
|
||||
try:
|
||||
arr = image_id.split("-")
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
content = STORAGE_IMPL.get(bkt, nm)
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type="image/JPEG"
|
||||
)
|
||||
return Response(content=content, media_type="image/JPEG")
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@@ -769,28 +683,28 @@ async def upload_and_parse(
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""上传并解析"""
|
||||
if not files:
|
||||
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
# Use UploadFile directly
|
||||
file_objs = files
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
for file_obj in files:
|
||||
if not file_obj.filename or file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = await doc_upload_and_parse(conversation_id, file_objs, current_user.id)
|
||||
doc_ids = doc_upload_and_parse(conversation_id, files, current_user.id)
|
||||
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@router.post("/parse")
|
||||
async def parse(
|
||||
req: ParseRequest = None,
|
||||
files: List[UploadFile] = File(None),
|
||||
request: Optional[ParseRequest] = None,
|
||||
files: Optional[List[UploadFile]] = File(None),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
url = req.url if req else ""
|
||||
"""解析文档"""
|
||||
url = request.url if request else None
|
||||
|
||||
if url:
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
@@ -828,46 +742,37 @@ async def parse(
|
||||
if not r or not r.group(1):
|
||||
return get_json_result(data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
f = File(r.group(1), os.path.join(download_path, r.group(1)))
|
||||
txt = await FileService.parse_docs([f], current_user.id)
|
||||
txt = FileService.parse_docs([f], current_user.id)
|
||||
return get_json_result(data=txt)
|
||||
|
||||
if not files:
|
||||
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
# Use UploadFile directly
|
||||
file_objs = files
|
||||
txt = await FileService.parse_docs(file_objs, current_user.id)
|
||||
txt = FileService.parse_docs(files, current_user.id)
|
||||
|
||||
return get_json_result(data=txt)
|
||||
|
||||
|
||||
@router.post("/set_meta")
|
||||
async def set_meta(
|
||||
req: SetMetaRequest,
|
||||
request: SetMetaRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not DocumentService.accessible(req.doc_id, current_user.id):
|
||||
"""设置元数据"""
|
||||
if not DocumentService.accessible(request.doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
meta = json.loads(req.meta)
|
||||
if not isinstance(meta, dict):
|
||||
return get_json_result(data=False, message="Only dictionary type supported.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
for k, v in meta.items():
|
||||
if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float):
|
||||
return get_json_result(data=False, message=f"The type is not supported: {v}", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
except Exception as e:
|
||||
return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if not isinstance(meta, dict):
|
||||
return get_json_result(data=False, message='Meta data should be in Json map format, like {"key": "value"}', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req.doc_id)
|
||||
meta = json.loads(request.meta)
|
||||
|
||||
e, doc = DocumentService.get_by_id(request.doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
if not DocumentService.update_by_id(req.doc_id, {"meta_fields": meta}):
|
||||
if not DocumentService.update_by_id(request.doc_id, {"meta_fields": meta}):
|
||||
return get_data_error_result(message="Database error (meta updates)!")
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@@ -15,15 +15,12 @@
|
||||
#
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from api.utils.api_utils import security
|
||||
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
@@ -31,91 +28,15 @@ from api.db import FileType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Security
|
||||
|
||||
# Pydantic models for request/response
|
||||
class ConvertRequest(BaseModel):
|
||||
file_ids: List[str]
|
||||
kb_ids: List[str]
|
||||
|
||||
class RemoveFile2DocumentRequest(BaseModel):
|
||||
file_ids: List[str]
|
||||
|
||||
# Dependency injection
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户"""
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.user_service import UserService
|
||||
from fastapi import HTTPException, status
|
||||
import logging
|
||||
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User {user[0].email} has empty access_token in database"
|
||||
)
|
||||
return user[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required"
|
||||
)
|
||||
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/convert')
|
||||
async def convert(
|
||||
req: ConvertRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_ids = req.kb_ids
|
||||
file_ids = req.file_ids
|
||||
@manager.route('/convert', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids", "kb_ids")
|
||||
def convert():
|
||||
req = request.json
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
|
||||
try:
|
||||
@@ -179,12 +100,12 @@ async def convert(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
req: RemoveFile2DocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
file_ids = req.file_ids
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
file_ids = req["file_ids"]
|
||||
if not file_ids:
|
||||
return get_json_result(
|
||||
data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
@@ -13,15 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from api.utils.api_utils import security
|
||||
import flask
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
@@ -36,110 +35,22 @@ from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Security
|
||||
|
||||
# Pydantic models for request/response
|
||||
class CreateFileRequest(BaseModel):
|
||||
name: str
|
||||
parent_id: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
|
||||
class RemoveFileRequest(BaseModel):
|
||||
file_ids: List[str]
|
||||
|
||||
class RenameFileRequest(BaseModel):
|
||||
file_id: str
|
||||
name: str
|
||||
|
||||
class MoveFileRequest(BaseModel):
|
||||
src_file_ids: List[str]
|
||||
dest_file_id: str
|
||||
|
||||
# Dependency injection
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户"""
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.user_service import UserService
|
||||
from fastapi import HTTPException, status
|
||||
import logging
|
||||
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User {user[0].email} has empty access_token in database"
|
||||
)
|
||||
return user[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required"
|
||||
)
|
||||
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/upload')
|
||||
async def upload(
|
||||
parent_id: Optional[str] = Form(None),
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
pf_id = parent_id
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
# @validate_request("parent_id")
|
||||
def upload():
|
||||
pf_id = request.form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if not files:
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = files
|
||||
file_objs = request.files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
@@ -186,7 +97,7 @@ async def upload(
|
||||
location = file_obj_names[file_len - 1]
|
||||
while STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
location += "_"
|
||||
blob = await file_obj.read()
|
||||
blob = file_obj.read()
|
||||
filename = duplicate_name(
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
@@ -209,13 +120,13 @@ async def upload(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/create')
|
||||
async def create(
|
||||
req: CreateFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
pf_id = req.parent_id
|
||||
input_file_type = req.type
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
@@ -224,7 +135,7 @@ async def create(
|
||||
if not FileService.is_parent_folder_exist(pf_id):
|
||||
return get_json_result(
|
||||
data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
|
||||
if FileService.query(name=req.name, parent_id=pf_id):
|
||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||
return get_data_error_result(
|
||||
message="Duplicated folder name in the same folder.")
|
||||
|
||||
@@ -238,7 +149,7 @@ async def create(
|
||||
"parent_id": pf_id,
|
||||
"tenant_id": current_user.id,
|
||||
"created_by": current_user.id,
|
||||
"name": req.name,
|
||||
"name": req["name"],
|
||||
"location": "",
|
||||
"size": 0,
|
||||
"type": file_type
|
||||
@@ -249,18 +160,17 @@ async def create(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/list')
|
||||
async def list_files(
|
||||
parent_id: Optional[str] = Query(None),
|
||||
keywords: str = Query(""),
|
||||
page: int = Query(1),
|
||||
page_size: int = Query(15),
|
||||
orderby: str = Query("create_time"),
|
||||
desc: bool = Query(True),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
pf_id = parent_id
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_files():
|
||||
pf_id = request.args.get("parent_id")
|
||||
|
||||
keywords = request.args.get("keywords", "")
|
||||
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 15))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
@@ -271,7 +181,7 @@ async def list_files(
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
|
||||
files, total = FileService.get_by_pf_id(
|
||||
current_user.id, pf_id, page, page_size, orderby, desc, keywords)
|
||||
current_user.id, pf_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(pf_id)
|
||||
if not parent_folder:
|
||||
@@ -282,8 +192,9 @@ async def list_files(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/root_folder')
|
||||
async def get_root_folder(current_user = Depends(get_current_user)):
|
||||
@manager.route('/root_folder', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get_root_folder():
|
||||
try:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
return get_json_result(data={"root_folder": root_folder})
|
||||
@@ -291,11 +202,10 @@ async def get_root_folder(current_user = Depends(get_current_user)):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/parent_folder')
|
||||
async def get_parent_folder(
|
||||
file_id: str = Query(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
@manager.route('/parent_folder', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get_parent_folder():
|
||||
file_id = request.args.get("file_id")
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@@ -307,11 +217,10 @@ async def get_parent_folder(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/all_parent_folder')
|
||||
async def get_all_parent_folders(
|
||||
file_id: str = Query(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
@manager.route('/all_parent_folder', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get_all_parent_folders():
|
||||
file_id = request.args.get("file_id")
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@@ -326,90 +235,99 @@ async def get_all_parent_folders(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
req: RemoveFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
file_ids = req.file_ids
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
file_ids = req["file_ids"]
|
||||
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
if file.location:
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if e and doc:
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if tenant_id:
|
||||
DocumentService.remove_document(doc, tenant_id)
|
||||
File2DocumentService.delete_by_file_id(file.id)
|
||||
|
||||
FileService.delete(file)
|
||||
|
||||
def _delete_folder_recursive(folder, tenant_id):
|
||||
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
||||
for sub_file in sub_files:
|
||||
if sub_file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(sub_file, tenant_id)
|
||||
else:
|
||||
_delete_single_file(sub_file)
|
||||
|
||||
FileService.delete(folder)
|
||||
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
if not e or not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for inner_file_id in file_id_list:
|
||||
e, file = FileService.get_by_id(inner_file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="File not found!")
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
FileService.delete_folder_by_pf_id(current_user.id, file_id)
|
||||
else:
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
if not FileService.delete(file):
|
||||
return get_data_error_result(
|
||||
message="Database error (File removal)!")
|
||||
_delete_folder_recursive(file, current_user.id)
|
||||
continue
|
||||
|
||||
# delete file2document
|
||||
informs = File2DocumentService.get_by_file_id(file_id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
File2DocumentService.delete_by_file_id(file_id)
|
||||
_delete_single_file(file)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/rename')
|
||||
async def rename(
|
||||
req: RenameFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
@manager.route('/rename', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_id", "name")
|
||||
def rename():
|
||||
req = request.json
|
||||
try:
|
||||
e, file = FileService.get_by_id(req.file_id)
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="File not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
if file.type != FileType.FOLDER.value \
|
||||
and pathlib.Path(req.name.lower()).suffix != pathlib.Path(
|
||||
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="The extension of file can't be changed",
|
||||
code=settings.RetCode.ARGUMENT_ERROR)
|
||||
for file in FileService.query(name=req.name, pf_id=file.parent_id):
|
||||
if file.name == req.name:
|
||||
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if file.name == req["name"]:
|
||||
return get_data_error_result(
|
||||
message="Duplicated file name in the same folder.")
|
||||
|
||||
if not FileService.update_by_id(
|
||||
req.file_id, {"name": req.name}):
|
||||
req["file_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(
|
||||
message="Database error (File rename)!")
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(req.file_id)
|
||||
informs = File2DocumentService.get_by_file_id(req["file_id"])
|
||||
if informs:
|
||||
if not DocumentService.update_by_id(
|
||||
informs[0].document_id, {"name": req.name}):
|
||||
informs[0].document_id, {"name": req["name"]}):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document rename)!")
|
||||
|
||||
@@ -418,8 +336,9 @@ async def rename(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/get/{file_id}')
|
||||
async def get(file_id: str, current_user = Depends(get_current_user)):
|
||||
@manager.route('/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get(file_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
@@ -432,6 +351,7 @@ async def get(file_id: str, current_user = Depends(get_current_user)):
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
@@ -439,43 +359,95 @@ async def get(file_id: str, current_user = Depends(get_current_user)):
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
else:
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
else:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
return StreamingResponse(
|
||||
iter([blob]),
|
||||
media_type=content_type,
|
||||
headers={"Content-Disposition": f"attachment; filename={file.name}"}
|
||||
)
|
||||
response.headers.set("Content-Type", content_type)
|
||||
return response
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/mv')
|
||||
async def move(
|
||||
req: MoveFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
@manager.route("/mv", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("src_file_ids", "dest_file_id")
|
||||
def move():
|
||||
req = request.json
|
||||
try:
|
||||
file_ids = req.src_file_ids
|
||||
parent_id = req.dest_file_id
|
||||
file_ids = req["src_file_ids"]
|
||||
dest_parent_id = req["dest_file_id"]
|
||||
|
||||
ok, dest_folder = FileService.get_by_id(dest_parent_id)
|
||||
if not ok or not dest_folder:
|
||||
return get_data_error_result(message="Parent Folder not found!")
|
||||
|
||||
files = FileService.get_by_ids(file_ids)
|
||||
files_dict = {}
|
||||
for file in files:
|
||||
files_dict[file.id] = file
|
||||
if not files:
|
||||
return get_data_error_result(message="Source files not found!")
|
||||
|
||||
files_dict = {f.id: f for f in files}
|
||||
|
||||
for file_id in file_ids:
|
||||
file = files_dict[file_id]
|
||||
file = files_dict.get(file_id)
|
||||
if not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_data_error_result(message="Parent Folder not found!")
|
||||
FileService.move_file(file_ids, parent_id)
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="No authorization.",
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||
)
|
||||
|
||||
def _move_entry_recursive(source_file_entry, dest_folder):
|
||||
if source_file_entry.type == FileType.FOLDER.value:
|
||||
existing_folder = FileService.query(name=source_file_entry.name, parent_id=dest_folder.id)
|
||||
if existing_folder:
|
||||
new_folder = existing_folder[0]
|
||||
else:
|
||||
new_folder = FileService.insert(
|
||||
{
|
||||
"id": get_uuid(),
|
||||
"parent_id": dest_folder.id,
|
||||
"tenant_id": source_file_entry.tenant_id,
|
||||
"created_by": current_user.id,
|
||||
"name": source_file_entry.name,
|
||||
"location": "",
|
||||
"size": 0,
|
||||
"type": FileType.FOLDER.value,
|
||||
}
|
||||
)
|
||||
|
||||
sub_files = FileService.list_all_files_by_parent_id(source_file_entry.id)
|
||||
for sub_file in sub_files:
|
||||
_move_entry_recursive(sub_file, new_folder)
|
||||
|
||||
FileService.delete_by_id(source_file_entry.id)
|
||||
return
|
||||
|
||||
old_parent_id = source_file_entry.parent_id
|
||||
old_location = source_file_entry.location
|
||||
filename = source_file_entry.name
|
||||
|
||||
new_location = filename
|
||||
while STORAGE_IMPL.obj_exist(dest_folder.id, new_location):
|
||||
new_location += "_"
|
||||
|
||||
try:
|
||||
STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location)
|
||||
except Exception as storage_err:
|
||||
raise RuntimeError(f"Move file failed at storage layer: {str(storage_err)}")
|
||||
|
||||
FileService.update_by_id(
|
||||
source_file_entry.id,
|
||||
{
|
||||
"parent_id": dest_folder.id,
|
||||
"location": new_location,
|
||||
},
|
||||
)
|
||||
|
||||
for file in files:
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
@@ -15,27 +15,29 @@
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from api.models.kb_models import (
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.kb_models import (
|
||||
CreateKnowledgeBaseRequest,
|
||||
UpdateKnowledgeBaseRequest,
|
||||
DeleteKnowledgeBaseRequest,
|
||||
ListKnowledgeBasesRequest,
|
||||
ListKnowledgeBasesQuery,
|
||||
ListKnowledgeBasesBody,
|
||||
RemoveTagsRequest,
|
||||
RenameTagRequest,
|
||||
RunGraphRAGRequest,
|
||||
ListPipelineLogsQuery,
|
||||
ListPipelineLogsBody,
|
||||
ListPipelineDatasetLogsQuery,
|
||||
ListPipelineDatasetLogsBody,
|
||||
DeletePipelineLogsQuery,
|
||||
DeletePipelineLogsBody,
|
||||
RunGraphragRequest,
|
||||
RunRaptorRequest,
|
||||
RunMindmapRequest,
|
||||
ListPipelineLogsRequest,
|
||||
ListPipelineDatasetLogsRequest,
|
||||
DeletePipelineLogsRequest,
|
||||
UnbindTaskRequest
|
||||
)
|
||||
from api.utils.api_utils import get_current_user
|
||||
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
@@ -44,7 +46,12 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, get_json_result
|
||||
from api.utils.api_utils import (
|
||||
get_error_data_result,
|
||||
server_error_response,
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
)
|
||||
from api.utils import get_uuid
|
||||
from api.db import PipelineTaskType, StatusEnum, FileSource, VALID_FILE_TYPES, VALID_TASK_STATUS
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@@ -53,9 +60,10 @@ from api import settings
|
||||
from rag.nlp import search
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
# 创建 FastAPI 路由器
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -64,7 +72,14 @@ async def create(
|
||||
request: CreateKnowledgeBaseRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
dataset_name = request.name
|
||||
"""创建知识库
|
||||
|
||||
支持两种解析类型:
|
||||
- parse_type=1: 使用内置解析器,需要 parser_id
|
||||
- parse_type=2: 使用自定义 pipeline,需要 pipeline_id
|
||||
"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
dataset_name = req["name"]
|
||||
if not isinstance(dataset_name, str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if dataset_name.strip() == "":
|
||||
@@ -80,55 +95,66 @@ async def create(
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value)
|
||||
try:
|
||||
req = {
|
||||
"id": get_uuid(),
|
||||
"name": dataset_name,
|
||||
"tenant_id": current_user.id,
|
||||
"created_by": current_user.id,
|
||||
"parser_id": request.parser_id or "naive",
|
||||
"description": request.description
|
||||
}
|
||||
# 根据 parse_type 处理 parser_id 和 pipeline_id
|
||||
parse_type = req.pop("parse_type", 1) # 移除 parse_type,不需要存储到数据库
|
||||
|
||||
if parse_type == 1:
|
||||
# 使用内置解析器,需要 parser_id
|
||||
# 验证器已经确保 parser_id 不为空,但保留默认值逻辑以防万一
|
||||
if not req.get("parser_id") or req["parser_id"].strip() == "":
|
||||
req["parser_id"] = "naive"
|
||||
# 清空 pipeline_id(设置为 None,数据库字段允许为 null)
|
||||
req["pipeline_id"] = None
|
||||
elif parse_type == 2:
|
||||
# 使用自定义 pipeline,需要 pipeline_id
|
||||
# 验证器已经确保 pipeline_id 不为空
|
||||
# parser_id 应该为空字符串,但数据库字段不允许 null,所以不设置 parser_id
|
||||
# 让数据库使用默认值 "naive"(虽然用户传入的是空字符串,但数据库会处理)
|
||||
# 如果用户明确传入了空字符串,我们也不设置它,让数据库使用默认值
|
||||
if "parser_id" in req and (not req["parser_id"] or req["parser_id"].strip() == ""):
|
||||
# 移除空字符串的 parser_id,让数据库使用默认值
|
||||
req.pop("parser_id")
|
||||
# pipeline_id 保留在 req 中,会被保存到数据库
|
||||
|
||||
req["id"] = get_uuid()
|
||||
req["name"] = dataset_name
|
||||
req["tenant_id"] = current_user.id
|
||||
req["created_by"] = current_user.id
|
||||
|
||||
# embd_id 已经在模型中定义为必需字段,直接使用
|
||||
|
||||
e, t = TenantService.get_by_id(current_user.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
|
||||
# 设置 embd_id 默认值
|
||||
if not request.embd_id:
|
||||
req["embd_id"] = t.embd_id
|
||||
else:
|
||||
req["embd_id"] = request.embd_id
|
||||
|
||||
if request.parser_config:
|
||||
req["parser_config"] = request.parser_config
|
||||
else:
|
||||
req["parser_config"] = {
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 512,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": False,
|
||||
"topn_tags": 3,
|
||||
"raptor": {
|
||||
"use_raptor": True,
|
||||
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
|
||||
"max_token": 256,
|
||||
"threshold": 0.1,
|
||||
"max_cluster": 64,
|
||||
"random_seed": 0
|
||||
},
|
||||
"graphrag": {
|
||||
"use_graphrag": True,
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
"geo",
|
||||
"event",
|
||||
"category"
|
||||
],
|
||||
"method": "light"
|
||||
}
|
||||
|
||||
req["parser_config"] = {
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 512,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": False,
|
||||
"topn_tags": 3,
|
||||
"raptor": {
|
||||
"use_raptor": True,
|
||||
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
|
||||
"max_token": 256,
|
||||
"threshold": 0.1,
|
||||
"max_cluster": 64,
|
||||
"random_seed": 0
|
||||
},
|
||||
"graphrag": {
|
||||
"use_graphrag": True,
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
"geo",
|
||||
"event",
|
||||
"category"
|
||||
],
|
||||
"method": "light"
|
||||
}
|
||||
}
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": req["id"]})
|
||||
@@ -141,16 +167,24 @@ async def update(
|
||||
request: UpdateKnowledgeBaseRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not isinstance(request.name, str):
|
||||
"""更新知识库"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if request.name.strip() == "":
|
||||
if req["name"].strip() == "":
|
||||
return get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(request.name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return get_data_error_result(
|
||||
message=f"Dataset name length is {len(request.name)} which is large than {DATASET_NAME_LIMIT}")
|
||||
name = request.name.strip()
|
||||
message=f"Dataset name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
|
||||
req["name"] = req["name"].strip()
|
||||
|
||||
if not KnowledgebaseService.accessible4deletion(request.kb_id, current_user.id):
|
||||
# 验证不允许的参数
|
||||
not_allowed = ["id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date"]
|
||||
for key in not_allowed:
|
||||
if key in req:
|
||||
del req[key]
|
||||
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
@@ -158,48 +192,29 @@ async def update(
|
||||
)
|
||||
try:
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=request.kb_id):
|
||||
created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(request.kb_id)
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
|
||||
if name.lower() != kb.name.lower() \
|
||||
if req["name"].lower() != kb.name.lower() \
|
||||
and len(
|
||||
KnowledgebaseService.query(name=name, tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
||||
KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
||||
return get_data_error_result(
|
||||
message="Duplicated knowledgebase name.")
|
||||
|
||||
# 构建更新数据,包含所有可更新的字段
|
||||
update_data = {
|
||||
"name": name,
|
||||
"pagerank": request.pagerank
|
||||
}
|
||||
|
||||
# 添加可选字段(如果提供了的话)
|
||||
if request.description is not None:
|
||||
update_data["description"] = request.description
|
||||
if request.permission is not None:
|
||||
update_data["permission"] = request.permission
|
||||
if request.avatar is not None:
|
||||
update_data["avatar"] = request.avatar
|
||||
if request.parser_id is not None:
|
||||
update_data["parser_id"] = request.parser_id
|
||||
if request.embd_id is not None:
|
||||
update_data["embd_id"] = request.embd_id
|
||||
if request.parser_config is not None:
|
||||
update_data["parser_config"] = request.parser_config
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, update_data):
|
||||
kb_id = req.pop("kb_id")
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_data_error_result()
|
||||
|
||||
if kb.pagerank != request.pagerank:
|
||||
if request.pagerank > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: request.pagerank},
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
@@ -211,26 +226,7 @@ async def update(
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase rename)!")
|
||||
kb = kb.to_dict()
|
||||
|
||||
# 使用完整的请求数据更新返回结果,保持与原来代码的一致性
|
||||
request_data = {
|
||||
"name": name,
|
||||
"pagerank": request.pagerank
|
||||
}
|
||||
if request.description is not None:
|
||||
request_data["description"] = request.description
|
||||
if request.permission is not None:
|
||||
request_data["permission"] = request.permission
|
||||
if request.avatar is not None:
|
||||
request_data["avatar"] = request.avatar
|
||||
if request.parser_id is not None:
|
||||
request_data["parser_id"] = request.parser_id
|
||||
if request.embd_id is not None:
|
||||
request_data["embd_id"] = request.embd_id
|
||||
if request.parser_config is not None:
|
||||
request_data["parser_config"] = request.parser_config
|
||||
|
||||
kb.update(request_data)
|
||||
kb.update(req)
|
||||
|
||||
return get_json_result(data=kb)
|
||||
except Exception as e:
|
||||
@@ -242,6 +238,7 @@ async def detail(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取知识库详情"""
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
@@ -257,6 +254,9 @@ async def detail(
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
||||
for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
|
||||
if finish_at := kb.get(key):
|
||||
kb[key] = finish_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||
return get_json_result(data=kb)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@@ -264,18 +264,22 @@ async def detail(
|
||||
|
||||
@router.post('/list')
|
||||
async def list_kbs(
|
||||
request: ListKnowledgeBasesRequest,
|
||||
keywords: str = Query("", description="关键词"),
|
||||
page: int = Query(0, description="页码"),
|
||||
page_size: int = Query(0, description="每页大小"),
|
||||
parser_id: Optional[str] = Query(None, description="解析器ID"),
|
||||
orderby: str = Query("create_time", description="排序字段"),
|
||||
desc: bool = Query(True, description="是否降序"),
|
||||
query: ListKnowledgeBasesQuery = Depends(),
|
||||
body: Optional[ListKnowledgeBasesBody] = None,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
page_number = page
|
||||
items_per_page = page_size
|
||||
owner_ids = request.owner_ids
|
||||
"""列出知识库"""
|
||||
if body is None:
|
||||
body = ListKnowledgeBasesBody()
|
||||
|
||||
keywords = query.keywords or ""
|
||||
page_number = int(query.page or 0)
|
||||
items_per_page = int(query.page_size or 0)
|
||||
parser_id = query.parser_id
|
||||
orderby = query.orderby or "create_time"
|
||||
desc = query.desc.lower() == "true" if query.desc else True
|
||||
|
||||
owner_ids = body.owner_ids or [] if body else []
|
||||
try:
|
||||
if not owner_ids:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
@@ -296,11 +300,13 @@ async def list_kbs(
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
request: DeleteKnowledgeBaseRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除知识库"""
|
||||
if not KnowledgebaseService.accessible4deletion(request.kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -343,6 +349,7 @@ async def list_tags(
|
||||
kb_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出知识库标签"""
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -353,17 +360,18 @@ async def list_tags(
|
||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
tags = []
|
||||
for tenant in tenants:
|
||||
tags += settings.retrievaler.all_tags(tenant["tenant_id"], [kb_id])
|
||||
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@router.get('/tags')
|
||||
async def list_tags_from_kbs(
|
||||
kb_ids: str = Query(..., description="知识库ID列表,用逗号分隔"),
|
||||
kb_ids: str = Query(..., description="知识库ID列表,逗号分隔"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_ids = kb_ids.split(",")
|
||||
for kb_id in kb_ids:
|
||||
"""从多个知识库列出标签"""
|
||||
kb_id_list = kb_ids.split(",")
|
||||
for kb_id in kb_id_list:
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -374,7 +382,7 @@ async def list_tags_from_kbs(
|
||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
tags = []
|
||||
for tenant in tenants:
|
||||
tags += settings.retrievaler.all_tags(tenant["tenant_id"], kb_ids)
|
||||
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_id_list)
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@@ -384,6 +392,7 @@ async def rm_tags(
|
||||
request: RemoveTagsRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除知识库标签"""
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -406,6 +415,7 @@ async def rename_tags(
|
||||
request: RenameTagRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""重命名知识库标签"""
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -426,6 +436,7 @@ async def knowledge_graph(
|
||||
kb_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取知识图谱"""
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -441,7 +452,7 @@ async def knowledge_graph(
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||
return get_json_result(data=obj)
|
||||
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
if not len(sres.ids):
|
||||
return get_json_result(data=obj)
|
||||
|
||||
@@ -468,6 +479,7 @@ async def delete_knowledge_graph(
|
||||
kb_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除知识图谱"""
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -482,18 +494,19 @@ async def delete_knowledge_graph(
|
||||
|
||||
@router.get("/get_meta")
|
||||
async def get_meta(
|
||||
kb_ids: str = Query(..., description="知识库ID列表,用逗号分隔"),
|
||||
kb_ids: str = Query(..., description="知识库ID列表,逗号分隔"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_ids = kb_ids.split(",")
|
||||
for kb_id in kb_ids:
|
||||
"""获取知识库元数据"""
|
||||
kb_id_list = kb_ids.split(",")
|
||||
for kb_id in kb_id_list:
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
|
||||
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_id_list))
|
||||
|
||||
|
||||
@router.get("/basic_info")
|
||||
@@ -501,6 +514,7 @@ async def get_basic_info(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取知识库基本信息"""
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -515,42 +529,45 @@ async def get_basic_info(
|
||||
|
||||
@router.post("/list_pipeline_logs")
|
||||
async def list_pipeline_logs(
|
||||
request: ListPipelineLogsRequest,
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
keywords: str = Query("", description="关键词"),
|
||||
page: int = Query(0, description="页码"),
|
||||
page_size: int = Query(0, description="每页大小"),
|
||||
orderby: str = Query("create_time", description="排序字段"),
|
||||
desc: bool = Query(True, description="是否降序"),
|
||||
create_date_from: str = Query("", description="创建日期开始"),
|
||||
create_date_to: str = Query("", description="创建日期结束"),
|
||||
query: ListPipelineLogsQuery = Depends(),
|
||||
body: Optional[ListPipelineLogsBody] = None,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
"""列出流水线日志"""
|
||||
if not query.kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
page_number = page
|
||||
items_per_page = page_size
|
||||
|
||||
if body is None:
|
||||
body = ListPipelineLogsBody()
|
||||
|
||||
keywords = query.keywords or ""
|
||||
page_number = int(query.page or 0)
|
||||
items_per_page = int(query.page_size or 0)
|
||||
orderby = query.orderby or "create_time"
|
||||
desc = query.desc.lower() == "true" if query.desc else True
|
||||
create_date_from = query.create_date_from or ""
|
||||
create_date_to = query.create_date_to or ""
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
operation_status = request.operation_status
|
||||
operation_status = body.operation_status or []
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = request.types
|
||||
types = body.types or []
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
suffix = request.suffix
|
||||
suffix = body.suffix or []
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
|
||||
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(
|
||||
query.kb_id, page_number, items_per_page, orderby, desc, keywords,
|
||||
operation_status, types, suffix, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@@ -558,33 +575,36 @@ async def list_pipeline_logs(
|
||||
|
||||
@router.post("/list_pipeline_dataset_logs")
|
||||
async def list_pipeline_dataset_logs(
|
||||
request: ListPipelineDatasetLogsRequest,
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
page: int = Query(0, description="页码"),
|
||||
page_size: int = Query(0, description="每页大小"),
|
||||
orderby: str = Query("create_time", description="排序字段"),
|
||||
desc: bool = Query(True, description="是否降序"),
|
||||
create_date_from: str = Query("", description="创建日期开始"),
|
||||
create_date_to: str = Query("", description="创建日期结束"),
|
||||
query: ListPipelineDatasetLogsQuery = Depends(),
|
||||
body: Optional[ListPipelineDatasetLogsBody] = None,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
"""列出流水线数据集日志"""
|
||||
if not query.kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
page_number = page
|
||||
items_per_page = page_size
|
||||
|
||||
if body is None:
|
||||
body = ListPipelineDatasetLogsBody()
|
||||
|
||||
page_number = int(query.page or 0)
|
||||
items_per_page = int(query.page_size or 0)
|
||||
orderby = query.orderby or "create_time"
|
||||
desc = query.desc.lower() == "true" if query.desc else True
|
||||
create_date_from = query.create_date_from or ""
|
||||
create_date_to = query.create_date_to or ""
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
operation_status = request.operation_status
|
||||
operation_status = body.operation_status or []
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
|
||||
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(
|
||||
query.kb_id, page_number, items_per_page, orderby, desc,
|
||||
operation_status, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@@ -592,14 +612,18 @@ async def list_pipeline_dataset_logs(
|
||||
|
||||
@router.post("/delete_pipeline_logs")
|
||||
async def delete_pipeline_logs(
|
||||
request: DeletePipelineLogsRequest,
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
query: DeletePipelineLogsQuery = Depends(),
|
||||
body: Optional[DeletePipelineLogsBody] = None,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
"""删除流水线日志"""
|
||||
if not query.kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
log_ids = request.log_ids
|
||||
if body is None:
|
||||
body = DeletePipelineLogsBody(log_ids=[])
|
||||
|
||||
log_ids = body.log_ids or []
|
||||
|
||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||
|
||||
@@ -608,9 +632,10 @@ async def delete_pipeline_logs(
|
||||
|
||||
@router.get("/pipeline_log_detail")
|
||||
async def pipeline_log_detail(
|
||||
log_id: str = Query(..., description="日志ID"),
|
||||
log_id: str = Query(..., description="流水线日志ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取流水线日志详情"""
|
||||
if not log_id:
|
||||
return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
@@ -623,9 +648,10 @@ async def pipeline_log_detail(
|
||||
|
||||
@router.post("/run_graphrag")
|
||||
async def run_graphrag(
|
||||
request: RunGraphRAGRequest,
|
||||
request: RunGraphragRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""运行 GraphRAG"""
|
||||
kb_id = request.kb_id
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
@@ -660,7 +686,7 @@ async def run_graphrag(
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
||||
@@ -673,6 +699,7 @@ async def trace_graphrag(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""追踪 GraphRAG 任务"""
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
@@ -696,6 +723,7 @@ async def run_raptor(
|
||||
request: RunRaptorRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""运行 RAPTOR"""
|
||||
kb_id = request.kb_id
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
@@ -730,7 +758,7 @@ async def run_raptor(
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||
logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
|
||||
@@ -743,6 +771,7 @@ async def trace_raptor(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""追踪 RAPTOR 任务"""
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
@@ -766,6 +795,7 @@ async def run_mindmap(
|
||||
request: RunMindmapRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""运行 Mindmap"""
|
||||
kb_id = request.kb_id
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
@@ -800,7 +830,7 @@ async def run_mindmap(
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
|
||||
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
|
||||
@@ -813,6 +843,7 @@ async def trace_mindmap(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""追踪 Mindmap 任务"""
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
@@ -834,33 +865,43 @@ async def trace_mindmap(
|
||||
@router.delete("/unbind_task")
|
||||
async def delete_kb_task(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
pipeline_task_type: str = Query(..., description="管道任务类型"),
|
||||
pipeline_task_type: str = Query(..., description="流水线任务类型"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""解绑任务"""
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||
return get_error_data_result(message="Invalid task type")
|
||||
|
||||
match pipeline_task_type:
|
||||
case PipelineTaskType.GRAPH_RAG:
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
kb_task_id = "graphrag_task_id"
|
||||
kb_task_id_field = "graphrag_task_id"
|
||||
task_id = kb.graphrag_task_id
|
||||
kb_task_finish_at = "graphrag_task_finish_at"
|
||||
case PipelineTaskType.RAPTOR:
|
||||
kb_task_id = "raptor_task_id"
|
||||
kb_task_id_field = "raptor_task_id"
|
||||
task_id = kb.raptor_task_id
|
||||
kb_task_finish_at = "raptor_task_finish_at"
|
||||
case PipelineTaskType.MINDMAP:
|
||||
kb_task_id = "mindmap_task_id"
|
||||
kb_task_id_field = "mindmap_task_id"
|
||||
task_id = kb.mindmap_task_id
|
||||
kb_task_finish_at = "mindmap_task_finish_at"
|
||||
case _:
|
||||
return get_error_data_result(message="Internal Error: Invalid task type")
|
||||
|
||||
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id: "", kb_task_finish_at: None})
|
||||
def cancel_task(task_id):
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
cancel_task(task_id)
|
||||
|
||||
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id_field: "", kb_task_finish_at: None})
|
||||
if not ok:
|
||||
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@@ -15,26 +15,36 @@
|
||||
#
|
||||
import logging
|
||||
import json
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.llm_models import (
|
||||
SetApiKeyRequest,
|
||||
AddLLMRequest,
|
||||
DeleteLLMRequest,
|
||||
DeleteFactoryRequest,
|
||||
MyLLMsQuery,
|
||||
ListLLMsQuery,
|
||||
)
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
from api.models.llm_models import SetApiKeyRequest, AddLLMRequest, DeleteLLMRequest, DeleteFactoryRequest
|
||||
from api.utils.api_utils import get_current_user
|
||||
|
||||
# 创建 FastAPI 路由器
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get('/factories')
|
||||
async def factories(current_user = Depends(get_current_user)):
|
||||
async def factories(
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取 LLM 工厂列表"""
|
||||
try:
|
||||
fac = LLMFactoriesService.get_all()
|
||||
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
|
||||
@@ -55,17 +65,22 @@ async def factories(current_user = Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post('/set_api_key')
|
||||
async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_current_user)):
|
||||
async def set_api_key(
|
||||
request: SetApiKeyRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""设置 API Key"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = request.llm_factory
|
||||
factory = req["llm_factory"]
|
||||
extra = {"provider": factory}
|
||||
msg = ""
|
||||
for llm in LLMService.query(fid=factory):
|
||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](
|
||||
request.api_key, llm.llm_name, base_url=request.base_url)
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""))
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
@@ -76,7 +91,7 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
request.api_key, llm.llm_name, base_url=request.base_url, **extra)
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""), **extra)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9, 'max_tokens': 50})
|
||||
@@ -89,7 +104,7 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
||||
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
||||
mdl = RerankModel[factory](
|
||||
request.api_key, llm.llm_name, base_url=request.base_url)
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url", ""))
|
||||
try:
|
||||
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||
if len(arr) == 0 or tc == 0:
|
||||
@@ -107,9 +122,12 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
llm_config = {
|
||||
"api_key": request.api_key,
|
||||
"api_base": request.base_url or ""
|
||||
"api_key": req["api_key"],
|
||||
"api_base": req.get("base_url", "")
|
||||
}
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req:
|
||||
llm_config[n] = req[n]
|
||||
|
||||
for llm in LLMService.query(fid=factory):
|
||||
llm_config["max_tokens"]=llm.max_tokens
|
||||
@@ -132,14 +150,19 @@ async def set_api_key(request: SetApiKeyRequest, current_user = Depends(get_curr
|
||||
|
||||
|
||||
@router.post('/add_llm')
|
||||
async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_user)):
|
||||
factory = request.llm_factory
|
||||
api_key = request.api_key or "x"
|
||||
llm_name = request.llm_name
|
||||
async def add_llm(
|
||||
request: AddLLMRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""添加 LLM"""
|
||||
req = request.model_dump(exclude_unset=True)
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
|
||||
def apikey_json(keys):
|
||||
nonlocal request
|
||||
return json.dumps({k: getattr(request, k, "") for k in keys})
|
||||
nonlocal req
|
||||
return json.dumps({k: req.get(k, "") for k in keys})
|
||||
|
||||
if factory == "VolcEngine":
|
||||
# For VolcEngine, due to its special authentication method
|
||||
@@ -147,21 +170,28 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
|
||||
api_key = apikey_json(["ark_api_key", "endpoint_id"])
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
# Create a temporary request object for set_api_key
|
||||
temp_request = SetApiKeyRequest(
|
||||
llm_factory=factory,
|
||||
api_key=apikey_json(["hunyuan_sid", "hunyuan_sk"]),
|
||||
base_url=request.api_base
|
||||
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
||||
# 创建 SetApiKeyRequest 并调用 set_api_key 逻辑
|
||||
set_api_key_req = SetApiKeyRequest(
|
||||
llm_factory=req["llm_factory"],
|
||||
api_key=req["api_key"],
|
||||
base_url=req.get("api_base", req.get("base_url", "")),
|
||||
model_type=req.get("model_type"),
|
||||
llm_name=req.get("llm_name")
|
||||
)
|
||||
return await set_api_key(temp_request, current_user)
|
||||
return await set_api_key(set_api_key_req, current_user)
|
||||
|
||||
elif factory == "Tencent Cloud":
|
||||
temp_request = SetApiKeyRequest(
|
||||
llm_factory=factory,
|
||||
api_key=apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]),
|
||||
base_url=request.api_base
|
||||
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
||||
# 创建 SetApiKeyRequest 并调用 set_api_key 逻辑
|
||||
set_api_key_req = SetApiKeyRequest(
|
||||
llm_factory=req["llm_factory"],
|
||||
api_key=req["api_key"],
|
||||
base_url=req.get("api_base", req.get("base_url", "")),
|
||||
model_type=req.get("model_type"),
|
||||
llm_name=req.get("llm_name")
|
||||
)
|
||||
return await set_api_key(temp_request, current_user)
|
||||
return await set_api_key(set_api_key_req, current_user)
|
||||
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
@@ -181,9 +211,9 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
|
||||
llm_name += "___VLLM"
|
||||
|
||||
elif factory == "XunFei Spark":
|
||||
if request.model_type == "chat":
|
||||
api_key = request.spark_api_password or ""
|
||||
elif request.model_type == "tts":
|
||||
if req["model_type"] == "chat":
|
||||
api_key = req.get("spark_api_password", "")
|
||||
elif req["model_type"] == "tts":
|
||||
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
|
||||
|
||||
elif factory == "BaiduYiyan":
|
||||
@@ -198,14 +228,17 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
|
||||
elif factory == "Azure-OpenAI":
|
||||
api_key = apikey_json(["api_key", "api_version"])
|
||||
|
||||
elif factory == "OpenRouter":
|
||||
api_key = apikey_json(["api_key", "provider_order"])
|
||||
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": factory,
|
||||
"model_type": request.model_type,
|
||||
"model_type": req["model_type"],
|
||||
"llm_name": llm_name,
|
||||
"api_base": request.api_base or "",
|
||||
"api_base": req.get("api_base", ""),
|
||||
"api_key": api_key,
|
||||
"max_tokens": request.max_tokens
|
||||
"max_tokens": req.get("max_tokens")
|
||||
}
|
||||
|
||||
msg = ""
|
||||
@@ -295,7 +328,11 @@ async def add_llm(request: AddLLMRequest, current_user = Depends(get_current_use
|
||||
|
||||
|
||||
@router.post('/delete_llm')
|
||||
async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_current_user)):
|
||||
async def delete_llm(
|
||||
request: DeleteLLMRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除 LLM"""
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory,
|
||||
TenantLLM.llm_name == request.llm_name])
|
||||
@@ -303,7 +340,11 @@ async def delete_llm(request: DeleteLLMRequest, current_user = Depends(get_curre
|
||||
|
||||
|
||||
@router.post('/delete_factory')
|
||||
async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(get_current_user)):
|
||||
async def delete_factory(
|
||||
request: DeleteFactoryRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除工厂"""
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == request.llm_factory])
|
||||
return get_json_result(data=True)
|
||||
@@ -311,10 +352,13 @@ async def delete_factory(request: DeleteFactoryRequest, current_user = Depends(g
|
||||
|
||||
@router.get('/my_llms')
|
||||
async def my_llms(
|
||||
include_details: bool = Query(False, description="是否包含详细信息"),
|
||||
query: MyLLMsQuery = Depends(),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取我的 LLMs"""
|
||||
try:
|
||||
include_details = query.include_details.lower() == 'true'
|
||||
|
||||
if include_details:
|
||||
res = {}
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
@@ -362,11 +406,13 @@ async def my_llms(
|
||||
|
||||
@router.get('/list')
|
||||
async def list_app(
|
||||
model_type: Optional[str] = Query(None, description="模型类型"),
|
||||
query: ListLLMsQuery = Depends(),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出 LLMs"""
|
||||
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
|
||||
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
|
||||
model_type = query.model_type
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
||||
|
||||
@@ -13,164 +13,61 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from api.utils.api_utils import security
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from api.apps.models.mcp_models import (
|
||||
ListMCPServersQuery,
|
||||
ListMCPServersBody,
|
||||
CreateMCPServerRequest,
|
||||
UpdateMCPServerRequest,
|
||||
DeleteMCPServersRequest,
|
||||
ImportMCPServersRequest,
|
||||
ExportMCPServersRequest,
|
||||
ListMCPToolsRequest,
|
||||
TestMCPToolRequest,
|
||||
CacheMCPToolsRequest,
|
||||
TestMCPRequest,
|
||||
)
|
||||
|
||||
from api import settings
|
||||
from api.db import VALID_MCP_SERVER_TYPES
|
||||
from api.db.db_models import MCPServer, User
|
||||
from api.db.db_models import MCPServer
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import RetCode
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_mcp_tools
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from api.utils.web_utils import safe_json_parse
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Security
|
||||
|
||||
# Pydantic models for request/response
|
||||
class ListMCPRequest(BaseModel):
|
||||
mcp_ids: List[str] = []
|
||||
|
||||
class CreateMCPRequest(BaseModel):
|
||||
name: str
|
||||
url: str
|
||||
server_type: str
|
||||
headers: Optional[Dict[str, Any]] = {}
|
||||
variables: Optional[Dict[str, Any]] = {}
|
||||
timeout: Optional[float] = 10
|
||||
|
||||
class UpdateMCPRequest(BaseModel):
|
||||
mcp_id: str
|
||||
name: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
server_type: Optional[str] = None
|
||||
headers: Optional[Dict[str, Any]] = None
|
||||
variables: Optional[Dict[str, Any]] = None
|
||||
timeout: Optional[float] = 10
|
||||
|
||||
class RemoveMCPRequest(BaseModel):
|
||||
mcp_ids: List[str]
|
||||
|
||||
class ImportMCPRequest(BaseModel):
|
||||
mcpServers: Dict[str, Dict[str, Any]]
|
||||
timeout: Optional[float] = 10
|
||||
|
||||
class ExportMCPRequest(BaseModel):
|
||||
mcp_ids: List[str]
|
||||
|
||||
|
||||
class ListToolsRequest(BaseModel):
|
||||
mcp_ids: List[str]
|
||||
timeout: Optional[float] = 10
|
||||
|
||||
|
||||
class TestToolRequest(BaseModel):
|
||||
mcp_id: str
|
||||
tool_name: str
|
||||
arguments: Dict[str, Any]
|
||||
timeout: Optional[float] = 10
|
||||
|
||||
|
||||
class CacheToolsRequest(BaseModel):
|
||||
mcp_id: str
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class TestMCPRequest(BaseModel):
|
||||
url: str
|
||||
server_type: str
|
||||
timeout: Optional[float] = 10
|
||||
headers: Optional[Dict[str, Any]] = {}
|
||||
variables: Optional[Dict[str, Any]] = {}
|
||||
|
||||
# Dependency injection
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户"""
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.user_service import UserService
|
||||
from fastapi import HTTPException, status
|
||||
import logging
|
||||
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User {user[0].email} has empty access_token in database"
|
||||
)
|
||||
return user[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required"
|
||||
)
|
||||
|
||||
# Create router
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/list")
|
||||
async def list_mcp(
|
||||
request: ListMCPRequest,
|
||||
keywords: str = Query("", description="Search keywords"),
|
||||
page: int = Query(0, description="Page number"),
|
||||
page_size: int = Query(0, description="Items per page"),
|
||||
orderby: str = Query("create_time", description="Order by field"),
|
||||
desc: bool = Query(True, description="Sort descending"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
query: ListMCPServersQuery = Depends(),
|
||||
body: ListMCPServersBody = None,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出MCP服务器"""
|
||||
if body is None:
|
||||
body = ListMCPServersBody()
|
||||
|
||||
keywords = query.keywords or ""
|
||||
page_number = int(query.page or 0)
|
||||
items_per_page = int(query.page_size or 0)
|
||||
orderby = query.orderby or "create_time"
|
||||
desc = query.desc.lower() == "true" if query.desc else True
|
||||
|
||||
mcp_ids = body.mcp_ids or []
|
||||
try:
|
||||
servers = MCPServerService.get_servers(current_user.id, request.mcp_ids, 0, 0, orderby, desc, keywords) or []
|
||||
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
|
||||
total = len(servers)
|
||||
|
||||
if page and page_size:
|
||||
servers = servers[(page - 1) * page_size : page * page_size]
|
||||
if page_number and items_per_page:
|
||||
servers = servers[(page_number - 1) * items_per_page : page_number * items_per_page]
|
||||
|
||||
return get_json_result(data={"mcp_servers": servers, "total": total})
|
||||
except Exception as e:
|
||||
@@ -179,9 +76,10 @@ async def list_mcp(
|
||||
|
||||
@router.get("/detail")
|
||||
async def detail(
|
||||
mcp_id: str = Query(..., description="MCP server ID"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
mcp_id: str = Query(..., description="MCP服务器ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""获取MCP服务器详情"""
|
||||
try:
|
||||
mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id)
|
||||
|
||||
@@ -195,9 +93,10 @@ async def detail(
|
||||
|
||||
@router.post("/create")
|
||||
async def create(
|
||||
request: CreateMCPRequest,
|
||||
current_user: User = Depends(get_current_user)
|
||||
request: CreateMCPServerRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""创建MCP服务器"""
|
||||
server_type = request.server_type
|
||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||
return get_data_error_result(message="Unsupported MCP server type.")
|
||||
@@ -218,10 +117,10 @@ async def create(
|
||||
variables = safe_json_parse(request.variables or {})
|
||||
variables.pop("tools", None)
|
||||
|
||||
timeout = request.timeout or 10
|
||||
timeout = request.timeout or 10.0
|
||||
|
||||
try:
|
||||
req_data = {
|
||||
req = {
|
||||
"id": get_uuid(),
|
||||
"tenant_id": current_user.id,
|
||||
"name": server_name,
|
||||
@@ -229,7 +128,6 @@ async def create(
|
||||
"server_type": server_type,
|
||||
"headers": headers,
|
||||
"variables": variables,
|
||||
"timeout": timeout
|
||||
}
|
||||
|
||||
e, _ = TenantService.get_by_id(current_user.id)
|
||||
@@ -244,46 +142,45 @@ async def create(
|
||||
tools = server_tools[server_name]
|
||||
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
|
||||
variables["tools"] = tools
|
||||
req_data["variables"] = variables
|
||||
req["variables"] = variables
|
||||
|
||||
if not MCPServerService.insert(**req_data):
|
||||
if not MCPServerService.insert(**req):
|
||||
return get_data_error_result("Failed to create MCP server.")
|
||||
|
||||
return get_json_result(data=req_data)
|
||||
return get_json_result(data=req)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/update")
|
||||
async def update(
|
||||
request: UpdateMCPRequest,
|
||||
current_user: User = Depends(get_current_user)
|
||||
request: UpdateMCPServerRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""更新MCP服务器"""
|
||||
mcp_id = request.mcp_id
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
if not e or mcp_server.tenant_id != current_user.id:
|
||||
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
|
||||
|
||||
server_type = request.server_type or mcp_server.server_type
|
||||
server_type = request.server_type if request.server_type is not None else mcp_server.server_type
|
||||
if server_type and server_type not in VALID_MCP_SERVER_TYPES:
|
||||
return get_data_error_result(message="Unsupported MCP server type.")
|
||||
|
||||
server_name = request.name or mcp_server.name
|
||||
server_name = request.name if request.name is not None else mcp_server.name
|
||||
if server_name and len(server_name.encode("utf-8")) > 255:
|
||||
return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.")
|
||||
|
||||
url = request.url or mcp_server.url
|
||||
url = request.url if request.url is not None else mcp_server.url
|
||||
if not url:
|
||||
return get_data_error_result(message="Invalid url.")
|
||||
|
||||
headers = safe_json_parse(request.headers or mcp_server.headers)
|
||||
variables = safe_json_parse(request.variables or mcp_server.variables)
|
||||
headers = safe_json_parse(request.headers if request.headers is not None else mcp_server.headers)
|
||||
variables = safe_json_parse(request.variables if request.variables is not None else mcp_server.variables)
|
||||
variables.pop("tools", None)
|
||||
|
||||
timeout = request.timeout or 10
|
||||
timeout = request.timeout or 10.0
|
||||
|
||||
try:
|
||||
req_data = {
|
||||
req = {
|
||||
"tenant_id": current_user.id,
|
||||
"id": mcp_id,
|
||||
"name": server_name,
|
||||
@@ -291,7 +188,6 @@ async def update(
|
||||
"server_type": server_type,
|
||||
"headers": headers,
|
||||
"variables": variables,
|
||||
"timeout": timeout
|
||||
}
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
@@ -302,12 +198,12 @@ async def update(
|
||||
tools = server_tools[server_name]
|
||||
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
|
||||
variables["tools"] = tools
|
||||
req_data["variables"] = variables
|
||||
req["variables"] = variables
|
||||
|
||||
if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req_data):
|
||||
if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req):
|
||||
return get_data_error_result(message="Failed to updated MCP server.")
|
||||
|
||||
e, updated_mcp = MCPServerService.get_by_id(req_data["id"])
|
||||
e, updated_mcp = MCPServerService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Failed to fetch updated MCP server.")
|
||||
|
||||
@@ -318,9 +214,10 @@ async def update(
|
||||
|
||||
@router.post("/rm")
|
||||
async def rm(
|
||||
request: RemoveMCPRequest,
|
||||
current_user: User = Depends(get_current_user)
|
||||
request: DeleteMCPServersRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""删除MCP服务器"""
|
||||
mcp_ids = request.mcp_ids
|
||||
|
||||
try:
|
||||
@@ -334,14 +231,15 @@ async def rm(
|
||||
|
||||
@router.post("/import")
|
||||
async def import_multiple(
|
||||
request: ImportMCPRequest,
|
||||
current_user: User = Depends(get_current_user)
|
||||
request: ImportMCPServersRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""批量导入MCP服务器"""
|
||||
servers = request.mcpServers
|
||||
if not servers:
|
||||
return get_data_error_result(message="No MCP servers provided.")
|
||||
|
||||
timeout = request.timeout or 10
|
||||
timeout = request.timeout or 10.0
|
||||
|
||||
results = []
|
||||
try:
|
||||
@@ -401,9 +299,10 @@ async def import_multiple(
|
||||
|
||||
@router.post("/export")
|
||||
async def export_multiple(
|
||||
request: ExportMCPRequest,
|
||||
current_user: User = Depends(get_current_user)
|
||||
request: ExportMCPServersRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""批量导出MCP服务器"""
|
||||
mcp_ids = request.mcp_ids
|
||||
|
||||
if not mcp_ids:
|
||||
@@ -432,12 +331,16 @@ async def export_multiple(
|
||||
|
||||
|
||||
@router.post("/list_tools")
|
||||
async def list_tools(req: ListToolsRequest, current_user: User = Depends(get_current_user)):
|
||||
mcp_ids = req.mcp_ids
|
||||
async def list_tools(
|
||||
request: ListMCPToolsRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""列出MCP工具"""
|
||||
mcp_ids = request.mcp_ids
|
||||
if not mcp_ids:
|
||||
return get_data_error_result(message="No MCP server IDs provided.")
|
||||
|
||||
timeout = req.timeout
|
||||
timeout = request.timeout or 10.0
|
||||
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
@@ -476,15 +379,19 @@ async def list_tools(req: ListToolsRequest, current_user: User = Depends(get_cur
|
||||
|
||||
|
||||
@router.post("/test_tool")
|
||||
async def test_tool(req: TestToolRequest, current_user: User = Depends(get_current_user)):
|
||||
mcp_id = req.mcp_id
|
||||
async def test_tool(
|
||||
request: TestMCPToolRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""测试MCP工具"""
|
||||
mcp_id = request.mcp_id
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
|
||||
timeout = req.timeout
|
||||
timeout = request.timeout or 10.0
|
||||
|
||||
tool_name = req.tool_name
|
||||
arguments = req.arguments
|
||||
tool_name = request.tool_name
|
||||
arguments = request.arguments
|
||||
if not all([tool_name, arguments]):
|
||||
return get_data_error_result(message="Require provide tool name and arguments.")
|
||||
|
||||
@@ -506,11 +413,15 @@ async def test_tool(req: TestToolRequest, current_user: User = Depends(get_curre
|
||||
|
||||
|
||||
@router.post("/cache_tools")
|
||||
async def cache_tool(req: CacheToolsRequest, current_user: User = Depends(get_current_user)):
|
||||
mcp_id = req.mcp_id
|
||||
async def cache_tool(
|
||||
request: CacheMCPToolsRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""缓存MCP工具"""
|
||||
mcp_id = request.mcp_id
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
tools = req.tools
|
||||
tools = request.tools
|
||||
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
if not e or mcp_server.tenant_id != current_user.id:
|
||||
@@ -527,18 +438,21 @@ async def cache_tool(req: CacheToolsRequest, current_user: User = Depends(get_cu
|
||||
|
||||
|
||||
@router.post("/test_mcp")
|
||||
async def test_mcp(req: TestMCPRequest):
|
||||
url = req.url
|
||||
async def test_mcp(
|
||||
request: TestMCPRequest
|
||||
):
|
||||
"""测试MCP服务器(不需要登录)"""
|
||||
url = request.url
|
||||
if not url:
|
||||
return get_data_error_result(message="Invalid MCP url.")
|
||||
|
||||
server_type = req.server_type
|
||||
server_type = request.server_type
|
||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||
return get_data_error_result(message="Unsupported MCP server type.")
|
||||
|
||||
timeout = req.timeout
|
||||
headers = req.headers
|
||||
variables = req.variables
|
||||
timeout = request.timeout or 10.0
|
||||
headers = safe_json_parse(request.headers or {})
|
||||
variables = safe_json_parse(request.variables or {})
|
||||
|
||||
mcp_server = MCPServer(id=f"{server_type}: {url}", server_type=server_type, url=url, headers=headers, variables=variables)
|
||||
|
||||
|
||||
53
api/apps/models/auth_dependencies.py
Normal file
53
api/apps/models/auth_dependencies.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import Depends, Header, Security, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
|
||||
# 创建 HTTPBearer 安全方案(auto_error=False 允许我们自定义错误处理)
|
||||
http_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = Security(http_bearer)):
|
||||
"""FastAPI 依赖注入:获取当前用户(替代 Flask 的 login_required 和 current_user)
|
||||
|
||||
使用 Security(http_bearer) 可以让 FastAPI 自动在 OpenAPI schema 中添加安全要求,
|
||||
这样 Swagger UI 就会显示授权输入框并自动在请求中添加 Authorization 头。
|
||||
"""
|
||||
# 延迟导入以避免循环导入
|
||||
from api.apps.__init___fastapi import get_current_user_from_token
|
||||
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header is required"
|
||||
)
|
||||
|
||||
# HTTPBearer 已经提取了 Bearer token,credentials.credentials 就是 token 本身
|
||||
authorization = credentials.credentials
|
||||
|
||||
user = get_current_user_from_token(authorization)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
129
api/apps/models/canvas_models.py
Normal file
129
api/apps/models/canvas_models.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DeleteCanvasRequest(BaseModel):
|
||||
"""删除画布请求"""
|
||||
canvas_ids: List[str] = Field(..., description="画布ID列表")
|
||||
|
||||
|
||||
class SaveCanvasRequest(BaseModel):
|
||||
"""保存画布请求"""
|
||||
dsl: Union[str, Dict[str, Any]] = Field(..., description="DSL配置")
|
||||
title: str = Field(..., description="画布标题")
|
||||
id: Optional[str] = Field(default=None, description="画布ID(更新时提供)")
|
||||
canvas_category: Optional[str] = Field(default=None, description="画布类别")
|
||||
description: Optional[str] = Field(default=None, description="描述")
|
||||
permission: Optional[str] = Field(default=None, description="权限")
|
||||
avatar: Optional[str] = Field(default=None, description="头像")
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""完成/运行画布请求"""
|
||||
id: str = Field(..., description="画布ID")
|
||||
query: Optional[str] = Field(default="", description="查询内容")
|
||||
files: Optional[List[str]] = Field(default=[], description="文件列表")
|
||||
inputs: Optional[Dict[str, Any]] = Field(default={}, description="输入参数")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
|
||||
|
||||
class RerunRequest(BaseModel):
|
||||
"""重新运行请求"""
|
||||
id: str = Field(..., description="流水线ID")
|
||||
dsl: Dict[str, Any] = Field(..., description="DSL配置")
|
||||
component_id: str = Field(..., description="组件ID")
|
||||
|
||||
|
||||
class ResetCanvasRequest(BaseModel):
|
||||
"""重置画布请求"""
|
||||
id: str = Field(..., description="画布ID")
|
||||
|
||||
|
||||
class InputFormQuery(BaseModel):
|
||||
"""获取输入表单查询参数"""
|
||||
id: str = Field(..., description="画布ID")
|
||||
component_id: str = Field(..., description="组件ID")
|
||||
|
||||
|
||||
class DebugRequest(BaseModel):
|
||||
"""调试请求"""
|
||||
id: str = Field(..., description="画布ID")
|
||||
component_id: str = Field(..., description="组件ID")
|
||||
params: Dict[str, Any] = Field(..., description="参数")
|
||||
|
||||
|
||||
class TestDBConnectRequest(BaseModel):
|
||||
"""测试数据库连接请求"""
|
||||
db_type: str = Field(..., description="数据库类型")
|
||||
database: str = Field(..., description="数据库名")
|
||||
username: str = Field(..., description="用户名")
|
||||
host: str = Field(..., description="主机")
|
||||
port: str = Field(..., description="端口")
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class ListCanvasQuery(BaseModel):
|
||||
"""列出画布查询参数"""
|
||||
keywords: Optional[str] = Field(default="", description="关键词")
|
||||
page: Optional[int] = Field(default=0, description="页码")
|
||||
page_size: Optional[int] = Field(default=0, description="每页大小")
|
||||
orderby: Optional[str] = Field(default="create_time", description="排序字段")
|
||||
desc: Optional[str] = Field(default="true", description="是否降序")
|
||||
canvas_category: Optional[str] = Field(default=None, description="画布类别")
|
||||
owner_ids: Optional[str] = Field(default="", description="所有者ID列表(逗号分隔)")
|
||||
|
||||
|
||||
class SettingRequest(BaseModel):
|
||||
"""画布设置请求"""
|
||||
id: str = Field(..., description="画布ID")
|
||||
title: str = Field(..., description="标题")
|
||||
permission: str = Field(..., description="权限")
|
||||
description: Optional[str] = Field(default=None, description="描述")
|
||||
avatar: Optional[str] = Field(default=None, description="头像")
|
||||
|
||||
|
||||
class TraceQuery(BaseModel):
|
||||
"""追踪查询参数"""
|
||||
canvas_id: str = Field(..., description="画布ID")
|
||||
message_id: str = Field(..., description="消息ID")
|
||||
|
||||
|
||||
class ListSessionsQuery(BaseModel):
|
||||
"""列出会话查询参数"""
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
page: Optional[int] = Field(default=1, description="页码")
|
||||
page_size: Optional[int] = Field(default=30, description="每页大小")
|
||||
keywords: Optional[str] = Field(default=None, description="关键词")
|
||||
from_date: Optional[str] = Field(default=None, description="开始日期")
|
||||
to_date: Optional[str] = Field(default=None, description="结束日期")
|
||||
orderby: Optional[str] = Field(default="update_time", description="排序字段")
|
||||
desc: Optional[str] = Field(default="true", description="是否降序")
|
||||
dsl: Optional[str] = Field(default="true", description="是否包含DSL")
|
||||
|
||||
|
||||
class DownloadQuery(BaseModel):
|
||||
"""下载查询参数"""
|
||||
id: str = Field(..., description="文件ID")
|
||||
created_by: str = Field(..., description="创建者ID")
|
||||
|
||||
|
||||
class UploadQuery(BaseModel):
|
||||
"""上传查询参数"""
|
||||
url: Optional[str] = Field(default=None, description="URL(可选,用于从URL下载)")
|
||||
|
||||
80
api/apps/models/chunk_models.py
Normal file
80
api/apps/models/chunk_models.py
Normal file
@@ -0,0 +1,80 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, List, Any, Union
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class ListChunksRequest(BaseModel):
|
||||
"""列出文档块请求"""
|
||||
doc_id: str = Field(..., description="文档ID")
|
||||
page: Optional[int] = Field(default=1, description="页码")
|
||||
size: Optional[int] = Field(default=30, description="每页大小")
|
||||
keywords: Optional[str] = Field(default="", description="关键词")
|
||||
available_int: Optional[int] = Field(default=None, description="可用状态")
|
||||
|
||||
|
||||
class SetChunkRequest(BaseModel):
|
||||
"""设置文档块请求"""
|
||||
doc_id: str = Field(..., description="文档ID")
|
||||
chunk_id: str = Field(..., description="块ID")
|
||||
content_with_weight: str = Field(..., description="内容")
|
||||
important_kwd: Optional[List[str]] = Field(default=None, description="重要关键词列表")
|
||||
question_kwd: Optional[List[str]] = Field(default=None, description="问题关键词列表")
|
||||
tag_kwd: Optional[str] = Field(default=None, description="标签关键词")
|
||||
tag_feas: Optional[Any] = Field(default=None, description="标签特征")
|
||||
available_int: Optional[int] = Field(default=None, description="可用状态")
|
||||
|
||||
|
||||
class SwitchChunksRequest(BaseModel):
|
||||
"""切换文档块状态请求"""
|
||||
doc_id: str = Field(..., description="文档ID")
|
||||
chunk_ids: List[str] = Field(..., description="块ID列表")
|
||||
available_int: int = Field(..., description="可用状态")
|
||||
|
||||
|
||||
class DeleteChunksRequest(BaseModel):
|
||||
"""删除文档块请求"""
|
||||
doc_id: str = Field(..., description="文档ID")
|
||||
chunk_ids: List[str] = Field(..., description="块ID列表")
|
||||
|
||||
|
||||
class CreateChunkRequest(BaseModel):
|
||||
"""创建文档块请求"""
|
||||
doc_id: str = Field(..., description="文档ID")
|
||||
content_with_weight: str = Field(..., description="内容")
|
||||
important_kwd: Optional[List[str]] = Field(default=[], description="重要关键词列表")
|
||||
question_kwd: Optional[List[str]] = Field(default=[], description="问题关键词列表")
|
||||
tag_feas: Optional[Any] = Field(default=None, description="标签特征")
|
||||
|
||||
|
||||
class RetrievalTestRequest(BaseModel):
|
||||
"""检索测试请求"""
|
||||
kb_id: Union[str, List[str]] = Field(..., description="知识库ID,可以是字符串或列表")
|
||||
question: str = Field(..., description="问题")
|
||||
page: Optional[int] = Field(default=1, description="页码")
|
||||
size: Optional[int] = Field(default=30, description="每页大小")
|
||||
doc_ids: Optional[List[str]] = Field(default=[], description="文档ID列表")
|
||||
use_kg: Optional[bool] = Field(default=False, description="是否使用知识图谱")
|
||||
top_k: Optional[int] = Field(default=1024, description="Top K")
|
||||
cross_languages: Optional[List[str]] = Field(default=[], description="跨语言列表")
|
||||
search_id: Optional[str] = Field(default="", description="搜索ID")
|
||||
rerank_id: Optional[str] = Field(default=None, description="重排序模型ID")
|
||||
keyword: Optional[bool] = Field(default=False, description="是否使用关键词")
|
||||
similarity_threshold: Optional[float] = Field(default=0.0, description="相似度阈值")
|
||||
vector_similarity_weight: Optional[float] = Field(default=0.3, description="向量相似度权重")
|
||||
highlight: Optional[bool] = Field(default=False, description="是否高亮")
|
||||
|
||||
204
api/apps/models/document_models.py
Normal file
204
api/apps/models/document_models.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Literal, List
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class CreateDocumentRequest(BaseModel):
|
||||
"""创建文档请求
|
||||
|
||||
支持两种解析类型:
|
||||
- parse_type=1: 使用内置解析器,需要 parser_id,pipeline_id 为空
|
||||
- parse_type=2: 使用自定义 pipeline,需要 pipeline_id,parser_id 为空
|
||||
如果不提供 parse_type,则从知识库继承解析配置
|
||||
"""
|
||||
name: str
|
||||
kb_id: str
|
||||
parse_type: Optional[Literal[1, 2]] = Field(default=None, description="解析类型:1=内置解析器,2=自定义pipeline,None=从知识库继承")
|
||||
parser_id: Optional[str] = Field(default="", description="解析器ID,parse_type=1时必需")
|
||||
pipeline_id: Optional[str] = Field(default="", description="流水线ID,parse_type=2时必需")
|
||||
parser_config: Optional[dict] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_parse_type_fields(self):
|
||||
"""根据 parse_type 验证相应字段"""
|
||||
if self.parse_type is not None:
|
||||
if self.parse_type == 1:
|
||||
# parse_type=1: 需要 parser_id,pipeline_id 必须为空
|
||||
parser_id_val = self.parser_id or ""
|
||||
pipeline_id_val = self.pipeline_id or ""
|
||||
|
||||
if parser_id_val.strip() == "":
|
||||
raise ValueError("parse_type=1时,parser_id不能为空")
|
||||
if pipeline_id_val.strip() != "":
|
||||
raise ValueError("parse_type=1时,pipeline_id必须为空")
|
||||
elif self.parse_type == 2:
|
||||
# parse_type=2: 需要 pipeline_id,parser_id 必须为空
|
||||
parser_id_val = self.parser_id or ""
|
||||
pipeline_id_val = self.pipeline_id or ""
|
||||
|
||||
if pipeline_id_val.strip() == "":
|
||||
raise ValueError("parse_type=2时,pipeline_id不能为空")
|
||||
if parser_id_val.strip() != "":
|
||||
raise ValueError("parse_type=2时,parser_id必须为空")
|
||||
return self
|
||||
|
||||
|
||||
class ChangeParserRequest(BaseModel):
|
||||
"""修改文档解析器请求
|
||||
|
||||
支持两种解析类型:
|
||||
- parse_type=1: 使用内置解析器,需要 parser_id,pipeline_id 为空
|
||||
- parse_type=2: 使用自定义 pipeline,需要 pipeline_id,parser_id 为空
|
||||
"""
|
||||
doc_id: str
|
||||
parse_type: Literal[1, 2] = Field(..., description="解析类型:1=内置解析器,2=自定义pipeline")
|
||||
parser_id: Optional[str] = Field(default="", description="解析器ID,parse_type=1时必需")
|
||||
pipeline_id: Optional[str] = Field(default="", description="流水线ID,parse_type=2时必需")
|
||||
parser_config: Optional[dict] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_parse_type_fields(self):
|
||||
"""根据 parse_type 验证相应字段"""
|
||||
if self.parse_type == 1:
|
||||
# parse_type=1: 需要 parser_id,pipeline_id 必须为空
|
||||
parser_id_val = self.parser_id or ""
|
||||
pipeline_id_val = self.pipeline_id or ""
|
||||
|
||||
if parser_id_val.strip() == "":
|
||||
raise ValueError("parse_type=1时,parser_id不能为空")
|
||||
if pipeline_id_val.strip() != "":
|
||||
raise ValueError("parse_type=1时,pipeline_id必须为空")
|
||||
elif self.parse_type == 2:
|
||||
# parse_type=2: 需要 pipeline_id,parser_id 必须为空
|
||||
parser_id_val = self.parser_id or ""
|
||||
pipeline_id_val = self.pipeline_id or ""
|
||||
|
||||
if pipeline_id_val.strip() == "":
|
||||
raise ValueError("parse_type=2时,pipeline_id不能为空")
|
||||
if parser_id_val.strip() != "":
|
||||
raise ValueError("parse_type=2时,parser_id必须为空")
|
||||
return self
|
||||
|
||||
|
||||
class WebCrawlRequest(BaseModel):
|
||||
"""网页爬取请求"""
|
||||
kb_id: str
|
||||
name: str
|
||||
url: str
|
||||
|
||||
|
||||
class ListDocumentsQuery(BaseModel):
|
||||
"""列出文档查询参数"""
|
||||
kb_id: str
|
||||
keywords: Optional[str] = ""
|
||||
page: Optional[int] = 0
|
||||
page_size: Optional[int] = 0
|
||||
orderby: Optional[str] = "create_time"
|
||||
desc: Optional[str] = "true"
|
||||
create_time_from: Optional[int] = 0
|
||||
create_time_to: Optional[int] = 0
|
||||
|
||||
|
||||
class ListDocumentsBody(BaseModel):
|
||||
"""列出文档请求体"""
|
||||
run_status: Optional[List[str]] = []
|
||||
types: Optional[List[str]] = []
|
||||
suffix: Optional[List[str]] = []
|
||||
|
||||
|
||||
class FilterDocumentsRequest(BaseModel):
|
||||
"""过滤文档请求"""
|
||||
kb_id: str
|
||||
keywords: Optional[str] = ""
|
||||
suffix: Optional[List[str]] = []
|
||||
run_status: Optional[List[str]] = []
|
||||
types: Optional[List[str]] = []
|
||||
|
||||
|
||||
class GetDocumentInfosRequest(BaseModel):
|
||||
"""获取文档信息请求"""
|
||||
doc_ids: List[str]
|
||||
|
||||
|
||||
class ChangeStatusRequest(BaseModel):
|
||||
"""修改文档状态请求"""
|
||||
doc_ids: List[str]
|
||||
status: str # "0" 或 "1"
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_status(self):
|
||||
if self.status not in ["0", "1"]:
|
||||
raise ValueError('Status must be either 0 or 1!')
|
||||
return self
|
||||
|
||||
|
||||
class DeleteDocumentRequest(BaseModel):
|
||||
"""删除文档请求"""
|
||||
doc_id: str | List[str] # 支持单个或列表
|
||||
|
||||
|
||||
class RunDocumentRequest(BaseModel):
|
||||
"""运行文档解析请求"""
|
||||
doc_ids: List[str]
|
||||
run: str # TaskStatus 值
|
||||
delete: Optional[bool] = False
|
||||
|
||||
|
||||
class RenameDocumentRequest(BaseModel):
|
||||
"""重命名文档请求"""
|
||||
doc_id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ChangeParserSimpleRequest(BaseModel):
|
||||
"""简单修改解析器请求(兼容旧逻辑)"""
|
||||
doc_id: str
|
||||
parser_id: Optional[str] = None
|
||||
pipeline_id: Optional[str] = None
|
||||
parser_config: Optional[dict] = None
|
||||
|
||||
|
||||
class UploadAndParseRequest(BaseModel):
|
||||
"""上传并解析请求(仅用于验证 conversation_id)"""
|
||||
conversation_id: str
|
||||
|
||||
|
||||
class ParseRequest(BaseModel):
|
||||
"""解析请求"""
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class SetMetaRequest(BaseModel):
|
||||
"""设置元数据请求"""
|
||||
doc_id: str
|
||||
meta: str # JSON 字符串
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_meta(self):
|
||||
import json
|
||||
try:
|
||||
meta_dict = json.loads(self.meta)
|
||||
if not isinstance(meta_dict, dict):
|
||||
raise ValueError("Only dictionary type supported.")
|
||||
for k, v in meta_dict.items():
|
||||
if not isinstance(v, (str, int, float)):
|
||||
raise ValueError(f"The type is not supported: {v}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Json syntax error: {e}")
|
||||
return self
|
||||
|
||||
159
api/apps/models/kb_models.py
Normal file
159
api/apps/models/kb_models.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class CreateKnowledgeBaseRequest(BaseModel):
|
||||
"""创建知识库请求
|
||||
|
||||
支持两种解析类型:
|
||||
- parse_type=1: 使用内置解析器,需要 parser_id,pipeline_id 为空
|
||||
- parse_type=2: 使用自定义 pipeline,需要 pipeline_id,parser_id 为空
|
||||
"""
|
||||
name: str
|
||||
parse_type: Literal[1, 2] = Field(..., description="解析类型:1=内置解析器,2=自定义pipeline")
|
||||
embd_id: str = Field(..., description="嵌入模型ID")
|
||||
parser_id: Optional[str] = Field(default="", description="解析器ID,parse_type=1时必需")
|
||||
pipeline_id: Optional[str] = Field(default="", description="流水线ID,parse_type=2时必需")
|
||||
description: Optional[str] = None
|
||||
pagerank: Optional[int] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_parse_type_fields(self):
|
||||
"""根据 parse_type 验证相应字段"""
|
||||
if self.parse_type == 1:
|
||||
# parse_type=1: 需要 parser_id,pipeline_id 必须为空
|
||||
parser_id_val = self.parser_id or ""
|
||||
pipeline_id_val = self.pipeline_id or ""
|
||||
|
||||
if parser_id_val.strip() == "":
|
||||
raise ValueError("parse_type=1时,parser_id不能为空")
|
||||
if pipeline_id_val.strip() != "":
|
||||
raise ValueError("parse_type=1时,pipeline_id必须为空")
|
||||
elif self.parse_type == 2:
|
||||
# parse_type=2: 需要 pipeline_id,parser_id 必须为空
|
||||
parser_id_val = self.parser_id or ""
|
||||
pipeline_id_val = self.pipeline_id or ""
|
||||
|
||||
if pipeline_id_val.strip() == "":
|
||||
raise ValueError("parse_type=2时,pipeline_id不能为空")
|
||||
if parser_id_val.strip() != "":
|
||||
raise ValueError("parse_type=2时,parser_id必须为空")
|
||||
return self
|
||||
|
||||
|
||||
class UpdateKnowledgeBaseRequest(BaseModel):
|
||||
"""更新知识库请求"""
|
||||
kb_id: str
|
||||
name: str
|
||||
description: str
|
||||
parser_id: str
|
||||
pagerank: Optional[int] = None
|
||||
# 其他可选字段,但排除 id, tenant_id, created_by, create_time, update_time, create_date, update_date
|
||||
|
||||
|
||||
class DeleteKnowledgeBaseRequest(BaseModel):
|
||||
"""删除知识库请求"""
|
||||
kb_id: str
|
||||
|
||||
|
||||
class ListKnowledgeBasesQuery(BaseModel):
|
||||
"""列出知识库查询参数"""
|
||||
keywords: Optional[str] = ""
|
||||
page: Optional[int] = 0
|
||||
page_size: Optional[int] = 0
|
||||
parser_id: Optional[str] = None
|
||||
orderby: Optional[str] = "create_time"
|
||||
desc: Optional[str] = "true"
|
||||
|
||||
|
||||
class ListKnowledgeBasesBody(BaseModel):
|
||||
"""列出知识库请求体"""
|
||||
owner_ids: Optional[List[str]] = []
|
||||
|
||||
|
||||
class RemoveTagsRequest(BaseModel):
|
||||
"""删除标签请求"""
|
||||
tags: List[str]
|
||||
|
||||
|
||||
class RenameTagRequest(BaseModel):
|
||||
"""重命名标签请求"""
|
||||
from_tag: str
|
||||
to_tag: str
|
||||
|
||||
|
||||
class ListPipelineLogsQuery(BaseModel):
|
||||
"""列出流水线日志查询参数"""
|
||||
kb_id: str
|
||||
keywords: Optional[str] = ""
|
||||
page: Optional[int] = 0
|
||||
page_size: Optional[int] = 0
|
||||
orderby: Optional[str] = "create_time"
|
||||
desc: Optional[str] = "true"
|
||||
create_date_from: Optional[str] = ""
|
||||
create_date_to: Optional[str] = ""
|
||||
|
||||
|
||||
class ListPipelineLogsBody(BaseModel):
|
||||
"""列出流水线日志请求体"""
|
||||
operation_status: Optional[List[str]] = []
|
||||
types: Optional[List[str]] = []
|
||||
suffix: Optional[List[str]] = []
|
||||
|
||||
|
||||
class ListPipelineDatasetLogsQuery(BaseModel):
|
||||
"""列出流水线数据集日志查询参数"""
|
||||
kb_id: str
|
||||
page: Optional[int] = 0
|
||||
page_size: Optional[int] = 0
|
||||
orderby: Optional[str] = "create_time"
|
||||
desc: Optional[str] = "true"
|
||||
create_date_from: Optional[str] = ""
|
||||
create_date_to: Optional[str] = ""
|
||||
|
||||
|
||||
class ListPipelineDatasetLogsBody(BaseModel):
|
||||
"""列出流水线数据集日志请求体"""
|
||||
operation_status: Optional[List[str]] = []
|
||||
|
||||
|
||||
class DeletePipelineLogsQuery(BaseModel):
|
||||
"""删除流水线日志查询参数"""
|
||||
kb_id: str
|
||||
|
||||
|
||||
class DeletePipelineLogsBody(BaseModel):
|
||||
"""删除流水线日志请求体"""
|
||||
log_ids: List[str]
|
||||
|
||||
|
||||
class RunGraphragRequest(BaseModel):
|
||||
"""运行 GraphRAG 请求"""
|
||||
kb_id: str
|
||||
|
||||
|
||||
class RunRaptorRequest(BaseModel):
|
||||
"""运行 RAPTOR 请求"""
|
||||
kb_id: str
|
||||
|
||||
|
||||
class RunMindmapRequest(BaseModel):
|
||||
"""运行 Mindmap 请求"""
|
||||
kb_id: str
|
||||
|
||||
101
api/apps/models/llm_models.py
Normal file
101
api/apps/models/llm_models.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SetApiKeyRequest(BaseModel):
|
||||
"""设置 API Key 请求"""
|
||||
llm_factory: str = Field(..., description="LLM 工厂名称")
|
||||
api_key: str = Field(..., description="API Key")
|
||||
base_url: Optional[str] = Field(default="", description="API Base URL")
|
||||
model_type: Optional[str] = Field(default=None, description="模型类型")
|
||||
llm_name: Optional[str] = Field(default=None, description="LLM 名称")
|
||||
|
||||
|
||||
class AddLLMRequest(BaseModel):
|
||||
"""添加 LLM 请求"""
|
||||
llm_factory: str = Field(..., description="LLM 工厂名称")
|
||||
model_type: str = Field(..., description="模型类型")
|
||||
llm_name: str = Field(..., description="LLM 名称")
|
||||
api_key: Optional[str] = Field(default="x", description="API Key")
|
||||
api_base: Optional[str] = Field(default="", description="API Base URL")
|
||||
max_tokens: Optional[int] = Field(default=None, description="最大 Token 数")
|
||||
|
||||
# VolcEngine 特殊字段
|
||||
ark_api_key: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
|
||||
# Tencent Hunyuan 特殊字段
|
||||
hunyuan_sid: Optional[str] = None
|
||||
hunyuan_sk: Optional[str] = None
|
||||
|
||||
# Tencent Cloud 特殊字段
|
||||
tencent_cloud_sid: Optional[str] = None
|
||||
tencent_cloud_sk: Optional[str] = None
|
||||
|
||||
# Bedrock 特殊字段
|
||||
bedrock_ak: Optional[str] = None
|
||||
bedrock_sk: Optional[str] = None
|
||||
bedrock_region: Optional[str] = None
|
||||
|
||||
# XunFei Spark 特殊字段
|
||||
spark_api_password: Optional[str] = None
|
||||
spark_app_id: Optional[str] = None
|
||||
spark_api_secret: Optional[str] = None
|
||||
spark_api_key: Optional[str] = None
|
||||
|
||||
# BaiduYiyan 特殊字段
|
||||
yiyan_ak: Optional[str] = None
|
||||
yiyan_sk: Optional[str] = None
|
||||
|
||||
# Fish Audio 特殊字段
|
||||
fish_audio_ak: Optional[str] = None
|
||||
fish_audio_refid: Optional[str] = None
|
||||
|
||||
# Google Cloud 特殊字段
|
||||
google_project_id: Optional[str] = None
|
||||
google_region: Optional[str] = None
|
||||
google_service_account_key: Optional[str] = None
|
||||
|
||||
# Azure-OpenAI 特殊字段
|
||||
api_version: Optional[str] = None
|
||||
|
||||
# OpenRouter 特殊字段
|
||||
provider_order: Optional[str] = None
|
||||
|
||||
|
||||
class DeleteLLMRequest(BaseModel):
|
||||
"""删除 LLM 请求"""
|
||||
llm_factory: str = Field(..., description="LLM 工厂名称")
|
||||
llm_name: str = Field(..., description="LLM 名称")
|
||||
|
||||
|
||||
class DeleteFactoryRequest(BaseModel):
|
||||
"""删除工厂请求"""
|
||||
llm_factory: str = Field(..., description="LLM 工厂名称")
|
||||
|
||||
|
||||
class MyLLMsQuery(BaseModel):
|
||||
"""获取我的 LLMs 查询参数"""
|
||||
include_details: Optional[str] = Field(default="false", description="是否包含详细信息")
|
||||
|
||||
|
||||
class ListLLMsQuery(BaseModel):
|
||||
"""列出 LLMs 查询参数"""
|
||||
model_type: Optional[str] = Field(default=None, description="模型类型过滤")
|
||||
|
||||
99
api/apps/models/mcp_models.py
Normal file
99
api/apps/models/mcp_models.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ListMCPServersQuery(BaseModel):
|
||||
"""列出MCP服务器查询参数"""
|
||||
keywords: Optional[str] = Field(default="", description="关键词")
|
||||
page: Optional[int] = Field(default=0, description="页码")
|
||||
page_size: Optional[int] = Field(default=0, description="每页大小")
|
||||
orderby: Optional[str] = Field(default="create_time", description="排序字段")
|
||||
desc: Optional[str] = Field(default="true", description="是否降序")
|
||||
|
||||
|
||||
class ListMCPServersBody(BaseModel):
|
||||
"""列出MCP服务器请求体"""
|
||||
mcp_ids: Optional[List[str]] = Field(default=[], description="MCP服务器ID列表")
|
||||
|
||||
|
||||
class CreateMCPServerRequest(BaseModel):
|
||||
"""创建MCP服务器请求"""
|
||||
name: str = Field(..., description="服务器名称")
|
||||
url: str = Field(..., description="服务器URL")
|
||||
server_type: str = Field(..., description="服务器类型")
|
||||
headers: Optional[Dict[str, Any]] = Field(default={}, description="请求头")
|
||||
variables: Optional[Dict[str, Any]] = Field(default={}, description="变量")
|
||||
timeout: Optional[float] = Field(default=10.0, description="超时时间")
|
||||
|
||||
|
||||
class UpdateMCPServerRequest(BaseModel):
|
||||
"""更新MCP服务器请求"""
|
||||
mcp_id: str = Field(..., description="MCP服务器ID")
|
||||
name: Optional[str] = Field(default=None, description="服务器名称")
|
||||
url: Optional[str] = Field(default=None, description="服务器URL")
|
||||
server_type: Optional[str] = Field(default=None, description="服务器类型")
|
||||
headers: Optional[Dict[str, Any]] = Field(default=None, description="请求头")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="变量")
|
||||
timeout: Optional[float] = Field(default=10.0, description="超时时间")
|
||||
|
||||
|
||||
class DeleteMCPServersRequest(BaseModel):
|
||||
"""删除MCP服务器请求"""
|
||||
mcp_ids: List[str] = Field(..., description="MCP服务器ID列表")
|
||||
|
||||
|
||||
class ImportMCPServersRequest(BaseModel):
|
||||
"""批量导入MCP服务器请求"""
|
||||
mcpServers: Dict[str, Any] = Field(..., description="MCP服务器配置字典")
|
||||
timeout: Optional[float] = Field(default=10.0, description="超时时间")
|
||||
|
||||
|
||||
class ExportMCPServersRequest(BaseModel):
|
||||
"""批量导出MCP服务器请求"""
|
||||
mcp_ids: List[str] = Field(..., description="MCP服务器ID列表")
|
||||
|
||||
|
||||
class ListMCPToolsRequest(BaseModel):
|
||||
"""列出MCP工具请求"""
|
||||
mcp_ids: List[str] = Field(..., description="MCP服务器ID列表")
|
||||
timeout: Optional[float] = Field(default=10.0, description="超时时间")
|
||||
|
||||
|
||||
class TestMCPToolRequest(BaseModel):
|
||||
"""测试MCP工具请求"""
|
||||
mcp_id: str = Field(..., description="MCP服务器ID")
|
||||
tool_name: str = Field(..., description="工具名称")
|
||||
arguments: Dict[str, Any] = Field(..., description="工具参数")
|
||||
timeout: Optional[float] = Field(default=10.0, description="超时时间")
|
||||
|
||||
|
||||
class CacheMCPToolsRequest(BaseModel):
|
||||
"""缓存MCP工具请求"""
|
||||
mcp_id: str = Field(..., description="MCP服务器ID")
|
||||
tools: List[Dict[str, Any]] = Field(..., description="工具列表")
|
||||
|
||||
|
||||
class TestMCPRequest(BaseModel):
|
||||
"""测试MCP服务器请求(不需要登录)"""
|
||||
url: str = Field(..., description="服务器URL")
|
||||
server_type: str = Field(..., description="服务器类型")
|
||||
headers: Optional[Dict[str, Any]] = Field(default={}, description="请求头")
|
||||
variables: Optional[Dict[str, Any]] = Field(default={}, description="变量")
|
||||
timeout: Optional[float] = Field(default=10.0, description="超时时间")
|
||||
|
||||
@@ -1,8 +1,26 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from flask import Response
|
||||
from flask_login import login_required
|
||||
from api.utils.api_utils import get_json_result
|
||||
from plugin import GlobalPluginManager
|
||||
|
||||
|
||||
@manager.route('/llm_tools', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def llm_tools() -> Response:
|
||||
|
||||
@@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_error_data_result, ge
|
||||
from api.utils.api_utils import get_result
|
||||
from flask import request
|
||||
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def list_agents(tenant_id):
|
||||
@@ -41,7 +42,7 @@ def list_agents(tenant_id):
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
canvas = UserCanvasService.get_list(tenant_id,page_number,items_per_page,orderby,desc,id,title)
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
|
||||
return get_result(data=canvas)
|
||||
|
||||
|
||||
@@ -93,7 +94,7 @@ def update_agent(tenant_id: str, agent_id: str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
|
||||
|
||||
if req.get("title") is not None:
|
||||
req["title"] = req["title"].strip()
|
||||
|
||||
|
||||
@@ -215,7 +215,8 @@ def delete(tenant_id):
|
||||
continue
|
||||
kb_id_instance_pairs.append((kb_id, kb))
|
||||
if len(error_kb_ids) > 0:
|
||||
return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
|
||||
return get_error_permission_result(
|
||||
message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
@@ -232,7 +233,8 @@ def delete(tenant_id):
|
||||
]
|
||||
)
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
|
||||
if not KnowledgebaseService.delete_by_id(kb_id):
|
||||
errors.append(f"Delete dataset error for {kb_id}")
|
||||
continue
|
||||
@@ -329,7 +331,8 @@ def update(tenant_id, dataset_id):
|
||||
try:
|
||||
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
|
||||
if kb is None:
|
||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
|
||||
return get_error_permission_result(
|
||||
message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
|
||||
|
||||
if req.get("parser_config"):
|
||||
req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
|
||||
@@ -341,7 +344,8 @@ def update(tenant_id, dataset_id):
|
||||
del req["parser_config"]
|
||||
|
||||
if "name" in req and req["name"].lower() != kb.name.lower():
|
||||
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id,
|
||||
status=StatusEnum.VALID.value)
|
||||
if exists:
|
||||
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
|
||||
|
||||
@@ -349,7 +353,8 @@ def update(tenant_id, dataset_id):
|
||||
if not req["embd_id"]:
|
||||
req["embd_id"] = kb.embd_id
|
||||
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
|
||||
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
|
||||
return get_error_data_result(
|
||||
message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
|
||||
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
||||
if not ok:
|
||||
return err
|
||||
@@ -359,10 +364,12 @@ def update(tenant_id, dataset_id):
|
||||
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
||||
|
||||
if req["pagerank"] > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||
@@ -454,7 +461,7 @@ def list_datasets(tenant_id):
|
||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
|
||||
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
|
||||
kbs = KnowledgebaseService.get_list(
|
||||
kbs, total = KnowledgebaseService.get_list(
|
||||
[m["tenant_id"] for m in tenants],
|
||||
tenant_id,
|
||||
args["page"],
|
||||
@@ -468,14 +475,15 @@ def list_datasets(tenant_id):
|
||||
response_data_list = []
|
||||
for kb in kbs:
|
||||
response_data_list.append(remap_dictionary_keys(kb))
|
||||
return get_result(data=response_data_list)
|
||||
return get_result(data=response_data_list, total=total)
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
|
||||
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def knowledge_graph(tenant_id,dataset_id):
|
||||
def knowledge_graph(tenant_id, dataset_id):
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
@@ -491,7 +499,7 @@ def knowledge_graph(tenant_id,dataset_id):
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||
return get_result(data=obj)
|
||||
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
if not len(sres.ids):
|
||||
return get_result(data=obj)
|
||||
|
||||
@@ -507,14 +515,16 @@ def knowledge_graph(tenant_id,dataset_id):
|
||||
if "nodes" in obj["graph"]:
|
||||
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
|
||||
if "edges" in obj["graph"]:
|
||||
node_id_set = { o["id"] for o in obj["graph"]["nodes"] }
|
||||
filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
|
||||
node_id_set = {o["id"] for o in obj["graph"]["nodes"]}
|
||||
filtered_edges = [o for o in obj["graph"]["edges"] if
|
||||
o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
|
||||
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
|
||||
return get_result(data=obj)
|
||||
|
||||
|
||||
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
|
||||
@token_required
|
||||
def delete_knowledge_graph(tenant_id,dataset_id):
|
||||
def delete_knowledge_graph(tenant_id, dataset_id):
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
@@ -522,6 +532,7 @@ def delete_knowledge_graph(tenant_id,dataset_id):
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
_, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id)
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
|
||||
search.index_name(kb.tenant_id), dataset_id)
|
||||
|
||||
return get_result(data=True)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -31,6 +31,89 @@ from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
@apikey_required
|
||||
@validate_request("knowledge_id", "query")
|
||||
def retrieval(tenant_id):
|
||||
"""
|
||||
Dify-compatible retrieval API
|
||||
---
|
||||
tags:
|
||||
- SDK
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- knowledge_id
|
||||
- query
|
||||
properties:
|
||||
knowledge_id:
|
||||
type: string
|
||||
description: Knowledge base ID
|
||||
query:
|
||||
type: string
|
||||
description: Query text
|
||||
use_kg:
|
||||
type: boolean
|
||||
description: Whether to use knowledge graph
|
||||
default: false
|
||||
retrieval_setting:
|
||||
type: object
|
||||
description: Retrieval configuration
|
||||
properties:
|
||||
score_threshold:
|
||||
type: number
|
||||
description: Similarity threshold
|
||||
default: 0.0
|
||||
top_k:
|
||||
type: integer
|
||||
description: Number of results to return
|
||||
default: 1024
|
||||
metadata_condition:
|
||||
type: object
|
||||
description: Metadata filter condition
|
||||
properties:
|
||||
conditions:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Field name
|
||||
comparison_operator:
|
||||
type: string
|
||||
description: Comparison operator
|
||||
value:
|
||||
type: string
|
||||
description: Field value
|
||||
responses:
|
||||
200:
|
||||
description: Retrieval succeeded
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
records:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
description: Content text
|
||||
score:
|
||||
type: number
|
||||
description: Similarity score
|
||||
title:
|
||||
type: string
|
||||
description: Document title
|
||||
metadata:
|
||||
type: object
|
||||
description: Metadata info
|
||||
404:
|
||||
description: Knowledge base or document not found
|
||||
"""
|
||||
req = request.json
|
||||
question = req["query"]
|
||||
kb_id = req["knowledge_id"]
|
||||
@@ -38,9 +121,9 @@ def retrieval(tenant_id):
|
||||
retrieval_setting = req.get("retrieval_setting", {})
|
||||
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
|
||||
top = int(retrieval_setting.get("top_k", 1024))
|
||||
metadata_condition = req.get("metadata_condition",{})
|
||||
metadata_condition = req.get("metadata_condition", {})
|
||||
metas = DocumentService.get_meta_by_kbs([kb_id])
|
||||
|
||||
|
||||
doc_ids = []
|
||||
try:
|
||||
|
||||
@@ -50,12 +133,12 @@ def retrieval(tenant_id):
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
print(metadata_condition)
|
||||
print("after",convert_conditions(metadata_condition))
|
||||
# print("after", convert_conditions(metadata_condition))
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
|
||||
print("doc_ids",doc_ids)
|
||||
# print("doc_ids", doc_ids)
|
||||
if not doc_ids and metadata_condition is not None:
|
||||
doc_ids = ['-999']
|
||||
ranks = settings.retrievaler.retrieval(
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
kb.tenant_id,
|
||||
@@ -70,17 +153,17 @@ def retrieval(tenant_id):
|
||||
)
|
||||
|
||||
if use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(question,
|
||||
[tenant_id],
|
||||
[kb_id],
|
||||
embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retriever.retrieval(question,
|
||||
[tenant_id],
|
||||
[kb_id],
|
||||
embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
|
||||
records = []
|
||||
for c in ranks["chunks"]:
|
||||
e, doc = DocumentService.get_by_id( c["doc_id"])
|
||||
e, doc = DocumentService.get_by_id(c["doc_id"])
|
||||
c.pop("vector", None)
|
||||
meta = getattr(doc, 'meta_fields', {})
|
||||
meta["doc_id"] = c["doc_id"]
|
||||
@@ -100,5 +183,3 @@ def retrieval(tenant_id):
|
||||
)
|
||||
logging.exception(e)
|
||||
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class Chunk(BaseModel):
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def upload(dataset_id, tenant_id):
|
||||
def upload(dataset_id, tenant_id):
|
||||
"""
|
||||
Upload documents to a dataset.
|
||||
---
|
||||
@@ -151,7 +151,7 @@ async def upload(dataset_id, tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
||||
err, files = await FileService.upload_document(kb, file_objs, tenant_id)
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
||||
if err:
|
||||
return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||
# rename key's name
|
||||
@@ -470,6 +470,20 @@ def list_docs(dataset_id, tenant_id):
|
||||
required: false
|
||||
default: 0
|
||||
description: Unix timestamp for filtering documents created before this time. 0 means no filter.
|
||||
- in: query
|
||||
name: suffix
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: false
|
||||
description: Filter by file suffix (e.g., ["pdf", "txt", "docx"]).
|
||||
- in: query
|
||||
name: run
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required: false
|
||||
description: Filter by document run status. Supports both numeric ("0", "1", "2", "3", "4") and text formats ("UNSTART", "RUNNING", "CANCEL", "DONE", "FAIL").
|
||||
- in: header
|
||||
name: Authorization
|
||||
type: string
|
||||
@@ -512,63 +526,62 @@ def list_docs(dataset_id, tenant_id):
|
||||
description: Processing status.
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
id = request.args.get("id")
|
||||
name = request.args.get("name")
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
|
||||
if id and not DocumentService.query(id=id, kb_id=dataset_id):
|
||||
return get_error_data_result(message=f"You don't own the document {id}.")
|
||||
q = request.args
|
||||
document_id = q.get("id")
|
||||
name = q.get("name")
|
||||
|
||||
if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id):
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
if name and not DocumentService.query(name=name, kb_id=dataset_id):
|
||||
return get_error_data_result(message=f"You don't own the document {name}.")
|
||||
|
||||
page = int(request.args.get("page", 1))
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_size = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc") == "False":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
docs, tol = DocumentService.get_list(dataset_id, page, page_size, orderby, desc, keywords, id, name)
|
||||
page = int(q.get("page", 1))
|
||||
page_size = int(q.get("page_size", 30))
|
||||
orderby = q.get("orderby", "create_time")
|
||||
desc = str(q.get("desc", "true")).strip().lower() != "false"
|
||||
keywords = q.get("keywords", "")
|
||||
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
# filters - align with OpenAPI parameter names
|
||||
suffix = q.getlist("suffix")
|
||||
run_status = q.getlist("run")
|
||||
create_time_from = int(q.get("create_time_from", 0))
|
||||
create_time_to = int(q.get("create_time_to", 0))
|
||||
|
||||
# map run status (accept text or numeric) - align with API parameter
|
||||
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
||||
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
|
||||
|
||||
docs, total = DocumentService.get_list(
|
||||
dataset_id, page, page_size, orderby, desc, keywords, document_id, name, suffix, run_status_converted
|
||||
)
|
||||
|
||||
# time range filter (0 means no bound)
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
for doc in docs:
|
||||
doc_create_time = doc.get("create_time", 0)
|
||||
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||
filtered_docs.append(doc)
|
||||
docs = filtered_docs
|
||||
docs = [
|
||||
d for d in docs
|
||||
if (create_time_from == 0 or d.get("create_time", 0) >= create_time_from)
|
||||
and (create_time_to == 0 or d.get("create_time", 0) <= create_time_to)
|
||||
]
|
||||
|
||||
# rename key's name
|
||||
renamed_doc_list = []
|
||||
# rename keys + map run status back to text for output
|
||||
key_mapping = {
|
||||
"chunk_num": "chunk_count",
|
||||
"kb_id": "dataset_id",
|
||||
"kb_id": "dataset_id",
|
||||
"token_num": "token_count",
|
||||
"parser_id": "chunk_method",
|
||||
}
|
||||
run_mapping = {
|
||||
"0": "UNSTART",
|
||||
"1": "RUNNING",
|
||||
"2": "CANCEL",
|
||||
"3": "DONE",
|
||||
"4": "FAIL",
|
||||
}
|
||||
for doc in docs:
|
||||
renamed_doc = {}
|
||||
for key, value in doc.items():
|
||||
if key == "run":
|
||||
renamed_doc["run"] = run_mapping.get(str(value))
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_doc[new_key] = value
|
||||
if key == "run":
|
||||
renamed_doc["run"] = run_mapping.get(value)
|
||||
renamed_doc_list.append(renamed_doc)
|
||||
return get_result(data={"total": tol, "docs": renamed_doc_list})
|
||||
run_status_numeric_to_text = {"0": "UNSTART", "1": "RUNNING", "2": "CANCEL", "3": "DONE", "4": "FAIL"}
|
||||
|
||||
output_docs = []
|
||||
for d in docs:
|
||||
renamed_doc = {key_mapping.get(k, k): v for k, v in d.items()}
|
||||
if "run" in d:
|
||||
renamed_doc["run"] = run_status_numeric_to_text.get(str(d["run"]), d["run"])
|
||||
output_docs.append(renamed_doc)
|
||||
|
||||
return get_result(data={"total": total, "docs": output_docs})
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/documents", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
@@ -982,7 +995,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
_ = Chunk(**final_chunk)
|
||||
|
||||
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
res["total"] = sres.total
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@@ -1446,7 +1459,7 @@ def retrieval_test(tenant_id):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
ranks = settings.retrievaler.retrieval(
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
@@ -1462,7 +1475,7 @@ def retrieval_test(tenant_id):
|
||||
rank_feature=label_question(question, kbs),
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
@@ -17,7 +34,8 @@ from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def upload(tenant_id):
|
||||
"""
|
||||
@@ -44,22 +62,22 @@ def upload(tenant_id):
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: File ID
|
||||
name:
|
||||
type: string
|
||||
description: File name
|
||||
size:
|
||||
type: integer
|
||||
description: File size in bytes
|
||||
type:
|
||||
type: string
|
||||
description: File type (e.g., document, folder)
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: File ID
|
||||
name:
|
||||
type: string
|
||||
description: File name
|
||||
size:
|
||||
type: integer
|
||||
description: File size in bytes
|
||||
type:
|
||||
type: string
|
||||
description: File type (e.g., document, folder)
|
||||
"""
|
||||
pf_id = request.form.get("parent_id")
|
||||
|
||||
@@ -97,12 +115,14 @@ def upload(tenant_id):
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, len_id_list)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, len_id_list)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
@@ -129,7 +149,7 @@ def upload(tenant_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/create', methods=['POST']) # noqa: F821
|
||||
@manager.route('/file/create', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
"""
|
||||
@@ -207,7 +227,7 @@ def create(tenant_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
||||
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def list_files(tenant_id):
|
||||
"""
|
||||
@@ -299,7 +319,7 @@ def list_files(tenant_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
||||
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_root_folder(tenant_id):
|
||||
"""
|
||||
@@ -335,7 +355,7 @@ def get_root_folder(tenant_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
||||
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_parent_folder():
|
||||
"""
|
||||
@@ -380,7 +400,7 @@ def get_parent_folder():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
||||
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_all_parent_folders(tenant_id):
|
||||
"""
|
||||
@@ -428,7 +448,7 @@ def get_all_parent_folders(tenant_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/rm', methods=['POST']) # noqa: F821
|
||||
@manager.route('/file/rm', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rm(tenant_id):
|
||||
"""
|
||||
@@ -502,7 +522,7 @@ def rm(tenant_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/rename', methods=['POST']) # noqa: F821
|
||||
@manager.route('/file/rename', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rename(tenant_id):
|
||||
"""
|
||||
@@ -542,7 +562,8 @@ def rename(tenant_id):
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
|
||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(file.name.lower()).suffix:
|
||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
|
||||
|
||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
@@ -562,9 +583,9 @@ def rename(tenant_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get(tenant_id,file_id):
|
||||
def get(tenant_id, file_id):
|
||||
"""
|
||||
Download a file.
|
||||
---
|
||||
@@ -610,7 +631,7 @@ def get(tenant_id,file_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def move(tenant_id):
|
||||
"""
|
||||
@@ -669,6 +690,7 @@ def move(tenant_id):
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def convert(tenant_id):
|
||||
@@ -735,4 +757,4 @@ def convert(tenant_id):
|
||||
file2documents.append(file2document.to_json())
|
||||
return get_json_result(data=file2documents)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
||||
@@ -36,7 +36,8 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||
get_result, server_error_response, token_required, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||
@@ -88,7 +89,8 @@ def create_agent_session(tenant_id, agent_id):
|
||||
canvas.reset()
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
||||
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id,
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
||||
API4ConversationService.save(**conv)
|
||||
conv["agent_id"] = conv.pop("dialog_id")
|
||||
return get_result(data=conv)
|
||||
@@ -279,7 +281,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||
if reasoning_match:
|
||||
reasoning_part = reasoning_match.group(1)
|
||||
content_part = answer[reasoning_match.end() :]
|
||||
content_part = answer[reasoning_match.end():]
|
||||
else:
|
||||
reasoning_part = ""
|
||||
content_part = answer
|
||||
@@ -324,7 +326,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
response["choices"][0]["delta"]["content"] = None
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
response["choices"][0]["finish_reason"] = "stop"
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used,
|
||||
"total_tokens": len(prompt) + token_used}
|
||||
if need_reference:
|
||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||
@@ -559,7 +562,8 @@ def list_agent_session(tenant_id, agent_id):
|
||||
desc = True
|
||||
# dsl defaults to True in all cases except for False and false
|
||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||
total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
|
||||
total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
|
||||
user_id, include_dsl)
|
||||
if not convs:
|
||||
return get_result(data=[])
|
||||
for conv in convs:
|
||||
@@ -581,7 +585,8 @@ def list_agent_session(tenant_id, agent_id):
|
||||
if message_num != 0 and messages[message_num]["role"] != "user":
|
||||
chunk_list = []
|
||||
# Add boundary and type checks to prevent KeyError
|
||||
if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
|
||||
if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(
|
||||
conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
|
||||
chunks = conv["reference"][chunk_num]["chunks"]
|
||||
for chunk in chunks:
|
||||
# Ensure chunk is a dictionary before calling get method
|
||||
@@ -639,13 +644,16 @@ def delete(tenant_id, chat_id):
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
return get_result(data={"success_count": success_count, "errors": errors},
|
||||
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
||||
return get_result(
|
||||
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
||||
data={"success_count": success_count, "errors": duplicate_messages})
|
||||
else:
|
||||
return get_error_data_result(message=";".join(duplicate_messages))
|
||||
|
||||
@@ -691,13 +699,16 @@ def delete_agent_session(tenant_id, agent_id):
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
return get_result(data={"success_count": success_count, "errors": errors},
|
||||
message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
||||
return get_result(
|
||||
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
|
||||
data={"success_count": success_count, "errors": duplicate_messages})
|
||||
else:
|
||||
return get_error_data_result(message=";".join(duplicate_messages))
|
||||
|
||||
@@ -730,7 +741,9 @@ def ask_about(tenant_id):
|
||||
for ans in ask(req["question"], req["kb_ids"], uid):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps(
|
||||
{"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
@@ -882,7 +895,9 @@ def begin_inputs(agent_id):
|
||||
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
||||
|
||||
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
|
||||
return get_result(data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
|
||||
return get_result(
|
||||
data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"),
|
||||
"prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
|
||||
|
||||
|
||||
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
||||
@@ -911,7 +926,9 @@ def ask_about_embedded():
|
||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps(
|
||||
{"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
@@ -978,7 +995,8 @@ def retrieval_test_embedded():
|
||||
tenant_ids.append(tenant.tenant_id)
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.",
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
if not e:
|
||||
@@ -998,11 +1016,13 @@ def retrieval_test_embedded():
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retrievaler.retrieval(
|
||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
ranks = settings.retriever.retrieval(
|
||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
|
||||
@@ -1013,7 +1033,8 @@ def retrieval_test_embedded():
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR)
|
||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@@ -1082,7 +1103,8 @@ def detail_share_embedded():
|
||||
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
return get_json_result(data=False, message="Has no permission for this operation.",
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
search = SearchService.get_detail(search_id)
|
||||
if not search:
|
||||
|
||||
@@ -39,6 +39,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
|
||||
|
||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def version():
|
||||
@@ -161,7 +162,7 @@ def status():
|
||||
task_executors = REDIS_CONN.smembers("TASKEXE")
|
||||
now = datetime.now().timestamp()
|
||||
for task_executor_id in task_executors:
|
||||
heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60*30, now)
|
||||
heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60 * 30, now)
|
||||
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
|
||||
task_executor_heartbeats[task_executor_id] = heartbeats
|
||||
except Exception:
|
||||
@@ -273,7 +274,8 @@ def token_list():
|
||||
objs = [o.to_dict() for o in objs]
|
||||
for o in objs:
|
||||
if not o["beta"]:
|
||||
o["beta"] = generate_confirmation_token(generate_confirmation_token(tenants[0].tenant_id)).replace("ragflow-", "")[:32]
|
||||
o["beta"] = generate_confirmation_token(generate_confirmation_token(tenants[0].tenant_id)).replace(
|
||||
"ragflow-", "")[:32]
|
||||
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
|
||||
return get_json_result(data=objs)
|
||||
except Exception as e:
|
||||
|
||||
@@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from api.models.tenant_models import InviteUserRequest, UserTenantResponse
|
||||
from api.utils.api_utils import get_current_user
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api import settings
|
||||
from api.apps import smtp_mail_server
|
||||
@@ -25,18 +24,13 @@ from api.db.db_models import UserTenant
|
||||
from api.db.services.user_service import UserTenantService, UserService
|
||||
|
||||
from api.utils import get_uuid, delta_seconds
|
||||
from api.utils.api_utils import get_json_result, server_error_response, get_data_error_result
|
||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
||||
from api.utils.web_utils import send_invite_email
|
||||
|
||||
# 创建 FastAPI 路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{tenant_id}/user/list")
|
||||
async def user_list(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def user_list(tenant_id):
|
||||
if current_user.id != tenant_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -52,19 +46,18 @@ async def user_list(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/{tenant_id}/user')
|
||||
async def create(
|
||||
tenant_id: str,
|
||||
request: InviteUserRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("email")
|
||||
def create(tenant_id):
|
||||
if current_user.id != tenant_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
invite_user_email = request.email
|
||||
req = request.json
|
||||
invite_user_email = req["email"]
|
||||
invite_users = UserService.query(email=invite_user_email)
|
||||
if not invite_users:
|
||||
return get_data_error_result(message="User not found.")
|
||||
@@ -77,7 +70,8 @@ async def create(
|
||||
return get_data_error_result(message=f"{invite_user_email} is already in the team.")
|
||||
if user_tenant_role == UserTenantRole.OWNER:
|
||||
return get_data_error_result(message=f"{invite_user_email} is the owner of the team.")
|
||||
return get_data_error_result(message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.")
|
||||
return get_data_error_result(
|
||||
message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.")
|
||||
|
||||
UserTenantService.save(
|
||||
id=get_uuid(),
|
||||
@@ -107,12 +101,9 @@ async def create(
|
||||
return get_json_result(data=usr)
|
||||
|
||||
|
||||
@router.delete('/{tenant_id}/user/{user_id}')
|
||||
async def rm(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
@manager.route('/<tenant_id>/user/<user_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
def rm(tenant_id, user_id):
|
||||
if current_user.id != tenant_id and current_user.id != user_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@@ -126,10 +117,9 @@ async def rm(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def tenant_list(
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def tenant_list():
|
||||
try:
|
||||
users = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
for u in users:
|
||||
@@ -139,13 +129,12 @@ async def tenant_list(
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.put("/agree/{tenant_id}")
|
||||
async def agree(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
@manager.route("/agree/<tenant_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
def agree(tenant_id):
|
||||
try:
|
||||
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], {"role": UserTenantRole.NORMAL})
|
||||
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id],
|
||||
{"role": UserTenantRole.NORMAL})
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@@ -15,11 +15,14 @@
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import string
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from flask import redirect, request, session
|
||||
from flask import redirect, request, session, make_response
|
||||
from flask_login import current_user, login_required, login_user, logout_user
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
@@ -46,6 +49,19 @@ from api.utils.api_utils import (
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import smtp_mail_server
|
||||
from api.utils.web_utils import (
|
||||
send_email_html,
|
||||
OTP_LENGTH,
|
||||
OTP_TTL_SECONDS,
|
||||
ATTEMPT_LIMIT,
|
||||
ATTEMPT_LOCK_SECONDS,
|
||||
RESEND_COOLDOWN_SECONDS,
|
||||
otp_keys,
|
||||
hash_code,
|
||||
captcha_key,
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
@@ -825,3 +841,172 @@ def set_tenant_info():
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
|
||||
def forget_get_captcha():
|
||||
"""
|
||||
GET /forget/captcha?email=<email>
|
||||
- Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS.
|
||||
- Returns the captcha as a PNG image.
|
||||
"""
|
||||
email = (request.args.get("email") or "")
|
||||
if not email:
|
||||
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email is required")
|
||||
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
|
||||
|
||||
# Generate captcha text
|
||||
allowed = string.ascii_uppercase + string.digits
|
||||
captcha_text = "".join(secrets.choice(allowed) for _ in range(OTP_LENGTH))
|
||||
REDIS_CONN.set(captcha_key(email), captcha_text, 60) # Valid for 60 seconds
|
||||
|
||||
from captcha.image import ImageCaptcha
|
||||
image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70])
|
||||
img_bytes = image.generate(captcha_text).read()
|
||||
response = make_response(img_bytes)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/forget/otp", methods=["POST"]) # noqa: F821
|
||||
def forget_send_otp():
|
||||
"""
|
||||
POST /forget/otp
|
||||
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||||
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||||
"""
|
||||
req = request.get_json()
|
||||
email = req.get("email") or ""
|
||||
captcha = (req.get("captcha") or "").strip()
|
||||
|
||||
if not email or not captcha:
|
||||
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email and captcha required")
|
||||
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
|
||||
|
||||
stored_captcha = REDIS_CONN.get(captcha_key(email))
|
||||
if not stored_captcha:
|
||||
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="invalid or expired captcha")
|
||||
if (stored_captcha or "").strip().lower() != captcha.lower():
|
||||
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="invalid or expired captcha")
|
||||
|
||||
# Delete captcha to prevent reuse
|
||||
REDIS_CONN.delete(captcha_key(email))
|
||||
|
||||
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||||
now = int(time.time())
|
||||
last_ts = REDIS_CONN.get(k_last)
|
||||
if last_ts:
|
||||
try:
|
||||
elapsed = now - int(last_ts)
|
||||
except Exception:
|
||||
elapsed = RESEND_COOLDOWN_SECONDS
|
||||
remaining = RESEND_COOLDOWN_SECONDS - elapsed
|
||||
if remaining > 0:
|
||||
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message=f"you still have to wait {remaining} seconds")
|
||||
|
||||
# Generate OTP (uppercase letters only) and store hashed
|
||||
otp = "".join(secrets.choice(string.ascii_uppercase) for _ in range(OTP_LENGTH))
|
||||
salt = os.urandom(16)
|
||||
code_hash = hash_code(otp, salt)
|
||||
REDIS_CONN.set(k_code, f"{code_hash}:{salt.hex()}", OTP_TTL_SECONDS)
|
||||
REDIS_CONN.set(k_attempts, 0, OTP_TTL_SECONDS)
|
||||
REDIS_CONN.set(k_last, now, OTP_TTL_SECONDS)
|
||||
REDIS_CONN.delete(k_lock)
|
||||
|
||||
ttl_min = OTP_TTL_SECONDS // 60
|
||||
|
||||
if not smtp_mail_server:
|
||||
logging.warning("SMTP mail server not initialized; skip sending email.")
|
||||
else:
|
||||
try:
|
||||
send_email_html(
|
||||
subject="Your Password Reset Code",
|
||||
to_email=email,
|
||||
template_key="reset_code",
|
||||
code=otp,
|
||||
ttl_min=ttl_min,
|
||||
)
|
||||
except Exception:
|
||||
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="failed to send email")
|
||||
|
||||
return get_json_result(data=True, code=settings.RetCode.SUCCESS, message="verification passed, email sent")
|
||||
|
||||
|
||||
@manager.route("/forget", methods=["POST"]) # noqa: F821
|
||||
def forget():
|
||||
"""
|
||||
POST: Verify email + OTP and reset password, then log the user in.
|
||||
Request JSON: { email, otp, new_password, confirm_new_password }
|
||||
"""
|
||||
req = request.get_json()
|
||||
email = req.get("email") or ""
|
||||
otp = (req.get("otp") or "").strip()
|
||||
new_pwd = req.get("new_password")
|
||||
new_pwd2 = req.get("confirm_new_password")
|
||||
|
||||
if not all([email, otp, new_pwd, new_pwd2]):
|
||||
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required")
|
||||
|
||||
# For reset, passwords are provided as-is (no decrypt needed)
|
||||
if new_pwd != new_pwd2:
|
||||
return get_json_result(data=False, code=settings.RetCode.ARGUMENT_ERROR, message="passwords do not match")
|
||||
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(data=False, code=settings.RetCode.DATA_ERROR, message="invalid email")
|
||||
|
||||
user = users[0]
|
||||
# Verify OTP from Redis
|
||||
k_code, k_attempts, k_last, k_lock = otp_keys(email)
|
||||
if REDIS_CONN.get(k_lock):
|
||||
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="too many attempts, try later")
|
||||
|
||||
stored = REDIS_CONN.get(k_code)
|
||||
if not stored:
|
||||
return get_json_result(data=False, code=settings.RetCode.NOT_EFFECTIVE, message="expired otp")
|
||||
|
||||
try:
|
||||
stored_hash, salt_hex = str(stored).split(":", 1)
|
||||
salt = bytes.fromhex(salt_hex)
|
||||
except Exception:
|
||||
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="otp storage corrupted")
|
||||
|
||||
# Case-insensitive verification: OTP generated uppercase
|
||||
calc = hash_code(otp.upper(), salt)
|
||||
if calc != stored_hash:
|
||||
# bump attempts
|
||||
try:
|
||||
attempts = int(REDIS_CONN.get(k_attempts) or 0) + 1
|
||||
except Exception:
|
||||
attempts = 1
|
||||
REDIS_CONN.set(k_attempts, attempts, OTP_TTL_SECONDS)
|
||||
if attempts >= ATTEMPT_LIMIT:
|
||||
REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS)
|
||||
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="expired otp")
|
||||
|
||||
# Success: consume OTP and reset password
|
||||
REDIS_CONN.delete(k_code)
|
||||
REDIS_CONN.delete(k_attempts)
|
||||
REDIS_CONN.delete(k_last)
|
||||
REDIS_CONN.delete(k_lock)
|
||||
|
||||
try:
|
||||
UserService.update_user_password(user.id, new_pwd)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=False, code=settings.RetCode.EXCEPTION_ERROR, message="failed to reset password")
|
||||
|
||||
# Auto login (reuse login flow)
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.save()
|
||||
msg = "Password reset successful. Logged in."
|
||||
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||
|
||||
@@ -21,8 +21,7 @@ from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from api.utils.api_utils import security
|
||||
from api.apps.models.auth_dependencies import get_current_user
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, EmailStr
|
||||
try:
|
||||
@@ -89,63 +88,7 @@ class TenantInfoRequest(BaseModel):
|
||||
img2txt_id: str
|
||||
llm_id: str
|
||||
|
||||
# 依赖项:获取当前用户
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户"""
|
||||
from api.db import StatusEnum
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User {user[0].email} has empty access_token in database"
|
||||
)
|
||||
return user[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required"
|
||||
)
|
||||
# 依赖项:获取当前用户 - 从 auth_dependencies 导入
|
||||
|
||||
@router.post("/login")
|
||||
async def login(request: LoginRequest):
|
||||
|
||||
Reference in New Issue
Block a user