#!/usr/bin/env python3
"""Batch-convert PDFs to Markdown with Mistral OCR.

Requires:
  pip install mistralai

Usage:
  MISTRAL_API_KEY=... python scripts/ocr/mistral_batch_ocr.py pdfs/ markdown/
"""

from __future__ import annotations

import argparse
import base64
import json
import os
import sys
import time
from pathlib import Path
from typing import Any

from mistralai import Mistral


TERMINAL_STATUSES = {"SUCCESS", "SUCCEEDED", "FAILED", "CANCELLED", "CANCELED"}
RUNNING_STATUSES = {"QUEUED", "RUNNING", "STARTED", "VALIDATING"}


def fail(message: str) -> None:
    print(f"error: {message}", file=sys.stderr)
    raise SystemExit(1)


def find_pdfs(input_dir: Path, max_size_mb: int) -> list[Path]:
    max_bytes = max_size_mb * 1024 * 1024
    pdfs: list[Path] = []
    skipped: list[tuple[Path, float]] = []

    for path in sorted(input_dir.rglob("*.pdf")):
        size = path.stat().st_size
        if size <= max_bytes:
            pdfs.append(path)
        else:
            skipped.append((path.relative_to(input_dir), size / (1024 * 1024)))

    for rel_path, size_mb in skipped:
        print(f"skip: {rel_path} is {size_mb:.1f} MB, above --max-size-mb")

    return pdfs


def data_url(path: Path) -> str:
    encoded = base64.b64encode(path.read_bytes()).decode("ascii")
    return f"data:application/pdf;base64,{encoded}"


def make_batch_request(pdf_path: Path, relative_path: Path, include_images: bool) -> dict[str, Any]:
    return {
        "custom_id": str(relative_path),
        "body": {
            "document": {
                "type": "document_url",
                "document_url": data_url(pdf_path),
            },
            "include_image_base64": include_images,
        },
    }


def write_jsonl_batch(input_dir: Path, pdfs: list[Path], batch_path: Path, include_images: bool) -> None:
    with batch_path.open("w", encoding="utf-8") as out:
        for index, pdf_path in enumerate(pdfs, start=1):
            relative_path = pdf_path.relative_to(input_dir)
            print(f"encode {index}/{len(pdfs)}: {relative_path}")
            out.write(json.dumps(make_batch_request(pdf_path, relative_path, include_images)) + "\n")


def response_bytes(response: Any) -> bytes:
    if hasattr(response, "iter_bytes"):
        return b"".join(response.iter_bytes())
    if hasattr(response, "content"):
        content = response.content
        return content if isinstance(content, bytes) else bytes(content)
    if isinstance(response, bytes):
        return response
    return bytes(response)


def get_attr_or_key(obj: Any, name: str, default: Any = None) -> Any:
    if isinstance(obj, dict):
        return obj.get(name, default)
    return getattr(obj, name, default)


def markdown_from_page(page: Any, fallback_index: int) -> str:
    """Handle both current Mistral OCR pages and older SDK/dict shapes."""
    markdown = get_attr_or_key(page, "markdown")
    if markdown:
        return str(markdown).rstrip()

    text = get_attr_or_key(page, "text")
    if text:
        return str(text).rstrip()

    # Some SDK versions expose a nested response dict after JSONL download.
    page_body = get_attr_or_key(page, "body")
    if page_body:
        return markdown_from_page(page_body, fallback_index)

    return f"<!-- no OCR text returned for page {fallback_index + 1} -->"


def write_markdown_result(result: dict[str, Any], output_dir: Path) -> bool:
    relative_pdf = result.get("custom_id")
    if not relative_pdf:
        print("skip result without custom_id")
        return False

    response = result.get("response") or {}
    if response.get("status_code") != 200:
        error = response.get("body") or response.get("error") or response
        print(f"failed: {relative_pdf}: {error}")
        return False

    body = response.get("body") or {}
    pages = body.get("pages") or []
    if not pages:
        print(f"failed: {relative_pdf}: no pages in OCR response")
        return False

    output_path = output_dir / Path(relative_pdf).with_suffix(".md")
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with output_path.open("w", encoding="utf-8") as out:
        for index, page in enumerate(pages):
            page_index = get_attr_or_key(page, "index", index)
            out.write(f"<!-- page {int(page_index) + 1} -->\n\n")
            out.write(markdown_from_page(page, index))
            out.write("\n\n---\n\n")

    return True


