"""Fix, clean markdown headings and enrich it with figures description, ..."""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_enrichr.ipynb.

# %% auto 0
__all__ = ['GEMINI_API_KEY', 'cfg', 'src_dir', 'lm', 'setup_enhanced_dir', 'get_hdgs', 'get_hdgs_with_pages', 'format_hdgs',
           'HeadingResult', 'FixHeadingHierarchy', 'fix_md', 'group_corrections_by_page', 'apply_corrections_to_page',
           'apply_all_corrections', 'fix_doc_hdgs', 'has_images', 'MarkdownPage', 'ImgRef', 'ImageRelevance',
           'describe_img', 'copy_page_to_enriched', 'process_single_page', 'enrich_images']

# %% ../nbs/04_enrichr.ipynb 3
from pathlib import Path
import os
import re
import json
import shutil
import time
from functools import partial
from tqdm import tqdm
from dotenv import load_dotenv
from fastcore.all import *
import dspy
from pydantic import BaseModel
from typing import List
from litellm import completion
import base64
from rich import print

# %% ../nbs/04_enrichr.ipynb 4
load_dotenv()
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')

# %% ../nbs/04_enrichr.ipynb 5
cfg = AttrDict({
    'enhanced_dir': 'enhanced',
    'enriched_dir': 'enriched',
    'lm': 'gemini/gemini-2.0-flash-exp',
    'api_key': GEMINI_API_KEY,
    'max_tokens': 8192,
    'track_usage': False,
    'img_dir': 'img'
})

# %% ../nbs/04_enrichr.ipynb 6
src_dir = Path("../_data/md_library/49d2fba781b6a7c0d94577479636ee6f")

# %% ../nbs/04_enrichr.ipynb 9
def setup_enhanced_dir(
    src_dir, # Source directory path
    enhanced_dir_name=cfg.enhanced_dir # Name of enhanced subdirectory
    ):
    "Create enhanced directory and copy all markdown files to it"
    src_path = Path(src_dir)
    enhanced_path = src_path / enhanced_dir_name
    enhanced_path.mkdir(exist_ok=True)
    for f in src_path.ls(file_exts=".md"): shutil.copy(f, enhanced_path)
    return enhanced_path

# %% ../nbs/04_enrichr.ipynb 12
def get_hdgs(md_txt): return re.findall(r'^#+.*$', md_txt, re.MULTILINE)

# %% ../nbs/04_enrichr.ipynb 13
def get_hdgs_with_pages(
    pages: list[Path] # List of pages
    ):
    "Get headings and the page number they are on"
    headings = []
    for i, page in enumerate(pages, 1):  # page numbers start at 1
        page_headings = get_hdgs(page.read_text())
        # add each heading with its page number
        for o in page_headings:
            headings.append({'heading': o, 'page': i})
    return headings

# %% ../nbs/04_enrichr.ipynb 16
def format_hdgs(
    hdgs: list[dict] # List of headings with page numbers
    ):
    "Format headings with page numbers"
    formatted = []
    page_positions = {}
    
    for item in hdgs:
        page = item['page']
        page_positions[page] = page_positions.get(page, 0) + 1
        formatted.append(f"{item['heading']} (Page {page}, Position {page_positions[page]})")
    
    return "\n".join(formatted)

# %% ../nbs/04_enrichr.ipynb 18
lm = dspy.LM(cfg.lm, api_key=cfg.api_key)
dspy.configure(lm=lm)
dspy.settings.configure(track_usage=cfg.track_usage)

# %% ../nbs/04_enrichr.ipynb 19
class HeadingResult(BaseModel):
    old: str
    page: int
    position: int
    new: str
    changed: bool  # True if correction was made

# %% ../nbs/04_enrichr.ipynb 20
class FixHeadingHierarchy(dspy.Signature):
    """Fix markdown heading hierarchy by analyzing the document's numbering patterns:
    - Detect numbering scheme (1.2.3, I.A.1, A.1.a, etc.)
    - Apply hierarchy levels based on nesting depth: # for top level, ## for second level, ### for third level
    - When a section number is lower than a previously seen number at the same level (e.g., seeing '2.' after '3.1'), it's likely a subsection or list item, not a main section
    - Unnumbered headings: keep as-is if at document boundaries, treat as subsections if within numbered sections
    - Return ALL headings with their corrected form
    """
    
    headings_with_pages: str = dspy.InputField(desc="List of headings with page numbers")
    results: List[HeadingResult] = dspy.OutputField(desc="All headings with corrections and change status")

# %% ../nbs/04_enrichr.ipynb 21
def fix_md(
    hdgs: list[dict], # List of headings with page numbers
    track_usage: bool=cfg.track_usage,
    ):
    "Fix markdown headings"
    lm = dspy.LM(cfg.lm, api_key=cfg.api_key, max_tokens=cfg.max_tokens)
    dspy.configure(lm=lm)
    dspy.settings.configure(track_usage=track_usage)

    inp = format_hdgs(hdgs)
    fix_hdgs = dspy.ChainOfThought(FixHeadingHierarchy)
    result = fix_hdgs(headings_with_pages=inp)
    return result

