"""
Knowledge analysis module for LLM Ripper.

This module implements Part II of the framework: Analysis and synthesis 
of extracted knowledge components.
"""

import torch
import numpy as np

try:
    from sklearn.decomposition import PCA  # type: ignore
    from sklearn.cluster import KMeans  # type: ignore
    from sklearn.metrics import silhouette_score  # type: ignore
except Exception:
    PCA = None  # type: ignore
    KMeans = None  # type: ignore
    silhouette_score = None  # type: ignore
try:
    import umap  # type: ignore
except Exception:  # optional for visualization; not required for tests
    umap = None  # type: ignore
import json
import logging
from pathlib import Path
from typing import Dict, List, Any, Optional
from dataclasses import dataclass

try:
    from scipy.spatial.distance import cosine  # type: ignore
    from scipy.stats import spearmanr  # type: ignore
except Exception:

    def cosine(a, b):  # type: ignore
        # Minimal fallback cosine distance using numpy
        a = np.asarray(a)
        b = np.asarray(b)
        na = np.linalg.norm(a)
        nb = np.linalg.norm(b)
        if na == 0 or nb == 0:
            return 1.0
        return 1.0 - float(np.dot(a, b) / (na * nb))

    def spearmanr(x, y):  # type: ignore
        return 0.0, 1.0


try:
    import h5py  # type: ignore
except Exception:
    h5py = None  # type: ignore
# tqdm is not used directly; avoid import to keep lightweight

from ..utils.config import ConfigManager
from ..utils.model_loader import ModelLoader

logger = logging.getLogger(__name__)


@dataclass
class AttentionHeadMetrics:
    """Metrics for attention head analysis."""

    head_id: str
    syntactic_score: float
    factual_score: float
    functional_label: str
    layer_idx: int
    head_idx: int


@dataclass
class EmbeddingMetrics:
    """Metrics for embedding analysis."""

    perplexity_score: float
    semantic_coverage: float
    dimension_analysis: Dict[str, Any]


@dataclass
class FFNMetrics:
    """Metrics for FFN analysis."""

    pca_variance_explained: float
    cluster_purity_score: float
    conceptual_clusters: Dict[str, Any]
    layer_idx: int


class KnowledgeAnalyzer:
    """
    Analyzes extracted knowledge components to compute interpretability metrics.

    Implements Section 4 and 5 of the framework: Knowledge analysis and cataloging.
    """

    def __init__(self, config: ConfigManager):
        self.config = config
        self.model_loader = ModelLoader(
            cache_dir=config.get("model_cache_dir"), device=config.get("device")
        )

    def analyze_knowledge_bank(
        self,
        knowledge_bank_dir: str,
        activations_file: Optional[str] = None,
        output_dir: str = "./analysis_output",
    ) -> Dict[str, Any]:
        """
        Analyze all components in a knowledge bank.

        Args:
            knowledge_bank_dir: Directory containing extracted components
            activations_file: Optional HDF5 file with captured activations
            output_dir: Directory to save analysis results

        Returns:
            Dictionary containing analysis results
        """
        logger.info(f"Starting knowledge bank analysis: {knowledge_bank_dir}")

        knowledge_bank_path = Path(knowledge_bank_dir)
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)

        # Load extraction metadata
        with open(knowledge_bank_path / "extraction_metadata.json", "r") as f:
            extraction_metadata = json.load(f)

        import uuid

        analysis_results = {
            "source_model": extraction_metadata["source_model"],
            "analysis_config": self.config.config.copy(),
            "component_analysis": {},
            "run_id": str(uuid.uuid4()),
            "seed": self.config.get("seed"),
        }

        # Analyze embeddings
        if (knowledge_bank_path / "embeddings").exists():
            embedding_analysis = self.analyze_embeddings(
                knowledge_bank_path / "embeddings", extraction_metadata["source_model"]
            )
            analysis_results["component_analysis"]["embeddings"] = embedding_analysis

        # Analyze attention heads
        if (knowledge_bank_path / "heads").exists():
            attention_analysis = self.analyze_attention_heads(
                knowledge_bank_path / "heads", activations_file
            )
            analysis_results["component_analysis"][
                "attention_heads"
            ] = attention_analysis

        # Analyze FFN layers
        if (knowledge_bank_path / "ffns").exists():
            ffn_analysis = self.analyze_ffn_layers(
                knowledge_bank_path / "ffns", activations_file
            )
            analysis_results["component_analysis"]["ffn_layers"] = ffn_analysis

        # Create head catalog
        head_catalog = self.create_head_catalog(
            analysis_results["component_analysis"].get("attention_heads", {})
        )
        analysis_results["head_catalog"] = head_catalog

        # Save analysis results
        with open(output_path / "analysis_results.json", "w") as f:
            json.dump(analysis_results, f, indent=2, default=self._json_serializer)

        # Save head catalog separately
        with open(output_path / "head_catalog.json", "w") as f:
            json.dump(head_catalog, f, indent=2, default=self._json_serializer)

        logger.info(f"Analysis completed. Results saved to: {output_path}")

        return analysis_results

    def analyze_embeddings(
        self, embeddings_dir: Path, model_name: str
    ) -> Dict[str, Any]:
        """
        Analyze embedding semantic coverage using downstream perplexity.

        Implements Section 4.1: Embeddings → Semantic Coverage
        """
        logger.info("Analyzing embeddings semantic coverage...")

        # Load embedding weights
        with open(embeddings_dir / "config.json", "r") as f:
            embedding_config = json.load(f)

        # Load embeddings via storage helpers (supports sharded pt)
        from ..utils.storage import load_pt_or_safetensors, load_sharded_pt

        embedding_weights = None
        idx = embeddings_dir / "embeddings.index.json"
        if idx.exists():
            try:
                embedding_weights = load_sharded_pt(idx)
            except Exception:
                embedding_weights = None
        if embedding_weights is None:
            st = embeddings_dir / "embeddings.safetensors"
            pt = embeddings_dir / "embeddings.pt"
            if st.exists():
                embedding_weights = load_pt_or_safetensors(st)["weight"]
            elif pt.exists():
                t = torch.load(str(pt))
                embedding_weights = (
                    t if isinstance(t, torch.Tensor) else t.get("weight")
                )
            else:
                raise FileNotFoundError("Embeddings file not found")

        # Compute intrinsic metrics
        intrinsic_metrics = self._compute_embedding_intrinsic_metrics(embedding_weights)

        # Compute downstream perplexity with frozen embeddings
        perplexity_score = self._compute_downstream_perplexity(
            embedding_weights, model_name, embedding_config
        )

        return {
            "metrics": EmbeddingMetrics(
                perplexity_score=perplexity_score,
                semantic_coverage=intrinsic_metrics["semantic_coverage"],
                dimension_analysis=intrinsic_metrics["dimension_analysis"],
            ).__dict__,
            "config": embedding_config,
            "intrinsic_metrics": intrinsic_metrics,
        }

    def _compute_embedding_intrinsic_metrics(
        self, embedding_weights: torch.Tensor
    ) -> Dict[str, Any]:
        """Compute intrinsic metrics for embeddings."""

        # Convert to numpy for analysis
        embeddings_np = embedding_weights.cpu().numpy()

        # Compute PCA to analyze dimensionality
        if PCA is None:
            raise RuntimeError(
                "scikit-learn is required for PCA-based analysis but is not installed."
            )
        pca = PCA()
        pca.fit(embeddings_np)

        # Find effective dimensionality (95% variance explained)
        cumsum_var = np.cumsum(pca.explained_variance_ratio_)
        effective_dim = np.argmax(cumsum_var >= 0.95) + 1

        # Compute average cosine similarity (semantic density)
        sample_indices = np.random.choice(
            len(embeddings_np), min(1000, len(embeddings_np)), replace=False
        )
        sample_embeddings = embeddings_np[sample_indices]

        similarities = []
        for i in range(len(sample_embeddings)):
            for j in range(i + 1, min(i + 100, len(sample_embeddings))):
                sim = 1 - cosine(sample_embeddings[i], sample_embeddings[j])
                similarities.append(sim)

        avg_similarity = np.mean(similarities)

        return {
            "semantic_coverage": 1.0
            - avg_similarity,  # Lower similarity = higher coverage
            "dimension_analysis": {
                "total_dimensions": len(pca.explained_variance_ratio_),
                "effective_dimensions": int(effective_dim),
                "variance_concentration": float(pca.explained_variance_ratio_[0]),
                "explained_variance_95": float(cumsum_var[effective_dim - 1]),
            },
            "average_cosine_similarity": float(avg_similarity),
        }

    def _compute_downstream_perplexity(
        self,
        embedding_weights: torch.Tensor,
        model_name: str,
        embedding_config: Dict[str, Any],
    ) -> float:
        """Compute perplexity using frozen embeddings in a small student model."""
        # Allow disabling due to cost
        if not self.config.get("enable_downstream_perplexity", False):
            logger.info("Downstream perplexity disabled by config; returning inf")
            return float("inf")

        try:
            # Create a small student model
            from transformers import GPT2Config, GPT2LMHeadModel

            student_config = GPT2Config(
                vocab_size=embedding_config["vocab_size"],
                n_embd=embedding_config["hidden_size"],
                n_layer=2,  # Small model
                n_head=4,
                n_positions=512,
            )

            student_model = GPT2LMHeadModel(student_config)

            # Replace embeddings with frozen extracted embeddings
            with torch.no_grad():
                student_model.transformer.wte.weight.copy_(embedding_weights)

            # Freeze embeddings
            student_model.transformer.wte.weight.requires_grad = False

            # Train a tiny student model for a few steps on a small corpus with frozen embeddings
            from ..utils.data_manager import DataManager

            data_manager = DataManager(self.config)
            corpus = data_manager._create_diverse_corpus()
            texts = (
                corpus["text"]
                if isinstance(corpus["text"], list)
                else list(corpus["text"])
            )

            # Prepare tokenizer for student model (simple character-level tokenizer approximation)
            # To keep it general without external downloads, use a basic whitespace split and map to ids
            vocab = {"<pad>": 0, "<eos>": 1}
            for t in texts:
                for tok in t.split():
                    if tok not in vocab:
                        vocab[tok] = len(vocab)

            # Build simple dataset of token ids
            sequences = []
            max_len = 64
            for t in texts:
                ids = [vocab.get(tok, 0) for tok in t.split()][: max_len - 1] + [1]
                sequences.append(ids)

            # Resize student model embeddings to match our vocab size and copy in available vectors by hashing words
            student_model.resize_token_embeddings(len(vocab))
            with torch.no_grad():
                for tok, idx in vocab.items():
                    if idx < student_model.transformer.wte.weight.shape[
                        0
                    ] and tok not in ("<pad>", "<eos>"):
                        # Initialize by averaging subword embeddings of characters if available else random small init
                        # Here we use a deterministic hash to pick some rows from the provided embedding matrix
                        h = abs(hash(tok)) % embedding_weights.shape[0]
                        student_model.transformer.wte.weight[idx].copy_(
                            embedding_weights[h]
                        )

            # Freeze embeddings
            student_model.transformer.wte.weight.requires_grad = False

            device = self.model_loader.device
            student_model.to(device)

            # Simple training loop
            optimizer = torch.optim.AdamW(
                [p for p in student_model.parameters() if p.requires_grad], lr=3e-4
            )
            student_model.train()
            steps = 50
            batch_size = 4
            import random

            random.seed(0)
            for step in range(steps):
                batch = random.sample(sequences, k=min(batch_size, len(sequences)))
                # Pad batch
                max_b = max(len(s) for s in batch)
                inp = torch.full((len(batch), max_b), vocab["<pad>"], dtype=torch.long)
                for i, s in enumerate(batch):
                    inp[i, : len(s)] = torch.tensor(s)
                labels = inp.clone()
                inp = inp.to(device)
                labels = labels.to(device)
                outputs = student_model(input_ids=inp, labels=labels)
                loss = outputs.loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Evaluation perplexity on the same corpus (small-scale demonstration)
            student_model.eval()
            total_loss = 0.0
            total_tokens = 0
            with torch.no_grad():
                for s in sequences:
                    inp = torch.tensor([s], dtype=torch.long).to(device)
                    labels = inp.clone()
                    out = student_model(input_ids=inp, labels=labels)
                    loss = out.loss.item()
                    total_loss += loss * len(s)
                    total_tokens += len(s)
            if total_tokens == 0:
                return float("inf")
            avg_loss = total_loss / total_tokens
            return float(np.exp(avg_loss))

        except Exception as e:
            logger.warning(f"Could not compute downstream perplexity: {e}")
            return float("inf")

    def analyze_attention_heads(
        self, heads_dir: Path, activations_file: Optional[str] = None
    ) -> Dict[
        str, Any
    ]:  # heads_dir is <kb>/heads; we'll infer donor model from kb metadata
        """
        Analyze attention heads for interpretability.

        Implements Section 4.2: Attention Maps → Head Interpretability
        """
        logger.info("Analyzing attention heads...")

        head_analysis = {}

        # Iterate through all layers
        for layer_dir in heads_dir.iterdir():
            if not layer_dir.is_dir() or not layer_dir.name.startswith("layer_"):
                continue

            layer_idx = int(layer_dir.name.split("_")[1])

            # Load layer configuration
            with open(layer_dir / "config.json", "r") as f:
                layer_config = json.load(f)

            # Analyze this layer's attention mechanism
            layer_analysis = self._analyze_layer_attention(
                layer_dir, layer_config, layer_idx, activations_file
            )

            head_analysis[f"layer_{layer_idx}"] = layer_analysis

        return head_analysis

    def _analyze_layer_attention(
        self,
        layer_dir: Path,
        layer_config: Dict[str, Any],
        layer_idx: int,
        activations_file: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Analyze attention patterns for a specific layer."""

        attention_type = layer_config.get("attention_type", "MHA")
        num_heads = layer_config.get(
            "num_heads", layer_config.get("num_query_heads", 0)
        )

        head_metrics = []

        if attention_type == "MHA":
            # For MHA, analyze each head separately
            for head_idx in range(num_heads):
                metrics = self._compute_head_metrics(
                    layer_dir, layer_idx, head_idx, activations_file
                )
                head_metrics.append(metrics)

        elif attention_type in ["GQA", "MQA"]:
            # For GQA/MQA, analyze query groups
            num_kv_heads = layer_config.get("num_key_value_heads", 1)

            for group_idx in range(num_kv_heads):
                metrics = self._compute_head_metrics(
                    layer_dir, layer_idx, group_idx, activations_file, is_group=True
                )
                head_metrics.append(metrics)

        return {
            "attention_type": attention_type,
            "num_heads": num_heads,
            "head_metrics": [
                metric.__dict__ if hasattr(metric, "__dict__") else metric
                for metric in head_metrics
            ],
            "layer_config": layer_config,
        }

    def _compute_head_metrics(
        self,
        layer_dir: Path,
        layer_idx: int,
        head_idx: int,
        activations_file: Optional[str] = None,
        is_group: bool = False,
    ) -> AttentionHeadMetrics:
        """Compute interpretability metrics for a specific attention head."""

        head_id = f"{layer_dir.parent.parent.name}_layer{layer_idx}_{'group' if is_group else 'head'}{head_idx}"

        # Compute syntactic score via diagonal attention dominance on a syntactic corpus
        syntactic_score = self._compute_syntactic_head_score(
            layer_dir, head_idx, activations_file
        )

        # Compute factual score via a sensitivity proxy using activation variance
        factual_score = self._compute_factual_head_score(
            layer_dir, head_idx, activations_file
        )

        # Determine functional label based on scores
        functional_label = self._determine_functional_label(
            syntactic_score, factual_score
        )

        return AttentionHeadMetrics(
            head_id=head_id,
            syntactic_score=syntactic_score,
            factual_score=factual_score,
            functional_label=functional_label,
            layer_idx=layer_idx,
            head_idx=head_idx,
        )

    def _compute_syntactic_head_score(
        self, layer_dir: Path, head_idx: int, activations_file: Optional[str] = None
    ) -> float:
        """Compute syntactic head score (SHS) using diagonal attention dominance.
        If activations_file is provided and contains attention maps for this layer, we compute
        the ratio between diagonal attention mass and total attention as a proxy for syntactic alignment.
        """
        try:
            if activations_file and Path(activations_file).exists():
                with h5py.File(activations_file, "r") as f:
                    # Look for attention activations for this layer
                    key = f"transformer_h_{layer_dir.name.split('_')[1]}_attn"
                    if "activations" in f and key in f["activations"]:
                        layer_group = f["activations"][key]
                        diag_scores = []
                        count = 0
                        for sample_key in layer_group.keys():
                            if sample_key.startswith("sample_"):
                                attn = np.array(
                                    layer_group[sample_key]["activations"]
                                )  # [seq, seq] or [..., seq, seq]
                                # Reduce to 2D attention matrix
                                while attn.ndim > 2:
                                    attn = attn.mean(axis=0)
                                if attn.shape[0] == 0 or attn.shape[1] == 0:
                                    continue
                                # Diagonal dominance
                                diag = np.trace(attn) / max(1, min(attn.shape))
                                total = attn.sum() + 1e-8
                                diag_scores.append(float(diag / total * attn.size))
                                count += 1
                        if diag_scores:
                            return float(np.clip(np.mean(diag_scores), 0.0, 1.0))
        except Exception as e:
            logger.warning(
                f"Failed SHS computation for {layer_dir} head {head_idx}: {e}"
            )
        # Fallback: compute a neutral score
        return 0.5

    def _compute_factual_head_score(
        self, layer_dir: Path, head_idx: int, activations_file: Optional[str] = None
    ) -> float:
        """Compute factual head score (FHS) using sensitivity of next-token prob to masking the head.
        This is an approximation: we estimate importance by comparing LM loss with and without contribution from this layer
        using available hidden states in activations_file if present.
        """
        try:
            if not activations_file or not Path(activations_file).exists():
                return 0.5
            # Without full model forward interception per-head, approximate by layer-level sensitivity
            with h5py.File(activations_file, "r") as f:
                key = f"transformer_h_{layer_dir.name.split('_')[1]}_attn"
                if "activations" not in f or key not in f["activations"]:
                    return 0.5
                layer_group = f["activations"][key]
                norms = []
                for sample_key in list(layer_group.keys())[:10]:
                    if sample_key.startswith("sample_"):
                        attn = np.array(
                            layer_group[sample_key]["activations"]
                        )  # [..., seq, seq]
                        # Head-wise variance as a proxy for contribution
                        var = float(np.var(attn))
                        norms.append(var)
                if norms:
                    # Normalize to [0,1] across heads by ranking
                    val = float(np.mean(norms))
                    return float(np.tanh(val)) if np.isfinite(val) else 0.5
        except Exception as e:
            logger.warning(
                f"Failed FHS computation for {layer_dir} head {head_idx}: {e}"
            )
        return 0.5

    def _determine_functional_label(
        self, syntactic_score: float, factual_score: float
    ) -> str:
        """Determine functional label based on scores."""
        threshold = 0.7

        if syntactic_score > threshold and factual_score > threshold:
            return "multi_functional"
        elif syntactic_score > threshold:
            return "syntactic_dependency"
        elif factual_score > threshold:
            return "factual_retrieval"
        elif syntactic_score > 0.5 or factual_score > 0.5:
            return "weakly_specialized"
        else:
            return "general_purpose"

    def analyze_ffn_layers(
        self, ffns_dir: Path, activations_file: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Analyze FFN layers for conceptual clustering.

        Implements Section 4.3: FFNs → Conceptual Clustering
        """
        logger.info("Analyzing FFN layers...")

        ffn_analysis = {}

        # Iterate through all FFN layers
        for layer_dir in ffns_dir.iterdir():
            if not layer_dir.is_dir() or not layer_dir.name.startswith("layer_"):
                continue

            layer_idx = int(layer_dir.name.split("_")[1])

            # Load layer configuration
            with open(layer_dir / "config.json", "r") as f:
                layer_config = json.load(f)

            # Analyze this layer's FFN
            layer_analysis = self._analyze_layer_ffn(
                layer_dir, layer_config, layer_idx, activations_file
            )

            ffn_analysis[f"layer_{layer_idx}"] = layer_analysis

        return ffn_analysis

    def _analyze_layer_ffn(
        self,
        layer_dir: Path,
        layer_config: Dict[str, Any],
        layer_idx: int,
        activations_file: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Analyze FFN for a specific layer."""

        # Compute PCA and clustering metrics
        metrics = self._compute_ffn_metrics(layer_dir, layer_idx, activations_file)

        return {
            "layer_idx": layer_idx,
            "metrics": metrics.__dict__ if hasattr(metrics, "__dict__") else metrics,
            "layer_config": layer_config,
        }

    def _compute_ffn_metrics(
        self, layer_dir: Path, layer_idx: int, activations_file: Optional[str] = None
    ) -> FFNMetrics:
        """Compute PCA and clustering metrics for FFN."""

        # Load FFN activations; require real data, no synthetic fallback
        if activations_file and Path(activations_file).exists():
            activations = self._load_ffn_activations(activations_file, layer_idx)
        else:
            raise ValueError(
                "Missing activations file for FFN analysis. Provide a real 'activations_file' captured from the model."
            )

        # Compute PCA
        if PCA is None:
            raise RuntimeError(
                "scikit-learn is required for PCA-based analysis but is not installed."
            )
        pca = PCA(n_components=min(50, activations.shape[1]))
        pca_result = pca.fit_transform(activations)

        # Compute variance explained by top components
        n_components = self.config.get("pca_components", 50)
        variance_explained = np.sum(pca.explained_variance_ratio_[:n_components])

        # Perform clustering
        n_clusters = self.config.get("n_clusters", 10)
        if KMeans is None:
            raise RuntimeError(
                "scikit-learn is required for clustering analysis but is not installed."
            )
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        cluster_labels = kmeans.fit_predict(pca_result)

        # Compute an unsupervised clustering quality proxy (silhouette)
        cluster_purity = self._compute_cluster_purity(cluster_labels, activations)

        # Extract conceptual clusters
        conceptual_clusters = self._extract_conceptual_clusters(
            kmeans.cluster_centers_, cluster_labels, n_clusters
        )

        return FFNMetrics(
            pca_variance_explained=float(variance_explained),
            cluster_purity_score=cluster_purity,
            conceptual_clusters=conceptual_clusters,
            layer_idx=layer_idx,
        )

    def _load_ffn_activations(
        self, activations_file: str, layer_idx: int
    ) -> np.ndarray:
        """Load FFN activations from HDF5 file."""
        try:
            if h5py is None:
                raise RuntimeError(
                    "h5py is required to read activation datasets but is not installed."
                )
            with h5py.File(activations_file, "r") as f:
                # Look for FFN layer activations
                ffn_layer_name = f"layer_{layer_idx}_ffn"
                if ffn_layer_name in f["activations"]:
                    layer_group = f["activations"][ffn_layer_name]
                    activations_list = []

                    # Collect activations from all samples
                    for sample_key in layer_group.keys():
                        if sample_key.startswith("sample_"):
                            sample_activations = np.array(
                                layer_group[sample_key]["activations"]
                            )
                            # Average over sequence length
                            if len(sample_activations.shape) > 1:
                                sample_activations = np.mean(sample_activations, axis=0)
                            activations_list.append(sample_activations)

                    return np.array(activations_list)
        except Exception as e:
            logger.warning(f"Could not load FFN activations for layer {layer_idx}: {e}")
            raise

    def _compute_cluster_purity(
        self, cluster_labels: np.ndarray, activations: np.ndarray
    ) -> float:
        """Compute unsupervised clustering quality via silhouette score (proxy)."""
        # Use silhouette score as an unsupervised proxy metric
        if len(np.unique(cluster_labels)) > 1:
            return float(silhouette_score(activations, cluster_labels))
        else:
            return 0.0

    def _extract_conceptual_clusters(
        self, cluster_centers: np.ndarray, cluster_labels: np.ndarray, n_clusters: int
    ) -> Dict[str, Any]:
        """Extract and describe conceptual clusters."""

        clusters = {}
        for i in range(n_clusters):
            cluster_mask = cluster_labels == i
            cluster_size = np.sum(cluster_mask)

            clusters[f"cluster_{i}"] = {
                "center": cluster_centers[i].tolist(),
                "size": int(cluster_size),
                "proportion": float(cluster_size / len(cluster_labels)),
                "description": f"Cluster {i} - {cluster_size} items",
            }

        return clusters

    def create_head_catalog(
        self, attention_analysis: Dict[str, Any]
    ) -> List[Dict[str, Any]]:
        """
        Create a catalog of attention heads with their functional properties.

        Implements Section 5.2: Building the Head Catalog
        """
        logger.info("Creating head catalog...")

        catalog = []

        for layer_key, layer_data in attention_analysis.items():
            layer_idx = int(layer_key.split("_")[1])
            layer_cfg = layer_data.get("layer_config", {})
            q2kv = layer_cfg.get("q_to_kv_mapping")

            for head_metrics in layer_data.get("head_metrics", []):
                if isinstance(head_metrics, dict):
                    kv_group = None
                    if q2kv is not None:
                        try:
                            kv_group = q2kv[head_metrics.get("head_idx", 0)]
                        except Exception:
                            kv_group = None
                    catalog_entry = {
                        "id": head_metrics.get("head_id", f"unknown_layer{layer_idx}"),
                        "layer_idx": layer_idx,
                        "head_idx": head_metrics.get("head_idx", 0),
                        "scores": {
                            "syntactic_score": head_metrics.get("syntactic_score", 0.0),
                            "factual_score": head_metrics.get("factual_score", 0.0),
                        },
                        "function": head_metrics.get("functional_label", "unknown"),
                        "kv_group": kv_group,
                        "attention_type": layer_data.get("attention_type", "MHA"),
                    }
                    catalog.append(catalog_entry)

        # Sort by functional importance (combined score)
        catalog.sort(
            key=lambda x: x["scores"]["syntactic_score"] + x["scores"]["factual_score"],
            reverse=True,
        )

        return catalog

    def _json_serializer(self, obj):
        """JSON serializer for numpy types."""
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.int32, np.int64)):
            return int(obj)
        raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
