"""
This module contains transformations that may be useful
for augmenting timeseries data during training
"""

import torch
from jaxtyping import Float
from torch import Tensor


class SignalInverter(torch.nn.Module):
    """
    Takes a tensor of timeseries of arbitrary dimension
    and randomly inverts i.e. :math:`h(t) \\rightarrow -h(t)`
    each timeseries with probability ``prob``.

    Args:
        prob:
            Probability that a timeseries is inverted
    """

    def __init__(self, prob: float = 0.5):
        super().__init__()
        self.prob = prob

    def forward(
        self, X: Float[Tensor, "*batch time"]
    ) -> Float[Tensor, "*batch time"]:
        mask = torch.rand(size=X.shape[:-1]) < self.prob
        X[mask] *= -1
        return X


class SignalReverser(torch.nn.Module):
    """
    Takes a tensor of timeseries of arbitrary dimension
    and randomly reverses i.e., :math:`h(t) \\rightarrow h(-t)`.
    each timeseries with probability ``prob``.

    Args:
        prob:
            Probability that a kernel is reversed
    """

    def __init__(self, prob: float = 0.5):
        super().__init__()
        self.prob = prob

    def forward(
        self, X: Float[Tensor, "*batch time"]
    ) -> Float[Tensor, "*batch time"]:
        mask = torch.rand(size=X.shape[:-1]) < self.prob
        X[mask] = X[mask].flip(-1)
        return X
