volaris-pdf-tool / main.py
linkie-academy
Deploy PDF Layout Extractor application
4d5ce5c
import os
import json
import signal
import sys
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Sequence, Set, Any
from multiprocessing import Pool, cpu_count
from functools import partial
import fitz # PyMuPDF (Still needed for drawing output PDF)
import pypdfium2 as pdfium
import torch
from doclayout_yolo import YOLOv10
from huggingface_hub import hf_hub_download
from loguru import logger
from PIL import Image
import numpy as np
try:
import pymupdf4llm # type: ignore
except ImportError: # pragma: no cover - optional dependency
pymupdf4llm = None # type: ignore
# ----------------------------------------------------------------------
# CONFIGURATION
# ----------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Model options
MODEL_SIZE = 1024
REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
WEIGHTS_FILE = f"doclayout_yolo_docstructbench_imgsz{MODEL_SIZE}.pt"
# Detection settings
CONF_THRESHOLD = 0.25
# Multiprocessing settings
NUM_WORKERS = None # None = auto (cpu_count - 1), or set to specific number like 4
USE_MULTIPROCESSING = True # Set to False to disable parallel processing entirely
# ----------------------------------------------------------------------
# Color map for the layout classes
# ----------------------------------------------------------------------
CLASS_COLORS = {
"text": (0, 128, 0), # Dark Green
"title": (192, 0, 0), # Dark Red
"figure": (0, 0, 192), # Dark Blue
"table": (218, 165, 32), # Goldenrod (Dark Yellow)
"list": (128, 0, 128), # Purple
"header": (0, 128, 128), # Teal
"footer": (100, 100, 100), # Dark Gray
"figure_caption": (0, 0, 128), # Navy
"table_caption": (139, 69, 19), # Saddle Brown
"table_footnote": (128, 0, 128), # Purple
}
# Global model instance (will be None in worker processes until loaded)
_model = None
_shutdown_requested = False
# ----------------------------------------------------------------------
# Signal handler for graceful shutdown
# ----------------------------------------------------------------------
def signal_handler(signum, frame):
"""Handle interrupt signals gracefully."""
global _shutdown_requested
if not _shutdown_requested:
_shutdown_requested = True
logger.warning("\n⚠️ Interrupt received! Finishing current page and shutting down gracefully...")
logger.warning("Press Ctrl+C again to force quit (may leave incomplete files)")
else:
logger.error("\n❌ Force quit requested. Exiting immediately.")
sys.exit(1)
def setup_signal_handlers():
"""Setup signal handlers for graceful shutdown."""
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# ----------------------------------------------------------------------
# Model loader function
# ----------------------------------------------------------------------
def get_model():
"""Lazy load the model (only once per process)."""
global _model
if _model is None:
weights_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE)
_model = YOLOv10(weights_path)
logger.info(f"βœ“ Model loaded in worker process (PID: {os.getpid()})")
return _model
# ----------------------------------------------------------------------
# Worker initialization function
# ----------------------------------------------------------------------
def init_worker():
"""Initialize worker process - loads model once at startup."""
try:
get_model()
logger.success(f"Worker {os.getpid()} ready")
except Exception as e:
logger.error(f"Failed to initialize worker {os.getpid()}: {e}")
raise
# ----------------------------------------------------------------------
# Run layout detection on a single page image (YOLO)
# ----------------------------------------------------------------------
def detect_page(pil_img: Image.Image) -> List[dict]:
"""Detect layout elements using YOLO model."""
model = get_model() # Will return already-loaded model in worker
img_cv = np.array(pil_img)
results = model.predict(
img_cv,
imgsz=MODEL_SIZE,
conf=CONF_THRESHOLD,
device=DEVICE,
verbose=False
)
dets = []
for i, box in enumerate(results[0].boxes):
cls_id = int(box.cls.item())
name = results[0].names[cls_id]
conf = float(box.conf.item())
x0, y0, x1, y1 = box.xyxy[0].cpu().numpy().tolist()
dets.append({
"name": name,
"bbox": [x0, y0, x1, y1],
"conf": conf,
"source": "yolo",
"index": i
})
return dets
# ----------------------------------------------------------------------
# Crop & save figure/table regions (with captions)
# ----------------------------------------------------------------------
def get_union_box(box1: List[float], box2: List[float]) -> List[float]:
"""Get the bounding box enclosing two boxes."""
x0 = min(box1[0], box2[0])
y0 = min(box1[1], box2[1])
x1 = max(box1[2], box2[2])
y1 = max(box1[3], box2[3])
return [x0, y0, x1, y1]
def collect_caption_elements(
element: Dict,
all_dets: List[Dict],
target_name: str,
max_vertical_gap: float = 60.0,
min_overlap: float = 0.25,
) -> List[Dict]:
"""
Collect contiguous caption detections directly below a figure/table.
"""
base_box = element["bbox"]
base_bottom = base_box[3]
selected: List[Dict] = []
last_bottom = base_bottom
relevant = [
d for d in all_dets
if d["name"] == target_name and d["bbox"][1] >= base_bottom - 5
]
relevant.sort(key=lambda d: d["bbox"][1])
for cand in relevant:
cand_box = cand["bbox"]
top = cand_box[1]
if selected and top - last_bottom > max_vertical_gap:
break
if selected:
overlap = _horizontal_overlap_ratio(selected[-1]["bbox"], cand_box)
else:
overlap = _horizontal_overlap_ratio(base_box, cand_box)
if overlap < min_overlap:
continue
selected.append(cand)
last_bottom = cand_box[3]
return selected
def collect_title_and_text_segments(
element: Dict,
all_dets: List[Dict],
processed_indices: Set[int],
settings: Optional[Dict[str, float]] = None,
) -> Tuple[List[Dict], List[Dict]]:
"""
Locate a title below the element and any contiguous text blocks directly beneath it.
"""
if settings is None:
settings = TITLE_TEXT_ASSOCIATION
if not element.get("bbox"):
return [], []
figure_box = element["bbox"]
figure_bottom = figure_box[3]
candidates = [
d for d in all_dets
if d.get("bbox") and d["index"] not in processed_indices
]
candidates.sort(key=lambda d: d["bbox"][1])
titles: List[Dict] = []
texts: List[Dict] = []
for idx, det in enumerate(candidates):
if det["name"] != "title":
continue
title_box = det["bbox"]
if title_box[1] < figure_bottom - 5:
continue
vertical_gap = title_box[1] - figure_bottom
if vertical_gap > settings["max_title_gap"]:
break
overlap = _horizontal_overlap_ratio(figure_box, title_box)
if overlap < settings["min_overlap"]:
continue
titles.append(det)
last_bottom = title_box[3]
for follower in candidates[idx + 1 :]:
if follower["name"] == "title":
break
if follower["name"] != "text":
continue
text_box = follower["bbox"]
if text_box[1] < title_box[1]:
continue
gap = text_box[1] - last_bottom
if gap > settings["max_text_gap"]:
break
if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]:
continue
texts.append(follower)
last_bottom = text_box[3]
break
return titles, texts
def save_layout_elements(pil_img: Image.Image, page_num: int,
dets: List[dict], out_dir: Path) -> List[dict]:
"""Save figure and table crops, merging captions."""
fig_dir = out_dir / "figures"
tab_dir = out_dir / "tables"
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(tab_dir, exist_ok=True)
infos = []
fig_count = 0
tab_count = 0
processed_indices = set()
for i, d in enumerate(dets):
if d["index"] in processed_indices:
continue
name = d["name"].lower()
final_box = d["bbox"]
caption_segments: List[Dict] = []
title_segments: List[Dict] = []
text_segments: List[Dict] = []
if name == "figure":
elem_type = "figure"
path_template = fig_dir / f"page_{page_num + 1}_fig_{fig_count}.png"
fig_count += 1
caption_segments = collect_caption_elements(d, dets, "figure_caption")
for cap in caption_segments:
final_box = get_union_box(final_box, cap["bbox"])
processed_indices.add(cap["index"])
title_segments, text_segments = collect_title_and_text_segments(
d, dets, processed_indices
)
for seg in title_segments + text_segments:
final_box = get_union_box(final_box, seg["bbox"])
processed_indices.add(seg["index"])
elif name == "table":
elem_type = "table"
path_template = tab_dir / f"page_{page_num + 1}_tab_{tab_count}.png"
tab_count += 1
caption_segments = collect_caption_elements(d, dets, "table_caption")
for cap in caption_segments:
final_box = get_union_box(final_box, cap["bbox"])
processed_indices.add(cap["index"])
else:
continue
x0, y0, x1, y1 = map(int, final_box)
crop = pil_img.crop((x0, y0, x1, y1))
if crop.mode == "CMYK":
crop = crop.convert("RGB")
crop.save(path_template)
info_data = {
"type": elem_type,
"page": page_num + 1,
"bbox_pixels": final_box,
"conf": d["conf"],
"source": d.get("source", "yolo"),
"image_path": str(path_template.relative_to(out_dir)),
"width": int(x1 - x0),
"height": int(y1 - y0),
"page_width": pil_img.width,
"page_height": pil_img.height,
}
if caption_segments:
info_data["captions"] = [
{
"bbox": cap["bbox"],
"conf": cap.get("conf"),
"index": cap["index"],
"source": cap.get("source"),
"page": page_num + 1,
}
for cap in caption_segments
]
if title_segments:
info_data["titles"] = [
{
"bbox": seg["bbox"],
"conf": seg.get("conf"),
"index": seg["index"],
"source": seg.get("source"),
"page": page_num + 1,
}
for seg in title_segments
]
if text_segments:
info_data["texts"] = [
{
"bbox": seg["bbox"],
"conf": seg.get("conf"),
"index": seg["index"],
"source": seg.get("source"),
"page": page_num + 1,
}
for seg in text_segments
]
infos.append(info_data)
return infos
TABLE_STITCH_TOLERANCES = {
"x_tol": 60,
"y_tol": 60,
"width_tol": 120,
"height_tol": 120,
}
CROSS_PAGE_CAPTION_THRESHOLDS = {
"max_top_ratio": 0.35,
"max_top_pixels": 220,
"x_tol": 120,
"width_tol": 200,
"min_overlap": 0.05,
}
TITLE_TEXT_ASSOCIATION = {
"max_title_gap": 220,
"max_text_gap": 160,
"min_overlap": 0.2,
}
def _horizontal_overlap_ratio(box1: List[float], box2: List[float]) -> float:
"""Compute horizontal overlap ratio between two bounding boxes."""
x_left = max(box1[0], box2[0])
x_right = min(box1[2], box2[2])
overlap = max(0.0, x_right - x_left)
if overlap <= 0:
return 0.0
width_union = max(box1[2], box2[2]) - min(box1[0], box2[0])
if width_union <= 0:
return 0.0
return overlap / width_union
def _bbox_to_rect(bbox: List[float]) -> Tuple[int, int, int, int]:
"""Convert [x0, y0, x1, y1] into (x, y, w, h)."""
x0, y0, x1, y1 = bbox
return int(x0), int(y0), int(x1 - x0), int(y1 - y0)
def _open_table_image(elem: Dict, out_dir: Path) -> Optional[Image.Image]:
"""Open a table image relative to the output directory."""
image_path = out_dir / elem["image_path"]
if not image_path.exists():
logger.warning(f"Missing table crop for stitching: {image_path}")
return None
img = Image.open(image_path)
if img.mode != "RGB":
img = img.convert("RGB")
return img
def _pad_width(img: Image.Image, target_width: int) -> Image.Image:
if img.width >= target_width:
return img
canvas = Image.new("RGB", (target_width, img.height), color=(255, 255, 255))
canvas.paste(img, (0, 0))
return canvas
def _pad_height(img: Image.Image, target_height: int) -> Image.Image:
if img.height >= target_height:
return img
canvas = Image.new("RGB", (img.width, target_height), color=(255, 255, 255))
canvas.paste(img, (0, 0))
return canvas
def _append_segment_image(
base_img: Image.Image,
segment_img: Image.Image,
resize_to_base: bool = False,
) -> Image.Image:
"""Append segment image below base image with optional width alignment."""
if base_img.mode != "RGB":
base_img = base_img.convert("RGB")
if segment_img.mode != "RGB":
segment_img = segment_img.convert("RGB")
if resize_to_base and segment_img.width > 0 and base_img.width > 0:
segment_img = segment_img.resize(
(
base_img.width,
max(1, int(segment_img.height * (base_img.width / segment_img.width))),
),
Image.Resampling.LANCZOS,
)
target_width = max(base_img.width, segment_img.width)
base_img = _pad_width(base_img, target_width)
segment_img = _pad_width(segment_img, target_width)
stitched = Image.new(
"RGB",
(target_width, base_img.height + segment_img.height),
color=(255, 255, 255),
)
stitched.paste(base_img, (0, 0))
stitched.paste(segment_img, (0, base_img.height))
return stitched
def _render_pdf_page(
pdf_doc: pdfium.PdfDocument,
page_index: int,
scale: float,
cache: Dict[int, Image.Image],
) -> Optional[Image.Image]:
"""Render a PDF page to a PIL image with caching."""
if page_index in cache:
return cache[page_index]
try:
page = pdf_doc[page_index]
bitmap = page.render(scale=scale)
pil_img = bitmap.to_pil()
page.close()
except Exception as exc:
logger.error(f"Failed to render page {page_index + 1} for caption stitching: {exc}")
return None
cache[page_index] = pil_img
return pil_img
def _crop_pdf_region(
page_img: Optional[Image.Image], bbox: List[float]
) -> Optional[Image.Image]:
"""Crop a region from a rendered PDF page."""
if page_img is None:
return None
x0, y0, x1, y1 = map(int, bbox)
x0 = max(0, x0)
y0 = max(0, y0)
x1 = min(page_img.width, max(x0 + 1, x1))
y1 = min(page_img.height, max(y0 + 1, y1))
if x0 >= x1 or y0 >= y1:
return None
crop = page_img.crop((x0, y0, x1, y1))
if crop.mode == "CMYK":
crop = crop.convert("RGB")
return crop
def write_markdown_document(pdf_path: Path, out_dir: Path) -> Optional[Path]:
"""
Extract markdown text from a PDF using PyMuPDF4LLM and write it to disk.
"""
if pymupdf4llm is None:
logger.warning(
"Skipping markdown extraction for %s because pymupdf4llm is not installed.",
pdf_path.name,
)
return None
try:
markdown_content = pymupdf4llm.to_markdown(str(pdf_path))
except Exception as exc:
logger.error(f" Failed to create markdown for {pdf_path.name}: {exc}")
return None
if isinstance(markdown_content, list):
markdown_content = "\n\n".join(
part for part in markdown_content if isinstance(part, str)
)
if not isinstance(markdown_content, str):
logger.error(
f" Unexpected markdown output type {type(markdown_content)} for {pdf_path.name}"
)
return None
markdown_content = markdown_content.strip()
if not markdown_content:
logger.warning(f" No textual content extracted from {pdf_path.name}")
return None
if not markdown_content.endswith("\n"):
markdown_content += "\n"
md_path = out_dir / f"{pdf_path.stem}.md"
md_path.write_text(markdown_content, encoding="utf-8")
logger.info(f" Saved markdown to {md_path.name}")
return md_path
def _collect_text_under_title_cross_page(
title_det: Dict,
sorted_dets: List[Dict],
start_idx: int,
page_idx: int,
used_indices: Set[Tuple[int, int]],
settings: Optional[Dict[str, float]] = None,
) -> List[Dict]:
"""Collect text elements directly below a title on the next page."""
if settings is None:
settings = TITLE_TEXT_ASSOCIATION
texts: List[Dict] = []
title_box = title_det["bbox"]
last_bottom = title_box[3]
for follower in sorted_dets[start_idx + 1 :]:
det_index = follower.get("index")
if det_index is None or (page_idx, det_index) in used_indices:
continue
if follower["name"] == "title":
break
if follower["name"] != "text":
continue
text_box = follower["bbox"]
if text_box[1] < title_box[1]:
continue
gap = text_box[1] - last_bottom
if gap > settings["max_text_gap"]:
break
if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]:
continue
texts.append(follower)
last_bottom = text_box[3]
return texts
def attach_cross_page_figure_captions(
elements: List[Dict],
all_dets: Sequence[Optional[List[Dict[str, Any]]]],
pdf_bytes: bytes,
out_dir: Path,
scale: float,
) -> List[Dict]:
"""
If a figure caption appears on the next page, stitch it to the prior figure.
"""
figures = [elem for elem in elements if elem.get("type") == "figure"]
if not figures or not all_dets:
return elements
try:
pdf_doc = pdfium.PdfDocument(pdf_bytes)
except Exception as exc:
logger.error(f"Unable to reopen PDF for figure caption stitching: {exc}")
return elements
page_cache: Dict[int, Image.Image] = {}
used_following_ids: Set[Tuple[int, int]] = set()
# Mark existing caption/title/text detections as used
for elem in figures:
for key in ("captions", "titles", "texts"):
for seg in elem.get(key, []) or []:
idx = seg.get("index")
page_no = seg.get("page")
if idx is None or page_no is None:
continue
used_following_ids.add((page_no - 1, idx))
for elem in figures:
page_no = elem.get("page")
bbox = elem.get("bbox_pixels")
if page_no is None or bbox is None:
continue
current_idx = page_no - 1
next_idx = current_idx + 1
if next_idx >= len(all_dets):
continue
next_dets = all_dets[next_idx]
if not next_dets:
continue
fig_width = bbox[2] - bbox[0]
page_img = _render_pdf_page(pdf_doc, next_idx, scale, page_cache)
if page_img is None:
continue
next_page_height = page_img.height
max_top_allowed = min(
CROSS_PAGE_CAPTION_THRESHOLDS["max_top_pixels"],
int(next_page_height * CROSS_PAGE_CAPTION_THRESHOLDS["max_top_ratio"]),
)
sorted_next = sorted(
[det for det in next_dets if det.get("bbox")],
key=lambda det: det["bbox"][1],
)
caption_candidate: Optional[Tuple[Dict, int]] = None
caption_candidates = []
for det in sorted_next:
if det.get("name") != "figure_caption":
continue
det_index = det.get("index")
if det_index is None or (next_idx, det_index) in used_following_ids:
continue
det_bbox = det.get("bbox")
if not det_bbox or det_bbox[1] > max_top_allowed:
continue
overlap = _horizontal_overlap_ratio(bbox, det_bbox)
x_diff = abs(bbox[0] - det_bbox[0])
width_diff = abs((bbox[2] - bbox[0]) - (det_bbox[2] - det_bbox[0]))
if overlap < CROSS_PAGE_CAPTION_THRESHOLDS["min_overlap"]:
if (
x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"]
or width_diff > CROSS_PAGE_CAPTION_THRESHOLDS["width_tol"]
):
continue
score = width_diff + 0.5 * x_diff
caption_candidates.append((score, det, det_index))
if caption_candidates:
caption_candidates.sort(key=lambda item: item[0])
_, best_det, best_index = caption_candidates[0]
caption_candidate = (best_det, best_index)
title_candidate: Optional[Tuple[Dict, int]] = None
title_texts: List[Dict] = []
for idx_sorted, det in enumerate(sorted_next):
if det.get("name") != "title":
continue
det_index = det.get("index")
if det_index is None or (next_idx, det_index) in used_following_ids:
continue
det_bbox = det.get("bbox")
if not det_bbox or det_bbox[1] > max_top_allowed:
continue
overlap = _horizontal_overlap_ratio(bbox, det_bbox)
x_diff = abs(bbox[0] - det_bbox[0])
if (
overlap < TITLE_TEXT_ASSOCIATION["min_overlap"]
and x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"]
):
continue
title_candidate = (det, det_index)
title_texts = _collect_text_under_title_cross_page(
det, sorted_next, idx_sorted, next_idx, used_following_ids
)
break
if not caption_candidate and not title_candidate and not title_texts:
continue
figure_path = out_dir / elem["image_path"]
if not figure_path.exists():
continue
figure_img = Image.open(figure_path)
if figure_img.mode == "CMYK":
figure_img = figure_img.convert("RGB")
segments_added = False
if caption_candidate:
cap_det, cap_index = caption_candidate
caption_crop = _crop_pdf_region(page_img, cap_det["bbox"])
if caption_crop is not None:
figure_img = _append_segment_image(
figure_img, caption_crop, resize_to_base=True
)
elem.setdefault("captions", [])
elem["captions"].append(
{
"bbox": cap_det["bbox"],
"conf": cap_det.get("conf"),
"index": cap_index,
"source": cap_det.get("source"),
"page": next_idx + 1,
}
)
used_following_ids.add((next_idx, cap_index))
segments_added = True
if title_candidate:
title_det, title_index = title_candidate
title_crop = _crop_pdf_region(page_img, title_det["bbox"])
if title_crop is not None:
figure_img = _append_segment_image(figure_img, title_crop)
elem.setdefault("titles", [])
elem["titles"].append(
{
"bbox": title_det["bbox"],
"conf": title_det.get("conf"),
"index": title_index,
"source": title_det.get("source"),
"page": next_idx + 1,
}
)
used_following_ids.add((next_idx, title_index))
segments_added = True
for text_det in title_texts:
text_index = text_det.get("index")
text_crop = _crop_pdf_region(page_img, text_det["bbox"])
if text_crop is None:
continue
figure_img = _append_segment_image(figure_img, text_crop)
elem.setdefault("texts", [])
elem["texts"].append(
{
"bbox": text_det["bbox"],
"conf": text_det.get("conf"),
"index": text_index,
"source": text_det.get("source"),
"page": next_idx + 1,
}
)
if text_index is not None:
used_following_ids.add((next_idx, text_index))
segments_added = True
if not segments_added:
continue
figure_img.save(figure_path)
elem["width"] = figure_img.width
elem["height"] = figure_img.height
span = elem.get("page_span")
if span:
if next_idx + 1 not in span:
span.append(next_idx + 1)
else:
base_page = elem.get("page")
new_span = [page for page in (base_page, next_idx + 1) if page is not None]
elem["page_span"] = new_span
pdf_doc.close()
return elements
def _stitch_table_pair(
base_elem: Dict,
candidate_elem: Dict,
out_dir: Path,
merge_index: int,
stitch_type: str,
) -> Optional[Dict]:
"""Stitch two table crops either vertically or horizontally."""
base_img = _open_table_image(base_elem, out_dir)
candidate_img = _open_table_image(candidate_elem, out_dir)
if base_img is None or candidate_img is None:
return None
tables_dir = out_dir / "tables"
tables_dir.mkdir(parents=True, exist_ok=True)
if stitch_type == "vertical":
target_width = max(base_img.width, candidate_img.width)
base_img = _pad_width(base_img, target_width)
candidate_img = _pad_width(candidate_img, target_width)
merged_height = base_img.height + candidate_img.height
stitched = Image.new("RGB", (target_width, merged_height), color=(255, 255, 255))
stitched.paste(base_img, (0, 0))
stitched.paste(candidate_img, (0, base_img.height))
else:
target_height = max(base_img.height, candidate_img.height)
base_img = _pad_height(base_img, target_height)
candidate_img = _pad_height(candidate_img, target_height)
merged_width = base_img.width + candidate_img.width
stitched = Image.new("RGB", (merged_width, target_height), color=(255, 255, 255))
stitched.paste(base_img, (0, 0))
stitched.paste(candidate_img, (base_img.width, 0))
merged_name = (
f"page_{base_elem['page']}_to_{candidate_elem['page']}_"
f"table_merged_{merge_index}.png"
)
merged_path = tables_dir / merged_name
stitched.save(merged_path)
# Remove original partial crops to avoid duplicates
(out_dir / base_elem["image_path"]).unlink(missing_ok=True)
(out_dir / candidate_elem["image_path"]).unlink(missing_ok=True)
new_bbox = [
min(base_elem["bbox_pixels"][0], candidate_elem["bbox_pixels"][0]),
min(base_elem["bbox_pixels"][1], candidate_elem["bbox_pixels"][1]),
max(base_elem["bbox_pixels"][2], candidate_elem["bbox_pixels"][2]),
max(base_elem["bbox_pixels"][3], candidate_elem["bbox_pixels"][3]),
]
merged_elem = base_elem.copy()
merged_elem["page_span"] = [base_elem["page"], candidate_elem["page"]]
merged_elem["box_refs"] = [
{"page": base_elem["page"], "image_path": base_elem["image_path"]},
{"page": candidate_elem["page"], "image_path": candidate_elem["image_path"]},
]
merged_elem["bbox_pixels"] = new_bbox
merged_elem["image_path"] = str(merged_path.relative_to(out_dir))
merged_elem["width"] = stitched.width
merged_elem["height"] = stitched.height
merged_elem["page_height"] = stitched.height
merged_elem["conf"] = min(
base_elem.get("conf", 1.0), candidate_elem.get("conf", 1.0)
)
return merged_elem
def merge_spanning_tables(elements: List[Dict], out_dir: Path) -> List[Dict]:
"""
Stitch table crops that continue across adjacent pages using the heuristic
from the legacy OpenCV-based extractor.
"""
if not elements:
return elements
tables_by_page: Dict[int, List[Dict]] = {}
non_tables: List[Dict] = []
for elem in elements:
if elem.get("type") != "table":
non_tables.append(elem)
continue
page = elem.get("page")
if not isinstance(page, int):
non_tables.append(elem)
continue
tables_by_page.setdefault(page, []).append(elem)
merged_results: List[Dict] = []
used_next: Dict[int, set[int]] = {}
merge_counter = 0
for page in sorted(tables_by_page.keys()):
current_tables = tables_by_page.get(page, [])
next_page_tables = tables_by_page.get(page + 1, [])
next_used_indices = used_next.get(page + 1, set())
current_used_indices = used_next.get(page, set())
for idx_current, table_elem in enumerate(current_tables):
if idx_current in current_used_indices:
continue
if not next_page_tables:
merged_results.append(table_elem)
continue
x, y, w, h = _bbox_to_rect(table_elem["bbox_pixels"])
matched = False
for idx, candidate in enumerate(next_page_tables):
if idx in next_used_indices:
continue
if candidate.get("type") != "table":
continue
cx, cy, cw, ch = _bbox_to_rect(candidate["bbox_pixels"])
vertical_match = (
abs(x - cx) <= TABLE_STITCH_TOLERANCES["x_tol"]
and abs((x + w) - (cx + cw)) <= TABLE_STITCH_TOLERANCES["width_tol"]
)
horizontal_match = (
abs(y - cy) <= TABLE_STITCH_TOLERANCES["y_tol"]
and abs((y + h) - (cy + ch))
<= TABLE_STITCH_TOLERANCES["height_tol"]
)
stitch_type = "vertical" if vertical_match else None
if not stitch_type and horizontal_match:
stitch_type = "horizontal"
if not stitch_type:
continue
merge_counter += 1
merged_elem = _stitch_table_pair(
table_elem, candidate, out_dir, merge_counter, stitch_type
)
if merged_elem is None:
continue
merged_results.append(merged_elem)
next_used_indices.add(idx)
matched = True
break
if not matched:
merged_results.append(table_elem)
used_next[page + 1] = next_used_indices
merged_results.extend(non_tables)
return merged_results
# ----------------------------------------------------------------------
# Draw layout boxes on the original PDF
# ----------------------------------------------------------------------
def draw_layout_pdf(pdf_bytes: bytes, all_dets: List[List[dict]],
scale: float, out_path: Path):
"""Annotate PDF with semi-transparent bounding boxes and labels."""
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
for page_no, dets in enumerate(all_dets):
page = doc[page_no]
for d in dets:
rgb = CLASS_COLORS.get(d["name"], (0, 0, 0))
rect = fitz.Rect([c / scale for c in d["bbox"]])
border_color = [c / 255 for c in rgb]
fill_color = [c / 255 for c in rgb]
fill_opacity = 0.15
border_width = 1.5
page.draw_rect(
rect,
color=border_color,
fill=fill_color,
width=border_width,
overlay=True,
fill_opacity=fill_opacity
)
label = f"{d['name']} {d['conf']:.2f}"
if d.get("source"):
label += f" [{d['source'][0].upper()}]"
text_bg = fitz.Rect(rect.x0, rect.y0 - 10, rect.x0 + 60, rect.y0)
page.draw_rect(text_bg, color=None, fill=(1, 1, 1, 0.6), overlay=True)
page.insert_text(
(rect.x0 + 2, rect.y0 - 8),
label,
fontsize=6.5,
color=border_color,
overlay=True
)
doc.save(str(out_path))
doc.close()
# ----------------------------------------------------------------------
# Process a single PDF Page (for parallel execution)
# ----------------------------------------------------------------------
def process_page(task_data: Tuple[int, bytes, float, Path, str]) -> Optional[Tuple[int, List[dict], List[dict]]]:
"""
Process a single page of a PDF in a worker process.
Returns: (page_number, detections, elements) or None on failure
"""
pno, pdf_bytes, scale, out_dir, pdf_name = task_data
if _shutdown_requested:
return None
pdf_pdfium = None
try:
pdf_pdfium = pdfium.PdfDocument(pdf_bytes)
page = pdf_pdfium[pno]
bitmap = page.render(scale=scale)
pil = bitmap.to_pil()
dets = detect_page(pil)
elements = save_layout_elements(pil, pno, dets, out_dir)
page_figures = len([d for d in dets if d['name'] == 'figure'])
page_tables = len([d for d in dets if d['name'] == 'table'])
logger.info(f" [{pdf_name}] Page {pno + 1}: {page_figures} figs, {page_tables} tables")
page.close()
pdf_pdfium.close()
return (pno, dets, elements)
except Exception as e:
logger.error(f"Failed to process page {pno + 1} of {pdf_name}: {e}")
if pdf_pdfium:
pdf_pdfium.close()
return None
# ----------------------------------------------------------------------
# Process a full PDF using the persistent worker pool
# ----------------------------------------------------------------------
def process_pdf_with_pool(
pdf_path: Path,
out_dir: Path,
pool: Optional[Pool] = None,
*,
extract_images: bool = True,
extract_markdown: bool = True,
):
"""
Main processing pipeline for a PDF file.
If pool is provided, uses it. Otherwise processes serially.
"""
if _shutdown_requested:
logger.warning(f"Skipping {pdf_path.name} due to shutdown request")
return
stem = pdf_path.stem
logger.info(f"Processing {pdf_path.name}")
pdf_bytes = pdf_path.read_bytes()
doc = None
try:
doc = pdfium.PdfDocument(pdf_bytes)
page_count = len(doc)
except Exception as e:
logger.error(f"Failed to open PDF {pdf_path.name}: {e}. Skipping.")
return
finally:
if doc is not None:
doc.close()
scale = 2.0
all_elements: List[Dict] = []
filtered_dets: List[List[dict]] = []
if extract_images:
all_dets: List[Optional[List[dict]]] = [None] * page_count
if pool is not None and USE_MULTIPROCESSING:
logger.info(f" Using worker pool for {page_count} pages...")
tasks = [
(pno, pdf_bytes, scale, out_dir, pdf_path.name)
for pno in range(page_count)
]
try:
results = pool.map(process_page, tasks)
for res in results:
if res:
pno, dets, elements = res
all_dets[pno] = dets
all_elements.extend(elements)
except KeyboardInterrupt:
logger.warning("Processing interrupted during parallel execution")
raise
else:
logger.info("Using serial processing...")
try:
pdf_pdfium = pdfium.PdfDocument(pdf_bytes)
for pno in range(page_count):
if _shutdown_requested:
logger.warning(
f"Stopping at page {pno + 1}/{page_count} due to shutdown request"
)
break
try:
logger.info(f" Processing page {pno + 1}/{page_count}")
page = pdf_pdfium[pno]
bitmap = page.render(scale=scale)
pil = bitmap.to_pil()
dets = detect_page(pil)
all_dets[pno] = dets
elements = save_layout_elements(pil, pno, dets, out_dir)
all_elements.extend(elements)
page_figures = len([d for d in dets if d["name"] == "figure"])
page_tables = len([d for d in dets if d["name"] == "table"])
logger.info(
f" Found {page_figures} figures and {page_tables} tables"
)
page.close()
except Exception as e:
logger.error(f"Failed to process page {pno + 1}: {e}. Skipping page.")
pdf_pdfium.close()
except Exception as e:
logger.error(f"Fatal error processing {pdf_path.name}: {e}")
if "pdf_pdfium" in locals() and pdf_pdfium:
pdf_pdfium.close()
return
dets_per_page: List[Optional[List[Dict[str, Any]]]] = [
det if det is not None else None for det in all_dets
]
filtered_dets = [d for d in all_dets if d is not None]
if all_elements:
all_elements = merge_spanning_tables(all_elements, out_dir)
all_elements = attach_cross_page_figure_captions(
all_elements, dets_per_page, pdf_bytes, out_dir, scale
)
if all_elements:
content_list_path = out_dir / f"{stem}_content_list.json"
with open(content_list_path, "w", encoding="utf-8") as f:
json.dump(all_elements, f, ensure_ascii=False, indent=4)
logger.info(f" Saved {len(all_elements)} elements to JSON")
if filtered_dets:
draw_layout_pdf(
pdf_bytes, filtered_dets, scale, out_dir / f"{stem}_layout.pdf"
)
logger.info(" Generated annotated PDF")
else:
logger.warning(f"No detections found for {stem}. Skipping layout PDF.")
else:
logger.info(" Image extraction skipped per configuration.")
markdown_path = None
if extract_markdown:
markdown_path = write_markdown_document(pdf_path, out_dir)
if markdown_path is None:
logger.warning(f" Markdown extraction yielded no content for {stem}.")
if _shutdown_requested:
logger.warning(f"⚠️ Partial results saved for {stem} β†’ {out_dir}")
else:
if extract_images:
logger.success(
f"βœ“ {stem} β†’ {out_dir} ({len(all_elements)} elements extracted)"
)
else:
logger.success(f"βœ“ {stem} β†’ {out_dir} (image extraction skipped)")
# ----------------------------------------------------------------------
# Main
# ----------------------------------------------------------------------
if __name__ == "__main__":
# Important for multiprocessing on Windows/macOS
torch.multiprocessing.set_start_method('spawn', force=True)
# Setup signal handlers for graceful shutdown
setup_signal_handlers()
INPUT_DIR = Path("./pdfs")
OUTPUT_DIR = Path("./output")
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
pdf_files = list(INPUT_DIR.glob("*.pdf"))
if not pdf_files:
logger.warning("No PDF files found in ./pdfs")
logger.info("Please add PDF files to the ./pdfs directory")
logger.info("The script will exit gracefully. No errors occurred.")
sys.exit(0)
logger.info(f"Found {len(pdf_files)} PDF file(s) to process")
logger.info(f"Settings: MODEL_SIZE={MODEL_SIZE}, CONF={CONF_THRESHOLD}")
# Determine worker count
total_cpus = cpu_count()
if NUM_WORKERS is None:
num_workers = max(1, total_cpus - 1)
else:
num_workers = max(1, min(NUM_WORKERS, total_cpus))
# Decide whether to use multiprocessing
use_pool = USE_MULTIPROCESSING and DEVICE == "cpu" and total_cpus >= 4
if use_pool:
logger.info(f"πŸš€ Creating persistent worker pool with {num_workers} workers...")
else:
if not USE_MULTIPROCESSING:
logger.info("Multiprocessing disabled by configuration")
elif DEVICE != "cpu":
logger.info(f"Using serial GPU processing (device: {DEVICE})")
else:
logger.info(f"Using serial CPU processing (CPU count {total_cpus} too low)")
pool = None
try:
# Create persistent pool ONCE for all PDFs
if use_pool:
pool = Pool(processes=num_workers, initializer=init_worker)
logger.success(f"βœ“ Worker pool ready with {num_workers} workers\n")
else:
# Load model in main process for serial execution
logger.info("Initializing model in main process...")
get_model()
logger.success(f"βœ“ Model loaded (device: {DEVICE})\n")
# Process all PDFs using the same pool
for i, pdf_path in enumerate(pdf_files, 1):
if _shutdown_requested:
logger.warning(f"\nShutdown requested. Processed {i-1}/{len(pdf_files)} files.")
break
logger.info(f"\n{'='*60}")
logger.info(f"πŸ“„ File {i}/{len(pdf_files)}: {pdf_path.name}")
logger.info(f"{'='*60}")
sub_out = OUTPUT_DIR / pdf_path.stem
os.makedirs(sub_out, exist_ok=True)
try:
process_pdf_with_pool(pdf_path, sub_out, pool)
except KeyboardInterrupt:
logger.warning(f"\nInterrupted while processing {pdf_path.name}")
break
except Exception as e:
logger.error(f"Error processing {pdf_path.name}: {e}")
if _shutdown_requested:
break
logger.info("Continuing with next file...")
continue
if _shutdown_requested:
logger.warning(f"\n⚠️ Processing interrupted. Partial results saved in {OUTPUT_DIR}")
else:
logger.success(f"\n✨ All done! Results are in {OUTPUT_DIR}")
except KeyboardInterrupt:
logger.error("\n❌ Processing interrupted by user")
sys.exit(1)
except Exception as e:
logger.error(f"\n❌ Fatal error: {e}")
sys.exit(1)
finally:
# Clean up pool if it exists
if pool is not None:
logger.info("\n🧹 Shutting down worker pool...")
pool.close()
pool.join()
logger.success("βœ“ Worker pool closed cleanly")