import json import os import re import time from pathlib import Path from urllib.parse import urlparse, urlunparse import base64 import uuid from openai import AzureOpenAI from azure.storage.blob import ContainerClient from azure.ai.documentintelligence import DocumentIntelligenceClient from azure.ai.documentintelligence.models import DocumentContentFormat, AnalyzeResult, \ DocumentAnalysisFeature, AnalyzeOutputOption, DocumentSpan from entity_models import DiResult, Document, FigureFlat from utils import TOKEN_ESTIMATOR, custom_serializer, resize_image, file_rename from resilient_http_pool import get_ai_inference_client def di_extract(source_file_path:str, di_client: DocumentIntelligenceClient, directory_path:str, figure_sas_url:str, language:str="zh-Hans") -> DiResult: di_features:list[str|DocumentAnalysisFeature] = [] allow_features_exts: list[str] = os.getenv("di_allow_features_ext", "").lower().split(';') # get file name from source_file_path without extension file_name = os.path.basename(source_file_path) di_source_file_path = source_file_path # PDF # JPEG / JPG、PNG、BMP、TIFF、HEIF file_ext: str = (source_file_path.split('.')[-1] if '.' in source_file_path.split('/')[-1] else '' ).lower() if file_ext in ['jpg', 'jpeg', 'jpe', 'jfif', 'pjpeg', 'pjp', 'png', 'gif', 'webp', 'tif', 'tiff', 'bmp', 'dib', 'heif', 'heic', 'avif', 'apng', 'svg']: di_source_file_path = resize_image(source_file_path) # doc to docx di_source_file_path = file_rename(di_source_file_path) if os.getenv("di-hiRes",'').lower() == "true" and file_ext in allow_features_exts: di_features.append(DocumentAnalysisFeature.OCR_HIGH_RESOLUTION) if os.getenv("di-Formulas",'').lower() == "true" and file_ext in allow_features_exts: di_features.append(DocumentAnalysisFeature.FORMULAS) print(f"di_features: {di_features},file_path:{file_name}") with open(di_source_file_path, "rb") as file: poller = di_client.begin_analyze_document(model_id="prebuilt-layout", body=file, features=di_features, output_content_format=DocumentContentFormat.MARKDOWN, output=[AnalyzeOutputOption.FIGURES]) # type: ignore result: AnalyzeResult = poller.result() extracted_doc = Document() source_rel_file_path = os.path.relpath(source_file_path, directory_path) extracted_doc.filepath = source_rel_file_path result_content: str = result.content # The operation id is required to later query individual figures operation_id: str = str(poller.details.get("operation_id")) output_folder = directory_path + "/.extracted/" + file_name os.makedirs(f"{output_folder}", exist_ok=True) extracted_doc.content = result_content with open(f"{output_folder}/_merged_origin.md", "w", encoding="utf-8") as doc_meta_file: doc_meta_file.write(result_content) # Download and process images figures = extract_figures(di_client, result, operation_id, directory_path, file_name, figure_sas_url) di_result:DiResult = DiResult( figures = figures, di_content = result_content, filepath= source_rel_file_path, language=language ) return di_result def extract_figures(di_client: DocumentIntelligenceClient, result:AnalyzeResult, result_id:str, directory_path:str, file_name:str, figure_sas_url:str)->list[FigureFlat]: """Extracts figures and their metadata from the analyzed result.""" figures:list[FigureFlat] = [] base_path: Path = Path(os.path.join(directory_path, ".extracted", file_name, ".images")) base_path.mkdir(parents=True, exist_ok=True) with open(f"{base_path}/result.json", "w", encoding="utf-8") as figures_file: json.dump(result, figures_file, default=custom_serializer, ensure_ascii=False, indent=4) for figure in result.figures if result.figures is not None else []: if not any(figure.spans): continue span:DocumentSpan = figure.spans[0] # Image extraction stream = di_client.get_analyze_result_figure(model_id=result.model_id, result_id=result_id, figure_id=figure.id) image_bytes = b"".join(list(stream)) path_image: Path = Path(os.path.join(base_path, f"figure_{figure.id}.png")) path_image.write_bytes(image_bytes) blob_url = upload_figure(figure_sas_url,f"figure_{figure.id}.png", image_bytes) image_str:str = base64.b64encode(image_bytes).decode('utf-8') figures.append(FigureFlat(offset=span.offset, length=span.length, url=blob_url, content="",image=image_str,understand_flag=False,caption = figure.caption.content if figure.caption else "")) return figures # Compile once for efficiency _specific_comments = re.compile( r""" # closing """, flags=re.VERBOSE ) def remove_specific_comments(text: str) -> str: return _specific_comments.sub('', text) def retry_get_embedding(text: str, embedding_model_key:str, embedding_endpoint:str,min_chunk_size:int=10,retry_num:int = 3): """ Retries getting embedding for the provided text until it succeeds or reaches the retry limit.""" full_metadata_size = TOKEN_ESTIMATOR.estimate_tokens(text) if full_metadata_size >= min_chunk_size: for i in range(retry_num): try: return get_embedding(text, embedding_model_key=embedding_model_key,embedding_model_endpoint=embedding_endpoint) except Exception as e: print(f"Error getting embedding for full_metadata_vector with error={e}, retrying, currently at {i + 1} retry, {retry_num - (i + 1)} retries left") time.sleep(10) raise Exception(f"Error getting embedding for full_metadata_vector={text}") return None def get_embedding(text:str, embedding_model_endpoint:str="", embedding_model_key:str="", azure_credential=None): endpoint = embedding_model_endpoint if embedding_model_endpoint else os.environ.get("EMBEDDING_MODEL_ENDPOINT") FLAG_EMBEDDING_MODEL = os.getenv("FLAG_EMBEDDING_MODEL", "AOAI") FLAG_COHERE = os.getenv("FLAG_COHERE", "ENGLISH") FLAG_AOAI = os.getenv("FLAG_AOAI", "V3") if azure_credential is None and (endpoint is None or embedding_model_key is None): raise Exception("EMBEDDING_MODEL_ENDPOINT and EMBEDDING_MODEL_KEY are required for embedding") try: if FLAG_EMBEDDING_MODEL == "AOAI": endpoint_parts = endpoint.split("/openai/deployments/") base_url = endpoint_parts[0] deployment_id = endpoint_parts[1].split("/embeddings")[0] api_version = endpoint_parts[1].split("api-version=")[1].split("&")[0] if azure_credential is not None: api_key = azure_credential.get_token("https://cognitiveservices.azure.com/.default").token else: api_key = embedding_model_key if embedding_model_key else os.getenv("AZURE_OPENAI_API_KEY") client = AzureOpenAI(api_version=api_version, azure_endpoint=base_url, api_key=api_key) if FLAG_AOAI == "V2": embeddings = client.embeddings.create(model=deployment_id, input=text, timeout=120) elif FLAG_AOAI == "V3": embeddings = client.embeddings.create(model=deployment_id, input=text, dimensions=int(os.getenv("VECTOR_DIMENSION", 1536)), timeout=120) return embeddings.model_dump()['data'][0]['embedding'] if FLAG_EMBEDDING_MODEL == "COHERE": raise Exception("COHERE is not supported for now") # if FLAG_COHERE == "MULTILINGUAL": # key = embedding_model_key if embedding_model_key else os.getenv("COHERE_MULTILINGUAL_API_KEY") # elif FLAG_COHERE == "ENGLISH": # key = embedding_model_key if embedding_model_key else os.getenv("COHERE_ENGLISH_API_KEY") # data, headers = get_payload_and_headers_cohere(text, key) # with httpx.Client() as client: # response = client.post(endpoint, json=data, headers=headers) # result_content = response.json() # return result_content["embeddings"][0] if FLAG_EMBEDDING_MODEL: headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {embedding_model_key}' } data = { "model": FLAG_EMBEDDING_MODEL, "input": text } client = get_ai_inference_client() response = client.post(endpoint, json=data, headers=headers) result_content = response.json() return result_content["data"][0]["embedding"] except Exception as e: print(f"Error getting embeddings with endpoint={endpoint} with error={e}") raise Exception(f"Error getting embeddings with endpoint={endpoint} with error={e}") def upload_figure(blob_sas_url: str, orgin_file_name: str, data: bytes) -> str: for i in range(3): try: # Upload image to Azure Blob fileName = generate_filename() container_client = ContainerClient.from_container_url(blob_sas_url) blob = container_client.upload_blob(name=f"{fileName}.png", data=data) return urlunparse(urlparse(blob.url)._replace(query='', fragment='')) except Exception as e: print( f"Error uploading figure with error={e}, retrying, currently at {i + 1} retry, {3 - (i + 1)} retries left") time.sleep(3) raise Exception(f"Error uploading figure for: {orgin_file_name}") def generate_filename(length:int=8): """Generate a unique 10-character ID using UUID""" t = int(time.time() * 1000) % 1000000 base = uuid.uuid4().hex[:length] return f"{t:06x}{base}"