diff --git a/app/logic/extraction.py b/app/logic/extraction.py new file mode 100644 index 0000000..7d0b939 --- /dev/null +++ b/app/logic/extraction.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import json +import re +from datetime import datetime +from decimal import Decimal, InvalidOperation + +from sqlalchemy.orm import Session, selectinload + +from app.models.document import Document +from app.models.extracted_field import ExtractedField +from app.models.text_version import TextVersion + + +MONEY_RE = re.compile(r"\$?\s*([0-9]+(?:\.[0-9]{2}))") +DATE_PATTERNS = [ + re.compile(r"\b(\d{1,2})/(\d{1,2})/(\d{4})\b"), + re.compile(r"\b(\d{4})-(\d{2})-(\d{2})\b"), +] +TIME_PATTERNS = [ + re.compile(r"\b(\d{1,2}:\d{2}(?::\d{2})?\s?(?:AM|PM|am|pm))\b"), + re.compile(r"\b(\d{1,2}:\d{2}\s?(?:am|pm|AM|PM))\b"), +] +TOTAL_RE = re.compile(r"(?im)^\s*total\b[:\s]*\$?\s*([0-9]+\.[0-9]{2})\s*$") +SUBTOTAL_RE = re.compile(r"(?im)^\s*sub\.?\s*total\b[:\s]*\$?\s*([0-9]+\.[0-9]{2})\s*$") +TAX_RE = re.compile(r"(?im)^\s*tax\b[:\s]*\$?\s*([0-9]+\.[0-9]{2})\s*$") +RECEIPT_NUM_RE = re.compile( + r"\b(?:order\s+number|receipt\s+number|receipt\s*#|tran\s+seq\s+no)\b[:\s]*([A-Za-z0-9\-]+)", + re.IGNORECASE, +) +PAYMENT_METHOD_RE = re.compile(r"\b(visa|mastercard|discover|amex|american express|cash|debit)\b", re.IGNORECASE) +CARD_LAST4_RE = re.compile(r"\*{4,}\s*([0-9]{4})") +STORE_NUM_RE = re.compile(r"#\s*0*([0-9]{3,})") + + +def _get_current_reviewed_text(document: Document) -> TextVersion | None: + reviewed = [tv for tv in document.text_versions if tv.version_type == "reviewed" and tv.is_current] + if reviewed: + return sorted(reviewed, key=lambda x: (x.version_number, x.created_at), reverse=True)[0] + + raw = [tv for tv in document.text_versions if tv.version_type == "raw_ocr" and tv.is_current] + if raw: + return sorted(raw, key=lambda x: (x.version_number, x.created_at), reverse=True)[0] + + return None + + +def _clean_lines(text: str) -> list[str]: + return [line.strip() for line in text.splitlines() if line.strip()] + + +def _parse_date(text: str): + for pat in DATE_PATTERNS: + m = pat.search(text) + if not m: + continue + + groups = m.groups() + try: + if pat.pattern.startswith(r"\b(\d{4})"): + return datetime.strptime("-".join(groups), "%Y-%m-%d").date() + return datetime.strptime("/".join(groups), "%m/%d/%Y").date() + except ValueError: + continue + return None + + +def _parse_time(text: str) -> str | None: + for pat in TIME_PATTERNS: + m = pat.search(text) + if m: + return m.group(1).strip() + return None + + +def _to_decimal(value: str | None) -> Decimal | None: + if not value: + return None + try: + return Decimal(value) + except (InvalidOperation, TypeError): + return None + + +def _find_amount(pattern: re.Pattern[str], text: str) -> Decimal | None: + m = pattern.search(text) + if not m: + return None + return _to_decimal(m.group(1)) + + +def _clean_merchant_name(line: str) -> str: + prefixes = [ + "welcome to ", + "thank you for shopping at ", + "thank you for visiting ", + ] + cleaned = line.strip() + lower = cleaned.lower() + for prefix in prefixes: + if lower.startswith(prefix): + cleaned = cleaned[len(prefix):].strip() + break + return cleaned + + +def _guess_merchant(lines: list[str]) -> str | None: + for line in lines[:5]: + if len(line) >= 3 and not any(ch.isdigit() for ch in line[:8]): + return _clean_merchant_name(line) + return _clean_merchant_name(lines[0]) if lines else None + + +def _guess_location(lines: list[str]) -> str | None: + for line in lines[1:6]: + if any(ch.isdigit() for ch in line) or "," in line or "(" in line: + return line + return None + + +def _extract_extra(lines: list[str], text: str) -> dict: + extra: dict = {} + + m = CARD_LAST4_RE.search(text) + if m: + extra["card_last4"] = m.group(1) + + m = STORE_NUM_RE.search(text) + if m: + extra["store_number"] = m.group(1) + + cashier = None + for line in lines: + if re.search(r"\bcashier\b", line, re.IGNORECASE): + cashier = line + break + if cashier: + extra["cashier"] = cashier + + return extra + + +def auto_extract_from_document(db: Session, document: Document) -> dict: + text_version = _get_current_reviewed_text(document) + if text_version is None: + return {} + + text = text_version.text_content or "" + lines = _clean_lines(text) + + merchant_raw = _guess_merchant(lines) + merchant_normalized = merchant_raw + transaction_date = _parse_date(text) + transaction_time = _parse_time(text) + + subtotal = _find_amount(SUBTOTAL_RE, text) + tax = _find_amount(TAX_RE, text) + total = _find_amount(TOTAL_RE, text) + + payment_method = None + m = PAYMENT_METHOD_RE.search(text) + if m: + payment_method = m.group(1).upper() + + receipt_number = None + m = RECEIPT_NUM_RE.search(text) + if m: + receipt_number = m.group(1) + + location = _guess_location(lines) + counterparty = merchant_raw + currency = "USD" + + extra = _extract_extra(lines, text) + + return { + "merchant_raw": merchant_raw or "", + "merchant_normalized": merchant_normalized or "", + "transaction_date": transaction_date.isoformat() if transaction_date else "", + "transaction_time": transaction_time or "", + "subtotal": str(subtotal) if subtotal is not None else "", + "tax": str(tax) if tax is not None else "", + "total": str(total) if total is not None else "", + "currency": currency or "", + "payment_method": payment_method or "", + "receipt_number": receipt_number or "", + "location": location or "", + "counterparty": counterparty or "", + "extra_json": json.dumps(extra, indent=2, sort_keys=True) if extra else "{}", + } + + +def get_current_extracted_fields(document: Document) -> ExtractedField | None: + if not document.extracted_fields: + return None + return sorted(document.extracted_fields, key=lambda x: x.updated_at or x.created_at, reverse=True)[0] + + +def save_extracted_fields( + db: Session, + document: Document, + merchant_raw: str, + merchant_normalized: str, + transaction_date: str, + transaction_time: str, + subtotal: str, + tax: str, + total: str, + currency: str, + payment_method: str, + receipt_number: str, + location: str, + counterparty: str, + extra_json: str, +) -> ExtractedField: + current = get_current_extracted_fields(document) + if current is None: + current = ExtractedField(document_id=document.id) + db.add(current) + + current.merchant_raw = merchant_raw or None + current.merchant_normalized = merchant_normalized or None + + current.transaction_date = datetime.strptime(transaction_date, "%Y-%m-%d").date() if transaction_date else None + current.transaction_time = transaction_time or None + + current.subtotal = _to_decimal(subtotal) + current.tax = _to_decimal(tax) + current.total = _to_decimal(total) + current.currency = currency or None + + current.payment_method = payment_method or None + current.receipt_number = receipt_number or None + current.location = location or None + current.counterparty = counterparty or None + + try: + current.extra_json = json.loads(extra_json) if extra_json.strip() else {} + except json.JSONDecodeError: + current.extra_json = {"raw_text": extra_json} + + db.commit() + db.refresh(current) + return current diff --git a/app/models/extracted_field.py b/app/models/extracted_field.py index e1811ee..a261562 100644 --- a/app/models/extracted_field.py +++ b/app/models/extracted_field.py @@ -1,6 +1,7 @@ from datetime import datetime, date from decimal import Decimal -from sqlalchemy import String, DateTime, Date, ForeignKey, Numeric, JSON + +from sqlalchemy import Date, DateTime, ForeignKey, Numeric, String, Text, JSON from sqlalchemy.orm import Mapped, mapped_column, relationship from app.db.base import Base @@ -10,32 +11,32 @@ class ExtractedField(Base): __tablename__ = "extracted_fields" id: Mapped[int] = mapped_column(primary_key=True, index=True) - document_id: Mapped[int] = mapped_column( - ForeignKey("documents.id"), nullable=False, index=True - ) + document_id: Mapped[int] = mapped_column(ForeignKey("documents.id"), nullable=False, index=True) - merchant_raw: Mapped[str | None] = mapped_column(String(255), nullable=True) - merchant_normalized: Mapped[str | None] = mapped_column(String(255), nullable=True) + merchant_raw: Mapped[str | None] = mapped_column(Text, nullable=True) + merchant_normalized: Mapped[str | None] = mapped_column(Text, nullable=True) transaction_date: Mapped[date | None] = mapped_column(Date, nullable=True) + transaction_time: Mapped[str | None] = mapped_column(String(32), nullable=True) subtotal: Mapped[Decimal | None] = mapped_column(Numeric(12, 2), nullable=True) tax: Mapped[Decimal | None] = mapped_column(Numeric(12, 2), nullable=True) total: Mapped[Decimal | None] = mapped_column(Numeric(12, 2), nullable=True) + currency: Mapped[str | None] = mapped_column(String(16), nullable=True) - currency: Mapped[str | None] = mapped_column(String(10), nullable=True) - payment_method: Mapped[str | None] = mapped_column(String(100), nullable=True) - receipt_number: Mapped[str | None] = mapped_column(String(255), nullable=True) - location: Mapped[str | None] = mapped_column(String(255), nullable=True) - counterparty: Mapped[str | None] = mapped_column(String(255), nullable=True) + payment_method: Mapped[str | None] = mapped_column(String(64), nullable=True) + receipt_number: Mapped[str | None] = mapped_column(String(128), nullable=True) + location: Mapped[str | None] = mapped_column(Text, nullable=True) + counterparty: Mapped[str | None] = mapped_column(Text, nullable=True) extra_json: Mapped[dict | None] = mapped_column(JSON, nullable=True) - created_at: Mapped[datetime] = mapped_column( - DateTime, default=datetime.utcnow, nullable=False - ) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column( - DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False + DateTime, + default=datetime.utcnow, + onupdate=datetime.utcnow, + nullable=False, ) document: Mapped["Document"] = relationship(back_populates="extracted_fields") diff --git a/app/routes/documents.py b/app/routes/documents.py index b5219cd..906573e 100644 --- a/app/routes/documents.py +++ b/app/routes/documents.py @@ -12,6 +12,11 @@ from app.logic.document_outputs import ( create_field_enriched_pdf_version, create_ocr_corrected_pdf_version, ) +from app.logic.extraction import ( + auto_extract_from_document, + get_current_extracted_fields, + save_extracted_fields, +) from app.logic.ingest import compute_quality_score, rerun_ocr_for_document from app.models.document import Document from app.models.document_version import DocumentVersion @@ -112,6 +117,48 @@ def _apply_reviewed_lines_to_layout(base_layout: dict | None, reviewed_text: str return new_layout +def _extracted_field_form_values(document: Document, request: Request) -> dict: + current = get_current_extracted_fields(document) + auto = request.query_params.get("autofill_extracted") + + if auto == "1": + values = auto_extract_from_document(None, document) + elif current is not None: + values = { + "merchant_raw": current.merchant_raw or "", + "merchant_normalized": current.merchant_normalized or "", + "transaction_date": current.transaction_date.isoformat() if current.transaction_date else "", + "transaction_time": current.transaction_time or "", + "subtotal": str(current.subtotal) if current.subtotal is not None else "", + "tax": str(current.tax) if current.tax is not None else "", + "total": str(current.total) if current.total is not None else "", + "currency": current.currency or "", + "payment_method": current.payment_method or "", + "receipt_number": current.receipt_number or "", + "location": current.location or "", + "counterparty": current.counterparty or "", + "extra_json": "{}" if current.extra_json is None else __import__("json").dumps(current.extra_json, indent=2, sort_keys=True), + } + else: + values = { + "merchant_raw": "", + "merchant_normalized": "", + "transaction_date": "", + "transaction_time": "", + "subtotal": "", + "tax": "", + "total": "", + "currency": "", + "payment_method": "", + "receipt_number": "", + "location": "", + "counterparty": "", + "extra_json": "{}", + } + + return values + + @router.get("/", response_class=HTMLResponse) def list_documents(request: Request, db: Session = Depends(get_db)): documents = db.query(Document).order_by(Document.created_at.desc()).all() @@ -122,58 +169,6 @@ def list_documents(request: Request, db: Session = Depends(get_db)): ) -@router.get("/test-ingest", response_class=RedirectResponse) -def test_ingest(db: Session = Depends(get_db)): - public_id = f"doc_{uuid4().hex[:12]}" - - document = Document( - document_id=public_id, - document_type="receipt", - source_path=f"/mnt/storage/documents/incoming/{public_id}.pdf", - current_path=f"/mnt/storage/documents/current/{public_id}.pdf", - original_filename=f"{public_id}.pdf", - canonical_filename=f"{public_id}.pdf", - mime_type="application/pdf", - file_size=245760, - page_count=1, - sha256_current="dummy_current_hash", - storage_status="ingested", - review_status="ocr_complete", - ) - db.add(document) - db.flush() - - version = DocumentVersion( - document_id=document.id, - version_number=1, - version_type="original", - file_path=document.current_path, - sha256=document.sha256_current, - created_by="system", - notes="Initial test ingest", - ) - db.add(version) - - raw_text = TextVersion( - document_id=document.id, - version_number=1, - version_type="raw_ocr", - text_content="CVS PHARMACY\nDate: 2026-04-01\nTotal: 12.34 USD\nHousehold supplies\n", - created_by="system", - is_current=True, - ocr_engine="test_seed", - ocr_engine_version=None, - rerun_source="initial_ingest", - quality_flags=[], - quality_note=None, - ) - db.add(raw_text) - - db.commit() - - return RedirectResponse(url=f"/documents/{document.document_id}", status_code=303) - - @router.post("/{document_id}/rerun-ocr", response_class=RedirectResponse) def rerun_ocr(document_id: str, db: Session = Depends(get_db)): document = db.query(Document).filter(Document.document_id == document_id).first() @@ -280,6 +275,54 @@ def save_reviewed_text( return RedirectResponse(url=f"/documents/{document.document_id}?editor_source=reviewed", status_code=303) +@router.post("/{document_id}/save-extracted-fields", response_class=RedirectResponse) +def save_extracted_fields_route( + document_id: str, + merchant_raw: str = Form(""), + merchant_normalized: str = Form(""), + transaction_date: str = Form(""), + transaction_time: str = Form(""), + subtotal: str = Form(""), + tax: str = Form(""), + total: str = Form(""), + currency: str = Form(""), + payment_method: str = Form(""), + receipt_number: str = Form(""), + location: str = Form(""), + counterparty: str = Form(""), + extra_json: str = Form("{}"), + db: Session = Depends(get_db), +): + document = ( + db.query(Document) + .options(selectinload(Document.extracted_fields), selectinload(Document.text_versions)) + .filter(Document.document_id == document_id) + .first() + ) + if document is None: + return RedirectResponse(url="/documents/", status_code=303) + + save_extracted_fields( + db=db, + document=document, + merchant_raw=merchant_raw, + merchant_normalized=merchant_normalized, + transaction_date=transaction_date, + transaction_time=transaction_time, + subtotal=subtotal, + tax=tax, + total=total, + currency=currency, + payment_method=payment_method, + receipt_number=receipt_number, + location=location, + counterparty=counterparty, + extra_json=extra_json, + ) + + return RedirectResponse(url=f"/documents/{document.document_id}?autofill_extracted=0", status_code=303) + + @router.get("/{document_id}", response_class=HTMLResponse) def document_detail(document_id: str, request: Request, db: Session = Depends(get_db)): document = ( @@ -298,13 +341,10 @@ def document_detail(document_id: str, request: Request, db: Session = Depends(ge return HTMLResponse(content="Document not found", status_code=404) raw_ocr, reviewed_ocr = _get_current_text_versions(document) + editor_source = request.query_params.get("editor_source", "reviewed") review_text_value = _build_review_text_value(raw_ocr, reviewed_ocr, editor_source) - base_layout = ( - reviewed_ocr.layout_json if reviewed_ocr and reviewed_ocr.layout_json - else raw_ocr.layout_json if raw_ocr else None - ) expected_line_count = _line_count_from_layout(raw_ocr.layout_json if raw_ocr else None) actual_line_count = len(review_text_value.splitlines()) if review_text_value else 0 line_numbers = list(range(1, max(actual_line_count, expected_line_count) + 1)) @@ -324,6 +364,9 @@ def document_detail(document_id: str, request: Request, db: Session = Depends(ge error_expected = request.query_params.get("expected") error_actual = request.query_params.get("actual") + extracted_form = _extracted_field_form_values(document, request) + current_extracted = get_current_extracted_fields(document) + return templates.TemplateResponse( request=request, name="documents/detail.html", @@ -344,5 +387,7 @@ def document_detail(document_id: str, request: Request, db: Session = Depends(ge "error": error, "error_expected": error_expected, "error_actual": error_actual, + "extracted_form": extracted_form, + "current_extracted": current_extracted, }, ) diff --git a/app/templates/documents/detail.html b/app/templates/documents/detail.html index ffd5e5b..61b775a 100644 --- a/app/templates/documents/detail.html +++ b/app/templates/documents/detail.html @@ -221,5 +221,77 @@ textarea.addEventListener("input", updateEditorState); updateEditorState(); + +

Extracted fields

+ + {% if current_extracted %} +

Current extracted fields last updated at {{ current_extracted.updated_at }}

+ {% else %} +

No extracted fields saved yet.

+ {% endif %} + +
+ + +
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+
+ +
+
+ +
+
+