Classify and crop unmatched vision regions

This commit is contained in:
Sean McElwain 2026-05-30 19:31:10 -05:00
parent f3e61e877b
commit 05c9b6964a
1 changed files with 129 additions and 0 deletions

View File

@ -423,6 +423,129 @@ def score_vision_regions_against_layout(
},
}
def _classify_region_geometry(region: dict[str, Any], *, page_width: float, page_height: float) -> dict[str, Any]:
bbox = region.get("bbox") or [0, 0, 0, 0]
x1, y1, x2, y2 = [float(v) for v in bbox[:4]]
w = max(0.0, x2 - x1)
h = max(0.0, y2 - y1)
area = w * h
page_area = max(1.0, page_width * page_height)
aspect = w / h if h else 0.0
label = "unknown_region"
confidence = 0.20
if area > page_area * 0.18:
label = "large_document_region"
confidence = 0.35
elif w > page_width * 0.70 and aspect > 4:
label = "wide_text_band"
confidence = 0.45
elif h > page_height * 0.10 and w > page_width * 0.35:
label = "large_text_block"
confidence = 0.40
elif aspect > 8:
label = "horizontal_rule_or_text_band"
confidence = 0.35
elif w < page_width * 0.12 and h < page_height * 0.06:
label = "small_symbol_or_short_text"
confidence = 0.30
item = dict(region)
item["geometry_class"] = label
item["geometry_confidence"] = confidence
item["geometry_features"] = {
"width": w,
"height": h,
"area_ratio": area / page_area,
"aspect_ratio": aspect,
}
return item
def _write_region_crop(
png_path: str | Path,
region: dict[str, Any],
*,
crop_index: int,
padding_px: int = 8,
) -> str | None:
if cv2 is None:
return None
img = cv2.imread(str(png_path))
if img is None:
return None
height, width = img.shape[:2]
bbox = region.get("rendered_bbox") or region.get("bbox")
if not bbox:
return None
x1, y1, x2, y2 = [int(round(float(v))) for v in bbox[:4]]
x1 = max(0, x1 - padding_px)
y1 = max(0, y1 - padding_px)
x2 = min(width, x2 + padding_px)
y2 = min(height, y2 + padding_px)
if x2 <= x1 or y2 <= y1:
return None
crop = img[y1:y2, x1:x2]
crop_dir = Path(png_path).parent / "crops" / Path(png_path).stem
crop_dir.mkdir(parents=True, exist_ok=True)
crop_path = crop_dir / f"region_{crop_index:04d}.png"
cv2.imwrite(str(crop_path), crop)
return str(crop_path)
def classify_and_crop_unmatched_regions(
vision_result: dict[str, Any],
layout_json: dict[str, Any] | None,
region_score: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Classify unmatched CV regions and write region crop images for later VLM analysis.
"""
pages = (layout_json or {}).get("pages") or []
rendered_pages = vision_result.get("rendered_pages") or []
if not pages or not rendered_pages:
return {
"schema_version": "vision_region_classification_v1",
"status": "not_enough_data",
"classified_regions": [],
}
page = pages[0]
page_width = float(page.get("page_width") or page.get("width") or 1)
page_height = float(page.get("page_height") or page.get("height") or 1)
png_path = rendered_pages[0].get("png_path")
unmatched_regions: list[dict[str, Any]] = []
if region_score:
for page_score in region_score.get("page_scores") or []:
unmatched_regions.extend(page_score.get("unmatched_regions") or [])
if not unmatched_regions:
unmatched_regions = ((vision_result.get("layers") or {}).get("vision_regions")) or []
classified: list[dict[str, Any]] = []
for idx, region in enumerate(unmatched_regions):
item = _classify_region_geometry(region, page_width=page_width, page_height=page_height)
if png_path:
item["crop_path"] = _write_region_crop(png_path, item, crop_index=idx)
item["classification_source"] = "opencv_geometry_classifier"
classified.append(item)
return {
"schema_version": "vision_region_classification_v1",
"status": "classified",
"classified_region_count": len(classified),
"classified_regions": classified,
}
def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_result: dict[str, Any]) -> dict[str, Any]:
"""
Convert vision analysis into normal layout_json.
@ -437,6 +560,11 @@ def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_re
normalized_vision = normalize_vision_regions_to_layout(vision_result, layout)
region_score = score_vision_regions_against_layout(normalized_vision, layout)
region_classification = classify_and_crop_unmatched_regions(
normalized_vision,
layout,
region_score,
)
layout["vision_assisted"] = True
layout["vision_assisted_status"] = normalized_vision.get("status", "unknown")
@ -444,6 +572,7 @@ def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_re
layout["vision_model_name"] = normalized_vision.get("model_name")
layout["vision_coordinate_normalization"] = normalized_vision.get("coordinate_normalization")
layout["vision_region_score"] = region_score
layout["vision_region_classification"] = region_classification
layout["layout_sync_source"] = "vision_assisted"
layout["layout_needs_review"] = True
return layout