198 lines
7.6 KiB
Python
198 lines
7.6 KiB
Python
"""
|
|
Refactored configuration management system
|
|
Uses dependency injection and config classes instead of global variables
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, Dict, Any
|
|
import os
|
|
import yaml
|
|
from azure.ai.formrecognizer import DocumentAnalysisClient
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from sqlalchemy import create_engine
|
|
|
|
@dataclass
|
|
class DatabaseConfig:
|
|
"""Database configuration"""
|
|
uri: str
|
|
pool_size: int = 5
|
|
max_overflow: int = 10
|
|
pool_timeout: int = 30
|
|
|
|
|
|
@dataclass
|
|
class AzureServiceConfig:
|
|
"""Azure service configuration"""
|
|
form_recognizer_endpoint: str
|
|
form_recognizer_key: str
|
|
search_service_name: str
|
|
search_admin_key: str
|
|
embedding_model_endpoint: Optional[str] = None
|
|
embedding_model_key: Optional[str] = None
|
|
captioning_model_endpoint: Optional[str] = None
|
|
captioning_model_key: Optional[str] = None
|
|
di_blob_account_url: Optional[str] = None
|
|
figure_blob_account_url: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class CaptionServiceConfig:
|
|
"""Caption service configuration"""
|
|
include_di_content:bool = True
|
|
description_gen_max_images:int = 0
|
|
model_endpoint: Optional[str] = None
|
|
model_key: Optional[str] = None
|
|
model:Optional[str] = None
|
|
azure_deployment:Optional[str] = None
|
|
api_version:Optional[str] = None
|
|
prompts:Optional[dict[str,Any]] = None
|
|
|
|
|
|
|
|
@dataclass
|
|
class ProcessingConfig:
|
|
"""Processing configuration"""
|
|
max_workers: int = 8
|
|
chunk_size: int = 2048
|
|
token_overlap: int = 128
|
|
min_chunk_size: int = 10
|
|
retry_count: int = 3
|
|
retry_delay: int = 15
|
|
tmp_directory: str = '/tmp'
|
|
|
|
|
|
@dataclass
|
|
class LoggingConfig:
|
|
"""Logging configuration"""
|
|
level: str = "INFO"
|
|
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
file_path: Optional[str] = None
|
|
console_output: bool = True
|
|
console_level: str = "WARNING" # Console only shows WARNING and above
|
|
console_format: str = "%(message)s" # Simplified format for console
|
|
console_progress_only: bool = True # Only show progress and key info in console
|
|
|
|
|
|
@dataclass
|
|
class ApplicationConfig:
|
|
"""Main application configuration"""
|
|
database: DatabaseConfig
|
|
azure_services: AzureServiceConfig
|
|
processing: ProcessingConfig
|
|
data_configs: list[Dict[str, Any]] = field(default_factory= list[Dict[str, Any]])
|
|
current_tmp_directory: str = ''
|
|
caption: CaptionServiceConfig = None
|
|
env_data: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
@classmethod
|
|
def from_env_and_config_files(cls, config_yaml_path: str, env_yaml_path: str = "env.yaml",prompt_path:str="prompt.yaml") -> 'ApplicationConfig':
|
|
"""Load configuration from environment variable file and config file."""
|
|
# 1. Load environment variable config file first
|
|
cls._load_env_yaml(cls,env_yaml_path)
|
|
|
|
# 2. Load business config file
|
|
with open(config_yaml_path, 'r', encoding='utf-8') as f:
|
|
config_data = yaml.safe_load(f)
|
|
|
|
|
|
# 3. Load prompt config file
|
|
if os.path.exists(prompt_path):
|
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
|
prompt_data = yaml.safe_load(f)
|
|
|
|
# 4. Build config object
|
|
return cls(
|
|
database=DatabaseConfig(
|
|
uri=os.getenv('DB_URI', 'sqlite:///app.db'),
|
|
pool_size=int(os.getenv('DB_POOL_SIZE', '5')),
|
|
max_overflow=int(os.getenv('DB_MAX_OVERFLOW', '10')),
|
|
pool_timeout=int(os.getenv('DB_POOL_TIMEOUT', '30'))
|
|
),
|
|
azure_services=AzureServiceConfig(
|
|
form_recognizer_endpoint=os.getenv('form_rec_resource', ''),
|
|
form_recognizer_key=os.getenv('form_rec_key', ''),
|
|
search_service_name=os.getenv('search_service_name', ''),
|
|
search_admin_key=os.getenv('search_admin_key', ''),
|
|
embedding_model_endpoint=os.getenv('embedding_model_endpoint'),
|
|
embedding_model_key=os.getenv('embedding_model_key'),
|
|
captioning_model_endpoint=os.getenv('captioning_model_endpoint'),
|
|
captioning_model_key=os.getenv('captioning_model_key'),
|
|
di_blob_account_url=os.getenv('DI_BLOB_ACCOUNT_URL',None),
|
|
figure_blob_account_url=os.getenv('FIGURE_BLOB_ACCOUNT_URL', '')
|
|
),
|
|
processing=ProcessingConfig(
|
|
max_workers=int(os.getenv('njobs', '8')),
|
|
retry_count=int(os.getenv('RETRY_COUNT', '3')),
|
|
retry_delay=int(os.getenv('RETRY_DELAY', '15')),
|
|
tmp_directory=os.getenv('TMP_DIRECTORY', '/tmp')
|
|
),
|
|
caption=CaptionServiceConfig(
|
|
description_gen_max_images= int(cls.env_data["figure_caption"]["description_gen_max_images"]),
|
|
include_di_content = cls.env_data["figure_caption"]["include_di_content"],
|
|
model_endpoint= cls.env_data["figure_caption"]["model_endpoint"],
|
|
model_key= cls.env_data["figure_caption"]["model_key"],
|
|
model= cls.env_data["figure_caption"]["model"],
|
|
azure_deployment= cls.env_data["figure_caption"]["azure_deployment"],
|
|
api_version=cls.env_data["figure_caption"]["api_version"],
|
|
prompts=prompt_data["caption"] if prompt_data and "caption" in prompt_data else None
|
|
),
|
|
data_configs=config_data if isinstance(config_data, list) else [config_data]
|
|
)
|
|
|
|
@staticmethod
|
|
def _load_env_yaml(self,env_yaml_path: str):
|
|
"""Load environment variable YAML file."""
|
|
if not os.path.exists(env_yaml_path):
|
|
return
|
|
|
|
with open(env_yaml_path, 'r', encoding='utf-8') as f:
|
|
self.env_data = yaml.safe_load(f)
|
|
|
|
# Set environment variables to system environment
|
|
if self.env_data:
|
|
for key, value in self.env_data.items():
|
|
if isinstance(value, bool):
|
|
value = str(value).lower()
|
|
os.environ[str(key)] = str(value)
|
|
|
|
def validate(self) -> None:
|
|
"""Validate configuration."""
|
|
if not self.database.uri:
|
|
raise ValueError("Database URI cannot be empty")
|
|
|
|
if not self.azure_services.form_recognizer_endpoint:
|
|
raise ValueError("Form Recognizer endpoint cannot be empty")
|
|
|
|
if not self.azure_services.form_recognizer_key:
|
|
raise ValueError("Form Recognizer key cannot be empty")
|
|
|
|
if self.processing.max_workers < 1:
|
|
raise ValueError("Number of worker threads must be greater than 0")
|
|
|
|
|
|
|
|
class ServiceFactory:
|
|
"""Service factory class, responsible for creating and managing various service instances."""
|
|
|
|
def __init__(self, config: ApplicationConfig):
|
|
self.config = config
|
|
self._form_recognizer_client = None
|
|
|
|
def get_form_recognizer_client(self) -> DocumentAnalysisClient:
|
|
"""Get Form Recognizer client (singleton)."""
|
|
if self._form_recognizer_client is None:
|
|
self._form_recognizer_client = DocumentAnalysisClient(
|
|
endpoint=self.config.azure_services.form_recognizer_endpoint,
|
|
credential=AzureKeyCredential(self.config.azure_services.form_recognizer_key)
|
|
)
|
|
return self._form_recognizer_client
|
|
|
|
def get_database_engine(self):
|
|
"""Get database engine."""
|
|
return create_engine(
|
|
self.config.database.uri,
|
|
pool_size=self.config.database.pool_size,
|
|
max_overflow=self.config.database.max_overflow,
|
|
pool_timeout=self.config.database.pool_timeout
|
|
)
|