# %% ../nbs/04_enrichr.ipynb 23
def group_corrections_by_page(
    results: list[HeadingResult], # List of headings with corrections and change status
    ):
    "Group HeadingResult corrections by page number into dict with page nums as keys"
    page_groups = {}
    for result in results:
        page = result.page
        if page not in page_groups:
            page_groups[page] = []
        page_groups[page].append(result)
    return page_groups

# %% ../nbs/04_enrichr.ipynb 25
def apply_corrections_to_page(
    page_nb, # Page number
    corrections, # List of corrections
    enhanced_path, # Path to enhanced directory
    ):
    "Apply corrections to a page in the enhanced directory"
    page_file = enhanced_path / f"page_{page_nb}.md"
    lines = page_file.read_text().splitlines()
    corrections_copy = corrections.copy()
    
    for i, line in enumerate(lines):
        for correction in corrections_copy:
            if line.strip() == correction.old.strip():
                lines[i] = f"{correction.new} .... page {page_nb}"
                corrections_copy.remove(correction)
                break
            
    page_file.write_text('\n'.join(lines))

# %% ../nbs/04_enrichr.ipynb 27
def apply_all_corrections(
    results, # List of headings with corrections and change status
    enhanced_path, # Path to enhanced directory
    ):
    "Apply all corrections to the pages in enhanced directory"
    grouped = group_corrections_by_page(results)
    for page_nb, corrections in grouped.items(): 
        apply_corrections_to_page(page_nb, corrections, enhanced_path)

# %% ../nbs/04_enrichr.ipynb 29
def fix_doc_hdgs(
    src_dir, # Path to the folder containing the document
    force=False, # Whether to overwrite the existing enhanced directory
    ):
    "Process the document directory"
    src_path = Path(src_dir)
    enhanced_path = src_path / cfg.enhanced_dir
    
    if enhanced_path.exists() and not force:
        print(f"Enhanced directory '{cfg.enhanced_dir}' already exists. Use force=True to overwrite.")
        return
    if enhanced_path.exists() and force: 
        shutil.rmtree(enhanced_path)
    
    enhanced_path = setup_enhanced_dir(src_dir)
    pages = enhanced_path.ls(file_exts=".md").sorted(key=lambda p: int(p.stem.split('_')[1]))
    result = fix_md(get_hdgs_with_pages(pages))
    apply_all_corrections(result.results, enhanced_path)

# %% ../nbs/04_enrichr.ipynb 33
def has_images(page_path):
    content = Path(page_path).read_text()
    return bool(re.search(r'!\[[^\]]*\]\([^)]+\)', content))

# %% ../nbs/04_enrichr.ipynb 36
class MarkdownPage: 
    "A class to represent a markdown page"
    def __init__(self, path): self.path = Path(path)

# %% ../nbs/04_enrichr.ipynb 37
class ImgRef(AttrDict):
    "A class to represent a image reference"
    def __repr__(self):
        clean_context = self.context.replace('\n', ' ')[:50] + "..."
        fields = [f"filename='{self.filename}'", f"context='{clean_context}'"]
        if hasattr(self, 'is_relevant'): fields.append(f"is_relevant={self.is_relevant}")
        if hasattr(self, 'reason'): fields.append(f"reason={self.reason}")
        # ... add other fields if present
        return f"ImgRef({', '.join(fields)})"


# %% ../nbs/04_enrichr.ipynb 38
@patch
def find_img_refs(
    self:MarkdownPage, # Markdown page of interest
    context_lines: int = 3, # Number of lines of context to include around the image
    ):
    "Find all image references in the markdown page and include the context around the image"
    content = self.path.read_text()
    lines = content.splitlines()
    results = []
    
    for i, line in enumerate(lines):
        if re.search(r'!\[[^\]]*\]\(([^)]+)\)', line):
            # Extract context around this line
            start = max(0, i - context_lines)
            end = min(len(lines), i + context_lines + 1)
            context = '\n'.join(lines[start:end])
            
            # Extract image filename
            match = re.search(r'!\[[^\]]*\]\(([^)]+)\)', line)
            results.append(ImgRef({
                "filename": match.group(1),
                "context": context
            }))
    
    return results

# %% ../nbs/04_enrichr.ipynb 41
class ImageRelevance(dspy.Signature):
    """Determine if an image contains substantive content for document understanding.
    
    RELEVANT: Charts, graphs, diagrams, figures, tables, screenshots, flowcharts
    IRRELEVANT: Logos, cover images, decorative elements, headers, footers
    """
    img_filename: str = dspy.InputField()
    surrounding_context: str = dspy.InputField(desc="Text context around the image")
    is_relevant: bool = dspy.OutputField(desc="True only for substantive content like data visualizations")
    reason: str = dspy.OutputField(desc="Brief explanation of decision")


