将flask改成fastapi
This commit is contained in:
181
api/apps/__init__.py
Normal file
181
api/apps/__init__.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, Flask
|
||||
from werkzeug.wrappers.request import Request
|
||||
from flask_cors import CORS
|
||||
from flasgger import Swagger
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.services import UserService
|
||||
from api.utils.json import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from flask_mail import Mail
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.constants import API_VERSION
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Flask(__name__)
|
||||
smtp_mail_server = Mail()
|
||||
|
||||
# Add this at the beginning of your file to configure Swagger UI
|
||||
swagger_config = {
|
||||
"headers": [],
|
||||
"specs": [
|
||||
{
|
||||
"endpoint": "apispec",
|
||||
"route": "/apispec.json",
|
||||
"rule_filter": lambda rule: True, # Include all endpoints
|
||||
"model_filter": lambda tag: True, # Include all models
|
||||
}
|
||||
],
|
||||
"static_url_path": "/flasgger_static",
|
||||
"swagger_ui": True,
|
||||
"specs_route": "/apidocs/",
|
||||
}
|
||||
|
||||
swagger = Swagger(
|
||||
app,
|
||||
config=swagger_config,
|
||||
template={
|
||||
"swagger": "2.0",
|
||||
"info": {
|
||||
"title": "RAGFlow API",
|
||||
"description": "",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"securityDefinitions": {
|
||||
"ApiKeyAuth": {"type": "apiKey", "name": "Authorization", "in": "header"}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
CORS(app, supports_credentials=True, max_age=2592000)
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
|
||||
## convince for dev and debug
|
||||
# app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
app.config["SESSION_TYPE"] = "filesystem"
|
||||
app.config["MAX_CONTENT_LENGTH"] = int(
|
||||
os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)
|
||||
)
|
||||
|
||||
Session(app)
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
commands.register_commands(app)
|
||||
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
app_path_list = [
|
||||
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
|
||||
]
|
||||
api_path_list = [
|
||||
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list.extend(api_path_list)
|
||||
return app_path_list
|
||||
|
||||
|
||||
def register_page(page_path):
|
||||
path = f"{page_path}"
|
||||
|
||||
page_name = page_path.stem.removesuffix("_app")
|
||||
module_name = ".".join(
|
||||
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
|
||||
)
|
||||
|
||||
spec = spec_from_file_location(module_name, page_path)
|
||||
page = module_from_spec(spec)
|
||||
page.app = app
|
||||
page.manager = Blueprint(page_name, module_name)
|
||||
sys.modules[module_name] = page
|
||||
spec.loader.exec_module(page)
|
||||
page_name = getattr(page, "page_name", page_name)
|
||||
sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/"
|
||||
url_prefix = (
|
||||
f"/api/{API_VERSION}" if sdk_path in path else f"/{API_VERSION}/{page_name}"
|
||||
)
|
||||
|
||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||
return url_prefix
|
||||
|
||||
|
||||
pages_dir = [
|
||||
Path(__file__).parent,
|
||||
Path(__file__).parent.parent / "api" / "apps",
|
||||
Path(__file__).parent.parent / "api" / "apps" / "sdk",
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
|
||||
]
|
||||
|
||||
|
||||
@login_manager.request_loader
|
||||
def load_user(web_request):
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = web_request.headers.get("Authorization")
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
return user[0]
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exc):
|
||||
close_connection()
|
||||
181
api/apps/__init___fastapi.py
Normal file
181
api/apps/__init___fastapi.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.services import UserService
|
||||
from api.utils.json import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.constants import API_VERSION
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""创建FastAPI应用实例"""
|
||||
app = FastAPI(
|
||||
title="RAGFlow API",
|
||||
description="RAGFlow API Server",
|
||||
version="1.0.0",
|
||||
docs_url="/apidocs/",
|
||||
redoc_url="/redoc/",
|
||||
openapi_url="/apispec.json"
|
||||
)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 生产环境中应该设置具体的域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
max_age=2592000
|
||||
)
|
||||
|
||||
# 添加信任主机中间件
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=["*"] # 生产环境中应该设置具体的域名
|
||||
)
|
||||
|
||||
# 添加会话中间件
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=settings.SECRET_KEY,
|
||||
max_age=2592000
|
||||
)
|
||||
|
||||
# 设置错误处理器
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request, exc):
|
||||
return server_error_response(exc)
|
||||
|
||||
return app
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
"""搜索页面路径"""
|
||||
app_path_list = [
|
||||
path for path in pages_dir.glob("*_app_fastapi.py") if not path.name.startswith(".")
|
||||
]
|
||||
api_path_list = [
|
||||
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list.extend(api_path_list)
|
||||
return app_path_list
|
||||
|
||||
def register_page(app: FastAPI, page_path):
|
||||
"""注册页面路由"""
|
||||
path = f"{page_path}"
|
||||
|
||||
page_name = page_path.stem.removesuffix("_app_fastapi")
|
||||
module_name = ".".join(
|
||||
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
|
||||
)
|
||||
|
||||
spec = spec_from_file_location(module_name, page_path)
|
||||
page = module_from_spec(spec)
|
||||
page.app = app
|
||||
page.router = None # FastAPI使用router而不是Blueprint
|
||||
sys.modules[module_name] = page
|
||||
spec.loader.exec_module(page)
|
||||
page_name = getattr(page, "page_name", page_name)
|
||||
sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/"
|
||||
url_prefix = (
|
||||
f"/api/{API_VERSION}" if sdk_path in path else f"/{API_VERSION}/{page_name}"
|
||||
)
|
||||
|
||||
# 在FastAPI中,我们需要检查是否有router属性
|
||||
if hasattr(page, 'router') and page.router:
|
||||
app.include_router(page.router, prefix=url_prefix)
|
||||
return url_prefix
|
||||
|
||||
def setup_routes(app: FastAPI):
|
||||
"""设置路由 - 注册所有接口"""
|
||||
from api.apps.user_app_fastapi import router as user_router
|
||||
from api.apps.kb_app import router as kb_router
|
||||
from api.apps.document_app import router as document_router
|
||||
from api.apps.file_app import router as file_router
|
||||
from api.apps.file2document_app import router as file2document_router
|
||||
|
||||
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
|
||||
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"])
|
||||
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
|
||||
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])
|
||||
app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"])
|
||||
|
||||
def get_current_user_from_token(authorization: str):
|
||||
"""从token获取当前用户"""
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
logging.warning("Authentication attempt with empty access token")
|
||||
return None
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
|
||||
return None
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
logging.warning(f"User {user[0].email} has empty access_token in database")
|
||||
return None
|
||||
return user[0]
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
# 创建应用实例
|
||||
app = create_app()
|
||||
|
||||
@app.middleware("http")
|
||||
async def db_close_middleware(request, call_next):
|
||||
"""数据库连接关闭中间件"""
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
finally:
|
||||
close_connection()
|
||||
|
||||
setup_routes(app)
|
||||
898
api/apps/api_app.py
Normal file
898
api/apps/api_app.py
Normal file
@@ -0,0 +1,898 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from flask import request, Response
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileType, LLMType, ParserType, FileSource
|
||||
from api.db.db_models import APIToken, Task, File
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import queue_tasks, TaskService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api import settings
|
||||
from api.utils import get_uuid, current_timestamp, datetime_format
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
generate_confirmation_token
|
||||
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import keyword_extraction
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from agent.canvas import Canvas
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def new_token():
|
||||
req = request.json
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
tenant_id = tenants[0].tenant_id
|
||||
obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id),
|
||||
"create_time": current_timestamp(),
|
||||
"create_date": datetime_format(datetime.now()),
|
||||
"update_time": None,
|
||||
"update_date": None
|
||||
}
|
||||
if req.get("canvas_id"):
|
||||
obj["dialog_id"] = req["canvas_id"]
|
||||
obj["source"] = "agent"
|
||||
else:
|
||||
obj["dialog_id"] = req["dialog_id"]
|
||||
|
||||
if not APITokenService.save(**obj):
|
||||
return get_data_error_result(message="Fail to new a dialog!")
|
||||
|
||||
return get_json_result(data=obj)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/token_list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def token_list():
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
id = request.args["dialog_id"] if "dialog_id" in request.args else request.args["canvas_id"]
|
||||
objs = APITokenService.query(tenant_id=tenants[0].tenant_id, dialog_id=id)
|
||||
return get_json_result(data=[o.to_dict() for o in objs])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("tokens", "tenant_id")
|
||||
@login_required
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
for token in req["tokens"]:
|
||||
APITokenService.filter_delete(
|
||||
[APIToken.tenant_id == req["tenant_id"], APIToken.token == token])
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/stats', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def stats():
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
objs = API4ConversationService.stats(
|
||||
tenants[0].tenant_id,
|
||||
request.args.get(
|
||||
"from_date",
|
||||
(datetime.now() -
|
||||
timedelta(
|
||||
days=7)).strftime("%Y-%m-%d 00:00:00")),
|
||||
request.args.get(
|
||||
"to_date",
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
|
||||
"agent" if "canvas_id" in request.args else None)
|
||||
res = {
|
||||
"pv": [(o["dt"], o["pv"]) for o in objs],
|
||||
"uv": [(o["dt"], o["uv"]) for o in objs],
|
||||
"speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs],
|
||||
"tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs],
|
||||
"round": [(o["dt"], o["round"]) for o in objs],
|
||||
"thumb_up": [(o["dt"], o["thumb_up"]) for o in objs]
|
||||
}
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/new_conversation', methods=['GET']) # noqa: F821
|
||||
def set_conversation():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
if objs[0].source == "agent":
|
||||
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent"
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
else:
|
||||
e, dia = DialogService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": dia.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = False
|
||||
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
def rename_field(ans):
|
||||
reference = ans['reference']
|
||||
if not isinstance(reference, dict):
|
||||
return
|
||||
for chunk_i in reference.get('chunks', []):
|
||||
if 'docnm_kwd' in chunk_i:
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
stream = req.get("stream", True)
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "content": ""}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=stream)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
if stream:
|
||||
assert isinstance(answer, partial), "Nothing. Is it over?"
|
||||
|
||||
def sse():
|
||||
nonlocal answer, cvs, conv
|
||||
try:
|
||||
for ans in answer():
|
||||
for k in ans.keys():
|
||||
final_ans[k] = ans[k]
|
||||
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
canvas.history.append(("assistant", final_ans["content"]))
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
fillin_conv(result)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
rename_field(result)
|
||||
return get_json_result(data=result)
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **req):
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
answer = ans
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
rename_field(answer)
|
||||
return get_json_result(data=answer)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/conversation/<conversation_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_conversation(conversation_id):
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
e, conv = API4ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
|
||||
conv = conv.to_dict()
|
||||
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
||||
return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for referenct_i in conv['reference']:
|
||||
if referenct_i is None or len(referenct_i) == 0:
|
||||
continue
|
||||
for chunk_i in referenct_i['chunks']:
|
||||
if 'docnm_kwd' in chunk_i.keys():
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/upload', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_name")
|
||||
def upload():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
kb_name = request.form.get("kb_name").strip()
|
||||
tenant_id = objs[0].tenant_id
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
kb_root_folder = FileService.get_kb_folder(tenant_id)
|
||||
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
try:
|
||||
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
|
||||
return get_data_error_result(
|
||||
message="Exceed the maximum file number of a free user!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
name=file.filename,
|
||||
kb_id=kb_id)
|
||||
filetype = filename_type(filename)
|
||||
if not filetype:
|
||||
return get_data_error_result(
|
||||
message="This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
blob = request.files['file'].read()
|
||||
STORAGE_IMPL.put(kb_id, location, blob)
|
||||
doc = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": kb.tenant_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob),
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
}
|
||||
|
||||
form_data = request.form
|
||||
if "parser_id" in form_data.keys():
|
||||
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
|
||||
doc["parser_id"] = request.form.get("parser_id").strip()
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(eml)$", filename):
|
||||
doc["parser_id"] = ParserType.EMAIL.value
|
||||
|
||||
doc_result = DocumentService.insert(doc)
|
||||
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if "run" in form_data.keys():
|
||||
if request.form.get("run").strip() == "1":
|
||||
try:
|
||||
info = {"run": 1, "progress": 0}
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
DocumentService.update_by_id(doc["id"], info)
|
||||
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
# e, doc = DocumentService.get_by_id(doc["id"])
|
||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
||||
e, doc = DocumentService.get_by_id(doc["id"])
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=doc_result.to_json())
|
||||
|
||||
|
||||
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id")
|
||||
async def upload_parse():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = await doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route('/list_chunks', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_chunks():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
|
||||
try:
|
||||
if "doc_name" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
|
||||
doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
|
||||
|
||||
elif "doc_id" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id(req['doc_id'])
|
||||
doc_id = req['doc_id']
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_name or doc_id"
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
"doc_name": res_item["docnm_kwd"],
|
||||
"image_id": res_item["img_id"]
|
||||
} for res_item in res
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=res)
|
||||
|
||||
@manager.route('/get_chunk/<chunk_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_chunk(chunk_id):
|
||||
from rag.nlp import search
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response(Exception("Chunk not found"))
|
||||
k = []
|
||||
for n in chunk.keys():
|
||||
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
||||
k.append(n)
|
||||
for n in k:
|
||||
del chunk[n]
|
||||
|
||||
return get_json_result(data=chunk)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_kb_docs():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_name = req.get("kb_name", "").strip()
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
page_number = int(req.get("page", 1))
|
||||
items_per_page = int(req.get("page_size", 15))
|
||||
orderby = req.get("orderby", "create_time")
|
||||
desc = req.get("desc", True)
|
||||
keywords = req.get("keywords", "")
|
||||
status = req.get("status", [])
|
||||
if status:
|
||||
invalid_status = {s for s in status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter status conditions: {', '.join(invalid_status)}"
|
||||
)
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
|
||||
)
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords, status, types)
|
||||
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/infos', methods=['POST']) # noqa: F821
|
||||
@validate_request("doc_ids")
|
||||
def docinfos():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
doc_ids = req["doc_ids"]
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
return get_json_result(data=list(docs.dicts()))
|
||||
|
||||
|
||||
@manager.route('/document', methods=['DELETE']) # noqa: F821
|
||||
# @login_required
|
||||
def document_rm():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
try:
|
||||
doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", []))
|
||||
for doc_id in req.get("doc_ids", []):
|
||||
if doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
|
||||
if not doc_ids:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_names or doc_ids"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
|
||||
errors = ""
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
doc_dic = {}
|
||||
for doc in docs:
|
||||
doc_dic[doc.id] = doc
|
||||
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
if doc_id not in doc_dic:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
doc = doc_dic[doc_id]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
|
||||
STORAGE_IMPL.rm(b, n)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821
|
||||
@validate_request("Authorization", "conversation_id", "word")
|
||||
def completion_faq():
|
||||
import base64
|
||||
req = request.json
|
||||
|
||||
token = req["Authorization"]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = True
|
||||
|
||||
msg = []
|
||||
msg.append({"role": "user", "content": req["word"]})
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "doc_aggs": []}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=False)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
ans = ""
|
||||
for a in chat(dia, msg, stream=False, **req):
|
||||
ans = a
|
||||
break
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/retrieval', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
kb_ids = req.get("kb_id", [])
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
question = req.get("question")
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("page_size", 30))
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
highlight = bool(req.get("highlight", False))
|
||||
|
||||
try:
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
return get_json_result(
|
||||
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id)
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||
rank_feature=label_question(question, kbs))
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
76
api/apps/auth/README.md
Normal file
76
api/apps/auth/README.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Auth
|
||||
|
||||
The Auth module provides implementations of OAuth2 and OpenID Connect (OIDC) authentication for integration with third-party identity providers.
|
||||
|
||||
**Features**
|
||||
|
||||
- Supports both OAuth2 and OIDC authentication protocols
|
||||
- Automatic OIDC configuration discovery (via `/.well-known/openid-configuration`)
|
||||
- JWT token validation
|
||||
- Unified user information handling
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# OAuth2 configuration
|
||||
oauth_config = {
|
||||
"type": "oauth2",
|
||||
"client_id": "your_client_id",
|
||||
"client_secret": "your_client_secret",
|
||||
"authorization_url": "https://your-oauth-provider.com/oauth/authorize",
|
||||
"token_url": "https://your-oauth-provider.com/oauth/token",
|
||||
"userinfo_url": "https://your-oauth-provider.com/oauth/userinfo",
|
||||
"redirect_uri": "https://your-app.com/v1/user/oauth/callback/<channel>"
|
||||
}
|
||||
|
||||
# OIDC configuration
|
||||
oidc_config = {
|
||||
"type": "oidc",
|
||||
"issuer": "https://your-oauth-provider.com/oidc",
|
||||
"client_id": "your_client_id",
|
||||
"client_secret": "your_client_secret",
|
||||
"redirect_uri": "https://your-app.com/v1/user/oauth/callback/<channel>"
|
||||
}
|
||||
|
||||
# Github OAuth configuration
|
||||
github_config = {
|
||||
"type": "github"
|
||||
"client_id": "your_client_id",
|
||||
"client_secret": "your_client_secret",
|
||||
"redirect_uri": "https://your-app.com/v1/user/oauth/callback/<channel>"
|
||||
}
|
||||
|
||||
# Get client instance
|
||||
client = get_auth_client(oauth_config)
|
||||
```
|
||||
|
||||
### Authentication Flow
|
||||
|
||||
1. Get authorization URL:
|
||||
```python
|
||||
auth_url = client.get_authorization_url()
|
||||
```
|
||||
|
||||
2. After user authorization, exchange authorization code for token:
|
||||
```python
|
||||
token_response = client.exchange_code_for_token(authorization_code)
|
||||
access_token = token_response["access_token"]
|
||||
```
|
||||
|
||||
3. Fetch user information:
|
||||
```python
|
||||
user_info = client.fetch_user_info(access_token)
|
||||
```
|
||||
|
||||
## User Information Structure
|
||||
|
||||
All authentication methods return user information following this structure:
|
||||
|
||||
```python
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"username": "username",
|
||||
"nickname": "User Name",
|
||||
"avatar_url": "https://example.com/avatar.jpg"
|
||||
}
|
||||
```
|
||||
40
api/apps/auth/__init__.py
Normal file
40
api/apps/auth/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from .oauth import OAuthClient
|
||||
from .oidc import OIDCClient
|
||||
from .github import GithubOAuthClient
|
||||
|
||||
|
||||
CLIENT_TYPES = {
|
||||
"oauth2": OAuthClient,
|
||||
"oidc": OIDCClient,
|
||||
"github": GithubOAuthClient
|
||||
}
|
||||
|
||||
|
||||
def get_auth_client(config)->OAuthClient:
|
||||
channel_type = str(config.get("type", "")).lower()
|
||||
if channel_type == "":
|
||||
if config.get("issuer"):
|
||||
channel_type = "oidc"
|
||||
else:
|
||||
channel_type = "oauth2"
|
||||
client_class = CLIENT_TYPES.get(channel_type)
|
||||
if not client_class:
|
||||
raise ValueError(f"Unsupported type: {channel_type}")
|
||||
|
||||
return client_class(config)
|
||||
63
api/apps/auth/github.py
Normal file
63
api/apps/auth/github.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
from .oauth import OAuthClient, UserInfo
|
||||
|
||||
|
||||
class GithubOAuthClient(OAuthClient):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Initialize the GithubOAuthClient with the provider's configuration.
|
||||
"""
|
||||
config.update({
|
||||
"authorization_url": "https://github.com/login/oauth/authorize",
|
||||
"token_url": "https://github.com/login/oauth/access_token",
|
||||
"userinfo_url": "https://api.github.com/user",
|
||||
"scope": "user:email"
|
||||
})
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
def fetch_user_info(self, access_token, **kwargs):
|
||||
"""
|
||||
Fetch github user info.
|
||||
"""
|
||||
user_info = {}
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
# user info
|
||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
user_info.update(response.json())
|
||||
# email info
|
||||
response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
email_info = response.json()
|
||||
user_info["email"] = next(
|
||||
(email for email in email_info if email["primary"]), None
|
||||
)["email"]
|
||||
return self.normalize_user_info(user_info)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Failed to fetch github user info: {e}")
|
||||
|
||||
|
||||
def normalize_user_info(self, user_info):
|
||||
email = user_info.get("email")
|
||||
username = user_info.get("login", str(email).split("@")[0])
|
||||
nickname = user_info.get("name", username)
|
||||
avatar_url = user_info.get("avatar_url", "")
|
||||
return UserInfo(email=email, username=username, nickname=nickname, avatar_url=avatar_url)
|
||||
110
api/apps/auth/oauth.py
Normal file
110
api/apps/auth/oauth.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
import urllib.parse
|
||||
|
||||
|
||||
class UserInfo:
|
||||
def __init__(self, email, username, nickname, avatar_url):
|
||||
self.email = email
|
||||
self.username = username
|
||||
self.nickname = nickname
|
||||
self.avatar_url = avatar_url
|
||||
|
||||
def to_dict(self):
|
||||
return {key: value for key, value in self.__dict__.items()}
|
||||
|
||||
|
||||
class OAuthClient:
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Initialize the OAuthClient with the provider's configuration.
|
||||
"""
|
||||
self.client_id = config["client_id"]
|
||||
self.client_secret = config["client_secret"]
|
||||
self.authorization_url = config["authorization_url"]
|
||||
self.token_url = config["token_url"]
|
||||
self.userinfo_url = config["userinfo_url"]
|
||||
self.redirect_uri = config["redirect_uri"]
|
||||
self.scope = config.get("scope", None)
|
||||
|
||||
self.http_request_timeout = 7
|
||||
|
||||
|
||||
def get_authorization_url(self, state=None):
|
||||
"""
|
||||
Generate the authorization URL for user login.
|
||||
"""
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
}
|
||||
if self.scope:
|
||||
params["scope"] = self.scope
|
||||
if state:
|
||||
params["state"] = state
|
||||
authorization_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url
|
||||
|
||||
|
||||
def exchange_code_for_token(self, code):
|
||||
"""
|
||||
Exchange authorization code for access token.
|
||||
"""
|
||||
try:
|
||||
payload = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code"
|
||||
}
|
||||
response = requests.post(
|
||||
self.token_url,
|
||||
data=payload,
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=self.http_request_timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
||||
|
||||
|
||||
def fetch_user_info(self, access_token, **kwargs):
|
||||
"""
|
||||
Fetch user information using access token.
|
||||
"""
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
return self.normalize_user_info(user_info)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Failed to fetch user info: {e}")
|
||||
|
||||
|
||||
def normalize_user_info(self, user_info):
|
||||
email = user_info.get("email")
|
||||
username = user_info.get("username", str(email).split("@")[0])
|
||||
nickname = user_info.get("nickname", username)
|
||||
avatar_url = user_info.get("avatar_url", None)
|
||||
if avatar_url is None:
|
||||
avatar_url = user_info.get("picture", "")
|
||||
return UserInfo(email=email, username=username, nickname=nickname, avatar_url=avatar_url)
|
||||
99
api/apps/auth/oidc.py
Normal file
99
api/apps/auth/oidc.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from .oauth import OAuthClient
|
||||
|
||||
|
||||
class OIDCClient(OAuthClient):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Initialize the OIDCClient with the provider's configuration.
|
||||
Use `issuer` as the single source of truth for configuration discovery.
|
||||
"""
|
||||
self.issuer = config.get("issuer")
|
||||
if not self.issuer:
|
||||
raise ValueError("Missing issuer in configuration.")
|
||||
|
||||
oidc_metadata = self._load_oidc_metadata(self.issuer)
|
||||
config.update({
|
||||
'issuer': oidc_metadata['issuer'],
|
||||
'jwks_uri': oidc_metadata['jwks_uri'],
|
||||
'authorization_url': oidc_metadata['authorization_endpoint'],
|
||||
'token_url': oidc_metadata['token_endpoint'],
|
||||
'userinfo_url': oidc_metadata['userinfo_endpoint']
|
||||
})
|
||||
|
||||
super().__init__(config)
|
||||
self.issuer = config['issuer']
|
||||
self.jwks_uri = config['jwks_uri']
|
||||
|
||||
|
||||
def _load_oidc_metadata(self, issuer):
|
||||
"""
|
||||
Load OIDC metadata from `/.well-known/openid-configuration`.
|
||||
"""
|
||||
try:
|
||||
metadata_url = f"{issuer}/.well-known/openid-configuration"
|
||||
response = requests.get(metadata_url, timeout=7)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Failed to fetch OIDC metadata: {e}")
|
||||
|
||||
|
||||
def parse_id_token(self, id_token):
|
||||
"""
|
||||
Parse and validate OIDC ID Token (JWT format) with signature verification.
|
||||
"""
|
||||
try:
|
||||
# Decode JWT header without verifying signature
|
||||
headers = jwt.get_unverified_header(id_token)
|
||||
|
||||
# OIDC usually uses `RS256` for signing
|
||||
alg = headers.get("alg", "RS256")
|
||||
|
||||
# Use PyJWT's PyJWKClient to fetch JWKS and find signing key
|
||||
jwks_cli = jwt.PyJWKClient(self.jwks_uri)
|
||||
signing_key = jwks_cli.get_signing_key_from_jwt(id_token).key
|
||||
|
||||
# Decode and verify signature
|
||||
decoded_token = jwt.decode(
|
||||
id_token,
|
||||
key=signing_key,
|
||||
algorithms=[alg],
|
||||
audience=str(self.client_id),
|
||||
issuer=self.issuer,
|
||||
)
|
||||
return decoded_token
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error parsing ID Token: {e}")
|
||||
|
||||
|
||||
def fetch_user_info(self, access_token, id_token=None, **kwargs):
|
||||
"""
|
||||
Fetch user info.
|
||||
"""
|
||||
user_info = {}
|
||||
if id_token:
|
||||
user_info = self.parse_id_token(id_token)
|
||||
user_info.update(super().fetch_user_info(access_token).to_dict())
|
||||
return self.normalize_user_info(user_info)
|
||||
|
||||
|
||||
def normalize_user_info(self, user_info):
|
||||
return super().normalize_user_info(user_info)
|
||||
564
api/apps/canvas_app.py
Normal file
564
api/apps/canvas_app.py
Normal file
@@ -0,0 +1,564 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import flask
|
||||
import trio
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from agent.component import LLM
|
||||
from api import settings
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.settings import RetCode
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
import time
|
||||
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def templates():
|
||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)])
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
def rm():
|
||||
for i in request.json["canvas_ids"]:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.delete_by_id(i)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
def save():
|
||||
req = request.json
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
cate = req.get("canvas_category", CanvasCategory.Agent)
|
||||
if "id" not in req:
|
||||
req["user_id"] = current_user.id
|
||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate):
|
||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||
req["id"] = get_uuid()
|
||||
if not UserCanvasService.save(**req):
|
||||
return get_data_error_result(message="Fail to save canvas.")
|
||||
else:
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.update_by_id(req["id"], req)
|
||||
# save version
|
||||
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
||||
UserCanvasVersionService.delete_all_versions(req["id"])
|
||||
return get_json_result(data=req)
|
||||
|
||||
|
||||
@manager.route('/get/<canvas_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@manager.route('/getsse/<canvas_id>', methods=['GET']) # type: ignore # noqa: F821
|
||||
def getsse(canvas_id):
|
||||
token = request.headers.get('Authorization').split()
|
||||
if len(token) != 2:
|
||||
return get_data_error_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
if not objs:
|
||||
return get_data_error_result(message='Authentication error: API key is invalid!"')
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=canvas_id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR
|
||||
)
|
||||
e, c = UserCanvasService.get_by_id(canvas_id)
|
||||
if not e or c.user_id != tenant_id:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
return get_json_result(data=c.to_dict())
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def run():
|
||||
req = request.json
|
||||
query = req.get("query", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
def sse():
|
||||
nonlocal canvas, user_id
|
||||
try:
|
||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
def rerun():
|
||||
req = request.json
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
doc = doc[0]
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
doc["progress_msg"] = ""
|
||||
doc["chunk_num"] = 0
|
||||
doc["token_num"] = 0
|
||||
DocumentService.clear_chunk_num_when_rerun(doc["id"])
|
||||
DocumentService.update_by_id(id, doc)
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
|
||||
dsl = req["dsl"]
|
||||
dsl["path"] = [req["component_id"]]
|
||||
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
|
||||
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
def cancel(task_id):
|
||||
try:
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def reset():
|
||||
req = request.json
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||
canvas.reset()
|
||||
req["dsl"] = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
||||
return get_json_result(data=req["dsl"])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
user_id = cvs["user_id"]
|
||||
def structured(filename, filetype, blob, content_type):
|
||||
nonlocal user_id
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
|
||||
location = get_uuid()
|
||||
FileService.put_blob(user_id, location, blob)
|
||||
|
||||
return {
|
||||
"id": location,
|
||||
"name": filename,
|
||||
"size": sys.getsizeof(blob),
|
||||
"extension": filename.split(".")[-1].lower(),
|
||||
"mime_type": content_type,
|
||||
"created_by": user_id,
|
||||
"created_at": time.time(),
|
||||
"preview_url": None
|
||||
}
|
||||
|
||||
if request.args.get("url"):
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
DefaultMarkdownGenerator,
|
||||
PruningContentFilter,
|
||||
CrawlResult
|
||||
)
|
||||
try:
|
||||
url = request.args.get("url")
|
||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||
async def adownload():
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
)
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
crawler_config = CrawlerRunConfig(
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter()
|
||||
),
|
||||
pdf=True,
|
||||
screenshot=False
|
||||
)
|
||||
result: CrawlResult = await crawler.arun(
|
||||
url=url,
|
||||
config=crawler_config
|
||||
)
|
||||
return result
|
||||
page = trio.run(adownload())
|
||||
if page.pdf:
|
||||
if filename.split(".")[-1].lower() != "pdf":
|
||||
filename += ".pdf"
|
||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
||||
|
||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
file = request.files['file']
|
||||
try:
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/input_form', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def input_form():
|
||||
cvs_id = request.args.get("id")
|
||||
cpn_id = request.args.get("component_id")
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(cvs_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=cvs_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||
return get_json_result(data=canvas.get_component_input_form(cpn_id))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/debug', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
def debug():
|
||||
req = request.json
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||
canvas.reset()
|
||||
canvas.message_id = get_uuid()
|
||||
component = canvas.get_component(req["component_id"])["obj"]
|
||||
component.reset()
|
||||
|
||||
if isinstance(component, LLM):
|
||||
component.set_debug_inputs(req["params"])
|
||||
component.invoke(**{k: o["value"] for k,o in req["params"].items()})
|
||||
outputs = component.output()
|
||||
for k in outputs.keys():
|
||||
if isinstance(outputs[k], partial):
|
||||
txt = ""
|
||||
for c in outputs[k]():
|
||||
txt += c
|
||||
outputs[k] = txt
|
||||
return get_json_result(data=outputs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
||||
@validate_request("db_type", "database", "username", "host", "port", "password")
|
||||
@login_required
|
||||
def test_db_connect():
|
||||
req = request.json
|
||||
try:
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
password=req["password"])
|
||||
elif req["db_type"] == 'postgres':
|
||||
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
password=req["password"])
|
||||
elif req["db_type"] == 'mssql':
|
||||
import pyodbc
|
||||
connection_string = (
|
||||
f"DRIVER={{ODBC Driver 17 for SQL Server}};"
|
||||
f"SERVER={req['host']},{req['port']};"
|
||||
f"DATABASE={req['database']};"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
db = pyodbc.connect(connection_string)
|
||||
cursor = db.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.close()
|
||||
elif req["db_type"] == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
f"DATABASE={req['database']};"
|
||||
f"HOSTNAME={req['host']};"
|
||||
f"PORT={req['port']};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
logging.info(conn_str)
|
||||
conn = ibm_db.connect(conn_str, "", "")
|
||||
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
||||
ibm_db.fetch_assoc(stmt)
|
||||
ibm_db.close(conn)
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
else:
|
||||
return server_error_response("Unsupported database type.")
|
||||
if req["db_type"] != 'mssql':
|
||||
db.connect()
|
||||
db.close()
|
||||
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
#api get list version dsl of canvas
|
||||
@manager.route('/getlistversion/<canvas_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
try:
|
||||
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=list)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
|
||||
#api get version dsl of canvas
|
||||
@manager.route('/getversion/<version_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def getversion( version_id):
|
||||
try:
|
||||
|
||||
e, version = UserCanvasVersionService.get_by_id(version_id)
|
||||
if version:
|
||||
return get_json_result(data=version.to_dict())
|
||||
except Exception as e:
|
||||
return get_json_result(data=f"Error getting history file: {e}")
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_canvas():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
canvas_category = request.args.get("canvas_category")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id]
|
||||
if not owner_ids:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
tenants = [m["tenant_id"] for m in tenants]
|
||||
tenants.append(current_user.id)
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
tenants, current_user.id, page_number,
|
||||
items_per_page, orderby, desc, keywords, canvas_category)
|
||||
else:
|
||||
tenants = owner_ids
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
tenants, current_user.id, 0,
|
||||
0, orderby, desc, keywords, canvas_category)
|
||||
return get_json_result(data={"canvas": canvas, "total": total})
|
||||
|
||||
|
||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
def setting():
|
||||
req = request.json
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e,flow = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
flow = flow.to_dict()
|
||||
flow["title"] = req["title"]
|
||||
|
||||
for key in ["description", "permission", "avatar"]:
|
||||
if value := req.get(key):
|
||||
flow[key] = value
|
||||
|
||||
num= UserCanvasService.update_by_id(req["id"], flow)
|
||||
return get_json_result(data=num)
|
||||
|
||||
|
||||
@manager.route('/trace', methods=['GET']) # noqa: F821
|
||||
def trace():
|
||||
cvs_id = request.args.get("canvas_id")
|
||||
msg_id = request.args.get("message_id")
|
||||
try:
|
||||
bin = REDIS_CONN.get(f"{cvs_id}-{msg_id}-logs")
|
||||
if not bin:
|
||||
return get_json_result(data={})
|
||||
|
||||
return get_json_result(data=json.loads(bin.encode("utf-8")))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
|
||||
@manager.route('/<canvas_id>/sessions', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def sessions(canvas_id):
|
||||
tenant_id = current_user.id
|
||||
if not UserCanvasService.accessible(canvas_id, tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
user_id = request.args.get("user_id")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
keywords = request.args.get("keywords")
|
||||
from_date = request.args.get("from_date")
|
||||
to_date = request.args.get("to_date")
|
||||
orderby = request.args.get("orderby", "update_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
# dsl defaults to True in all cases except for False and false
|
||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
|
||||
None, user_id, include_dsl, keywords, from_date, to_date)
|
||||
try:
|
||||
return get_json_result(data={"total": total, "sessions": sess})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def prompts():
|
||||
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||
return get_json_result(data={
|
||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||
"plan_generation": NEXT_STEP,
|
||||
"reflection": REFLECT,
|
||||
#"context_summary": SUMMARY4MEMORY,
|
||||
#"context_ranking": RANK_MEMORY,
|
||||
"citation_guidelines": CITATION_PROMPT_TEMPLATE
|
||||
})
|
||||
|
||||
|
||||
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||
def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return flask.make_response(blob)
|
||||
415
api/apps/chunk_app.py
Normal file
415
api/apps/chunk_app.py
Normal file
@@ -0,0 +1,415 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
|
||||
import xxhash
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def list_chunk():
|
||||
req = request.json
|
||||
doc_id = req["doc_id"]
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req.get("keywords", "")
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
query = {
|
||||
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
|
||||
}
|
||||
if "available_int" in req:
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[
|
||||
id].get(
|
||||
"content_with_weight", ""),
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||
"question_kwd": sres.field[id].get("question_kwd", []),
|
||||
"image_id": sres.field[id].get("img_id", ""),
|
||||
"available_int": int(sres.field[id].get("available_int", 1)),
|
||||
"positions": sres.field[id].get("position_int", []),
|
||||
}
|
||||
assert isinstance(d["positions"], list)
|
||||
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
||||
res["chunks"].append(d)
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found!',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/get', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
chunk_id = request.args["chunk_id"]
|
||||
try:
|
||||
chunk = None
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
for tenant in tenants:
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
|
||||
if chunk:
|
||||
break
|
||||
if chunk is None:
|
||||
return server_error_response(Exception("Chunk not found"))
|
||||
|
||||
k = []
|
||||
for n in chunk.keys():
|
||||
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
||||
k.append(n)
|
||||
for n in k:
|
||||
del chunk[n]
|
||||
|
||||
return get_json_result(data=chunk)
|
||||
except Exception as e:
|
||||
if str(e).find("NotFoundError") >= 0:
|
||||
return get_json_result(data=False, message='Chunk not found!',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||
def set():
|
||||
req = request.json
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if "important_kwd" in req:
|
||||
if not isinstance(req["important_kwd"], list):
|
||||
return get_data_error_result(message="`important_kwd` should be a list")
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
|
||||
if "question_kwd" in req:
|
||||
if not isinstance(req["question_kwd"], list):
|
||||
return get_data_error_result(message="`question_kwd` should be a list")
|
||||
d["question_kwd"] = req["question_kwd"]
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
|
||||
if "tag_kwd" in req:
|
||||
d["tag_kwd"] = req["tag_kwd"]
|
||||
if "tag_feas" in req:
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
if "available_int" in req:
|
||||
d["available_int"] = req["available_int"]
|
||||
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||
d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/switch', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "available_int", "doc_id")
|
||||
def switch():
|
||||
req = request.json
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
def rm():
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
req = request.json
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "content_with_weight")
|
||||
def create():
|
||||
req = request.json
|
||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
d["important_kwd"] = req.get("important_kwd", [])
|
||||
if not isinstance(d["important_kwd"], list):
|
||||
return get_data_error_result(message="`important_kwd` is required to be a list")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
d["question_kwd"] = req.get("question_kwd", [])
|
||||
if not isinstance(d["question_kwd"], list):
|
||||
return get_data_error_result(message="`question_kwd` is required to be a list")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
if "tag_feas" in req:
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
if "tag_feas" in req:
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||
d["doc_id"] = doc.id
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
if kb.pagerank:
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval_test():
|
||||
req = request.json
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
kb_ids = req["kb_id"]
|
||||
if isinstance(kb_ids, str):
|
||||
kb_ids = [kb_ids]
|
||||
if not kb_ids:
|
||||
return get_json_result(data=False, message='Please specify dataset firstly.',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
langs = req.get("cross_languages", [])
|
||||
tenant_ids = []
|
||||
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kb_id):
|
||||
tenant_ids.append(tenant.tenant_id)
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
|
||||
if langs:
|
||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
|
||||
rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(question,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
ranks["labels"] = labels
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def knowledge_graph():
|
||||
doc_id = request.args["doc_id"]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
req = {
|
||||
"doc_ids": [doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
try:
|
||||
content_json = json.loads(sres.field[id]["content_with_weight"])
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if ty == 'mind_map':
|
||||
node_dict = {}
|
||||
|
||||
def repeat_deal(content_json, node_dict):
|
||||
if 'id' in content_json:
|
||||
if content_json['id'] in node_dict:
|
||||
node_name = content_json['id']
|
||||
content_json['id'] += f"({node_dict[content_json['id']]})"
|
||||
node_dict[node_name] += 1
|
||||
else:
|
||||
node_dict[content_json['id']] = 1
|
||||
if 'children' in content_json and content_json['children']:
|
||||
for item in content_json['children']:
|
||||
repeat_deal(item, node_dict)
|
||||
|
||||
repeat_deal(content_json, node_dict)
|
||||
|
||||
obj[ty] = content_json
|
||||
|
||||
return get_json_result(data=obj)
|
||||
419
api/apps/conversation_app.py
Normal file
419
api/apps/conversation_app.py
Normal file
@@ -0,0 +1,419 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_conversation():
|
||||
req = request.json
|
||||
conv_id = req.get("conversation_id")
|
||||
is_new = req.get("is_new")
|
||||
name = req.get("name", "New conversation")
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if len(name) > 255:
|
||||
name = name[0:255]
|
||||
|
||||
del req["is_new"]
|
||||
if not is_new:
|
||||
del req["conversation_id"]
|
||||
try:
|
||||
if not ConversationService.update_by_id(conv_id, req):
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
e, conv = ConversationService.get_by_id(conv_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Fail to update a conversation!")
|
||||
conv = conv.to_dict()
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
try:
|
||||
e, dia = DialogService.get_by_id(req["dialog_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": conv_id,
|
||||
"dialog_id": req["dialog_id"],
|
||||
"name": name,
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
|
||||
"user_id": current_user.id,
|
||||
"reference": [],
|
||||
}
|
||||
ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
conv_id = request.args["conversation_id"]
|
||||
try:
|
||||
e, conv = ConversationService.get_by_id(conv_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
avatar = None
|
||||
for tenant in tenants:
|
||||
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
|
||||
if dialog and len(dialog) > 0:
|
||||
avatar = dialog[0].icon
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
for ref in conv.reference:
|
||||
if isinstance(ref, list):
|
||||
continue
|
||||
ref["chunks"] = chunks_format(ref)
|
||||
|
||||
conv = conv.to_dict()
|
||||
conv["avatar"] = avatar
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
|
||||
def getsse(dialog_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_data_error_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
if not objs:
|
||||
return get_data_error_result(message='Authentication error: API key is invalid!"')
|
||||
try:
|
||||
e, conv = DialogService.get_by_id(dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
conv = conv.to_dict()
|
||||
conv["avatar"] = conv["icon"]
|
||||
del conv["icon"]
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def rm():
|
||||
conv_ids = request.json["conversation_ids"]
|
||||
try:
|
||||
for cid in conv_ids:
|
||||
exist, conv = ConversationService.get_by_id(cid)
|
||||
if not exist:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
ConversationService.delete_by_id(cid)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_conversation():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||
return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
|
||||
|
||||
convs = [d.to_dict() for d in convs]
|
||||
return get_json_result(data=convs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/completion", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
req = request.json
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
message_id = msg[-1].get("id")
|
||||
chat_model_id = req.get("llm_id", "")
|
||||
req.pop("llm_id", None)
|
||||
|
||||
chat_model_config = {}
|
||||
for model_config in [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"max_tokens",
|
||||
]:
|
||||
config = req.get(model_config)
|
||||
if config:
|
||||
chat_model_config[model_config] = config
|
||||
|
||||
try:
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
conv.message = deepcopy(req["messages"])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.reference = [r for r in conv.reference if r]
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
if chat_model_id:
|
||||
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
|
||||
req.pop("chat_model_id", None)
|
||||
req.pop("chat_model_config", None)
|
||||
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
|
||||
dia.llm_id = chat_model_id
|
||||
dia.llm_setting = chat_model_config
|
||||
|
||||
is_embedded = bool(chat_model_id)
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **req):
|
||||
ans = structure_answer(conv, ans, message_id, conv.id)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
if not is_embedded:
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
else:
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
answer = structure_answer(conv, ans, message_id, conv.id)
|
||||
if not is_embedded:
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
break
|
||||
return get_json_result(data=answer)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def tts():
|
||||
req = request.json
|
||||
text = req["text"]
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
tts_id = tenants[0]["tts_id"]
|
||||
if not tts_id:
|
||||
return get_data_error_result(message="No default TTS model is set")
|
||||
|
||||
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
||||
|
||||
def stream_audio():
|
||||
try:
|
||||
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
|
||||
for chunk in tts_mdl.tts(txt):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
||||
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route("/delete_msg", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def delete_msg():
|
||||
req = request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
|
||||
conv = conv.to_dict()
|
||||
for i, msg in enumerate(conv["message"]):
|
||||
if req["message_id"] != msg.get("id", ""):
|
||||
continue
|
||||
assert conv["message"][i + 1]["id"] == req["message_id"]
|
||||
conv["message"].pop(i)
|
||||
conv["message"].pop(i)
|
||||
conv["reference"].pop(max(0, i // 2 - 1))
|
||||
break
|
||||
|
||||
ConversationService.update_by_id(conv["id"], conv)
|
||||
return get_json_result(data=conv)
|
||||
|
||||
|
||||
@manager.route("/thumbup", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
def thumbup():
|
||||
req = request.json
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
up_down = req.get("thumbup")
|
||||
feedback = req.get("feedback", "")
|
||||
conv = conv.to_dict()
|
||||
for i, msg in enumerate(conv["message"]):
|
||||
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
|
||||
if up_down:
|
||||
msg["thumbup"] = True
|
||||
if "feedback" in msg:
|
||||
del msg["feedback"]
|
||||
else:
|
||||
msg["thumbup"] = False
|
||||
if feedback:
|
||||
msg["feedback"] = feedback
|
||||
break
|
||||
|
||||
ConversationService.update_by_id(conv["id"], conv)
|
||||
return get_json_result(data=conv)
|
||||
|
||||
|
||||
@manager.route("/ask", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def ask_about():
|
||||
req = request.json
|
||||
uid = current_user.id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = None
|
||||
search_config = {}
|
||||
if search_id:
|
||||
search_app = SearchService.get_detail(search_id)
|
||||
if search_app:
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
def stream():
|
||||
nonlocal req, uid
|
||||
try:
|
||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route("/mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
def mindmap():
|
||||
req = request.json
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||
kb_ids = search_config.get("kb_ids", [])
|
||||
kb_ids.extend(req["kb_ids"])
|
||||
kb_ids = list(set(kb_ids))
|
||||
|
||||
mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
|
||||
|
||||
@manager.route("/related_questions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
def related_questions():
|
||||
req = request.json
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
if search_id:
|
||||
if search_app := SearchService.get_detail(search_id):
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
question = req["question"]
|
||||
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)
|
||||
|
||||
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
||||
if "parameter" in gen_conf:
|
||||
del gen_conf["parameter"]
|
||||
prompt = load_prompt("related_question")
|
||||
ans = chat_mdl.chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Keywords: {question}
|
||||
Related search terms:
|
||||
""",
|
||||
}
|
||||
],
|
||||
gen_conf,
|
||||
)
|
||||
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
|
||||
227
api/apps/dialog_app.py
Normal file
227
api/apps/dialog_app.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("prompt_config")
|
||||
@login_required
|
||||
def set_dialog():
|
||||
req = request.json
|
||||
dialog_id = req.get("dialog_id", "")
|
||||
is_create = not dialog_id
|
||||
name = req.get("name", "New Dialog")
|
||||
if not isinstance(name, str):
|
||||
return get_data_error_result(message="Dialog name must be string.")
|
||||
if name.strip() == "":
|
||||
return get_data_error_result(message="Dialog name can't be empty.")
|
||||
if len(name.encode("utf-8")) > 255:
|
||||
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
|
||||
|
||||
if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()):
|
||||
name = name.strip()
|
||||
name = duplicate_name(
|
||||
DialogService.query,
|
||||
name=name,
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value)
|
||||
|
||||
description = req.get("description", "A helpful dialog")
|
||||
icon = req.get("icon", "")
|
||||
top_n = req.get("top_n", 6)
|
||||
top_k = req.get("top_k", 1024)
|
||||
rerank_id = req.get("rerank_id", "")
|
||||
if not rerank_id:
|
||||
req["rerank_id"] = ""
|
||||
similarity_threshold = req.get("similarity_threshold", 0.1)
|
||||
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
||||
llm_setting = req.get("llm_setting", {})
|
||||
meta_data_filter = req.get("meta_data_filter", {})
|
||||
prompt_config = req["prompt_config"]
|
||||
|
||||
if not is_create:
|
||||
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
|
||||
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no knowledge base / Tavily used here.")
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["optional"]:
|
||||
continue
|
||||
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
|
||||
return get_data_error_result(
|
||||
message="Parameter '{}' is not used".format(p["key"]))
|
||||
|
||||
try:
|
||||
e, tenant = TenantService.get_by_id(current_user.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids", []))
|
||||
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||
embd_count = len(set(embd_ids))
|
||||
if embd_count > 1:
|
||||
return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"')
|
||||
|
||||
llm_id = req.get("llm_id", tenant.llm_id)
|
||||
if not dialog_id:
|
||||
dia = {
|
||||
"id": get_uuid(),
|
||||
"tenant_id": current_user.id,
|
||||
"name": name,
|
||||
"kb_ids": req.get("kb_ids", []),
|
||||
"description": description,
|
||||
"llm_id": llm_id,
|
||||
"llm_setting": llm_setting,
|
||||
"prompt_config": prompt_config,
|
||||
"meta_data_filter": meta_data_filter,
|
||||
"top_n": top_n,
|
||||
"top_k": top_k,
|
||||
"rerank_id": rerank_id,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"icon": icon
|
||||
}
|
||||
if not DialogService.save(**dia):
|
||||
return get_data_error_result(message="Fail to new a dialog!")
|
||||
return get_json_result(data=dia)
|
||||
else:
|
||||
del req["dialog_id"]
|
||||
if "kb_names" in req:
|
||||
del req["kb_names"]
|
||||
if not DialogService.update_by_id(dialog_id, req):
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
e, dia = DialogService.get_by_id(dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Fail to update a dialog!")
|
||||
dia = dia.to_dict()
|
||||
dia.update(req)
|
||||
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
||||
return get_json_result(data=dia)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/get', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
e, dia = DialogService.get_by_id(dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
dia = dia.to_dict()
|
||||
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
|
||||
return get_json_result(data=dia)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
def get_kb_names(kb_ids):
|
||||
ids, nms = [], []
|
||||
for kid in kb_ids:
|
||||
e, kb = KnowledgebaseService.get_by_id(kid)
|
||||
if not e or kb.status != StatusEnum.VALID.value:
|
||||
continue
|
||||
ids.append(kid)
|
||||
nms.append(kb.name)
|
||||
return ids, nms
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_dialogs():
|
||||
try:
|
||||
diags = DialogService.query(
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value,
|
||||
reverse=True,
|
||||
order_by=DialogService.model.create_time)
|
||||
diags = [d.to_dict() for d in diags]
|
||||
for d in diags:
|
||||
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
||||
return get_json_result(data=diags)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/next', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
def list_dialogs_next():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
parser_id = request.args.get("parser_id")
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
# tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
# tenants = [tenant["tenant_id"] for tenant in tenants]
|
||||
tenants = [] # keep it here
|
||||
dialogs, total = DialogService.get_by_tenant_ids(
|
||||
tenants, current_user.id, page_number,
|
||||
items_per_page, orderby, desc, keywords, parser_id)
|
||||
else:
|
||||
tenants = owner_ids
|
||||
dialogs, total = DialogService.get_by_tenant_ids(
|
||||
tenants, current_user.id, 0,
|
||||
0, orderby, desc, keywords, parser_id)
|
||||
dialogs = [dialog for dialog in dialogs if dialog["tenant_id"] in tenants]
|
||||
total = len(dialogs)
|
||||
if page_number and items_per_page:
|
||||
dialogs = dialogs[(page_number-1)*items_per_page:page_number*items_per_page]
|
||||
return get_json_result(data={"dialogs": dialogs, "total": total})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("dialog_ids")
|
||||
def rm():
|
||||
req = request.json
|
||||
dialog_list=[]
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
try:
|
||||
for id in req["dialog_ids"]:
|
||||
for tenant in tenants:
|
||||
if DialogService.query(tenant_id=tenant.tenant_id, id=id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of dialog authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
|
||||
DialogService.update_many_by_id(dialog_list)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
873
api/apps/document_app.py
Normal file
873
api/apps/document_app.py
Normal file
@@ -0,0 +1,873 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import json
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from api import settings
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus
|
||||
from api.db.db_models import File, Task
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, get_project_base_directory, thumbnail
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
|
||||
from deepdoc.parser.html_parser import RAGFlowHtmlParser
|
||||
from rag.nlp import search
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from pydantic import BaseModel
|
||||
from api.db.db_models import User
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
|
||||
# Pydantic models for request/response
|
||||
class WebCrawlRequest(BaseModel):
|
||||
kb_id: str
|
||||
name: str
|
||||
url: str
|
||||
|
||||
class CreateDocumentRequest(BaseModel):
|
||||
name: str
|
||||
kb_id: str
|
||||
|
||||
class DocumentListRequest(BaseModel):
|
||||
run_status: List[str] = []
|
||||
types: List[str] = []
|
||||
suffix: List[str] = []
|
||||
|
||||
class DocumentFilterRequest(BaseModel):
|
||||
kb_id: str
|
||||
keywords: str = ""
|
||||
run_status: List[str] = []
|
||||
types: List[str] = []
|
||||
suffix: List[str] = []
|
||||
|
||||
class DocumentInfosRequest(BaseModel):
|
||||
doc_ids: List[str]
|
||||
|
||||
class ChangeStatusRequest(BaseModel):
|
||||
doc_ids: List[str]
|
||||
status: str
|
||||
|
||||
class RemoveDocumentRequest(BaseModel):
|
||||
doc_id: List[str]
|
||||
|
||||
class RunDocumentRequest(BaseModel):
|
||||
doc_ids: List[str]
|
||||
run: str
|
||||
delete: bool = False
|
||||
|
||||
class RenameDocumentRequest(BaseModel):
|
||||
doc_id: str
|
||||
name: str
|
||||
|
||||
class ChangeParserRequest(BaseModel):
|
||||
doc_id: str
|
||||
parser_id: str
|
||||
pipeline_id: Optional[str] = None
|
||||
parser_config: Optional[dict] = None
|
||||
|
||||
class UploadAndParseRequest(BaseModel):
|
||||
conversation_id: str
|
||||
|
||||
class ParseRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
|
||||
class SetMetaRequest(BaseModel):
|
||||
doc_id: str
|
||||
meta: str
|
||||
|
||||
|
||||
# Dependency injection
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户"""
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.user_service import UserService
|
||||
from fastapi import HTTPException, status
|
||||
import logging
|
||||
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User {user[0].email} has empty access_token in database"
|
||||
)
|
||||
return user[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required"
|
||||
)
|
||||
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload(
|
||||
kb_id: str = Form(...),
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if not files:
|
||||
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
# Use UploadFile directly
|
||||
file_objs = files
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
err, files = await FileService.upload_document(kb, file_objs, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
if not files:
|
||||
return get_json_result(data=files, message="There seems to be an issue with your file format. Please verify it is correct and not corrupted.", code=settings.RetCode.DATA_ERROR)
|
||||
files = [f[0] for f in files] # remove the blob
|
||||
|
||||
return get_json_result(data=files)
|
||||
|
||||
|
||||
@router.post("/web_crawl")
|
||||
async def web_crawl(
|
||||
req: WebCrawlRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_id = req.kb_id
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
name = req.name
|
||||
url = req.url
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = html2pdf(url)
|
||||
if not blob:
|
||||
return server_error_response(ValueError("Download failure."))
|
||||
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, current_user.id)
|
||||
kb_root_folder = FileService.get_kb_folder(current_user.id)
|
||||
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
try:
|
||||
filename = duplicate_name(DocumentService.query, name=name + ".pdf", kb_id=kb.id)
|
||||
filetype = filename_type(filename)
|
||||
if filetype == FileType.OTHER.value:
|
||||
raise RuntimeError("This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
STORAGE_IMPL.put(kb_id, location, blob)
|
||||
doc = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob),
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
}
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(eml)$", filename):
|
||||
doc["parser_id"] = ParserType.EMAIL.value
|
||||
DocumentService.insert(doc)
|
||||
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
async def create(
|
||||
req: CreateDocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_id = req.kb_id
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if len(req.name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if req.name.strip() == "":
|
||||
return get_json_result(data=False, message="File name can't be empty.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
req.name = req.name.strip()
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Can't find this knowledgebase!")
|
||||
|
||||
if DocumentService.query(name=req.name, kb_id=kb_id):
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
kb_root_folder = FileService.get_kb_folder(kb.tenant_id)
|
||||
if not kb_root_folder:
|
||||
return get_data_error_result(message="Cannot find the root folder.")
|
||||
kb_folder = FileService.new_a_file_from_kb(
|
||||
kb.tenant_id,
|
||||
kb.name,
|
||||
kb_root_folder["id"],
|
||||
)
|
||||
if not kb_folder:
|
||||
return get_data_error_result(message="Cannot find the kb folder for this file.")
|
||||
|
||||
doc = DocumentService.insert(
|
||||
{
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": FileType.VIRTUAL,
|
||||
"name": req.name,
|
||||
"suffix": Path(req.name).suffix.lstrip("."),
|
||||
"location": "",
|
||||
"size": 0,
|
||||
}
|
||||
)
|
||||
|
||||
FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], kb.tenant_id)
|
||||
|
||||
return get_json_result(data=doc.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/list")
|
||||
async def list_docs(
|
||||
kb_id: str = Query(...),
|
||||
keywords: str = Query(""),
|
||||
page: int = Query(0),
|
||||
page_size: int = Query(0),
|
||||
orderby: str = Query("create_time"),
|
||||
desc: str = Query("true"),
|
||||
create_time_from: int = Query(0),
|
||||
create_time_to: int = Query(0),
|
||||
req: DocumentListRequest = None,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
if desc.lower() == "false":
|
||||
desc_bool = False
|
||||
else:
|
||||
desc_bool = True
|
||||
|
||||
run_status = req.run_status if req else []
|
||||
if run_status:
|
||||
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = req.types if req else []
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
suffix = req.suffix if req else []
|
||||
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page, page_size, orderby, desc_bool, keywords, run_status, types, suffix)
|
||||
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
for doc in docs:
|
||||
doc_create_time = doc.get("create_time", 0)
|
||||
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||
filtered_docs.append(doc)
|
||||
docs = filtered_docs
|
||||
|
||||
for doc_item in docs:
|
||||
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
|
||||
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/filter")
|
||||
async def get_filter(
|
||||
req: DocumentFilterRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_id = req.kb_id
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
keywords = req.keywords
|
||||
suffix = req.suffix
|
||||
run_status = req.run_status
|
||||
if run_status:
|
||||
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = req.types
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
try:
|
||||
filter, total = DocumentService.get_filter_by_kb_id(kb_id, keywords, run_status, types, suffix)
|
||||
return get_json_result(data={"total": total, "filter": filter})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/infos")
|
||||
async def docinfos(
|
||||
req: DocumentInfosRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
doc_ids = req.doc_ids
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
return get_json_result(data=list(docs.dicts()))
|
||||
|
||||
|
||||
@router.get("/thumbnails")
|
||||
async def thumbnails(
|
||||
doc_ids: List[str] = Query(...)
|
||||
):
|
||||
if not doc_ids:
|
||||
return get_json_result(data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
docs = DocumentService.get_thumbnails(doc_ids)
|
||||
|
||||
for doc_item in docs:
|
||||
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
|
||||
doc_item["thumbnail"] = f"/v1/document/image/{doc_item['kb_id']}-{doc_item['thumbnail']}"
|
||||
|
||||
return get_json_result(data={d["id"]: d["thumbnail"] for d in docs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/change_status")
|
||||
async def change_status(
|
||||
req: ChangeStatusRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
doc_ids = req.doc_ids
|
||||
status = str(req.status)
|
||||
|
||||
if status not in ["0", "1"]:
|
||||
return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
result = {}
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
result[doc_id] = {"error": "No authorization."}
|
||||
continue
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
result[doc_id] = {"error": "No authorization."}
|
||||
continue
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
result[doc_id] = {"error": "Can't find this knowledgebase!"}
|
||||
continue
|
||||
if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
|
||||
result[doc_id] = {"error": "Database error (Document update)!"}
|
||||
continue
|
||||
|
||||
status_int = int(status)
|
||||
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
|
||||
result[doc_id] = {"error": "Database error (docStore update)!"}
|
||||
result[doc_id] = {"status": status}
|
||||
except Exception as e:
|
||||
result[doc_id] = {"error": f"Internal server error: {str(e)}"}
|
||||
|
||||
return get_json_result(data=result)
|
||||
|
||||
|
||||
@router.post("/rm")
|
||||
async def rm(
|
||||
req: RemoveDocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
doc_ids = req.doc_id
|
||||
if isinstance(doc_ids, str):
|
||||
doc_ids = [doc_ids]
|
||||
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, current_user.id)
|
||||
errors = ""
|
||||
kb_table_num_map = {}
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
TaskService.filter_delete([Task.doc_id == doc_id])
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(message="Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
deleted_file_count = 0
|
||||
if f2d:
|
||||
deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
if deleted_file_count > 0:
|
||||
STORAGE_IMPL.rm(b, n)
|
||||
|
||||
doc_parser = doc.parser_id
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.kb_id
|
||||
if kb_id not in kb_table_num_map:
|
||||
counts = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = counts
|
||||
kb_table_num_map[kb_id] -= 1
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@router.post("/run")
|
||||
async def run(
|
||||
req: RunDocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
for doc_id in req.doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
kb_table_num_map = {}
|
||||
for id in req.doc_ids:
|
||||
info = {"run": str(req.run), "progress": 0}
|
||||
if str(req.run) == TaskStatus.RUNNING.value and req.delete:
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
if str(req.run) == TaskStatus.CANCEL.value:
|
||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||
cancel_all_task_of(id)
|
||||
else:
|
||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||
if all([req.delete, str(req.run) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.delete:
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
if str(req.run) == TaskStatus.RUNNING.value:
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
|
||||
doc_parser = doc.get("parser_id", ParserType.NAIVE)
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.get("kb_id")
|
||||
if not kb_id:
|
||||
continue
|
||||
if kb_id not in kb_table_num_map:
|
||||
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = count
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
if doc.get("pipeline_id", ""):
|
||||
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
|
||||
else:
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/rename")
|
||||
async def rename(
|
||||
req: RenameDocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not DocumentService.accessible(req.doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req.doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if pathlib.Path(req.name.lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if len(req.name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
for d in DocumentService.query(name=req.name, kb_id=doc.kb_id):
|
||||
if d.name == req.name:
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
if not DocumentService.update_by_id(req.doc_id, {"name": req.name}):
|
||||
return get_data_error_result(message="Database error (Document rename)!")
|
||||
|
||||
informs = File2DocumentService.get_by_document_id(req.doc_id)
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
FileService.update_by_id(file.id, {"name": req.name})
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get("/get/{doc_id}")
|
||||
async def get(doc_id: str):
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
content = STORAGE_IMPL.get(b, n)
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
|
||||
if ext:
|
||||
if doc.type == FileType.VISUAL.value:
|
||||
media_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
else:
|
||||
media_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
else:
|
||||
media_type = "application/octet-stream"
|
||||
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f"attachment; filename={doc.name}"}
|
||||
)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/change_parser")
|
||||
async def change_parser(
|
||||
req: ChangeParserRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not DocumentService.accessible(req.doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req.doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
def reset_doc():
|
||||
nonlocal doc
|
||||
e = DocumentService.update_by_id(doc.id, {"parser_id": req.parser_id, "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if doc.token_num > 0:
|
||||
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(req.doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
try:
|
||||
if req.pipeline_id:
|
||||
if doc.pipeline_id == req.pipeline_id:
|
||||
return get_json_result(data=True)
|
||||
DocumentService.update_by_id(doc.id, {"pipeline_id": req.pipeline_id})
|
||||
reset_doc()
|
||||
return get_json_result(data=True)
|
||||
|
||||
if doc.parser_id.lower() == req.parser_id.lower():
|
||||
if req.parser_config:
|
||||
if req.parser_config == doc.parser_config:
|
||||
return get_json_result(data=True)
|
||||
else:
|
||||
return get_json_result(data=True)
|
||||
|
||||
if (doc.type == FileType.VISUAL and req.parser_id != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req.parser_id != "presentation"):
|
||||
return get_data_error_result(message="Not supported yet!")
|
||||
if req.parser_config:
|
||||
DocumentService.update_parser_config(doc.id, req.parser_config)
|
||||
reset_doc()
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get("/image/{image_id}")
|
||||
async def get_image(image_id: str):
|
||||
try:
|
||||
arr = image_id.split("-")
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
content = STORAGE_IMPL.get(bkt, nm)
|
||||
return StreamingResponse(
|
||||
iter([content]),
|
||||
media_type="image/JPEG"
|
||||
)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/upload_and_parse")
|
||||
async def upload_and_parse(
|
||||
conversation_id: str = Form(...),
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not files:
|
||||
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
# Use UploadFile directly
|
||||
file_objs = files
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = await doc_upload_and_parse(conversation_id, file_objs, current_user.id)
|
||||
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@router.post("/parse")
|
||||
async def parse(
|
||||
req: ParseRequest = None,
|
||||
files: List[UploadFile] = File(None),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
url = req.url if req else ""
|
||||
if url:
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
download_path = os.path.join(get_project_base_directory(), "logs/downloads")
|
||||
os.makedirs(download_path, exist_ok=True)
|
||||
from seleniumwire.webdriver import Chrome, ChromeOptions
|
||||
|
||||
options = ChromeOptions()
|
||||
options.add_argument("--headless")
|
||||
options.add_argument("--disable-gpu")
|
||||
options.add_argument("--no-sandbox")
|
||||
options.add_argument("--disable-dev-shm-usage")
|
||||
options.add_experimental_option("prefs", {"download.default_directory": download_path, "download.prompt_for_download": False, "download.directory_upgrade": True, "safebrowsing.enabled": True})
|
||||
driver = Chrome(options=options)
|
||||
driver.get(url)
|
||||
res_headers = [r.response.headers for r in driver.requests if r and r.response]
|
||||
if len(res_headers) > 1:
|
||||
sections = RAGFlowHtmlParser().parser_txt(driver.page_source)
|
||||
driver.quit()
|
||||
return get_json_result(data="\n".join(sections))
|
||||
|
||||
class File:
|
||||
filename: str
|
||||
filepath: str
|
||||
|
||||
def __init__(self, filename, filepath):
|
||||
self.filename = filename
|
||||
self.filepath = filepath
|
||||
|
||||
def read(self):
|
||||
with open(self.filepath, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
r = re.search(r"filename=\"([^\"]+)\"", str(res_headers))
|
||||
if not r or not r.group(1):
|
||||
return get_json_result(data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
f = File(r.group(1), os.path.join(download_path, r.group(1)))
|
||||
txt = await FileService.parse_docs([f], current_user.id)
|
||||
return get_json_result(data=txt)
|
||||
|
||||
if not files:
|
||||
return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
# Use UploadFile directly
|
||||
file_objs = files
|
||||
txt = await FileService.parse_docs(file_objs, current_user.id)
|
||||
|
||||
return get_json_result(data=txt)
|
||||
|
||||
|
||||
@router.post("/set_meta")
|
||||
async def set_meta(
|
||||
req: SetMetaRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not DocumentService.accessible(req.doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
meta = json.loads(req.meta)
|
||||
if not isinstance(meta, dict):
|
||||
return get_json_result(data=False, message="Only dictionary type supported.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
for k, v in meta.items():
|
||||
if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float):
|
||||
return get_json_result(data=False, message=f"The type is not supported: {v}", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
except Exception as e:
|
||||
return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if not isinstance(meta, dict):
|
||||
return get_json_result(data=False, message='Meta data should be in Json map format, like {"key": "value"}', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req.doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
if not DocumentService.update_by_id(req.doc_id, {"meta_fields": meta}):
|
||||
return get_data_error_result(message="Database error (meta updates)!")
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
212
api/apps/file2document_app.py
Normal file
212
api/apps/file2document_app.py
Normal file
@@ -0,0 +1,212 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.db import FileType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
|
||||
# Pydantic models for request/response
|
||||
class ConvertRequest(BaseModel):
|
||||
file_ids: List[str]
|
||||
kb_ids: List[str]
|
||||
|
||||
class RemoveFile2DocumentRequest(BaseModel):
|
||||
file_ids: List[str]
|
||||
|
||||
# Dependency injection
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户"""
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.user_service import UserService
|
||||
from fastapi import HTTPException, status
|
||||
import logging
|
||||
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User {user[0].email} has empty access_token in database"
|
||||
)
|
||||
return user[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required"
|
||||
)
|
||||
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/convert')
|
||||
async def convert(
|
||||
req: ConvertRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_ids = req.kb_ids
|
||||
file_ids = req.file_ids
|
||||
file2documents = []
|
||||
|
||||
try:
|
||||
files = FileService.get_by_ids(file_ids)
|
||||
files_set = dict({file.id: file for file in files})
|
||||
for file_id in file_ids:
|
||||
file = files_set[file_id]
|
||||
if not file:
|
||||
return get_data_error_result(message="File not found!")
|
||||
file_ids_list = [file_id]
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for id in file_ids_list:
|
||||
informs = File2DocumentService.get_by_file_id(id)
|
||||
# delete
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
File2DocumentService.delete_by_file_id(id)
|
||||
|
||||
# insert
|
||||
for kb_id in kb_ids:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
e, file = FileService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this file!")
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": file.type,
|
||||
"name": file.name,
|
||||
"suffix": Path(file.name).suffix.lstrip("."),
|
||||
"location": file.location,
|
||||
"size": file.size
|
||||
})
|
||||
file2document = File2DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"file_id": id,
|
||||
"document_id": doc.id,
|
||||
})
|
||||
|
||||
file2documents.append(file2document.to_json())
|
||||
return get_json_result(data=file2documents)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
req: RemoveFile2DocumentRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
file_ids = req.file_ids
|
||||
if not file_ids:
|
||||
return get_json_result(
|
||||
data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
informs = File2DocumentService.get_by_file_id(file_id)
|
||||
if not informs:
|
||||
return get_data_error_result(message="Inform not found!")
|
||||
for inform in informs:
|
||||
if not inform:
|
||||
return get_data_error_result(message="Inform not found!")
|
||||
File2DocumentService.delete_by_file_id(file_id)
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
481
api/apps/file_app.py
Normal file
481
api/apps/file_app.py
Normal file
@@ -0,0 +1,481 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.db import FileType, FileSource
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
|
||||
# Pydantic models for request/response
|
||||
class CreateFileRequest(BaseModel):
|
||||
name: str
|
||||
parent_id: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
|
||||
class RemoveFileRequest(BaseModel):
|
||||
file_ids: List[str]
|
||||
|
||||
class RenameFileRequest(BaseModel):
|
||||
file_id: str
|
||||
name: str
|
||||
|
||||
class MoveFileRequest(BaseModel):
|
||||
src_file_ids: List[str]
|
||||
dest_file_id: str
|
||||
|
||||
# Dependency injection
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户"""
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.user_service import UserService
|
||||
from fastapi import HTTPException, status
|
||||
import logging
|
||||
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User {user[0].email} has empty access_token in database"
|
||||
)
|
||||
return user[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required"
|
||||
)
|
||||
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/upload')
|
||||
async def upload(
|
||||
parent_id: Optional[str] = Form(None),
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
pf_id = parent_id
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if not files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = files
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
file_res = []
|
||||
try:
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_data_error_result( message="Can't find this folder!")
|
||||
for file_obj in file_objs:
|
||||
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
|
||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||
|
||||
# split file name path
|
||||
if not file_obj.filename:
|
||||
file_obj_names = [pf_folder.name, file_obj.filename]
|
||||
else:
|
||||
full_path = '/' + file_obj.filename
|
||||
file_obj_names = full_path.split('/')
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# get folder
|
||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# create folder
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
# file type
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
location += "_"
|
||||
blob = await file_obj.read()
|
||||
filename = duplicate_name(
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
parent_id=last_folder.id)
|
||||
STORAGE_IMPL.put(last_folder.id, location, blob)
|
||||
file = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": last_folder.id,
|
||||
"tenant_id": current_user.id,
|
||||
"created_by": current_user.id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
}
|
||||
file = FileService.insert(file)
|
||||
file_res.append(file.to_json())
|
||||
return get_json_result(data=file_res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/create')
|
||||
async def create(
|
||||
req: CreateFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
pf_id = req.parent_id
|
||||
input_file_type = req.type
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
try:
|
||||
if not FileService.is_parent_folder_exist(pf_id):
|
||||
return get_json_result(
|
||||
data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
|
||||
if FileService.query(name=req.name, parent_id=pf_id):
|
||||
return get_data_error_result(
|
||||
message="Duplicated folder name in the same folder.")
|
||||
|
||||
if input_file_type == FileType.FOLDER.value:
|
||||
file_type = FileType.FOLDER.value
|
||||
else:
|
||||
file_type = FileType.VIRTUAL.value
|
||||
|
||||
file = FileService.insert({
|
||||
"id": get_uuid(),
|
||||
"parent_id": pf_id,
|
||||
"tenant_id": current_user.id,
|
||||
"created_by": current_user.id,
|
||||
"name": req.name,
|
||||
"location": "",
|
||||
"size": 0,
|
||||
"type": file_type
|
||||
})
|
||||
|
||||
return get_json_result(data=file.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/list')
|
||||
async def list_files(
|
||||
parent_id: Optional[str] = Query(None),
|
||||
keywords: str = Query(""),
|
||||
page: int = Query(1),
|
||||
page_size: int = Query(15),
|
||||
orderby: str = Query("create_time"),
|
||||
desc: bool = Query(True),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
pf_id = parent_id
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, current_user.id)
|
||||
try:
|
||||
e, file = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
|
||||
files, total = FileService.get_by_pf_id(
|
||||
current_user.id, pf_id, page, page_size, orderby, desc, keywords)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(pf_id)
|
||||
if not parent_folder:
|
||||
return get_json_result(message="File not found!")
|
||||
|
||||
return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/root_folder')
|
||||
async def get_root_folder(current_user = Depends(get_current_user)):
|
||||
try:
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
return get_json_result(data={"root_folder": root_folder})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/parent_folder')
|
||||
async def get_parent_folder(
|
||||
file_id: str = Query(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
|
||||
parent_folder = FileService.get_parent_folder(file_id)
|
||||
return get_json_result(data={"parent_folder": parent_folder.to_json()})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/all_parent_folder')
|
||||
async def get_all_parent_folders(
|
||||
file_id: str = Query(...),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
|
||||
parent_folders = FileService.get_all_parent_folders(file_id)
|
||||
parent_folders_res = []
|
||||
for parent_folder in parent_folders:
|
||||
parent_folders_res.append(parent_folder.to_json())
|
||||
return get_json_result(data={"parent_folders": parent_folders_res})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
req: RemoveFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
file_ids = req.file_ids
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for inner_file_id in file_id_list:
|
||||
e, file = FileService.get_by_id(inner_file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="File not found!")
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
FileService.delete_folder_by_pf_id(current_user.id, file_id)
|
||||
else:
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
if not FileService.delete(file):
|
||||
return get_data_error_result(
|
||||
message="Database error (File removal)!")
|
||||
|
||||
# delete file2document
|
||||
informs = File2DocumentService.get_by_file_id(file_id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
File2DocumentService.delete_by_file_id(file_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/rename')
|
||||
async def rename(
|
||||
req: RenameFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
try:
|
||||
e, file = FileService.get_by_id(req.file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="File not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
if file.type != FileType.FOLDER.value \
|
||||
and pathlib.Path(req.name.lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="The extension of file can't be changed",
|
||||
code=settings.RetCode.ARGUMENT_ERROR)
|
||||
for file in FileService.query(name=req.name, pf_id=file.parent_id):
|
||||
if file.name == req.name:
|
||||
return get_data_error_result(
|
||||
message="Duplicated file name in the same folder.")
|
||||
|
||||
if not FileService.update_by_id(
|
||||
req.file_id, {"name": req.name}):
|
||||
return get_data_error_result(
|
||||
message="Database error (File rename)!")
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(req.file_id)
|
||||
if informs:
|
||||
if not DocumentService.update_by_id(
|
||||
informs[0].document_id, {"name": req.name}):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document rename)!")
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/get/{file_id}')
|
||||
async def get(file_id: str, current_user = Depends(get_current_user)):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = STORAGE_IMPL.get(b, n)
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
if file.type == FileType.VISUAL.value:
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
else:
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
else:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
return StreamingResponse(
|
||||
iter([blob]),
|
||||
media_type=content_type,
|
||||
headers={"Content-Disposition": f"attachment; filename={file.name}"}
|
||||
)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/mv')
|
||||
async def move(
|
||||
req: MoveFileRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
try:
|
||||
file_ids = req.src_file_ids
|
||||
parent_id = req.dest_file_id
|
||||
files = FileService.get_by_ids(file_ids)
|
||||
files_dict = {}
|
||||
for file in files:
|
||||
files_dict[file.id] = file
|
||||
|
||||
for file_id in file_ids:
|
||||
file = files_dict[file_id]
|
||||
if not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_data_error_result(message="Parent Folder not found!")
|
||||
FileService.move_file(file_ids, parent_id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
831
api/apps/kb_app.py
Normal file
831
api/apps/kb_app.py
Normal file
@@ -0,0 +1,831 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from api.models.kb_models import (
|
||||
CreateKnowledgeBaseRequest,
|
||||
UpdateKnowledgeBaseRequest,
|
||||
DeleteKnowledgeBaseRequest,
|
||||
ListKnowledgeBasesRequest,
|
||||
RemoveTagsRequest,
|
||||
RenameTagRequest,
|
||||
RunGraphRAGRequest,
|
||||
RunRaptorRequest,
|
||||
RunMindmapRequest,
|
||||
ListPipelineLogsRequest,
|
||||
ListPipelineDatasetLogsRequest,
|
||||
DeletePipelineLogsRequest,
|
||||
UnbindTaskRequest
|
||||
)
|
||||
from api.utils.api_utils import get_current_user
|
||||
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, get_json_result
|
||||
from api.utils import get_uuid
|
||||
from api.db import PipelineTaskType, StatusEnum, FileSource, VALID_FILE_TYPES, VALID_TASK_STATUS
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
# 创建 FastAPI 路由器
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/create')
|
||||
async def create(
|
||||
request: CreateKnowledgeBaseRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
dataset_name = request.name
|
||||
if not isinstance(dataset_name, str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if dataset_name.strip() == "":
|
||||
return get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return get_data_error_result(
|
||||
message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
||||
|
||||
dataset_name = dataset_name.strip()
|
||||
dataset_name = duplicate_name(
|
||||
KnowledgebaseService.query,
|
||||
name=dataset_name,
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value)
|
||||
try:
|
||||
req = {
|
||||
"id": get_uuid(),
|
||||
"name": dataset_name,
|
||||
"tenant_id": current_user.id,
|
||||
"created_by": current_user.id,
|
||||
"parser_id": request.parser_id or "naive",
|
||||
"description": request.description
|
||||
}
|
||||
e, t = TenantService.get_by_id(current_user.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
|
||||
# 设置 embd_id 默认值
|
||||
if not request.embd_id:
|
||||
req["embd_id"] = t.embd_id
|
||||
else:
|
||||
req["embd_id"] = request.embd_id
|
||||
|
||||
if request.parser_config:
|
||||
req["parser_config"] = request.parser_config
|
||||
else:
|
||||
req["parser_config"] = {
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 512,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": False,
|
||||
"topn_tags": 3,
|
||||
"raptor": {
|
||||
"use_raptor": True,
|
||||
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
|
||||
"max_token": 256,
|
||||
"threshold": 0.1,
|
||||
"max_cluster": 64,
|
||||
"random_seed": 0
|
||||
},
|
||||
"graphrag": {
|
||||
"use_graphrag": True,
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
"geo",
|
||||
"event",
|
||||
"category"
|
||||
],
|
||||
"method": "light"
|
||||
}
|
||||
}
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": req["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/update')
|
||||
async def update(
|
||||
request: UpdateKnowledgeBaseRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not isinstance(request.name, str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if request.name.strip() == "":
|
||||
return get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(request.name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return get_data_error_result(
|
||||
message=f"Dataset name length is {len(request.name)} which is large than {DATASET_NAME_LIMIT}")
|
||||
name = request.name.strip()
|
||||
|
||||
if not KnowledgebaseService.accessible4deletion(request.kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
try:
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=request.kb_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(request.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
|
||||
if name.lower() != kb.name.lower() \
|
||||
and len(
|
||||
KnowledgebaseService.query(name=name, tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
|
||||
return get_data_error_result(
|
||||
message="Duplicated knowledgebase name.")
|
||||
|
||||
update_data = {
|
||||
"name": name,
|
||||
"pagerank": request.pagerank
|
||||
}
|
||||
if not KnowledgebaseService.update_by_id(kb.id, update_data):
|
||||
return get_data_error_result()
|
||||
|
||||
if kb.pagerank != request.pagerank:
|
||||
if request.pagerank > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: request.pagerank},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase rename)!")
|
||||
kb = kb.to_dict()
|
||||
kb.update(update_data)
|
||||
|
||||
return get_json_result(data=kb)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/detail')
|
||||
async def detail(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
kb = KnowledgebaseService.get_detail(kb_id)
|
||||
if not kb:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
|
||||
return get_json_result(data=kb)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post('/list')
|
||||
async def list_kbs(
|
||||
request: ListKnowledgeBasesRequest,
|
||||
keywords: str = Query("", description="关键词"),
|
||||
page: int = Query(0, description="页码"),
|
||||
page_size: int = Query(0, description="每页大小"),
|
||||
parser_id: Optional[str] = Query(None, description="解析器ID"),
|
||||
orderby: str = Query("create_time", description="排序字段"),
|
||||
desc: bool = Query(True, description="是否降序"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
page_number = page
|
||||
items_per_page = page_size
|
||||
owner_ids = request.owner_ids
|
||||
try:
|
||||
if not owner_ids:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
tenants = [m["tenant_id"] for m in tenants]
|
||||
kbs, total = KnowledgebaseService.get_by_tenant_ids(
|
||||
tenants, current_user.id, page_number,
|
||||
items_per_page, orderby, desc, keywords, parser_id)
|
||||
else:
|
||||
tenants = owner_ids
|
||||
kbs, total = KnowledgebaseService.get_by_tenant_ids(
|
||||
tenants, current_user.id, 0,
|
||||
0, orderby, desc, keywords, parser_id)
|
||||
kbs = [kb for kb in kbs if kb["tenant_id"] in tenants]
|
||||
total = len(kbs)
|
||||
if page_number and items_per_page:
|
||||
kbs = kbs[(page_number-1)*items_per_page:page_number*items_per_page]
|
||||
return get_json_result(data={"kbs": kbs, "total": total})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@router.post('/rm')
|
||||
async def rm(
|
||||
request: DeleteKnowledgeBaseRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not KnowledgebaseService.accessible4deletion(request.kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
try:
|
||||
kbs = KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=request.kb_id)
|
||||
if not kbs:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
for doc in DocumentService.query(kb_id=request.kb_id):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
if f2d:
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||
if not KnowledgebaseService.delete_by_id(request.kb_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(STORAGE_IMPL, 'remove_bucket'):
|
||||
STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.get('/{kb_id}/tags')
|
||||
async def list_tags(
|
||||
kb_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
tags = []
|
||||
for tenant in tenants:
|
||||
tags += settings.retrievaler.all_tags(tenant["tenant_id"], [kb_id])
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@router.get('/tags')
|
||||
async def list_tags_from_kbs(
|
||||
kb_ids: str = Query(..., description="知识库ID列表,用逗号分隔"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_ids = kb_ids.split(",")
|
||||
for kb_id in kb_ids:
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
tags = []
|
||||
for tenant in tenants:
|
||||
tags += settings.retrievaler.all_tags(tenant["tenant_id"], kb_ids)
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@router.post('/{kb_id}/rm_tags')
|
||||
async def rm_tags(
|
||||
kb_id: str,
|
||||
request: RemoveTagsRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
|
||||
for t in request.tags:
|
||||
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
|
||||
{"remove": {"tag_kwd": t}},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@router.post('/{kb_id}/rename_tag')
|
||||
async def rename_tags(
|
||||
kb_id: str,
|
||||
request: RenameTagRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
|
||||
settings.docStoreConn.update({"tag_kwd": request.from_tag, "kb_id": [kb_id]},
|
||||
{"remove": {"tag_kwd": request.from_tag.strip()}, "add": {"tag_kwd": request.to_tag}},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@router.get('/{kb_id}/knowledge_graph')
|
||||
async def knowledge_graph(
|
||||
kb_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
req = {
|
||||
"kb_id": [kb_id],
|
||||
"knowledge_graph_kwd": ["graph"]
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||
return get_json_result(data=obj)
|
||||
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
if not len(sres.ids):
|
||||
return get_json_result(data=obj)
|
||||
|
||||
for id in sres.ids[:1]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
try:
|
||||
content_json = json.loads(sres.field[id]["content_with_weight"])
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
obj[ty] = content_json
|
||||
|
||||
if "nodes" in obj["graph"]:
|
||||
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
|
||||
if "edges" in obj["graph"]:
|
||||
node_id_set = { o["id"] for o in obj["graph"]["nodes"] }
|
||||
filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
|
||||
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
|
||||
return get_json_result(data=obj)
|
||||
|
||||
|
||||
@router.delete('/{kb_id}/knowledge_graph')
|
||||
async def delete_knowledge_graph(
|
||||
kb_id: str,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@router.get("/get_meta")
|
||||
async def get_meta(
|
||||
kb_ids: str = Query(..., description="知识库ID列表,用逗号分隔"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_ids = kb_ids.split(",")
|
||||
for kb_id in kb_ids:
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
|
||||
|
||||
|
||||
@router.get("/basic_info")
|
||||
async def get_basic_info(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
basic_info = DocumentService.knowledgebase_basic_info(kb_id)
|
||||
|
||||
return get_json_result(data=basic_info)
|
||||
|
||||
|
||||
@router.post("/list_pipeline_logs")
|
||||
async def list_pipeline_logs(
|
||||
request: ListPipelineLogsRequest,
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
keywords: str = Query("", description="关键词"),
|
||||
page: int = Query(0, description="页码"),
|
||||
page_size: int = Query(0, description="每页大小"),
|
||||
orderby: str = Query("create_time", description="排序字段"),
|
||||
desc: bool = Query(True, description="是否降序"),
|
||||
create_date_from: str = Query("", description="创建日期开始"),
|
||||
create_date_to: str = Query("", description="创建日期结束"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
page_number = page
|
||||
items_per_page = page_size
|
||||
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
operation_status = request.operation_status
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = request.types
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
suffix = request.suffix
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/list_pipeline_dataset_logs")
|
||||
async def list_pipeline_dataset_logs(
|
||||
request: ListPipelineDatasetLogsRequest,
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
page: int = Query(0, description="页码"),
|
||||
page_size: int = Query(0, description="每页大小"),
|
||||
orderby: str = Query("create_time", description="排序字段"),
|
||||
desc: bool = Query(True, description="是否降序"),
|
||||
create_date_from: str = Query("", description="创建日期开始"),
|
||||
create_date_to: str = Query("", description="创建日期结束"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
page_number = page
|
||||
items_per_page = page_size
|
||||
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
operation_status = request.operation_status
|
||||
if operation_status:
|
||||
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@router.post("/delete_pipeline_logs")
|
||||
async def delete_pipeline_logs(
|
||||
request: DeletePipelineLogsRequest,
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
log_ids = request.log_ids
|
||||
|
||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@router.get("/pipeline_log_detail")
|
||||
async def pipeline_log_detail(
|
||||
log_id: str = Query(..., description="日志ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not log_id:
|
||||
return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
ok, log = PipelineOperationLogService.get_by_id(log_id)
|
||||
if not ok:
|
||||
return get_data_error_result(message="Invalid pipeline log ID")
|
||||
|
||||
return get_json_result(data=log.to_dict())
|
||||
|
||||
|
||||
@router.post("/run_graphrag")
|
||||
async def run_graphrag(
|
||||
request: RunGraphRAGRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_id = request.kb_id
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@router.get("/trace_graphrag")
|
||||
async def trace_graphrag(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="GraphRAG Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@router.post("/run_raptor")
|
||||
async def run_raptor(
|
||||
request: RunRaptorRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_id = request.kb_id
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||
logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"raptor_task_id": task_id})
|
||||
|
||||
|
||||
@router.get("/trace_raptor")
|
||||
async def trace_raptor(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@router.post("/run_mindmap")
|
||||
async def run_mindmap(
|
||||
request: RunMindmapRequest,
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
kb_id = request.kb_id
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.mindmap_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
|
||||
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"mindmap_task_id": task_id})
|
||||
|
||||
|
||||
@router.get("/trace_mindmap")
|
||||
async def trace_mindmap(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.mindmap_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Mindmap Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@router.delete("/unbind_task")
|
||||
async def delete_kb_task(
|
||||
kb_id: str = Query(..., description="知识库ID"),
|
||||
pipeline_task_type: str = Query(..., description="管道任务类型"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_json_result(data=True)
|
||||
if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||
return get_error_data_result(message="Invalid task type")
|
||||
|
||||
match pipeline_task_type:
|
||||
case PipelineTaskType.GRAPH_RAG:
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
kb_task_id = "graphrag_task_id"
|
||||
kb_task_finish_at = "graphrag_task_finish_at"
|
||||
case PipelineTaskType.RAPTOR:
|
||||
kb_task_id = "raptor_task_id"
|
||||
kb_task_finish_at = "raptor_task_finish_at"
|
||||
case PipelineTaskType.MINDMAP:
|
||||
kb_task_id = "mindmap_task_id"
|
||||
kb_task_finish_at = "mindmap_task_finish_at"
|
||||
case _:
|
||||
return get_error_data_result(message="Internal Error: Invalid task type")
|
||||
|
||||
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id: "", kb_task_finish_at: None})
|
||||
if not ok:
|
||||
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
|
||||
|
||||
return get_json_result(data=True)
|
||||
97
api/apps/langfuse_app.py
Normal file
97
api/apps/langfuse_app.py
Normal file
@@ -0,0 +1,97 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
from langfuse import Langfuse
|
||||
|
||||
from api.db.db_models import DB
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request
|
||||
|
||||
|
||||
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("secret_key", "public_key", "host")
|
||||
def set_api_key():
|
||||
req = request.get_json()
|
||||
secret_key = req.get("secret_key", "")
|
||||
public_key = req.get("public_key", "")
|
||||
host = req.get("host", "")
|
||||
if not all([secret_key, public_key, host]):
|
||||
return get_error_data_result(message="Missing required fields")
|
||||
|
||||
langfuse_keys = dict(
|
||||
tenant_id=current_user.id,
|
||||
secret_key=secret_key,
|
||||
public_key=public_key,
|
||||
host=host,
|
||||
)
|
||||
|
||||
langfuse = Langfuse(public_key=langfuse_keys["public_key"], secret_key=langfuse_keys["secret_key"], host=langfuse_keys["host"])
|
||||
if not langfuse.auth_check():
|
||||
return get_error_data_result(message="Invalid Langfuse keys")
|
||||
|
||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
||||
with DB.atomic():
|
||||
try:
|
||||
if not langfuse_entry:
|
||||
TenantLangfuseService.save(**langfuse_keys)
|
||||
else:
|
||||
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys)
|
||||
return get_json_result(data=langfuse_keys)
|
||||
except Exception as e:
|
||||
server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request()
|
||||
def get_api_key():
|
||||
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id)
|
||||
if not langfuse_entry:
|
||||
return get_json_result(message="Have not record any Langfuse keys.")
|
||||
|
||||
langfuse = Langfuse(public_key=langfuse_entry["public_key"], secret_key=langfuse_entry["secret_key"], host=langfuse_entry["host"])
|
||||
try:
|
||||
if not langfuse.auth_check():
|
||||
return get_error_data_result(message="Invalid Langfuse keys loaded")
|
||||
except langfuse.api.core.api_error.ApiError as api_err:
|
||||
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
||||
except Exception as e:
|
||||
server_error_response(e)
|
||||
|
||||
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
||||
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
||||
|
||||
return get_json_result(data=langfuse_entry)
|
||||
|
||||
|
||||
@manager.route("/api_key", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request()
|
||||
def delete_api_key():
|
||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
||||
if not langfuse_entry:
|
||||
return get_json_result(message="Have not record any Langfuse keys.")
|
||||
|
||||
with DB.atomic():
|
||||
try:
|
||||
TenantLangfuseService.delete_model(langfuse_entry)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
server_error_response(e)
|
||||
396
api/apps/llm_app.py
Normal file
396
api/apps/llm_app.py
Normal file
@@ -0,0 +1,396 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import json
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
|
||||
|
||||
@manager.route('/factories', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def factories():
|
||||
try:
|
||||
fac = LLMFactoriesService.get_all()
|
||||
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]
|
||||
llms = LLMService.get_all()
|
||||
mdl_types = {}
|
||||
for m in llms:
|
||||
if m.status != StatusEnum.VALID.value:
|
||||
continue
|
||||
if m.fid not in mdl_types:
|
||||
mdl_types[m.fid] = set([])
|
||||
mdl_types[m.fid].add(m.model_type)
|
||||
for f in fac:
|
||||
f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK,
|
||||
LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS]))
|
||||
return get_json_result(data=fac)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/set_api_key', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
def set_api_key():
|
||||
req = request.json
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
extra = {"provider": factory}
|
||||
msg = ""
|
||||
for llm in LLMService.query(fid=factory):
|
||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
raise Exception("Fail")
|
||||
embd_passed = True
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9, 'max_tokens': 50})
|
||||
if m.find("**ERROR**") >= 0:
|
||||
raise Exception(m)
|
||||
chat_passed = True
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
||||
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
||||
mdl = RerankModel[factory](
|
||||
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
try:
|
||||
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
||||
if len(arr) == 0 or tc == 0:
|
||||
raise Exception("Fail")
|
||||
rerank_passed = True
|
||||
logging.debug(f'passed model rerank {llm.llm_name}')
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(
|
||||
e)
|
||||
if any([embd_passed, chat_passed, rerank_passed]):
|
||||
msg = ''
|
||||
break
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
llm_config = {
|
||||
"api_key": req["api_key"],
|
||||
"api_base": req.get("base_url", "")
|
||||
}
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req:
|
||||
llm_config[n] = req[n]
|
||||
|
||||
for llm in LLMService.query(fid=factory):
|
||||
llm_config["max_tokens"]=llm.max_tokens
|
||||
if not TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id,
|
||||
TenantLLM.llm_factory == factory,
|
||||
TenantLLM.llm_name == llm.llm_name],
|
||||
llm_config):
|
||||
TenantLLMService.save(
|
||||
tenant_id=current_user.id,
|
||||
llm_factory=factory,
|
||||
llm_name=llm.llm_name,
|
||||
model_type=llm.model_type,
|
||||
api_key=llm_config["api_key"],
|
||||
api_base=llm_config["api_base"],
|
||||
max_tokens=llm_config["max_tokens"]
|
||||
)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/add_llm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def add_llm():
|
||||
req = request.json
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
|
||||
def apikey_json(keys):
|
||||
nonlocal req
|
||||
return json.dumps({k: req.get(k, "") for k in keys})
|
||||
|
||||
if factory == "VolcEngine":
|
||||
# For VolcEngine, due to its special authentication method
|
||||
# Assemble ark_api_key endpoint_id into api_key
|
||||
api_key = apikey_json(["ark_api_key", "endpoint_id"])
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
|
||||
return set_api_key()
|
||||
|
||||
elif factory == "Tencent Cloud":
|
||||
req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
|
||||
return set_api_key()
|
||||
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
||||
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
||||
|
||||
elif factory == "LocalAI":
|
||||
llm_name += "___LocalAI"
|
||||
|
||||
elif factory == "HuggingFace":
|
||||
llm_name += "___HuggingFace"
|
||||
|
||||
elif factory == "OpenAI-API-Compatible":
|
||||
llm_name += "___OpenAI-API"
|
||||
|
||||
elif factory == "VLLM":
|
||||
llm_name += "___VLLM"
|
||||
|
||||
elif factory == "XunFei Spark":
|
||||
if req["model_type"] == "chat":
|
||||
api_key = req.get("spark_api_password", "")
|
||||
elif req["model_type"] == "tts":
|
||||
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
|
||||
|
||||
elif factory == "BaiduYiyan":
|
||||
api_key = apikey_json(["yiyan_ak", "yiyan_sk"])
|
||||
|
||||
elif factory == "Fish Audio":
|
||||
api_key = apikey_json(["fish_audio_ak", "fish_audio_refid"])
|
||||
|
||||
elif factory == "Google Cloud":
|
||||
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
|
||||
|
||||
elif factory == "Azure-OpenAI":
|
||||
api_key = apikey_json(["api_key", "api_version"])
|
||||
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": factory,
|
||||
"model_type": req["model_type"],
|
||||
"llm_name": llm_name,
|
||||
"api_base": req.get("api_base", ""),
|
||||
"api_key": api_key,
|
||||
"max_tokens": req.get("max_tokens")
|
||||
}
|
||||
|
||||
msg = ""
|
||||
mdl_nm = llm["llm_name"].split("___")[0]
|
||||
extra = {"provider": factory}
|
||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](
|
||||
key=llm['api_key'],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"])
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
key=llm['api_key'],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"],
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
||||
"temperature": 0.9})
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(
|
||||
e)
|
||||
elif llm["model_type"] == LLMType.RERANK:
|
||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = RerankModel[factory](
|
||||
key=llm["api_key"],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"]
|
||||
)
|
||||
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
||||
if len(arr) == 0:
|
||||
raise Exception("Not known.")
|
||||
except KeyError:
|
||||
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(
|
||||
e)
|
||||
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
|
||||
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
|
||||
mdl = CvModel[factory](
|
||||
key=llm["api_key"],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"]
|
||||
)
|
||||
try:
|
||||
image_data = test_image
|
||||
m, tc = mdl.describe(image_data)
|
||||
if not m and not tc:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.TTS:
|
||||
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
||||
mdl = TTSModel[factory](
|
||||
key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]
|
||||
)
|
||||
try:
|
||||
for resp in mdl.tts("Hello~ RAGFlower!"):
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
else:
|
||||
# TODO: check other type of models
|
||||
pass
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
if not TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory,
|
||||
TenantLLM.llm_name == llm["llm_name"]], llm):
|
||||
TenantLLMService.save(**llm)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/delete_llm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def delete_llm():
|
||||
req = request.json
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
|
||||
TenantLLM.llm_name == req["llm_name"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/delete_factory', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
def delete_factory():
|
||||
req = request.json
|
||||
TenantLLMService.filter_delete(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/my_llms', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def my_llms():
|
||||
try:
|
||||
include_details = request.args.get('include_details', 'false').lower() == 'true'
|
||||
|
||||
if include_details:
|
||||
res = {}
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
factories = LLMFactoriesService.query(status=StatusEnum.VALID.value)
|
||||
|
||||
for o in objs:
|
||||
o_dict = o.to_dict()
|
||||
factory_tags = None
|
||||
for f in factories:
|
||||
if f.name == o_dict["llm_factory"]:
|
||||
factory_tags = f.tags
|
||||
break
|
||||
|
||||
if o_dict["llm_factory"] not in res:
|
||||
res[o_dict["llm_factory"]] = {
|
||||
"tags": factory_tags,
|
||||
"llm": []
|
||||
}
|
||||
|
||||
res[o_dict["llm_factory"]]["llm"].append({
|
||||
"type": o_dict["model_type"],
|
||||
"name": o_dict["llm_name"],
|
||||
"used_token": o_dict["used_tokens"],
|
||||
"api_base": o_dict["api_base"] or "",
|
||||
"max_tokens": o_dict["max_tokens"] or 8192
|
||||
})
|
||||
else:
|
||||
res = {}
|
||||
for o in TenantLLMService.get_my_llms(current_user.id):
|
||||
if o["llm_factory"] not in res:
|
||||
res[o["llm_factory"]] = {
|
||||
"tags": o["tags"],
|
||||
"llm": []
|
||||
}
|
||||
res[o["llm_factory"]]["llm"].append({
|
||||
"type": o["model_type"],
|
||||
"name": o["llm_name"],
|
||||
"used_token": o["used_tokens"]
|
||||
})
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_app():
|
||||
self_deployed = ["Youdao", "FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio", "GPUStack"]
|
||||
weighted = ["Youdao", "FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
|
||||
model_type = request.args.get("model_type")
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict()
|
||||
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
||||
|
||||
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
|
||||
for o in objs:
|
||||
if o.llm_name + "@" + o.llm_factory in llm_set:
|
||||
continue
|
||||
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
|
||||
|
||||
res = {}
|
||||
for m in llms:
|
||||
if model_type and m["model_type"].find(model_type) < 0:
|
||||
continue
|
||||
if m["fid"] not in res:
|
||||
res[m["fid"]] = []
|
||||
res[m["fid"]].append(m)
|
||||
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
444
api/apps/mcp_server_app.py
Normal file
444
api/apps/mcp_server_app.py
Normal file
@@ -0,0 +1,444 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from api.db import VALID_MCP_SERVER_TYPES
|
||||
from api.db.db_models import MCPServer
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import RetCode
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_mcp_tools
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_mcp() -> Response:
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
try:
|
||||
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
|
||||
total = len(servers)
|
||||
|
||||
if page_number and items_per_page:
|
||||
servers = servers[(page_number - 1) * items_per_page : page_number * items_per_page]
|
||||
|
||||
return get_json_result(data={"mcp_servers": servers, "total": total})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/detail", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def detail() -> Response:
|
||||
mcp_id = request.args["mcp_id"]
|
||||
try:
|
||||
mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id)
|
||||
|
||||
if mcp_server is None:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, data=None)
|
||||
|
||||
return get_json_result(data=mcp_server.to_dict())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "url", "server_type")
|
||||
def create() -> Response:
|
||||
req = request.get_json()
|
||||
|
||||
server_type = req.get("server_type", "")
|
||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||
return get_data_error_result(message="Unsupported MCP server type.")
|
||||
|
||||
server_name = req.get("name", "")
|
||||
if not server_name or len(server_name.encode("utf-8")) > 255:
|
||||
return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.")
|
||||
|
||||
e, _ = MCPServerService.get_by_name_and_tenant(name=server_name, tenant_id=current_user.id)
|
||||
if e:
|
||||
return get_data_error_result(message="Duplicated MCP server name.")
|
||||
|
||||
url = req.get("url", "")
|
||||
if not url:
|
||||
return get_data_error_result(message="Invalid url.")
|
||||
|
||||
headers = safe_json_parse(req.get("headers", {}))
|
||||
req["headers"] = headers
|
||||
variables = safe_json_parse(req.get("variables", {}))
|
||||
variables.pop("tools", None)
|
||||
|
||||
timeout = get_float(req, "timeout", 10)
|
||||
|
||||
try:
|
||||
req["id"] = get_uuid()
|
||||
req["tenant_id"] = current_user.id
|
||||
|
||||
e, _ = TenantService.get_by_id(current_user.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
tools = server_tools[server_name]
|
||||
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
|
||||
variables["tools"] = tools
|
||||
req["variables"] = variables
|
||||
|
||||
if not MCPServerService.insert(**req):
|
||||
return get_data_error_result("Failed to create MCP server.")
|
||||
|
||||
return get_json_result(data=req)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/update", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id")
|
||||
def update() -> Response:
|
||||
req = request.get_json()
|
||||
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
if not e or mcp_server.tenant_id != current_user.id:
|
||||
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
|
||||
|
||||
server_type = req.get("server_type", mcp_server.server_type)
|
||||
if server_type and server_type not in VALID_MCP_SERVER_TYPES:
|
||||
return get_data_error_result(message="Unsupported MCP server type.")
|
||||
server_name = req.get("name", mcp_server.name)
|
||||
if server_name and len(server_name.encode("utf-8")) > 255:
|
||||
return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.")
|
||||
url = req.get("url", mcp_server.url)
|
||||
if not url:
|
||||
return get_data_error_result(message="Invalid url.")
|
||||
|
||||
headers = safe_json_parse(req.get("headers", mcp_server.headers))
|
||||
req["headers"] = headers
|
||||
|
||||
variables = safe_json_parse(req.get("variables", mcp_server.variables))
|
||||
variables.pop("tools", None)
|
||||
|
||||
timeout = get_float(req, "timeout", 10)
|
||||
|
||||
try:
|
||||
req["tenant_id"] = current_user.id
|
||||
req.pop("mcp_id", None)
|
||||
req["id"] = mcp_id
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
tools = server_tools[server_name]
|
||||
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
|
||||
variables["tools"] = tools
|
||||
req["variables"] = variables
|
||||
|
||||
if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req):
|
||||
return get_data_error_result(message="Failed to updated MCP server.")
|
||||
|
||||
e, updated_mcp = MCPServerService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Failed to fetch updated MCP server.")
|
||||
|
||||
return get_json_result(data=updated_mcp.to_dict())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def rm() -> Response:
|
||||
req = request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
try:
|
||||
req["tenant_id"] = current_user.id
|
||||
|
||||
if not MCPServerService.delete_by_ids(mcp_ids):
|
||||
return get_data_error_result(message=f"Failed to delete MCP servers {mcp_ids}")
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/import", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcpServers")
|
||||
def import_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
servers = req.get("mcpServers", {})
|
||||
if not servers:
|
||||
return get_data_error_result(message="No MCP servers provided.")
|
||||
|
||||
timeout = get_float(req, "timeout", 10)
|
||||
|
||||
results = []
|
||||
try:
|
||||
for server_name, config in servers.items():
|
||||
if not all(key in config for key in {"type", "url"}):
|
||||
results.append({"server": server_name, "success": False, "message": "Missing required fields (type or url)"})
|
||||
continue
|
||||
|
||||
if not server_name or len(server_name.encode("utf-8")) > 255:
|
||||
results.append({"server": server_name, "success": False, "message": f"Invalid MCP name or length is {len(server_name)} which is large than 255."})
|
||||
continue
|
||||
|
||||
base_name = server_name
|
||||
new_name = base_name
|
||||
counter = 0
|
||||
|
||||
while True:
|
||||
e, _ = MCPServerService.get_by_name_and_tenant(name=new_name, tenant_id=current_user.id)
|
||||
if not e:
|
||||
break
|
||||
new_name = f"{base_name}_{counter}"
|
||||
counter += 1
|
||||
|
||||
create_data = {
|
||||
"id": get_uuid(),
|
||||
"tenant_id": current_user.id,
|
||||
"name": new_name,
|
||||
"url": config["url"],
|
||||
"server_type": config["type"],
|
||||
"variables": {"authorization_token": config.get("authorization_token", "")},
|
||||
}
|
||||
|
||||
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
|
||||
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
|
||||
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
|
||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
||||
if err_message:
|
||||
results.append({"server": base_name, "success": False, "message": err_message})
|
||||
continue
|
||||
|
||||
tools = server_tools[new_name]
|
||||
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
|
||||
create_data["variables"]["tools"] = tools
|
||||
|
||||
if MCPServerService.insert(**create_data):
|
||||
result = {"server": server_name, "success": True, "action": "created", "id": create_data["id"], "new_name": new_name}
|
||||
if new_name != base_name:
|
||||
result["message"] = f"Renamed from '{base_name}' to '{new_name}' avoid duplication"
|
||||
results.append(result)
|
||||
else:
|
||||
results.append({"server": server_name, "success": False, "message": "Failed to create MCP server."})
|
||||
|
||||
return get_json_result(data={"results": results})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/export", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def export_multiple() -> Response:
|
||||
req = request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
if not mcp_ids:
|
||||
return get_data_error_result(message="No MCP server IDs provided.")
|
||||
|
||||
try:
|
||||
exported_servers = {}
|
||||
|
||||
for mcp_id in mcp_ids:
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
|
||||
if e and mcp_server.tenant_id == current_user.id:
|
||||
server_key = mcp_server.name
|
||||
|
||||
exported_servers[server_key] = {
|
||||
"type": mcp_server.server_type,
|
||||
"url": mcp_server.url,
|
||||
"name": mcp_server.name,
|
||||
"authorization_token": mcp_server.variables.get("authorization_token", ""),
|
||||
"tools": mcp_server.variables.get("tools", {}),
|
||||
}
|
||||
|
||||
return get_json_result(data={"mcpServers": exported_servers})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def list_tools() -> Response:
|
||||
req = request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
if not mcp_ids:
|
||||
return get_data_error_result(message="No MCP server IDs provided.")
|
||||
|
||||
timeout = get_float(req, "timeout", 10)
|
||||
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
try:
|
||||
for mcp_id in mcp_ids:
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
|
||||
if e and mcp_server.tenant_id == current_user.id:
|
||||
server_key = mcp_server.id
|
||||
|
||||
cached_tools = mcp_server.variables.get("tools", {})
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
|
||||
try:
|
||||
tools = tool_call_session.get_tools(timeout)
|
||||
except Exception as e:
|
||||
tools = []
|
||||
return get_data_error_result(message=f"MCP list tools error: {e}")
|
||||
|
||||
results[server_key] = []
|
||||
for tool in tools:
|
||||
tool_dict = tool.model_dump()
|
||||
cached_tool = cached_tools.get(tool_dict["name"], {})
|
||||
|
||||
tool_dict["enabled"] = cached_tool.get("enabled", True)
|
||||
results[server_key].append(tool_dict)
|
||||
|
||||
return get_json_result(data=results)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
|
||||
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tool_name", "arguments")
|
||||
def test_tool() -> Response:
|
||||
req = request.get_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
|
||||
timeout = get_float(req, "timeout", 10)
|
||||
|
||||
tool_name = req.get("tool_name", "")
|
||||
arguments = req.get("arguments", {})
|
||||
if not all([tool_name, arguments]):
|
||||
return get_data_error_result(message="Require provide tool name and arguments.")
|
||||
|
||||
tool_call_sessions = []
|
||||
try:
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
if not e or mcp_server.tenant_id != current_user.id:
|
||||
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
result = tool_call_session.tool_call(tool_name, arguments, timeout)
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/cache_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tools")
|
||||
def cache_tool() -> Response:
|
||||
req = request.get_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
tools = req.get("tools", [])
|
||||
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
if not e or mcp_server.tenant_id != current_user.id:
|
||||
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
|
||||
|
||||
variables = mcp_server.variables
|
||||
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
|
||||
variables["tools"] = tools
|
||||
|
||||
if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], {"variables": variables}):
|
||||
return get_data_error_result(message="Failed to updated MCP server.")
|
||||
|
||||
return get_json_result(data=tools)
|
||||
|
||||
|
||||
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
|
||||
@validate_request("url", "server_type")
|
||||
def test_mcp() -> Response:
|
||||
req = request.get_json()
|
||||
|
||||
url = req.get("url", "")
|
||||
if not url:
|
||||
return get_data_error_result(message="Invalid MCP url.")
|
||||
|
||||
server_type = req.get("server_type", "")
|
||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||
return get_data_error_result(message="Unsupported MCP server type.")
|
||||
|
||||
timeout = get_float(req, "timeout", 10)
|
||||
headers = safe_json_parse(req.get("headers", {}))
|
||||
variables = safe_json_parse(req.get("variables", {}))
|
||||
|
||||
mcp_server = MCPServer(id=f"{server_type}: {url}", server_type=server_type, url=url, headers=headers, variables=variables)
|
||||
|
||||
result = []
|
||||
try:
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
|
||||
try:
|
||||
tools = tool_call_session.get_tools(timeout)
|
||||
except Exception as e:
|
||||
tools = []
|
||||
return get_data_error_result(message=f"Test MCP error: {e}")
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions([tool_call_session])
|
||||
|
||||
for tool in tools:
|
||||
tool_dict = tool.model_dump()
|
||||
tool_dict["enabled"] = True
|
||||
result.append(tool_dict)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
12
api/apps/plugin_app.py
Normal file
12
api/apps/plugin_app.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from flask import Response
|
||||
from flask_login import login_required
|
||||
from api.utils.api_utils import get_json_result
|
||||
from plugin import GlobalPluginManager
|
||||
|
||||
@manager.route('/llm_tools', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def llm_tools() -> Response:
|
||||
tools = GlobalPluginManager.get_llm_tools()
|
||||
tools_metadata = [t.get_metadata() for t in tools]
|
||||
|
||||
return get_json_result(data=tools_metadata)
|
||||
128
api/apps/sdk/agent.py
Normal file
128
api/apps/sdk/agent.py
Normal file
@@ -0,0 +1,128 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, cast
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.settings import RetCode
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from flask import request
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def list_agents(tenant_id):
|
||||
id = request.args.get("id")
|
||||
title = request.args.get("title")
|
||||
if id or title:
|
||||
canvas = UserCanvasService.query(id=id, title=title, user_id=tenant_id)
|
||||
if not canvas:
|
||||
return get_error_data_result("The agent doesn't exist.")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "update_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
canvas = UserCanvasService.get_list(tenant_id,page_number,items_per_page,orderby,desc,id,title)
|
||||
return get_result(data=canvas)
|
||||
|
||||
|
||||
@manager.route("/agents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], request.json)
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
else:
|
||||
return get_json_result(data=False, message="No DSL data in request.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if req.get("title") is not None:
|
||||
req["title"] = req["title"].strip()
|
||||
else:
|
||||
return get_json_result(data=False, message="No title in request.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if UserCanvasService.query(user_id=tenant_id, title=req["title"]):
|
||||
return get_data_error_result(message=f"Agent with title {req['title']} already exists.")
|
||||
|
||||
agent_id = get_uuid()
|
||||
req["id"] = agent_id
|
||||
|
||||
if not UserCanvasService.save(**req):
|
||||
return get_data_error_result(message="Fail to create agent.")
|
||||
|
||||
UserCanvasVersionService.insert(
|
||||
user_canvas_id=agent_id,
|
||||
title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")),
|
||||
dsl=req["dsl"]
|
||||
)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], request.json).items() if v is not None}
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
|
||||
if req.get("title") is not None:
|
||||
req["title"] = req["title"].strip()
|
||||
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_json_result(
|
||||
data=False, message="Only owner of canvas authorized for this operation.",
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
UserCanvasService.update_by_id(agent_id, req)
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
UserCanvasVersionService.insert(
|
||||
user_canvas_id=agent_id,
|
||||
title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")),
|
||||
dsl=req["dsl"]
|
||||
)
|
||||
|
||||
UserCanvasVersionService.delete_all_versions(agent_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete_agent(tenant_id: str, agent_id: str):
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_json_result(
|
||||
data=False, message="Only owner of canvas authorized for this operation.",
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
UserCanvasService.delete_by_id(agent_id)
|
||||
return get_json_result(data=True)
|
||||
325
api/apps/sdk/chat.py
Normal file
325
api/apps/sdk/chat.py
Normal file
@@ -0,0 +1,325 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
from api import settings
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||
|
||||
|
||||
@manager.route("/chats", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
req = request.json
|
||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||
for kb_id in ids:
|
||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||
if not kbs:
|
||||
return get_error_data_result(f"You don't own the dataset {kb_id}")
|
||||
kbs = KnowledgebaseService.query(id=kb_id)
|
||||
kb = kbs[0]
|
||||
if kb.chunk_num == 0:
|
||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(ids) if ids else []
|
||||
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||
embd_count = list(set(embd_ids))
|
||||
if len(embd_count) > 1:
|
||||
return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req["kb_ids"] = ids
|
||||
# llm
|
||||
llm = req.get("llm")
|
||||
if llm:
|
||||
if "model_name" in llm:
|
||||
req["llm_id"] = llm.pop("model_name")
|
||||
if req.get("llm_id") is not None:
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
|
||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||
req["llm_setting"] = req.pop("llm")
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
return get_error_data_result(message="Tenant not found!")
|
||||
# prompt
|
||||
prompt = req.get("prompt")
|
||||
key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"}
|
||||
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"]
|
||||
if prompt:
|
||||
for new_key, old_key in key_mapping.items():
|
||||
if old_key in prompt:
|
||||
prompt[new_key] = prompt.pop(old_key)
|
||||
for key in key_list:
|
||||
if key in prompt:
|
||||
req[key] = prompt.pop(key)
|
||||
req["prompt_config"] = req.pop("prompt")
|
||||
# init
|
||||
req["id"] = get_uuid()
|
||||
req["description"] = req.get("description", "A helpful Assistant")
|
||||
req["icon"] = req.get("avatar", "")
|
||||
req["top_n"] = req.get("top_n", 6)
|
||||
req["top_k"] = req.get("top_k", 1024)
|
||||
req["rerank_id"] = req.get("rerank_id", "")
|
||||
if req.get("rerank_id"):
|
||||
value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"]
|
||||
if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"):
|
||||
return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist")
|
||||
if not req.get("llm_id"):
|
||||
req["llm_id"] = tenant.llm_id
|
||||
if not req.get("name"):
|
||||
return get_error_data_result(message="`name` is required.")
|
||||
if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="Duplicated chat name in creating chat.")
|
||||
# tenant_id
|
||||
if req.get("tenant_id"):
|
||||
return get_error_data_result(message="`tenant_id` must not be provided.")
|
||||
req["tenant_id"] = tenant_id
|
||||
# prompt more parameter
|
||||
default_prompt = {
|
||||
"system": """You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.
|
||||
Here is the knowledge base:
|
||||
{knowledge}
|
||||
The above is the knowledge base.""",
|
||||
"prologue": "Hi! I'm your assistant. What can I do for you?",
|
||||
"parameters": [{"key": "knowledge", "optional": False}],
|
||||
"empty_response": "Sorry! No relevant content was found in the knowledge base!",
|
||||
"quote": True,
|
||||
"tts": False,
|
||||
"refine_multiturn": True,
|
||||
}
|
||||
key_list_2 = ["system", "prologue", "parameters", "empty_response", "quote", "tts", "refine_multiturn"]
|
||||
if "prompt_config" not in req:
|
||||
req["prompt_config"] = {}
|
||||
for key in key_list_2:
|
||||
temp = req["prompt_config"].get(key)
|
||||
if (not temp and key == "system") or (key not in req["prompt_config"]):
|
||||
req["prompt_config"][key] = default_prompt[key]
|
||||
for p in req["prompt_config"]["parameters"]:
|
||||
if p["optional"]:
|
||||
continue
|
||||
if req["prompt_config"]["system"].find("{%s}" % p["key"]) < 0:
|
||||
return get_error_data_result(message="Parameter '{}' is not used".format(p["key"]))
|
||||
# save
|
||||
if not DialogService.save(**req):
|
||||
return get_error_data_result(message="Fail to new a chat!")
|
||||
# response
|
||||
e, res = DialogService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_error_data_result(message="Fail to new a chat!")
|
||||
res = res.to_json()
|
||||
renamed_dict = {}
|
||||
for key, value in res["prompt_config"].items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_dict[new_key] = value
|
||||
res["prompt"] = renamed_dict
|
||||
del res["prompt_config"]
|
||||
new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]}
|
||||
res["prompt"].update(new_dict)
|
||||
for key in key_list:
|
||||
del res[key]
|
||||
res["llm"] = res.pop("llm_setting")
|
||||
res["llm"]["model_name"] = res.pop("llm_id")
|
||||
del res["kb_ids"]
|
||||
res["dataset_ids"] = req.get("dataset_ids", [])
|
||||
res["avatar"] = res.pop("icon")
|
||||
return get_result(data=res)
|
||||
|
||||
|
||||
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
req = request.json
|
||||
ids = req.get("dataset_ids", [])
|
||||
if "show_quotation" in req:
|
||||
req["do_refer"] = req.pop("show_quotation")
|
||||
if ids:
|
||||
for kb_id in ids:
|
||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||
if not kbs:
|
||||
return get_error_data_result(f"You don't own the dataset {kb_id}")
|
||||
kbs = KnowledgebaseService.query(id=kb_id)
|
||||
kb = kbs[0]
|
||||
if kb.chunk_num == 0:
|
||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(ids)
|
||||
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||
embd_count = list(set(embd_ids))
|
||||
if len(embd_count) > 1:
|
||||
return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req["kb_ids"] = ids
|
||||
llm = req.get("llm")
|
||||
if llm:
|
||||
if "model_name" in llm:
|
||||
req["llm_id"] = llm.pop("model_name")
|
||||
if req.get("llm_id") is not None:
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
|
||||
if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
|
||||
return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
|
||||
req["llm_setting"] = req.pop("llm")
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
return get_error_data_result(message="Tenant not found!")
|
||||
# prompt
|
||||
prompt = req.get("prompt")
|
||||
key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"}
|
||||
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"]
|
||||
if prompt:
|
||||
for new_key, old_key in key_mapping.items():
|
||||
if old_key in prompt:
|
||||
prompt[new_key] = prompt.pop(old_key)
|
||||
for key in key_list:
|
||||
if key in prompt:
|
||||
req[key] = prompt.pop(key)
|
||||
req["prompt_config"] = req.pop("prompt")
|
||||
e, res = DialogService.get_by_id(chat_id)
|
||||
res = res.to_json()
|
||||
if req.get("rerank_id"):
|
||||
value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"]
|
||||
if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"):
|
||||
return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist")
|
||||
if "name" in req:
|
||||
if not req.get("name"):
|
||||
return get_error_data_result(message="`name` cannot be empty.")
|
||||
if req["name"].lower() != res["name"].lower() and len(DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
|
||||
return get_error_data_result(message="Duplicated chat name in updating chat.")
|
||||
if "prompt_config" in req:
|
||||
res["prompt_config"].update(req["prompt_config"])
|
||||
for p in res["prompt_config"]["parameters"]:
|
||||
if p["optional"]:
|
||||
continue
|
||||
if res["prompt_config"]["system"].find("{%s}" % p["key"]) < 0:
|
||||
return get_error_data_result(message="Parameter '{}' is not used".format(p["key"]))
|
||||
if "llm_setting" in req:
|
||||
res["llm_setting"].update(req["llm_setting"])
|
||||
req["prompt_config"] = res["prompt_config"]
|
||||
req["llm_setting"] = res["llm_setting"]
|
||||
# avatar
|
||||
if "avatar" in req:
|
||||
req["icon"] = req.pop("avatar")
|
||||
if "dataset_ids" in req:
|
||||
req.pop("dataset_ids")
|
||||
if not DialogService.update_by_id(chat_id, req):
|
||||
return get_error_data_result(message="Chat not found!")
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route("/chats", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
if not req:
|
||||
ids = None
|
||||
else:
|
||||
ids = req.get("ids")
|
||||
if not ids:
|
||||
id_list = []
|
||||
dias = DialogService.query(tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
for dia in dias:
|
||||
id_list.append(dia.id)
|
||||
else:
|
||||
id_list = ids
|
||||
|
||||
unique_id_list, duplicate_messages = check_duplicate_ids(id_list, "assistant")
|
||||
|
||||
for id in unique_id_list:
|
||||
if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value):
|
||||
errors.append(f"Assistant({id}) not found.")
|
||||
continue
|
||||
temp_dict = {"status": StatusEnum.INVALID.value}
|
||||
DialogService.update_by_id(id, temp_dict)
|
||||
success_count += 1
|
||||
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} chats with {len(errors)} errors")
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
return get_result(message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
|
||||
else:
|
||||
return get_error_data_result(message=";".join(duplicate_messages))
|
||||
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route("/chats", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_chat(tenant_id):
|
||||
id = request.args.get("id")
|
||||
name = request.args.get("name")
|
||||
if id or name:
|
||||
chat = DialogService.query(id=id, name=name, status=StatusEnum.VALID.value, tenant_id=tenant_id)
|
||||
if not chat:
|
||||
return get_error_data_result(message="The chat doesn't exist")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name)
|
||||
if not chats:
|
||||
return get_result(data=[])
|
||||
list_assts = []
|
||||
key_mapping = {
|
||||
"parameters": "variables",
|
||||
"prologue": "opener",
|
||||
"quote": "show_quote",
|
||||
"system": "prompt",
|
||||
"rerank_id": "rerank_model",
|
||||
"vector_similarity_weight": "keywords_similarity_weight",
|
||||
"do_refer": "show_quotation",
|
||||
}
|
||||
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
|
||||
for res in chats:
|
||||
renamed_dict = {}
|
||||
for key, value in res["prompt_config"].items():
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_dict[new_key] = value
|
||||
res["prompt"] = renamed_dict
|
||||
del res["prompt_config"]
|
||||
new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]}
|
||||
res["prompt"].update(new_dict)
|
||||
for key in key_list:
|
||||
del res[key]
|
||||
res["llm"] = res.pop("llm_setting")
|
||||
res["llm"]["model_name"] = res.pop("llm_id")
|
||||
kb_list = []
|
||||
for kb_id in res["kb_ids"]:
|
||||
kb = KnowledgebaseService.query(id=kb_id)
|
||||
if not kb:
|
||||
logging.warning(f"The kb {kb_id} does not exist.")
|
||||
continue
|
||||
kb_list.append(kb[0].to_json())
|
||||
del res["kb_ids"]
|
||||
res["datasets"] = kb_list
|
||||
res["avatar"] = res.pop("icon")
|
||||
list_assts.append(res)
|
||||
return get_result(data=list_assts)
|
||||
527
api/apps/sdk/dataset.py
Normal file
527
api/apps/sdk/dataset.py
Normal file
@@ -0,0 +1,527 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from flask import request
|
||||
from peewee import OperationalError
|
||||
from api import settings
|
||||
from api.db import FileSource, StatusEnum
|
||||
from api.db.db_models import File
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import (
|
||||
deep_merge,
|
||||
get_error_argument_result,
|
||||
get_error_data_result,
|
||||
get_error_operating_result,
|
||||
get_error_permission_result,
|
||||
get_parser_config,
|
||||
get_result,
|
||||
remap_dictionary_keys,
|
||||
token_required,
|
||||
verify_embedding_availability,
|
||||
)
|
||||
from api.utils.validation_utils import (
|
||||
CreateDatasetReq,
|
||||
DeleteDatasetReq,
|
||||
ListDatasetReq,
|
||||
UpdateDatasetReq,
|
||||
validate_and_parse_json_request,
|
||||
validate_and_parse_request_args,
|
||||
)
|
||||
from rag.nlp import search
|
||||
from rag.settings import PAGERANK_FLD
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
"""
|
||||
Create a new dataset.
|
||||
---
|
||||
tags:
|
||||
- Datasets
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: header
|
||||
name: Authorization
|
||||
type: string
|
||||
required: true
|
||||
description: Bearer token for authentication.
|
||||
- in: body
|
||||
name: body
|
||||
description: Dataset creation parameters.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Name of the dataset.
|
||||
avatar:
|
||||
type: string
|
||||
description: Base64 encoding of the avatar.
|
||||
description:
|
||||
type: string
|
||||
description: Description of the dataset.
|
||||
embedding_model:
|
||||
type: string
|
||||
description: Embedding model Name.
|
||||
permission:
|
||||
type: string
|
||||
enum: ['me', 'team']
|
||||
description: Dataset permission.
|
||||
chunk_method:
|
||||
type: string
|
||||
enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
|
||||
"picture", "presentation", "qa", "table", "tag"
|
||||
]
|
||||
description: Chunking method.
|
||||
parser_config:
|
||||
type: object
|
||||
description: Parser configuration.
|
||||
responses:
|
||||
200:
|
||||
description: Successful operation.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: object
|
||||
"""
|
||||
# Field name transformations during model dump:
|
||||
# | Original | Dump Output |
|
||||
# |----------------|-------------|
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
try:
|
||||
if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_operating_result(message=f"Dataset name '{req['name']}' already exists")
|
||||
|
||||
req["parser_config"] = get_parser_config(req["parser_id"], req["parser_config"])
|
||||
req["id"] = get_uuid()
|
||||
req["tenant_id"] = tenant_id
|
||||
req["created_by"] = tenant_id
|
||||
|
||||
ok, t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
return get_error_permission_result(message="Tenant not found")
|
||||
|
||||
if not req.get("embd_id"):
|
||||
req["embd_id"] = t.embd_id
|
||||
else:
|
||||
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
||||
if not ok:
|
||||
return err
|
||||
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_error_data_result(message="Create dataset error.(Database error)")
|
||||
|
||||
ok, k = KnowledgebaseService.get_by_id(req["id"])
|
||||
if not ok:
|
||||
return get_error_data_result(message="Dataset created failed")
|
||||
|
||||
response_data = remap_dictionary_keys(k.to_dict())
|
||||
return get_result(data=response_data)
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id):
|
||||
"""
|
||||
Delete datasets.
|
||||
---
|
||||
tags:
|
||||
- Datasets
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: header
|
||||
name: Authorization
|
||||
type: string
|
||||
required: true
|
||||
description: Bearer token for authentication.
|
||||
- in: body
|
||||
name: body
|
||||
description: Dataset deletion parameters.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- ids
|
||||
properties:
|
||||
ids:
|
||||
type: array or null
|
||||
items:
|
||||
type: string
|
||||
description: |
|
||||
Specifies the datasets to delete:
|
||||
- If `null`, all datasets will be deleted.
|
||||
- If an array of IDs, only the specified datasets will be deleted.
|
||||
- If an empty array, no datasets will be deleted.
|
||||
responses:
|
||||
200:
|
||||
description: Successful operation.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req, err = validate_and_parse_json_request(request, DeleteDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
try:
|
||||
kb_id_instance_pairs = []
|
||||
if req["ids"] is None:
|
||||
kbs = KnowledgebaseService.query(tenant_id=tenant_id)
|
||||
for kb in kbs:
|
||||
kb_id_instance_pairs.append((kb.id, kb))
|
||||
|
||||
else:
|
||||
error_kb_ids = []
|
||||
for kb_id in req["ids"]:
|
||||
kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id)
|
||||
if kb is None:
|
||||
error_kb_ids.append(kb_id)
|
||||
continue
|
||||
kb_id_instance_pairs.append((kb_id, kb))
|
||||
if len(error_kb_ids) > 0:
|
||||
return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""")
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
for kb_id, kb in kb_id_instance_pairs:
|
||||
for doc in DocumentService.query(kb_id=kb_id):
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
errors.append(f"Remove document '{doc.id}' error for dataset '{kb_id}'")
|
||||
continue
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[
|
||||
File.source_type == FileSource.KNOWLEDGEBASE,
|
||||
File.id == f2d[0].file_id,
|
||||
]
|
||||
)
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
|
||||
if not KnowledgebaseService.delete_by_id(kb_id):
|
||||
errors.append(f"Delete dataset error for {kb_id}")
|
||||
continue
|
||||
success_count += 1
|
||||
|
||||
if not errors:
|
||||
return get_result()
|
||||
|
||||
error_message = f"Successfully deleted {success_count} datasets, {len(errors)} failed. Details: {'; '.join(errors)[:128]}..."
|
||||
if success_count == 0:
|
||||
return get_error_data_result(message=error_message)
|
||||
|
||||
return get_result(data={"success_count": success_count, "errors": errors[:5]}, message=error_message)
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, dataset_id):
|
||||
"""
|
||||
Update a dataset.
|
||||
---
|
||||
tags:
|
||||
- Datasets
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: path
|
||||
name: dataset_id
|
||||
type: string
|
||||
required: true
|
||||
description: ID of the dataset to update.
|
||||
- in: header
|
||||
name: Authorization
|
||||
type: string
|
||||
required: true
|
||||
description: Bearer token for authentication.
|
||||
- in: body
|
||||
name: body
|
||||
description: Dataset update parameters.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: New name of the dataset.
|
||||
avatar:
|
||||
type: string
|
||||
description: Updated base64 encoding of the avatar.
|
||||
description:
|
||||
type: string
|
||||
description: Updated description of the dataset.
|
||||
embedding_model:
|
||||
type: string
|
||||
description: Updated embedding model Name.
|
||||
permission:
|
||||
type: string
|
||||
enum: ['me', 'team']
|
||||
description: Updated dataset permission.
|
||||
chunk_method:
|
||||
type: string
|
||||
enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
|
||||
"picture", "presentation", "qa", "table", "tag"
|
||||
]
|
||||
description: Updated chunking method.
|
||||
pagerank:
|
||||
type: integer
|
||||
description: Updated page rank.
|
||||
parser_config:
|
||||
type: object
|
||||
description: Updated parser configuration.
|
||||
responses:
|
||||
200:
|
||||
description: Successful operation.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
# Field name transformations during model dump:
|
||||
# | Original | Dump Output |
|
||||
# |----------------|-------------|
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
extras = {"dataset_id": dataset_id}
|
||||
req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
if not req:
|
||||
return get_error_argument_result(message="No properties were modified")
|
||||
|
||||
try:
|
||||
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
|
||||
if kb is None:
|
||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
|
||||
|
||||
if req.get("parser_config"):
|
||||
req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
|
||||
|
||||
if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id:
|
||||
if not req.get("parser_config"):
|
||||
req["parser_config"] = get_parser_config(chunk_method, None)
|
||||
elif "parser_config" in req and not req["parser_config"]:
|
||||
del req["parser_config"]
|
||||
|
||||
if "name" in req and req["name"].lower() != kb.name.lower():
|
||||
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
if exists:
|
||||
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
|
||||
|
||||
if "embd_id" in req:
|
||||
if not req["embd_id"]:
|
||||
req["embd_id"] = kb.embd_id
|
||||
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
|
||||
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
|
||||
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
||||
if not ok:
|
||||
return err
|
||||
|
||||
if "pagerank" in req and req["pagerank"] != kb.pagerank:
|
||||
if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity":
|
||||
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
||||
|
||||
if req["pagerank"] > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||
|
||||
ok, k = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Dataset created failed")
|
||||
|
||||
response_data = remap_dictionary_keys(k.to_dict())
|
||||
return get_result(data=response_data)
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_datasets(tenant_id):
|
||||
"""
|
||||
List datasets.
|
||||
---
|
||||
tags:
|
||||
- Datasets
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: query
|
||||
name: id
|
||||
type: string
|
||||
required: false
|
||||
description: Dataset ID to filter.
|
||||
- in: query
|
||||
name: name
|
||||
type: string
|
||||
required: false
|
||||
description: Dataset name to filter.
|
||||
- in: query
|
||||
name: page
|
||||
type: integer
|
||||
required: false
|
||||
default: 1
|
||||
description: Page number.
|
||||
- in: query
|
||||
name: page_size
|
||||
type: integer
|
||||
required: false
|
||||
default: 30
|
||||
description: Number of items per page.
|
||||
- in: query
|
||||
name: orderby
|
||||
type: string
|
||||
required: false
|
||||
default: "create_time"
|
||||
description: Field to order by.
|
||||
- in: query
|
||||
name: desc
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
description: Order in descending.
|
||||
- in: header
|
||||
name: Authorization
|
||||
type: string
|
||||
required: true
|
||||
description: Bearer token for authentication.
|
||||
responses:
|
||||
200:
|
||||
description: Successful operation.
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
"""
|
||||
args, err = validate_and_parse_request_args(request, ListDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
try:
|
||||
kb_id = request.args.get("id")
|
||||
name = args.get("name")
|
||||
if kb_id:
|
||||
kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id)
|
||||
|
||||
if not kbs:
|
||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'")
|
||||
if name:
|
||||
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
|
||||
if not kbs:
|
||||
return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'")
|
||||
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
|
||||
kbs = KnowledgebaseService.get_list(
|
||||
[m["tenant_id"] for m in tenants],
|
||||
tenant_id,
|
||||
args["page"],
|
||||
args["page_size"],
|
||||
args["orderby"],
|
||||
args["desc"],
|
||||
kb_id,
|
||||
name,
|
||||
)
|
||||
|
||||
response_data_list = []
|
||||
for kb in kbs:
|
||||
response_data_list.append(remap_dictionary_keys(kb))
|
||||
return get_result(data=response_data_list)
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
|
||||
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def knowledge_graph(tenant_id,dataset_id):
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
_, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
req = {
|
||||
"kb_id": [dataset_id],
|
||||
"knowledge_graph_kwd": ["graph"]
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||
return get_result(data=obj)
|
||||
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
if not len(sres.ids):
|
||||
return get_result(data=obj)
|
||||
|
||||
for id in sres.ids[:1]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
try:
|
||||
content_json = json.loads(sres.field[id]["content_with_weight"])
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
obj[ty] = content_json
|
||||
|
||||
if "nodes" in obj["graph"]:
|
||||
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
|
||||
if "edges" in obj["graph"]:
|
||||
node_id_set = { o["id"] for o in obj["graph"]["nodes"] }
|
||||
filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
|
||||
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
|
||||
return get_result(data=obj)
|
||||
|
||||
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
|
||||
@token_required
|
||||
def delete_knowledge_graph(tenant_id,dataset_id):
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
_, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id)
|
||||
|
||||
return get_result(data=True)
|
||||
104
api/apps/sdk/dify_retrieval.py
Normal file
104
api/apps/sdk/dify_retrieval.py
Normal file
@@ -0,0 +1,104 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
|
||||
from flask import request, jsonify
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
||||
from rag.app.tag import label_question
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
|
||||
|
||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||
@apikey_required
|
||||
@validate_request("knowledge_id", "query")
|
||||
def retrieval(tenant_id):
|
||||
req = request.json
|
||||
question = req["query"]
|
||||
kb_id = req["knowledge_id"]
|
||||
use_kg = req.get("use_kg", False)
|
||||
retrieval_setting = req.get("retrieval_setting", {})
|
||||
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
|
||||
top = int(retrieval_setting.get("top_k", 1024))
|
||||
metadata_condition = req.get("metadata_condition",{})
|
||||
metas = DocumentService.get_meta_by_kbs([kb_id])
|
||||
|
||||
doc_ids = []
|
||||
try:
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
print(metadata_condition)
|
||||
print("after",convert_conditions(metadata_condition))
|
||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition)))
|
||||
print("doc_ids",doc_ids)
|
||||
if not doc_ids and metadata_condition is not None:
|
||||
doc_ids = ['-999']
|
||||
ranks = settings.retrievaler.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
kb.tenant_id,
|
||||
[kb_id],
|
||||
page=1,
|
||||
page_size=top,
|
||||
similarity_threshold=similarity_threshold,
|
||||
vector_similarity_weight=0.3,
|
||||
top=top,
|
||||
doc_ids=doc_ids,
|
||||
rank_feature=label_question(question, [kb])
|
||||
)
|
||||
|
||||
if use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(question,
|
||||
[tenant_id],
|
||||
[kb_id],
|
||||
embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
|
||||
records = []
|
||||
for c in ranks["chunks"]:
|
||||
e, doc = DocumentService.get_by_id( c["doc_id"])
|
||||
c.pop("vector", None)
|
||||
meta = getattr(doc, 'meta_fields', {})
|
||||
meta["doc_id"] = c["doc_id"]
|
||||
records.append({
|
||||
"content": c["content_with_weight"],
|
||||
"score": c["similarity"],
|
||||
"title": c["docnm_kwd"],
|
||||
"metadata": meta
|
||||
})
|
||||
|
||||
return jsonify({"records": records})
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return build_error_result(
|
||||
message='No chunk found! Check the chunk status please!',
|
||||
code=settings.RetCode.NOT_FOUND
|
||||
)
|
||||
logging.exception(e)
|
||||
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
1497
api/apps/sdk/doc.py
Normal file
1497
api/apps/sdk/doc.py
Normal file
File diff suppressed because it is too large
Load Diff
738
api/apps/sdk/files.py
Normal file
738
api/apps/sdk/files.py
Normal file
@@ -0,0 +1,738 @@
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from pathlib import Path
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, token_required
|
||||
from api.utils import get_uuid
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def upload(tenant_id):
|
||||
"""
|
||||
Upload a file to the system.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: formData
|
||||
name: file
|
||||
type: file
|
||||
required: true
|
||||
description: The file to upload
|
||||
- in: formData
|
||||
name: parent_id
|
||||
type: string
|
||||
description: Parent folder ID where the file will be uploaded. Optional.
|
||||
responses:
|
||||
200:
|
||||
description: Successfully uploaded the file.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: File ID
|
||||
name:
|
||||
type: string
|
||||
description: File name
|
||||
size:
|
||||
type: integer
|
||||
description: File size in bytes
|
||||
type:
|
||||
type: string
|
||||
description: File type (e.g., document, folder)
|
||||
"""
|
||||
pf_id = request.form.get("parent_id")
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(data=False, message='No file part!', code=400)
|
||||
file_objs = request.files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(data=False, message='No selected file!', code=400)
|
||||
|
||||
file_res = []
|
||||
|
||||
try:
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Handle file path
|
||||
full_path = '/' + file_obj.filename
|
||||
file_obj_names = full_path.split('/')
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# Get folder path ID
|
||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# Crete file folder
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, len_id_list)
|
||||
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
location += "_"
|
||||
blob = file_obj.read()
|
||||
filename = duplicate_name(FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id)
|
||||
|
||||
file = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": last_folder.id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
}
|
||||
file = FileService.insert(file)
|
||||
STORAGE_IMPL.put(last_folder.id, location, blob)
|
||||
file_res.append(file.to_json())
|
||||
return get_json_result(data=file_res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/create', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id):
|
||||
"""
|
||||
Create a new file or folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
description: File creation parameters
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Name of the file/folder
|
||||
parent_id:
|
||||
type: string
|
||||
description: Parent folder ID. Optional.
|
||||
type:
|
||||
type: string
|
||||
enum: ["FOLDER", "VIRTUAL"]
|
||||
description: Type of the file
|
||||
responses:
|
||||
200:
|
||||
description: File created successfully.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
"""
|
||||
req = request.json
|
||||
pf_id = request.json.get("parent_id")
|
||||
input_file_type = request.json.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
try:
|
||||
if not FileService.is_parent_folder_exist(pf_id):
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=400)
|
||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
|
||||
|
||||
if input_file_type == FileType.FOLDER.value:
|
||||
file_type = FileType.FOLDER.value
|
||||
else:
|
||||
file_type = FileType.VIRTUAL.value
|
||||
|
||||
file = FileService.insert({
|
||||
"id": get_uuid(),
|
||||
"parent_id": pf_id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"name": req["name"],
|
||||
"location": "",
|
||||
"size": 0,
|
||||
"type": file_type
|
||||
})
|
||||
|
||||
return get_json_result(data=file.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def list_files(tenant_id):
|
||||
"""
|
||||
List files under a specific folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: query
|
||||
name: parent_id
|
||||
type: string
|
||||
description: Folder ID to list files from
|
||||
- in: query
|
||||
name: keywords
|
||||
type: string
|
||||
description: Search keyword filter
|
||||
- in: query
|
||||
name: page
|
||||
type: integer
|
||||
default: 1
|
||||
description: Page number
|
||||
- in: query
|
||||
name: page_size
|
||||
type: integer
|
||||
default: 15
|
||||
description: Number of results per page
|
||||
- in: query
|
||||
name: orderby
|
||||
type: string
|
||||
default: "create_time"
|
||||
description: Sort by field
|
||||
- in: query
|
||||
name: desc
|
||||
type: boolean
|
||||
default: true
|
||||
description: Descending order
|
||||
responses:
|
||||
200:
|
||||
description: Successfully retrieved file list.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
total:
|
||||
type: integer
|
||||
files:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
size:
|
||||
type: integer
|
||||
create_time:
|
||||
type: string
|
||||
format: date-time
|
||||
"""
|
||||
pf_id = request.args.get("parent_id")
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 15))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
|
||||
try:
|
||||
e, file = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
|
||||
files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(pf_id)
|
||||
if not parent_folder:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
|
||||
return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_root_folder(tenant_id):
|
||||
"""
|
||||
Get user's root folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
200:
|
||||
description: Root folder information
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: object
|
||||
properties:
|
||||
root_folder:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
"""
|
||||
try:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
return get_json_result(data={"root_folder": root_folder})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_parent_folder():
|
||||
"""
|
||||
Get parent folder info of a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: query
|
||||
name: file_id
|
||||
type: string
|
||||
required: true
|
||||
description: Target file ID
|
||||
responses:
|
||||
200:
|
||||
description: Parent folder information
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: object
|
||||
properties:
|
||||
parent_folder:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
"""
|
||||
file_id = request.args.get("file_id")
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(file_id)
|
||||
return get_json_result(data={"parent_folder": parent_folder.to_json()})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_all_parent_folders(tenant_id):
|
||||
"""
|
||||
Get all parent folders of a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: query
|
||||
name: file_id
|
||||
type: string
|
||||
required: true
|
||||
description: Target file ID
|
||||
responses:
|
||||
200:
|
||||
description: All parent folders of the file
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: object
|
||||
properties:
|
||||
parent_folders:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
"""
|
||||
file_id = request.args.get("file_id")
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
|
||||
parent_folders = FileService.get_all_parent_folders(file_id)
|
||||
parent_folders_res = [folder.to_json() for folder in parent_folders]
|
||||
return get_json_result(data={"parent_folders": parent_folders_res})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/rm', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rm(tenant_id):
|
||||
"""
|
||||
Delete one or multiple files/folders.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
description: Files to delete
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
file_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: List of file IDs to delete
|
||||
responses:
|
||||
200:
|
||||
description: Successfully deleted files
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
file_ids = req["file_ids"]
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for inner_file_id in file_id_list:
|
||||
e, file = FileService.get_by_id(inner_file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
FileService.delete_folder_by_pf_id(tenant_id, file_id)
|
||||
else:
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
if not FileService.delete(file):
|
||||
return get_json_result(message="Database error (File removal)!", code=500)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file_id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(message="Database error (Document removal)!", code=500)
|
||||
File2DocumentService.delete_by_file_id(file_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/rename', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def rename(tenant_id):
|
||||
"""
|
||||
Rename a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
description: Rename file
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
file_id:
|
||||
type: string
|
||||
description: Target file ID
|
||||
name:
|
||||
type: string
|
||||
description: New name for the file
|
||||
responses:
|
||||
200:
|
||||
description: File renamed successfully
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
|
||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(file.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
|
||||
|
||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if existing_file.name == req["name"]:
|
||||
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
|
||||
|
||||
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (File rename)!", code=500)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(req["file_id"])
|
||||
if informs:
|
||||
if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (Document rename)!", code=500)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/get/<file_id>', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get(tenant_id,file_id):
|
||||
"""
|
||||
Download a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
produces:
|
||||
- application/octet-stream
|
||||
parameters:
|
||||
- in: path
|
||||
name: file_id
|
||||
type: string
|
||||
required: true
|
||||
description: File ID to download
|
||||
responses:
|
||||
200:
|
||||
description: File stream
|
||||
schema:
|
||||
type: file
|
||||
404:
|
||||
description: File not found
|
||||
"""
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
|
||||
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name)
|
||||
if ext:
|
||||
if file.type == FileType.VISUAL.value:
|
||||
response.headers.set('Content-Type', 'image/%s' % ext.group(1))
|
||||
else:
|
||||
response.headers.set('Content-Type', 'application/%s' % ext.group(1))
|
||||
return response
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def move(tenant_id):
|
||||
"""
|
||||
Move one or multiple files to another folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
description: Move operation
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
src_file_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: Source file IDs
|
||||
dest_file_id:
|
||||
type: string
|
||||
description: Destination folder ID
|
||||
responses:
|
||||
200:
|
||||
description: Files moved successfully
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = request.json
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
parent_id = req["dest_file_id"]
|
||||
files = FileService.get_by_ids(file_ids)
|
||||
files_dict = {f.id: f for f in files}
|
||||
|
||||
for file_id in file_ids:
|
||||
file = files_dict[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_json_result(message="Parent Folder not found!", code=404)
|
||||
|
||||
FileService.move_file(file_ids, parent_id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def convert(tenant_id):
|
||||
req = request.json
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
|
||||
try:
|
||||
files = FileService.get_by_ids(file_ids)
|
||||
files_set = dict({file.id: file for file in files})
|
||||
for file_id in file_ids:
|
||||
file = files_set[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
file_ids_list = [file_id]
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for id in file_ids_list:
|
||||
informs = File2DocumentService.get_by_file_id(id)
|
||||
# delete
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(
|
||||
message="Database error (Document removal)!", code=404)
|
||||
File2DocumentService.delete_by_file_id(id)
|
||||
|
||||
# insert
|
||||
for kb_id in kb_ids:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this knowledgebase!", code=404)
|
||||
e, file = FileService.get_by_id(id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this file!", code=404)
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": tenant_id,
|
||||
"type": file.type,
|
||||
"name": file.name,
|
||||
"suffix": Path(file.name).suffix.lstrip("."),
|
||||
"location": file.location,
|
||||
"size": file.size
|
||||
})
|
||||
file2document = File2DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"file_id": id,
|
||||
"document_id": doc.id,
|
||||
})
|
||||
|
||||
file2documents.append(file2document.to_json())
|
||||
return get_json_result(data=file2documents)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
1115
api/apps/sdk/session.py
Normal file
1115
api/apps/sdk/session.py
Normal file
File diff suppressed because it is too large
Load Diff
188
api/apps/search_app.py
Normal file
188
api/apps/search_app.py
Normal file
@@ -0,0 +1,188 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from api import settings
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import DB
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
|
||||
|
||||
|
||||
@manager.route("/create", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create():
|
||||
req = request.get_json()
|
||||
search_name = req["name"]
|
||||
description = req.get("description", "")
|
||||
if not isinstance(search_name, str):
|
||||
return get_data_error_result(message="Search name must be string.")
|
||||
if search_name.strip() == "":
|
||||
return get_data_error_result(message="Search name can't be empty.")
|
||||
if len(search_name.encode("utf-8")) > 255:
|
||||
return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than 255.")
|
||||
e, _ = TenantService.get_by_id(current_user.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Authorized identity.")
|
||||
|
||||
search_name = search_name.strip()
|
||||
search_name = duplicate_name(SearchService.query, name=search_name, tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||
|
||||
req["id"] = get_uuid()
|
||||
req["name"] = search_name
|
||||
req["description"] = description
|
||||
req["tenant_id"] = current_user.id
|
||||
req["created_by"] = current_user.id
|
||||
with DB.atomic():
|
||||
try:
|
||||
if not SearchService.save(**req):
|
||||
return get_data_error_result()
|
||||
return get_json_result(data={"search_id": req["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/update", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("search_id", "name", "search_config", "tenant_id")
|
||||
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
def update():
|
||||
req = request.get_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Search name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
return get_data_error_result(message="Search name can't be empty.")
|
||||
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return get_data_error_result(message=f"Search name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
|
||||
req["name"] = req["name"].strip()
|
||||
tenant_id = req["tenant_id"]
|
||||
e, _ = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Authorized identity.")
|
||||
|
||||
search_id = req["search_id"]
|
||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0]
|
||||
if not search_app:
|
||||
return get_json_result(data=False, message=f"Cannot find search {search_id}", code=settings.RetCode.DATA_ERROR)
|
||||
|
||||
if req["name"].lower() != search_app.name.lower() and len(SearchService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) >= 1:
|
||||
return get_data_error_result(message="Duplicated search name.")
|
||||
|
||||
if "search_config" in req:
|
||||
current_config = search_app.search_config or {}
|
||||
new_config = req["search_config"]
|
||||
|
||||
if not isinstance(new_config, dict):
|
||||
return get_data_error_result(message="search_config must be a JSON object")
|
||||
|
||||
updated_config = {**current_config, **new_config}
|
||||
req["search_config"] = updated_config
|
||||
|
||||
req.pop("search_id", None)
|
||||
req.pop("tenant_id", None)
|
||||
|
||||
updated = SearchService.update_by_id(search_id, req)
|
||||
if not updated:
|
||||
return get_data_error_result(message="Failed to update search")
|
||||
|
||||
e, updated_search = SearchService.get_by_id(search_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Failed to fetch updated search")
|
||||
|
||||
return get_json_result(data=updated_search.to_dict())
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/detail", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def detail():
|
||||
search_id = request.args["search_id"]
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
search = SearchService.get_detail(search_id)
|
||||
if not search:
|
||||
return get_data_error_result(message="Can't find this Search App!")
|
||||
return get_json_result(data=search)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def list_search_app():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = request.get_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
# tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
# tenants = [m["tenant_id"] for m in tenants]
|
||||
tenants = []
|
||||
search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords)
|
||||
else:
|
||||
tenants = owner_ids
|
||||
search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, 0, 0, orderby, desc, keywords)
|
||||
search_apps = [search_app for search_app in search_apps if search_app["tenant_id"] in tenants]
|
||||
total = len(search_apps)
|
||||
if page_number and items_per_page:
|
||||
search_apps = search_apps[(page_number - 1) * items_per_page : page_number * items_per_page]
|
||||
return get_json_result(data={"search_apps": search_apps, "total": total})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/rm", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("search_id")
|
||||
def rm():
|
||||
req = request.get_json()
|
||||
search_id = req["search_id"]
|
||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
if not SearchService.delete_by_id(search_id):
|
||||
return get_data_error_result(message=f"Failed to delete search App {search_id}")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
334
api/apps/system_app.py
Normal file
334
api/apps/system_app.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api import settings
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
from api.utils.api_utils import (
|
||||
get_json_result,
|
||||
get_data_error_result,
|
||||
server_error_response,
|
||||
generate_confirmation_token,
|
||||
)
|
||||
from api.versions import get_ragflow_version
|
||||
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
|
||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def version():
|
||||
"""
|
||||
Get the current version of the application.
|
||||
---
|
||||
tags:
|
||||
- System
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
200:
|
||||
description: Version retrieved successfully.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
version:
|
||||
type: string
|
||||
description: Version number.
|
||||
"""
|
||||
return get_json_result(data=get_ragflow_version())
|
||||
|
||||
|
||||
@manager.route("/status", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def status():
|
||||
"""
|
||||
Get the system status.
|
||||
---
|
||||
tags:
|
||||
- System
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
200:
|
||||
description: System is operational.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
es:
|
||||
type: object
|
||||
description: Elasticsearch status.
|
||||
storage:
|
||||
type: object
|
||||
description: Storage status.
|
||||
database:
|
||||
type: object
|
||||
description: Database status.
|
||||
503:
|
||||
description: Service unavailable.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
description: Error message.
|
||||
"""
|
||||
res = {}
|
||||
st = timer()
|
||||
try:
|
||||
res["doc_engine"] = settings.docStoreConn.health()
|
||||
res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
||||
except Exception as e:
|
||||
res["doc_engine"] = {
|
||||
"type": "unknown",
|
||||
"status": "red",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
st = timer()
|
||||
try:
|
||||
STORAGE_IMPL.health()
|
||||
res["storage"] = {
|
||||
"storage": STORAGE_IMPL_TYPE.lower(),
|
||||
"status": "green",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
}
|
||||
except Exception as e:
|
||||
res["storage"] = {
|
||||
"storage": STORAGE_IMPL_TYPE.lower(),
|
||||
"status": "red",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
st = timer()
|
||||
try:
|
||||
KnowledgebaseService.get_by_id("x")
|
||||
res["database"] = {
|
||||
"database": settings.DATABASE_TYPE.lower(),
|
||||
"status": "green",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
}
|
||||
except Exception as e:
|
||||
res["database"] = {
|
||||
"database": settings.DATABASE_TYPE.lower(),
|
||||
"status": "red",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
st = timer()
|
||||
try:
|
||||
if not REDIS_CONN.health():
|
||||
raise Exception("Lost connection!")
|
||||
res["redis"] = {
|
||||
"status": "green",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
}
|
||||
except Exception as e:
|
||||
res["redis"] = {
|
||||
"status": "red",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
task_executor_heartbeats = {}
|
||||
try:
|
||||
task_executors = REDIS_CONN.smembers("TASKEXE")
|
||||
now = datetime.now().timestamp()
|
||||
for task_executor_id in task_executors:
|
||||
heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60*30, now)
|
||||
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
|
||||
task_executor_heartbeats[task_executor_id] = heartbeats
|
||||
except Exception:
|
||||
logging.exception("get task executor heartbeats failed!")
|
||||
res["task_executor_heartbeats"] = task_executor_heartbeats
|
||||
|
||||
return get_json_result(data=res)
|
||||
|
||||
|
||||
@manager.route("/healthz", methods=["GET"]) # noqa: F821
|
||||
def healthz():
|
||||
result, all_ok = run_health_checks()
|
||||
return jsonify(result), (200 if all_ok else 500)
|
||||
|
||||
|
||||
@manager.route("/ping", methods=["GET"]) # noqa: F821
|
||||
def ping():
|
||||
return "pong", 200
|
||||
|
||||
|
||||
@manager.route("/new_token", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def new_token():
|
||||
"""
|
||||
Generate a new API token.
|
||||
---
|
||||
tags:
|
||||
- API Tokens
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: query
|
||||
name: name
|
||||
type: string
|
||||
required: false
|
||||
description: Name of the token.
|
||||
responses:
|
||||
200:
|
||||
description: Token generated successfully.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
description: The generated API token.
|
||||
"""
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id
|
||||
obj = {
|
||||
"tenant_id": tenant_id,
|
||||
"token": generate_confirmation_token(tenant_id),
|
||||
"beta": generate_confirmation_token(generate_confirmation_token(tenant_id)).replace("ragflow-", "")[:32],
|
||||
"create_time": current_timestamp(),
|
||||
"create_date": datetime_format(datetime.now()),
|
||||
"update_time": None,
|
||||
"update_date": None,
|
||||
}
|
||||
|
||||
if not APITokenService.save(**obj):
|
||||
return get_data_error_result(message="Fail to new a dialog!")
|
||||
|
||||
return get_json_result(data=obj)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/token_list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def token_list():
|
||||
"""
|
||||
List all API tokens for the current user.
|
||||
---
|
||||
tags:
|
||||
- API Tokens
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
200:
|
||||
description: List of API tokens.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
tokens:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
description: The API token.
|
||||
name:
|
||||
type: string
|
||||
description: Name of the token.
|
||||
create_time:
|
||||
type: string
|
||||
description: Token creation time.
|
||||
"""
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id
|
||||
objs = APITokenService.query(tenant_id=tenant_id)
|
||||
objs = [o.to_dict() for o in objs]
|
||||
for o in objs:
|
||||
if not o["beta"]:
|
||||
o["beta"] = generate_confirmation_token(generate_confirmation_token(tenants[0].tenant_id)).replace("ragflow-", "")[:32]
|
||||
APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o)
|
||||
return get_json_result(data=objs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/token/<token>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
def rm(token):
|
||||
"""
|
||||
Remove an API token.
|
||||
---
|
||||
tags:
|
||||
- API Tokens
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: path
|
||||
name: token
|
||||
type: string
|
||||
required: true
|
||||
description: The API token to remove.
|
||||
responses:
|
||||
200:
|
||||
description: Token removed successfully.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
success:
|
||||
type: boolean
|
||||
description: Deletion status.
|
||||
"""
|
||||
APITokenService.filter_delete(
|
||||
[APIToken.tenant_id == current_user.id, APIToken.token == token]
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/config', methods=['GET']) # noqa: F821
|
||||
def get_config():
|
||||
"""
|
||||
Get system configuration.
|
||||
---
|
||||
tags:
|
||||
- System
|
||||
responses:
|
||||
200:
|
||||
description: Return system configuration
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
registerEnable:
|
||||
type: integer 0 means disabled, 1 means enabled
|
||||
description: Whether user registration is enabled
|
||||
"""
|
||||
return get_json_result(data={
|
||||
"registerEnabled": settings.REGISTER_ENABLED
|
||||
})
|
||||
138
api/apps/tenant_app.py
Normal file
138
api/apps/tenant_app.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api import settings
|
||||
from api.apps import smtp_mail_server
|
||||
from api.db import UserTenantRole, StatusEnum
|
||||
from api.db.db_models import UserTenant
|
||||
from api.db.services.user_service import UserTenantService, UserService
|
||||
|
||||
from api.utils import get_uuid, delta_seconds
|
||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
||||
from api.utils.web_utils import send_invite_email
|
||||
|
||||
|
||||
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def user_list(tenant_id):
|
||||
if current_user.id != tenant_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
users = UserTenantService.get_by_tenant_id(tenant_id)
|
||||
for u in users:
|
||||
u["delta_seconds"] = delta_seconds(str(u["update_date"]))
|
||||
return get_json_result(data=users)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/<tenant_id>/user', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("email")
|
||||
def create(tenant_id):
|
||||
if current_user.id != tenant_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
invite_user_email = req["email"]
|
||||
invite_users = UserService.query(email=invite_user_email)
|
||||
if not invite_users:
|
||||
return get_data_error_result(message="User not found.")
|
||||
|
||||
user_id_to_invite = invite_users[0].id
|
||||
user_tenants = UserTenantService.query(user_id=user_id_to_invite, tenant_id=tenant_id)
|
||||
if user_tenants:
|
||||
user_tenant_role = user_tenants[0].role
|
||||
if user_tenant_role == UserTenantRole.NORMAL:
|
||||
return get_data_error_result(message=f"{invite_user_email} is already in the team.")
|
||||
if user_tenant_role == UserTenantRole.OWNER:
|
||||
return get_data_error_result(message=f"{invite_user_email} is the owner of the team.")
|
||||
return get_data_error_result(message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.")
|
||||
|
||||
UserTenantService.save(
|
||||
id=get_uuid(),
|
||||
user_id=user_id_to_invite,
|
||||
tenant_id=tenant_id,
|
||||
invited_by=current_user.id,
|
||||
role=UserTenantRole.INVITE,
|
||||
status=StatusEnum.VALID.value)
|
||||
|
||||
if smtp_mail_server and settings.SMTP_CONF:
|
||||
from threading import Thread
|
||||
|
||||
user_name = ""
|
||||
_, user = UserService.get_by_id(current_user.id)
|
||||
if user:
|
||||
user_name = user.nickname
|
||||
|
||||
Thread(
|
||||
target=send_invite_email,
|
||||
args=(invite_user_email, settings.MAIL_FRONTEND_URL, tenant_id, user_name or current_user.email),
|
||||
daemon=True
|
||||
).start()
|
||||
|
||||
usr = invite_users[0].to_dict()
|
||||
usr = {k: v for k, v in usr.items() if k in ["id", "avatar", "email", "nickname"]}
|
||||
|
||||
return get_json_result(data=usr)
|
||||
|
||||
|
||||
@manager.route('/<tenant_id>/user/<user_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
def rm(tenant_id, user_id):
|
||||
if current_user.id != tenant_id and current_user.id != user_id:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
UserTenantService.filter_delete([UserTenant.tenant_id == tenant_id, UserTenant.user_id == user_id])
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def tenant_list():
|
||||
try:
|
||||
users = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
for u in users:
|
||||
u["delta_seconds"] = delta_seconds(str(u["update_date"]))
|
||||
return get_json_result(data=users)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/agree/<tenant_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
def agree(tenant_id):
|
||||
try:
|
||||
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], {"role": UserTenantRole.NORMAL})
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
827
api/apps/user_app.py
Normal file
827
api/apps/user_app.py
Normal file
@@ -0,0 +1,827 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
|
||||
from flask import redirect, request, session
|
||||
from flask_login import current_user, login_required, login_user, logout_user
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
from api import settings
|
||||
from api.apps.auth import get_auth_client
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import get_init_tenant_llm
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from api.utils import (
|
||||
current_timestamp,
|
||||
datetime_format,
|
||||
download_img,
|
||||
get_format_time,
|
||||
get_uuid,
|
||||
)
|
||||
from api.utils.api_utils import (
|
||||
construct_response,
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
def login():
|
||||
"""
|
||||
User login endpoint.
|
||||
---
|
||||
tags:
|
||||
- User
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
description: Login credentials.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
email:
|
||||
type: string
|
||||
description: User email.
|
||||
password:
|
||||
type: string
|
||||
description: User password.
|
||||
responses:
|
||||
200:
|
||||
description: Login successful.
|
||||
schema:
|
||||
type: object
|
||||
401:
|
||||
description: Authentication failed.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
if not request.json:
|
||||
return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||||
|
||||
email = request.json.get("email", "")
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||
message=f"Email: {email} is not registered!",
|
||||
)
|
||||
|
||||
password = request.json.get("password")
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except BaseException:
|
||||
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=settings.RetCode.FORBIDDEN,
|
||||
message="This account has been disabled, please contact the administrator!",
|
||||
)
|
||||
elif user:
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||
message="Email and password do not match!",
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
|
||||
def get_login_channels():
|
||||
"""
|
||||
Get all supported authentication channels.
|
||||
"""
|
||||
try:
|
||||
channels = []
|
||||
for channel, config in settings.OAUTH_CONFIG.items():
|
||||
channels.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"display_name": config.get("display_name", channel.title()),
|
||||
"icon": config.get("icon", "sso"),
|
||||
}
|
||||
)
|
||||
return get_json_result(data=channels)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=settings.RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_login(channel):
|
||||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||||
if not channel_config:
|
||||
raise ValueError(f"Invalid channel name: {channel}")
|
||||
auth_cli = get_auth_client(channel_config)
|
||||
|
||||
state = get_uuid()
|
||||
session["oauth_state"] = state
|
||||
auth_url = auth_cli.get_authorization_url(state)
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_callback(channel):
|
||||
"""
|
||||
Handle the OAuth/OIDC callback for various channels dynamically.
|
||||
"""
|
||||
try:
|
||||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||||
if not channel_config:
|
||||
raise ValueError(f"Invalid channel name: {channel}")
|
||||
auth_cli = get_auth_client(channel_config)
|
||||
|
||||
# Check the state
|
||||
state = request.args.get("state")
|
||||
if not state or state != session.get("oauth_state"):
|
||||
return redirect("/?error=invalid_state")
|
||||
session.pop("oauth_state", None)
|
||||
|
||||
# Obtain the authorization code
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return redirect("/?error=missing_code")
|
||||
|
||||
# Exchange authorization code for access token
|
||||
token_info = auth_cli.exchange_code_for_token(code)
|
||||
access_token = token_info.get("access_token")
|
||||
if not access_token:
|
||||
return redirect("/?error=token_failed")
|
||||
|
||||
id_token = token_info.get("id_token")
|
||||
|
||||
# Fetch user info
|
||||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||||
if not user_info.email:
|
||||
return redirect("/?error=email_missing")
|
||||
|
||||
# Login or register
|
||||
users = UserService.query(email=user_info.email)
|
||||
user_id = get_uuid()
|
||||
|
||||
if not users:
|
||||
try:
|
||||
try:
|
||||
avatar = download_img(user_info.avatar_url)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
avatar = ""
|
||||
|
||||
users = user_register(
|
||||
user_id,
|
||||
{
|
||||
"access_token": get_uuid(),
|
||||
"email": user_info.email,
|
||||
"avatar": avatar,
|
||||
"nickname": user_info.nickname,
|
||||
"login_channel": channel,
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
},
|
||||
)
|
||||
|
||||
if not users:
|
||||
raise Exception(f"Failed to register {user_info.email}")
|
||||
if len(users) > 1:
|
||||
raise Exception(f"Same email: {user_info.email} exists!")
|
||||
|
||||
# Try to log in
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return redirect(f"/?auth={user.get_id()}")
|
||||
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
logging.exception(e)
|
||||
return redirect(f"/?error={str(e)}")
|
||||
|
||||
# User exists, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect(f"/?auth={user.get_id()}")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return redirect(f"/?error={str(e)}")
|
||||
|
||||
|
||||
@manager.route("/github_callback", methods=["GET"]) # noqa: F821
|
||||
def github_callback():
|
||||
"""
|
||||
**Deprecated**, Use `/oauth/callback/<channel>` instead.
|
||||
|
||||
GitHub OAuth callback endpoint.
|
||||
---
|
||||
tags:
|
||||
- OAuth
|
||||
parameters:
|
||||
- in: query
|
||||
name: code
|
||||
type: string
|
||||
required: true
|
||||
description: Authorization code from GitHub.
|
||||
responses:
|
||||
200:
|
||||
description: Authentication successful.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
import requests
|
||||
|
||||
res = requests.post(
|
||||
settings.GITHUB_OAUTH.get("url"),
|
||||
data={
|
||||
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
|
||||
"code": request.args.get("code"),
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
res = res.json()
|
||||
if "error" in res:
|
||||
return redirect("/?error=%s" % res["error_description"])
|
||||
|
||||
if "user:email" not in res["scope"].split(","):
|
||||
return redirect("/?error=user:email not in scope")
|
||||
|
||||
session["access_token"] = res["access_token"]
|
||||
session["access_token_from"] = "github"
|
||||
user_info = user_info_from_github(session["access_token"])
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
if not users:
|
||||
# User isn't try to register
|
||||
try:
|
||||
try:
|
||||
avatar = download_img(user_info["avatar_url"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
avatar = ""
|
||||
users = user_register(
|
||||
user_id,
|
||||
{
|
||||
"access_token": session["access_token"],
|
||||
"email": email_address,
|
||||
"avatar": avatar,
|
||||
"nickname": user_info["login"],
|
||||
"login_channel": "github",
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
},
|
||||
)
|
||||
if not users:
|
||||
raise Exception(f"Fail to register {email_address}.")
|
||||
if len(users) > 1:
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
|
||||
# Try to log in
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
logging.exception(e)
|
||||
return redirect("/?error=%s" % str(e))
|
||||
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
|
||||
|
||||
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
|
||||
def feishu_callback():
|
||||
"""
|
||||
Feishu OAuth callback endpoint.
|
||||
---
|
||||
tags:
|
||||
- OAuth
|
||||
parameters:
|
||||
- in: query
|
||||
name: code
|
||||
type: string
|
||||
required: true
|
||||
description: Authorization code from Feishu.
|
||||
responses:
|
||||
200:
|
||||
description: Authentication successful.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
import requests
|
||||
|
||||
app_access_token_res = requests.post(
|
||||
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"app_id": settings.FEISHU_OAUTH.get("app_id"),
|
||||
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
|
||||
}
|
||||
),
|
||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
)
|
||||
app_access_token_res = app_access_token_res.json()
|
||||
if app_access_token_res["code"] != 0:
|
||||
return redirect("/?error=%s" % app_access_token_res)
|
||||
|
||||
res = requests.post(
|
||||
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
|
||||
"code": request.args.get("code"),
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {app_access_token_res['app_access_token']}",
|
||||
},
|
||||
)
|
||||
res = res.json()
|
||||
if res["code"] != 0:
|
||||
return redirect("/?error=%s" % res["message"])
|
||||
|
||||
if "contact:user.email:readonly" not in res["data"]["scope"].split():
|
||||
return redirect("/?error=contact:user.email:readonly not in scope")
|
||||
session["access_token"] = res["data"]["access_token"]
|
||||
session["access_token_from"] = "feishu"
|
||||
user_info = user_info_from_feishu(session["access_token"])
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
if not users:
|
||||
# User isn't try to register
|
||||
try:
|
||||
try:
|
||||
avatar = download_img(user_info["avatar_url"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
avatar = ""
|
||||
users = user_register(
|
||||
user_id,
|
||||
{
|
||||
"access_token": session["access_token"],
|
||||
"email": email_address,
|
||||
"avatar": avatar,
|
||||
"nickname": user_info["en_name"],
|
||||
"login_channel": "feishu",
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
},
|
||||
)
|
||||
if not users:
|
||||
raise Exception(f"Fail to register {email_address}.")
|
||||
if len(users) > 1:
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
|
||||
# Try to log in
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
logging.exception(e)
|
||||
return redirect("/?error=%s" % str(e))
|
||||
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
|
||||
|
||||
def user_info_from_feishu(access_token):
|
||||
import requests
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||
user_info = res.json()["data"]
|
||||
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
import requests
|
||||
|
||||
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(
|
||||
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||
headers=headers,
|
||||
).json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
@manager.route("/logout", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def log_out():
|
||||
"""
|
||||
User logout endpoint.
|
||||
---
|
||||
tags:
|
||||
- User
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
200:
|
||||
description: Logout successful.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
current_user.access_token = f"INVALID_{secrets.token_hex(16)}"
|
||||
current_user.save()
|
||||
logout_user()
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def setting_user():
|
||||
"""
|
||||
Update user settings.
|
||||
---
|
||||
tags:
|
||||
- User
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
description: User settings to update.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
nickname:
|
||||
type: string
|
||||
description: New nickname.
|
||||
email:
|
||||
type: string
|
||||
description: New email.
|
||||
responses:
|
||||
200:
|
||||
description: Settings updated successfully.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
update_dict = {}
|
||||
request_data = request.json
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||
message="Password error!",
|
||||
)
|
||||
|
||||
if new_password:
|
||||
update_dict["password"] = generate_password_hash(decrypt(new_password))
|
||||
|
||||
for k in request_data.keys():
|
||||
if k in [
|
||||
"password",
|
||||
"new_password",
|
||||
"email",
|
||||
"status",
|
||||
"is_superuser",
|
||||
"login_channel",
|
||||
"is_anonymous",
|
||||
"is_active",
|
||||
"is_authenticated",
|
||||
"last_login_time",
|
||||
]:
|
||||
continue
|
||||
update_dict[k] = request_data[k]
|
||||
|
||||
try:
|
||||
UserService.update_by_id(current_user.id, update_dict)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/info", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def user_profile():
|
||||
"""
|
||||
Get user profile information.
|
||||
---
|
||||
tags:
|
||||
- User
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
200:
|
||||
description: User profile retrieved successfully.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: User ID.
|
||||
nickname:
|
||||
type: string
|
||||
description: User nickname.
|
||||
email:
|
||||
type: string
|
||||
description: User email.
|
||||
"""
|
||||
return get_json_result(data=current_user.to_dict())
|
||||
|
||||
|
||||
def rollback_user_registration(user_id):
|
||||
try:
|
||||
UserService.delete_by_id(user_id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
TenantService.delete_by_id(user_id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
u = UserTenantService.query(tenant_id=user_id)
|
||||
if u:
|
||||
UserTenantService.delete_by_id(u[0].id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def user_register(user_id, user):
|
||||
user["id"] = user_id
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
"name": user["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
"rerank_id": settings.RERANK_MDL,
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
"user_id": user_id,
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER,
|
||||
}
|
||||
file_id = get_uuid()
|
||||
file = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": user_id,
|
||||
"created_by": user_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
|
||||
tenant_llm = get_init_tenant_llm(user_id)
|
||||
|
||||
if not UserService.save(**user):
|
||||
return
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
FileService.insert(file)
|
||||
return UserService.query(email=user["email"])
|
||||
|
||||
|
||||
@manager.route("/register", methods=["POST"]) # noqa: F821
|
||||
@validate_request("nickname", "email", "password")
|
||||
def user_add():
|
||||
"""
|
||||
Register a new user.
|
||||
---
|
||||
tags:
|
||||
- User
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
description: Registration details.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
nickname:
|
||||
type: string
|
||||
description: User nickname.
|
||||
email:
|
||||
type: string
|
||||
description: User email.
|
||||
password:
|
||||
type: string
|
||||
description: User password.
|
||||
responses:
|
||||
200:
|
||||
description: Registration successful.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
|
||||
if not settings.REGISTER_ENABLED:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="User registration is disabled!",
|
||||
code=settings.RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
req = request.json
|
||||
email_address = req["email"]
|
||||
|
||||
# Validate the email address
|
||||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message=f"Invalid email address: {email_address}!",
|
||||
code=settings.RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
# Check if the email address is already used
|
||||
if UserService.query(email=email_address):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message=f"Email: {email_address} has already registered!",
|
||||
code=settings.RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
# Construct user info data
|
||||
nickname = req["nickname"]
|
||||
user_dict = {
|
||||
"access_token": get_uuid(),
|
||||
"email": email_address,
|
||||
"nickname": nickname,
|
||||
"password": decrypt(req["password"]),
|
||||
"login_channel": "password",
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
}
|
||||
|
||||
user_id = get_uuid()
|
||||
try:
|
||||
users = user_register(user_id, user_dict)
|
||||
if not users:
|
||||
raise Exception(f"Fail to register {email_address}.")
|
||||
if len(users) > 1:
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return construct_response(
|
||||
data=user.to_json(),
|
||||
auth=user.get_id(),
|
||||
message=f"{nickname}, welcome aboard!",
|
||||
)
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
logging.exception(e)
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message=f"User registration failure, error: {str(e)}",
|
||||
code=settings.RetCode.EXCEPTION_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def tenant_info():
|
||||
"""
|
||||
Get tenant information.
|
||||
---
|
||||
tags:
|
||||
- Tenant
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
200:
|
||||
description: Tenant information retrieved successfully.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
tenant_id:
|
||||
type: string
|
||||
description: Tenant ID.
|
||||
name:
|
||||
type: string
|
||||
description: Tenant name.
|
||||
llm_id:
|
||||
type: string
|
||||
description: LLM ID.
|
||||
embd_id:
|
||||
type: string
|
||||
description: Embedding model ID.
|
||||
"""
|
||||
try:
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
return get_json_result(data=tenants[0])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
|
||||
def set_tenant_info():
|
||||
"""
|
||||
Update tenant information.
|
||||
---
|
||||
tags:
|
||||
- Tenant
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
description: Tenant information to update.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
tenant_id:
|
||||
type: string
|
||||
description: Tenant ID.
|
||||
llm_id:
|
||||
type: string
|
||||
description: LLM ID.
|
||||
embd_id:
|
||||
type: string
|
||||
description: Embedding model ID.
|
||||
asr_id:
|
||||
type: string
|
||||
description: ASR model ID.
|
||||
img2txt_id:
|
||||
type: string
|
||||
description: Image to Text model ID.
|
||||
responses:
|
||||
200:
|
||||
description: Tenant information updated successfully.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = request.json
|
||||
try:
|
||||
tid = req.pop("tenant_id")
|
||||
TenantService.update_by_id(tid, req)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
540
api/apps/user_app_fastapi.py
Normal file
540
api/apps/user_app_fastapi.py
Normal file
@@ -0,0 +1,540 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel, EmailStr
|
||||
try:
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
except ImportError:
|
||||
# 如果没有werkzeug,使用passlib作为替代
|
||||
from passlib.context import CryptContext
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def check_password_hash(hashed, password):
|
||||
return pwd_context.verify(password, hashed)
|
||||
|
||||
def generate_password_hash(password):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
from api import settings
|
||||
from api.apps.auth import get_auth_client
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import get_init_tenant_llm
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from api.utils import (
|
||||
current_timestamp,
|
||||
datetime_format,
|
||||
download_img,
|
||||
get_format_time,
|
||||
get_uuid,
|
||||
)
|
||||
from api.utils.api_utils import (
|
||||
construct_response,
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
# 安全方案
|
||||
security = HTTPBearer()
|
||||
|
||||
# Pydantic模型
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
nickname: str
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
class UserSettingRequest(BaseModel):
|
||||
nickname: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
new_password: Optional[str] = None
|
||||
|
||||
class TenantInfoRequest(BaseModel):
|
||||
tenant_id: str
|
||||
asr_id: str
|
||||
embd_id: str
|
||||
img2txt_id: str
|
||||
llm_id: str
|
||||
|
||||
# 依赖项:获取当前用户
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""获取当前用户"""
|
||||
from api.db import StatusEnum
|
||||
try:
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
except ImportError:
|
||||
# 如果没有itsdangerous,使用jwt作为替代
|
||||
import jwt
|
||||
Serializer = jwt
|
||||
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = credentials.credentials
|
||||
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
if not access_token or not access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication attempt with empty access token"
|
||||
)
|
||||
|
||||
# Access tokens should be UUIDs (32 hex characters)
|
||||
if len(access_token.strip()) < 32:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Authentication attempt with invalid token format: {len(access_token)} chars"
|
||||
)
|
||||
|
||||
user = UserService.query(
|
||||
access_token=access_token, status=StatusEnum.VALID.value
|
||||
)
|
||||
if user:
|
||||
if not user[0].access_token or not user[0].access_token.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"User {user[0].email} has empty access_token in database"
|
||||
)
|
||||
return user[0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"load_user got exception {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authorization header required"
|
||||
)
|
||||
|
||||
@router.post("/login")
|
||||
async def login(request: LoginRequest):
|
||||
"""
|
||||
用户登录端点
|
||||
"""
|
||||
email = request.email
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Email: {email} is not registered!"
|
||||
)
|
||||
|
||||
password = request.password
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except BaseException:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Fail to crypt password"
|
||||
)
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="This account has been disabled, please contact the administrator!"
|
||||
)
|
||||
elif user:
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return construct_response(data=response_data, auth=user.get_id(), message=msg)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Email and password do not match!"
|
||||
)
|
||||
|
||||
@router.get("/login/channels")
|
||||
async def get_login_channels():
|
||||
"""
|
||||
获取所有支持的身份验证渠道
|
||||
"""
|
||||
try:
|
||||
channels = []
|
||||
for channel, config in settings.OAUTH_CONFIG.items():
|
||||
channels.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"display_name": config.get("display_name", channel.title()),
|
||||
"icon": config.get("icon", "sso"),
|
||||
}
|
||||
)
|
||||
return get_json_result(data=channels)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Load channels failure, error: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/login/{channel}")
|
||||
async def oauth_login(channel: str, request: Request):
|
||||
"""OAuth登录"""
|
||||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||||
if not channel_config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid channel name: {channel}"
|
||||
)
|
||||
|
||||
auth_cli = get_auth_client(channel_config)
|
||||
state = get_uuid()
|
||||
|
||||
# 在FastAPI中,我们需要使用session来存储state
|
||||
# 这里简化处理,实际应该使用FastAPI的session管理
|
||||
auth_url = auth_cli.get_authorization_url(state)
|
||||
return RedirectResponse(url=auth_url)
|
||||
|
||||
@router.get("/oauth/callback/{channel}")
|
||||
async def oauth_callback(channel: str, request: Request):
|
||||
"""
|
||||
处理各种渠道的OAuth/OIDC回调
|
||||
"""
|
||||
try:
|
||||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||||
if not channel_config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid channel name: {channel}"
|
||||
)
|
||||
|
||||
auth_cli = get_auth_client(channel_config)
|
||||
|
||||
# 检查state
|
||||
state = request.query_params.get("state")
|
||||
# 在实际应用中,应该从session中获取state进行比较
|
||||
if not state:
|
||||
return RedirectResponse(url="/?error=invalid_state")
|
||||
|
||||
# 获取授权码
|
||||
code = request.query_params.get("code")
|
||||
if not code:
|
||||
return RedirectResponse(url="/?error=missing_code")
|
||||
|
||||
# 交换授权码获取访问令牌
|
||||
token_info = auth_cli.exchange_code_for_token(code)
|
||||
access_token = token_info.get("access_token")
|
||||
if not access_token:
|
||||
return RedirectResponse(url="/?error=token_failed")
|
||||
|
||||
id_token = token_info.get("id_token")
|
||||
|
||||
# 获取用户信息
|
||||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||||
if not user_info.email:
|
||||
return RedirectResponse(url="/?error=email_missing")
|
||||
|
||||
# 登录或注册
|
||||
users = UserService.query(email=user_info.email)
|
||||
user_id = get_uuid()
|
||||
|
||||
if not users:
|
||||
try:
|
||||
try:
|
||||
avatar = download_img(user_info.avatar_url)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
avatar = ""
|
||||
|
||||
users = user_register(
|
||||
user_id,
|
||||
{
|
||||
"access_token": get_uuid(),
|
||||
"email": user_info.email,
|
||||
"avatar": avatar,
|
||||
"nickname": user_info.nickname,
|
||||
"login_channel": channel,
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
},
|
||||
)
|
||||
|
||||
if not users:
|
||||
raise Exception(f"Failed to register {user_info.email}")
|
||||
if len(users) > 1:
|
||||
raise Exception(f"Same email: {user_info.email} exists!")
|
||||
|
||||
# 尝试登录
|
||||
user = users[0]
|
||||
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
logging.exception(e)
|
||||
return RedirectResponse(url=f"/?error={str(e)}")
|
||||
|
||||
# 用户存在,尝试登录
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return RedirectResponse(url="/?error=user_inactive")
|
||||
|
||||
user.save()
|
||||
return RedirectResponse(url=f"/?auth={user.get_id()}")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return RedirectResponse(url=f"/?error={str(e)}")
|
||||
|
||||
@router.get("/logout")
|
||||
async def log_out(current_user = Depends(get_current_user)):
|
||||
"""
|
||||
用户登出端点
|
||||
"""
|
||||
current_user.access_token = f"INVALID_{secrets.token_hex(16)}"
|
||||
current_user.save()
|
||||
return get_json_result(data=True)
|
||||
|
||||
@router.post("/setting")
|
||||
async def setting_user(request: UserSettingRequest, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
更新用户设置
|
||||
"""
|
||||
update_dict = {}
|
||||
request_data = request.dict()
|
||||
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Password error!"
|
||||
)
|
||||
|
||||
if new_password:
|
||||
update_dict["password"] = generate_password_hash(decrypt(new_password))
|
||||
|
||||
for k in request_data.keys():
|
||||
if k in [
|
||||
"password",
|
||||
"new_password",
|
||||
"email",
|
||||
"status",
|
||||
"is_superuser",
|
||||
"login_channel",
|
||||
"is_anonymous",
|
||||
"is_active",
|
||||
"is_authenticated",
|
||||
"last_login_time",
|
||||
]:
|
||||
continue
|
||||
update_dict[k] = request_data[k]
|
||||
|
||||
try:
|
||||
UserService.update_by_id(current_user.id, update_dict)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Update failure!"
|
||||
)
|
||||
|
||||
@router.get("/info")
|
||||
async def user_profile(current_user = Depends(get_current_user)):
|
||||
"""
|
||||
获取用户配置文件信息
|
||||
"""
|
||||
return get_json_result(data=current_user.to_dict())
|
||||
|
||||
def rollback_user_registration(user_id):
|
||||
"""回滚用户注册"""
|
||||
try:
|
||||
UserService.delete_by_id(user_id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
TenantService.delete_by_id(user_id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
u = UserTenantService.query(tenant_id=user_id)
|
||||
if u:
|
||||
UserTenantService.delete_by_id(u[0].id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def user_register(user_id, user):
|
||||
"""用户注册"""
|
||||
user["id"] = user_id
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
"name": user["nickname"] + "'s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
"rerank_id": settings.RERANK_MDL,
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
"user_id": user_id,
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER,
|
||||
}
|
||||
file_id = get_uuid()
|
||||
file = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": user_id,
|
||||
"created_by": user_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
|
||||
tenant_llm = get_init_tenant_llm(user_id)
|
||||
|
||||
if not UserService.save(**user):
|
||||
return
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
FileService.insert(file)
|
||||
return UserService.query(email=user["email"])
|
||||
|
||||
@router.post("/register")
|
||||
async def user_add(request: RegisterRequest):
|
||||
"""
|
||||
注册新用户
|
||||
"""
|
||||
if not settings.REGISTER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User registration is disabled!"
|
||||
)
|
||||
|
||||
email_address = request.email
|
||||
|
||||
# 验证邮箱地址
|
||||
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid email address: {email_address}!"
|
||||
)
|
||||
|
||||
# 检查邮箱地址是否已被使用
|
||||
if UserService.query(email=email_address):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Email: {email_address} has already registered!"
|
||||
)
|
||||
|
||||
# 构建用户信息数据
|
||||
nickname = request.nickname
|
||||
user_dict = {
|
||||
"access_token": get_uuid(),
|
||||
"email": email_address,
|
||||
"nickname": nickname,
|
||||
"password": decrypt(request.password),
|
||||
"login_channel": "password",
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
}
|
||||
|
||||
user_id = get_uuid()
|
||||
try:
|
||||
users = user_register(user_id, user_dict)
|
||||
if not users:
|
||||
raise Exception(f"Fail to register {email_address}.")
|
||||
if len(users) > 1:
|
||||
raise Exception(f"Same email: {email_address} exists!")
|
||||
user = users[0]
|
||||
return construct_response(
|
||||
data=user.to_json(),
|
||||
auth=user.get_id(),
|
||||
message=f"{nickname}, welcome aboard!",
|
||||
)
|
||||
except Exception as e:
|
||||
rollback_user_registration(user_id)
|
||||
logging.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"User registration failure, error: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/tenant_info")
|
||||
async def tenant_info(current_user = Depends(get_current_user)):
|
||||
"""
|
||||
获取租户信息
|
||||
"""
|
||||
try:
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
if not tenants:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found!"
|
||||
)
|
||||
return get_json_result(data=tenants[0])
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
@router.post("/set_tenant_info")
|
||||
async def set_tenant_info(request: TenantInfoRequest, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
更新租户信息
|
||||
"""
|
||||
try:
|
||||
req_dict = request.dict()
|
||||
tid = req_dict.pop("tenant_id")
|
||||
TenantService.update_by_id(tid, req_dict)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
Reference in New Issue
Block a user