Files
TERES_fastapi_backend/ocr/main.py
2025-10-31 17:50:25 +08:00

186 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
OCR PDF处理服务的主程序入口
独立运行不依赖RAGFlow的其他部分
"""
import argparse
import logging
import os
import sys
import signal
from pathlib import Path
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ocr.api import ocr_router
from ocr.config import MODEL_DIR
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
def create_app() -> FastAPI:
"""创建FastAPI应用实例"""
app = FastAPI(
title="OCR PDF Parser API",
description="独立的OCR PDF处理服务提供PDF文档的OCR识别功能",
version="1.0.0",
docs_url="/apidocs", # Swagger UI 文档地址
redoc_url="/redoc", # ReDoc 文档地址(备用)
openapi_url="/openapi.json" # OpenAPI JSON schema 地址
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境中应该设置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册OCR路由
app.include_router(ocr_router)
# 根路径
@app.get("/")
async def root():
return {
"service": "OCR PDF Parser",
"version": "1.0.0",
"docs": "/apidocs",
"health": "/api/v1/ocr/health"
}
return app
def signal_handler(sig, frame):
"""信号处理器,用于优雅关闭"""
logger.info("Received shutdown signal, exiting...")
sys.exit(0)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="OCR PDF处理服务")
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="服务器监听地址 (default: 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="服务器端口 (default: 8000)"
)
parser.add_argument(
"--reload",
action="store_true",
help="开发模式:自动重载代码"
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="工作进程数 (default: 1)"
)
parser.add_argument(
"--log-level",
type=str,
default="info",
choices=["critical", "error", "warning", "info", "debug", "trace"],
help="日志级别 (default: info)"
)
parser.add_argument(
"--model-dir",
type=str,
default=None,
help=f"OCR模型目录路径 (default: {MODEL_DIR})"
)
args = parser.parse_args()
# 设置模型目录(如果提供)
if args.model_dir:
os.environ["OCR_MODEL_DIR"] = args.model_dir
logger.info(f"Using custom model directory: {args.model_dir}")
# 检查模型目录
model_dir = os.environ.get("OCR_MODEL_DIR", MODEL_DIR)
if model_dir and not os.path.exists(model_dir):
logger.warning(f"Model directory does not exist: {model_dir}")
logger.info("Models will be downloaded on first use")
# 注册信号处理器
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# 显示启动信息
logger.info("=" * 60)
logger.info("OCR PDF Parser Service")
logger.info("=" * 60)
logger.info(f"Host: {args.host}")
logger.info(f"Port: {args.port}")
logger.info(f"Model Directory: {model_dir}")
logger.info(f"Workers: {args.workers}")
logger.info(f"Reload: {args.reload}")
logger.info(f"Log Level: {args.log_level}")
logger.info("=" * 60)
logger.info(f"API Documentation (Swagger): http://{args.host}:{args.port}/apidocs")
logger.info(f"API Documentation (ReDoc): http://{args.host}:{args.port}/redoc")
logger.info(f"Health Check: http://{args.host}:{args.port}/api/v1/ocr/health")
logger.info("=" * 60)
# 创建应用
app = create_app()
# 启动服务器
try:
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level=args.log_level,
reload=args.reload,
workers=args.workers if not args.reload else 1, # reload模式不支持多进程
access_log=True
)
except KeyboardInterrupt:
logger.info("Server stopped by user")
except Exception as e:
logger.error(f"Server error: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()