"""Batch OCR for PDFs using Mistral API"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb.

# %% auto 0
__all__ = ['ocr_model', 'ocr_endpoint', 'get_api_key', 'upload_pdf', 'create_batch_entry', 'prep_pdf_batch', 'submit_batch',
           'wait_for_job', 'download_results', 'save_images', 'save_page', 'save_pages', 'ocr', 'read_pgs']

# %% ../nbs/00_core.ipynb 3
from fastcore.all import *
import os, re, json, time, base64, tempfile, logging
from io import BytesIO
from pathlib import Path
from PIL import Image
from mistralai import Mistral

# %% ../nbs/00_core.ipynb 6
def get_api_key(
    key:str=None # Mistral API key
    ):
    "Get Mistral API key from parameter or environment"
    key = key or os.getenv("MISTRAL_API_KEY")
    if not key: raise ValueError("MISTRAL_API_KEY not found")
    return key

# %% ../nbs/00_core.ipynb 7
ocr_model = "mistral-ocr-latest"
ocr_endpoint = "/v1/ocr"

# %% ../nbs/00_core.ipynb 10
def upload_pdf(
    path:str, # Path to PDF file
    key:str=None # Mistral API key
    ) -> tuple[str, Mistral]: # Mistral pdf signed url and client
    "Upload PDF to Mistral and return signed URL"
    c = Mistral(api_key=get_api_key(key))
    path = Path(path)
    uploaded = c.files.upload(file=dict(file_name=path.stem, content=path.read_bytes()), purpose="ocr")
    return c.files.get_signed_url(file_id=uploaded.id).url, c

# %% ../nbs/00_core.ipynb 15
def create_batch_entry(
    path:str, # Path to PDF file, 
    url:str, # Mistral signed URL
    cid:str=None, # Custom ID (by default using the file name without extension)
    inc_img:bool=True # Include image in response
    ) -> dict[str, str | dict[str, str | bool]]: # Batch entry dict
    "Create a batch entry dict for OCR"
    path = Path(path)
    if not cid: cid = path.stem
    return dict(custom_id=cid, body=dict(document=dict(type="document_url", document_url=url), include_image_base64=inc_img))

# %% ../nbs/00_core.ipynb 17
def prep_pdf_batch(
    path:str, # Path to PDF file, 
    cid:str=None, # Custom ID (by default using the file name without extention)
    inc_img:bool=True, # Include image in response
    key=None # API key
    ) -> dict: # Batch entry dict
    "Upload PDF and create batch entry in one step"
    url, c = upload_pdf(path, key)
    return create_batch_entry(path, url, cid, inc_img), c

# %% ../nbs/00_core.ipynb 21
def submit_batch(
    entries:list[dict], # List of batch entries, 
    c:Mistral=None, # Mistral client, 
    model:str=ocr_model, # Model name, 
    endpoint:str=ocr_endpoint # Endpoint name
    ) -> dict: # Job dict
    "Submit batch entries and return job"
    with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=True) as f:
        for e in entries: f.write(json.dumps(e) + '\n')
        f.flush()
        batch_data = c.files.upload(file=dict(file_name="batch.jsonl", content=open(f.name, "rb")), purpose="batch")
    return c.batch.jobs.create(input_files=[batch_data.id], model=model, endpoint=endpoint)

# %% ../nbs/00_core.ipynb 24
def wait_for_job(
    job:dict, # Job dict, 
    c:Mistral=None, # Mistral client, 
    poll_interval:int=10 # Poll interval in seconds
    ) -> dict: # Job dict (with status)  
    "Poll job until completion and return final job status"
    while job.status in ["QUEUED", "RUNNING"]:
        time.sleep(poll_interval)
        job = c.batch.jobs.get(job_id=job.id)
    return job

# %% ../nbs/00_core.ipynb 26
def download_results(
    job:dict, # Job dict, 
    c:Mistral=None # Mistral client
    ) -> list[dict]: # List of results
    "Download and parse batch job results"
    content = c.files.download(file_id=job.output_file).read().decode('utf-8')
    return [json.loads(line) for line in content.strip().split('\n') if line]

# %% ../nbs/00_core.ipynb 31
def save_images(
    page:dict, # Page dict, 
    img_dir:str='img' # Directory to save images
    ) -> None:
    "Save all images from a page to directory"
    for img in page.get('images', []):
        if img.get('image_base64') and img.get('id'):
            img_bytes = base64.b64decode(img['image_base64'].split(',')[1])
            Image.open(BytesIO(img_bytes)).save(img_dir / img['id'])

# %% ../nbs/00_core.ipynb 32
def save_page(
    page:dict, # Page dict, 
    dst:str, # Directory to save page
    img_dir:str='img' # Directory to save images
    ) -> None:
    "Save single page markdown and images"
    (dst / f"page_{page['index']+1}.md").write_text(page['markdown'])
    if page.get('images'):
        img_dir.mkdir(exist_ok=True)
        save_images(page, img_dir)

# %% ../nbs/00_core.ipynb 34
def save_pages(
    ocr_resp:dict, # OCR response, 
    dst:str, # Directory to save pages, 
    cid:str # Custom ID
    ) -> Path: # Output directory
    "Save markdown pages and images from OCR response to output directory"
    dst = Path(dst) / cid
    dst.mkdir(parents=True, exist_ok=True)
    img_dir = dst / 'img'
    for page in ocr_resp['pages']: save_page(page, dst, img_dir)
    return dst

# %% ../nbs/00_core.ipynb 40
def _get_paths(path:str) -> list[Path]:
    "Get list of PDFs from file or folder"
    path = Path(path)
    if path.is_file(): return [path]
    if path.is_dir():
        pdfs = path.ls(file_exts='.pdf')
        if not pdfs: raise ValueError(f"No PDFs found in {path}")
        return pdfs
    raise ValueError(f"Path not found: {path}")

# %% ../nbs/00_core.ipynb 41
def _prep_batch(pdfs:list[Path], inc_img:bool=True, key:str=None) -> tuple[list[dict], Mistral]:
    "Prepare batch entries for list of PDFs"
    entries, c = [], None
    for pdf in pdfs:
        entry, c = prep_pdf_batch(pdf, inc_img=inc_img, key=key)
        entries.append(entry)
    return entries, c

# %% ../nbs/00_core.ipynb 42
def _run_batch(entries:list[dict], c:Mistral, poll_interval:int=2) -> list[dict]:
    "Submit batch, wait for completion, and download results"
    job = submit_batch(entries, c)
    job = wait_for_job(job, c, poll_interval)
    if job.status != 'SUCCESS': raise Exception(f"Job failed with status: {job.status}")
    return download_results(job, c)

# %% ../nbs/00_core.ipynb 43
def ocr(
    path:str, # Path to PDF file or folder,
    dst:str='md', # Directory to save markdown pages, 
    inc_img:bool=True, # Include image in response, 
    key:str=None, # API key, 
    poll_interval:int=2 # Poll interval in seconds
    ) -> list[Path]: # List of output directories
    "OCR a PDF file or folder of PDFs and save results"
    pdfs = _get_paths(path)
    entries, c = _prep_batch(pdfs, inc_img, key)
    results = _run_batch(entries, c, poll_interval)
    return L([save_pages(r['response']['body'], dst, r['custom_id']) for r in results])

# %% ../nbs/00_core.ipynb 48
def read_pgs(
    path:str, # OCR output directory, 
    join:bool=True # Join pages into single string
    ) -> str|list[str]: # Joined string or list of page contents
    "Read specific page or all pages from OCR output directory"
    path = Path(path)
    pgs = sorted(path.glob('page_*.md'), key=lambda p: int(p.stem.split('_')[1]))
    contents = L([p.read_text() for p in pgs])
    return '\n\n'.join(contents) if join else contents
