Classify and crop unmatched vision regions
This commit is contained in:
parent
f3e61e877b
commit
05c9b6964a
|
|
@ -423,6 +423,129 @@ def score_vision_regions_against_layout(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_region_geometry(region: dict[str, Any], *, page_width: float, page_height: float) -> dict[str, Any]:
|
||||||
|
bbox = region.get("bbox") or [0, 0, 0, 0]
|
||||||
|
x1, y1, x2, y2 = [float(v) for v in bbox[:4]]
|
||||||
|
w = max(0.0, x2 - x1)
|
||||||
|
h = max(0.0, y2 - y1)
|
||||||
|
area = w * h
|
||||||
|
page_area = max(1.0, page_width * page_height)
|
||||||
|
aspect = w / h if h else 0.0
|
||||||
|
|
||||||
|
label = "unknown_region"
|
||||||
|
confidence = 0.20
|
||||||
|
|
||||||
|
if area > page_area * 0.18:
|
||||||
|
label = "large_document_region"
|
||||||
|
confidence = 0.35
|
||||||
|
elif w > page_width * 0.70 and aspect > 4:
|
||||||
|
label = "wide_text_band"
|
||||||
|
confidence = 0.45
|
||||||
|
elif h > page_height * 0.10 and w > page_width * 0.35:
|
||||||
|
label = "large_text_block"
|
||||||
|
confidence = 0.40
|
||||||
|
elif aspect > 8:
|
||||||
|
label = "horizontal_rule_or_text_band"
|
||||||
|
confidence = 0.35
|
||||||
|
elif w < page_width * 0.12 and h < page_height * 0.06:
|
||||||
|
label = "small_symbol_or_short_text"
|
||||||
|
confidence = 0.30
|
||||||
|
|
||||||
|
item = dict(region)
|
||||||
|
item["geometry_class"] = label
|
||||||
|
item["geometry_confidence"] = confidence
|
||||||
|
item["geometry_features"] = {
|
||||||
|
"width": w,
|
||||||
|
"height": h,
|
||||||
|
"area_ratio": area / page_area,
|
||||||
|
"aspect_ratio": aspect,
|
||||||
|
}
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def _write_region_crop(
|
||||||
|
png_path: str | Path,
|
||||||
|
region: dict[str, Any],
|
||||||
|
*,
|
||||||
|
crop_index: int,
|
||||||
|
padding_px: int = 8,
|
||||||
|
) -> str | None:
|
||||||
|
if cv2 is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
img = cv2.imread(str(png_path))
|
||||||
|
if img is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
height, width = img.shape[:2]
|
||||||
|
bbox = region.get("rendered_bbox") or region.get("bbox")
|
||||||
|
if not bbox:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = [int(round(float(v))) for v in bbox[:4]]
|
||||||
|
x1 = max(0, x1 - padding_px)
|
||||||
|
y1 = max(0, y1 - padding_px)
|
||||||
|
x2 = min(width, x2 + padding_px)
|
||||||
|
y2 = min(height, y2 + padding_px)
|
||||||
|
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
crop = img[y1:y2, x1:x2]
|
||||||
|
crop_dir = Path(png_path).parent / "crops" / Path(png_path).stem
|
||||||
|
crop_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
crop_path = crop_dir / f"region_{crop_index:04d}.png"
|
||||||
|
cv2.imwrite(str(crop_path), crop)
|
||||||
|
return str(crop_path)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_and_crop_unmatched_regions(
|
||||||
|
vision_result: dict[str, Any],
|
||||||
|
layout_json: dict[str, Any] | None,
|
||||||
|
region_score: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Classify unmatched CV regions and write region crop images for later VLM analysis.
|
||||||
|
"""
|
||||||
|
pages = (layout_json or {}).get("pages") or []
|
||||||
|
rendered_pages = vision_result.get("rendered_pages") or []
|
||||||
|
if not pages or not rendered_pages:
|
||||||
|
return {
|
||||||
|
"schema_version": "vision_region_classification_v1",
|
||||||
|
"status": "not_enough_data",
|
||||||
|
"classified_regions": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
page = pages[0]
|
||||||
|
page_width = float(page.get("page_width") or page.get("width") or 1)
|
||||||
|
page_height = float(page.get("page_height") or page.get("height") or 1)
|
||||||
|
png_path = rendered_pages[0].get("png_path")
|
||||||
|
|
||||||
|
unmatched_regions: list[dict[str, Any]] = []
|
||||||
|
if region_score:
|
||||||
|
for page_score in region_score.get("page_scores") or []:
|
||||||
|
unmatched_regions.extend(page_score.get("unmatched_regions") or [])
|
||||||
|
|
||||||
|
if not unmatched_regions:
|
||||||
|
unmatched_regions = ((vision_result.get("layers") or {}).get("vision_regions")) or []
|
||||||
|
|
||||||
|
classified: list[dict[str, Any]] = []
|
||||||
|
for idx, region in enumerate(unmatched_regions):
|
||||||
|
item = _classify_region_geometry(region, page_width=page_width, page_height=page_height)
|
||||||
|
if png_path:
|
||||||
|
item["crop_path"] = _write_region_crop(png_path, item, crop_index=idx)
|
||||||
|
item["classification_source"] = "opencv_geometry_classifier"
|
||||||
|
classified.append(item)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"schema_version": "vision_region_classification_v1",
|
||||||
|
"status": "classified",
|
||||||
|
"classified_region_count": len(classified),
|
||||||
|
"classified_regions": classified,
|
||||||
|
}
|
||||||
|
|
||||||
def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_result: dict[str, Any]) -> dict[str, Any]:
|
def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_result: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert vision analysis into normal layout_json.
|
Convert vision analysis into normal layout_json.
|
||||||
|
|
@ -437,6 +560,11 @@ def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_re
|
||||||
|
|
||||||
normalized_vision = normalize_vision_regions_to_layout(vision_result, layout)
|
normalized_vision = normalize_vision_regions_to_layout(vision_result, layout)
|
||||||
region_score = score_vision_regions_against_layout(normalized_vision, layout)
|
region_score = score_vision_regions_against_layout(normalized_vision, layout)
|
||||||
|
region_classification = classify_and_crop_unmatched_regions(
|
||||||
|
normalized_vision,
|
||||||
|
layout,
|
||||||
|
region_score,
|
||||||
|
)
|
||||||
|
|
||||||
layout["vision_assisted"] = True
|
layout["vision_assisted"] = True
|
||||||
layout["vision_assisted_status"] = normalized_vision.get("status", "unknown")
|
layout["vision_assisted_status"] = normalized_vision.get("status", "unknown")
|
||||||
|
|
@ -444,6 +572,7 @@ def build_vision_assisted_layout(source_layout: dict[str, Any] | None, vision_re
|
||||||
layout["vision_model_name"] = normalized_vision.get("model_name")
|
layout["vision_model_name"] = normalized_vision.get("model_name")
|
||||||
layout["vision_coordinate_normalization"] = normalized_vision.get("coordinate_normalization")
|
layout["vision_coordinate_normalization"] = normalized_vision.get("coordinate_normalization")
|
||||||
layout["vision_region_score"] = region_score
|
layout["vision_region_score"] = region_score
|
||||||
|
layout["vision_region_classification"] = region_classification
|
||||||
layout["layout_sync_source"] = "vision_assisted"
|
layout["layout_sync_source"] = "vision_assisted"
|
||||||
layout["layout_needs_review"] = True
|
layout["layout_needs_review"] = True
|
||||||
return layout
|
return layout
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue