# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/environments/01_multi_period_inventory.ipynb.

# %% auto 0
__all__ = ['MultiPeriodEnv', 'MultiPeriodEnvMP']

# %% ../../nbs/environments/01_multi_period_inventory.ipynb 4
# General libraries:
import numpy as np
    
# build env:
from mushroom_rl.core import Environment, MDPInfo
from mushroom_rl.utils.spaces import Discrete, Box
from .feature_converters import get_time_series_features

# others
from abc import ABC, abstractmethod

# %% ../../nbs/environments/01_multi_period_inventory.ipynb 6
class MultiPeriodEnv(Environment, ABC):
    def __init__(self,
                 
                ### demand data
                demand,

                ### feature data
                features = None,

                mask = None, # optional mask to mask out certain actions (e.g., if product not offered), shape(timesteps, num_products)
                feature_map = None, # optional feature map to map features to products, shape(num_features, num_products)

                ### cost related to overage and underate
                overage_cost = np.array([0]), 
                underage_cost = np.array([0]), #! check if there can be a case where this is different from penalty cost

                ### cost related to ordering
                fixed_ordering_cost = np.array([0]), 
                variable_ordering_cost = np.array([0]), 

                ### demand parameters
                max_demand_backlog = np.array([0]), 

                ### inventory parameters
                inventory_allowed = True,
                start_inventory = np.array([0]), 
                inventory_cap = np.array([np.inf]), 
                
                ### cost related to inventory
                holding_cost = None, 

                ### ordering parameters
                lead_time = np.array([0]), 
                lead_time_variance = np.array([0]), 
                max_lead_time = None,
                actions_per_product = 1,
                order_limit_low = None, 
                order_limit_high = None, 

                ### other parameters
                horizon = 365, # lenght of each episode in simulation
                gamma = 1, # discount factor
                observation_space_kwargs = None,

                ### logic of step-method
                calculations = None,

                ### Precision of state
                precision = 5,

                lag_window = 0,

                normalize_env_state = False,
                normalize_reward = False,

                record_state = False
                ):
        """
        Multi-period inventory environment covering all major single-echelon inventory models with decision on quantity.

        Note: Except True/False parameters and horizon/gamma, all parameters are arrays to set the parameter for each product individually.

        args:
            demand: historical observations of demand
            features: features related to demand - normalize features before passing to environment
            overage_cost: overage cost incurred per unit in each period
            underage_cost: underage cost incurred per unit in each period
            fixed_ordering_cost: fixed cost incurred per order
            variable_ordering_cost: variable cost incurred per unit ordered (note: related to ordering only, not cogs)
            max_demand_backlog: How much demand can be backlogged. If set to zero: Lost Sales model
            inventory_allowed: whether inventory is allowed. False: Excess orders will be thrown away
            start_inventory: initial inventory
            inventory_cap: maximum inventory allowed
            holding_cost: holding cost incurred per unit in each period
            lead_time: lead time between ordering and receiving the order
            lead_time_variance: variance in lead time
            order_limit_low: maximum order allowed
            order_limit_high: minimum order allowed
            horizon: lenght of each episode in simulation
            gamma: discount factor
            observation_space_kwargs: kwargs for observation space
            calculations: function that calculates the next state and reward in the step function
            precision: precision of state variables (decimal places)
        """


        self.demand = demand
        self.avg_demand = np.mean(demand, axis = 0)

        print("num_products:", self.avg_demand.shape[0])
        print("average_demand:", self.avg_demand)

        self.num_observations = demand.shape[0] # number of observations
        self.num_products = demand.shape[1] # number of products
        self.features = features
        if self.features is not None:
            self.num_features = features.shape[1] # number of features
        else:   
            self.num_features = 0
        
        self.mask=mask

        self.feature_map = feature_map

        # print(demand.shape)
        # print(features.shape)
        # print("num_obs:", self.num_observations)
        # print("num_prod:", self.num_products)
        # print("num_feat:", self.num_features)

        self.overage_cost = self.convert_to_numpy_array(overage_cost, expand = True)
        self.underage_cost = self.convert_to_numpy_array(underage_cost, expand = True)

        # print("overage_cost:", self.overage_cost)
        # print("underage_cost:", self.underage_cost)


        self.fixed_ordering_cost = self.convert_to_numpy_array(fixed_ordering_cost, expand = True)
        self.variable_ordering_cost = self.convert_to_numpy_array(variable_ordering_cost, expand = True)

        self.max_demand_backlog = self.convert_to_numpy_array(max_demand_backlog, expand = True)

        # print("max_demand_backlog:", self.max_demand_backlog)
        # print("variable_ordering_cost:", self.variable_ordering_cost)
        # print("fixed_ordering_cost:", self.fixed_ordering_cost)
        

        self.inventory_allowed = inventory_allowed

        if not self.inventory_allowed:
            assert np.all(start_inventory == 0), "If inventory is not allowed, start_inventory must be 0"
        self.start_inventory = self.convert_to_numpy_array(start_inventory, expand = True)
        # print("setting inventory cap")
        self.inventory_cap = self.convert_to_numpy_array(inventory_cap, expand = True)

        if normalize_env_state:
            assert np.all(inventory_cap != np.inf), "If inventory is normalized, inventory_cap must be a finite number"
        
        # print("start_inventory:", self.start_inventory)
        # print("inventory_cap:", self.inventory_cap)

        # print("set inventory cap")

        if not self.inventory_allowed:
            assert holding_cost is None, "If inventory is not allowed, holding_cost must be None"
        # print("setting holding cost")
        self.holding_cost = self.convert_to_numpy_array(holding_cost, expand = True)

        # print("holding_cost:", self.holding_cost)
        # print("staring setting LT")

        if np.all(lead_time == 0):
            self.use_order_pipeline = False
        else:
            self.use_order_pipeline = True

        self.lead_time = self.convert_to_numpy_array(lead_time, expand = True)
        self.lead_time_variance = self.convert_to_numpy_array(lead_time_variance, expand = True)

        # print("lead_time:", self.lead_time)
        # print("lead_time_variance:", self.lead_time_variance)

        if np.all(lead_time_variance == 0):
            # print("setting LT")
            self.max_lead_time = self.lead_time
            # print("set LT")
        else:
            self.max_lead_time = self.convert_to_numpy_array(max_lead_time, expand = True)

        if order_limit_low is None:
            self.order_limit_low = np.zeros(self.num_products)
            self.order_limit_high = np.ones(self.num_products)
        
        self.order_limit_low = self.convert_to_numpy_array(order_limit_low, expand = True)
        self.order_limit_high = self.convert_to_numpy_array(order_limit_high, expand = True)

        # print("order_limit_low:", self.order_limit_low.shape)
        # print("order_limit_high:", self.order_limit_high.shape)

        order_limit_low = np.repeat(self.order_limit_low, actions_per_product)
        order_limit_high = np.repeat(self.order_limit_high, actions_per_product)

        self.action_space = Box(order_limit_low, order_limit_high, shape=(self.num_products*actions_per_product,))
        # print("action space:", self.action_space.shape)


        # Set observation space
        if observation_space_kwargs is None:
            observation_space_kwargs = {}

        # print("obs kwargs:", observation_space_kwargs)
        self.set_observation_space(**observation_space_kwargs)

        self.calculations = calculations

        # Set MDP info
        mdp_info = MDPInfo(observation_space=self.observation_space,
                           action_space=self.action_space,
                           horizon=horizon,
                           gamma=gamma)
        
        super().__init__(mdp_info)

        self.precision = precision

        self.lag_window = lag_window

        self.normalize_reward = normalize_reward

        # print("normalioze rewrad in env class:", self.normalize_reward)
        self.normalize_env_state = normalize_env_state

        # print("done, now resetting")

        self.record_state = record_state

        # Set state variables
        self.reset()

    def set_observation_space(self,
                              features = False,
                              inventory = False,
                              order_pipeline = False,
                              demand_backlog = False
                              ):
        """
        Set observation space. This is used to customize the observation space.

        args:
            features: whether to include features in the observation space
            inventory: whether to include inventory in the observation space
            order_pipeline: whether to include order pipeline in the observation space
            demand_backlog: whether to include demand backlog in the observation space
        """

        features_space = self.num_features if features else 0
        inventory_space = self.num_products if inventory else 0
        order_pipeline_space = np.sum(self.num_products * np.max(self.max_lead_time)) if order_pipeline else 0
        demand_backlog_space = self.num_products if demand_backlog else 0
        total_obs_space_length = features_space + inventory_space + order_pipeline_space + demand_backlog_space

        assert total_obs_space_length > 0, "Need at least one observation dimension"

        low_limit = np.zeros(total_obs_space_length)
        high_limit = np.zeros(total_obs_space_length)

        if features:
            low_limit[:self.num_features] = -np.ones(self.num_features)
            high_limit[:self.num_features] = np.ones(self.num_features)
        else:
            self.num_features = 0

        if inventory:
            low_limit[self.num_features:self.num_features+inventory_space] = np.zeros(inventory_space) # assuming the low inventory end is always 0
            high_limit[self.num_features:self.num_features+inventory_space] = self.inventory_cap

        if order_pipeline:
            slots_occupied = 0
            for product in range(self.num_products):
                # Note: the length of the order pipeline can vary by product. Note that for RNN processing we need to pad the order pipeline to the same length in the preprocessing step
                low_limit[self.num_features+inventory_space+slots_occupied:self.num_features+inventory_space+slots_occupied+self.max_lead_time[product]] = self.order_limit_low[product]
                high_limit[self.num_features+inventory_space+slots_occupied:self.num_features+inventory_space+slots_occupied+self.max_lead_time[product]] = self.order_limit_high[product]
                slots_occupied += self.max_lead_time[product]

        if demand_backlog:
            low_limit[self.num_features+inventory_space+order_pipeline_space:self.num_features+inventory_space+order_pipeline_space+demand_backlog_space] = 0
            high_limit[self.num_features+inventory_space+order_pipeline_space:self.num_features+inventory_space+order_pipeline_space+demand_backlog_space] = self.max_demand_backlog

        # print("features space:", features_space)
        # print("inventory space:", inventory_space)
        # print("order pipeline space:", order_pipeline_space)
        # print("demand backlog space:", demand_backlog_space)
        # print("length:", total_obs_space_length)
        self.observation_space = Box(low_limit, high_limit, shape=(total_obs_space_length,))

    def reset(self, state = None):

        # TODO: Check in mushroom where state would come from and implement here
        if self.num_observations == self._mdp_info.horizon:
            self.period = 0
        else:
            self.period = np.random.choice(self.num_observations-self._mdp_info.horizon)
        #print("reset with period {}".format(self.period))
        self.demand_backlog = np.zeros(self.num_products)
        self.inventory = self.start_inventory.copy()
        self.order_pipeline = np.zeros((self.num_products,np.max(self.max_lead_time)))

        self.set_observation_state()

        return self.observation_state

    def set_observation_state(self):
        
        # TODO: check if this is computationally expensive
        # TODO: make sure this is consistent if some elements are there but not tracked in state

        if self.period >= self.num_observations:
            self.observation_state = np.zeros(self.observation_space.shape)
        
        else:
            
            self.observation_state = np.zeros(self.observation_space.shape)

            if self.num_features > 0:
                self.observation_state[:self.num_features] = self.features[self.period]

            if self.inventory_allowed:

                if self.normalize_env_state:
                    inv = self.inventory/self.inventory_cap
                else:
                    inv = self.inventory

                self.observation_state[self.num_features:self.num_features+self.num_products] = inv

            if self.use_order_pipeline:
                slots_occupied = 0
                for product in range(self.num_products):

                    # print("pipeline_raw:", self.order_pipeline[product])
                    
                    if self.normalize_env_state:
                        prod = self.order_pipeline[product]/self.order_limit_high[product]
                    else:
                        prod = self.order_pipeline[product]

                    # print(prod)

                    self.observation_state[self.num_features+self.num_products+slots_occupied:self.num_features+self.num_products+slots_occupied+self.max_lead_time[product]] = prod
                    slots_occupied += self.max_lead_time[product]

            if np.max(self.max_demand_backlog) > 0:
                self.observation_state[self.num_features+self.num_products+np.sum(self.num_products * self.lead_time):] = self.demand_backlog

            self.observation_state = np.round(self.observation_state, self.precision)

            # print("inventory:", np.round(self.inventory, 2))
            # print("order pipeline:", np.round(self.order_pipeline, 2))
            # print("_______")


    def step(self, action):

        absorbing = False

        reward, info = self.calculations(self, action)

        self.set_observation_state()
        
        # if self.normalize_reward:
        #     print(self.observation_state)


        return self.observation_state, reward, absorbing, info
    
    def render(self):
        pass

    def convert_to_numpy_array(self, value, expand = False):

        """
        Convert a scalar, list, or numpy array to a numpy array.
        
        Parameters:
        - value (scalar, list, numpy.ndarray): The input value.
        
        Returns:
        - numpy.ndarray: The converted numpy array.
        """

        if np.isscalar(value):
            value = np.array([value])
        elif isinstance(value, list):
            value = np.array(value)
        elif isinstance(value, np.ndarray):
            value = value
        elif value is None:
            value = None
        else:
            raise ValueError("Invalid input type. Expected scalar, list, or numpy array.")

        if expand:
            if value is not None:
                if value.size == 1:
                    value = np.repeat(value, self.num_products)
                elif value.size == self.num_products:
                    value = value
                else:
                    raise ValueError("Invalid input size. Expected scalar, list, or numpy array of size 1 or num_products.")
        return value


