init
This commit is contained in:
197
vw-document-ai-indexer/app_config.py
Normal file
197
vw-document-ai-indexer/app_config.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Refactored configuration management system
|
||||
Uses dependency injection and config classes instead of global variables
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any
|
||||
import os
|
||||
import yaml
|
||||
from azure.ai.formrecognizer import DocumentAnalysisClient
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database configuration"""
|
||||
uri: str
|
||||
pool_size: int = 5
|
||||
max_overflow: int = 10
|
||||
pool_timeout: int = 30
|
||||
|
||||
|
||||
@dataclass
|
||||
class AzureServiceConfig:
|
||||
"""Azure service configuration"""
|
||||
form_recognizer_endpoint: str
|
||||
form_recognizer_key: str
|
||||
search_service_name: str
|
||||
search_admin_key: str
|
||||
embedding_model_endpoint: Optional[str] = None
|
||||
embedding_model_key: Optional[str] = None
|
||||
captioning_model_endpoint: Optional[str] = None
|
||||
captioning_model_key: Optional[str] = None
|
||||
di_blob_account_url: Optional[str] = None
|
||||
figure_blob_account_url: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaptionServiceConfig:
|
||||
"""Caption service configuration"""
|
||||
include_di_content:bool = True
|
||||
description_gen_max_images:int = 0
|
||||
model_endpoint: Optional[str] = None
|
||||
model_key: Optional[str] = None
|
||||
model:Optional[str] = None
|
||||
azure_deployment:Optional[str] = None
|
||||
api_version:Optional[str] = None
|
||||
prompts:Optional[dict[str,Any]] = None
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingConfig:
|
||||
"""Processing configuration"""
|
||||
max_workers: int = 8
|
||||
chunk_size: int = 2048
|
||||
token_overlap: int = 128
|
||||
min_chunk_size: int = 10
|
||||
retry_count: int = 3
|
||||
retry_delay: int = 15
|
||||
tmp_directory: str = '/tmp'
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
"""Logging configuration"""
|
||||
level: str = "INFO"
|
||||
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
file_path: Optional[str] = None
|
||||
console_output: bool = True
|
||||
console_level: str = "WARNING" # Console only shows WARNING and above
|
||||
console_format: str = "%(message)s" # Simplified format for console
|
||||
console_progress_only: bool = True # Only show progress and key info in console
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApplicationConfig:
|
||||
"""Main application configuration"""
|
||||
database: DatabaseConfig
|
||||
azure_services: AzureServiceConfig
|
||||
processing: ProcessingConfig
|
||||
data_configs: list[Dict[str, Any]] = field(default_factory= list[Dict[str, Any]])
|
||||
current_tmp_directory: str = ''
|
||||
caption: CaptionServiceConfig = None
|
||||
env_data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_env_and_config_files(cls, config_yaml_path: str, env_yaml_path: str = "env.yaml",prompt_path:str="prompt.yaml") -> 'ApplicationConfig':
|
||||
"""Load configuration from environment variable file and config file."""
|
||||
# 1. Load environment variable config file first
|
||||
cls._load_env_yaml(cls,env_yaml_path)
|
||||
|
||||
# 2. Load business config file
|
||||
with open(config_yaml_path, 'r', encoding='utf-8') as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
|
||||
# 3. Load prompt config file
|
||||
if os.path.exists(prompt_path):
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt_data = yaml.safe_load(f)
|
||||
|
||||
# 4. Build config object
|
||||
return cls(
|
||||
database=DatabaseConfig(
|
||||
uri=os.getenv('DB_URI', 'sqlite:///app.db'),
|
||||
pool_size=int(os.getenv('DB_POOL_SIZE', '5')),
|
||||
max_overflow=int(os.getenv('DB_MAX_OVERFLOW', '10')),
|
||||
pool_timeout=int(os.getenv('DB_POOL_TIMEOUT', '30'))
|
||||
),
|
||||
azure_services=AzureServiceConfig(
|
||||
form_recognizer_endpoint=os.getenv('form_rec_resource', ''),
|
||||
form_recognizer_key=os.getenv('form_rec_key', ''),
|
||||
search_service_name=os.getenv('search_service_name', ''),
|
||||
search_admin_key=os.getenv('search_admin_key', ''),
|
||||
embedding_model_endpoint=os.getenv('embedding_model_endpoint'),
|
||||
embedding_model_key=os.getenv('embedding_model_key'),
|
||||
captioning_model_endpoint=os.getenv('captioning_model_endpoint'),
|
||||
captioning_model_key=os.getenv('captioning_model_key'),
|
||||
di_blob_account_url=os.getenv('DI_BLOB_ACCOUNT_URL',None),
|
||||
figure_blob_account_url=os.getenv('FIGURE_BLOB_ACCOUNT_URL', '')
|
||||
),
|
||||
processing=ProcessingConfig(
|
||||
max_workers=int(os.getenv('njobs', '8')),
|
||||
retry_count=int(os.getenv('RETRY_COUNT', '3')),
|
||||
retry_delay=int(os.getenv('RETRY_DELAY', '15')),
|
||||
tmp_directory=os.getenv('TMP_DIRECTORY', '/tmp')
|
||||
),
|
||||
caption=CaptionServiceConfig(
|
||||
description_gen_max_images= int(cls.env_data["figure_caption"]["description_gen_max_images"]),
|
||||
include_di_content = cls.env_data["figure_caption"]["include_di_content"],
|
||||
model_endpoint= cls.env_data["figure_caption"]["model_endpoint"],
|
||||
model_key= cls.env_data["figure_caption"]["model_key"],
|
||||
model= cls.env_data["figure_caption"]["model"],
|
||||
azure_deployment= cls.env_data["figure_caption"]["azure_deployment"],
|
||||
api_version=cls.env_data["figure_caption"]["api_version"],
|
||||
prompts=prompt_data["caption"] if prompt_data and "caption" in prompt_data else None
|
||||
),
|
||||
data_configs=config_data if isinstance(config_data, list) else [config_data]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _load_env_yaml(self,env_yaml_path: str):
|
||||
"""Load environment variable YAML file."""
|
||||
if not os.path.exists(env_yaml_path):
|
||||
return
|
||||
|
||||
with open(env_yaml_path, 'r', encoding='utf-8') as f:
|
||||
self.env_data = yaml.safe_load(f)
|
||||
|
||||
# Set environment variables to system environment
|
||||
if self.env_data:
|
||||
for key, value in self.env_data.items():
|
||||
if isinstance(value, bool):
|
||||
value = str(value).lower()
|
||||
os.environ[str(key)] = str(value)
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration."""
|
||||
if not self.database.uri:
|
||||
raise ValueError("Database URI cannot be empty")
|
||||
|
||||
if not self.azure_services.form_recognizer_endpoint:
|
||||
raise ValueError("Form Recognizer endpoint cannot be empty")
|
||||
|
||||
if not self.azure_services.form_recognizer_key:
|
||||
raise ValueError("Form Recognizer key cannot be empty")
|
||||
|
||||
if self.processing.max_workers < 1:
|
||||
raise ValueError("Number of worker threads must be greater than 0")
|
||||
|
||||
|
||||
|
||||
class ServiceFactory:
|
||||
"""Service factory class, responsible for creating and managing various service instances."""
|
||||
|
||||
def __init__(self, config: ApplicationConfig):
|
||||
self.config = config
|
||||
self._form_recognizer_client = None
|
||||
|
||||
def get_form_recognizer_client(self) -> DocumentAnalysisClient:
|
||||
"""Get Form Recognizer client (singleton)."""
|
||||
if self._form_recognizer_client is None:
|
||||
self._form_recognizer_client = DocumentAnalysisClient(
|
||||
endpoint=self.config.azure_services.form_recognizer_endpoint,
|
||||
credential=AzureKeyCredential(self.config.azure_services.form_recognizer_key)
|
||||
)
|
||||
return self._form_recognizer_client
|
||||
|
||||
def get_database_engine(self):
|
||||
"""Get database engine."""
|
||||
return create_engine(
|
||||
self.config.database.uri,
|
||||
pool_size=self.config.database.pool_size,
|
||||
max_overflow=self.config.database.max_overflow,
|
||||
pool_timeout=self.config.database.pool_timeout
|
||||
)
|
||||
Reference in New Issue
Block a user