document-processor/app/logic/extraction.py

245 lines
7.3 KiB
Python

from __future__ import annotations
import json
import re
from datetime import datetime
from decimal import Decimal, InvalidOperation
from sqlalchemy.orm import Session, selectinload
from app.models.document import Document
from app.models.extracted_field import ExtractedField
from app.models.text_version import TextVersion
MONEY_RE = re.compile(r"\$?\s*([0-9]+(?:\.[0-9]{2}))")
DATE_PATTERNS = [
re.compile(r"\b(\d{1,2})/(\d{1,2})/(\d{4})\b"),
re.compile(r"\b(\d{4})-(\d{2})-(\d{2})\b"),
]
TIME_PATTERNS = [
re.compile(r"\b(\d{1,2}:\d{2}(?::\d{2})?\s?(?:AM|PM|am|pm))\b"),
re.compile(r"\b(\d{1,2}:\d{2}\s?(?:am|pm|AM|PM))\b"),
]
TOTAL_RE = re.compile(r"(?im)^\s*total\b[:\s]*\$?\s*([0-9]+\.[0-9]{2})\s*$")
SUBTOTAL_RE = re.compile(r"(?im)^\s*sub\.?\s*total\b[:\s]*\$?\s*([0-9]+\.[0-9]{2})\s*$")
TAX_RE = re.compile(r"(?im)^\s*tax\b[:\s]*\$?\s*([0-9]+\.[0-9]{2})\s*$")
RECEIPT_NUM_RE = re.compile(
r"\b(?:order\s+number|receipt\s+number|receipt\s*#|tran\s+seq\s+no)\b[:\s]*([A-Za-z0-9\-]+)",
re.IGNORECASE,
)
PAYMENT_METHOD_RE = re.compile(r"\b(visa|mastercard|discover|amex|american express|cash|debit)\b", re.IGNORECASE)
CARD_LAST4_RE = re.compile(r"\*{4,}\s*([0-9]{4})")
STORE_NUM_RE = re.compile(r"#\s*0*([0-9]{3,})")
def _get_current_reviewed_text(document: Document) -> TextVersion | None:
reviewed = [tv for tv in document.text_versions if tv.version_type == "reviewed" and tv.is_current]
if reviewed:
return sorted(reviewed, key=lambda x: (x.version_number, x.created_at), reverse=True)[0]
raw = [tv for tv in document.text_versions if tv.version_type == "raw_ocr" and tv.is_current]
if raw:
return sorted(raw, key=lambda x: (x.version_number, x.created_at), reverse=True)[0]
return None
def _clean_lines(text: str) -> list[str]:
return [line.strip() for line in text.splitlines() if line.strip()]
def _parse_date(text: str):
for pat in DATE_PATTERNS:
m = pat.search(text)
if not m:
continue
groups = m.groups()
try:
if pat.pattern.startswith(r"\b(\d{4})"):
return datetime.strptime("-".join(groups), "%Y-%m-%d").date()
return datetime.strptime("/".join(groups), "%m/%d/%Y").date()
except ValueError:
continue
return None
def _parse_time(text: str) -> str | None:
for pat in TIME_PATTERNS:
m = pat.search(text)
if m:
return m.group(1).strip()
return None
def _to_decimal(value: str | None) -> Decimal | None:
if not value:
return None
try:
return Decimal(value)
except (InvalidOperation, TypeError):
return None
def _find_amount(pattern: re.Pattern[str], text: str) -> Decimal | None:
m = pattern.search(text)
if not m:
return None
return _to_decimal(m.group(1))
def _clean_merchant_name(line: str) -> str:
prefixes = [
"welcome to ",
"thank you for shopping at ",
"thank you for visiting ",
]
cleaned = line.strip()
lower = cleaned.lower()
for prefix in prefixes:
if lower.startswith(prefix):
cleaned = cleaned[len(prefix):].strip()
break
return cleaned
def _guess_merchant(lines: list[str]) -> str | None:
for line in lines[:5]:
if len(line) >= 3 and not any(ch.isdigit() for ch in line[:8]):
return _clean_merchant_name(line)
return _clean_merchant_name(lines[0]) if lines else None
def _guess_location(lines: list[str]) -> str | None:
for line in lines[1:6]:
if any(ch.isdigit() for ch in line) or "," in line or "(" in line:
return line
return None
def _extract_extra(lines: list[str], text: str) -> dict:
extra: dict = {}
m = CARD_LAST4_RE.search(text)
if m:
extra["card_last4"] = m.group(1)
m = STORE_NUM_RE.search(text)
if m:
extra["store_number"] = m.group(1)
cashier = None
for line in lines:
if re.search(r"\bcashier\b", line, re.IGNORECASE):
cashier = line
break
if cashier:
extra["cashier"] = cashier
return extra
def auto_extract_from_document(db: Session, document: Document) -> dict:
text_version = _get_current_reviewed_text(document)
if text_version is None:
return {}
text = text_version.text_content or ""
lines = _clean_lines(text)
merchant_raw = _guess_merchant(lines)
merchant_normalized = merchant_raw
transaction_date = _parse_date(text)
transaction_time = _parse_time(text)
subtotal = _find_amount(SUBTOTAL_RE, text)
tax = _find_amount(TAX_RE, text)
total = _find_amount(TOTAL_RE, text)
payment_method = None
m = PAYMENT_METHOD_RE.search(text)
if m:
payment_method = m.group(1).upper()
receipt_number = None
m = RECEIPT_NUM_RE.search(text)
if m:
receipt_number = m.group(1)
location = _guess_location(lines)
counterparty = merchant_raw
currency = "USD"
extra = _extract_extra(lines, text)
return {
"merchant_raw": merchant_raw or "",
"merchant_normalized": merchant_normalized or "",
"transaction_date": transaction_date.isoformat() if transaction_date else "",
"transaction_time": transaction_time or "",
"subtotal": str(subtotal) if subtotal is not None else "",
"tax": str(tax) if tax is not None else "",
"total": str(total) if total is not None else "",
"currency": currency or "",
"payment_method": payment_method or "",
"receipt_number": receipt_number or "",
"location": location or "",
"counterparty": counterparty or "",
"extra_json": json.dumps(extra, indent=2, sort_keys=True) if extra else "{}",
}
def get_current_extracted_fields(document: Document) -> ExtractedField | None:
if not document.extracted_fields:
return None
return sorted(document.extracted_fields, key=lambda x: x.updated_at or x.created_at, reverse=True)[0]
def save_extracted_fields(
db: Session,
document: Document,
merchant_raw: str,
merchant_normalized: str,
transaction_date: str,
transaction_time: str,
subtotal: str,
tax: str,
total: str,
currency: str,
payment_method: str,
receipt_number: str,
location: str,
counterparty: str,
extra_json: str,
) -> ExtractedField:
current = get_current_extracted_fields(document)
if current is None:
current = ExtractedField(document_id=document.id)
db.add(current)
current.merchant_raw = merchant_raw or None
current.merchant_normalized = merchant_normalized or None
current.transaction_date = datetime.strptime(transaction_date, "%Y-%m-%d").date() if transaction_date else None
current.transaction_time = transaction_time or None
current.subtotal = _to_decimal(subtotal)
current.tax = _to_decimal(tax)
current.total = _to_decimal(total)
current.currency = currency or None
current.payment_method = payment_method or None
current.receipt_number = receipt_number or None
current.location = location or None
current.counterparty = counterparty or None
try:
current.extra_json = json.loads(extra_json) if extra_json.strip() else {}
except json.JSONDecodeError:
current.extra_json = {"raw_text": extra_json}
db.commit()
db.refresh(current)
return current