# %% ../../nbs/environments/01_multi_period_inventory.ipynb 7
class MultiPeriodEnvMP(Environment, ABC):
    def __init__(self,
                 
                ### demand data
                demand,

                ### feature data
                features = None,

                features_timeless = None,

                mask = None, # optional mask to mask out certain actions (e.g., if product not offered), shape(timesteps, num_products)
                feature_map = None, # optional feature map to map features to products, shape(num_features, num_products)

                ### cost related to overage and underate
                overage_cost = np.array([0]), 
                underage_cost = np.array([0]), #! check if there can be a case where this is different from penalty cost

                ### cost related to ordering
                fixed_ordering_cost = np.array([0]), 
                variable_ordering_cost = np.array([0]), 

                ### demand parameters
                max_demand_backlog = np.array([0]), 

                ### inventory parameters
                inventory_allowed = True,
                start_inventory = np.array([0]), 
                inventory_cap = np.array([np.inf]), 
                
                ### cost related to inventory
                holding_cost = None, 

                ### ordering parameters
                lead_time = np.array([0]), 
                lead_time_variance = np.array([0]), 
                max_lead_time = None,
                actions_per_product = 1,
                order_limit_low = None, 
                order_limit_high = None, 

                ### other parameters
                horizon = 365, # lenght of each episode in simulation
                gamma = 1, # discount factor
                observation_space_kwargs = None,

                ### logic of step-method
                calculations = None,

                ### Precision of state
                precision = 5,

                lag_window = 0,

                normalize_env_state = False,
                normalize_reward = False,

                record_state = False
                ):
        """
        Multi-period inventory environment covering all major single-echelon inventory models with decision on quantity.

        Note: Except True/False parameters and horizon/gamma, all parameters are arrays to set the parameter for each product individually.

        args:
            demand: historical observations of demand
            features: features related to demand - normalize features before passing to environment
            overage_cost: overage cost incurred per unit in each period
            underage_cost: underage cost incurred per unit in each period
            fixed_ordering_cost: fixed cost incurred per order
            variable_ordering_cost: variable cost incurred per unit ordered (note: related to ordering only, not cogs)
            max_demand_backlog: How much demand can be backlogged. If set to zero: Lost Sales model
            inventory_allowed: whether inventory is allowed. False: Excess orders will be thrown away
            start_inventory: initial inventory
            inventory_cap: maximum inventory allowed
            holding_cost: holding cost incurred per unit in each period
            lead_time: lead time between ordering and receiving the order
            lead_time_variance: variance in lead time
            order_limit_low: maximum order allowed
            order_limit_high: minimum order allowed
            horizon: lenght of each episode in simulation
            gamma: discount factor
            observation_space_kwargs: kwargs for observation space
            calculations: function that calculates the next state and reward in the step function
            precision: precision of state variables (decimal places)
        """

        self.demand = demand
        self.avg_demand = np.mean(demand, axis = 0)


        self.num_observations = demand.shape[0]-lag_window # number of observations # in this environment, the demand will incldue lag-window observations in the beginning

        self.num_products = demand.shape[1] # number of products
        self.features = [features, features_timeless] 

        if features is not None:

            # number of overarching features (some may be partally product specific)
            self.num_features = feature_map.shape[1]
            self.num_features = self.num_features + 2 # price and sold information. # TODO: Automate this instead of hardcoding
            self.num_features += 1 # lag demand of previous step

            if features_timeless is not None:
                self.num_features += features_timeless.shape[1]

        else:   
            self.num_features = 0
        
        self.mask=mask

        self.feature_map = feature_map

        # print(demand.shape)
        # print(features.shape)
        # print("num_obs:", self.num_observations)
        # print("num_prod:", self.num_products)
        # print("num_feat:", self.num_features)

        self.overage_cost = self.convert_to_numpy_array(overage_cost, expand = True)
        self.underage_cost = self.convert_to_numpy_array(underage_cost, expand = True)

        # print("overage_cost:", self.overage_cost)
        # print("underage_cost:", self.underage_cost)


        self.fixed_ordering_cost = self.convert_to_numpy_array(fixed_ordering_cost, expand = True)
        self.variable_ordering_cost = self.convert_to_numpy_array(variable_ordering_cost, expand = True)

        self.max_demand_backlog = self.convert_to_numpy_array(max_demand_backlog, expand = True)

        # print("max_demand_backlog:", self.max_demand_backlog)
        # print("variable_ordering_cost:", self.variable_ordering_cost)
        # print("fixed_ordering_cost:", self.fixed_ordering_cost)
        

        self.inventory_allowed = inventory_allowed

        if not self.inventory_allowed:
            assert np.all(start_inventory == 0), "If inventory is not allowed, start_inventory must be 0"
        self.start_inventory = self.convert_to_numpy_array(start_inventory, expand = True)
        # print("setting inventory cap")
        self.inventory_cap = self.convert_to_numpy_array(inventory_cap, expand = True)

        if normalize_env_state:
            assert np.all(inventory_cap != np.inf), "If inventory is normalized, inventory_cap must be a finite number"
        
        # print("start_inventory:", self.start_inventory)
        # print("inventory_cap:", self.inventory_cap)

        # print("set inventory cap")

        if not self.inventory_allowed:
            assert holding_cost is None, "If inventory is not allowed, holding_cost must be None"
        # print("setting holding cost")
        self.holding_cost = self.convert_to_numpy_array(holding_cost, expand = True)

        # print("holding_cost:", self.holding_cost)
        # print("staring setting LT")

        if np.all(lead_time == 0):
            self.use_order_pipeline = False
        else:
            self.use_order_pipeline = True

        self.lead_time = self.convert_to_numpy_array(lead_time, expand = True)
        self.lead_time_variance = self.convert_to_numpy_array(lead_time_variance, expand = True)

        # print("lead_time:", self.lead_time)
        # print("lead_time_variance:", self.lead_time_variance)

        if np.all(lead_time_variance == 0):
            # print("setting LT")
            self.max_lead_time = self.lead_time
            # print("set LT")
        else:
            self.max_lead_time = self.convert_to_numpy_array(max_lead_time, expand = True)

        if order_limit_low is None:
            self.order_limit_low = np.zeros(self.num_products)
            self.order_limit_high = np.ones(self.num_products)
        
        self.order_limit_low = self.convert_to_numpy_array(order_limit_low, expand = True)
        self.order_limit_high = self.convert_to_numpy_array(order_limit_high, expand = True)

        # print("order_limit_low:", self.order_limit_low.shape)
        # print("order_limit_high:", self.order_limit_high.shape)

        order_limit_low = np.repeat(self.order_limit_low, actions_per_product)
        order_limit_high = np.repeat(self.order_limit_high, actions_per_product)

        self.action_space = Box(order_limit_low, order_limit_high, shape=(self.num_products*actions_per_product,))
        # print("action space:", self.action_space.shape)

        # Set observation space
        if observation_space_kwargs is None:
            observation_space_kwargs = {}

        # print("obs kwargs:", observation_space_kwargs)
        self.set_observation_space(**observation_space_kwargs)

        self.calculations = calculations

        # Set MDP info
        mdp_info = MDPInfo(observation_space=self.observation_space,
                           action_space=self.action_space,
                           horizon=horizon,
                           gamma=gamma)
        
        super().__init__(mdp_info)

        self.precision = precision

        self.lag_window = lag_window

        self.normalize_reward = normalize_reward

        # print("normalioze rewrad in env class:", self.normalize_reward)
        self.normalize_env_state = normalize_env_state

        # print("done, now resetting")

        self.record_state = record_state

        # Set state variables
        self.reset()

    def set_observation_space(self,
                              features = False,
                              inventory = False,
                              order_pipeline = False,
                              demand_backlog = False
                              ):
        """
        Set observation space. This is used to customize the observation space.

        args:
            features: whether to include features in the observation space
            inventory: whether to include inventory in the observation space
            order_pipeline: whether to include order pipeline in the observation space
            demand_backlog: whether to include demand backlog in the observation space
        """

        print(self.num_features)

        features_space = self.num_features if features else 0
        #! for now only focusing on NV problem with features only
        # inventory_space = self.num_products if inventory else 0
        # order_pipeline_space = np.sum(self.num_products * np.max(self.max_lead_time)) if order_pipeline else 0
        # demand_backlog_space = self.num_products if demand_backlog else 0

        # total_obs_space_length = features_space + inventory_space + order_pipeline_space + demand_backlog_space
        total_obs_space_length = features_space

        assert total_obs_space_length > 0, "Need at least one observation dimension"

        low_limit = np.zeros(total_obs_space_length)
        high_limit = np.zeros(total_obs_space_length)

        if features:
            low_limit[:self.num_features] = -np.ones(self.num_features)
            high_limit[:self.num_features] = np.ones(self.num_features)
        else:
            self.num_features = 0

        if inventory:
            low_limit[self.num_features:self.num_features+inventory_space] = np.zeros(inventory_space) # assuming the low inventory end is always 0
            high_limit[self.num_features:self.num_features+inventory_space] = self.inventory_cap

        if order_pipeline:
            slots_occupied = 0
            for product in range(self.num_products):
                # Note: the length of the order pipeline can vary by product. Note that for RNN processing we need to pad the order pipeline to the same length in the preprocessing step
                low_limit[self.num_features+inventory_space+slots_occupied:self.num_features+inventory_space+slots_occupied+self.max_lead_time[product]] = self.order_limit_low[product]
                high_limit[self.num_features+inventory_space+slots_occupied:self.num_features+inventory_space+slots_occupied+self.max_lead_time[product]] = self.order_limit_high[product]
                slots_occupied += self.max_lead_time[product]

        if demand_backlog:
            low_limit[self.num_features+inventory_space+order_pipeline_space:self.num_features+inventory_space+order_pipeline_space+demand_backlog_space] = 0
            high_limit[self.num_features+inventory_space+order_pipeline_space:self.num_features+inventory_space+order_pipeline_space+demand_backlog_space] = self.max_demand_backlog

        # print("features space:", features_space)
        # print("inventory space:", inventory_space)
        # print("order pipeline space:", order_pipeline_space)
        # print("demand backlog space:", demand_backlog_space)
        # print("length:", total_obs_space_length)
        self.observation_space = Box(low_limit, high_limit, shape=(total_obs_space_length,))

    def reset(self, state = None):

        # TODO: Check in mushroom where state would come from and implement here
       
        if self.num_observations == self._mdp_info.horizon:
            self.period = 0 # in this environment, the demand will include lag-window observations in the beginning
        else:
            self.period = np.random.choice(self.num_observations-self._mdp_info.horizon)
        #print("reset with period {}".format(self.period))
        self.demand_backlog = np.zeros(self.num_products)
        self.inventory = self.start_inventory.copy()
        self.order_pipeline = np.zeros((self.num_products,np.max(self.max_lead_time)))

        self.set_observation_state()

        return self.observation_state

    def set_observation_state(self):
        
        # TODO: check if this is computationally expensive
        # TODO: make sure this is consistent if some elements are there but not tracked in state

        if self.period >= self.num_observations:
            self.observation_state = np.zeros(self.observation_space.shape)
        
        else:
            
            index = self.period + self.lag_window

            # products on the batch dimension, then time,  then features,
            self.observation_state = np.zeros((self.num_products, self.lag_window, self.observation_space.shape[0]))

            relevant_demand = self.demand[index-self.lag_window:index]

            relevant_demand = np.expand_dims(relevant_demand.T, -1)

            self.observation_state[:,:,0:1] = relevant_demand


            features_timeless = np.expand_dims(self.features[1], axis=1)
            features_timeless = np.repeat(features_timeless, self.lag_window, axis=1)       

            self.observation_state[:,:,1:1+self.features[1].shape[1]] = features_timeless
            
            relevant_features = self.features[0][index-self.lag_window+1:index+1] # the current observation can be included in the state (unlike in the demand)

            relevant_features_specific = relevant_features[:,:-self.num_products*2] #! TDOD make the 2 variable

            relevant_features_specific_mapped = np.zeros((self.num_products, self.lag_window, self.feature_map.shape[1]))

            #! This is a very slow way to do this. Need to find a better way. Do profile to see if this is a bottleneck
            for i in range(self.num_products):
                # Select the features for this product based on indices_array
                # This uses advanced indexing to pick the correct features
                selected_features = relevant_features_specific[:, self.feature_map[i]]

                # Assign the selected features to the new array, duplicating across the 'history' dimension
                relevant_features_specific_mapped[i] = selected_features

            relevant_features_product = relevant_features[:,-self.num_products*2:] #! TDOD make the 2 variable
            relevant_features_product = relevant_features_product.reshape((self.lag_window, 2, self.num_products)) #! TDOD make the 2 variable
            relevant_features_product = relevant_features_product.transpose((2, 0, 1)) 

            # TODO: ADD relevant_features_specific_mapped and relevant_features_product to the state 
            self.observation_state[:,:,1+self.features[1].shape[1]:1+self.features[1].shape[1]+relevant_features_specific_mapped.shape[2]] = relevant_features_specific_mapped
            self.observation_state[:,:,1+self.features[1].shape[1]+relevant_features_specific_mapped.shape[2]:1+self.features[1].shape[1]+relevant_features_specific_mapped.shape[2]+relevant_features_product.shape[2]] = relevant_features_product
          
            #! So far only for features implemented

            # if self.inventory_allowed:

            #     if self.normalize_env_state:
            #         inv = self.inventory/self.inventory_cap
            #     else:
            #         inv = self.inventory

            #     self.observation_state[self.num_features:self.num_features+self.num_products] = inv

            # if self.use_order_pipeline:
            #     slots_occupied = 0
            #     for product in range(self.num_products):

            #         # print("pipeline_raw:", self.order_pipeline[product])
                    
            #         if self.normalize_env_state:
            #             prod = self.order_pipeline[product]/self.order_limit_high[product]
            #         else:
            #             prod = self.order_pipeline[product]

            #         # print(prod)

            #         self.observation_state[self.num_features+self.num_products+slots_occupied:self.num_features+self.num_products+slots_occupied+self.max_lead_time[product]] = prod
            #         slots_occupied += self.max_lead_time[product]

            # if np.max(self.max_demand_backlog) > 0:
            #     self.observation_state[self.num_features+self.num_products+np.sum(self.num_products * self.lead_time):] = self.demand_backlog

            self.observation_state = np.round(self.observation_state, self.precision)

            # print("inventory:", np.round(self.inventory, 2))
            # print("order pipeline:", np.round(self.order_pipeline, 2))
            # print("_______")


    def step(self, action):

        absorbing = False

        reward, info = self.calculations(self, action)

        self.set_observation_state()
        
        # if self.normalize_reward:
        #     print(self.observation_state)

        return self.observation_state, reward, absorbing, info
    
    def render(self):
        pass

    def convert_to_numpy_array(self, value, expand = False):

        """
        Convert a scalar, list, or numpy array to a numpy array.
        
        Parameters:
        - value (scalar, list, numpy.ndarray): The input value.
        
        Returns:
        - numpy.ndarray: The converted numpy array.
        """

        if np.isscalar(value):
            value = np.array([value])
        elif isinstance(value, list):
            value = np.array(value)
        elif isinstance(value, np.ndarray):
            value = value
        elif value is None:
            value = None
        else:
            raise ValueError("Invalid input type. Expected scalar, list, or numpy array.")

        if expand:
            if value is not None:
                if value.size == 1:
                    value = np.repeat(value, self.num_products)
                elif value.size == self.num_products:
                    value = value
                else:
                    raise ValueError("Invalid input size. Expected scalar, list, or numpy array of size 1 or num_products.")
        return value
