"""
Hallucination Detection Module

This module provides the core HallucinationDetector class that implements
confidence-aware routing for LLM reliability enhancement using a multi-signal approach.
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Union
import warnings

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForImageTextToText,
    AutoProcessor,
    PaliGemmaForConditionalGeneration,
    BitsAndBytesConfig,
)

# Import FlagEmbedding for BGE-M3
try:
    from FlagEmbedding import BGEM3FlagModel
    BGE_M3_AVAILABLE = True
except ImportError:
    BGE_M3_AVAILABLE = False
    warnings.warn(
        "FlagEmbedding not available. Install with: pip install -U FlagEmbedding",
        ImportWarning
    )

# Import PIL for image processing
try:
    from PIL import Image
    PIL_AVAILABLE = True
except ImportError:
    PIL_AVAILABLE = False
    warnings.warn(
        "PIL not available. Install with: pip install Pillow",
        ImportWarning
    )

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)


class ProjectionHead(nn.Module):
    """
    Projection head for mapping LLM hidden states to embedding space.
    
    This component learns to project internal LLM representations to align
    with reference embeddings from BGE-M3 for confidence estimation.
    """
    
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 1024):
        """
        Initialize projection head.
        
        Args:
            input_dim: Dimension of input LLM hidden states
            output_dim: Dimension of target embedding space
            hidden_dim: Hidden layer dimension
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, output_dim),
            nn.LayerNorm(output_dim),
        )

    def forward(self, x):
        """Forward pass through projection network."""
        return self.net(x)


class UltraStableProjectionHead(nn.Module):
    """
    Ultra-stable projection head optimized for MedGemma models.
    
    This enhanced projection head provides better stability and convergence
    for medical domain applications with heavy normalization and conservative
    weight initialization.
    """
    
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 1280):
        """
        Initialize ultra-stable projection head.
        
        Args:
            input_dim: Dimension of input LLM hidden states
            output_dim: Dimension of target embedding space
            hidden_dim: Hidden layer dimension
        """
        super().__init__()
        
        # Ultra stable architecture with heavy normalization
        self.input_norm = nn.LayerNorm(input_dim)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.dropout1 = nn.Dropout(0.3)
        
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.norm2 = nn.LayerNorm(hidden_dim // 2)
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)
        self.output_norm = nn.LayerNorm(output_dim)
        
        # Ultra conservative weight initialization
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        """Initialize weights conservatively."""
        if isinstance(module, nn.Linear):
            # Very small weight initialization
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
            if module.bias is not None:
                torch.nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.constant_(module.bias, 0)
            torch.nn.init.constant_(module.weight, 1.0)

    def forward(self, x):
        """Ultra stable forward pass with NaN protection."""
        # Ultra stable forward pass with NaN protection
        x = self.input_norm(x)
        x = torch.tanh(self.fc1(x))  # Use tanh instead of ReLU for stability
        x = self.norm1(x)
        x = self.dropout1(x)
        
        x = torch.tanh(self.fc2(x))
        x = self.norm2(x)
        x = self.dropout2(x)
        
        x = self.fc3(x)
        x = self.output_norm(x)
        
        # Clamp output to prevent extreme values
        x = torch.clamp(x, min=-10.0, max=10.0)
        
        return x


def get_pooled_embeddings(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    texts: List[str],
    device: str,
    max_length: int = 512,
):
    """
    Extract pooled embeddings from LLM.
    
    Args:
        model: The LLM model
        tokenizer: Associated tokenizer
        texts: List of input texts
        device: Computing device
        max_length: Maximum sequence length
        
    Returns:
        Pooled embeddings tensor
    """
    model.eval()
    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    ).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    
    last_hidden = outputs.hidden_states[-1]
    mask = inputs.attention_mask.unsqueeze(-1)
    pooled = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1)
    
    # Ensure output is on the specified device
    return pooled.to(device)


