Files
2025-09-26 17:15:54 +08:00

216 lines
9.9 KiB
Python

import json
import os
import re
import time
from pathlib import Path
from urllib.parse import urlparse, urlunparse
import base64
import uuid
from openai import AzureOpenAI
from azure.storage.blob import ContainerClient
from azure.ai.documentintelligence import DocumentIntelligenceClient
from azure.ai.documentintelligence.models import DocumentContentFormat, AnalyzeResult, \
DocumentAnalysisFeature, AnalyzeOutputOption, DocumentSpan
from entity_models import DiResult, Document, FigureFlat
from utils import TOKEN_ESTIMATOR, custom_serializer, resize_image, file_rename
from resilient_http_pool import get_ai_inference_client
def di_extract(source_file_path:str, di_client: DocumentIntelligenceClient, directory_path:str, figure_sas_url:str, language:str="zh-Hans") -> DiResult:
di_features:list[str|DocumentAnalysisFeature] = []
allow_features_exts: list[str] = os.getenv("di_allow_features_ext", "").lower().split(';')
# get file name from source_file_path without extension
file_name = os.path.basename(source_file_path)
di_source_file_path = source_file_path
# PDF
# JPEG / JPG、PNG、BMP、TIFF、HEIF
file_ext: str = (source_file_path.split('.')[-1] if '.' in source_file_path.split('/')[-1] else '' ).lower()
if file_ext in ['jpg', 'jpeg', 'jpe', 'jfif', 'pjpeg', 'pjp', 'png', 'gif', 'webp', 'tif', 'tiff', 'bmp', 'dib', 'heif', 'heic', 'avif', 'apng', 'svg']:
di_source_file_path = resize_image(source_file_path)
# doc to docx
di_source_file_path = file_rename(di_source_file_path)
if os.getenv("di-hiRes",'').lower() == "true" and file_ext in allow_features_exts:
di_features.append(DocumentAnalysisFeature.OCR_HIGH_RESOLUTION)
if os.getenv("di-Formulas",'').lower() == "true" and file_ext in allow_features_exts:
di_features.append(DocumentAnalysisFeature.FORMULAS)
print(f"di_features: {di_features},file_path:{file_name}")
with open(di_source_file_path, "rb") as file:
poller = di_client.begin_analyze_document(model_id="prebuilt-layout", body=file,
features=di_features, output_content_format=DocumentContentFormat.MARKDOWN, output=[AnalyzeOutputOption.FIGURES]) # type: ignore
result: AnalyzeResult = poller.result()
extracted_doc = Document()
source_rel_file_path = os.path.relpath(source_file_path, directory_path)
extracted_doc.filepath = source_rel_file_path
result_content: str = result.content
# The operation id is required to later query individual figures
operation_id: str = str(poller.details.get("operation_id"))
output_folder = directory_path + "/.extracted/" + file_name
os.makedirs(f"{output_folder}", exist_ok=True)
extracted_doc.content = result_content
with open(f"{output_folder}/_merged_origin.md", "w", encoding="utf-8") as doc_meta_file:
doc_meta_file.write(result_content)
# Download and process images
figures = extract_figures(di_client, result, operation_id, directory_path, file_name, figure_sas_url)
di_result:DiResult = DiResult(
figures = figures,
di_content = result_content,
filepath= source_rel_file_path,
language=language
)
return di_result
def extract_figures(di_client: DocumentIntelligenceClient, result:AnalyzeResult, result_id:str, directory_path:str, file_name:str, figure_sas_url:str)->list[FigureFlat]:
"""Extracts figures and their metadata from the analyzed result."""
figures:list[FigureFlat] = []
base_path: Path = Path(os.path.join(directory_path, ".extracted", file_name, ".images"))
base_path.mkdir(parents=True, exist_ok=True)
with open(f"{base_path}/result.json", "w", encoding="utf-8") as figures_file:
json.dump(result, figures_file, default=custom_serializer, ensure_ascii=False, indent=4)
for figure in result.figures if result.figures is not None else []:
if not any(figure.spans):
continue
span:DocumentSpan = figure.spans[0]
# Image extraction
stream = di_client.get_analyze_result_figure(model_id=result.model_id, result_id=result_id, figure_id=figure.id)
image_bytes = b"".join(list(stream))
path_image: Path = Path(os.path.join(base_path, f"figure_{figure.id}.png"))
path_image.write_bytes(image_bytes)
blob_url = upload_figure(figure_sas_url,f"figure_{figure.id}.png", image_bytes)
image_str:str = base64.b64encode(image_bytes).decode('utf-8')
figures.append(FigureFlat(offset=span.offset, length=span.length, url=blob_url, content="",image=image_str,understand_flag=False,caption = figure.caption.content if figure.caption else ""))
return figures
# Compile once for efficiency
_specific_comments = re.compile(
r"""<!--\s* # opening
(?:PageFooter="[^"]*" # PageFooter=""
|PageNumber="[^"]*" # PageNumber=""
|PageBreak # PageBreak
|PageHeader="[^"]*") # PageHeader=""
\s*--> # closing
""",
flags=re.VERBOSE
)
def remove_specific_comments(text: str) -> str:
return _specific_comments.sub('', text)
def retry_get_embedding(text: str, embedding_model_key:str, embedding_endpoint:str,min_chunk_size:int=10,retry_num:int = 3):
""" Retries getting embedding for the provided text until it succeeds or reaches the retry limit."""
full_metadata_size = TOKEN_ESTIMATOR.estimate_tokens(text)
if full_metadata_size >= min_chunk_size:
for i in range(retry_num):
try:
return get_embedding(text, embedding_model_key=embedding_model_key,embedding_model_endpoint=embedding_endpoint)
except Exception as e:
print(f"Error getting embedding for full_metadata_vector with error={e}, retrying, currently at {i + 1} retry, {retry_num - (i + 1)} retries left")
time.sleep(10)
raise Exception(f"Error getting embedding for full_metadata_vector={text}")
return None
def get_embedding(text:str, embedding_model_endpoint:str="", embedding_model_key:str="", azure_credential=None):
endpoint = embedding_model_endpoint if embedding_model_endpoint else os.environ.get("EMBEDDING_MODEL_ENDPOINT")
FLAG_EMBEDDING_MODEL = os.getenv("FLAG_EMBEDDING_MODEL", "AOAI")
FLAG_COHERE = os.getenv("FLAG_COHERE", "ENGLISH")
FLAG_AOAI = os.getenv("FLAG_AOAI", "V3")
if azure_credential is None and (endpoint is None or embedding_model_key is None):
raise Exception("EMBEDDING_MODEL_ENDPOINT and EMBEDDING_MODEL_KEY are required for embedding")
try:
if FLAG_EMBEDDING_MODEL == "AOAI":
endpoint_parts = endpoint.split("/openai/deployments/")
base_url = endpoint_parts[0]
deployment_id = endpoint_parts[1].split("/embeddings")[0]
api_version = endpoint_parts[1].split("api-version=")[1].split("&")[0]
if azure_credential is not None:
api_key = azure_credential.get_token("https://cognitiveservices.azure.com/.default").token
else:
api_key = embedding_model_key if embedding_model_key else os.getenv("AZURE_OPENAI_API_KEY")
client = AzureOpenAI(api_version=api_version, azure_endpoint=base_url, api_key=api_key)
if FLAG_AOAI == "V2":
embeddings = client.embeddings.create(model=deployment_id, input=text, timeout=120)
elif FLAG_AOAI == "V3":
embeddings = client.embeddings.create(model=deployment_id,
input=text,
dimensions=int(os.getenv("VECTOR_DIMENSION", 1536)), timeout=120)
return embeddings.model_dump()['data'][0]['embedding']
if FLAG_EMBEDDING_MODEL == "COHERE":
raise Exception("COHERE is not supported for now")
# if FLAG_COHERE == "MULTILINGUAL":
# key = embedding_model_key if embedding_model_key else os.getenv("COHERE_MULTILINGUAL_API_KEY")
# elif FLAG_COHERE == "ENGLISH":
# key = embedding_model_key if embedding_model_key else os.getenv("COHERE_ENGLISH_API_KEY")
# data, headers = get_payload_and_headers_cohere(text, key)
# with httpx.Client() as client:
# response = client.post(endpoint, json=data, headers=headers)
# result_content = response.json()
# return result_content["embeddings"][0]
if FLAG_EMBEDDING_MODEL:
headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {embedding_model_key}' }
data = { "model": FLAG_EMBEDDING_MODEL, "input": text }
client = get_ai_inference_client()
response = client.post(endpoint, json=data, headers=headers)
result_content = response.json()
return result_content["data"][0]["embedding"]
except Exception as e:
print(f"Error getting embeddings with endpoint={endpoint} with error={e}")
raise Exception(f"Error getting embeddings with endpoint={endpoint} with error={e}")
def upload_figure(blob_sas_url: str, orgin_file_name: str, data: bytes) -> str:
for i in range(3):
try:
# Upload image to Azure Blob
fileName = generate_filename()
container_client = ContainerClient.from_container_url(blob_sas_url)
blob = container_client.upload_blob(name=f"{fileName}.png", data=data)
return urlunparse(urlparse(blob.url)._replace(query='', fragment=''))
except Exception as e:
print(
f"Error uploading figure with error={e}, retrying, currently at {i + 1} retry, {3 - (i + 1)} retries left")
time.sleep(3)
raise Exception(f"Error uploading figure for: {orgin_file_name}")
def generate_filename(length:int=8):
"""Generate a unique 10-character ID using UUID"""
t = int(time.time() * 1000) % 1000000
base = uuid.uuid4().hex[:length]
return f"{t:06x}{base}"