# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/agents/processors/01_processors.ipynb.

# %% auto 0
__all__ = ['RoundAction', 'ClipNegative', 'DiscreteToContinuous', 'OneHotToDiscrete', 'HybridToContinuous',
           'OneHotHybridtoContinuous', 'ContToBinary', 'MakeInputZero', 'GetTimeSeriesAndStaticFeatures']

# %% ../../../nbs/agents/processors/01_processors.ipynb 4
# General libraries:
import numpy as np

# %% ../../../nbs/agents/processors/01_processors.ipynb 6
class RoundAction():
    """
    This class implements a processor that rounds the action to the closest
    multiple of a given number.

    args:
        increment (float or np.array): the increment to be used for rounding.

    """
    

    def __init__(self, increment):
        self.increment = increment
    
    def __call__(self, input):
        return np.round(input / self.increment) * self.increment

# %% ../../../nbs/agents/processors/01_processors.ipynb 7
class ClipNegative():
    """
    This class implements a processor that clips the input to zero if it is
    negative.
    
    """

    def __call__(self, input):
        return input.copy().clip(min=0)

# %% ../../../nbs/agents/processors/01_processors.ipynb 8
class DiscreteToContinuous():
    """
    This class implements a processor that converts a discrete action to a
    continuous action. The discrete action is assumed to be an integer between
    0 and n_actions - 1, and the continuous action is assumed to be between
    action_min and action_max.

    args:
        n_actions (int): the number of actions.
        action_min (np.array): the minimum value of each continuous action.
        action_max (np.array): the maximum value of each continuous action.

    """

    def __init__(self, n_actions, action_min, action_max):
        if isinstance(n_actions, int):
            n_actions = np.array([n_actions])
        assert np.all(n_actions%2 != 0) , "Need odd number of actions to ensure 0 is included"
        self.n_actions = n_actions
        self.action_min = action_min
        self.action_max = action_max
    
    def __call__(self, action):
        action = action * (self.action_max - self.action_min) / (self.n_actions - 1) + self.action_min
        return action

# %% ../../../nbs/agents/processors/01_processors.ipynb 9
class OneHotToDiscrete():

    """
    This class implements a processor that converts a one-hot action to a
    discrete action. The one-hot action is assumed to be a one-hot vector of
    length n_actions, and the discrete action is assumed to be an integer
    between 0 and n_actions - 1.

    args:
        n_actions (int): the number of actions.

    """
    # TODO: make this work for multidimensional one-hot vectors

    def __init__(self):
        pass
    
    def __call__(self, action):
        action = action.argmax(axis=-1)
        return action

# %% ../../../nbs/agents/processors/01_processors.ipynb 10
class HybridToContinuous():

    """

    This class implements a processor that converts a hybrid action to a
    continuous action. The hybrid action is assumed to be a vector of length
    n_actions * 2 where the first half corresponds to the decision of whether
    to order or not, and the second half corresponds to the quantity.

    args:
        n_actions (int): the number of actions.

    """

    def __init__(self, num_products):
        self.num_products = num_products
    
    def __call__(self, action):
        order_action = action[:self.num_products]
        order_quantity = action[self.num_products:]
        order_action = np.where(order_action > 0, 1, 0)
        action = order_action * order_quantity
        return action

# %% ../../../nbs/agents/processors/01_processors.ipynb 11
class OneHotHybridtoContinuous():

    """

    This class implements a processor that converts a hybrid action that has a one-
    hot vector for the discrete action into the hybrid action that has a scalar value
    for the discrete action

    args:
        n_actions (int): the number of actions.

    """

    def __init__(self, num_products):
        self.num_products = num_products
    
    def __call__(self, action):
        # TODO check if this works for more than 1 product
        order_quantity = action[:self.num_products]
        order_action = action[self.num_products:]
        for i in range(self.num_products):
            order_action_prod = order_action[i:i+2]
            order_action = np.argmax(order_action_prod)
        action = order_action * order_quantity
        return action

# %% ../../../nbs/agents/processors/01_processors.ipynb 12
class ContToBinary():

    """
    This class implements a processor that converts a continuous input 
    into a binary input. The binary feature will be one if the continuous
    input is greater than 0, and zero otherwise.

    #! Currently only works for single product

    """

    def __init__(self):
        pass

    def __call__(self, state):
        return np.array([1]) if state[0] > 0 else np.array([0])

# %% ../../../nbs/agents/processors/01_processors.ipynb 13
class MakeInputZero():
    
        """
        This class implements a processor that makes the input zero.
    
        """
    
        def __init__(self):
            pass
    
        def __call__(self, state):
            return np.zeros_like(state)

# %% ../../../nbs/agents/processors/01_processors.ipynb 14
class GetTimeSeriesAndStaticFeatures():
    
    """
    This class implements a processor that splits the state into time series and static features.
    The time series features are assumed to be the first num_ts_features * len_ts_features features
    and the static features are assumed to be the remaining features.

    args:
        num_ts_features (int): the number of time series features.
        len_ts_features (int): the length of each time series feature.
    """

    def __init__(self, num_ts_features, len_ts_features):
        self.num_ts_features = num_ts_features
        self.len_ts_features = len_ts_features
        
    def __call__(self, state):

        ts_data = state[:,:self.num_ts_features*self.len_ts_features]
        static_data = state[:,self.num_ts_features*self.len_ts_features:]

        ts_data = ts_data.reshape(state.shape[0], self.num_ts_features, self.len_ts_features).transpose(1, 2)

        return [ts_data, static_data]
