293 lines
9.1 KiB
Python
293 lines
9.1 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import fitz
|
|
import pytesseract
|
|
from pdf2image import convert_from_path
|
|
from PIL import Image
|
|
|
|
|
|
@dataclass
|
|
class LayoutOCRResult:
|
|
engine_name: str
|
|
engine_version: str
|
|
pages: list[dict[str, Any]]
|
|
|
|
def to_analysis_json(self) -> dict[str, Any]:
|
|
return {
|
|
"schema_version": 1,
|
|
"analysis_type": "canonical",
|
|
"engine": {
|
|
"name": self.engine_name,
|
|
"version": self.engine_version,
|
|
},
|
|
"pages": self.pages,
|
|
}
|
|
|
|
|
|
|
|
|
|
def _safe_float(value, default=0.0):
|
|
try:
|
|
return float(value)
|
|
except Exception:
|
|
return float(default)
|
|
|
|
|
|
def _bbox_union(items: list[dict[str, Any]]) -> list[float]:
|
|
if not items:
|
|
return [0.0, 0.0, 0.0, 0.0]
|
|
xs1, ys1, xs2, ys2 = [], [], [], []
|
|
for item in items:
|
|
bbox = item.get("bbox") or [0, 0, 0, 0]
|
|
if not isinstance(bbox, (list, tuple)) or len(bbox) != 4:
|
|
continue
|
|
xs1.append(_safe_float(bbox[0]))
|
|
ys1.append(_safe_float(bbox[1]))
|
|
xs2.append(_safe_float(bbox[2]))
|
|
ys2.append(_safe_float(bbox[3]))
|
|
if not xs1:
|
|
return [0.0, 0.0, 0.0, 0.0]
|
|
return [min(xs1), min(ys1), max(xs2), max(ys2)]
|
|
|
|
|
|
def _word_center_x(word: dict[str, Any]) -> float:
|
|
bbox = word.get("bbox") or [0, 0, 0, 0]
|
|
return (_safe_float(bbox[0]) + _safe_float(bbox[2])) / 2.0
|
|
|
|
|
|
def _word_center_y(word: dict[str, Any]) -> float:
|
|
bbox = word.get("bbox") or [0, 0, 0, 0]
|
|
return (_safe_float(bbox[1]) + _safe_float(bbox[3])) / 2.0
|
|
|
|
|
|
def _group_words_into_lines_local(words: list[dict[str, Any]], y_tol: float = 12.0) -> list[dict[str, Any]]:
|
|
if not words:
|
|
return []
|
|
|
|
ordered = sorted(words, key=lambda w: (_word_center_y(w), _safe_float((w.get("bbox") or [0, 0, 0, 0])[0])))
|
|
groups: list[list[dict[str, Any]]] = []
|
|
|
|
for word in ordered:
|
|
placed = False
|
|
wy = _word_center_y(word)
|
|
for group in groups:
|
|
gy = sum(_word_center_y(item) for item in group) / len(group)
|
|
if abs(wy - gy) <= y_tol:
|
|
group.append(word)
|
|
placed = True
|
|
break
|
|
if not placed:
|
|
groups.append([word])
|
|
|
|
lines: list[dict[str, Any]] = []
|
|
for idx, group in enumerate(groups, start=1):
|
|
group = sorted(group, key=lambda w: _safe_float((w.get("bbox") or [0, 0, 0, 0])[0]))
|
|
text_value = " ".join((w.get("text") or "").strip() for w in group if (w.get("text") or "").strip()).strip()
|
|
if not text_value:
|
|
continue
|
|
bbox = _bbox_union(group)
|
|
avg_height = max(
|
|
1.0,
|
|
sum((_safe_float((w.get("bbox") or [0, 0, 0, 0])[3]) - _safe_float((w.get("bbox") or [0, 0, 0, 0])[1])) for w in group) / len(group),
|
|
)
|
|
lines.append(
|
|
{
|
|
"line_id": idx,
|
|
"text": text_value,
|
|
"bbox": bbox,
|
|
"confidence": None,
|
|
"font_family_guess": "Helvetica",
|
|
"font_size_guess": max(6.0, avg_height * 0.75),
|
|
"text_color_guess": "#000000",
|
|
"word_ids": [w.get("word_id") for w in group if w.get("word_id") is not None],
|
|
"words": group,
|
|
}
|
|
)
|
|
return lines
|
|
|
|
|
|
def _build_regions_from_words(words: list[dict[str, Any]], page_w: float) -> list[dict[str, Any]]:
|
|
visible_words = [
|
|
w for w in words
|
|
if (w.get("text") or "").strip()
|
|
and isinstance(w.get("bbox"), (list, tuple))
|
|
and len(w.get("bbox")) == 4
|
|
]
|
|
if not visible_words:
|
|
return []
|
|
|
|
ordered_x = sorted(visible_words, key=_word_center_x)
|
|
centers = [_word_center_x(w) for w in ordered_x]
|
|
|
|
split_idx = None
|
|
max_gap = 0.0
|
|
for i in range(len(centers) - 1):
|
|
gap = centers[i + 1] - centers[i]
|
|
if gap > max_gap:
|
|
max_gap = gap
|
|
split_idx = i
|
|
|
|
min_gap = max(80.0, page_w * 0.10)
|
|
|
|
if split_idx is None or max_gap < min_gap:
|
|
bucket = sorted(visible_words, key=lambda w: (_word_center_y(w), _word_center_x(w)))
|
|
return [
|
|
{
|
|
"region_id": 1,
|
|
"bbox": _bbox_union(bucket),
|
|
"words": bucket,
|
|
"lines": _group_words_into_lines_local(bucket),
|
|
}
|
|
]
|
|
|
|
split_x = (centers[split_idx] + centers[split_idx + 1]) / 2.0
|
|
left_words = [w for w in visible_words if _word_center_x(w) <= split_x]
|
|
right_words = [w for w in visible_words if _word_center_x(w) > split_x]
|
|
|
|
buckets = [bucket for bucket in [left_words, right_words] if bucket]
|
|
buckets.sort(key=lambda bucket: _bbox_union(bucket)[0])
|
|
|
|
regions = []
|
|
for idx, bucket in enumerate(buckets, start=1):
|
|
bucket = sorted(bucket, key=lambda w: (_word_center_y(w), _word_center_x(w)))
|
|
regions.append(
|
|
{
|
|
"region_id": idx,
|
|
"bbox": _bbox_union(bucket),
|
|
"words": bucket,
|
|
"lines": _group_words_into_lines_local(bucket),
|
|
}
|
|
)
|
|
return regions
|
|
|
|
def _group_words_into_lines(words: list[dict[str, Any]], y_tol: float = 12.0) -> list[dict[str, Any]]:
|
|
if not words:
|
|
return []
|
|
|
|
words = sorted(words, key=lambda w: (w["bbox"][1], w["bbox"][0]))
|
|
groups: list[list[dict[str, Any]]] = []
|
|
|
|
for word in words:
|
|
placed = False
|
|
wy = word["bbox"][1]
|
|
for group in groups:
|
|
gy = sum(item["bbox"][1] for item in group) / len(group)
|
|
if abs(wy - gy) <= y_tol:
|
|
group.append(word)
|
|
placed = True
|
|
break
|
|
if not placed:
|
|
groups.append([word])
|
|
|
|
lines: list[dict[str, Any]] = []
|
|
for group in groups:
|
|
group = sorted(group, key=lambda w: w["bbox"][0])
|
|
text = " ".join((w.get("text") or "").strip() for w in group).strip()
|
|
if not text:
|
|
continue
|
|
left = min(w["bbox"][0] for w in group)
|
|
top = min(w["bbox"][1] for w in group)
|
|
right = max(w["bbox"][2] for w in group)
|
|
bottom = max(w["bbox"][3] for w in group)
|
|
avg_height = max(1.0, sum((w["bbox"][3] - w["bbox"][1]) for w in group) / len(group))
|
|
lines.append(
|
|
{
|
|
"text": text,
|
|
"bbox": [left, top, right, bottom],
|
|
"confidence": None,
|
|
"font_family_guess": "Helvetica",
|
|
"font_size_guess": max(6.0, avg_height * 0.75),
|
|
"text_color_guess": "#000000",
|
|
"words": group,
|
|
}
|
|
)
|
|
return lines
|
|
|
|
|
|
def run_layout_ocr(pdf_path: str | Path, dpi: int = 300) -> LayoutOCRResult:
|
|
pdf_path = Path(pdf_path)
|
|
if not pdf_path.exists():
|
|
raise FileNotFoundError(f"PDF not found: {pdf_path}")
|
|
|
|
doc = fitz.open(pdf_path)
|
|
pil_pages = convert_from_path(str(pdf_path), dpi=dpi)
|
|
|
|
pages: list[dict[str, Any]] = []
|
|
|
|
for idx, (pdf_page, pil_img) in enumerate(zip(doc, pil_pages), start=1):
|
|
page_w = float(pdf_page.rect.width)
|
|
page_h = float(pdf_page.rect.height)
|
|
|
|
if not isinstance(pil_img, Image.Image):
|
|
raise ValueError(f"Rendered page {idx} is not a PIL image")
|
|
|
|
img_w, img_h = pil_img.size
|
|
scale_x = page_w / float(img_w)
|
|
scale_y = page_h / float(img_h)
|
|
|
|
data = pytesseract.image_to_data(
|
|
pil_img,
|
|
output_type=pytesseract.Output.DICT,
|
|
config="--oem 3 --psm 6",
|
|
)
|
|
|
|
words: list[dict[str, Any]] = []
|
|
n = len(data.get("text", []))
|
|
for i in range(n):
|
|
text = (data["text"][i] or "").strip()
|
|
if not text:
|
|
continue
|
|
|
|
try:
|
|
conf = float(data["conf"][i])
|
|
except Exception:
|
|
conf = None
|
|
|
|
left_px = float(data["left"][i])
|
|
top_px = float(data["top"][i])
|
|
width_px = float(data["width"][i])
|
|
height_px = float(data["height"][i])
|
|
|
|
if width_px <= 0 or height_px <= 0:
|
|
continue
|
|
|
|
left = left_px * scale_x
|
|
top = top_px * scale_y
|
|
right = (left_px + width_px) * scale_x
|
|
bottom = (top_px + height_px) * scale_y
|
|
|
|
words.append(
|
|
{
|
|
"word_id": len(words) + 1,
|
|
"text": text,
|
|
"bbox": [left, top, right, bottom],
|
|
"confidence": conf,
|
|
}
|
|
)
|
|
|
|
lines = _group_words_into_lines(words)
|
|
regions = _build_regions_from_words(words, page_w)
|
|
|
|
pages.append(
|
|
{
|
|
"page": idx,
|
|
"page_width": page_w,
|
|
"page_height": page_h,
|
|
"image_width": page_w,
|
|
"image_height": page_h,
|
|
"lines": lines,
|
|
"words": words,
|
|
"regions": regions,
|
|
}
|
|
)
|
|
|
|
return LayoutOCRResult(
|
|
engine_name="tesseract_layout",
|
|
engine_version=str(pytesseract.get_tesseract_version()),
|
|
pages=pages,
|
|
)
|