"""
Importance Sampling Loss
=========================

Implements importance sampling for off-policy learning.
"""

import torch
import torch.nn.functional as F
from typing import Optional

from .custom_loss import CustomLoss, CustomLossConfig


class ImportanceSamplingLoss(CustomLoss):
    """
    Importance sampling loss for off-policy learning.

    Used when training on data generated by a different policy.
    """

    def __init__(
        self,
        config: Optional[CustomLossConfig] = None,
        clip_ratio: float = 10.0,
        normalize: bool = True,
    ):
        """
        Initialize importance sampling loss.

        Args:
            config: Configuration
            clip_ratio: Maximum importance ratio
            normalize: Whether to normalize importance weights
        """
        if config is None:
            config = CustomLossConfig(name="importance_sampling")
        super().__init__(config)
        self.clip_ratio = clip_ratio
        self.normalize = normalize

    def compute_loss(
        self,
        log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        advantages: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        """
        Compute importance sampling loss.

        Args:
            log_probs: Log probabilities from current policy
            old_log_probs: Log probabilities from behavior policy
            advantages: Advantage estimates
            mask: Optional mask
            **kwargs: Additional arguments

        Returns:
            Importance sampling loss
        """
        # Compute importance ratio
        ratio = torch.exp(log_probs - old_log_probs)

        # Clip ratio to prevent extreme values
        ratio = torch.clamp(ratio, max=self.clip_ratio)

        # Normalize importance weights if requested
        if self.normalize and mask is not None:
            ratio = ratio * mask
            ratio = ratio / ratio.sum() * mask.sum()

        # Compute weighted loss
        loss = -(ratio * advantages)

        return loss