Files
catonline_ai/vw-document-ai-indexer/hierarchy_fix.py
2025-09-26 17:15:54 +08:00

474 lines
22 KiB
Python

"""
Fixed the problem of mismatch between the upper and lower titles in MD documents. Solve the problem that the # number of the lower title is raised to the same as the upper title, or is higher than the upper title.
"""
import re
from typing import Any, List, Dict, Optional
class HeaderInfo:
"""Title information"""
def __init__(self, line_number: int, original_line: str, hash_count: int,
level: int, number_pattern: str, title_text: str):
self.line_number = line_number
self.original_line = original_line
self.hash_count = hash_count
self.level = level
self.number_pattern = number_pattern
self.title_text = title_text
self.correct_hash_count = hash_count # Will be updated by Fixer
class HierarchyFixer:
"""Special fixer for title hierarchy # number mismatch issues"""
def __init__(self):
# Number pattern matching - supports both formats with and without trailing dots
self.number_patterns = [
r'^(\d+)\.?$', # 1 or 1.
r'^(\d+)\.(\d+)\.?$', # 1.1 or 1.1.
r'^(\d+)\.(\d+)\.(\d+)\.?$', # 1.1.1 or 1.1.1.
r'^(\d+)\.(\d+)\.(\d+)\.(\d+)\.?$', # 1.1.1.1 or 1.1.1.1.
r'^(\d+)\.(\d+)\.(\d+)\.(\d+)\.(\d+)\.?$', # 1.1.1.1.1 or 1.1.1.1.1.
r'^(\d+)\.(\d+)\.(\d+)\.(\d+)\.(\d+)\.(\d+)\.?$', # 1.1.1.1.1.1 or 1.1.1.1.1.1.
]
# Letter+number pattern matching - supports both "A.x.x.x" and "C. x.x.x" formats
self.letter_number_patterns = [
# Single letter: A, B, C (followed by space or end)
(r'^([A-Z])(?:\s|$)', 1),
# Letter + space + numbers: "C. 1", "A. 2"
(r'^([A-Z])\.\s+(\d+)(?:\s|$)', 2),
(r'^([A-Z])\.\s+(\d+)\.(\d+)(?:\s|$)', 3), # C. 1.1, A. 2.3
(r'^([A-Z])\.\s+(\d+)\.(\d+)\.(\d+)(?:\s|$)', 4), # C. 1.1.1, A. 2.3.4
(r'^([A-Z])\.\s+(\d+)\.(\d+)\.(\d+)\.(\d+)(?:\s|$)', 5), # C. 1.1.1.1, A. 2.3.4.5
(r'^([A-Z])\.\s+(\d+)\.(\d+)\.(\d+)\.(\d+)\.(\d+)(?:\s|$)', 6), # C. 1.1.1.1.1, A. 2.3.4.5.6
# Compact format (no space): A.1, A.1.2, A.1.2.3 etc.
(r'^([A-Z])\.(\d+)(?:\s|$|[^\d\.])', 2), # A.1, A.2
(r'^([A-Z])\.(\d+)\.(\d+)(?:\s|$|[^\d\.])', 3), # A.1.2, A.1.3
(r'^([A-Z])\.(\d+)\.(\d+)\.(\d+)(?:\s|$|[^\d\.])', 4), # A.1.2.3
(r'^([A-Z])\.(\d+)\.(\d+)\.(\d+)\.(\d+)(?:\s|$|[^\d\.])', 5), # A.1.2.3.4
(r'^([A-Z])\.(\d+)\.(\d+)\.(\d+)\.(\d+)\.(\d+)(?:\s|$|[^\d\.])', 6), # A.1.2.3.4.5
]
def detect_headers(self, content: str) -> List[HeaderInfo]:
"""Detect all headers and determine their logical levels"""
lines = content.split('\n')
headers: List[HeaderInfo] = []
for line_num, line in enumerate(lines):
if line.strip().startswith('#'):
header_info = self._parse_header_line(line_num, line)
if header_info:
headers.append(header_info)
return headers
def _parse_header_line(self, line_num: int, line: str) -> Optional[HeaderInfo]:
"""Analyze the title line"""
line = line.strip()
# Count the number of # characters
hash_count = 0
for char in line:
if char == '#':
hash_count += 1
else:
break
if hash_count == 0:
return None
# Extract title content
title_content = line[hash_count:].strip()
# Try to match number pattern
level = 1
number_pattern = ""
# Check for letter+number patterns first (A.1.2.3 format)
for pattern, expected_level in self.letter_number_patterns:
match = re.match(pattern, title_content)
if match:
level = expected_level
# Extract the complete matched numbering pattern
matched_text = match.group(0)
# For space-separated patterns like "C. 1.1", we need to extract the full pattern
if '. ' in matched_text:
# This is a space-separated pattern like "C. 1.1"
# The match already contains the complete pattern we want
number_pattern = matched_text.rstrip() # Remove trailing space if any
else:
# This is a compact pattern like "A.1.2.3"
number_pattern = matched_text
return HeaderInfo(
line_number=line_num,
original_line=line,
hash_count=hash_count,
level=level,
number_pattern=number_pattern,
title_text=title_content
)
# If no letter+number pattern, try traditional number patterns
if title_content:
# First, try to identify and extract the complete numbering part
# Look for patterns like "1.2.3", "1 . 2 . 3", "1. 2. 3", etc.
words = title_content.split()
numbering_words = []
# Collect words that could be part of the numbering (digits, dots, spaces)
for word in words:
if re.match(r'^[\d\.]+$', word) or word == '.':
numbering_words.append(word)
else:
break # Stop at first non-numbering word
if numbering_words:
# Join and normalize the numbering part
numbering_text = ' '.join(numbering_words)
# Normalize: "1 . 2 . 3" -> "1.2.3", "1. 2. 3" -> "1.2.3"
normalized = re.sub(r'\s*\.\s*', '.', numbering_text)
normalized = re.sub(r'\.+$', '', normalized) # Remove trailing dots
normalized = normalized.strip()
# Try to match the normalized pattern
for i, pattern in enumerate(self.number_patterns, 1):
match = re.match(pattern, normalized)
if match:
level = i
number_pattern = normalized
break
else:
# If no numbering pattern found in separate words, try the first word directly
first_word = words[0] if words else ""
for i, pattern in enumerate(self.number_patterns, 1):
match = re.match(pattern, first_word)
if match:
level = i
number_pattern = match.group(0).rstrip('.')
break
# If no number pattern is found, infer level from # count
if not number_pattern:
level = hash_count
return HeaderInfo(
line_number=line_num,
original_line=line,
hash_count=hash_count,
level=level,
number_pattern=number_pattern,
title_text=title_content
)
def find_hierarchy_problems(self, headers: List[HeaderInfo]) -> List[Dict]:
"""Find problems with mismatched # counts using adaptive analysis"""
problems = []
# 首先分析文档的自适应层级映射
level_hash_mapping = self._analyze_document_hash_pattern(headers)
# 1. Check for level-hash mismatch based on adaptive mapping
for header in headers:
if header.number_pattern: # Only check numbered headers
expected_hash_count = level_hash_mapping.get(header.level, header.level)
if header.hash_count != expected_hash_count:
problems.append({
'type': 'level_hash_mismatch',
'line': header.line_number + 1,
'level': header.level,
'current_hash': header.hash_count,
'expected_hash': expected_hash_count,
'title': header.title_text[:50],
'pattern': header.number_pattern,
'problem': f"Level {header.level} header '{header.number_pattern}' uses {header.hash_count} #, but document pattern suggests {expected_hash_count} #"
})
# 2. Check for parent-child hierarchy issues
for i in range(len(headers) - 1):
current = headers[i]
next_header = headers[i + 1]
# Only consider headers with a clear number pattern
if current.number_pattern and next_header.number_pattern:
# Check if the child header's # count is less than or equal to the parent header's
if next_header.level > current.level: # Child header
expected_parent_hash = level_hash_mapping.get(current.level, current.level)
expected_child_hash = level_hash_mapping.get(next_header.level, next_header.level)
if next_header.hash_count <= current.hash_count:
problems.append({
'type': 'hierarchy_violation',
'parent_line': current.line_number + 1,
'parent_level': current.level,
'parent_hash': current.hash_count,
'parent_title': current.title_text[:50],
'child_line': next_header.line_number + 1,
'child_level': next_header.level,
'child_hash': next_header.hash_count,
'child_title': next_header.title_text[:50],
'problem': f"Child header ({next_header.level} level) # count ({next_header.hash_count}) should be greater than parent header ({current.level} level, {current.hash_count} #). Expected pattern: parent {expected_parent_hash}#, child {expected_child_hash}#"
})
# 3. Check for significant inconsistency within same level (now less strict)
same_level_problems = self._find_same_level_inconsistency(headers)
problems.extend(same_level_problems)
return problems
def _find_same_level_inconsistency(self, headers: List[HeaderInfo]) -> List[Dict]:
"""Check the problem of inconsistent number of titles # numbers at the same level"""
problems = []
# Group by level, only numbered titles
level_groups = {}
for header in headers:
if header.number_pattern: # Only numbered titles
if header.level not in level_groups:
level_groups[header.level] = []
level_groups[header.level].append(header)
# Check the consistency of # numbers within each level
for level, group_headers in level_groups.items():
if len(group_headers) < 2:
continue # Only one header, no need to check
# Count the usage of different # numbers within the same level
hash_count_stats = {}
for header in group_headers:
hash_count = header.hash_count
if hash_count not in hash_count_stats:
hash_count_stats[hash_count] = []
hash_count_stats[hash_count].append(header)
# If there are different # numbers in the same level
if len(hash_count_stats) > 1:
# Find the most common # number as the standard
most_common_hash_count = max(hash_count_stats.keys(),
key=lambda x: len(hash_count_stats[x]))
# Report titles that do not meet the standard
for hash_count, headers_with_this_count in hash_count_stats.items():
if hash_count != most_common_hash_count:
for header in headers_with_this_count:
problems.append({
'type': 'same_level_inconsistency',
'line': header.line_number + 1,
'level': header.level,
'current_hash': header.hash_count,
'expected_hash': most_common_hash_count,
'title': header.title_text[:50],
'pattern': header.number_pattern,
'problem': f"{header.level} level header uses {header.hash_count} #, but the majority of siblings use {most_common_hash_count} #"
})
return problems
def fix_hierarchy(self, content: str) -> Dict[str,Any]:
"""Fix hierarchy issues"""
headers = self.detect_headers(content)
if not headers:
return {
'fixed_content': content,
'problems_found': [],
'fixes_applied': 0,
'message': 'No headers detected'
}
# Check for problems
problems = self.find_hierarchy_problems(headers)
if not problems:
return {
'fixed_content': content,
'problems_found': [],
'fixes_applied': 0,
'message': 'No hierarchy issues found'
}
# Apply fixes
lines = content.split('\n')
fixes_applied = 0
# To ensure child headers have more # than parent headers, we need to recalculate the # count for each header
fixed_headers = self._calculate_correct_hash_counts(headers)
# Apply fixes
for header in fixed_headers:
if header.hash_count != header.correct_hash_count:
old_line = lines[header.line_number]
new_hash = '#' * header.correct_hash_count
# Replace # part
new_line = re.sub(r'^#+', new_hash, old_line)
lines[header.line_number] = new_line
fixes_applied += 1
fixed_content = '\n'.join(lines)
return {
'fixed_content': fixed_content,
'original_content': content,
'problems_found': problems,
'fixes_applied': fixes_applied,
'fixed_headers': [(h.line_number + 1, h.hash_count, h.correct_hash_count, h.title_text[:30])
for h in fixed_headers if h.hash_count != h.correct_hash_count]
}
def _calculate_correct_hash_counts(self, headers: List[HeaderInfo]) -> List[HeaderInfo]:
"""Calculate the correct number of #'s based on adaptive analysis of the document"""
if not headers:
return []
# 1. 分析文档中各层级的#号使用模式 (自适应分析)
level_hash_mapping = self._analyze_document_hash_pattern(headers)
# Create copies with the correct number of #'s
fixed_headers: list[HeaderInfo] = []
for header in headers:
# Copy original information
fixed_header = HeaderInfo(
line_number=header.line_number,
original_line=header.original_line,
hash_count=header.hash_count,
level=header.level,
number_pattern=header.number_pattern,
title_text=header.title_text
)
if fixed_header.number_pattern:
# For numbered headers, use the adaptive mapping
if fixed_header.level in level_hash_mapping:
fixed_header.correct_hash_count = level_hash_mapping[fixed_header.level]
else:
# Fallback: extrapolate from existing pattern
fixed_header.correct_hash_count = self._extrapolate_hash_count(
fixed_header.level, level_hash_mapping)
else:
# For non-numbered headers, keep the original # count
fixed_header.correct_hash_count = fixed_header.hash_count
fixed_headers.append(fixed_header)
return fixed_headers
def _analyze_document_hash_pattern(self, headers: List[HeaderInfo]) -> Dict[int, int]:
"""Analyze the document's # pattern to determine the adaptive mapping"""
# Count the number of #'s used at each level
level_hash_stats = {}
for header in headers:
if header.number_pattern: # Only numbered titles are considered
level = header.level
hash_count = header.hash_count
if level not in level_hash_stats:
level_hash_stats[level] = {}
if hash_count not in level_hash_stats[level]:
level_hash_stats[level][hash_count] = 0
level_hash_stats[level][hash_count] += 1
# Find out the most commonly used number of # numbers for each level
level_hash_mapping = {}
for level, hash_stats in level_hash_stats.items():
most_common_hash = max(hash_stats.keys(), key=lambda x: hash_stats[x])
level_hash_mapping[level] = most_common_hash
# Verify and adjust the mapping to ensure that the incremental # number of the hierarchy is also incremented
level_hash_mapping = self._ensure_monotonic_mapping(level_hash_mapping)
return level_hash_mapping
def _ensure_monotonic_mapping(self, level_hash_mapping: Dict[int, int]) -> Dict[int, int]:
"""Ensure that the level mapping is monotonically increasing (higher level = more #'s)"""
if not level_hash_mapping:
return level_hash_mapping
# Sort by level
sorted_levels = sorted(level_hash_mapping.keys())
adjusted_mapping = {}
# Ensure that the # count for each level is at least 1 more than the previous level
for i, level in enumerate(sorted_levels):
current_hash = level_hash_mapping[level]
if i == 0:
# First level, use as is
adjusted_mapping[level] = current_hash
else:
# Ensure at least 1 more # than the previous level
prev_level = sorted_levels[i-1]
min_required_hash = adjusted_mapping[prev_level] + 1
adjusted_mapping[level] = max(current_hash, min_required_hash)
return adjusted_mapping
def _extrapolate_hash_count(self, level: int, level_hash_mapping: Dict[int, int]) -> int:
"""Infer the number of # numbers for the hierarchy that have not appeared"""
if not level_hash_mapping:
return level # Fallback to simple 1:1 mapping
sorted_levels = sorted(level_hash_mapping.keys())
if level < sorted_levels[0]:
# Smaller than the minimum level, infer forward
diff = sorted_levels[0] - level
return max(1, level_hash_mapping[sorted_levels[0]] - diff)
elif level > sorted_levels[-1]:
# Larger than the maximum level, infer backward
diff = level - sorted_levels[-1]
return level_hash_mapping[sorted_levels[-1]] + diff
else:
# Between known levels, interpolation inference
for i in range(len(sorted_levels) - 1):
if sorted_levels[i] < level < sorted_levels[i + 1]:
# Simple linear interpolation
lower_level = sorted_levels[i]
upper_level = sorted_levels[i + 1]
lower_hash = level_hash_mapping[lower_level]
upper_hash = level_hash_mapping[upper_level]
# Linear interpolation
ratio = (level - lower_level) / (upper_level - lower_level)
return int(lower_hash + ratio * (upper_hash - lower_hash))
return level # Fallback
def _fix_same_level_inconsistency(self, headers: List[HeaderInfo]) -> None:
"""Fix inconsistency of # count at the same level"""
# Group by level, only process headers with a numbering pattern
level_groups = {}
for header in headers:
if header.number_pattern: # Only process headers with a numbering pattern
if header.level not in level_groups:
level_groups[header.level] = []
level_groups[header.level].append(header)
# Fix inconsistency of # count within each level
for level, group_headers in level_groups.items():
if len(group_headers) < 2:
continue # Only one header, no need to fix
# Count the usage of different # counts within the same level
hash_count_stats = {}
for header in group_headers:
hash_count = header.correct_hash_count
if hash_count not in hash_count_stats:
hash_count_stats[hash_count] = []
hash_count_stats[hash_count].append(header)
# If different # counts exist at the same level
if len(hash_count_stats) > 1:
# Find the most common # count as the standard
most_common_hash_count = max(hash_count_stats.keys(),
key=lambda x: len(hash_count_stats[x]))
# Unify all titles of the same level into the most commonly used number of # numbers
for header in group_headers:
header.correct_hash_count = most_common_hash_count