Classify and crop unmatched vision regions
This commit is contained in:
parent
f3e61e877b
commit
05c9b6964a
|
|
@ -423,6 +423,129 @@ def score_vision_regions_against_layout(
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
def _classify_region_geometry(region: dict[str, Any], *, page_width: float, page_height: float) -> dict[str, Any]:
|
||||
bbox = region.get("bbox") or [0, 0, 0, 0]
|
||||
x1, y1, x2, y2 = [float(v) for v in bbox[:4]]
|
||||
w = max(0.0, x2 - x1)
|
||||
h = max(0.0, y2 - y1)
|
||||
area = w * h
|
||||
page_area = max(1.0, page_width * page_height)
|
||||
aspect = w / h if h else 0.0
|
||||
|
||||
label = "unknown_region"
|
||||
confidence = 0.20
|
||||
|
||||
if area > page_area * 0.18:
|
||||
label = "large_document_region"
|
||||
confidence = 0.35
|
||||
elif w > page_width * 0.70 and aspect > 4:
|
||||
label = "wide_text_band"
|
||||
confidence = 0.45
|
||||
elif h > page_height * 0.10 and w > page_width * 0.35:
|
||||
label = "large_text_block"
|
||||
confidence = 0.40
|
||||
elif aspect > 8:
|
||||
label = "horizontal_rule_or_text_band"
|
||||
confidence = 0.35
|
||||
elif w < page_width * 0.12 and h < page_height * 0.06:
|
||||
label = "small_symbol_or_short_text"
|
||||
confidence = 0.30
|
||||
|
||||
item = dict(region)
|
||||
item["geometry_class"] = label
|
||||
item["geometry_confidence"] = confidence
|
||||
item["geometry_features"] = {
|
||||
"width": w,
|
||||
"height": h,
|
||||
"area_ratio": area / page_area,
|
||||
"aspect_ratio": aspect,
|
||||
}
|
||||
return item
|
||||
|
||||
|
||||
def _write_region_crop(
|
||||
png_path: str | Path,
|
||||
region: dict[str, Any],
|
||||
*,
|
||||
crop_index: int,
|
||||
padding_px: int = 8,
|
||||
) -> str | None:
|
||||
if cv2 is None:
|
||||
return None
|
||||
|
||||
img = cv2.imread(str(png_path))
|
||||
if img is None:
|
||||
return None
|
||||
|
||||
height, width = img.shape[:2]
|
||||
bbox = region.get("rendered_bbox") or region.get("bbox")
|
||||
if not bbox:
|
||||
return None
|
||||
|
||||
x1, y1, x2, y2 = [int(round(float(v))) for v in bbox[:4]]
|
||||
x1 = max(0, x1 - padding_px)
|
||||
y1 = max(0, y1 - padding_px)
|
||||
x2 = min(width, x2 + padding_px)
|
||||
y2 = min(height, y2 + padding_px)
|
||||
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return None
|
||||
|
||||
crop = img[y1:y2, x1:x2]
|
||||
crop_dir = Path(png_path).parent / "crops" / Path(png_path).stem
|
||||
crop_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
crop_path = crop_dir / f"region_{crop_index:04d}.png"
|
||||
cv2.imwrite(str(crop_path), crop)
|
||||
return str(crop_path)
|
||||
|
||||
|
||||
def classify_and_crop_unmatched_regions(
|
||||
vision_result: dict[str, Any],
|
||||
layout_json: dict[str, Any] | None,
|
||||
region_score: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Classify unmatched CV regions and write region crop images for later VLM analysis.
|
||||
"""
|
||||
pages = (layout_json or {}).get("pages") or []
|
||||
rendered_pages = vision_result.get("rendered_pages") or []
|
||||
if not pages or not rendered_pages:
|
||||
return {
|
||||
"schema_version": "vision_region_classification_v1",
|
||||
"status": "not_enough_data",
|
||||
"classified_regions": [],
|
||||
}
|
||||
|
||||
page = pages[0]
|
||||
page_width = float(page.get("page_width") or page.get("width") or 1)
|
||||
page_height = float(page.get("page_height") or page.get("height") or 1)
|
||||
png_path = rendered_pages[0].get("png_path")
|
||||
|
||||
unmatched_regions: list[dict[str, Any]] = []
|
||||
if region_score:
|
||||
for page_score in region_score.get("page_scores") or []:
|
||||
unmatched_regions.extend(page_score.get("unmatched_regions") or [])
|
||||
|
||||
if not unmatched_regions:
|
||||
unmatched_regions = ((vision_result.get("layers") or {}).get("vision_regions")) or []
|
||||
|
||||
classified: list[dict[str, Any]] = []
|
||||
for idx, region in enumerate(unmatched_regions):
|
||||
item = _classify_region_geometry(region, page_width=page_width, page_height=page_height)
|
||||
if png_path:
|
||||
item["crop_path"] = _write_region_crop(png_path, item, crop_index=idx)
|
||||
item["classification_source"] = "opencv_geometry_classifier"
|
||||
classified.append(item)
|
||||
|
||||
return {
|
||||
"schema_version": "vision_region_classification_v1",
|
||||
"status": "classified",
|
||||
"classified_region_count": len(classified),
|
||||
"classified_regions": classified,
|
||||
}
|
||||
|
||||
def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_result: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Convert vision analysis into normal layout_json.
|
||||
|
|
@ -437,6 +560,11 @@ def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_re
|
|||
|
||||
normalized_vision = normalize_vision_regions_to_layout(vision_result, layout)
|
||||
region_score = score_vision_regions_against_layout(normalized_vision, layout)
|
||||
region_classification = classify_and_crop_unmatched_regions(
|
||||
normalized_vision,
|
||||
layout,
|
||||
region_score,
|
||||
)
|
||||
|
||||
layout["vision_assisted"] = True
|
||||
layout["vision_assisted_status"] = normalized_vision.get("status", "unknown")
|
||||
|
|
@ -444,6 +572,7 @@ def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_re
|
|||
layout["vision_model_name"] = normalized_vision.get("model_name")
|
||||
layout["vision_coordinate_normalization"] = normalized_vision.get("coordinate_normalization")
|
||||
layout["vision_region_score"] = region_score
|
||||
layout["vision_region_classification"] = region_classification
|
||||
layout["layout_sync_source"] = "vision_assisted"
|
||||
layout["layout_needs_review"] = True
|
||||
return layout
|
||||
|
|
|
|||
Loading…
Reference in New Issue