2025-10-31 14:38:37 +08:00
|
|
|
|
#
|
|
|
|
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
|
|
|
|
#
|
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
|
#
|
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
#
|
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
#
|
|
|
|
|
|
"""
|
|
|
|
|
|
OCR PDF处理服务的主程序入口
|
|
|
|
|
|
独立运行,不依赖RAGFlow的其他部分
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
import logging
|
|
|
|
|
|
import os
|
|
|
|
|
|
import sys
|
|
|
|
|
|
import signal
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
2025-11-03 10:22:28 +08:00
|
|
|
|
# 确保项目根目录在 sys.path 中
|
|
|
|
|
|
_current_file = Path(__file__).resolve()
|
|
|
|
|
|
_project_root = _current_file.parent.parent
|
|
|
|
|
|
if str(_project_root) not in sys.path:
|
|
|
|
|
|
sys.path.insert(0, str(_project_root))
|
|
|
|
|
|
|
2025-10-31 14:38:37 +08:00
|
|
|
|
import uvicorn
|
|
|
|
|
|
from fastapi import FastAPI
|
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
2025-10-31 17:50:25 +08:00
|
|
|
|
from ocr.api import ocr_router
|
|
|
|
|
|
from ocr.config import MODEL_DIR
|
2025-10-31 14:38:37 +08:00
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
|
|
logging.basicConfig(
|
|
|
|
|
|
level=logging.INFO,
|
|
|
|
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
|
|
|
|
handlers=[
|
|
|
|
|
|
logging.StreamHandler(sys.stdout)
|
|
|
|
|
|
]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_app() -> FastAPI:
|
|
|
|
|
|
"""创建FastAPI应用实例"""
|
|
|
|
|
|
app = FastAPI(
|
|
|
|
|
|
title="OCR PDF Parser API",
|
|
|
|
|
|
description="独立的OCR PDF处理服务,提供PDF文档的OCR识别功能",
|
|
|
|
|
|
version="1.0.0",
|
|
|
|
|
|
docs_url="/apidocs", # Swagger UI 文档地址
|
|
|
|
|
|
redoc_url="/redoc", # ReDoc 文档地址(备用)
|
|
|
|
|
|
openapi_url="/openapi.json" # OpenAPI JSON schema 地址
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加CORS中间件
|
|
|
|
|
|
app.add_middleware(
|
|
|
|
|
|
CORSMiddleware,
|
|
|
|
|
|
allow_origins=["*"], # 生产环境中应该设置具体的域名
|
|
|
|
|
|
allow_credentials=True,
|
|
|
|
|
|
allow_methods=["*"],
|
|
|
|
|
|
allow_headers=["*"],
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 注册OCR路由
|
|
|
|
|
|
app.include_router(ocr_router)
|
|
|
|
|
|
|
|
|
|
|
|
# 根路径
|
|
|
|
|
|
@app.get("/")
|
|
|
|
|
|
async def root():
|
|
|
|
|
|
return {
|
|
|
|
|
|
"service": "OCR PDF Parser",
|
|
|
|
|
|
"version": "1.0.0",
|
|
|
|
|
|
"docs": "/apidocs",
|
|
|
|
|
|
"health": "/api/v1/ocr/health"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def signal_handler(sig, frame):
|
|
|
|
|
|
"""信号处理器,用于优雅关闭"""
|
|
|
|
|
|
logger.info("Received shutdown signal, exiting...")
|
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
"""主函数"""
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="OCR PDF处理服务")
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--host",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default="0.0.0.0",
|
|
|
|
|
|
help="服务器监听地址 (default: 0.0.0.0)"
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--port",
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
default=8000,
|
|
|
|
|
|
help="服务器端口 (default: 8000)"
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--reload",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="开发模式:自动重载代码"
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--workers",
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
default=1,
|
|
|
|
|
|
help="工作进程数 (default: 1)"
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--log-level",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default="info",
|
|
|
|
|
|
choices=["critical", "error", "warning", "info", "debug", "trace"],
|
|
|
|
|
|
help="日志级别 (default: info)"
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--model-dir",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
help=f"OCR模型目录路径 (default: {MODEL_DIR})"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
# 设置模型目录(如果提供)
|
|
|
|
|
|
if args.model_dir:
|
|
|
|
|
|
os.environ["OCR_MODEL_DIR"] = args.model_dir
|
|
|
|
|
|
logger.info(f"Using custom model directory: {args.model_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查模型目录
|
|
|
|
|
|
model_dir = os.environ.get("OCR_MODEL_DIR", MODEL_DIR)
|
|
|
|
|
|
if model_dir and not os.path.exists(model_dir):
|
|
|
|
|
|
logger.warning(f"Model directory does not exist: {model_dir}")
|
|
|
|
|
|
logger.info("Models will be downloaded on first use")
|
|
|
|
|
|
|
|
|
|
|
|
# 注册信号处理器
|
|
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
|
|
|
|
|
|
|
|
# 显示启动信息
|
|
|
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
logger.info("OCR PDF Parser Service")
|
|
|
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
logger.info(f"Host: {args.host}")
|
|
|
|
|
|
logger.info(f"Port: {args.port}")
|
|
|
|
|
|
logger.info(f"Model Directory: {model_dir}")
|
|
|
|
|
|
logger.info(f"Workers: {args.workers}")
|
|
|
|
|
|
logger.info(f"Reload: {args.reload}")
|
|
|
|
|
|
logger.info(f"Log Level: {args.log_level}")
|
|
|
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
logger.info(f"API Documentation (Swagger): http://{args.host}:{args.port}/apidocs")
|
|
|
|
|
|
logger.info(f"API Documentation (ReDoc): http://{args.host}:{args.port}/redoc")
|
|
|
|
|
|
logger.info(f"Health Check: http://{args.host}:{args.port}/api/v1/ocr/health")
|
|
|
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建应用
|
|
|
|
|
|
app = create_app()
|
|
|
|
|
|
|
|
|
|
|
|
# 启动服务器
|
|
|
|
|
|
try:
|
|
|
|
|
|
uvicorn.run(
|
|
|
|
|
|
app,
|
|
|
|
|
|
host=args.host,
|
|
|
|
|
|
port=args.port,
|
|
|
|
|
|
log_level=args.log_level,
|
|
|
|
|
|
reload=args.reload,
|
|
|
|
|
|
workers=args.workers if not args.reload else 1, # reload模式不支持多进程
|
|
|
|
|
|
access_log=True
|
|
|
|
|
|
)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
|
logger.info("Server stopped by user")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"Server error: {e}", exc_info=True)
|
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|
|
|
|
|
|
|