245 lines
7.3 KiB
Python
245 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from datetime import datetime
|
|
from decimal import Decimal, InvalidOperation
|
|
|
|
from sqlalchemy.orm import Session, selectinload
|
|
|
|
from app.models.document import Document
|
|
from app.models.extracted_field import ExtractedField
|
|
from app.models.text_version import TextVersion
|
|
|
|
|
|
MONEY_RE = re.compile(r"\$?\s*([0-9]+(?:\.[0-9]{2}))")
|
|
DATE_PATTERNS = [
|
|
re.compile(r"\b(\d{1,2})/(\d{1,2})/(\d{4})\b"),
|
|
re.compile(r"\b(\d{4})-(\d{2})-(\d{2})\b"),
|
|
]
|
|
TIME_PATTERNS = [
|
|
re.compile(r"\b(\d{1,2}:\d{2}(?::\d{2})?\s?(?:AM|PM|am|pm))\b"),
|
|
re.compile(r"\b(\d{1,2}:\d{2}\s?(?:am|pm|AM|PM))\b"),
|
|
]
|
|
TOTAL_RE = re.compile(r"(?im)^\s*total\b[:\s]*\$?\s*([0-9]+\.[0-9]{2})\s*$")
|
|
SUBTOTAL_RE = re.compile(r"(?im)^\s*sub\.?\s*total\b[:\s]*\$?\s*([0-9]+\.[0-9]{2})\s*$")
|
|
TAX_RE = re.compile(r"(?im)^\s*tax\b[:\s]*\$?\s*([0-9]+\.[0-9]{2})\s*$")
|
|
RECEIPT_NUM_RE = re.compile(
|
|
r"\b(?:order\s+number|receipt\s+number|receipt\s*#|tran\s+seq\s+no)\b[:\s]*([A-Za-z0-9\-]+)",
|
|
re.IGNORECASE,
|
|
)
|
|
PAYMENT_METHOD_RE = re.compile(r"\b(visa|mastercard|discover|amex|american express|cash|debit)\b", re.IGNORECASE)
|
|
CARD_LAST4_RE = re.compile(r"\*{4,}\s*([0-9]{4})")
|
|
STORE_NUM_RE = re.compile(r"#\s*0*([0-9]{3,})")
|
|
|
|
|
|
def _get_current_reviewed_text(document: Document) -> TextVersion | None:
|
|
reviewed = [tv for tv in document.text_versions if tv.version_type == "reviewed" and tv.is_current]
|
|
if reviewed:
|
|
return sorted(reviewed, key=lambda x: (x.version_number, x.created_at), reverse=True)[0]
|
|
|
|
raw = [tv for tv in document.text_versions if tv.version_type == "raw_ocr" and tv.is_current]
|
|
if raw:
|
|
return sorted(raw, key=lambda x: (x.version_number, x.created_at), reverse=True)[0]
|
|
|
|
return None
|
|
|
|
|
|
def _clean_lines(text: str) -> list[str]:
|
|
return [line.strip() for line in text.splitlines() if line.strip()]
|
|
|
|
|
|
def _parse_date(text: str):
|
|
for pat in DATE_PATTERNS:
|
|
m = pat.search(text)
|
|
if not m:
|
|
continue
|
|
|
|
groups = m.groups()
|
|
try:
|
|
if pat.pattern.startswith(r"\b(\d{4})"):
|
|
return datetime.strptime("-".join(groups), "%Y-%m-%d").date()
|
|
return datetime.strptime("/".join(groups), "%m/%d/%Y").date()
|
|
except ValueError:
|
|
continue
|
|
return None
|
|
|
|
|
|
def _parse_time(text: str) -> str | None:
|
|
for pat in TIME_PATTERNS:
|
|
m = pat.search(text)
|
|
if m:
|
|
return m.group(1).strip()
|
|
return None
|
|
|
|
|
|
def _to_decimal(value: str | None) -> Decimal | None:
|
|
if not value:
|
|
return None
|
|
try:
|
|
return Decimal(value)
|
|
except (InvalidOperation, TypeError):
|
|
return None
|
|
|
|
|
|
def _find_amount(pattern: re.Pattern[str], text: str) -> Decimal | None:
|
|
m = pattern.search(text)
|
|
if not m:
|
|
return None
|
|
return _to_decimal(m.group(1))
|
|
|
|
|
|
def _clean_merchant_name(line: str) -> str:
|
|
prefixes = [
|
|
"welcome to ",
|
|
"thank you for shopping at ",
|
|
"thank you for visiting ",
|
|
]
|
|
cleaned = line.strip()
|
|
lower = cleaned.lower()
|
|
for prefix in prefixes:
|
|
if lower.startswith(prefix):
|
|
cleaned = cleaned[len(prefix):].strip()
|
|
break
|
|
return cleaned
|
|
|
|
|
|
def _guess_merchant(lines: list[str]) -> str | None:
|
|
for line in lines[:5]:
|
|
if len(line) >= 3 and not any(ch.isdigit() for ch in line[:8]):
|
|
return _clean_merchant_name(line)
|
|
return _clean_merchant_name(lines[0]) if lines else None
|
|
|
|
|
|
def _guess_location(lines: list[str]) -> str | None:
|
|
for line in lines[1:6]:
|
|
if any(ch.isdigit() for ch in line) or "," in line or "(" in line:
|
|
return line
|
|
return None
|
|
|
|
|
|
def _extract_extra(lines: list[str], text: str) -> dict:
|
|
extra: dict = {}
|
|
|
|
m = CARD_LAST4_RE.search(text)
|
|
if m:
|
|
extra["card_last4"] = m.group(1)
|
|
|
|
m = STORE_NUM_RE.search(text)
|
|
if m:
|
|
extra["store_number"] = m.group(1)
|
|
|
|
cashier = None
|
|
for line in lines:
|
|
if re.search(r"\bcashier\b", line, re.IGNORECASE):
|
|
cashier = line
|
|
break
|
|
if cashier:
|
|
extra["cashier"] = cashier
|
|
|
|
return extra
|
|
|
|
|
|
def auto_extract_from_document(db: Session, document: Document) -> dict:
|
|
text_version = _get_current_reviewed_text(document)
|
|
if text_version is None:
|
|
return {}
|
|
|
|
text = text_version.text_content or ""
|
|
lines = _clean_lines(text)
|
|
|
|
merchant_raw = _guess_merchant(lines)
|
|
merchant_normalized = merchant_raw
|
|
transaction_date = _parse_date(text)
|
|
transaction_time = _parse_time(text)
|
|
|
|
subtotal = _find_amount(SUBTOTAL_RE, text)
|
|
tax = _find_amount(TAX_RE, text)
|
|
total = _find_amount(TOTAL_RE, text)
|
|
|
|
payment_method = None
|
|
m = PAYMENT_METHOD_RE.search(text)
|
|
if m:
|
|
payment_method = m.group(1).upper()
|
|
|
|
receipt_number = None
|
|
m = RECEIPT_NUM_RE.search(text)
|
|
if m:
|
|
receipt_number = m.group(1)
|
|
|
|
location = _guess_location(lines)
|
|
counterparty = merchant_raw
|
|
currency = "USD"
|
|
|
|
extra = _extract_extra(lines, text)
|
|
|
|
return {
|
|
"merchant_raw": merchant_raw or "",
|
|
"merchant_normalized": merchant_normalized or "",
|
|
"transaction_date": transaction_date.isoformat() if transaction_date else "",
|
|
"transaction_time": transaction_time or "",
|
|
"subtotal": str(subtotal) if subtotal is not None else "",
|
|
"tax": str(tax) if tax is not None else "",
|
|
"total": str(total) if total is not None else "",
|
|
"currency": currency or "",
|
|
"payment_method": payment_method or "",
|
|
"receipt_number": receipt_number or "",
|
|
"location": location or "",
|
|
"counterparty": counterparty or "",
|
|
"extra_json": json.dumps(extra, indent=2, sort_keys=True) if extra else "{}",
|
|
}
|
|
|
|
|
|
def get_current_extracted_fields(document: Document) -> ExtractedField | None:
|
|
if not document.extracted_fields:
|
|
return None
|
|
return sorted(document.extracted_fields, key=lambda x: x.updated_at or x.created_at, reverse=True)[0]
|
|
|
|
|
|
def save_extracted_fields(
|
|
db: Session,
|
|
document: Document,
|
|
merchant_raw: str,
|
|
merchant_normalized: str,
|
|
transaction_date: str,
|
|
transaction_time: str,
|
|
subtotal: str,
|
|
tax: str,
|
|
total: str,
|
|
currency: str,
|
|
payment_method: str,
|
|
receipt_number: str,
|
|
location: str,
|
|
counterparty: str,
|
|
extra_json: str,
|
|
) -> ExtractedField:
|
|
current = get_current_extracted_fields(document)
|
|
if current is None:
|
|
current = ExtractedField(document_id=document.id)
|
|
db.add(current)
|
|
|
|
current.merchant_raw = merchant_raw or None
|
|
current.merchant_normalized = merchant_normalized or None
|
|
|
|
current.transaction_date = datetime.strptime(transaction_date, "%Y-%m-%d").date() if transaction_date else None
|
|
current.transaction_time = transaction_time or None
|
|
|
|
current.subtotal = _to_decimal(subtotal)
|
|
current.tax = _to_decimal(tax)
|
|
current.total = _to_decimal(total)
|
|
current.currency = currency or None
|
|
|
|
current.payment_method = payment_method or None
|
|
current.receipt_number = receipt_number or None
|
|
current.location = location or None
|
|
current.counterparty = counterparty or None
|
|
|
|
try:
|
|
current.extra_json = json.loads(extra_json) if extra_json.strip() else {}
|
|
except json.JSONDecodeError:
|
|
current.extra_json = {"raw_text": extra_json}
|
|
|
|
db.commit()
|
|
db.refresh(current)
|
|
return current
|