216 lines
9.9 KiB
Python
216 lines
9.9 KiB
Python
import json
|
|
import os
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
from urllib.parse import urlparse, urlunparse
|
|
import base64
|
|
import uuid
|
|
from openai import AzureOpenAI
|
|
from azure.storage.blob import ContainerClient
|
|
from azure.ai.documentintelligence import DocumentIntelligenceClient
|
|
from azure.ai.documentintelligence.models import DocumentContentFormat, AnalyzeResult, \
|
|
DocumentAnalysisFeature, AnalyzeOutputOption, DocumentSpan
|
|
from entity_models import DiResult, Document, FigureFlat
|
|
from utils import TOKEN_ESTIMATOR, custom_serializer, resize_image, file_rename
|
|
from resilient_http_pool import get_ai_inference_client
|
|
|
|
|
|
def di_extract(source_file_path:str, di_client: DocumentIntelligenceClient, directory_path:str, figure_sas_url:str, language:str="zh-Hans") -> DiResult:
|
|
di_features:list[str|DocumentAnalysisFeature] = []
|
|
allow_features_exts: list[str] = os.getenv("di_allow_features_ext", "").lower().split(';')
|
|
|
|
# get file name from source_file_path without extension
|
|
file_name = os.path.basename(source_file_path)
|
|
di_source_file_path = source_file_path
|
|
# PDF
|
|
# JPEG / JPG、PNG、BMP、TIFF、HEIF
|
|
|
|
|
|
file_ext: str = (source_file_path.split('.')[-1] if '.' in source_file_path.split('/')[-1] else '' ).lower()
|
|
|
|
if file_ext in ['jpg', 'jpeg', 'jpe', 'jfif', 'pjpeg', 'pjp', 'png', 'gif', 'webp', 'tif', 'tiff', 'bmp', 'dib', 'heif', 'heic', 'avif', 'apng', 'svg']:
|
|
di_source_file_path = resize_image(source_file_path)
|
|
|
|
# doc to docx
|
|
di_source_file_path = file_rename(di_source_file_path)
|
|
|
|
if os.getenv("di-hiRes",'').lower() == "true" and file_ext in allow_features_exts:
|
|
di_features.append(DocumentAnalysisFeature.OCR_HIGH_RESOLUTION)
|
|
if os.getenv("di-Formulas",'').lower() == "true" and file_ext in allow_features_exts:
|
|
di_features.append(DocumentAnalysisFeature.FORMULAS)
|
|
|
|
|
|
print(f"di_features: {di_features},file_path:{file_name}")
|
|
with open(di_source_file_path, "rb") as file:
|
|
poller = di_client.begin_analyze_document(model_id="prebuilt-layout", body=file,
|
|
features=di_features, output_content_format=DocumentContentFormat.MARKDOWN, output=[AnalyzeOutputOption.FIGURES]) # type: ignore
|
|
|
|
result: AnalyzeResult = poller.result()
|
|
extracted_doc = Document()
|
|
|
|
source_rel_file_path = os.path.relpath(source_file_path, directory_path)
|
|
extracted_doc.filepath = source_rel_file_path
|
|
|
|
result_content: str = result.content
|
|
# The operation id is required to later query individual figures
|
|
operation_id: str = str(poller.details.get("operation_id"))
|
|
|
|
output_folder = directory_path + "/.extracted/" + file_name
|
|
os.makedirs(f"{output_folder}", exist_ok=True)
|
|
extracted_doc.content = result_content
|
|
|
|
with open(f"{output_folder}/_merged_origin.md", "w", encoding="utf-8") as doc_meta_file:
|
|
doc_meta_file.write(result_content)
|
|
|
|
# Download and process images
|
|
figures = extract_figures(di_client, result, operation_id, directory_path, file_name, figure_sas_url)
|
|
di_result:DiResult = DiResult(
|
|
figures = figures,
|
|
di_content = result_content,
|
|
filepath= source_rel_file_path,
|
|
language=language
|
|
)
|
|
return di_result
|
|
|
|
|
|
|
|
def extract_figures(di_client: DocumentIntelligenceClient, result:AnalyzeResult, result_id:str, directory_path:str, file_name:str, figure_sas_url:str)->list[FigureFlat]:
|
|
"""Extracts figures and their metadata from the analyzed result."""
|
|
figures:list[FigureFlat] = []
|
|
|
|
base_path: Path = Path(os.path.join(directory_path, ".extracted", file_name, ".images"))
|
|
base_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(f"{base_path}/result.json", "w", encoding="utf-8") as figures_file:
|
|
json.dump(result, figures_file, default=custom_serializer, ensure_ascii=False, indent=4)
|
|
|
|
for figure in result.figures if result.figures is not None else []:
|
|
if not any(figure.spans):
|
|
continue
|
|
|
|
span:DocumentSpan = figure.spans[0]
|
|
|
|
# Image extraction
|
|
stream = di_client.get_analyze_result_figure(model_id=result.model_id, result_id=result_id, figure_id=figure.id)
|
|
image_bytes = b"".join(list(stream))
|
|
path_image: Path = Path(os.path.join(base_path, f"figure_{figure.id}.png"))
|
|
path_image.write_bytes(image_bytes)
|
|
|
|
blob_url = upload_figure(figure_sas_url,f"figure_{figure.id}.png", image_bytes)
|
|
image_str:str = base64.b64encode(image_bytes).decode('utf-8')
|
|
figures.append(FigureFlat(offset=span.offset, length=span.length, url=blob_url, content="",image=image_str,understand_flag=False,caption = figure.caption.content if figure.caption else ""))
|
|
return figures
|
|
|
|
|
|
|
|
# Compile once for efficiency
|
|
_specific_comments = re.compile(
|
|
r"""<!--\s* # opening
|
|
(?:PageFooter="[^"]*" # PageFooter="…"
|
|
|PageNumber="[^"]*" # PageNumber="…"
|
|
|PageBreak # PageBreak
|
|
|PageHeader="[^"]*") # PageHeader="…"
|
|
\s*--> # closing
|
|
""",
|
|
flags=re.VERBOSE
|
|
)
|
|
|
|
|
|
def remove_specific_comments(text: str) -> str:
|
|
return _specific_comments.sub('', text)
|
|
|
|
def retry_get_embedding(text: str, embedding_model_key:str, embedding_endpoint:str,min_chunk_size:int=10,retry_num:int = 3):
|
|
""" Retries getting embedding for the provided text until it succeeds or reaches the retry limit."""
|
|
full_metadata_size = TOKEN_ESTIMATOR.estimate_tokens(text)
|
|
if full_metadata_size >= min_chunk_size:
|
|
for i in range(retry_num):
|
|
try:
|
|
return get_embedding(text, embedding_model_key=embedding_model_key,embedding_model_endpoint=embedding_endpoint)
|
|
except Exception as e:
|
|
print(f"Error getting embedding for full_metadata_vector with error={e}, retrying, currently at {i + 1} retry, {retry_num - (i + 1)} retries left")
|
|
time.sleep(10)
|
|
raise Exception(f"Error getting embedding for full_metadata_vector={text}")
|
|
|
|
return None
|
|
|
|
def get_embedding(text:str, embedding_model_endpoint:str="", embedding_model_key:str="", azure_credential=None):
|
|
endpoint = embedding_model_endpoint if embedding_model_endpoint else os.environ.get("EMBEDDING_MODEL_ENDPOINT")
|
|
|
|
FLAG_EMBEDDING_MODEL = os.getenv("FLAG_EMBEDDING_MODEL", "AOAI")
|
|
FLAG_COHERE = os.getenv("FLAG_COHERE", "ENGLISH")
|
|
FLAG_AOAI = os.getenv("FLAG_AOAI", "V3")
|
|
|
|
if azure_credential is None and (endpoint is None or embedding_model_key is None):
|
|
raise Exception("EMBEDDING_MODEL_ENDPOINT and EMBEDDING_MODEL_KEY are required for embedding")
|
|
|
|
try:
|
|
if FLAG_EMBEDDING_MODEL == "AOAI":
|
|
endpoint_parts = endpoint.split("/openai/deployments/")
|
|
base_url = endpoint_parts[0]
|
|
deployment_id = endpoint_parts[1].split("/embeddings")[0]
|
|
api_version = endpoint_parts[1].split("api-version=")[1].split("&")[0]
|
|
if azure_credential is not None:
|
|
api_key = azure_credential.get_token("https://cognitiveservices.azure.com/.default").token
|
|
else:
|
|
api_key = embedding_model_key if embedding_model_key else os.getenv("AZURE_OPENAI_API_KEY")
|
|
|
|
client = AzureOpenAI(api_version=api_version, azure_endpoint=base_url, api_key=api_key)
|
|
if FLAG_AOAI == "V2":
|
|
embeddings = client.embeddings.create(model=deployment_id, input=text, timeout=120)
|
|
elif FLAG_AOAI == "V3":
|
|
embeddings = client.embeddings.create(model=deployment_id,
|
|
input=text,
|
|
dimensions=int(os.getenv("VECTOR_DIMENSION", 1536)), timeout=120)
|
|
|
|
return embeddings.model_dump()['data'][0]['embedding']
|
|
|
|
if FLAG_EMBEDDING_MODEL == "COHERE":
|
|
raise Exception("COHERE is not supported for now")
|
|
# if FLAG_COHERE == "MULTILINGUAL":
|
|
# key = embedding_model_key if embedding_model_key else os.getenv("COHERE_MULTILINGUAL_API_KEY")
|
|
# elif FLAG_COHERE == "ENGLISH":
|
|
# key = embedding_model_key if embedding_model_key else os.getenv("COHERE_ENGLISH_API_KEY")
|
|
# data, headers = get_payload_and_headers_cohere(text, key)
|
|
|
|
# with httpx.Client() as client:
|
|
# response = client.post(endpoint, json=data, headers=headers)
|
|
# result_content = response.json()
|
|
|
|
# return result_content["embeddings"][0]
|
|
|
|
if FLAG_EMBEDDING_MODEL:
|
|
headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {embedding_model_key}' }
|
|
data = { "model": FLAG_EMBEDDING_MODEL, "input": text }
|
|
|
|
client = get_ai_inference_client()
|
|
response = client.post(endpoint, json=data, headers=headers)
|
|
result_content = response.json()
|
|
|
|
return result_content["data"][0]["embedding"]
|
|
|
|
except Exception as e:
|
|
print(f"Error getting embeddings with endpoint={endpoint} with error={e}")
|
|
raise Exception(f"Error getting embeddings with endpoint={endpoint} with error={e}")
|
|
|
|
|
|
def upload_figure(blob_sas_url: str, orgin_file_name: str, data: bytes) -> str:
|
|
for i in range(3):
|
|
try:
|
|
# Upload image to Azure Blob
|
|
fileName = generate_filename()
|
|
container_client = ContainerClient.from_container_url(blob_sas_url)
|
|
blob = container_client.upload_blob(name=f"{fileName}.png", data=data)
|
|
return urlunparse(urlparse(blob.url)._replace(query='', fragment=''))
|
|
except Exception as e:
|
|
print(
|
|
f"Error uploading figure with error={e}, retrying, currently at {i + 1} retry, {3 - (i + 1)} retries left")
|
|
time.sleep(3)
|
|
raise Exception(f"Error uploading figure for: {orgin_file_name}")
|
|
|
|
def generate_filename(length:int=8):
|
|
"""Generate a unique 10-character ID using UUID"""
|
|
t = int(time.time() * 1000) % 1000000
|
|
base = uuid.uuid4().hex[:length]
|
|
return f"{t:06x}{base}"
|