Files
TERES_fastapi_backend/api/apps/__init___fastapi.py

234 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
import logging
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from starlette.middleware.sessions import SessionMiddleware
try:
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
except ImportError:
# 如果没有itsdangerous使用jwt作为替代
import jwt
Serializer = jwt
from api.db import StatusEnum
from api.db.db_models import close_connection
from api.db.services import UserService
from api.utils.json import CustomJSONEncoder
from api.utils import commands
from api import settings
from api.utils.api_utils import server_error_response
from api.constants import API_VERSION
__all__ = ["app"]
def create_app() -> FastAPI:
"""创建FastAPI应用实例"""
app = FastAPI(
title="RAGFlow API",
description="RAGFlow API Server",
version="1.0.0",
docs_url="/apidocs/",
redoc_url="/redoc/",
openapi_url="/apispec.json"
)
# 自定义 OpenAPI schema 以支持 Bearer Token 认证
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# 添加安全方案定义HTTPBearer 会自动注册为 "HTTPBearer"
# 如果 components 不存在,先创建它
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "securitySchemes" not in openapi_schema["components"]:
openapi_schema["components"]["securitySchemes"] = {}
# 移除 CustomHTTPBearer如果存在只保留 HTTPBearer
if "CustomHTTPBearer" in openapi_schema["components"]["securitySchemes"]:
del openapi_schema["components"]["securitySchemes"]["CustomHTTPBearer"]
# 添加/更新 HTTPBearer 安全方案FastAPI 默认名称)
openapi_schema["components"]["securitySchemes"]["HTTPBearer"] = {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "输入 token可以不带 'Bearer ' 前缀,系统会自动添加"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境中应该设置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
max_age=2592000
)
# 添加信任主机中间件
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["*"] # 生产环境中应该设置具体的域名
)
# 添加会话中间件
app.add_middleware(
SessionMiddleware,
secret_key=settings.SECRET_KEY,
max_age=2592000
)
# 设置错误处理器
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
return server_error_response(exc)
return app
def search_pages_path(pages_dir):
"""搜索页面路径"""
app_path_list = [
path for path in pages_dir.glob("*_app_fastapi.py") if not path.name.startswith(".")
]
api_path_list = [
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
]
app_path_list.extend(api_path_list)
return app_path_list
def register_page(app: FastAPI, page_path):
"""注册页面路由"""
path = f"{page_path}"
page_name = page_path.stem.removesuffix("_app_fastapi")
module_name = ".".join(
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
)
spec = spec_from_file_location(module_name, page_path)
page = module_from_spec(spec)
page.app = app
page.router = None # FastAPI使用router而不是Blueprint
sys.modules[module_name] = page
spec.loader.exec_module(page)
page_name = getattr(page, "page_name", page_name)
sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/"
url_prefix = (
f"/api/{API_VERSION}" if sdk_path in path else f"/{API_VERSION}/{page_name}"
)
# 在FastAPI中我们需要检查是否有router属性
if hasattr(page, 'router') and page.router:
app.include_router(page.router, prefix=url_prefix)
return url_prefix
def setup_routes(app: FastAPI):
"""设置路由 - 注册所有接口"""
from api.apps.user_app_fastapi import router as user_router
from api.apps.kb_app import router as kb_router
from api.apps.document_app import router as document_router
from api.apps.llm_app import router as llm_router
from api.apps.chunk_app import router as chunk_router
from api.apps.mcp_server_app import router as mcp_router
from api.apps.canvas_app import router as canvas_router
from api.apps.tenant_app import router as tenant_router
from api.apps.dialog_app import router as dialog_router
from api.apps.system_app import router as system_router
from api.apps.search_app import router as search_router
from api.apps.conversation_app import router as conversation_router
from api.apps.file_app import router as file_router
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KnowledgeBase"])
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
app.include_router(llm_router, prefix=f"/{API_VERSION}/llm", tags=["LLM"])
app.include_router(chunk_router, prefix=f"/{API_VERSION}/chunk", tags=["Chunk"])
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp_server", tags=["MCP"])
app.include_router(canvas_router, prefix=f"/{API_VERSION}/canvas", tags=["Canvas"])
app.include_router(tenant_router, prefix=f"/{API_VERSION}/tenant", tags=["Tenant"])
app.include_router(dialog_router, prefix=f"/{API_VERSION}/dialog", tags=["Dialog"])
app.include_router(system_router, prefix=f"/{API_VERSION}/system", tags=["System"])
app.include_router(search_router, prefix=f"/{API_VERSION}/search", tags=["Search"])
app.include_router(conversation_router, prefix=f"/{API_VERSION}/conversation", tags=["Conversation"])
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])
def get_current_user_from_token(authorization: str):
"""从token获取当前用户"""
jwt = Serializer(secret_key=settings.SECRET_KEY)
if authorization:
try:
access_token = str(jwt.loads(authorization))
if not access_token or not access_token.strip():
logging.warning("Authentication attempt with empty access token")
return None
# Access tokens should be UUIDs (32 hex characters)
if len(access_token.strip()) < 32:
logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars")
return None
user = UserService.query(
access_token=access_token, status=StatusEnum.VALID.value
)
if user:
if not user[0].access_token or not user[0].access_token.strip():
logging.warning(f"User {user[0].email} has empty access_token in database")
return None
return user[0]
else:
return None
except Exception as e:
logging.warning(f"load_user got exception {e}")
return None
else:
return None
# 创建应用实例
app = create_app()
@app.middleware("http")
async def db_close_middleware(request, call_next):
"""数据库连接关闭中间件"""
try:
response = await call_next(request)
return response
finally:
close_connection()
setup_routes(app)