#!/usr/bin/env python3
"""Local OCR with Baidu/PaddleOCR models.

This is "unlimited" in the practical sense that it runs locally and does not
meter pages through a hosted OCR API. Accuracy, speed, and model downloads
depend on your machine and PaddleOCR installation.

Requires:
  pip install paddleocr

Install the PaddlePaddle runtime that matches your platform from:
  https://www.paddlepaddle.org.cn/install/quick

Usage:
  python scripts/ocr/paddle_unlimited_ocr.py input/ output/
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any, Iterable

SUPPORTED_SUFFIXES = {".pdf", ".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp", ".webp"}


def fail(message: str) -> None:
    print(f"error: {message}", file=sys.stderr)
    raise SystemExit(1)


def iter_inputs(input_path: Path) -> list[Path]:
    if input_path.is_file():
        return [input_path] if input_path.suffix.lower() in SUPPORTED_SUFFIXES else []
    if input_path.is_dir():
        return sorted(path for path in input_path.rglob("*") if path.suffix.lower() in SUPPORTED_SUFFIXES)
    return []


def as_mapping(obj: Any) -> dict[str, Any]:
    if isinstance(obj, dict):
        return obj
    for attr in ("json", "res"):
        value = getattr(obj, attr, None)
        if isinstance(value, dict):
            return value
    return {}


def collect_text_values(value: Any) -> list[str]:
    """Extract likely OCR text fields from several PaddleOCR result shapes."""
    texts: list[str] = []

    if isinstance(value, dict):
        for key in ("rec_texts", "texts", "text"):
            item = value.get(key)
            if isinstance(item, str):
                texts.append(item)
            elif isinstance(item, Iterable) and not isinstance(item, (bytes, str, dict)):
                texts.extend(str(part) for part in item if part)

        # PaddleOCR sometimes nests the useful payload under "res".
        if "res" in value:
            texts.extend(collect_text_values(value["res"]))

        for key in ("ocr_result", "pages", "results"):
            if key in value:
                texts.extend(collect_text_values(value[key]))

    elif isinstance(value, list):
        for item in value:
            texts.extend(collect_text_values(item))
    else:
        mapping = as_mapping(value)
        if mapping:
            texts.extend(collect_text_values(mapping))

    return [text for text in texts if text and text.strip()]


def result_to_text(result: Any) -> str:
    texts = collect_text_values(result)
    seen: set[str] = set()
    deduped: list[str] = []
    for text in texts:
        normalized = text.strip()
        if normalized not in seen:
            deduped.append(normalized)
            seen.add(normalized)
    return "\n".join(deduped).strip()


def result_to_jsonable(result: Any) -> Any:
    if isinstance(result, (dict, list, str, int, float, bool)) or result is None:
        return result
    mapping = as_mapping(result)
    if mapping:
        return mapping
    return repr(result)


def make_output_path(input_root: Path, source: Path, output_dir: Path) -> Path:
    if input_root.is_dir():
        relative = source.relative_to(input_root)
    else:
        relative = source.name
    return (output_dir / relative).with_suffix(".md")


def create_ocr(args: argparse.Namespace) -> Any:
    try:
        from paddleocr import PaddleOCR
    except ImportError as exc:
        fail("paddleocr is not installed. Run: pip install paddleocr")

    kwargs: dict[str, Any] = {
        "lang": args.lang,
        "use_doc_orientation_classify": args.use_doc_orientation,
        "use_doc_unwarping": args.use_doc_unwarping,
        "use_textline_orientation": args.use_textline_orientation,
    }
    if args.ocr_version:
        kwargs["ocr_version"] = args.ocr_version
    if args.device:
        kwargs["device"] = args.device

    try:
        return PaddleOCR(**kwargs)
    except TypeError:
        # Older PaddleOCR versions use different option names.
        fallback_kwargs = {"lang": args.lang}
        if args.ocr_version:
            fallback_kwargs["ocr_version"] = args.ocr_version
        return PaddleOCR(**fallback_kwargs)


def predict(ocr: Any, source: Path) -> Any:
    if hasattr(ocr, "predict"):
        return ocr.predict(input=str(source))
    return ocr.ocr(str(source), cls=True)


def write_output(output_path: Path, source: Path, text: str, raw_result: Any, include_json: bool) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with output_path.open("w", encoding="utf-8") as out:
        out.write(f"<!-- source: {source} -->\n\n")
        if text:
            out.write(text)
            out.write("\n")
        else:
            out.write("<!-- no OCR text extracted -->\n")

    if include_json:
        json_path = output_path.with_suffix(".json")
        json_path.write_text(
            json.dumps(result_to_jsonable(raw_result), ensure_ascii=False, indent=2),
            encoding="utf-8",
        )


def main() -> None:
    parser = argparse.ArgumentParser(description="Run local OCR with Baidu/PaddleOCR models.")
    parser.add_argument("input_path", type=Path, help="PDF/image file or folder to process.")
    parser.add_argument("output_dir", type=Path, help="Folder where Markdown files will be written.")
    parser.add_argument("--lang", default="en", help="PaddleOCR language code. Use 'ch' for Chinese.")
    parser.add_argument("--ocr-version", default=None, help="Optional PaddleOCR version, for example PP-OCRv5.")
    parser.add_argument("--device", default=None, help="Optional device string, for example cpu, gpu:0, or mps.")
    parser.add_argument("--use-doc-orientation", action="store_true", help="Enable document orientation classification.")
    parser.add_argument("--use-doc-unwarping", action="store_true", help="Enable document unwarping.")
    parser.add_argument("--use-textline-orientation", action="store_true", help="Enable text-line orientation classification.")
    parser.add_argument("--include-json", action="store_true", help="Write raw PaddleOCR JSON next to each Markdown file.")
    args = parser.parse_args()

    input_path = args.input_path.expanduser().resolve()
    output_dir = args.output_dir.expanduser().resolve()
    inputs = iter_inputs(input_path)
    if not inputs:
        fail(f"no supported PDF/image files found: {input_path}")

    output_dir.mkdir(parents=True, exist_ok=True)
    ocr = create_ocr(args)

    for index, source in enumerate(inputs, start=1):
        print(f"ocr {index}/{len(inputs)}: {source}")
        try:
            result = predict(ocr, source)
            text = result_to_text(result)
            output_path = make_output_path(input_path, source, output_dir)
            write_output(output_path, source, text, result, args.include_json)
        except Exception as exc:
            print(f"failed: {source}: {exc}", file=sys.stderr)


if __name__ == "__main__":
    main()
