""" Refactored configuration management system Uses dependency injection and config classes instead of global variables """ from dataclasses import dataclass, field from typing import Optional, Dict, Any import os import yaml from azure.ai.formrecognizer import DocumentAnalysisClient from azure.core.credentials import AzureKeyCredential from sqlalchemy import create_engine @dataclass class DatabaseConfig: """Database configuration""" uri: str pool_size: int = 5 max_overflow: int = 10 pool_timeout: int = 30 @dataclass class AzureServiceConfig: """Azure service configuration""" form_recognizer_endpoint: str form_recognizer_key: str search_service_name: str search_admin_key: str embedding_model_endpoint: Optional[str] = None embedding_model_key: Optional[str] = None captioning_model_endpoint: Optional[str] = None captioning_model_key: Optional[str] = None di_blob_account_url: Optional[str] = None figure_blob_account_url: Optional[str] = None @dataclass class CaptionServiceConfig: """Caption service configuration""" include_di_content:bool = True description_gen_max_images:int = 0 model_endpoint: Optional[str] = None model_key: Optional[str] = None model:Optional[str] = None azure_deployment:Optional[str] = None api_version:Optional[str] = None prompts:Optional[dict[str,Any]] = None @dataclass class ProcessingConfig: """Processing configuration""" max_workers: int = 8 chunk_size: int = 2048 token_overlap: int = 128 min_chunk_size: int = 10 retry_count: int = 3 retry_delay: int = 15 tmp_directory: str = '/tmp' @dataclass class LoggingConfig: """Logging configuration""" level: str = "INFO" format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" file_path: Optional[str] = None console_output: bool = True console_level: str = "WARNING" # Console only shows WARNING and above console_format: str = "%(message)s" # Simplified format for console console_progress_only: bool = True # Only show progress and key info in console @dataclass class ApplicationConfig: """Main application configuration""" database: DatabaseConfig azure_services: AzureServiceConfig processing: ProcessingConfig data_configs: list[Dict[str, Any]] = field(default_factory= list[Dict[str, Any]]) current_tmp_directory: str = '' caption: CaptionServiceConfig = None env_data: Dict[str, Any] = field(default_factory=dict) @classmethod def from_env_and_config_files(cls, config_yaml_path: str, env_yaml_path: str = "env.yaml",prompt_path:str="prompt.yaml") -> 'ApplicationConfig': """Load configuration from environment variable file and config file.""" # 1. Load environment variable config file first cls._load_env_yaml(cls,env_yaml_path) # 2. Load business config file with open(config_yaml_path, 'r', encoding='utf-8') as f: config_data = yaml.safe_load(f) # 3. Load prompt config file if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: prompt_data = yaml.safe_load(f) # 4. Build config object return cls( database=DatabaseConfig( uri=os.getenv('DB_URI', 'sqlite:///app.db'), pool_size=int(os.getenv('DB_POOL_SIZE', '5')), max_overflow=int(os.getenv('DB_MAX_OVERFLOW', '10')), pool_timeout=int(os.getenv('DB_POOL_TIMEOUT', '30')) ), azure_services=AzureServiceConfig( form_recognizer_endpoint=os.getenv('form_rec_resource', ''), form_recognizer_key=os.getenv('form_rec_key', ''), search_service_name=os.getenv('search_service_name', ''), search_admin_key=os.getenv('search_admin_key', ''), embedding_model_endpoint=os.getenv('embedding_model_endpoint'), embedding_model_key=os.getenv('embedding_model_key'), captioning_model_endpoint=os.getenv('captioning_model_endpoint'), captioning_model_key=os.getenv('captioning_model_key'), di_blob_account_url=os.getenv('DI_BLOB_ACCOUNT_URL',None), figure_blob_account_url=os.getenv('FIGURE_BLOB_ACCOUNT_URL', '') ), processing=ProcessingConfig( max_workers=int(os.getenv('njobs', '8')), retry_count=int(os.getenv('RETRY_COUNT', '3')), retry_delay=int(os.getenv('RETRY_DELAY', '15')), tmp_directory=os.getenv('TMP_DIRECTORY', '/tmp') ), caption=CaptionServiceConfig( description_gen_max_images= int(cls.env_data["figure_caption"]["description_gen_max_images"]), include_di_content = cls.env_data["figure_caption"]["include_di_content"], model_endpoint= cls.env_data["figure_caption"]["model_endpoint"], model_key= cls.env_data["figure_caption"]["model_key"], model= cls.env_data["figure_caption"]["model"], azure_deployment= cls.env_data["figure_caption"]["azure_deployment"], api_version=cls.env_data["figure_caption"]["api_version"], prompts=prompt_data["caption"] if prompt_data and "caption" in prompt_data else None ), data_configs=config_data if isinstance(config_data, list) else [config_data] ) @staticmethod def _load_env_yaml(self,env_yaml_path: str): """Load environment variable YAML file.""" if not os.path.exists(env_yaml_path): return with open(env_yaml_path, 'r', encoding='utf-8') as f: self.env_data = yaml.safe_load(f) # Set environment variables to system environment if self.env_data: for key, value in self.env_data.items(): if isinstance(value, bool): value = str(value).lower() os.environ[str(key)] = str(value) def validate(self) -> None: """Validate configuration.""" if not self.database.uri: raise ValueError("Database URI cannot be empty") if not self.azure_services.form_recognizer_endpoint: raise ValueError("Form Recognizer endpoint cannot be empty") if not self.azure_services.form_recognizer_key: raise ValueError("Form Recognizer key cannot be empty") if self.processing.max_workers < 1: raise ValueError("Number of worker threads must be greater than 0") class ServiceFactory: """Service factory class, responsible for creating and managing various service instances.""" def __init__(self, config: ApplicationConfig): self.config = config self._form_recognizer_client = None def get_form_recognizer_client(self) -> DocumentAnalysisClient: """Get Form Recognizer client (singleton).""" if self._form_recognizer_client is None: self._form_recognizer_client = DocumentAnalysisClient( endpoint=self.config.azure_services.form_recognizer_endpoint, credential=AzureKeyCredential(self.config.azure_services.form_recognizer_key) ) return self._form_recognizer_client def get_database_engine(self): """Get database engine.""" return create_engine( self.config.database.uri, pool_size=self.config.database.pool_size, max_overflow=self.config.database.max_overflow, pool_timeout=self.config.database.pool_timeout )