"""
Loss functions for diffusion training
Implements specialized loss computations for diffusion language models
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class DiffusionLoss(nn.Module):
    """
    Loss function for diffusion language models
    Combines cross-entropy on masked positions with optional energy model loss
    and confidence-based weighting
    """
    
    def __init__(self, vocab_size, use_energy_model=True, label_smoothing=0.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.use_energy_model = use_energy_model
        self.label_smoothing = label_smoothing
    
    def forward(self, logits, targets, mask_positions, confidence=None, energy_scores=None):
        """
        Compute diffusion training loss
        
        Args:
            logits: (batch, seq_len, vocab_size) - Model predictions
            targets: (batch, seq_len) - Ground truth tokens
            mask_positions: (batch, seq_len) - Boolean mask indicating which positions were masked
            confidence: (batch, seq_len) - Optional confidence scores
            energy_scores: (batch,) - Optional energy model scores
        
        Returns:
            loss: Scalar loss value
        """
        # Extract only masked positions
        flat_logits = logits[mask_positions]  # (num_masked, vocab_size)
        flat_targets = targets[mask_positions]  # (num_masked,)
        
        # Cross-entropy loss on masked tokens
        ce_loss = F.cross_entropy(
            flat_logits, 
            flat_targets,
            label_smoothing=self.label_smoothing,
            reduction='mean'
        )
        
        total_loss = ce_loss
        
        # Optional: Confidence-weighted loss
        # Lower confidence predictions get higher weight
        if confidence is not None:
            confidence_weight = confidence[mask_positions].detach()
            # Weight inversely proportional to confidence
            weighted_ce = F.cross_entropy(
                flat_logits, 
                flat_targets, 
                reduction='none'
            )
            weighted_loss = (weighted_ce * (1.0 - confidence_weight)).mean()
            total_loss = total_loss + 0.1 * weighted_loss
        
        # Optional: Energy-based model loss
        # Penalize high energy (low probability) sequences
        if self.use_energy_model and energy_scores is not None:
            # Energy should be low for correct sequences
            energy_loss = energy_scores.mean()
            total_loss = total_loss + 0.01 * energy_loss
        
        return total_loss


class ContrastiveLoss(nn.Module):
    """
    Contrastive loss for multimodal training
    Aligns text and image representations
    """
    
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, text_embeds, image_embeds):
        """
        Compute contrastive loss between text and image embeddings
        
        Args:
            text_embeds: (batch, dim) - Text embeddings
            image_embeds: (batch, dim) - Image embeddings
        
        Returns:
            loss: Scalar loss value
        """
        # Normalize embeddings
        text_embeds = F.normalize(text_embeds, p=2, dim=-1)
        image_embeds = F.normalize(image_embeds, p=2, dim=-1)
        
        # Compute similarity matrix
        logits = torch.matmul(text_embeds, image_embeds.t()) / self.temperature
        
        # Labels are the diagonal (matched pairs)
        batch_size = text_embeds.shape[0]
        labels = torch.arange(batch_size, device=logits.device)
        
        # Symmetric loss (text-to-image and image-to-text)
        loss_t2i = F.cross_entropy(logits, labels)
        loss_i2t = F.cross_entropy(logits.t(), labels)
        
        return (loss_t2i + loss_i2t) / 2
