Files
2025-09-26 17:15:54 +08:00

198 lines
7.6 KiB
Python

"""
Refactored configuration management system
Uses dependency injection and config classes instead of global variables
"""
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
import os
import yaml
from azure.ai.formrecognizer import DocumentAnalysisClient
from azure.core.credentials import AzureKeyCredential
from sqlalchemy import create_engine
@dataclass
class DatabaseConfig:
"""Database configuration"""
uri: str
pool_size: int = 5
max_overflow: int = 10
pool_timeout: int = 30
@dataclass
class AzureServiceConfig:
"""Azure service configuration"""
form_recognizer_endpoint: str
form_recognizer_key: str
search_service_name: str
search_admin_key: str
embedding_model_endpoint: Optional[str] = None
embedding_model_key: Optional[str] = None
captioning_model_endpoint: Optional[str] = None
captioning_model_key: Optional[str] = None
di_blob_account_url: Optional[str] = None
figure_blob_account_url: Optional[str] = None
@dataclass
class CaptionServiceConfig:
"""Caption service configuration"""
include_di_content:bool = True
description_gen_max_images:int = 0
model_endpoint: Optional[str] = None
model_key: Optional[str] = None
model:Optional[str] = None
azure_deployment:Optional[str] = None
api_version:Optional[str] = None
prompts:Optional[dict[str,Any]] = None
@dataclass
class ProcessingConfig:
"""Processing configuration"""
max_workers: int = 8
chunk_size: int = 2048
token_overlap: int = 128
min_chunk_size: int = 10
retry_count: int = 3
retry_delay: int = 15
tmp_directory: str = '/tmp'
@dataclass
class LoggingConfig:
"""Logging configuration"""
level: str = "INFO"
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
file_path: Optional[str] = None
console_output: bool = True
console_level: str = "WARNING" # Console only shows WARNING and above
console_format: str = "%(message)s" # Simplified format for console
console_progress_only: bool = True # Only show progress and key info in console
@dataclass
class ApplicationConfig:
"""Main application configuration"""
database: DatabaseConfig
azure_services: AzureServiceConfig
processing: ProcessingConfig
data_configs: list[Dict[str, Any]] = field(default_factory= list[Dict[str, Any]])
current_tmp_directory: str = ''
caption: CaptionServiceConfig = None
env_data: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_env_and_config_files(cls, config_yaml_path: str, env_yaml_path: str = "env.yaml",prompt_path:str="prompt.yaml") -> 'ApplicationConfig':
"""Load configuration from environment variable file and config file."""
# 1. Load environment variable config file first
cls._load_env_yaml(cls,env_yaml_path)
# 2. Load business config file
with open(config_yaml_path, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
# 3. Load prompt config file
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt_data = yaml.safe_load(f)
# 4. Build config object
return cls(
database=DatabaseConfig(
uri=os.getenv('DB_URI', 'sqlite:///app.db'),
pool_size=int(os.getenv('DB_POOL_SIZE', '5')),
max_overflow=int(os.getenv('DB_MAX_OVERFLOW', '10')),
pool_timeout=int(os.getenv('DB_POOL_TIMEOUT', '30'))
),
azure_services=AzureServiceConfig(
form_recognizer_endpoint=os.getenv('form_rec_resource', ''),
form_recognizer_key=os.getenv('form_rec_key', ''),
search_service_name=os.getenv('search_service_name', ''),
search_admin_key=os.getenv('search_admin_key', ''),
embedding_model_endpoint=os.getenv('embedding_model_endpoint'),
embedding_model_key=os.getenv('embedding_model_key'),
captioning_model_endpoint=os.getenv('captioning_model_endpoint'),
captioning_model_key=os.getenv('captioning_model_key'),
di_blob_account_url=os.getenv('DI_BLOB_ACCOUNT_URL',None),
figure_blob_account_url=os.getenv('FIGURE_BLOB_ACCOUNT_URL', '')
),
processing=ProcessingConfig(
max_workers=int(os.getenv('njobs', '8')),
retry_count=int(os.getenv('RETRY_COUNT', '3')),
retry_delay=int(os.getenv('RETRY_DELAY', '15')),
tmp_directory=os.getenv('TMP_DIRECTORY', '/tmp')
),
caption=CaptionServiceConfig(
description_gen_max_images= int(cls.env_data["figure_caption"]["description_gen_max_images"]),
include_di_content = cls.env_data["figure_caption"]["include_di_content"],
model_endpoint= cls.env_data["figure_caption"]["model_endpoint"],
model_key= cls.env_data["figure_caption"]["model_key"],
model= cls.env_data["figure_caption"]["model"],
azure_deployment= cls.env_data["figure_caption"]["azure_deployment"],
api_version=cls.env_data["figure_caption"]["api_version"],
prompts=prompt_data["caption"] if prompt_data and "caption" in prompt_data else None
),
data_configs=config_data if isinstance(config_data, list) else [config_data]
)
@staticmethod
def _load_env_yaml(self,env_yaml_path: str):
"""Load environment variable YAML file."""
if not os.path.exists(env_yaml_path):
return
with open(env_yaml_path, 'r', encoding='utf-8') as f:
self.env_data = yaml.safe_load(f)
# Set environment variables to system environment
if self.env_data:
for key, value in self.env_data.items():
if isinstance(value, bool):
value = str(value).lower()
os.environ[str(key)] = str(value)
def validate(self) -> None:
"""Validate configuration."""
if not self.database.uri:
raise ValueError("Database URI cannot be empty")
if not self.azure_services.form_recognizer_endpoint:
raise ValueError("Form Recognizer endpoint cannot be empty")
if not self.azure_services.form_recognizer_key:
raise ValueError("Form Recognizer key cannot be empty")
if self.processing.max_workers < 1:
raise ValueError("Number of worker threads must be greater than 0")
class ServiceFactory:
"""Service factory class, responsible for creating and managing various service instances."""
def __init__(self, config: ApplicationConfig):
self.config = config
self._form_recognizer_client = None
def get_form_recognizer_client(self) -> DocumentAnalysisClient:
"""Get Form Recognizer client (singleton)."""
if self._form_recognizer_client is None:
self._form_recognizer_client = DocumentAnalysisClient(
endpoint=self.config.azure_services.form_recognizer_endpoint,
credential=AzureKeyCredential(self.config.azure_services.form_recognizer_key)
)
return self._form_recognizer_client
def get_database_engine(self):
"""Get database engine."""
return create_engine(
self.config.database.uri,
pool_size=self.config.database.pool_size,
max_overflow=self.config.database.max_overflow,
pool_timeout=self.config.database.pool_timeout
)