Files
TERES_fastapi_backend/ocr/main.py

203 lines
5.8 KiB
Python
Raw Normal View History

2025-10-31 14:38:37 +08:00
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR PDF处理服务的主程序入口
独立运行不依赖RAGFlow的其他部分
"""
import argparse
import logging
import os
import sys
import signal
from pathlib import Path
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# 处理直接运行时的导入问题
# 当直接运行 python ocr/main.py 时__package__ 为 None
# 当作为模块运行时python -m ocr.main__package__ 为 'ocr'
try:
_package = __package__
except NameError:
_package = None
if _package is None:
# 直接运行脚本时,添加父目录到路径
parent_dir = Path(__file__).parent.parent
if str(parent_dir) not in sys.path:
sys.path.insert(0, str(parent_dir))
from api import router as ocr_router
from config import MODEL_DIR
else:
# 作为模块导入时使用相对导入
from api import router as ocr_router
from config import MODEL_DIR
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
def create_app() -> FastAPI:
"""创建FastAPI应用实例"""
app = FastAPI(
title="OCR PDF Parser API",
description="独立的OCR PDF处理服务提供PDF文档的OCR识别功能",
version="1.0.0",
docs_url="/apidocs", # Swagger UI 文档地址
redoc_url="/redoc", # ReDoc 文档地址(备用)
openapi_url="/openapi.json" # OpenAPI JSON schema 地址
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境中应该设置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册OCR路由
app.include_router(ocr_router)
# 根路径
@app.get("/")
async def root():
return {
"service": "OCR PDF Parser",
"version": "1.0.0",
"docs": "/apidocs",
"health": "/api/v1/ocr/health"
}
return app
def signal_handler(sig, frame):
"""信号处理器,用于优雅关闭"""
logger.info("Received shutdown signal, exiting...")
sys.exit(0)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="OCR PDF处理服务")
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="服务器监听地址 (default: 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="服务器端口 (default: 8000)"
)
parser.add_argument(
"--reload",
action="store_true",
help="开发模式:自动重载代码"
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="工作进程数 (default: 1)"
)
parser.add_argument(
"--log-level",
type=str,
default="info",
choices=["critical", "error", "warning", "info", "debug", "trace"],
help="日志级别 (default: info)"
)
parser.add_argument(
"--model-dir",
type=str,
default=None,
help=f"OCR模型目录路径 (default: {MODEL_DIR})"
)
args = parser.parse_args()
# 设置模型目录(如果提供)
if args.model_dir:
os.environ["OCR_MODEL_DIR"] = args.model_dir
logger.info(f"Using custom model directory: {args.model_dir}")
# 检查模型目录
model_dir = os.environ.get("OCR_MODEL_DIR", MODEL_DIR)
if model_dir and not os.path.exists(model_dir):
logger.warning(f"Model directory does not exist: {model_dir}")
logger.info("Models will be downloaded on first use")
# 注册信号处理器
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# 显示启动信息
logger.info("=" * 60)
logger.info("OCR PDF Parser Service")
logger.info("=" * 60)
logger.info(f"Host: {args.host}")
logger.info(f"Port: {args.port}")
logger.info(f"Model Directory: {model_dir}")
logger.info(f"Workers: {args.workers}")
logger.info(f"Reload: {args.reload}")
logger.info(f"Log Level: {args.log_level}")
logger.info("=" * 60)
logger.info(f"API Documentation (Swagger): http://{args.host}:{args.port}/apidocs")
logger.info(f"API Documentation (ReDoc): http://{args.host}:{args.port}/redoc")
logger.info(f"Health Check: http://{args.host}:{args.port}/api/v1/ocr/health")
logger.info("=" * 60)
# 创建应用
app = create_app()
# 启动服务器
try:
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level=args.log_level,
reload=args.reload,
workers=args.workers if not args.reload else 1, # reload模式不支持多进程
access_log=True
)
except KeyboardInterrupt:
logger.info("Server stopped by user")
except Exception as e:
logger.error(f"Server error: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()