document-processor/app/logic/extraction.py

893 lines
27 KiB
Python

from __future__ import annotations
import json
import re
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal, InvalidOperation
from sqlalchemy.orm import Session
from app.models.document import Document
from app.models.extracted_field import ExtractedField
from app.models.receipt_line_item import ReceiptLineItem
from app.models.text_version import TextVersion
MONEY_RE = re.compile(r"(?<!\d)([0-9]+(?:\.[0-9]{2}))(?!\d)")
DATE_PATTERNS = [
re.compile(r"\b(\d{1,2})/(\d{1,2})/(\d{4})\b"),
re.compile(r"\b(\d{1,2})/(\d{1,2})/(\d{2})\b"),
re.compile(r"\b(\d{4})-(\d{2})-(\d{2})\b"),
]
TIME_PATTERNS = [
re.compile(r"\b(\d{1,2}:\d{2}(?::\d{2})?\s?(?:AM|PM|am|pm))\b"),
re.compile(r"\b(\d{1,2}:\d{2}\s?(?:am|pm|AM|PM))\b"),
]
REFERENCE_NUM_RE = re.compile(
r"\b(?:order\s+number|order\s*#|receipt\s+number|receipt\s*#|invoice\s+number|invoice\s*#|check\s+number|check\s*#|transaction\s+number|transaction\s*#|confirmation\s+number|confirmation\s*#|reference\s+number|reference\s*#|ticket\s*#|tran\s+seq\s+no)\b[:\s]*([A-Za-z0-9\-]+)",
re.IGNORECASE,
)
PAYMENT_METHOD_RE = re.compile(
r"\b(visa|mastercard|discover|amex|american express|cash|debit)\b",
re.IGNORECASE,
)
CARD_LAST4_RE = re.compile(r"\*{4,}\s*([0-9]{4})")
STORE_NUM_RE = re.compile(r"#\s*0*([0-9]{3,})")
ADDRESS_HINT_RE = re.compile(
r"\b(st|street|ave|avenue|rd|road|dr|drive|blvd|boulevard|ln|lane|hwy|highway|suite|ste)\b",
re.IGNORECASE,
)
PHONE_RE = re.compile(r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")
QTY_PREFIX_RE = re.compile(r"^\s*(\d+(?:\.\d+)?)\s+(.+?)\s*$")
ITEM_LINE_RE = re.compile(r"^(.*?)([0-9]+\.[0-9]{2})\s*$")
@dataclass
class DocumentLine:
page: int | None
line_index: int
text: str
normalized: str
bbox: list[int] | None
confidence: float | None
def _get_current_reviewed_text(document: Document) -> TextVersion | None:
reviewed = [tv for tv in document.text_versions if tv.version_type == "reviewed" and tv.is_current]
if reviewed:
return sorted(reviewed, key=lambda x: (x.version_number, x.created_at), reverse=True)[0]
raw = [tv for tv in document.text_versions if tv.version_type == "raw_ocr" and tv.is_current]
if raw:
return sorted(raw, key=lambda x: (x.version_number, x.created_at), reverse=True)[0]
return None
def _normalize_line(text: str) -> str:
return re.sub(r"\s+", " ", text.strip()).lower()
def _clean_lines(text: str) -> list[str]:
return [line.strip() for line in text.splitlines() if line.strip()]
def _build_lines_from_layout(layout_json: dict | None) -> list[DocumentLine]:
if not layout_json:
return []
lines: list[DocumentLine] = []
idx = 0
for page in layout_json.get("pages", []):
page_num = page.get("page")
for line in page.get("lines", []):
text = (line.get("text") or "").strip()
if not text:
continue
lines.append(
DocumentLine(
page=page_num,
line_index=idx,
text=text,
normalized=_normalize_line(text),
bbox=line.get("bbox"),
confidence=line.get("confidence"),
)
)
idx += 1
return lines
def _build_lines_from_text(text: str) -> list[DocumentLine]:
return [
DocumentLine(
page=None,
line_index=idx,
text=line,
normalized=_normalize_line(line),
bbox=None,
confidence=None,
)
for idx, line in enumerate(_clean_lines(text))
]
def _get_document_lines(text_version: TextVersion) -> list[DocumentLine]:
lines = _build_lines_from_layout(text_version.layout_json)
if lines:
return lines
return _build_lines_from_text(text_version.text_content or "")
def _normalize_time_ocr(text: str) -> str:
cleaned = text
cleaned = re.sub(r"\bpie\b", "pm", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\bpni\b", "pm", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\baie\b", "am", cleaned, flags=re.IGNORECASE)
return cleaned
def _parse_date(text: str):
for pat in DATE_PATTERNS:
m = pat.search(text)
if not m:
continue
groups = m.groups()
try:
if pat.pattern.startswith(r"\b(\d{4})"):
return datetime.strptime("-".join(groups), "%Y-%m-%d").date()
if len(groups[2]) == 2:
return datetime.strptime("/".join(groups), "%m/%d/%y").date()
return datetime.strptime("/".join(groups), "%m/%d/%Y").date()
except ValueError:
continue
return None
def _parse_time(text: str) -> str | None:
normalized_text = _normalize_time_ocr(text)
for pat in TIME_PATTERNS:
m = pat.search(normalized_text)
if m:
return m.group(1).strip()
return None
def _to_decimal(value: str | None) -> Decimal | None:
if value is None:
return None
try:
return Decimal(str(value).strip())
except (InvalidOperation, TypeError):
return None
def _extract_line_amount(line: DocumentLine) -> Decimal | None:
matches = MONEY_RE.findall(line.text.replace(",", ""))
if not matches:
return None
return _to_decimal(matches[-1])
def _money_match_count(text: str) -> int:
return len(MONEY_RE.findall(text.replace(",", "")))
def _source_span(line: DocumentLine | None) -> dict | None:
if line is None:
return None
return {
"page": line.page,
"line_index": line.line_index,
"text": line.text,
"bbox": line.bbox,
"confidence": line.confidence,
}
def _clean_merchant_name(line: str) -> str:
prefixes = [
"welcome to ",
"thank you for shopping at ",
"thank you for visiting ",
]
cleaned = line.strip()
lower = cleaned.lower()
for prefix in prefixes:
if lower.startswith(prefix):
cleaned = cleaned[len(prefix):].strip()
break
return cleaned
def _looks_like_address(line: str) -> bool:
lower = line.lower()
if "date:" in lower or "time:" in lower:
return False
return bool(ADDRESS_HINT_RE.search(line) or (any(ch.isdigit() for ch in line) and "," in line))
def _looks_like_phone(line: str) -> bool:
return bool(PHONE_RE.search(line))
def _looks_like_date_line(line: str) -> bool:
lower = line.lower()
if "date:" in lower or "time:" in lower:
return True
return any(p.search(line) for p in DATE_PATTERNS)
def _is_price_only_line(line: DocumentLine) -> bool:
text = line.text.strip().replace(",", "")
if not text:
return False
if _money_match_count(text) != 1:
return False
stripped = text.replace("$", "").strip()
return bool(re.fullmatch(r"[0-9]+\.[0-9]{2}", stripped))
def _guess_merchant(lines: list[DocumentLine]) -> tuple[str | None, DocumentLine | None]:
for line in lines[:5]:
text = line.text.strip()
if len(text) < 3:
continue
if _looks_like_phone(text):
continue
if _looks_like_address(text):
continue
if _looks_like_date_line(text):
continue
return _clean_merchant_name(text), line
if lines:
return _clean_merchant_name(lines[0].text), lines[0]
return None, None
def _guess_location(lines: list[DocumentLine]) -> tuple[str | None, DocumentLine | None]:
for line in lines[1:8]:
text = line.text.strip()
lower = text.lower()
if "date:" in lower or "time:" in lower:
continue
if _looks_like_phone(text):
continue
if _looks_like_date_line(text):
continue
if _looks_like_address(text):
return text, line
return None, None
def _extract_extra(lines: list[DocumentLine], text: str) -> dict:
extra: dict = {}
m = CARD_LAST4_RE.search(text)
if m:
extra["card_last4"] = m.group(1)
m = STORE_NUM_RE.search(text)
if m:
extra["store_number"] = m.group(1)
cashier = None
cashier_span = None
for line in lines:
if re.search(r"\bcashier\b", line.text, re.IGNORECASE):
cashier = line.text
cashier_span = _source_span(line)
break
if cashier:
extra["cashier"] = cashier
extra["cashier_source"] = cashier_span
return extra
def _score_total_line(line: DocumentLine, total_lines: int) -> float:
score = 0.0
text = line.normalized
amount = _extract_line_amount(line)
if "subtotal" in text or "sub total" in text:
score -= 8.0
if "tax" in text:
score -= 5.0
if "tip" in text:
score -= 2.0
if "grand total" in text:
score += 8.0
elif re.search(r"\btotal\b", text):
score += 6.0
if amount is not None:
score += 2.0
if total_lines > 0:
score += (line.line_index / max(total_lines, 1)) * 2.0
return score
def _score_subtotal_line(line: DocumentLine) -> float:
score = 0.0
text = line.normalized
amount = _extract_line_amount(line)
if "subtotal" in text or "sub-total" in text or "sub total" in text:
score += 8.0
elif re.search(r"\btotal\b", text):
score -= 3.0
if "tax" in text:
score -= 3.0
if amount is not None:
score += 2.0
return score
def _score_tax_line(line: DocumentLine) -> float:
score = 0.0
text = line.normalized
amount = _extract_line_amount(line)
if "sales tax" in text:
score += 8.0
elif re.search(r"\btax\b", text):
score += 7.0
elif "vat" in text or "gst" in text:
score += 6.0
if "total" in text and "subtotal" not in text and "sub total" not in text:
score -= 2.0
if amount is not None:
score += 2.0
return score
def _pick_best_line(lines: list[DocumentLine], scorer) -> DocumentLine | None:
if not lines:
return None
scored = [(scorer(line), line) for line in lines]
scored.sort(key=lambda item: item[0], reverse=True)
best_score, best_line = scored[0]
if best_score <= 0:
return None
return best_line
def _extract_total(lines: list[DocumentLine]) -> tuple[Decimal | None, DocumentLine | None]:
best = _pick_best_line(lines, lambda line: _score_total_line(line, len(lines)))
if not best:
return None, None
amount = _extract_line_amount(best)
if amount is not None:
return amount, best
next_idx = best.line_index + 1
next_line = next((line for line in lines if line.line_index == next_idx), None)
if next_line:
return _extract_line_amount(next_line), best
return None, best
def _extract_subtotal(lines: list[DocumentLine]) -> tuple[Decimal | None, DocumentLine | None]:
best = _pick_best_line(lines, _score_subtotal_line)
if not best:
return None, None
amount = _extract_line_amount(best)
if amount is not None:
return amount, best
next_idx = best.line_index + 1
next_line = next((line for line in lines if line.line_index == next_idx), None)
if next_line:
return _extract_line_amount(next_line), best
return None, best
def _extract_tax(lines: list[DocumentLine]) -> tuple[Decimal | None, DocumentLine | None]:
best = _pick_best_line(lines, _score_tax_line)
if not best:
return None, None
amount = _extract_line_amount(best)
if amount is not None:
return amount, best
next_idx = best.line_index + 1
next_line = next((line for line in lines if line.line_index == next_idx), None)
if next_line:
return _extract_line_amount(next_line), best
return None, best
def _is_non_item_line(normalized: str) -> bool:
blocked_terms = [
"subtotal",
"sub total",
"sub-total",
"total",
"tax",
"service fee",
"tip",
"pay this amount",
"recommended gratuity",
"gratuity",
"cashier",
"server",
"guest",
"table #",
"table:",
"date:",
"time:",
"order #",
"order:",
"invoice #",
"invoice:",
"reference #",
"confirmation #",
"receipt",
"visa",
"mastercard",
"discover",
"amex",
"cash",
"debit",
"thank you",
"regresen pronto",
"gracias",
]
if any(term in normalized for term in blocked_terms):
return True
if "% =" in normalized:
return True
return False
def _normalize_item_description(text: str) -> str:
cleaned = re.sub(r"\s+", " ", text.strip())
cleaned = cleaned.strip("-: ")
cleaned = re.sub(r"\s+\$$", "", cleaned)
cleaned = re.sub(r"\$$", "", cleaned)
return cleaned.title()
def _clean_item_description(text: str) -> str:
cleaned = re.sub(r"\s+", " ", text.strip())
cleaned = cleaned.strip("-: ")
cleaned = re.sub(r"\s+\$$", "", cleaned)
cleaned = re.sub(r"\$$", "", cleaned)
return cleaned.strip()
def _infer_item_category(text: str) -> str | None:
normalized = text.lower()
cocktail_terms = [
"margarita",
"old fashioned",
"oldfashion",
"picante",
"martini",
"negroni",
"spritz",
"mezcal",
"tequila",
"paloma",
"manhattan",
"mojito",
"cocktail",
]
food_terms = [
"dip",
"burger",
"fries",
"taco",
"nachos",
"quesadilla",
"salad",
"enchilada",
"steak",
"burrito",
"sandwich",
]
modifier_terms = [
"add ",
"extra ",
"side ",
"sauce",
"cheese",
"espinaca",
"jalape",
"onion ring",
]
if any(term in normalized for term in cocktail_terms):
return "cocktail"
if any(term in normalized for term in food_terms):
return "food"
if any(term in normalized for term in modifier_terms):
return "modifier"
if "beer" in normalized:
return "beer"
if "wine" in normalized:
return "wine"
return None
def _candidate_item_description_line(line: DocumentLine) -> bool:
text = line.text.strip()
normalized = line.normalized
if len(text) < 3:
return False
if _is_non_item_line(normalized):
return False
if _looks_like_address(text) or _looks_like_phone(text) or _looks_like_date_line(text):
return False
if _money_match_count(text) > 1:
return False
if _is_price_only_line(line):
return False
return True
def _extract_receipt_line_items(lines: list[DocumentLine]) -> list[dict]:
items: list[dict] = []
used_line_indexes: set[int] = set()
protected_amount_indexes: set[int] = set()
for label in ["subtotal", "sub-total", "tax", "service fee", "total", "pay this amount"]:
for idx, line in enumerate(lines):
if label in line.normalized:
protected_amount_indexes.add(line.line_index)
if idx + 1 < len(lines):
protected_amount_indexes.add(lines[idx + 1].line_index)
for idx, line in enumerate(lines):
if line.line_index in used_line_indexes:
continue
if line.line_index in protected_amount_indexes:
continue
normalized = line.normalized
text = line.text.strip()
if len(text) < 3:
continue
if _is_non_item_line(normalized):
continue
if _looks_like_address(text) or _looks_like_phone(text) or _looks_like_date_line(text):
continue
if _money_match_count(text) > 1:
continue
same_line_match = ITEM_LINE_RE.match(text.replace(",", ""))
if same_line_match:
description_part = same_line_match.group(1).strip()
price_part = same_line_match.group(2).strip()
if description_part and description_part not in {"$"}:
quantity = None
description = description_part
qty_match = QTY_PREFIX_RE.match(description_part)
if qty_match:
quantity = _to_decimal(qty_match.group(1))
description = qty_match.group(2).strip()
description = _clean_item_description(description)
line_total = _to_decimal(price_part)
if description and line_total is not None and description.lower() not in {"total", "subtotal", "tax"}:
confidence = Decimal("85.00")
if quantity is not None:
confidence = Decimal("90.00")
items.append(
{
"line_index": line.line_index,
"raw_description": description,
"normalized_description": _normalize_item_description(description),
"quantity": str(quantity) if quantity is not None else "",
"unit_price": "",
"line_total": str(line_total),
"item_category": _infer_item_category(description) or "",
"confidence": str(confidence),
"extra_json": {
"page": line.page,
"bbox": line.bbox,
"source_text": line.text,
"source_confidence": line.confidence,
"match_type": "same_line",
},
}
)
used_line_indexes.add(line.line_index)
continue
if not _candidate_item_description_line(line):
continue
next_line = lines[idx + 1] if idx + 1 < len(lines) else None
if not next_line or next_line.line_index in used_line_indexes:
continue
if next_line.line_index in protected_amount_indexes:
continue
if not _is_price_only_line(next_line):
continue
if _is_non_item_line(next_line.normalized):
continue
description = text
quantity = None
qty_match = QTY_PREFIX_RE.match(description)
if qty_match:
quantity = _to_decimal(qty_match.group(1))
description = qty_match.group(2).strip()
description = _clean_item_description(description)
line_total = _extract_line_amount(next_line)
if not description or line_total is None:
continue
confidence = Decimal("88.00")
if quantity is not None:
confidence = Decimal("92.00")
items.append(
{
"line_index": line.line_index,
"raw_description": description,
"normalized_description": _normalize_item_description(description),
"quantity": str(quantity) if quantity is not None else "",
"unit_price": "",
"line_total": str(line_total),
"item_category": _infer_item_category(description) or "",
"confidence": str(confidence),
"extra_json": {
"page": line.page,
"bbox": line.bbox,
"price_line_index": next_line.line_index,
"price_bbox": next_line.bbox,
"price_text": next_line.text,
"source_text": line.text,
"source_confidence": line.confidence,
"match_type": "paired_next_line",
},
}
)
used_line_indexes.add(line.line_index)
used_line_indexes.add(next_line.line_index)
for idx, line in enumerate(lines):
if line.line_index in used_line_indexes:
continue
if line.line_index in protected_amount_indexes:
continue
if not _candidate_item_description_line(line):
continue
prev_line = lines[idx - 1] if idx - 1 >= 0 else None
if not prev_line:
continue
if prev_line.line_index in used_line_indexes:
continue
if prev_line.line_index in protected_amount_indexes:
continue
if not _is_price_only_line(prev_line):
continue
if _is_non_item_line(prev_line.normalized):
continue
description = line.text.strip()
quantity = None
qty_match = QTY_PREFIX_RE.match(description)
if qty_match:
quantity = _to_decimal(qty_match.group(1))
description = qty_match.group(2).strip()
description = _clean_item_description(description)
line_total = _extract_line_amount(prev_line)
if not description or line_total is None:
continue
confidence = Decimal("89.00")
if quantity is not None:
confidence = Decimal("93.00")
items.append(
{
"line_index": line.line_index,
"raw_description": description,
"normalized_description": _normalize_item_description(description),
"quantity": str(quantity) if quantity is not None else "",
"unit_price": "",
"line_total": str(line_total),
"item_category": _infer_item_category(description) or "",
"confidence": str(confidence),
"extra_json": {
"page": line.page,
"bbox": line.bbox,
"price_line_index": prev_line.line_index,
"price_bbox": prev_line.bbox,
"price_text": prev_line.text,
"source_text": line.text,
"source_confidence": line.confidence,
"match_type": "paired_prev_line",
},
}
)
used_line_indexes.add(line.line_index)
used_line_indexes.add(prev_line.line_index)
items.sort(key=lambda x: x.get("line_index", 0))
return items
def _replace_receipt_line_items(db: Session, document: Document, items: list[dict]) -> None:
existing_items = list(getattr(document, "receipt_line_items", []) or [])
for item in existing_items:
db.delete(item)
for item in items:
db.add(
ReceiptLineItem(
document_id=document.id,
line_index=item.get("line_index"),
raw_description=item.get("raw_description") or "",
normalized_description=item.get("normalized_description") or None,
quantity=_to_decimal(item.get("quantity")),
unit_price=_to_decimal(item.get("unit_price")),
line_total=_to_decimal(item.get("line_total")),
item_category=item.get("item_category") or None,
confidence=_to_decimal(item.get("confidence")),
extra_json=item.get("extra_json") or {},
)
)
def auto_extract_from_document(db: Session, document: Document) -> dict:
text_version = _get_current_reviewed_text(document)
if text_version is None:
return {}
text = text_version.text_content or ""
lines = _get_document_lines(text_version)
merchant_raw, merchant_line = _guess_merchant(lines)
merchant_normalized = merchant_raw
transaction_date = _parse_date(text)
transaction_time = _parse_time(text)
subtotal, subtotal_line = _extract_subtotal(lines)
tax, tax_line = _extract_tax(lines)
total, total_line = _extract_total(lines)
payment_method = None
m = PAYMENT_METHOD_RE.search(text)
if m:
payment_method = m.group(1).upper()
reference_number = None
m = REFERENCE_NUM_RE.search(text)
if m:
reference_number = m.group(1)
location, location_line = _guess_location(lines)
counterparty = merchant_raw
currency = "USD"
line_items = _extract_receipt_line_items(lines)
extra = _extract_extra(lines, text)
extra["source_spans"] = {
"merchant_raw": _source_span(merchant_line),
"location": _source_span(location_line),
"subtotal": _source_span(subtotal_line),
"tax": _source_span(tax_line),
"total": _source_span(total_line),
"reference_number": {"value": reference_number} if reference_number else None,
}
extra["analysis"] = {
"line_count": len(lines),
"has_layout": bool(text_version.layout_json),
"source_version_type": text_version.version_type,
}
extra["line_items"] = line_items
return {
"merchant_raw": merchant_raw or "",
"merchant_normalized": merchant_normalized or "",
"transaction_date": transaction_date.isoformat() if transaction_date else "",
"transaction_time": transaction_time or "",
"subtotal": str(subtotal) if subtotal is not None else "",
"tax": str(tax) if tax is not None else "",
"total": str(total) if total is not None else "",
"currency": currency or "",
"payment_method": payment_method or "",
"receipt_number": reference_number or "",
"location": location or "",
"counterparty": counterparty or "",
"extra_json": json.dumps(extra, indent=2, sort_keys=True) if extra else "{}",
}
def get_current_extracted_fields(document: Document) -> ExtractedField | None:
if not document.extracted_fields:
return None
return sorted(document.extracted_fields, key=lambda x: x.updated_at or x.created_at, reverse=True)[0]
def save_extracted_fields(
db: Session,
document: Document,
merchant_raw: str,
merchant_normalized: str,
transaction_date: str,
transaction_time: str,
subtotal: str,
tax: str,
total: str,
currency: str,
payment_method: str,
receipt_number: str,
location: str,
counterparty: str,
extra_json: str,
) -> ExtractedField:
current = get_current_extracted_fields(document)
if current is None:
current = ExtractedField(document_id=document.id)
db.add(current)
current.merchant_raw = merchant_raw or None
current.merchant_normalized = merchant_normalized or None
current.transaction_date = datetime.strptime(transaction_date, "%Y-%m-%d").date() if transaction_date else None
current.transaction_time = transaction_time or None
current.subtotal = _to_decimal(subtotal)
current.tax = _to_decimal(tax)
current.total = _to_decimal(total)
current.currency = currency or None
current.payment_method = payment_method or None
current.receipt_number = receipt_number or None
current.location = location or None
current.counterparty = counterparty or None
parsed_extra: dict
try:
parsed_extra = json.loads(extra_json) if extra_json.strip() else {}
except json.JSONDecodeError:
parsed_extra = {"raw_text": extra_json}
current.extra_json = parsed_extra
line_items = parsed_extra.get("line_items", [])
if isinstance(line_items, list):
_replace_receipt_line_items(db, document, line_items)
else:
_replace_receipt_line_items(db, document, [])
db.commit()
db.refresh(current)
return current