"""Postprocess markdown files by fixing heading hierarchy and describint images"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_refine.ipynb.

# %% auto 0
__all__ = ['prompt_fix_hdgs', 'get_hdgs', 'fmt_hdgs_idx', 'HeadingCorrections', 'fix_hdg_hierarchy', 'mk_fixes_lut',
           'apply_hdg_fixes', 'fix_md_hdgs']

# %% ../nbs/01_refine.ipynb 3
from fastcore.all import *
from .core import read_pgs
from re import sub, findall, MULTILINE
from pydantic import BaseModel
from lisette.core import completion
import os
import json

# %% ../nbs/01_refine.ipynb 7
def get_hdgs(
    md:str # Markdown file string
    ):
    "Return the markdown headings"
    # Sanitize removing '#' in python snippet if any
    md = sub(r'```[\s\S]*?```', '', md)
    return L(findall(r'^#{1,6} .+$', md, MULTILINE))



# %% ../nbs/01_refine.ipynb 10
def fmt_hdgs_idx(
    hdgs: list[str] # List of markdown headings
    ) -> str: # Formatted string with index
    "Format the headings with index"
    return '\n'.join(f"{i}. {h}" for i, h in enumerate(hdgs))


# %% ../nbs/01_refine.ipynb 13
class HeadingCorrections(BaseModel):
    corrections: dict[int, str]  # index → corrected heading

# %% ../nbs/01_refine.ipynb 15
prompt_fix_hdgs = """Fix markdown heading hierarchy errors while preserving the document's intended structure.

INPUT FORMAT: Each heading is prefixed with its index number (e.g., "0. # Title")

RULES - Only fix these errors:
1. **Level jumps**: Headings can only increase by one # at a time
   - Wrong: 0. # Title → 1. #### Abstract
   - Fixed: 0. # Title → 1. ## Abstract

2. **Numbering inconsistency**: Subsection numbers must be one level deeper
   - Wrong: 4. ## 3. Section → 5. ## 3.1 Subsection
   - Fixed: 4. ## 3. Section → 5. ### 3.1 Subsection

3. **Preserve working structure**: If sections are consistently marked, keep it

4. **Decreasing levels is OK**: Going from ### to ## is valid for new sections

OUTPUT: Return a Python dictionary mapping index to corrected heading (without the index prefix).
Only include entries that need changes. Example: {{1: '## Abstract', 15: '### PASCAL VOC'}}

Headings to analyze:
{headings_list}
"""

# %% ../nbs/01_refine.ipynb 18
def fix_hdg_hierarchy(
    hdgs: list[str], # List of markdown headings
    prompt: str=prompt_fix_hdgs, # Prompt to use
    model: str='claude-sonnet-4-5', # Model to use
    api_key: str=os.getenv('ANTHROPIC_API_KEY') # API key
    ) -> dict[int, str]: # Dictionary of index → corrected heading
    "Fix the heading hierarchy"
    r = completion(
        model=model, 
        messages=[{"role": "user", "content": prompt_fix_hdgs.format(headings_list=fmt_hdgs_idx(hdgs))}], 
        response_format=HeadingCorrections, 
        api_key=api_key
        )
    return json.loads(r.choices[0].message.content)['corrections']

# %% ../nbs/01_refine.ipynb 21
def mk_fixes_lut(
    hdgs: list[str], # List of markdown headings
    model: str='claude-sonnet-4-5', # Model to use
    api_key: str=os.getenv('ANTHROPIC_API_KEY') # API key
    ) -> dict[str, str]: # Dictionary of old → new heading
    "Make a lookup table of fixes"
    fixes = fix_hdg_hierarchy(hdgs, model, api_key)
    return {hdgs[int(k)]:v for k,v in fixes.items()}

# %% ../nbs/01_refine.ipynb 24
def apply_hdg_fixes(
    p:str, # Page to fix
    lut_fixes: dict[str, str], # Lookup table of fixes
    pg: int=None, # Optionnaly specify the page number to append to original heading
    ) -> str: # Page with fixes applied
    "Apply the fixes to the page"
    for old in get_hdgs(p): p = p.replace(old, lut_fixes.get(old, old) + (f' .... page {pg}' if pg else ''))
    return p

# %% ../nbs/01_refine.ipynb 27
def fix_md_hdgs(
    src:str, # Source directory with markdown pages
    model:str='claude-sonnet-4-5', # Model
    dst:str=None, # Destination directory (None=overwrite)
    pg_nums:bool=True # Add page numbers
):
    "Fix heading hierarchy in markdown document"
    src_path,dst_path = Path(src),Path(dst) if dst else Path(src)
    if dst_path != src_path: dst_path.mkdir(parents=True, exist_ok=True)
    lut = mk_fixes_lut(get_hdgs(read_pgs(src_path)), model)
    for i,p in enumerate(read_pgs(src_path, join=False), 1):
        (dst_path/f'page_{i}.md').write_text(apply_hdg_fixes(p, lut, pg=i if pg_nums else None))