def process_results(results_path: Path, output_dir: Path) -> int:
    processed = 0
    with results_path.open("r", encoding="utf-8") as results:
        for line_number, line in enumerate(results, start=1):
            if not line.strip():
                continue
            try:
                result = json.loads(line)
            except json.JSONDecodeError as exc:
                print(f"skip malformed JSONL line {line_number}: {exc}")
                continue
            if write_markdown_result(result, output_dir):
                processed += 1
    return processed


def wait_for_job(client: Mistral, job_id: str, poll_seconds: int) -> Any:
    while True:
        job = client.batch.jobs.get(job_id=job_id)
        status = str(get_attr_or_key(job, "status", "")).upper()
        total = get_attr_or_key(job, "total_requests", "?")
        succeeded = get_attr_or_key(job, "succeeded_requests", "?")
        failed = get_attr_or_key(job, "failed_requests", "?")
        print(f"\rstatus: {status} | total: {total} | succeeded: {succeeded} | failed: {failed}", end="", flush=True)

        if status in TERMINAL_STATUSES or status not in RUNNING_STATUSES:
            print()
            return job

        time.sleep(poll_seconds)


def main() -> None:
    parser = argparse.ArgumentParser(description="Batch OCR PDF files with Mistral OCR.")
    parser.add_argument("input_dir", type=Path, help="Folder containing PDF files.")
    parser.add_argument("output_dir", type=Path, help="Folder where Markdown files will be written.")
    parser.add_argument("--api-key", default=os.getenv("MISTRAL_API_KEY"), help="Mistral API key. Defaults to MISTRAL_API_KEY.")
    parser.add_argument("--model", default="mistral-ocr-latest", help="Mistral OCR model name.")
    parser.add_argument("--max-size-mb", type=int, default=50, help="Skip PDFs larger than this size. Default: 50.")
    parser.add_argument("--include-images", action="store_true", help="Ask Mistral to include extracted image base64 in OCR results.")
    parser.add_argument("--work-dir", type=Path, default=Path(".ocr-work"), help="Temporary JSONL directory.")
    parser.add_argument("--poll-seconds", type=int, default=5, help="Batch job polling interval.")
    parser.add_argument("--keep-work-files", action="store_true", help="Keep request/result JSONL files after completion.")
    args = parser.parse_args()

    if not args.api_key:
        fail("pass --api-key or set MISTRAL_API_KEY")

    input_dir = args.input_dir.expanduser().resolve()
    output_dir = args.output_dir.expanduser().resolve()
    work_dir = args.work_dir.expanduser().resolve()

    if not input_dir.is_dir():
        fail(f"input_dir is not a directory: {input_dir}")

    pdfs = find_pdfs(input_dir, args.max_size_mb)
    if not pdfs:
        fail(f"no PDFs found under {input_dir}")

    output_dir.mkdir(parents=True, exist_ok=True)
    work_dir.mkdir(parents=True, exist_ok=True)

    request_path = work_dir / "mistral_ocr_requests.jsonl"
    result_path = work_dir / "mistral_ocr_results.jsonl"
    write_jsonl_batch(input_dir, pdfs, request_path, args.include_images)

    client = Mistral(api_key=args.api_key)

    print("upload batch file")
    with request_path.open("rb") as request_file:
        uploaded = client.files.upload(
            file={"file_name": request_path.name, "content": request_file},
            purpose="batch",
        )

    print("create batch job")
    job = client.batch.jobs.create(
        input_files=[uploaded.id],
        model=args.model,
        endpoint="/v1/ocr",
        metadata={"job_type": "research_memex_ocr"},
    )

    job_id = get_attr_or_key(job, "id")
    if not job_id:
        fail("Mistral did not return a batch job id")

    finished = wait_for_job(client, job_id, args.poll_seconds)
    status = str(get_attr_or_key(finished, "status", "")).upper()
    output_file = get_attr_or_key(finished, "output_file")

    if status not in {"SUCCESS", "SUCCEEDED"} or not output_file:
        fail(f"batch job ended with status={status}; output_file={output_file!r}")

    print("download results")
    response = client.files.download(file_id=output_file)
    result_path.write_bytes(response_bytes(response))

    processed = process_results(result_path, output_dir)
    print(f"processed {processed}/{len(pdfs)} PDFs into {output_dir}")

    if not args.keep_work_files:
        request_path.unlink(missing_ok=True)
        result_path.unlink(missing_ok=True)
        try:
            work_dir.rmdir()
        except OSError:
            pass


if __name__ == "__main__":
    main()
