Files
catonline_ai/vw-document-ai-indexer/vllm_extractor.py
2025-09-26 17:15:54 +08:00

483 lines
21 KiB
Python

import json
import os
import time
from typing import Any, List
import base64
from app_config import ApplicationConfig
from azure_index_service import get_cloud_api_client
from pdf2image import convert_from_path # type: ignore
import numpy as np
from PIL import Image
from langchain_openai import ChatOpenAI ,AzureChatOpenAI
from langchain.schema.messages import SystemMessage
from langchain_core.messages import AIMessage,HumanMessage,ToolMessage
from di_extractor import FigureFlat
from entity_models import DiResult, Document, UnsupportedFormatError
from resilient_http_pool import get_ai_inference_client
RETRY_COUNT = 3
def vision_extract(pdf_file_path:str, file_format:str, directory_path:str, vllm_endpoint:str, vllm_key:str) -> List[Document]:
if file_format not in ["pdf"]:
raise UnsupportedFormatError(f"Unsupported file format: {file_format}")
source_rel_file_path = os.path.relpath(pdf_file_path, directory_path)
image_dir = directory_path + "/.images/" + source_rel_file_path
print(f"Converting to images: {pdf_file_path}")
pdf_to_images(pdf_file_path, image_dir)
print(f"Converted to images: {pdf_file_path}")
image_filenames = os.listdir(image_dir)
image_filenames.sort()
rsltDocs: List[Document] = []
page_index = 0
for image_filename in image_filenames:
if image_filename.endswith(".webp"):
print(f"extracting: {image_dir}/{image_filename}")
image_path = os.path.join(image_dir, image_filename)
rsltDoc = None
if page_index == 0:
rsltDoc = extract_from_image(image_path, vllm_endpoint, vllm_key, directory_path, source_rel_file_path, page_index)
else:
rsltDoc = extract_from_image(image_path, vllm_endpoint, vllm_key, directory_path, source_rel_file_path, page_index, rsltDocs[page_index-1])
rsltDocs.append(rsltDoc)
page_index = page_index+1
return rsltDocs
def pdf_to_images(pdf_path, output_folder, dpi=250):
untrimed_folder = output_folder+"/.untrimed"
os.makedirs(untrimed_folder, exist_ok=True)
# Convert PDF to images
convert_from_path(pdf_path, dpi=dpi, output_folder=untrimed_folder,fmt="png", paths_only=True)
image_filenames = os.listdir(untrimed_folder)
image_filenames.sort()
# # clear the output folder
# for file in os.listdir(output_folder):
# os.remove(os.path.join(output_folder, file))
# Save images to the output folder
for i, image_filename in enumerate(image_filenames):
# generate index num with fixed width of 6 digits
# load image
image = Image.open(f"{untrimed_folder}/{image_filename}")
trimmed_image = trim_image(image)
index = str(i + 1).zfill(6)
image_path = f"{output_folder}/{index}.webp"
trimmed_image.save(image_path, format="WEBP")
os.remove(f"{untrimed_folder}/{image_filename}")
def trim_image(input_image: Image.Image) -> Image.Image:
"""
Trim the margins of a scanned document image, ignoring noise and small specks.
Args:
input_image (Image.Image): The input PIL Image object.
Returns:
Image.Image: The cropped PIL Image object.
"""
# Convert the image to grayscale
grayscale_image = input_image.convert("L")
# Convert grayscale to numpy array
image_array = np.array(grayscale_image)
# Apply a threshold to create a binary image
threshold = 240 # Adjust this value if needed
binary_image = (image_array < threshold).astype(np.uint8)
# Find the bounding box of the non-zero regions
rows = np.any(binary_image, axis=1)
cols = np.any(binary_image, axis=0)
if not rows.any() or not cols.any():
# If the image is completely empty or noise-free, return the original
return input_image
ymin, ymax = np.where(rows)[0][[0, -1]]
xmin, xmax = np.where(cols)[0][[0, -1]]
# Add a small margin (optional, remove if not needed)
margin = 10
ymin = max(0, ymin - margin)
ymax = min(binary_image.shape[0], ymax + margin)
xmin = max(0, xmin - margin)
xmax = min(binary_image.shape[1], xmax + margin)
# Crop the image using the calculated bounding box
cropped_image = input_image.crop((xmin, ymin, xmax + 1, ymax + 1))
return cropped_image
tips = "- The document is about standard/regulatory for a automobile industry company to refer. So prioritize extracting content about standards/regulatory/compliance carefully"
# Define the messages for the chat
SYS_MSG_Flow_Layout = f"""# Role
You are specialized in extracting content from screenshots of document.
# Rules
- You will receive a page screenshot from a multi-pages document. Extract content into a structured markdown format.
- Identify if the page is Table of Contents(目录, 目次) or empty page(after ignoring watermarks)
- If yes, just ignore the whole page, and output "[]" only
- If no, you should follow below rules to extract content
- Recognize hierarchical section header, and use appropriate markdown symbols "#" to reflect its hierarchy level.
- Detection:
- Identify line of section header that beginning with a hierarchical section numbering part and optionally followed by a text part. The section numbering part conatains only numbers, alphabets, and dots. The section numbering part is a tiered (multi-level) numbering system. For example: "2.3.17 示例标题", "1 Sample Title", "6.1.2.5", "A.14.8.9 示例标题".
- Each section header is just one line, and the section number is at the beginning of the line.
- Header Hierarchy Level Mapping:
- The section numbering part is a tiered (multi-level) numbering system. Section number at each hierarchy level in section numbering part is seperated by dot(.), so the count of separated section number reflects its the section header's hierarchy levels. For example, the header "4.13.2 Sample" should be considered as an H3 level.
- Use appropriate markdown symbols "#" to reflect section headers's hierarchy levels. **The number of "#" symbols should correspond to the depth of the section level.** For instance:
- "1 section Title" should be output as "# 1 section Title"
- "2.3.17 section Title" should be output as "### 2.3.17 section Title"
- "A.14.8.9 section Title" should be output as "#### A.14.8.9 section Title"
- **Table title or picture title should NOT be considered as a section header, even if it is at beginning of the page. Output them as format "[table/picture titles]", for example: "[表 1.2 示例]", "[图5 示例]")**
- IMPORTANT: The screenshot is taken from one page of a multi-page document, note that it represents only a single page, not the entire document.**The beginning area of the page may not fall under a section header. Nevertheless, ensure that you still extract content from this area, even if it is not explicitly labeled under a section header.**
- Embedded Pictures/Graphs/Diagram:
- If the embedded picture/graph/diagram is major content and can be understood clearly, descript it as caption, using format: `![<caption>](picture)`
- Otherwise, just use a placeholder: `![](picture)`
# Tips
- Carefully recognize scientific symbols and formulas, and output them professionally and accurately.
- If a table is not a blank template, you should extract using markdown table markup
- Accurately recognize the content according to the screenshot, and do not speculate any content.
- Ignore any diagonally arranged watermarks present in the document.
- The page footer and header can be ignored.
{tips}
"""
SYS_MSG_Slides_Layout = f"""# Role
You are specialized in extracting content from screenshots of a slides deck like PPT.
# Rules
- You will receive a page screenshot from a multi-pages deck. Extract content into a structured markdown format.
- Recognize title headers from the page and use appropriate markdown symbols "#" to reflect their hierarchy levels. Every page should have one H1 title header.
- Embedded Pictures/Graphs/Diagram: If there are embedded pictures/figures, try your best to understand them, and descript them into caption paragraphs.
# Tips
- Carefully recognize scientific symbols and formulas, and output them professionally and accurately.
- If a table is not a blank template, you should extract using markdown table markup
- Accurately recognize the content according to the screenshot, and do not speculate any content.
- Ignore any diagonally arranged watermarks present in the document. Identify if the page is empty after ignoring watermarks. If yes, just ignore this page, and output "[]" only
{tips}
"""
SYS_MSG_Cover = f"""# Role
You are specialized in extracting content from screenshots of document.
# Rules
- You will receive the cover page from a multi-pages document. Extract content into a structured JSON format.
- Recognize what type of Document Schema it is, there are the two below types of document layout schema:
- flow: Like a page of Office Words document, mainly in flow document layout.
- slides: Like a page of Office PowerPoint document, mainly in a presenting slide layout.
- other: Not looks like either of abvoe document layout schema type
- The cover page may contain the following information: main_title, sub_title, publisher, publised_date, document_code, document_category.
- Detect the primary and secondary language of the document. Use language code as their values. The default primary language is `zh-Hans`. If there are titles in secondary language, they should also be included as well.
- Whole page should be extracted as markdown string and stored in the `whole_page` field.
- The output JSON schema:
- document_schema
- main_title
- sub_title
- publisher
- publised_date
- document_code
- document_category
- main_title_sec_language
- sub_title_sec_language
- primary_language
- secondary_language
- whole_page
# Tips
- Accurately recognize the text content according to the screenshot, and do not speculate any content.
- Ignore any diagonally arranged watermarks present in the document.
- Don't use horizontal divider("---") or simmilar markdown syntax to separate the content.
{tips}
"""
USER_MSG = """# task
Recognize screenshot of this document cover page, return the result
"""
def extract_from_image(image_path, vllm_endpoint, vllm_key, directory_path, source_rel_file_path, page_index, pre_document:Document = None) -> Document:
encoded_image = base64.b64encode(open(image_path, 'rb').read()).decode('ascii')
file_ext = image_path.split(".")[-1]
system_msg = ""
if page_index==0:
system_msg = SYS_MSG_Cover
else:
if pre_document.document_schema == "flow":
system_msg = SYS_MSG_Flow_Layout
elif pre_document.document_schema == "slides":
system_msg = SYS_MSG_Slides_Layout
else:
raise ValueError(f"schema = {pre_document.document_schema}, not supported")
headers = {
"Content-Type": "application/json",
"api-key": vllm_key,
}
payload = {
"messages": [
{
"role": "system",
"content": [
{
"type": "text",
"text": system_msg
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": USER_MSG
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/{file_ext};base64,{encoded_image}"
}
}
]
}
],
"temperature": 0
}
response = None
for i in range(RETRY_COUNT):
try:
client = get_ai_inference_client()
response = client.post(vllm_endpoint, headers=headers, json=payload, timeout=180)
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
break
except Exception as e:
print(f"Error extract_from_image {image_path} with error={e}, retrying, current at {i + 1} retry, {RETRY_COUNT - (i + 1)} retries left")
time.sleep(15)
rslt = None
if response and response.status_code != 200:
if response.status_code == 400:
try:
rsltObj = response.json()
if rsltObj["error"]["inner_error"]["code"] == "ResponsibleAIPolicyViolation":
rslt = "[]"
print(f"Ignored: {image_path}. Error extract_from_image with status_code={response.status_code}\n {response.text}")
except:
raise Exception(f"Error extract_from_image {image_path} with status_code={response.status_code}\n {response.text}")
else:
raise Exception(f"Error extract_from_image {image_path} with status_code={response.status_code}\n {response.text}")
if rslt is None and response:
rslt = response.json()["choices"][0]["message"]["content"]
# img_tag = image_content_to_tag(caption)
# mapping = {img_tag: f"data:image/{file_ext};base64,{encoded_image}"}
# if rslt starts with ```markdown
if rslt.startswith("```"):
# remove the first line and the last line
rslt = rslt.split("\n")[1:-1]
rslt = "\n".join(rslt)
## add a page number at the first line of the result text
# rslt = f"[Page {image_filename.replace('page_', '').replace('.png', '')}]\n\n{rslt}\n\n\n\n"
page_index_output = str(page_index + 1).zfill(6)
output_folder = directory_path + "/.extracted/" + source_rel_file_path
os.makedirs(f"{output_folder}", exist_ok=True)
document = None
if page_index==0:
with open(f"{output_folder}/{page_index_output}.json", "w") as file:
file.write(rslt)
rsltObj = json.loads(rslt)
document_schema = rsltObj.get("document_schema", "flow").lower()
if document_schema == "other":
document_schema = "flow"
document = Document(
document_schema = document_schema,
main_title = rsltObj.get("main_title", "") or "",
sub_title = rsltObj.get("sub_title", "") or "",
publisher = rsltObj.get("publisher", "") or "",
document_code = rsltObj.get("document_code", "") or "",
document_category = rsltObj.get("document_category", "") or "",
main_title_sec_language = rsltObj.get("main_title_sec_language", "") or "",
sub_title_sec_language = rsltObj.get("sub_title_sec_language", "") or "",
primary_language= rsltObj.get("primary_language", ""),
secondary_language= rsltObj.get("secondary_language", ""),
)
if document.sub_title != "":
document.title = f"{document.main_title}-{document.sub_title}"
else:
document.title = document.main_title
document.doc_metadata = f"{document.main_title}, {document.sub_title}, {document.document_code}, {document.main_title_sec_language}, {document.sub_title_sec_language}"
document.filepath = source_rel_file_path
document.content = rsltObj.get("whole_page", "")
else:
with open(f"{output_folder}/{page_index_output}.md", "w") as file:
file.write(rslt)
document = Document(
document_schema = pre_document.document_schema,
main_title = pre_document.main_title,
sub_title = pre_document.sub_title,
publisher = pre_document.publisher,
document_code = pre_document.document_code,
document_category = pre_document.document_category,
main_title_sec_language = pre_document.main_title_sec_language,
sub_title_sec_language = pre_document.sub_title_sec_language,
primary_language= pre_document.primary_language,
secondary_language= pre_document.secondary_language,
title = pre_document.title,
doc_metadata = pre_document.doc_metadata,
filepath = pre_document.filepath,
)
document.content = rslt
return document
def understand_with_langchain(image:bytes, mime_type: str, captioning_model_endpoint: str, captioning_model_key: str,model:str|None,azure_deployment:str|None=None,api_version:str|None=None,language:str|None=None, prompts: dict[str,Any]=None):
"""
Use LangChain to automatically adapt to various model platforms for image understanding
Supports OpenAI, Azure OpenAI, Tongyi Qianwen, Bailian and other platforms
"""
# Select prompt words based on language and description type
lang_key = "zh-Hans" if language == "zh-Hans" else "en"
if prompts is None or len(prompts) == 0:
prompts = {
"zh-Hans": { "system": "您是一个帮助用户寻找描述性字幕的字幕模型。", "user": "描述此图像就像您将其描述给看不见的人一样。" },
"en": { "system": "You are a captioning model that helps uses find descriptive captions.", "user": "Describe this image as if you were describing it to someone who can't see it." }
}
if lang_key in prompts.keys():
prompt = prompts[lang_key]
elif "en" in prompts.keys() :
prompt = prompts["en"]
else:
prompt =prompts[prompts.keys()[0]]
# Encoded images
encoded_image = base64.b64encode(image).decode('utf-8')
image_url = f"data:image/{mime_type};base64,{encoded_image}"
http_client = get_cloud_api_client()
# Judging the model type according to endpoint and initialize the corresponding LangChain client
llm:Any=None
for i in range(RETRY_COUNT):
try:
if "openai.azure" in captioning_model_endpoint:
llm = AzureChatOpenAI(azure_deployment=azure_deployment,api_key=captioning_model_key, azure_endpoint=captioning_model_endpoint,api_version=api_version, temperature=0, http_client=http_client)
else:
llm = ChatOpenAI(base_url=captioning_model_endpoint, api_key=captioning_model_key, model=model, temperature=0, http_client=http_client)
# Build the message
messages = [
SystemMessage(content=prompt["system"]),
HumanMessage(content=[{"type": "text", "text": prompt["user"]}, {"type": "image_url", "image_url": {"url": image_url}} ])
]
# 调用模型
response = llm.invoke(messages)
caption = response.content
return caption
except Exception as e:
print(f"Error getting caption with langchain (attempt {i+1}/{RETRY_COUNT}): {e}")
if i < RETRY_COUNT - 1:
time.sleep(5)
else:
# The last attempt failed
raise Exception(f"Failed to get caption after {RETRY_COUNT} attempts: {e}")
return ""
def process_document_figures(di_result:DiResult|None=None,config:ApplicationConfig|None=None) -> DiResult:
"""
Perform figure fusion on the extracted document content.
"""
# Implement figure fusion logic here
if di_result is None:
raise Exception("di_result cannot be None")
if config is None:
raise ValueError("config is None")
description_gen_max_images: int = config.caption.description_gen_max_images
vllm_endpoint:str = config.caption.model_endpoint
vllm_key:str = config.caption.model_key
captioning_model:str = config.caption.model
api_version:str = config.caption.api_version
azure_deployment:str = config.caption.azure_deployment
include_di_content: bool = config.caption.include_di_content
figures = di_result.figures or []
processed_figures:List[FigureFlat] = []
content:str = di_result.di_content
len_figures:int = len(figures)
for figure in figures:
figure_content:str= content[figure.offset:figure.offset + figure.length]
if not figure_content.lstrip().startswith("<figure>"):
continue
image_bytes = base64.b64decode(figure.image)
language = di_result.language
# Image content generation
vision_content:str = ""
if figure.understand_flag:
vision_content = figure.content
elif include_di_content:
if len_figures < description_gen_max_images:
vision_content = understand_with_langchain(image=image_bytes, mime_type="png", captioning_model_endpoint=vllm_endpoint, captioning_model_key=vllm_key, model=captioning_model,azure_deployment=azure_deployment,api_version=api_version, language=language, prompts=config.caption.prompts)
figure.understand_flag = True
else:
vision_content = content[figure.offset:figure.offset + figure.length].lstrip("<figure>").rstrip("</figure>").strip()
vision_content = ' '.join(line.strip() for line in vision_content.splitlines())
vision_content = f"<figcaption>{figure.caption}</figcaption>" + vision_content
if not include_di_content and figure.caption and len(figure.caption)>0:
vision_content = f"<figcaption>{figure.caption}</figcaption>"
figure.content = vision_content
processed_figures.append(figure)
return di_result