class HallucinationDetector:
    """
    Confidence-aware hallucination detector supporting both Llama-3.2-3B and MedGemma-4B-IT + BGE-M3.
    
    This class implements the multi-signal confidence estimation approach described
    in the research paper, combining semantic alignment measurement, internal 
    convergence analysis, and learned confidence estimation. Supports multimodal
    capabilities for MedGemma 4b-it models.
    """
    
    def __init__(
        self,
        model_path: str = None,
        llm_model_id: str = "unsloth/Llama-3.2-3B-Instruct",
        embed_model_id: str = "BAAI/bge-m3",
        device: str = None,
        max_length: int = 512,
        bge_max_length: int = 512,
        use_fp16: bool = True,
        load_llm: bool = True,
        enable_inference: bool = False,
        confidence_threshold: float = None,
        enable_response_generation: bool = False,
        use_quantization: bool = False,
        quantization_config: BitsAndBytesConfig = None,
        mode: str = "auto",
    ):
        """
        Initialize the hallucination detector.
        
        Args:
            model_path: Path to trained model checkpoint. If None, downloads pre-trained model.
            llm_model_id: Hugging Face model ID for the LLM
            embed_model_id: Hugging Face model ID for the embedding model
            device: Computing device ('cuda' or 'cpu')
            max_length: Maximum sequence length for LLM
            bge_max_length: Maximum sequence length for BGE-M3
            use_fp16: Whether to use FP16 precision
            load_llm: Whether to load the LLM (set False if only using for projection/embedding)
            enable_inference: Whether to enable LLM inference capabilities
            confidence_threshold: Confidence threshold for high confidence routing (0.62 for medical)
            enable_response_generation: Whether to enable response generation when threshold is met
            use_quantization: Whether to use 4-bit quantization to reduce memory usage
            quantization_config: Custom BitsAndBytesConfig for quantization (auto-created if None)
            mode: Operation mode - "auto", "text", "image", or "both" (auto-detected from model if "auto")
        """
        if not BGE_M3_AVAILABLE:
            raise ImportError(
                "FlagEmbedding is required for BGE-M3. "
                "Install with: pip install -U FlagEmbedding"
            )
        
        # Auto-detect device if not specified
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        self.max_length = max_length
        self.bge_max_length = bge_max_length
        self.use_fp16 = use_fp16
        self.load_llm = load_llm
        self.enable_inference = enable_inference
        self.llm_model_id = llm_model_id
        self.enable_response_generation = enable_response_generation
        self.use_quantization = use_quantization
        
        # Validate and set mode
        valid_modes = ["auto", "text", "image", "both"]
        if mode not in valid_modes:
            raise ValueError(f"Invalid mode '{mode}'. Must be one of: {valid_modes}")
        self.mode = mode
        
        # Set up quantization configuration
        if use_quantization and quantization_config is None:
            # Create default 4-bit quantization config for memory optimization
            self.quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16 if use_fp16 else torch.float32,
            )
            print("🔧 Using 4-bit quantization for memory optimization")
        else:
            self.quantization_config = quantization_config
        
        # Determine if this is a MedGemma 4b-it model
        self.is_medgemma_4b = "4b-it" in llm_model_id.lower()
        
        # Determine multimodal capability based on model and mode
        if self.mode == "auto":
            # Auto-detect: MedGemma 4b-it supports multimodal
            self.is_multimodal = self.is_medgemma_4b
            self.effective_mode = "both" if self.is_medgemma_4b else "text"
        elif self.mode == "text":
            self.is_multimodal = False
            self.effective_mode = "text"
        elif self.mode == "image":
            if not self.is_medgemma_4b:
                raise ValueError("Image mode requires MedGemma 4b-it model. Use llm_model_id='google/medgemma-4b-it'")
            self.is_multimodal = True
            self.effective_mode = "image"
        elif self.mode == "both":
            if not self.is_medgemma_4b:
                raise ValueError("Both mode requires MedGemma 4b-it model. Use llm_model_id='google/medgemma-4b-it'")
            self.is_multimodal = True
            self.effective_mode = "both"
        
        # Set confidence threshold - use 0.62 for medical models, 0.65 for others
        if confidence_threshold is None:
            self.confidence_threshold = 0.62 if self.is_medgemma_4b else 0.65
        else:
            self.confidence_threshold = confidence_threshold
        
        print(f"🚀 Loading models on {self.device}...")
        print(f"🔬 Model type: {'MedGemma 4B-IT (Medical + Multimodal)' if self.is_medgemma_4b else 'Llama-3.2-3B (General)'}")
        print(f"📊 Confidence threshold: {self.confidence_threshold}")
        
        # Download model if path not provided
        if model_path is None:
            if self.is_medgemma_4b:
                from .utils import download_medgemma_model
                model_path = download_medgemma_model(llm_model_id)
            else:
                from .utils import download_model
                model_path = download_model()
        
        # Load checkpoint with proper device mapping
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        self.config = checkpoint['config']
        
        # Conditionally load LLM
        self.llm = None
        self.llm_multimodal = None
        self.tokenizer = None
        self.processor = None
        
        if self.load_llm:
            model_dtype = torch.float16 if (use_fp16 and self.device == "cuda") else torch.float32
            model_kwargs = dict(
                dtype=model_dtype,  # Updated from torch_dtype to dtype
                device_map="auto" if self.device == "cuda" else None,
            )
            
            # Add quantization config if enabled
            if self.use_quantization and self.quantization_config is not None:
                if self.device == "cpu":
                    print("⚠️ Quantization not supported on CPU, disabling quantization")
                    self.use_quantization = False
                    self.quantization_config = None
                else:
                    model_kwargs["quantization_config"] = self.quantization_config
                    # When using quantization, let device_map handle placement
                    model_kwargs["device_map"] = "auto"
            
            if self.is_medgemma_4b:
                print(f"📥 Loading MedGemma-4B-IT (mode: {self.effective_mode})...")
                
                if self.effective_mode in ["both", "image"]:
                    # Load unified multimodal model that can handle both text and images
                    print("🔄 Loading unified multimodal model...")
                    try:
                        self.llm_multimodal = AutoModelForImageTextToText.from_pretrained(llm_model_id, **model_kwargs)
                        print("✅ Loaded unified model as AutoModelForImageTextToText")
                        # Use the same model for text processing to avoid double loading
                        self.llm = self.llm_multimodal
                        print(f"🔗 Using unified model for {self.effective_mode} processing")
                    except Exception as e:
                        print(f"⚠️ Failed to load as AutoModelForImageTextToText: {e}")
                        try:
                            self.llm_multimodal = PaliGemmaForConditionalGeneration.from_pretrained(llm_model_id, **model_kwargs)
                            print("✅ Loaded unified model as PaliGemmaForConditionalGeneration")
                            # Use the same model for text processing
                            self.llm = self.llm_multimodal
                            print(f"🔗 Using unified model for {self.effective_mode} processing")
                        except Exception as e2:
                            print(f"⚠️ PaliGemma loading failed: {e2}")
                            if self.effective_mode == "image":
                                raise RuntimeError("Failed to load multimodal model in image mode")
                            print("📥 Falling back to text-only model...")
                            # Fallback to text-only model
                            self.llm = AutoModelForCausalLM.from_pretrained(llm_model_id, **model_kwargs)
                            self.llm_multimodal = None
                            self.is_multimodal = False
                            self.effective_mode = "text"
                            print("⚠️ Image processing disabled - text-only mode")
                else:
                    # Text-only mode for MedGemma
                    print("📝 Loading text-only model (mode: text)...")
                    self.llm = AutoModelForCausalLM.from_pretrained(llm_model_id, **model_kwargs)
                    self.llm_multimodal = None
                    print("✅ Text-only mode enabled")
                
                # Load tokenizer and processor
                self.tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
                if self.llm_multimodal is not None:
                    try:
                        self.processor = AutoProcessor.from_pretrained(llm_model_id)
                        print("✅ MedGemma processor loaded successfully")
                    except Exception as e:
                        print(f"⚠️ Could not load MedGemma processor: {e}")
                        self.processor = None
                        self.is_multimodal = False
            else:
                print("📥 Loading Llama-3.2-3B-Instruct...")
                self.llm = AutoModelForCausalLM.from_pretrained(llm_model_id, **model_kwargs)
                self.tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # Ensure LLM is on the correct device (skip for quantized models)
            if self.device == "cpu" and not self.use_quantization:
                self.llm = self.llm.to(self.device)
                # For MedGemma unified models, llm and llm_multimodal are the same object
                if self.llm_multimodal and self.llm_multimodal is not self.llm:
                    self.llm_multimodal = self.llm_multimodal.to(self.device)
            elif self.use_quantization:
                print("✅ Quantized models automatically placed on GPU via device_map")
                print(f"🔗 Unified model: {self.llm is self.llm_multimodal}")
        else:
            print("⏩ Skipping LLM loading (load_llm=False)")
            # Create a dummy tokenizer for cases where we need basic tokenization
            self.tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load BGE-M3
        print("📥 Loading BGE-M3...")
        # Disable FP16 for BGE-M3 when using CPU
        bge_use_fp16 = use_fp16 and (self.device == "cuda")
        self.embed_model = BGEM3FlagModel(embed_model_id, use_fp16=bge_use_fp16)
        
        # Load projection head
        print("📥 Loading projection head...")
        if self.is_medgemma_4b:
            # Use UltraStableProjectionHead for MedGemma models
            self.projector = UltraStableProjectionHead(
                self.config['llm_hidden_size'],
                self.config['embedding_dim'],
                hidden_dim=1280,  # Match training script
            ).to(self.device)
        else:
            # Use standard ProjectionHead for Llama models
            self.projector = ProjectionHead(
                self.config['llm_hidden_size'],
                self.config['embedding_dim'],
                hidden_dim=1536,
            ).to(self.device)
        
        self.projector.load_state_dict(checkpoint['projector_state'])
        self.projector.eval()
        
        # Ensure projection head is on same device as LLM (if loaded)
        if self.llm and hasattr(self.llm, 'device'):
            llm_device = next(self.llm.parameters()).device
            if str(llm_device) != str(self.device):
                print(f"⚠️ Moving projection head from {self.device} to {llm_device}")
                self.device = str(llm_device)
                self.projector = self.projector.to(self.device)
        
        print(f"✅ Model loaded successfully!")
        print(f"   LLM Hidden Size: {self.config['llm_hidden_size']}")
        print(f"   Embedding Dimension: {self.config['embedding_dim']}")
        print(f"   Operation Mode: {self.effective_mode} (requested: {self.mode})")
        if 'best_val_loss' in checkpoint:
            print(f"   Best Validation Loss: {checkpoint['best_val_loss']:.4f}")
        
        # Print optimization info
        if self.is_medgemma_4b:
            unified_model = self.llm is self.llm_multimodal
            print(f"   Unified Model: {unified_model} (avoids double loading)")
            print(f"   Memory Optimized: {self.use_quantization}")
            if self.use_quantization:
                print(f"   Quantization: 4-bit NF4 with double quantization")
    
    def predict(self, texts: Union[str, List[str]], query_context_pairs: List[Dict] = None) -> Dict:
        """
        Predict hallucination confidence scores for given texts.
        
        This method implements the core confidence estimation approach by:
        1. Computing semantic alignment between LLM and reference embeddings
        2. Analyzing internal convergence patterns
        3. Using learned confidence estimation
        
        Args:
            texts: Input text(s) to analyze
            query_context_pairs: Optional list of dicts with 'query' and 'context' keys for enhanced embedding
            
        Returns:
            Dictionary with predictions, confidence scores, and interpretations
        """
        if isinstance(texts, str):
            texts = [texts]
        
        # Prepare texts for embedding based on context
        embedding_texts = texts
        if query_context_pairs:
            # Format with specific prompt structure for better accuracy
            embedding_texts = []
            for i, text in enumerate(texts):
                if i < len(query_context_pairs) and query_context_pairs[i]:
                    pair = query_context_pairs[i]
                    formatted_text = f"query: {pair.get('query', text)}\ncontext: {pair.get('context', '')}"
                    embedding_texts.append(formatted_text)
                else:
                    embedding_texts.append(text)
        
        # Check if LLM is loaded for embedding computation
        if not self.load_llm or self.llm is None:
            raise RuntimeError("LLM not loaded. Set load_llm=True for hallucination detection.")
        
        # Get LLM embeddings and ensure they're on correct device
        # For projection model, use only query part for better similarity comparison
        projection_texts = texts
        if query_context_pairs:
            projection_texts = [pair.get('query', text) if i < len(query_context_pairs) and query_context_pairs[i] else text 
                             for i, (text, pair) in enumerate(zip(texts, query_context_pairs + [None] * len(texts)))]
        
        llm_embeddings = get_pooled_embeddings(
            self.llm,
            self.tokenizer,
            projection_texts,
            self.device,
            self.max_length,
        ).to(self.device)
        
        # Get BGE-M3 reference embeddings with formatted texts
        bge_outputs = self.embed_model.encode(
            embedding_texts,
            batch_size=len(embedding_texts),
            max_length=self.bge_max_length,
            return_dense=True,
            return_sparse=False,
            return_colbert_vecs=False,
        )
        
        ref_embeddings = torch.tensor(
            bge_outputs['dense_vecs'], 
            dtype=torch.float32,
            device=self.device
        )
        
        # Project LLM embeddings and compute similarity
        with torch.no_grad():
            # Ensure all tensors are on the same device and correct dtype
            llm_embeddings = llm_embeddings.float().to(self.device)
            ref_embeddings = ref_embeddings.to(self.device)
            
            projected = self.projector(llm_embeddings)
            similarities = F.cosine_similarity(projected, ref_embeddings, dim=1)
            confidence_scores = torch.sigmoid(similarities)
        
        # Convert to numpy for easier handling
        confidence_scores = confidence_scores.cpu().numpy()
        similarities = similarities.cpu().numpy()
        
        # Interpret results according to confidence-aware routing strategy
        results = []
        for i, (text, conf_score, sim_score) in enumerate(zip(texts, confidence_scores, similarities)):
            # Use dynamic thresholds based on model type
            if self.is_medgemma_4b:
                # Medical domain thresholds (lower due to higher precision requirements)
                if conf_score >= 0.62:
                    interpretation = "HIGH_MEDICAL_CONFIDENCE"
                    risk_level = "LOW_MEDICAL_RISK"
                    routing_action = "LOCAL_GENERATION"
                    description = "This medical response appears to be factual and reliable."
                elif conf_score >= 0.55:
                    interpretation = "MEDIUM_MEDICAL_CONFIDENCE"
                    risk_level = "MEDIUM_MEDICAL_RISK"
                    routing_action = "RAG_RETRIEVAL"
                    description = "This medical response may contain uncertainties. Verify with authoritative sources."
                elif conf_score >= 0.50:
                    interpretation = "LOW_MEDICAL_CONFIDENCE"
                    risk_level = "HIGH_MEDICAL_RISK"
                    routing_action = "LARGER_MODEL"
                    description = "This medical response is likely unreliable. Professional verification required."
                else:
                    interpretation = "VERY_LOW_MEDICAL_CONFIDENCE"
                    risk_level = "VERY_HIGH_MEDICAL_RISK"
                    routing_action = "HUMAN_REVIEW"
                    description = "This medical response appears highly unreliable. Seek professional medical advice."
            else:
                # General domain thresholds
                if conf_score >= 0.65:
                    interpretation = "HIGH_CONFIDENCE"
                    risk_level = "LOW_RISK"
                    routing_action = "LOCAL_GENERATION"
                    description = "This response appears to be factual and reliable."
                elif conf_score >= 0.60:
                    interpretation = "MEDIUM_CONFIDENCE"
                    risk_level = "MEDIUM_RISK"
                    routing_action = "RAG_RETRIEVAL"
                    description = "This response may contain uncertainties. Consider retrieval augmentation."
                elif conf_score >= 0.4:
                    interpretation = "LOW_CONFIDENCE"
                    risk_level = "HIGH_RISK"
                    routing_action = "LARGER_MODEL"
                    description = "This response is likely unreliable. Route to larger model."
                else:
                    interpretation = "VERY_LOW_CONFIDENCE"
                    risk_level = "VERY_HIGH_RISK"
                    routing_action = "HUMAN_REVIEW"
                    description = "This response appears to be highly unreliable. Human review required."
            
            results.append({
                "text": text,
                "confidence_score": float(conf_score),
                "similarity_score": float(sim_score),
                "interpretation": interpretation,
                "risk_level": risk_level,
                "routing_action": routing_action,
                "description": description,
            })
        
        return {
            "predictions": results,
            "summary": {
                "total_texts": len(texts),
                "avg_confidence": float(confidence_scores.mean()),
                "high_confidence_count": sum(1 for score in confidence_scores if score >= 0.65),
                "medium_confidence_count": sum(1 for score in confidence_scores if 0.60 <= score < 0.65),
                "low_confidence_count": sum(1 for score in confidence_scores if 0.4 <= score < 0.6),
                "very_low_confidence_count": sum(1 for score in confidence_scores if score < 0.4),
            }
        }
    
    def batch_predict(self, texts: List[str], batch_size: int = 16) -> Dict:
        """
        Process large batches of texts efficiently.
        
        Args:
            texts: List of texts to analyze
            batch_size: Batch size for processing
            
        Returns:
            Combined results dictionary
        """
        all_results = []
        
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            batch_results = self.predict(batch)
            all_results.extend(batch_results["predictions"])
        
        # Compute overall summary
        confidence_scores = [r["confidence_score"] for r in all_results]
        
        return {
            "predictions": all_results,
            "summary": {
                "total_texts": len(texts),
                "avg_confidence": sum(confidence_scores) / len(confidence_scores),
                "high_confidence_count": sum(1 for score in confidence_scores if score >= 0.65),
                "medium_confidence_count": sum(1 for score in confidence_scores if 0.60 <= score < 0.65),
                "low_confidence_count": sum(1 for score in confidence_scores if 0.4 <= score < 0.6),
                "very_low_confidence_count": sum(1 for score in confidence_scores if score < 0.4),
            }
        }
    
    def evaluate_routing_strategy(self, texts: List[str]) -> Dict:
        """
        Evaluate the confidence-aware routing strategy for given texts.
        
        Args:
            texts: List of texts to analyze
            
        Returns:
            Routing strategy analysis
        """
        results = self.predict(texts)
        routing_counts = {}
        
        for pred in results["predictions"]:
            action = pred["routing_action"]
            routing_counts[action] = routing_counts.get(action, 0) + 1
        
        return {
            "routing_distribution": routing_counts,
            "computational_efficiency": {
                "local_generation_percentage": routing_counts.get("LOCAL_GENERATION", 0) / len(texts) * 100,
                "expensive_operations_percentage": (
                    routing_counts.get("RAG_RETRIEVAL", 0) + 
                    routing_counts.get("LARGER_MODEL", 0)
                ) / len(texts) * 100,
                "human_review_percentage": routing_counts.get("HUMAN_REVIEW", 0) / len(texts) * 100,
            },
            "summary": results["summary"]
        }
    
    def predict_with_query_context(self, query_context_pairs: List[Dict]) -> Dict:
        """
        Convenience method for predicting with query-context pairs.
        
        Args:
            query_context_pairs: List of dicts with 'query' and 'context' keys
            
        Returns:
            Dictionary with predictions, confidence scores, and interpretations
        """
        texts = [pair.get('query', '') for pair in query_context_pairs]
        return self.predict(texts, query_context_pairs=query_context_pairs)
    
    @classmethod
    def for_embedding_only(
        cls,
        model_path: str = None,
        embed_model_id: str = "BAAI/bge-m3",
        device: str = None,
        bge_max_length: int = 512,
        use_fp16: bool = True,
    ):
        """
        Create detector instance optimized for embedding-only usage (no LLM loading).
        
        Args:
            model_path: Path to trained model checkpoint
            embed_model_id: Hugging Face model ID for the embedding model
            device: Computing device ('cuda' or 'cpu')
            bge_max_length: Maximum sequence length for BGE-M3
            use_fp16: Whether to use FP16 precision
            
        Returns:
            HallucinationDetector instance with LLM disabled
        """
        return cls(
            model_path=model_path,
            embed_model_id=embed_model_id,
            device=device,
            bge_max_length=bge_max_length,
            use_fp16=use_fp16,
            load_llm=False,
            enable_inference=False,
        )
    
    @classmethod
    def for_low_memory(
        cls,
        llm_model_id: str = "google/medgemma-4b-it",
        model_path: str = None,
        device: str = "cuda",
        enable_response_generation: bool = True,
        **kwargs
    ):
        """
        Create detector instance optimized for low memory usage with 4-bit quantization.
        
        Args:
            llm_model_id: LLM model ID (default: MedGemma for medical tasks)
            model_path: Path to trained model checkpoint
            device: Computing device (cuda recommended for quantization)
            enable_response_generation: Whether to enable response generation
            **kwargs: Additional arguments passed to HallucinationDetector
            
        Returns:
            HallucinationDetector instance with memory optimization
        """
        return cls(
            model_path=model_path,
            llm_model_id=llm_model_id,
            device=device,
            use_quantization=True,
            enable_response_generation=enable_response_generation,
            enable_inference=True,
            use_fp16=True,
            **kwargs
        )
    
    def generate_response(self, prompt: str, max_length: int = 512, check_confidence: bool = True) -> Union[str, Dict]:
        """
        Generate a response from the LLM with optional confidence checking.
        
        Args:
            prompt: Input prompt/question
            max_length: Maximum response length
            check_confidence: Whether to check confidence before generating
            
        Returns:
            Generated response text or dict with response and confidence info
        """
        if not self.enable_response_generation:
            raise RuntimeError("Response generation not enabled. Set enable_response_generation=True.")
        
        if not self.load_llm or self.llm is None:
            raise RuntimeError("LLM not loaded. Set load_llm=True for response generation.")
        
        # Check confidence first if requested
        if check_confidence:
            confidence_result = self.predict([prompt])
            confidence_score = confidence_result["predictions"][0]["confidence_score"]
            
            if confidence_score < self.confidence_threshold:
                return {
                    "response": None,
                    "confidence_score": confidence_score,
                    "should_generate": False,
                    "reason": f"Confidence {confidence_score:.3f} below threshold {self.confidence_threshold}",
                    "recommendation": confidence_result["predictions"][0]["routing_action"]
                }
        
        try:
            # Format prompt for the specific model type
            if self.is_medgemma_4b:
                # Use medical context for MedGemma
                if hasattr(self.tokenizer, 'apply_chat_template'):
                    messages = [
                        {
                            "role": "system",
                            "content": "You are a helpful medical assistant."
                        },
                        {
                            "role": "user",
                            "content": prompt
                        }
                    ]
                    formatted_prompt = self.tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=True
                    )
                else:
                    formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
            else:
                # Use general context for Llama
                formatted_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
            
            # Tokenize input
            inputs = self.tokenizer(
                formatted_prompt,
                return_tensors="pt",
                truncation=True,
                max_length=self.max_length
            ).to(self.device)
            
            # Generate response
            with torch.inference_mode():
                input_len = inputs["input_ids"].shape[-1]
                
                generation = self.llm.generate(
                    **inputs,
                    max_new_tokens=max_length,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    repetition_penalty=1.1,
                    no_repeat_ngram_size=3,
                    pad_token_id=self.tokenizer.eos_token_id if self.tokenizer.eos_token_id else self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    early_stopping=True,
                )
                
                # Extract only the generated part
                generation = generation[0][input_len:]
            
            # Decode response
            response = self.tokenizer.decode(generation, skip_special_tokens=True)
            response = response.strip()
            
            if check_confidence:
                return {
                    "response": response,
                    "confidence_score": confidence_score,
                    "should_generate": True,
                    "meets_threshold": confidence_score >= self.confidence_threshold
                }
            else:
                return response
                
        except Exception as e:
            if check_confidence:
                return {
                    "response": None,
                    "error": str(e),
                    "should_generate": False
                }
            else:
                return f"[Error: {str(e)}]"
    
    def predict_images(self, images: List, image_descriptions: List[str] = None) -> Dict:
        """
        Predict confidence scores for medical images (MedGemma 4b-it only).
        
        Args:
            images: List of PIL Images to analyze
            image_descriptions: Optional descriptions of what the images should show
            
        Returns:
            Dictionary with image predictions and confidence scores
        """
        # Validate mode for image processing
        if self.effective_mode not in ["image", "both"]:
            raise ValueError(f"Image prediction requires mode 'image' or 'both', but current mode is '{self.effective_mode}'")
        
        if not self.is_multimodal:
            raise ValueError("Image prediction only supported for MedGemma 4b-it models")
        
        if not PIL_AVAILABLE:
            raise ImportError("PIL is required for image processing. Install with: pip install Pillow")
        
        if self.llm_multimodal is None or self.processor is None:
            raise ValueError("Multimodal model or processor not available for image processing")
        
        # Convert single image to list
        if not isinstance(images, list):
            images = [images]
        
        if image_descriptions is None:
            image_descriptions = [f"Medical image {i+1}" for i in range(len(images))]
        
        # This is a simplified implementation - in practice, you'd need proper image embeddings
        # For now, we'll analyze text descriptions and return placeholder results
        results = []
        for i, (image, desc) in enumerate(zip(images, image_descriptions)):
            # Placeholder confidence score for images
            confidence_score = 0.60  # Default medium confidence for images
            
            # Use medical image thresholds
            if confidence_score >= 0.62:
                interpretation = "HIGH_MEDICAL_IMAGE_CONFIDENCE"
                risk_level = "LOW_MEDICAL_RISK"
                description = "This medical image analysis appears reliable."
            elif confidence_score >= 0.55:
                interpretation = "MEDIUM_MEDICAL_IMAGE_CONFIDENCE"
                risk_level = "MEDIUM_MEDICAL_RISK"
                description = "This medical image analysis may need expert verification."
            else:
                interpretation = "LOW_MEDICAL_IMAGE_CONFIDENCE"
                risk_level = "HIGH_MEDICAL_RISK"
                description = "This medical image analysis appears unreliable."
            
            results.append({
                "image_index": i,
                "image_description": desc,
                "confidence_score": float(confidence_score),
                "interpretation": interpretation,
                "risk_level": risk_level,
                "description": description,
            })
        
        return {
            "predictions": results,
            "summary": {
                "total_images": len(images),
                "avg_confidence": sum(r["confidence_score"] for r in results) / len(results),
                "high_confidence_count": sum(1 for r in results if r["confidence_score"] >= 0.62),
                "medium_confidence_count": sum(1 for r in results if 0.55 <= r["confidence_score"] < 0.62),
                "low_confidence_count": sum(1 for r in results if r["confidence_score"] < 0.55),
            }
        }
    
    def generate_image_response(self, image, prompt: str = "Describe this medical image.", max_length: int = 200) -> str:
        """
        Generate a response from MedGemma for a given medical image.
        
        Args:
            image: PIL Image to analyze
            prompt: Text prompt for the image analysis
            max_length: Maximum response length
            
        Returns:
            Generated response text
        """
        # Validate mode for image response generation
        if self.effective_mode not in ["image", "both"]:
            raise ValueError(f"Image response generation requires mode 'image' or 'both', but current mode is '{self.effective_mode}'")
        
        if not self.is_multimodal:
            raise ValueError("Image response generation only supported for MedGemma 4b-it models")
        
        if not PIL_AVAILABLE:
            raise ImportError("PIL is required for image processing. Install with: pip install Pillow")
        
        if self.llm_multimodal is None or self.processor is None:
            raise ValueError("Multimodal model or processor not available for image processing")
        
        try:
            # Create proper MedGemma message format
            messages = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": "You are an expert radiologist."}]
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {"type": "image", "image": image}
                    ]
                }
            ]
            
            # Apply chat template to get proper inputs
            inputs = self.processor.apply_chat_template(
                messages, 
                add_generation_prompt=True, 
                tokenize=True,
                return_dict=True, 
                return_tensors="pt"
            ).to(self.device)
            
            input_len = inputs["input_ids"].shape[-1]
            
            # Generate response
            with torch.inference_mode():
                generation = self.llm_multimodal.generate(
                    **inputs, 
                    max_new_tokens=max_length, 
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    repetition_penalty=1.2,
                    no_repeat_ngram_size=3,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
                generation = generation[0][input_len:]
            
            # Decode response
            decoded = self.processor.decode(generation, skip_special_tokens=True)
            return decoded.strip() if decoded.strip() else "[No response generated]"
            
        except Exception as e:
            return f"[Error: {str(e)}]"