Spaces:
Running
Running
| import os | |
| import json | |
| import signal | |
| import sys | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Optional, Sequence, Set, Any | |
| from multiprocessing import Pool, cpu_count | |
| from functools import partial | |
| import fitz # PyMuPDF (Still needed for drawing output PDF) | |
| import pypdfium2 as pdfium | |
| import torch | |
| from doclayout_yolo import YOLOv10 | |
| from huggingface_hub import hf_hub_download | |
| from loguru import logger | |
| from PIL import Image | |
| import numpy as np | |
| try: | |
| import pymupdf4llm # type: ignore | |
| except ImportError: # pragma: no cover - optional dependency | |
| pymupdf4llm = None # type: ignore | |
| # ---------------------------------------------------------------------- | |
| # CONFIGURATION | |
| # ---------------------------------------------------------------------- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Model options | |
| MODEL_SIZE = 1024 | |
| REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench" | |
| WEIGHTS_FILE = f"doclayout_yolo_docstructbench_imgsz{MODEL_SIZE}.pt" | |
| # Detection settings | |
| CONF_THRESHOLD = 0.25 | |
| # Multiprocessing settings | |
| NUM_WORKERS = None # None = auto (cpu_count - 1), or set to specific number like 4 | |
| USE_MULTIPROCESSING = True # Set to False to disable parallel processing entirely | |
| # ---------------------------------------------------------------------- | |
| # Color map for the layout classes | |
| # ---------------------------------------------------------------------- | |
| CLASS_COLORS = { | |
| "text": (0, 128, 0), # Dark Green | |
| "title": (192, 0, 0), # Dark Red | |
| "figure": (0, 0, 192), # Dark Blue | |
| "table": (218, 165, 32), # Goldenrod (Dark Yellow) | |
| "list": (128, 0, 128), # Purple | |
| "header": (0, 128, 128), # Teal | |
| "footer": (100, 100, 100), # Dark Gray | |
| "figure_caption": (0, 0, 128), # Navy | |
| "table_caption": (139, 69, 19), # Saddle Brown | |
| "table_footnote": (128, 0, 128), # Purple | |
| } | |
| # Global model instance (will be None in worker processes until loaded) | |
| _model = None | |
| _shutdown_requested = False | |
| # ---------------------------------------------------------------------- | |
| # Signal handler for graceful shutdown | |
| # ---------------------------------------------------------------------- | |
| def signal_handler(signum, frame): | |
| """Handle interrupt signals gracefully.""" | |
| global _shutdown_requested | |
| if not _shutdown_requested: | |
| _shutdown_requested = True | |
| logger.warning("\nβ οΈ Interrupt received! Finishing current page and shutting down gracefully...") | |
| logger.warning("Press Ctrl+C again to force quit (may leave incomplete files)") | |
| else: | |
| logger.error("\nβ Force quit requested. Exiting immediately.") | |
| sys.exit(1) | |
| def setup_signal_handlers(): | |
| """Setup signal handlers for graceful shutdown.""" | |
| signal.signal(signal.SIGINT, signal_handler) | |
| signal.signal(signal.SIGTERM, signal_handler) | |
| # ---------------------------------------------------------------------- | |
| # Model loader function | |
| # ---------------------------------------------------------------------- | |
| def get_model(): | |
| """Lazy load the model (only once per process).""" | |
| global _model | |
| if _model is None: | |
| weights_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE) | |
| _model = YOLOv10(weights_path) | |
| logger.info(f"β Model loaded in worker process (PID: {os.getpid()})") | |
| return _model | |
| # ---------------------------------------------------------------------- | |
| # Worker initialization function | |
| # ---------------------------------------------------------------------- | |
| def init_worker(): | |
| """Initialize worker process - loads model once at startup.""" | |
| try: | |
| get_model() | |
| logger.success(f"Worker {os.getpid()} ready") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize worker {os.getpid()}: {e}") | |
| raise | |
| # ---------------------------------------------------------------------- | |
| # Run layout detection on a single page image (YOLO) | |
| # ---------------------------------------------------------------------- | |
| def detect_page(pil_img: Image.Image) -> List[dict]: | |
| """Detect layout elements using YOLO model.""" | |
| model = get_model() # Will return already-loaded model in worker | |
| img_cv = np.array(pil_img) | |
| results = model.predict( | |
| img_cv, | |
| imgsz=MODEL_SIZE, | |
| conf=CONF_THRESHOLD, | |
| device=DEVICE, | |
| verbose=False | |
| ) | |
| dets = [] | |
| for i, box in enumerate(results[0].boxes): | |
| cls_id = int(box.cls.item()) | |
| name = results[0].names[cls_id] | |
| conf = float(box.conf.item()) | |
| x0, y0, x1, y1 = box.xyxy[0].cpu().numpy().tolist() | |
| dets.append({ | |
| "name": name, | |
| "bbox": [x0, y0, x1, y1], | |
| "conf": conf, | |
| "source": "yolo", | |
| "index": i | |
| }) | |
| return dets | |
| # ---------------------------------------------------------------------- | |
| # Crop & save figure/table regions (with captions) | |
| # ---------------------------------------------------------------------- | |
| def get_union_box(box1: List[float], box2: List[float]) -> List[float]: | |
| """Get the bounding box enclosing two boxes.""" | |
| x0 = min(box1[0], box2[0]) | |
| y0 = min(box1[1], box2[1]) | |
| x1 = max(box1[2], box2[2]) | |
| y1 = max(box1[3], box2[3]) | |
| return [x0, y0, x1, y1] | |
| def collect_caption_elements( | |
| element: Dict, | |
| all_dets: List[Dict], | |
| target_name: str, | |
| max_vertical_gap: float = 60.0, | |
| min_overlap: float = 0.25, | |
| ) -> List[Dict]: | |
| """ | |
| Collect contiguous caption detections directly below a figure/table. | |
| """ | |
| base_box = element["bbox"] | |
| base_bottom = base_box[3] | |
| selected: List[Dict] = [] | |
| last_bottom = base_bottom | |
| relevant = [ | |
| d for d in all_dets | |
| if d["name"] == target_name and d["bbox"][1] >= base_bottom - 5 | |
| ] | |
| relevant.sort(key=lambda d: d["bbox"][1]) | |
| for cand in relevant: | |
| cand_box = cand["bbox"] | |
| top = cand_box[1] | |
| if selected and top - last_bottom > max_vertical_gap: | |
| break | |
| if selected: | |
| overlap = _horizontal_overlap_ratio(selected[-1]["bbox"], cand_box) | |
| else: | |
| overlap = _horizontal_overlap_ratio(base_box, cand_box) | |
| if overlap < min_overlap: | |
| continue | |
| selected.append(cand) | |
| last_bottom = cand_box[3] | |
| return selected | |
| def collect_title_and_text_segments( | |
| element: Dict, | |
| all_dets: List[Dict], | |
| processed_indices: Set[int], | |
| settings: Optional[Dict[str, float]] = None, | |
| ) -> Tuple[List[Dict], List[Dict]]: | |
| """ | |
| Locate a title below the element and any contiguous text blocks directly beneath it. | |
| """ | |
| if settings is None: | |
| settings = TITLE_TEXT_ASSOCIATION | |
| if not element.get("bbox"): | |
| return [], [] | |
| figure_box = element["bbox"] | |
| figure_bottom = figure_box[3] | |
| candidates = [ | |
| d for d in all_dets | |
| if d.get("bbox") and d["index"] not in processed_indices | |
| ] | |
| candidates.sort(key=lambda d: d["bbox"][1]) | |
| titles: List[Dict] = [] | |
| texts: List[Dict] = [] | |
| for idx, det in enumerate(candidates): | |
| if det["name"] != "title": | |
| continue | |
| title_box = det["bbox"] | |
| if title_box[1] < figure_bottom - 5: | |
| continue | |
| vertical_gap = title_box[1] - figure_bottom | |
| if vertical_gap > settings["max_title_gap"]: | |
| break | |
| overlap = _horizontal_overlap_ratio(figure_box, title_box) | |
| if overlap < settings["min_overlap"]: | |
| continue | |
| titles.append(det) | |
| last_bottom = title_box[3] | |
| for follower in candidates[idx + 1 :]: | |
| if follower["name"] == "title": | |
| break | |
| if follower["name"] != "text": | |
| continue | |
| text_box = follower["bbox"] | |
| if text_box[1] < title_box[1]: | |
| continue | |
| gap = text_box[1] - last_bottom | |
| if gap > settings["max_text_gap"]: | |
| break | |
| if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]: | |
| continue | |
| texts.append(follower) | |
| last_bottom = text_box[3] | |
| break | |
| return titles, texts | |
| def save_layout_elements(pil_img: Image.Image, page_num: int, | |
| dets: List[dict], out_dir: Path) -> List[dict]: | |
| """Save figure and table crops, merging captions.""" | |
| fig_dir = out_dir / "figures" | |
| tab_dir = out_dir / "tables" | |
| os.makedirs(fig_dir, exist_ok=True) | |
| os.makedirs(tab_dir, exist_ok=True) | |
| infos = [] | |
| fig_count = 0 | |
| tab_count = 0 | |
| processed_indices = set() | |
| for i, d in enumerate(dets): | |
| if d["index"] in processed_indices: | |
| continue | |
| name = d["name"].lower() | |
| final_box = d["bbox"] | |
| caption_segments: List[Dict] = [] | |
| title_segments: List[Dict] = [] | |
| text_segments: List[Dict] = [] | |
| if name == "figure": | |
| elem_type = "figure" | |
| path_template = fig_dir / f"page_{page_num + 1}_fig_{fig_count}.png" | |
| fig_count += 1 | |
| caption_segments = collect_caption_elements(d, dets, "figure_caption") | |
| for cap in caption_segments: | |
| final_box = get_union_box(final_box, cap["bbox"]) | |
| processed_indices.add(cap["index"]) | |
| title_segments, text_segments = collect_title_and_text_segments( | |
| d, dets, processed_indices | |
| ) | |
| for seg in title_segments + text_segments: | |
| final_box = get_union_box(final_box, seg["bbox"]) | |
| processed_indices.add(seg["index"]) | |
| elif name == "table": | |
| elem_type = "table" | |
| path_template = tab_dir / f"page_{page_num + 1}_tab_{tab_count}.png" | |
| tab_count += 1 | |
| caption_segments = collect_caption_elements(d, dets, "table_caption") | |
| for cap in caption_segments: | |
| final_box = get_union_box(final_box, cap["bbox"]) | |
| processed_indices.add(cap["index"]) | |
| else: | |
| continue | |
| x0, y0, x1, y1 = map(int, final_box) | |
| crop = pil_img.crop((x0, y0, x1, y1)) | |
| if crop.mode == "CMYK": | |
| crop = crop.convert("RGB") | |
| crop.save(path_template) | |
| info_data = { | |
| "type": elem_type, | |
| "page": page_num + 1, | |
| "bbox_pixels": final_box, | |
| "conf": d["conf"], | |
| "source": d.get("source", "yolo"), | |
| "image_path": str(path_template.relative_to(out_dir)), | |
| "width": int(x1 - x0), | |
| "height": int(y1 - y0), | |
| "page_width": pil_img.width, | |
| "page_height": pil_img.height, | |
| } | |
| if caption_segments: | |
| info_data["captions"] = [ | |
| { | |
| "bbox": cap["bbox"], | |
| "conf": cap.get("conf"), | |
| "index": cap["index"], | |
| "source": cap.get("source"), | |
| "page": page_num + 1, | |
| } | |
| for cap in caption_segments | |
| ] | |
| if title_segments: | |
| info_data["titles"] = [ | |
| { | |
| "bbox": seg["bbox"], | |
| "conf": seg.get("conf"), | |
| "index": seg["index"], | |
| "source": seg.get("source"), | |
| "page": page_num + 1, | |
| } | |
| for seg in title_segments | |
| ] | |
| if text_segments: | |
| info_data["texts"] = [ | |
| { | |
| "bbox": seg["bbox"], | |
| "conf": seg.get("conf"), | |
| "index": seg["index"], | |
| "source": seg.get("source"), | |
| "page": page_num + 1, | |
| } | |
| for seg in text_segments | |
| ] | |
| infos.append(info_data) | |
| return infos | |
| TABLE_STITCH_TOLERANCES = { | |
| "x_tol": 60, | |
| "y_tol": 60, | |
| "width_tol": 120, | |
| "height_tol": 120, | |
| } | |
| CROSS_PAGE_CAPTION_THRESHOLDS = { | |
| "max_top_ratio": 0.35, | |
| "max_top_pixels": 220, | |
| "x_tol": 120, | |
| "width_tol": 200, | |
| "min_overlap": 0.05, | |
| } | |
| TITLE_TEXT_ASSOCIATION = { | |
| "max_title_gap": 220, | |
| "max_text_gap": 160, | |
| "min_overlap": 0.2, | |
| } | |
| def _horizontal_overlap_ratio(box1: List[float], box2: List[float]) -> float: | |
| """Compute horizontal overlap ratio between two bounding boxes.""" | |
| x_left = max(box1[0], box2[0]) | |
| x_right = min(box1[2], box2[2]) | |
| overlap = max(0.0, x_right - x_left) | |
| if overlap <= 0: | |
| return 0.0 | |
| width_union = max(box1[2], box2[2]) - min(box1[0], box2[0]) | |
| if width_union <= 0: | |
| return 0.0 | |
| return overlap / width_union | |
| def _bbox_to_rect(bbox: List[float]) -> Tuple[int, int, int, int]: | |
| """Convert [x0, y0, x1, y1] into (x, y, w, h).""" | |
| x0, y0, x1, y1 = bbox | |
| return int(x0), int(y0), int(x1 - x0), int(y1 - y0) | |
| def _open_table_image(elem: Dict, out_dir: Path) -> Optional[Image.Image]: | |
| """Open a table image relative to the output directory.""" | |
| image_path = out_dir / elem["image_path"] | |
| if not image_path.exists(): | |
| logger.warning(f"Missing table crop for stitching: {image_path}") | |
| return None | |
| img = Image.open(image_path) | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| return img | |
| def _pad_width(img: Image.Image, target_width: int) -> Image.Image: | |
| if img.width >= target_width: | |
| return img | |
| canvas = Image.new("RGB", (target_width, img.height), color=(255, 255, 255)) | |
| canvas.paste(img, (0, 0)) | |
| return canvas | |
| def _pad_height(img: Image.Image, target_height: int) -> Image.Image: | |
| if img.height >= target_height: | |
| return img | |
| canvas = Image.new("RGB", (img.width, target_height), color=(255, 255, 255)) | |
| canvas.paste(img, (0, 0)) | |
| return canvas | |
| def _append_segment_image( | |
| base_img: Image.Image, | |
| segment_img: Image.Image, | |
| resize_to_base: bool = False, | |
| ) -> Image.Image: | |
| """Append segment image below base image with optional width alignment.""" | |
| if base_img.mode != "RGB": | |
| base_img = base_img.convert("RGB") | |
| if segment_img.mode != "RGB": | |
| segment_img = segment_img.convert("RGB") | |
| if resize_to_base and segment_img.width > 0 and base_img.width > 0: | |
| segment_img = segment_img.resize( | |
| ( | |
| base_img.width, | |
| max(1, int(segment_img.height * (base_img.width / segment_img.width))), | |
| ), | |
| Image.Resampling.LANCZOS, | |
| ) | |
| target_width = max(base_img.width, segment_img.width) | |
| base_img = _pad_width(base_img, target_width) | |
| segment_img = _pad_width(segment_img, target_width) | |
| stitched = Image.new( | |
| "RGB", | |
| (target_width, base_img.height + segment_img.height), | |
| color=(255, 255, 255), | |
| ) | |
| stitched.paste(base_img, (0, 0)) | |
| stitched.paste(segment_img, (0, base_img.height)) | |
| return stitched | |
| def _render_pdf_page( | |
| pdf_doc: pdfium.PdfDocument, | |
| page_index: int, | |
| scale: float, | |
| cache: Dict[int, Image.Image], | |
| ) -> Optional[Image.Image]: | |
| """Render a PDF page to a PIL image with caching.""" | |
| if page_index in cache: | |
| return cache[page_index] | |
| try: | |
| page = pdf_doc[page_index] | |
| bitmap = page.render(scale=scale) | |
| pil_img = bitmap.to_pil() | |
| page.close() | |
| except Exception as exc: | |
| logger.error(f"Failed to render page {page_index + 1} for caption stitching: {exc}") | |
| return None | |
| cache[page_index] = pil_img | |
| return pil_img | |
| def _crop_pdf_region( | |
| page_img: Optional[Image.Image], bbox: List[float] | |
| ) -> Optional[Image.Image]: | |
| """Crop a region from a rendered PDF page.""" | |
| if page_img is None: | |
| return None | |
| x0, y0, x1, y1 = map(int, bbox) | |
| x0 = max(0, x0) | |
| y0 = max(0, y0) | |
| x1 = min(page_img.width, max(x0 + 1, x1)) | |
| y1 = min(page_img.height, max(y0 + 1, y1)) | |
| if x0 >= x1 or y0 >= y1: | |
| return None | |
| crop = page_img.crop((x0, y0, x1, y1)) | |
| if crop.mode == "CMYK": | |
| crop = crop.convert("RGB") | |
| return crop | |
| def write_markdown_document(pdf_path: Path, out_dir: Path) -> Optional[Path]: | |
| """ | |
| Extract markdown text from a PDF using PyMuPDF4LLM and write it to disk. | |
| """ | |
| if pymupdf4llm is None: | |
| logger.warning( | |
| "Skipping markdown extraction for %s because pymupdf4llm is not installed.", | |
| pdf_path.name, | |
| ) | |
| return None | |
| try: | |
| markdown_content = pymupdf4llm.to_markdown(str(pdf_path)) | |
| except Exception as exc: | |
| logger.error(f" Failed to create markdown for {pdf_path.name}: {exc}") | |
| return None | |
| if isinstance(markdown_content, list): | |
| markdown_content = "\n\n".join( | |
| part for part in markdown_content if isinstance(part, str) | |
| ) | |
| if not isinstance(markdown_content, str): | |
| logger.error( | |
| f" Unexpected markdown output type {type(markdown_content)} for {pdf_path.name}" | |
| ) | |
| return None | |
| markdown_content = markdown_content.strip() | |
| if not markdown_content: | |
| logger.warning(f" No textual content extracted from {pdf_path.name}") | |
| return None | |
| if not markdown_content.endswith("\n"): | |
| markdown_content += "\n" | |
| md_path = out_dir / f"{pdf_path.stem}.md" | |
| md_path.write_text(markdown_content, encoding="utf-8") | |
| logger.info(f" Saved markdown to {md_path.name}") | |
| return md_path | |
| def _collect_text_under_title_cross_page( | |
| title_det: Dict, | |
| sorted_dets: List[Dict], | |
| start_idx: int, | |
| page_idx: int, | |
| used_indices: Set[Tuple[int, int]], | |
| settings: Optional[Dict[str, float]] = None, | |
| ) -> List[Dict]: | |
| """Collect text elements directly below a title on the next page.""" | |
| if settings is None: | |
| settings = TITLE_TEXT_ASSOCIATION | |
| texts: List[Dict] = [] | |
| title_box = title_det["bbox"] | |
| last_bottom = title_box[3] | |
| for follower in sorted_dets[start_idx + 1 :]: | |
| det_index = follower.get("index") | |
| if det_index is None or (page_idx, det_index) in used_indices: | |
| continue | |
| if follower["name"] == "title": | |
| break | |
| if follower["name"] != "text": | |
| continue | |
| text_box = follower["bbox"] | |
| if text_box[1] < title_box[1]: | |
| continue | |
| gap = text_box[1] - last_bottom | |
| if gap > settings["max_text_gap"]: | |
| break | |
| if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]: | |
| continue | |
| texts.append(follower) | |
| last_bottom = text_box[3] | |
| return texts | |
| def attach_cross_page_figure_captions( | |
| elements: List[Dict], | |
| all_dets: Sequence[Optional[List[Dict[str, Any]]]], | |
| pdf_bytes: bytes, | |
| out_dir: Path, | |
| scale: float, | |
| ) -> List[Dict]: | |
| """ | |
| If a figure caption appears on the next page, stitch it to the prior figure. | |
| """ | |
| figures = [elem for elem in elements if elem.get("type") == "figure"] | |
| if not figures or not all_dets: | |
| return elements | |
| try: | |
| pdf_doc = pdfium.PdfDocument(pdf_bytes) | |
| except Exception as exc: | |
| logger.error(f"Unable to reopen PDF for figure caption stitching: {exc}") | |
| return elements | |
| page_cache: Dict[int, Image.Image] = {} | |
| used_following_ids: Set[Tuple[int, int]] = set() | |
| # Mark existing caption/title/text detections as used | |
| for elem in figures: | |
| for key in ("captions", "titles", "texts"): | |
| for seg in elem.get(key, []) or []: | |
| idx = seg.get("index") | |
| page_no = seg.get("page") | |
| if idx is None or page_no is None: | |
| continue | |
| used_following_ids.add((page_no - 1, idx)) | |
| for elem in figures: | |
| page_no = elem.get("page") | |
| bbox = elem.get("bbox_pixels") | |
| if page_no is None or bbox is None: | |
| continue | |
| current_idx = page_no - 1 | |
| next_idx = current_idx + 1 | |
| if next_idx >= len(all_dets): | |
| continue | |
| next_dets = all_dets[next_idx] | |
| if not next_dets: | |
| continue | |
| fig_width = bbox[2] - bbox[0] | |
| page_img = _render_pdf_page(pdf_doc, next_idx, scale, page_cache) | |
| if page_img is None: | |
| continue | |
| next_page_height = page_img.height | |
| max_top_allowed = min( | |
| CROSS_PAGE_CAPTION_THRESHOLDS["max_top_pixels"], | |
| int(next_page_height * CROSS_PAGE_CAPTION_THRESHOLDS["max_top_ratio"]), | |
| ) | |
| sorted_next = sorted( | |
| [det for det in next_dets if det.get("bbox")], | |
| key=lambda det: det["bbox"][1], | |
| ) | |
| caption_candidate: Optional[Tuple[Dict, int]] = None | |
| caption_candidates = [] | |
| for det in sorted_next: | |
| if det.get("name") != "figure_caption": | |
| continue | |
| det_index = det.get("index") | |
| if det_index is None or (next_idx, det_index) in used_following_ids: | |
| continue | |
| det_bbox = det.get("bbox") | |
| if not det_bbox or det_bbox[1] > max_top_allowed: | |
| continue | |
| overlap = _horizontal_overlap_ratio(bbox, det_bbox) | |
| x_diff = abs(bbox[0] - det_bbox[0]) | |
| width_diff = abs((bbox[2] - bbox[0]) - (det_bbox[2] - det_bbox[0])) | |
| if overlap < CROSS_PAGE_CAPTION_THRESHOLDS["min_overlap"]: | |
| if ( | |
| x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"] | |
| or width_diff > CROSS_PAGE_CAPTION_THRESHOLDS["width_tol"] | |
| ): | |
| continue | |
| score = width_diff + 0.5 * x_diff | |
| caption_candidates.append((score, det, det_index)) | |
| if caption_candidates: | |
| caption_candidates.sort(key=lambda item: item[0]) | |
| _, best_det, best_index = caption_candidates[0] | |
| caption_candidate = (best_det, best_index) | |
| title_candidate: Optional[Tuple[Dict, int]] = None | |
| title_texts: List[Dict] = [] | |
| for idx_sorted, det in enumerate(sorted_next): | |
| if det.get("name") != "title": | |
| continue | |
| det_index = det.get("index") | |
| if det_index is None or (next_idx, det_index) in used_following_ids: | |
| continue | |
| det_bbox = det.get("bbox") | |
| if not det_bbox or det_bbox[1] > max_top_allowed: | |
| continue | |
| overlap = _horizontal_overlap_ratio(bbox, det_bbox) | |
| x_diff = abs(bbox[0] - det_bbox[0]) | |
| if ( | |
| overlap < TITLE_TEXT_ASSOCIATION["min_overlap"] | |
| and x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"] | |
| ): | |
| continue | |
| title_candidate = (det, det_index) | |
| title_texts = _collect_text_under_title_cross_page( | |
| det, sorted_next, idx_sorted, next_idx, used_following_ids | |
| ) | |
| break | |
| if not caption_candidate and not title_candidate and not title_texts: | |
| continue | |
| figure_path = out_dir / elem["image_path"] | |
| if not figure_path.exists(): | |
| continue | |
| figure_img = Image.open(figure_path) | |
| if figure_img.mode == "CMYK": | |
| figure_img = figure_img.convert("RGB") | |
| segments_added = False | |
| if caption_candidate: | |
| cap_det, cap_index = caption_candidate | |
| caption_crop = _crop_pdf_region(page_img, cap_det["bbox"]) | |
| if caption_crop is not None: | |
| figure_img = _append_segment_image( | |
| figure_img, caption_crop, resize_to_base=True | |
| ) | |
| elem.setdefault("captions", []) | |
| elem["captions"].append( | |
| { | |
| "bbox": cap_det["bbox"], | |
| "conf": cap_det.get("conf"), | |
| "index": cap_index, | |
| "source": cap_det.get("source"), | |
| "page": next_idx + 1, | |
| } | |
| ) | |
| used_following_ids.add((next_idx, cap_index)) | |
| segments_added = True | |
| if title_candidate: | |
| title_det, title_index = title_candidate | |
| title_crop = _crop_pdf_region(page_img, title_det["bbox"]) | |
| if title_crop is not None: | |
| figure_img = _append_segment_image(figure_img, title_crop) | |
| elem.setdefault("titles", []) | |
| elem["titles"].append( | |
| { | |
| "bbox": title_det["bbox"], | |
| "conf": title_det.get("conf"), | |
| "index": title_index, | |
| "source": title_det.get("source"), | |
| "page": next_idx + 1, | |
| } | |
| ) | |
| used_following_ids.add((next_idx, title_index)) | |
| segments_added = True | |
| for text_det in title_texts: | |
| text_index = text_det.get("index") | |
| text_crop = _crop_pdf_region(page_img, text_det["bbox"]) | |
| if text_crop is None: | |
| continue | |
| figure_img = _append_segment_image(figure_img, text_crop) | |
| elem.setdefault("texts", []) | |
| elem["texts"].append( | |
| { | |
| "bbox": text_det["bbox"], | |
| "conf": text_det.get("conf"), | |
| "index": text_index, | |
| "source": text_det.get("source"), | |
| "page": next_idx + 1, | |
| } | |
| ) | |
| if text_index is not None: | |
| used_following_ids.add((next_idx, text_index)) | |
| segments_added = True | |
| if not segments_added: | |
| continue | |
| figure_img.save(figure_path) | |
| elem["width"] = figure_img.width | |
| elem["height"] = figure_img.height | |
| span = elem.get("page_span") | |
| if span: | |
| if next_idx + 1 not in span: | |
| span.append(next_idx + 1) | |
| else: | |
| base_page = elem.get("page") | |
| new_span = [page for page in (base_page, next_idx + 1) if page is not None] | |
| elem["page_span"] = new_span | |
| pdf_doc.close() | |
| return elements | |
| def _stitch_table_pair( | |
| base_elem: Dict, | |
| candidate_elem: Dict, | |
| out_dir: Path, | |
| merge_index: int, | |
| stitch_type: str, | |
| ) -> Optional[Dict]: | |
| """Stitch two table crops either vertically or horizontally.""" | |
| base_img = _open_table_image(base_elem, out_dir) | |
| candidate_img = _open_table_image(candidate_elem, out_dir) | |
| if base_img is None or candidate_img is None: | |
| return None | |
| tables_dir = out_dir / "tables" | |
| tables_dir.mkdir(parents=True, exist_ok=True) | |
| if stitch_type == "vertical": | |
| target_width = max(base_img.width, candidate_img.width) | |
| base_img = _pad_width(base_img, target_width) | |
| candidate_img = _pad_width(candidate_img, target_width) | |
| merged_height = base_img.height + candidate_img.height | |
| stitched = Image.new("RGB", (target_width, merged_height), color=(255, 255, 255)) | |
| stitched.paste(base_img, (0, 0)) | |
| stitched.paste(candidate_img, (0, base_img.height)) | |
| else: | |
| target_height = max(base_img.height, candidate_img.height) | |
| base_img = _pad_height(base_img, target_height) | |
| candidate_img = _pad_height(candidate_img, target_height) | |
| merged_width = base_img.width + candidate_img.width | |
| stitched = Image.new("RGB", (merged_width, target_height), color=(255, 255, 255)) | |
| stitched.paste(base_img, (0, 0)) | |
| stitched.paste(candidate_img, (base_img.width, 0)) | |
| merged_name = ( | |
| f"page_{base_elem['page']}_to_{candidate_elem['page']}_" | |
| f"table_merged_{merge_index}.png" | |
| ) | |
| merged_path = tables_dir / merged_name | |
| stitched.save(merged_path) | |
| # Remove original partial crops to avoid duplicates | |
| (out_dir / base_elem["image_path"]).unlink(missing_ok=True) | |
| (out_dir / candidate_elem["image_path"]).unlink(missing_ok=True) | |
| new_bbox = [ | |
| min(base_elem["bbox_pixels"][0], candidate_elem["bbox_pixels"][0]), | |
| min(base_elem["bbox_pixels"][1], candidate_elem["bbox_pixels"][1]), | |
| max(base_elem["bbox_pixels"][2], candidate_elem["bbox_pixels"][2]), | |
| max(base_elem["bbox_pixels"][3], candidate_elem["bbox_pixels"][3]), | |
| ] | |
| merged_elem = base_elem.copy() | |
| merged_elem["page_span"] = [base_elem["page"], candidate_elem["page"]] | |
| merged_elem["box_refs"] = [ | |
| {"page": base_elem["page"], "image_path": base_elem["image_path"]}, | |
| {"page": candidate_elem["page"], "image_path": candidate_elem["image_path"]}, | |
| ] | |
| merged_elem["bbox_pixels"] = new_bbox | |
| merged_elem["image_path"] = str(merged_path.relative_to(out_dir)) | |
| merged_elem["width"] = stitched.width | |
| merged_elem["height"] = stitched.height | |
| merged_elem["page_height"] = stitched.height | |
| merged_elem["conf"] = min( | |
| base_elem.get("conf", 1.0), candidate_elem.get("conf", 1.0) | |
| ) | |
| return merged_elem | |
| def merge_spanning_tables(elements: List[Dict], out_dir: Path) -> List[Dict]: | |
| """ | |
| Stitch table crops that continue across adjacent pages using the heuristic | |
| from the legacy OpenCV-based extractor. | |
| """ | |
| if not elements: | |
| return elements | |
| tables_by_page: Dict[int, List[Dict]] = {} | |
| non_tables: List[Dict] = [] | |
| for elem in elements: | |
| if elem.get("type") != "table": | |
| non_tables.append(elem) | |
| continue | |
| page = elem.get("page") | |
| if not isinstance(page, int): | |
| non_tables.append(elem) | |
| continue | |
| tables_by_page.setdefault(page, []).append(elem) | |
| merged_results: List[Dict] = [] | |
| used_next: Dict[int, set[int]] = {} | |
| merge_counter = 0 | |
| for page in sorted(tables_by_page.keys()): | |
| current_tables = tables_by_page.get(page, []) | |
| next_page_tables = tables_by_page.get(page + 1, []) | |
| next_used_indices = used_next.get(page + 1, set()) | |
| current_used_indices = used_next.get(page, set()) | |
| for idx_current, table_elem in enumerate(current_tables): | |
| if idx_current in current_used_indices: | |
| continue | |
| if not next_page_tables: | |
| merged_results.append(table_elem) | |
| continue | |
| x, y, w, h = _bbox_to_rect(table_elem["bbox_pixels"]) | |
| matched = False | |
| for idx, candidate in enumerate(next_page_tables): | |
| if idx in next_used_indices: | |
| continue | |
| if candidate.get("type") != "table": | |
| continue | |
| cx, cy, cw, ch = _bbox_to_rect(candidate["bbox_pixels"]) | |
| vertical_match = ( | |
| abs(x - cx) <= TABLE_STITCH_TOLERANCES["x_tol"] | |
| and abs((x + w) - (cx + cw)) <= TABLE_STITCH_TOLERANCES["width_tol"] | |
| ) | |
| horizontal_match = ( | |
| abs(y - cy) <= TABLE_STITCH_TOLERANCES["y_tol"] | |
| and abs((y + h) - (cy + ch)) | |
| <= TABLE_STITCH_TOLERANCES["height_tol"] | |
| ) | |
| stitch_type = "vertical" if vertical_match else None | |
| if not stitch_type and horizontal_match: | |
| stitch_type = "horizontal" | |
| if not stitch_type: | |
| continue | |
| merge_counter += 1 | |
| merged_elem = _stitch_table_pair( | |
| table_elem, candidate, out_dir, merge_counter, stitch_type | |
| ) | |
| if merged_elem is None: | |
| continue | |
| merged_results.append(merged_elem) | |
| next_used_indices.add(idx) | |
| matched = True | |
| break | |
| if not matched: | |
| merged_results.append(table_elem) | |
| used_next[page + 1] = next_used_indices | |
| merged_results.extend(non_tables) | |
| return merged_results | |
| # ---------------------------------------------------------------------- | |
| # Draw layout boxes on the original PDF | |
| # ---------------------------------------------------------------------- | |
| def draw_layout_pdf(pdf_bytes: bytes, all_dets: List[List[dict]], | |
| scale: float, out_path: Path): | |
| """Annotate PDF with semi-transparent bounding boxes and labels.""" | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| for page_no, dets in enumerate(all_dets): | |
| page = doc[page_no] | |
| for d in dets: | |
| rgb = CLASS_COLORS.get(d["name"], (0, 0, 0)) | |
| rect = fitz.Rect([c / scale for c in d["bbox"]]) | |
| border_color = [c / 255 for c in rgb] | |
| fill_color = [c / 255 for c in rgb] | |
| fill_opacity = 0.15 | |
| border_width = 1.5 | |
| page.draw_rect( | |
| rect, | |
| color=border_color, | |
| fill=fill_color, | |
| width=border_width, | |
| overlay=True, | |
| fill_opacity=fill_opacity | |
| ) | |
| label = f"{d['name']} {d['conf']:.2f}" | |
| if d.get("source"): | |
| label += f" [{d['source'][0].upper()}]" | |
| text_bg = fitz.Rect(rect.x0, rect.y0 - 10, rect.x0 + 60, rect.y0) | |
| page.draw_rect(text_bg, color=None, fill=(1, 1, 1, 0.6), overlay=True) | |
| page.insert_text( | |
| (rect.x0 + 2, rect.y0 - 8), | |
| label, | |
| fontsize=6.5, | |
| color=border_color, | |
| overlay=True | |
| ) | |
| doc.save(str(out_path)) | |
| doc.close() | |
| # ---------------------------------------------------------------------- | |
| # Process a single PDF Page (for parallel execution) | |
| # ---------------------------------------------------------------------- | |
| def process_page(task_data: Tuple[int, bytes, float, Path, str]) -> Optional[Tuple[int, List[dict], List[dict]]]: | |
| """ | |
| Process a single page of a PDF in a worker process. | |
| Returns: (page_number, detections, elements) or None on failure | |
| """ | |
| pno, pdf_bytes, scale, out_dir, pdf_name = task_data | |
| if _shutdown_requested: | |
| return None | |
| pdf_pdfium = None | |
| try: | |
| pdf_pdfium = pdfium.PdfDocument(pdf_bytes) | |
| page = pdf_pdfium[pno] | |
| bitmap = page.render(scale=scale) | |
| pil = bitmap.to_pil() | |
| dets = detect_page(pil) | |
| elements = save_layout_elements(pil, pno, dets, out_dir) | |
| page_figures = len([d for d in dets if d['name'] == 'figure']) | |
| page_tables = len([d for d in dets if d['name'] == 'table']) | |
| logger.info(f" [{pdf_name}] Page {pno + 1}: {page_figures} figs, {page_tables} tables") | |
| page.close() | |
| pdf_pdfium.close() | |
| return (pno, dets, elements) | |
| except Exception as e: | |
| logger.error(f"Failed to process page {pno + 1} of {pdf_name}: {e}") | |
| if pdf_pdfium: | |
| pdf_pdfium.close() | |
| return None | |
| # ---------------------------------------------------------------------- | |
| # Process a full PDF using the persistent worker pool | |
| # ---------------------------------------------------------------------- | |
| def process_pdf_with_pool( | |
| pdf_path: Path, | |
| out_dir: Path, | |
| pool: Optional[Pool] = None, | |
| *, | |
| extract_images: bool = True, | |
| extract_markdown: bool = True, | |
| ): | |
| """ | |
| Main processing pipeline for a PDF file. | |
| If pool is provided, uses it. Otherwise processes serially. | |
| """ | |
| if _shutdown_requested: | |
| logger.warning(f"Skipping {pdf_path.name} due to shutdown request") | |
| return | |
| stem = pdf_path.stem | |
| logger.info(f"Processing {pdf_path.name}") | |
| pdf_bytes = pdf_path.read_bytes() | |
| doc = None | |
| try: | |
| doc = pdfium.PdfDocument(pdf_bytes) | |
| page_count = len(doc) | |
| except Exception as e: | |
| logger.error(f"Failed to open PDF {pdf_path.name}: {e}. Skipping.") | |
| return | |
| finally: | |
| if doc is not None: | |
| doc.close() | |
| scale = 2.0 | |
| all_elements: List[Dict] = [] | |
| filtered_dets: List[List[dict]] = [] | |
| if extract_images: | |
| all_dets: List[Optional[List[dict]]] = [None] * page_count | |
| if pool is not None and USE_MULTIPROCESSING: | |
| logger.info(f" Using worker pool for {page_count} pages...") | |
| tasks = [ | |
| (pno, pdf_bytes, scale, out_dir, pdf_path.name) | |
| for pno in range(page_count) | |
| ] | |
| try: | |
| results = pool.map(process_page, tasks) | |
| for res in results: | |
| if res: | |
| pno, dets, elements = res | |
| all_dets[pno] = dets | |
| all_elements.extend(elements) | |
| except KeyboardInterrupt: | |
| logger.warning("Processing interrupted during parallel execution") | |
| raise | |
| else: | |
| logger.info("Using serial processing...") | |
| try: | |
| pdf_pdfium = pdfium.PdfDocument(pdf_bytes) | |
| for pno in range(page_count): | |
| if _shutdown_requested: | |
| logger.warning( | |
| f"Stopping at page {pno + 1}/{page_count} due to shutdown request" | |
| ) | |
| break | |
| try: | |
| logger.info(f" Processing page {pno + 1}/{page_count}") | |
| page = pdf_pdfium[pno] | |
| bitmap = page.render(scale=scale) | |
| pil = bitmap.to_pil() | |
| dets = detect_page(pil) | |
| all_dets[pno] = dets | |
| elements = save_layout_elements(pil, pno, dets, out_dir) | |
| all_elements.extend(elements) | |
| page_figures = len([d for d in dets if d["name"] == "figure"]) | |
| page_tables = len([d for d in dets if d["name"] == "table"]) | |
| logger.info( | |
| f" Found {page_figures} figures and {page_tables} tables" | |
| ) | |
| page.close() | |
| except Exception as e: | |
| logger.error(f"Failed to process page {pno + 1}: {e}. Skipping page.") | |
| pdf_pdfium.close() | |
| except Exception as e: | |
| logger.error(f"Fatal error processing {pdf_path.name}: {e}") | |
| if "pdf_pdfium" in locals() and pdf_pdfium: | |
| pdf_pdfium.close() | |
| return | |
| dets_per_page: List[Optional[List[Dict[str, Any]]]] = [ | |
| det if det is not None else None for det in all_dets | |
| ] | |
| filtered_dets = [d for d in all_dets if d is not None] | |
| if all_elements: | |
| all_elements = merge_spanning_tables(all_elements, out_dir) | |
| all_elements = attach_cross_page_figure_captions( | |
| all_elements, dets_per_page, pdf_bytes, out_dir, scale | |
| ) | |
| if all_elements: | |
| content_list_path = out_dir / f"{stem}_content_list.json" | |
| with open(content_list_path, "w", encoding="utf-8") as f: | |
| json.dump(all_elements, f, ensure_ascii=False, indent=4) | |
| logger.info(f" Saved {len(all_elements)} elements to JSON") | |
| if filtered_dets: | |
| draw_layout_pdf( | |
| pdf_bytes, filtered_dets, scale, out_dir / f"{stem}_layout.pdf" | |
| ) | |
| logger.info(" Generated annotated PDF") | |
| else: | |
| logger.warning(f"No detections found for {stem}. Skipping layout PDF.") | |
| else: | |
| logger.info(" Image extraction skipped per configuration.") | |
| markdown_path = None | |
| if extract_markdown: | |
| markdown_path = write_markdown_document(pdf_path, out_dir) | |
| if markdown_path is None: | |
| logger.warning(f" Markdown extraction yielded no content for {stem}.") | |
| if _shutdown_requested: | |
| logger.warning(f"β οΈ Partial results saved for {stem} β {out_dir}") | |
| else: | |
| if extract_images: | |
| logger.success( | |
| f"β {stem} β {out_dir} ({len(all_elements)} elements extracted)" | |
| ) | |
| else: | |
| logger.success(f"β {stem} β {out_dir} (image extraction skipped)") | |
| # ---------------------------------------------------------------------- | |
| # Main | |
| # ---------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # Important for multiprocessing on Windows/macOS | |
| torch.multiprocessing.set_start_method('spawn', force=True) | |
| # Setup signal handlers for graceful shutdown | |
| setup_signal_handlers() | |
| INPUT_DIR = Path("./pdfs") | |
| OUTPUT_DIR = Path("./output") | |
| os.makedirs(INPUT_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| pdf_files = list(INPUT_DIR.glob("*.pdf")) | |
| if not pdf_files: | |
| logger.warning("No PDF files found in ./pdfs") | |
| logger.info("Please add PDF files to the ./pdfs directory") | |
| logger.info("The script will exit gracefully. No errors occurred.") | |
| sys.exit(0) | |
| logger.info(f"Found {len(pdf_files)} PDF file(s) to process") | |
| logger.info(f"Settings: MODEL_SIZE={MODEL_SIZE}, CONF={CONF_THRESHOLD}") | |
| # Determine worker count | |
| total_cpus = cpu_count() | |
| if NUM_WORKERS is None: | |
| num_workers = max(1, total_cpus - 1) | |
| else: | |
| num_workers = max(1, min(NUM_WORKERS, total_cpus)) | |
| # Decide whether to use multiprocessing | |
| use_pool = USE_MULTIPROCESSING and DEVICE == "cpu" and total_cpus >= 4 | |
| if use_pool: | |
| logger.info(f"π Creating persistent worker pool with {num_workers} workers...") | |
| else: | |
| if not USE_MULTIPROCESSING: | |
| logger.info("Multiprocessing disabled by configuration") | |
| elif DEVICE != "cpu": | |
| logger.info(f"Using serial GPU processing (device: {DEVICE})") | |
| else: | |
| logger.info(f"Using serial CPU processing (CPU count {total_cpus} too low)") | |
| pool = None | |
| try: | |
| # Create persistent pool ONCE for all PDFs | |
| if use_pool: | |
| pool = Pool(processes=num_workers, initializer=init_worker) | |
| logger.success(f"β Worker pool ready with {num_workers} workers\n") | |
| else: | |
| # Load model in main process for serial execution | |
| logger.info("Initializing model in main process...") | |
| get_model() | |
| logger.success(f"β Model loaded (device: {DEVICE})\n") | |
| # Process all PDFs using the same pool | |
| for i, pdf_path in enumerate(pdf_files, 1): | |
| if _shutdown_requested: | |
| logger.warning(f"\nShutdown requested. Processed {i-1}/{len(pdf_files)} files.") | |
| break | |
| logger.info(f"\n{'='*60}") | |
| logger.info(f"π File {i}/{len(pdf_files)}: {pdf_path.name}") | |
| logger.info(f"{'='*60}") | |
| sub_out = OUTPUT_DIR / pdf_path.stem | |
| os.makedirs(sub_out, exist_ok=True) | |
| try: | |
| process_pdf_with_pool(pdf_path, sub_out, pool) | |
| except KeyboardInterrupt: | |
| logger.warning(f"\nInterrupted while processing {pdf_path.name}") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error processing {pdf_path.name}: {e}") | |
| if _shutdown_requested: | |
| break | |
| logger.info("Continuing with next file...") | |
| continue | |
| if _shutdown_requested: | |
| logger.warning(f"\nβ οΈ Processing interrupted. Partial results saved in {OUTPUT_DIR}") | |
| else: | |
| logger.success(f"\n⨠All done! Results are in {OUTPUT_DIR}") | |
| except KeyboardInterrupt: | |
| logger.error("\nβ Processing interrupted by user") | |
| sys.exit(1) | |
| except Exception as e: | |
| logger.error(f"\nβ Fatal error: {e}") | |
| sys.exit(1) | |
| finally: | |
| # Clean up pool if it exists | |
| if pool is not None: | |
| logger.info("\nπ§Ή Shutting down worker pool...") | |
| pool.close() | |
| pool.join() | |
| logger.success("β Worker pool closed cleanly") |