# %% ../nbs/04_enrichr.ipynb 42
@patch
def classify_imgs(
    self:MarkdownPage, # Markdown page of interest
    img_refs: list[ImgRef], # List of image references
    ):
    "Classify images in the markdown page"
    classifier = dspy.ChainOfThought(ImageRelevance)
    for img_ref in img_refs:
        result = classifier(
            img_filename=img_ref.filename,
            surrounding_context=img_ref.context,
            page_nb=1  # We could make this dynamic if needed
        )
        img_ref.is_relevant = result.is_relevant
        img_ref.reason = result.reason
    return img_refs

# %% ../nbs/04_enrichr.ipynb 45
def describe_img(
    img_path: Path, # Path to the image
    context: str, # Context of the image
    api_key: str = GEMINI_API_KEY, # API key for the Gemini model
    model: str = cfg.lm, # Model to use
    ):
    "Describe an image using an LLM"
    with open(img_path, "rb") as image_file:
        base64_image = base64.b64encode(image_file.read()).decode('utf-8')
    
    # Auto-detect image format
    img_format = img_path.suffix.lower().replace('.', '')
    if img_format == 'jpg': img_format = 'jpeg'
    
    prompt = f"""Provide a concise paragraph description of this image for evaluation report analysis. Include: type of content, main topic, key data/statistics, trends, and takeaways. Write as flowing text, not numbered points. Context: {context}"""
    response = completion(
        model=model,
        messages=[{
            "role": "user", 
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/{img_format};base64,{base64_image}"}}
            ]
        }],
        api_key=api_key
    )
    return response.choices[0].message.content

# %% ../nbs/04_enrichr.ipynb 46
@patch
def describe_imgs(
    self:MarkdownPage, # Markdown page of interest
    img_refs: list[ImgRef], # List of image references
    img_dir: str # Image directory
    ):
    "Describe images in the markdown page"
    for img_ref in img_refs:
        if img_ref.is_relevant:
            img_path = Path(img_dir) / img_ref.filename
            description = describe_img(img_path, img_ref.context, GEMINI_API_KEY)
            img_ref.description = description
    return img_refs

# %% ../nbs/04_enrichr.ipynb 49
@patch
def replace_imgs_with_desc(
    self:MarkdownPage, # Markdown page of interest
    img_refs, # List of image references
    enriched_dir: str = cfg.enriched_dir, # Enriched directory
    ):
    "Replace images with their descriptions in the markdown page"
    enriched_path = self.path.parent.parent / enriched_dir
    enriched_path.mkdir(exist_ok=True)
    
    content = self.path.read_text()
    for img_ref in img_refs:
        if img_ref.is_relevant and hasattr(img_ref, 'description'):
            pattern = f'!\\[[^\\]]*\\]\\({re.escape(img_ref.filename)}\\)'
            content = re.sub(pattern, img_ref.description, content)
    
    enriched_file = enriched_path / self.path.name
    enriched_file.write_text(content)
    return enriched_file

# %% ../nbs/04_enrichr.ipynb 50
def copy_page_to_enriched(
    page, # Page to copy
    enriched_dir: str = cfg.enriched_dir, # Enriched directory
    ):
    "Copy a page to the enriched directory"
    enriched_path = page.parent.parent / enriched_dir
    enriched_path.mkdir(exist_ok=True)
    return shutil.copy(page, enriched_path)

# %% ../nbs/04_enrichr.ipynb 51
def process_single_page(
    page, # Page to process
    img_dir, # Image directory
    enriched_dir: str = cfg.enriched_dir, # Enriched directory
    ):
    "Process a single page"
    md_page = MarkdownPage(page)
    # Pipeline: find → classify → describe → replace
    img_refs = md_page.find_img_refs()
    
    if not img_refs: return copy_page_to_enriched(page, enriched_dir)
    
    classified_refs = md_page.classify_imgs(img_refs)
    described_refs = md_page.describe_imgs(classified_refs, img_dir)
    return md_page.replace_imgs_with_desc(described_refs)

# %% ../nbs/04_enrichr.ipynb 52
def enrich_images(
    pages_dir, # Pages directory
    img_dir, # Image directory
    n_workers=2, # Number of workers
    ):
    "Enrich images in the pages directory"
    pages = Path(pages_dir).ls(file_exts=".md")
    
    pages_with_imgs = []
    for page in pages:
        if has_images(page):
            pages_with_imgs.append(page)
        else:
            copy_page_to_enriched(page)
    
    if pages_with_imgs:
        process_fn = partial(process_single_page, img_dir=img_dir)
        parallel(process_fn, pages_with_imgs, n_workers=n_workers, threadpool=True, progress=True)
        
    print(f"✓ Processed {len(pages)} pages ({len(pages_with_imgs)} with images)")
