"""Scale up evaluation report mapping against evaluation frameworks using agentic workflows"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_mappr.ipynb.

# %% auto 0
__all__ = ['GEMINI_API_KEY', 'cfg', 'lm', 'traces_dir', 'find_section_path', 'get_content_tool', 'format_enabler_theme',
           'format_crosscutting_theme', 'format_gcm_theme', 'Overview', 'Exploration', 'Assessment', 'Phase',
           'TraceContext', 'Synthesis', 'setup_logger', 'setup_trace_logging', 'ThemeAnalyzer', 'PipelineResults',
           'PipelineOrchestrator', 'get_stage1_covered_context']

# %% ../nbs/06_mappr.ipynb 5
from pathlib import Path
from functools import reduce
from toolslm.md_hier import *
from rich import print
import json
from fastcore.all import *
from enum import Enum
import logging
import uuid
from datetime import datetime
from typing import List, Callable
import dspy
from asyncio import Semaphore, gather, sleep
import time
from collections import defaultdict
import copy

from .frameworks import (EvalData, 
                                 IOMEvalData, 
                                 FrameworkInfo, 
                                 Framework,
                                 FrameworkCat)

# %% ../nbs/06_mappr.ipynb 6
from dotenv import load_dotenv
import os

load_dotenv()
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')

# %% ../nbs/06_mappr.ipynb 7
cfg = AttrDict({
    'lm': 'gemini/gemini-2.0-flash-exp',
    'api_key': GEMINI_API_KEY,
    'max_tokens': 8192,
    'track_usage': False,
    'rpm_limit': 15, 
    'call_delay': 6, # in seconds
    'dirs': AttrDict({
        'data': '.evaluator',
        'trace': 'traces'
    }),
    'verbosity': 1,
    'cache': AttrDict({
        'is_active': True,
        'delay': 0.1 # threshold in seconds below which we consider the response is cached
    })
})

# %% ../nbs/06_mappr.ipynb 8
lm = dspy.LM(cfg.lm, api_key=cfg.api_key, cache=cfg.cache.is_active)
dspy.configure(lm=lm)

# %% ../nbs/06_mappr.ipynb 13
def find_section_path(
    hdgs: dict, # The nested dictionary structure
    target_section: str # The section name to find
) -> list: # The nested key path for the given section name
    "Find the nested key path for a given section name."
    def search_recursive(current_dict, path=[]):
        for key, value in current_dict.items():
            current_path = path + [key]
            if key == target_section:
                return current_path
            if isinstance(value, dict):
                result = search_recursive(value, current_path)
                if result:
                    return result
        return None
    
    return search_recursive(hdgs)

# %% ../nbs/06_mappr.ipynb 17
def get_content_tool(
    hdgs: dict, # The nested dictionary structure
    keys_list: list, # The list of keys to navigate through
    ) -> str: # The content of the section
    "Navigate through nested levels using the exact key strings."
    return reduce(lambda current, key: current[key], keys_list, hdgs).text

# %% ../nbs/06_mappr.ipynb 21
def format_enabler_theme(
    theme: EvalData # The theme object
    ) -> str: # The formatted theme string
    "Format SRF enabler into structured text for LM processing."
    parts = [
        f'## Enabler {theme.id}: {theme.title}',
        '### Description', 
        theme.description
    ]
    return '\n'.join(parts)

# %% ../nbs/06_mappr.ipynb 24
def format_crosscutting_theme(
    theme: EvalData # The theme object
    ) -> str: # The formatted theme string
    "Format SRF cross-cutting into structured text for LM processing."
    parts = [
        f'## Cross-cutting {theme.id}: {theme.title}',
        '### Description', 
        theme.description
    ]
    return '\n'.join(parts)

# %% ../nbs/06_mappr.ipynb 27
def format_gcm_theme(
    theme: dict # The GCM theme object from gcm_small
    ) -> str: # The formatted theme string
    "Format GCM objective into structured text for LM processing."
    parts = [
        f'## GCM Objective {theme["id"]}: {theme["title"]}',
        '### Core Theme', 
        theme["core_theme"]
    ]
    
    if theme.get("key_principles"):
        parts.extend(['### Key Principles', ', '.join(theme["key_principles"])])
    
    if theme.get("target_groups"):
        parts.extend(['### Target Groups', ', '.join(theme["target_groups"])])
        
    if theme.get("main_activities"):
        parts.extend(['### Main Activities', ', '.join(theme["main_activities"])])
    
    return '\n'.join(parts)

# %% ../nbs/06_mappr.ipynb 32
class Overview(dspy.Signature):
    "Based on framework theme to map and report's TOC determine the sections to explore first."
    theme: str = dspy.InputField(desc="Theme being analyzed")
    prior_coverage_context: str = dspy.InputField(desc="Themes already covered in this report, indicating its scope and analytical focus", default="")
    all_headings: str = dspy.InputField(desc="Complete document structure")
    priority_sections: List[str] = dspy.OutputField(desc="Ordered list of section keys to explore first")
    strategy: str = dspy.OutputField(desc="Reasoning for this exploration strategy")

# %% ../nbs/06_mappr.ipynb 35
class Exploration(dspy.Signature):
    "Decide next exploration step for theme to be mapped based on current findings and available sections."
    theme: str = dspy.InputField(desc="Theme being analyzed")
    prior_coverage_context: str = dspy.InputField(desc="Themes already covered in this report, indicating its scope and analytical focus", default="")
    current_findings: str = dspy.InputField(desc="Evidence found so far")
    available_sections: str = dspy.InputField(desc="Remaining sections to explore")
    next_section: str = dspy.OutputField(desc="Next section key to explore, or 'DONE' if sufficient")
    reasoning: str = dspy.OutputField(desc="Why this section or why stopping")

# %% ../nbs/06_mappr.ipynb 37
class Assessment(dspy.Signature):
    "Assess if current evidence is sufficient for theme analysis."
    theme: str = dspy.InputField(desc="Theme being analyzed")
    prior_coverage_context: str = dspy.InputField(desc="Themes already covered in this report, indicating its scope and analytical focus", default="")
    evidence_so_far: str = dspy.InputField(desc="All evidence collected")
    sections_explored: str = dspy.InputField(desc="Sections already checked")
    sufficient: bool = dspy.OutputField(desc="Is evidence sufficient to make conclusion?")
    confidence_score: float = dspy.OutputField(desc="Confidence in current findings (0-1)")
    next_priority: str = dspy.OutputField(desc="If continuing, what type of section to prioritize")
    reasoning: str = dspy.OutputField(desc="Why this assessment was made")

# %% ../nbs/06_mappr.ipynb 39
class Phase(Enum):
    "Pipeline phase number."
    STAGE1 = "stage1"
    STAGE2 = "stage2"
    STAGE3 = "stage3"
    def __str__(self): return self.value

# %% ../nbs/06_mappr.ipynb 40
class TraceContext(AttrDict):
    "Context for tracing the mapping process"
    def __init__(self, 
                 report_id:str,  # Report identifier
                 phase:Phase,  # Pipeline phase number
                 framework:FrameworkInfo,  # Framework info (name, category, theme_id)
                 ): 
        # self.run_id = str(uuid.uuid4())[:8]  # Short unique ID
        store_attr()

    def __repr__(self):
        return f"TraceContext(report_id={self.report_id}, phase={self.phase}, framework={self.framework})"

# %% ../nbs/06_mappr.ipynb 42
class Synthesis(dspy.Signature):
    "Provide detailed rationale and synthesis of theme analysis."
    trace_ctx: str = dspy.InputField(desc="Trace context")
    theme: str = dspy.InputField(desc="Theme being analyzed")
    prior_coverage_context: str = dspy.InputField(desc="Themes already covered in this report, indicating its scope and analytical focus", default="")
    all_evidence: str = dspy.InputField(desc="All collected evidence")
    sections_explored: str = dspy.InputField(desc="List of sections explored")
    theme_covered: bool = dspy.OutputField(desc="Final decision on theme coverage")
    confidence_explanation: str = dspy.OutputField(desc="Detailed explanation of confidence score")
    evidence_summary: str = dspy.OutputField(desc="Key evidence supporting the conclusion")
    gaps_identified: str = dspy.OutputField(desc="Any gaps or missing aspects")

# %% ../nbs/06_mappr.ipynb 45
traces_dir = Path.home() / cfg.dirs.data / cfg.dirs.trace
traces_dir.mkdir(parents=True, exist_ok=True)

# %% ../nbs/06_mappr.ipynb 46
def setup_logger(name, handler, level=logging.INFO, **kwargs):
    "Helper function to setup a logger with common configuration"
    logger = logging.getLogger(name)
    logger.handlers.clear()
    logger.addHandler(handler)
    logger.setLevel(level)
    for k,v in kwargs.items(): setattr(logger, k, v)
    return logger

# %% ../nbs/06_mappr.ipynb 47
def setup_trace_logging(report_id, verbosity=cfg.verbosity):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f'{report_id}_{timestamp}.jsonl'
    file_handler = logging.FileHandler(traces_dir / filename, mode='w')
    setup_logger('trace.file', file_handler)    
    console_handler = logging.StreamHandler()
    setup_logger('trace.console', console_handler, verbosity=verbosity)

# %% ../nbs/06_mappr.ipynb 48
class ThemeAnalyzer(dspy.Module):
    """
    Analyzes a theme across a document by iteratively exploring sections, collecting evidence, and synthesizing findings. 
    Uses a structured pipeline of overview -> exploration -> assessment -> synthesis.
    """
    def __init__(self, 
                 overview_sig:dspy.Signature, # Overview signature
                 exploration_sig:dspy.Signature, # Exploration signature
                 assessment_sig:dspy.Signature, # Assessment signature
                 synthesis_sig:dspy.Signature, # Synthesis signature
                 trace_ctx:TraceContext, # Trace context
                 confidence_threshold:float=0.8, # Confidence threshold
                 max_iter:int=10, # Maximum number of iterations in the ReAct loop
                 semaphore=None # Semaphore for rate limiting
                 ):
        self.overview = dspy.ChainOfThought(overview_sig)
        self.explore = dspy.ChainOfThought(exploration_sig)
        self.assess = dspy.ChainOfThought(assessment_sig)
        self.synthesize = dspy.ChainOfThought(synthesis_sig)
        self.max_iter = max_iter
        self.trace_ctx = trace_ctx
        self.confidence_threshold = confidence_threshold
        self.semaphore = semaphore

# %% ../nbs/06_mappr.ipynb 49
@patch
def _log_trace(self:ThemeAnalyzer, event, **extra_data):
    file_logger = logging.getLogger('trace.file')
    console_logger = logging.getLogger('trace.console')
    
    base_data = {
        "timestamp": datetime.now().isoformat(),
        "event": event,
        "report_id": self.trace_ctx.report_id,
        "phase": str(self.trace_ctx.phase),
        "framework": str(self.trace_ctx.framework.name),
        "framework_category": str(self.trace_ctx.framework.category),
        "framework_theme_id": str(self.trace_ctx.framework.theme_id),
    }
    base_data.update(extra_data)
    
    # File logger - always full JSON
    file_logger.info(json.dumps(base_data, indent=2))
    
    # Console logger - verbosity-based formatting
    if hasattr(console_logger, 'verbosity'):
        if console_logger.verbosity == 1:
            console_msg = f"{base_data['report_id']} - {base_data['phase']}"
        elif console_logger.verbosity == 2:
            console_msg = f"{base_data['report_id']} - {base_data['phase']} - {base_data['framework']} - {base_data['framework_category']} - {base_data['framework_theme_id']} - {base_data['event']}"
        else:  # verbosity == 3
            console_msg = json.dumps(base_data, indent=2)
        
        console_logger.info(console_msg)

# %% ../nbs/06_mappr.ipynb 50
@patch    
async def _rate_limited_fn(self:ThemeAnalyzer, mod, **kwargs):
    async with self.semaphore:
        start = time.time()
        result = await mod.acall(**kwargs)
        
        # Check if cached (fast response + no usage)
        elapsed = time.time() - start
        if elapsed > cfg.cache.delay: await sleep(cfg.call_delay)
        return result

# %% ../nbs/06_mappr.ipynb 51
@patch
async def aforward(
    self:ThemeAnalyzer, 
    theme: str, # The formatted theme to analyze
    headings: dict, # The headings TOC of the document
    get_content_fn:Callable=get_content_tool, # The function to get the content of a section using `hdgs[keys_list].text` for instance
    prior_coverage_context: str = "" # The themes already covered in this report, indicating its scope and analytical focus
    ) -> Synthesis: # Synthesized analysis results including theme coverage, confidence, evidence and gaps
    "Executes a structured analysis process."
    self._log_trace(event="Starting Analysis", theme=theme)
    priority_sections = await self.get_overview(theme, headings, prior_coverage_context)
    evidence = await self.explore_iteratively(theme, priority_sections, headings, get_content_fn, prior_coverage_context)
    return await self.synthesize_findings(theme, evidence, prior_coverage_context)

# %% ../nbs/06_mappr.ipynb 52
@patch
async def get_overview(
    self:ThemeAnalyzer, 
    theme: str, # The formatted theme to analyze
    headings: dict, # The headings TOC of the document
    prior_coverage_context: str = ""
    ) -> Overview:
    "Based on framework theme to map and report's TOC determine the sections to explore first."
    overview = await self._rate_limited_fn(
        self.overview, 
        theme=theme, 
        all_headings=str(headings), 
        prior_coverage_context=prior_coverage_context)
    
    self._log_trace(
        event="Overview", 
        priority_sections=overview.priority_sections, 
        strategy=overview.strategy)
    return overview.priority_sections

# %% ../nbs/06_mappr.ipynb 53
@patch
async def explore_iteratively(
    self:ThemeAnalyzer, 
    theme: str, # The formatted theme to analyze
    priority_sections: list, # The sections to explore first
    headings: dict, # The headings TOC of the document
    get_content_fn: Callable, # The function to get the content of a section using `hdgs[keys_list].text` for instance
    prior_coverage_context: str = ""
    ) -> dict:
    "Iteratively explore the sections to collect evidence."
    evidence_collected = []
    sections_explored = []
    available_sections = priority_sections.copy()
    
    for i in range(self.max_iter):
        if not available_sections:
            self._log_trace(event="Iterative Exploration", iteration_nb=i+1, decision="No more sections to explore, stopping")
            break
            
        if await self.should_stop_exploring(theme, evidence_collected, sections_explored):   
            break
        
        decision = await self.make_exploration_decision(theme, evidence_collected, available_sections, prior_coverage_context)
        self._log_trace(
            event="Iterative Exploration", 
            iteration_nb=i+1, 
            decision=decision.next_section, 
            reasoning=decision.reasoning)
        
        if decision.next_section == 'DONE':
            self._log_trace(event="Iterative Exploration", iteration_nb=i+1, decision="Done")
            break
        
        evidence_collected, sections_explored = self.process_section(decision, 
                                                                     headings, 
                                                                     get_content_fn, 
                                                                     evidence_collected, 
                                                                     sections_explored, 
                                                                     available_sections)
    
    return {"evidence": evidence_collected, "sections": sections_explored}


# %% ../nbs/06_mappr.ipynb 54
@patch
async def make_exploration_decision(
    self:ThemeAnalyzer, 
    theme: str, # The formatted theme to analyze
    evidence_collected: list, # The evidence collected so far
    available_sections: list, # The sections to explore
    prior_coverage_context: str = ""
    ):    
    "Make a decision on the next section to explore."
    decision = await self._rate_limited_fn(
        self.explore, 
        theme=theme, 
        current_findings="\n\n".join(evidence_collected) if evidence_collected else "No evidence collected yet", 
        available_sections=str(available_sections),
        prior_coverage_context=prior_coverage_context
        )
    
    return decision


# %% ../nbs/06_mappr.ipynb 55
@patch
async def should_stop_exploring(
    self:ThemeAnalyzer, 
    theme: str, # The formatted theme to analyze
    evidence_collected: list, # The evidence collected so far
    sections_explored: list, # The sections explored so far
    prior_coverage_context: str = ""
    ):
    "Check if the exploration should stop based on the evidence collected and the sections explored."
    if not evidence_collected:
        return False
    assessment = await self._rate_limited_fn(
        self.assess, 
        theme=theme,
        evidence_so_far="\n\n".join(evidence_collected),
        sections_explored=str(sections_explored),
        prior_coverage_context=prior_coverage_context
    )
    
    self._log_trace(
        "Should stop exploring", 
        assessment=assessment.sufficient, 
        confidence=assessment.confidence_score,
        reasoning=assessment.reasoning
        )
    
    return assessment.sufficient and assessment.confidence_score > self.confidence_threshold

# %% ../nbs/06_mappr.ipynb 56
@patch
def process_section(self:ThemeAnalyzer, decision, headings, get_content_fn, evidence_collected, sections_explored, available_sections):
    path = find_section_path(headings, decision.next_section)
    
    if path:
        content = get_content_fn(headings, path)
        evidence_collected.append(f"# Section: {decision.next_section}\n## Content\n{content}")
        sections_explored.append(decision.next_section)
        if decision.next_section in available_sections:
            available_sections.remove(decision.next_section)
    else:
        # No path found for section! TBD
        pass
    
    return evidence_collected, sections_explored

# %% ../nbs/06_mappr.ipynb 57
@patch
async def synthesize_findings(self:ThemeAnalyzer, theme, evidence, prior_coverage_context):
    synthesis = await self._rate_limited_fn(
        self.synthesize, 
        trace_ctx=str(self.trace_ctx),
        theme=theme,
        all_evidence="\n\n".join(evidence["evidence"]),
        sections_explored=str(evidence["sections"]),
        prior_coverage_context=prior_coverage_context
    )
    
    self._log_trace("Synthesis", 
                    theme=theme, 
                    reasoning=synthesis.reasoning,
                    theme_covered=synthesis.theme_covered,
                    confidence_explanation=synthesis.confidence_explanation,
                    evidence_summary=synthesis.evidence_summary,
                    gaps_identified=synthesis.gaps_identified
                    )
    
    synthesis.framework_name = self.trace_ctx.framework.name
    synthesis.framework_category = self.trace_ctx.framework.category  
    synthesis.framework_theme_id = self.trace_ctx.framework.theme_id
    return synthesis

# %% ../nbs/06_mappr.ipynb 77
class PipelineResults(dict):
    def __init__(self):
        super().__init__()
        self[Phase.STAGE1] = defaultdict(lambda: defaultdict(dict))
        self[Phase.STAGE2] = defaultdict(lambda: defaultdict(dict))
        self[Phase.STAGE3] = defaultdict(lambda: defaultdict(dict))

# %% ../nbs/06_mappr.ipynb 78
@patch
def __call__(self:PipelineResults, stage=Phase.STAGE1, filter_type="all"):
    themes = []
    for frameworks in self[stage].values():
        for categories in frameworks.values():
            for theme in categories.values():
                if filter_type == "all" or \
                   (filter_type == "covered" and theme.theme_covered) or \
                   (filter_type == "uncovered" and not theme.theme_covered):
                    themes.append(theme)
    return themes

# %% ../nbs/06_mappr.ipynb 79
class PipelineOrchestrator:
    "Orchestrator for the IOM evaluation report mapping pipeline"
    def __init__(self, 
                 report_id:str, # Report identifier
                 headings:dict, # Report headings
                 get_content_fn:Callable, # Function to get the content of a section
                 eval_data:EvalData, # Evaluation data
                 verbosity:int=2, # Verbosity level
                 ):
        store_attr()
        setup_trace_logging(report_id, verbosity)
        self.results = PipelineResults()

# %% ../nbs/06_mappr.ipynb 80
@patch
async def run_stage1(self:PipelineOrchestrator, semaphore):
    "Run stage 1 of the pipeline"
    setup_trace_logging(self.report_id, self.verbosity)
    analyzers = []
    
    collections = [
        (self.eval_data.srf_enablers, FrameworkCat.ENABLERS, format_enabler_theme),
        (self.eval_data.srf_crosscutting_priorities, FrameworkCat.CROSSCUT, format_crosscutting_theme)
    ]

    for items, framework_cat, format_fn in collections:
        for item in items:
            trace_ctx = TraceContext(self.report_id, Phase.STAGE1, FrameworkInfo(Framework.SRF, framework_cat, item.id))
            theme = format_fn(item)
            analyzer = ThemeAnalyzer(Overview, Exploration, Assessment, Synthesis, trace_ctx, semaphore=semaphore)
            analyzers.append((analyzer, theme))

    results = await gather(*[analyzer.acall(theme, self.headings, self.get_content_fn) 
                             for analyzer, theme in analyzers])
    for result in results: 
        self.results[Phase.STAGE1][result.framework_name][result.framework_category][result.framework_theme_id] = result

# %% ../nbs/06_mappr.ipynb 84
def get_stage1_covered_context(results: PipelineResults, eval_data: EvalData) -> str:
    "Get and format covered themes in Stage 1."
    covered_themes = results(Phase.STAGE1, filter_type="covered")
    if not covered_themes: return ""
    
    context_parts = []
    for theme in covered_themes:
        if theme.framework_category == str(FrameworkCat.ENABLERS):
            theme_data = next(t for t in eval_data.srf_enablers if t.id == theme.framework_theme_id)
        elif theme.framework_category == str(FrameworkCat.CROSSCUT):
            theme_data = next(t for t in eval_data.srf_crosscutting_priorities if t.id == theme.framework_theme_id)
        
        context_parts.append(f"- **{theme.framework_category} {theme_data.id}**: {theme_data.title}")
    
    return f"### Report Preliminary Context\nThis evaluation report covers the following Strategic Results Framework themes:\n" + "\n".join(context_parts)


# %% ../nbs/06_mappr.ipynb 87
@patch
async def run_stage2(self:PipelineOrchestrator, semaphore):
    "Run stage 2 of the pipeline - GCM objectives analysis"
    setup_trace_logging(self.report_id, self.verbosity)
    stage1_context = get_stage1_covered_context(self.results, self.eval_data)
    analyzers = []
    
    for gcm_obj in gcm_small:
        trace_ctx = TraceContext(self.report_id, Phase.STAGE2, FrameworkInfo(Framework.GCM, FrameworkCat.OBJS, gcm_obj["id"]))
        theme = format_gcm_theme(gcm_obj)
        analyzer = ThemeAnalyzer(Overview, Exploration, Assessment, Synthesis, trace_ctx, semaphore=semaphore)
        analyzers.append((analyzer, theme, stage1_context))

    results = await gather(*[analyzer.acall(theme, self.headings, self.get_content_fn, context) 
                             for analyzer, theme, context in analyzers])
    
    for result in results: 
        self.results[Phase.STAGE2][result.framework_name][result.framework_category][result.framework_theme_id] = result
