# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/agents/benchmark_agents/03_ERM_agents.ipynb.

# %% auto 0
__all__ = ['FakePolicy', 'NewsvendorData', 'LinearModel', 'MLP', 'MLP_BLOCK', 'RNN', 'SGDBase', 'LERMsgdAgent', 'MLPsgdAgent',
           'RNNsgdAgent', 'RNNsgdMetaAgent', 'LlamaRotaryEmbedding', 'rotate_half', 'apply_rotary_pos_emb',
           'CausalSelfAttention', 'find_multiple', 'MLP_block', 'RMSNorm', 'Block', 'LagLlama', 'LagLlamasgdAgent',
           'NewsvendorDataMeta', 'LagLlamasgdMetaAgent']

# %% ../../../nbs/agents/benchmark_agents/03_ERM_agents.ipynb 4
# General libraries:
import numpy as np
from scipy.stats import norm
from tqdm import trange, tqdm
from time import sleep

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast


from ..processors.processors import GetTimeSeriesAndStaticFeatures


# Mushroom libraries
from mushroom_rl.core import Agent

from timeit import default_timer as timer

from ...utils.MLtools import LRSchedulerPerStep

from torchinfo import summary

# %% ../../../nbs/agents/benchmark_agents/03_ERM_agents.ipynb 6
class FakePolicy():
    def reset(*args, **kwargs):
        pass

class NewsvendorData(Dataset):

    def __init__(self, x, y, mask=None, feature_map=None):
        # create torch tensors
        self.x=torch.from_numpy(x)
        self.y=torch.from_numpy(y)
        if mask is not None:
            self.mask = torch.from_numpy(mask)
            self.mask = self.mask.bool()
            if len(self.mask.shape) == 1:
                self.mask = self.mask.unsqueeze(1)
        else:
            self.mask = None
        self.feature_map = feature_map
        
        # convert to torch float32
        self.x=self.x.float()
        self.y=self.y.float()

        if self.feature_map is not None:
            self.n_samples=y.shape[0]*y.shape[1]
            self.index_mapping = dict()
            index_used=0
            for i in range(y.shape[0]):
                for j in range(y.shape[1]):
                    self.index_mapping[index_used] = (i,j)
                    index_used+=1
        else:
            self.n_samples=y.shape[0]

    def __getitem__(self, index):
        
        if self.feature_map is not None:
            
            coordinates = self.index_mapping[index]

            time = coordinates[0]
            product = coordinates[1]

            y = self.y[time, product]

            # print("time: ", time, "product: ", product, "y: ", y)

            # print(self.feature_map)
            # print(self.feature_map[:, product])
            # print(self.x[time])

            x = self.x[time, self.feature_map[:, product].astype(bool)]

            if self.mask is not None:
                mask = self.mask[time, product]
        
        else:
            x = self.x[index]
            y = self.y[index]
            if self.mask is not None:
                mask = self.mask[index]
        
        if self.feature_map is not None:
            if self.mask is not None:
                return x, y, mask, product
            else:
                return x, y, product
        else:
            if self.mask is not None:
                return x, y, mask
            else:
                return x, y

    def __len__(self):
        return self.n_samples

class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, relu_output=False):
        super().__init__()
        self.l1=nn.Linear(input_size, output_size)
        if relu_output:
            self.final_activation = nn.ReLU()
        else:
            self.final_activation = nn.Identity()
            
    def forward(self, x):
        out=self.l1(x)
        out=self.final_activation(out)
        return out

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_hidden_layers=3, drop_prob=0.0, relu_output=False):
        super().__init__()

        # List of layers
        layers = []

        # Input layer
        layers.append(nn.Linear(input_size, hidden_size))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(p=drop_prob))
        
        # Hidden layers
        for _ in range(num_hidden_layers-1): 
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(p=drop_prob))

        # Output layer
        layers.append(nn.Linear(hidden_size, output_size))
        if relu_output:
            layers.append(nn.ReLU()) # output is non-negative
        else:
            layers.append(nn.Identity())

        # Combine layers
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class MLP_BLOCK(nn.Module):
    def __init__(self, size_in, size_out, drop_prob):
        super(MLP_BLOCK, self).__init__()
        self.MLP_block = nn.Sequential(
            nn.Linear(size_in, size_out),
            nn.ReLU(),
            nn.Dropout(p=drop_prob),
            # add batch norm
            nn.BatchNorm1d(size_out)
        )
    
    def forward(self, x):
        return self.MLP_block(x)

class RNN(nn.Module):
    def __init__(self,
                input_size,
                hidden_size,
                output_size,
                num_hidden_layers=2,
                num_RNN_layers=2,
                drop_prob=0.0,
                num_time_series_features=None,
                lag_window=None,
                relu_output=False):
        super().__init__()

        self.input_processor = GetTimeSeriesAndStaticFeatures(num_time_series_features,lag_window)
        # self._n_output = output_shape[0]
        # self._n_features = n_features

        self._h1_recurrent = nn.GRU(int(input_size/lag_window), int(hidden_size/2), num_RNN_layers, batch_first = True)
        self._h1_dropout = nn.Dropout(p=drop_prob)

        self.mlp_blocks = nn.ModuleList()
        for i in range(num_hidden_layers):
            if i == 0:
                self.mlp_blocks.append(MLP_BLOCK(size_in = int(hidden_size/2), size_out = hidden_size, drop_prob=drop_prob))
            else:
                self.mlp_blocks.append(MLP_BLOCK(size_in = int(hidden_size/(2**(i-1))), size_out = int(hidden_size/(2**i)), drop_prob=drop_prob))

        if num_hidden_layers==0:
            self.output_layer = nn.Linear(int(hidden_size/2), output_size)
        else:
            self.output_layer = nn.Linear(int(hidden_size/(2**i)), output_size)

        if relu_output:
            self.final_activation = nn.ReLU()  # output is non-negative
        else:
            self.final_activation = nn.Identity()

    def forward(self, x):

        # if shape is 1 then expand
        if len(x.shape)==1:
            x=x.unsqueeze(0)
        x = self.input_processor(x)[0] # no static data returned as no inventory

        # RNN block
        x, _ = self._h1_recurrent(x)
        x = self._h1_dropout(x)
        x = x[:, -1, :] # only take the last time step

 
        # MLP block
        for mlp_block in self.mlp_blocks:
            x = mlp_block(x)  
        
        x = self.output_layer(x) 
        x = self.final_activation(x)

        # print(x.shape)
    
        return x    

class SGDBase(Agent):

    def __init__(self, config=None, input_size=None, hidden_size=64, output_size=1, learning_rate=0.01, num_hidden_layers=3, num_RNN_layers=None, drop_prob=0.0, l2_reg=0.0, learning_rate_scheduler=None,  scheduler_params=None, num_time_series_features=None, lag_window=None, relu_output=False):
        if self.model_type=="Linear":
            self.model=LinearModel(input_size, output_size, relu_output=relu_output)
        elif self.model_type=="MLP":
            self.model=MLP(input_size, hidden_size, output_size, num_hidden_layers=num_hidden_layers, drop_prob=drop_prob, relu_output=relu_output)
        elif self.model_type=="RNN":
            self.model=RNN(input_size, hidden_size, output_size, num_hidden_layers=num_hidden_layers, num_RNN_layers=num_RNN_layers, drop_prob=drop_prob, num_time_series_features=num_time_series_features, lag_window=lag_window, relu_output=relu_output)
        elif self.model_type=="LagLlama":
            # print("config: ", config)
            self.model=LagLlama(**config)

            #print model and number of parameters

        # check if input siize is integer
        if isinstance(input_size, int):
            input_size = [input_size]
        summary(self.model, input_size=input_size)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=l2_reg)
        
        if learning_rate_scheduler == "LinearWarmupWithDecay":
            self.scheduler = LRSchedulerPerStep(self.optimizer, base_learning_rate=learning_rate, **scheduler_params)
            self.scheduler.step()
        elif learning_rate_scheduler == None:
            self.scheduler = None
        else:
            raise ValueError("Learning rate scheduler not recognized")


        self.criterion = nn.MSELoss()


    def fit(self, X_train, y_train, mask, cu, co, batch_size=64, learning_rate=0.01, device="cpu"):
        
        if y_train.ndim == 1:
            y_train = y.reshape(-1, 1)
        
        dataset_train=NewsvendorData(X_train, y_train, mask)

        self.model.to(device)

        train_loader=DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)

        self.model.train()

        total_loss = 0
        for i, (output) in enumerate(train_loader):

                if len(output)==2:
                    feat, labels = output
                    masks=None
                else:
                    feat, labels, masks = output
                    masks=masks.to(device)
                
                feat=feat.to(device)
                labels=labels.to(device)
                outputs=self.model(feat)

                loss = torch.mean(SGDBase.pinball_loss(cu, co, labels, outputs, masks))
                
                total_loss += loss.item()

                #backward
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        print("training loss: ", total_loss)   
        self.model.eval()
        self.model.to("cpu")

        return self.model

    def predict(self, X):

        # if len(X.shape)==1:
        #     X = np.expand_dims(X, axis=0)

        self.model.to("cpu")
        
        self.model.eval()
        with torch.no_grad():
            X=torch.from_numpy(X)
            X=X.float()
            
            output=self.model(X)
            output=output.numpy()
        
        if len(output.shape)==2:
            output=output.reshape(-1)
        
        return output

    def train(self):
        self.model.train()
    
    def eval(self):
        self.model.eval()   
    
    ### Helper functions

    @staticmethod
    def max_or_zero(data):
        value = torch.max(data, torch.zeros_like(data))
        return value

    @staticmethod
    def pinball_loss(cu, co, demand, order_quantity, mask):

        # TODO: check that this works in termns of shapes

        if len(demand.shape)==1:
            demand = demand.unsqueeze(1)
        
        assert demand.shape == order_quantity.shape

        cu = torch.tensor(cu, dtype=torch.float32)
        co = torch.tensor(co, dtype=torch.float32)

        cu = cu.to(demand.device)
        co = co.to(demand.device)

        underage=cu*SGDBase.max_or_zero(demand-order_quantity)
        overage=co*SGDBase.max_or_zero(order_quantity-demand)

        loss=underage+overage
        
        if mask is not None:
            loss = loss*mask
        return loss

class LERMsgdAgent(SGDBase):
    def __init__(self,
                    input_size,
                    output_size,
                    cu,
                    co,
                    final_activation="identity",
                    batch_size=128,
                    learning_rate=0.01,
                    device="cpu",
                    agent_name = "LERM"
                    ):
        self.name=agent_name
        self.model_type="Linear"
        self.cu = cu
        self.co = co
        self.device = device
        self.batch_size=batch_size
        self.learning_rate=learning_rate
        self.device=device


        self.policy=FakePolicy()
        self._postprocessors = list()
        self._preprocessors = list()
        self.train_directly=True
        self.train_mode = "epochs"

        if final_activation=="identity":
            self.final_activation = False
        elif final_activation=="relu":
            self.final_activation = True

        super().__init__(input_size=input_size, hidden_size=None, output_size=output_size, learning_rate=learning_rate, relu_output=self.final_activation)

    def fit_epoch(self, features_train, demand_train, mask=None):
        super().fit(features_train, demand_train, mask=mask, cu=self.cu, co=self.co, batch_size=self.batch_size, learning_rate=self.learning_rate, device=self.device)

    def draw_action(self, X):
        return super().predict(X)

class MLPsgdAgent(SGDBase):
    def __init__(self,
                    input_size,
                    output_size,
                    cu,
                    co,
                    final_activation="identity",
                    hidden_size=64,
                    batch_size=128,
                    learning_rate=0.01,
                    device="cpu",
                    agent_name = "DLNV",
                    num_hidden_layers=3,
                    drop_prob=0.0,
                    l2_reg=0.0,
                    ): 
        
        self.name=agent_name
        self.model_type="MLP"
        self.cu = cu
        self.co = co
        self.device = device
        self.hidden_size = hidden_size
        self.batch_size=batch_size
        self.learning_rate=learning_rate
        self.device=device

        self.policy=FakePolicy()
        self._postprocessors = list()
        self._preprocessors = list()
        self.train_directly=True
        self.train_mode = "epochs"

        if final_activation=="identity":
            self.final_activation = False
        elif final_activation=="relu":
            self.final_activation = True

        super().__init__(input_size=input_size, hidden_size=hidden_size, output_size=output_size, learning_rate=learning_rate, num_hidden_layers=num_hidden_layers, drop_prob=drop_prob, l2_reg=l2_reg, relu_output=self.final_activation)

    def fit_epoch(self, features_train, demand_train, mask=None):
        super().fit(features_train, demand_train, mask=mask, cu=self.cu, co=self.co, batch_size=self.batch_size, learning_rate=self.learning_rate, device=self.device)

    def draw_action(self, X):
        return super().predict(X)

class RNNsgdAgent(SGDBase):
    def __init__(self,
                    input_size,
                    output_size,
                    cu,
                    co,
                    num_time_series_features,
                    lag_window,
                    final_activation="identity",
                    hidden_size=64,
                    batch_size=128,
                    learning_rate=0.01,
                    device="cpu",
                    agent_name = "DLNV_RNN",
                    num_RNN_layers = 2,
                    num_hidden_layers=3,
                    drop_prob=0.0,
                    l2_reg=0.0,
                    ): 
        
        self.name=agent_name
        self.model_type="RNN"
        self.cu = cu
        self.co = co
        self.device = device
        self.hidden_size = hidden_size
        self.batch_size=batch_size
        self.learning_rate=learning_rate
        self.device=device

        self.policy=FakePolicy()
        self._postprocessors = list()
        self._preprocessors = list()
        self.train_directly=True
        self.train_mode = "epochs"
    
        if final_activation=="identity":
            self.final_activation = False
        elif final_activation=="relu":
            self.final_activation = True

        super().__init__(input_size=input_size, hidden_size=hidden_size, output_size=output_size, learning_rate=learning_rate, num_hidden_layers=num_hidden_layers, num_RNN_layers=num_RNN_layers, drop_prob=drop_prob, l2_reg=l2_reg, num_time_series_features=num_time_series_features, lag_window=lag_window, relu_output=self.final_activation, )
        
    def fit_epoch(self, features_train, demand_train, mask=None):
        super().fit(features_train, demand_train, mask=mask, cu=self.cu, co=self.co, batch_size=self.batch_size, learning_rate=self.learning_rate, device=self.device)

    def draw_action(self, X):
        return super().predict(X)

class RNNsgdMetaAgent(SGDBase):
    def __init__(self,
                    feature_map,
                    input_size,
                    output_size,
                    cu,
                    co,
                    num_time_series_features,
                    lag_window,
                    final_activation="identity",
                    hidden_size=64,
                    batch_size=128,
                    learning_rate=0.01,
                    device="cpu",
                    agent_name = "DLNV_RNN",
                    num_RNN_layers = 2,
                    num_hidden_layers=3,
                    drop_prob=0.0,
                    l2_reg=0.0,
                    ): 
        
        self.feature_map = feature_map
        self.name=agent_name
        self.model_type="RNN"
        self.cu = cu
        self.co = co
        self.device = device
        self.hidden_size = hidden_size
        self.batch_size=batch_size
        self.learning_rate=learning_rate
        self.device=device

        self.policy=FakePolicy()
        self._postprocessors = list()
        self._preprocessors = list()
        self.train_directly=True
        self.train_mode = "epochs"

        # print("feature map: ", self.feature_map)
    
        if final_activation=="identity":
            self.final_activation = False
        elif final_activation=="relu":
            self.final_activation = True

        input_size = np.sum(self.feature_map[:,0])

        super().__init__(input_size=input_size, hidden_size=hidden_size, output_size=1, learning_rate=learning_rate, num_hidden_layers=num_hidden_layers, num_RNN_layers=num_RNN_layers, drop_prob=drop_prob, l2_reg=l2_reg, num_time_series_features=num_time_series_features, lag_window=lag_window, relu_output=self.final_activation, )

    def fit_epoch(self, features_train, demand_train, mask=None):
        self.fit(features_train, demand_train, mask=mask, cu=self.cu, co=self.co, batch_size=self.batch_size, learning_rate=self.learning_rate, device=self.device)
    
    def fit(self, X_train, y_train, mask, cu, co, batch_size=64, learning_rate=0.01, device="cpu"):
        
        if y_train.ndim == 1:
            y_train = y.reshape(-1, 1)
         
        dataset_train=NewsvendorData(X_train, y_train, mask, self.feature_map)

        self.model.to(device)
        
        train_loader=DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=False) # make shuffle true

        self.model.train()

        total_loss = 0
        for i, (output) in tqdm(enumerate(train_loader)):

                if len(output)==3:
                    feat, labels, product = output
                    masks=None
                else:
                    feat, labels, masks, product = output
                    masks=masks.to(device)   

                feat=feat.to(device)
                labels=labels.to(device)
                outputs=self.model(feat)

                cu_selected = cu[product]
                co_selected = co[product]
                
                loss_per_product = self.pinball_loss(cu_selected, co_selected, labels, outputs, masks)
                loss = torch.mean(loss_per_product)
                
                total_loss += loss.item()

                #backward
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        print("training loss: ", total_loss)   
        self.model.eval()
        self.model.to("cpu")

        return self.model

    def draw_action(self, X):
        return self.predict(X)
    
    def predict(self, X):
        self.model.eval()
        self.model.to(self.device)

        if len(X.shape) == 1:
            X = np.expand_dims(X, axis=0)

        # Initialize a list to hold the X_i batches
        X_batches = []

        # Create batches
        for i in range(self.feature_map.shape[1]):
            X_i = X[:, self.feature_map[:, i].astype(bool)]
            X_batches.append(X_i)

        X_batch = np.concatenate([x for x in X_batches], axis=0)
        X_batch = torch.from_numpy(X_batch)
        X_batch = X_batch.float().to(self.device)

        with torch.no_grad():
            output = self.model(X_batch)
            output = output.cpu().numpy()

        # reduce output dimension

        output = output.squeeze(1)
        
        # print(outputs)

        # check if outputs need to got to cpu or handled by mushroomrl

        return output

    @staticmethod
    def pinball_loss(cu, co, demand, order_quantity, mask):

        if len(demand.shape)==1:
            demand = demand.unsqueeze(1)
        
        assert demand.shape == order_quantity.shape

        cu = torch.tensor(cu, dtype=torch.float32)
        co = torch.tensor(co, dtype=torch.float32)

        cu = cu.unsqueeze(1)
        co = co.unsqueeze(1)

        cu = cu.to(demand.device)
        co = co.to(demand.device)

        underage_quantity = SGDBase.max_or_zero(demand-order_quantity)
        overage_quantity = SGDBase.max_or_zero(order_quantity-demand)

        assert cu.shape == underage_quantity.shape
        assert co.shape == overage_quantity.shape

        underage=cu*underage_quantity
        overage=co*overage_quantity

        loss=underage+overage
 
        if mask is not None:
            loss = loss*mask
        return loss


# %% ../../../nbs/agents/benchmark_agents/03_ERM_agents.ipynb 8
class LlamaRotaryEmbedding(torch.nn.Module):

    # Rotary positional embeddings (RoPE) based on https://arxiv.org/abs/2104.09864, code exactly as implemented in https://github.com/time-series-foundation-models/lag-llama

    # TODO: potentially also include LlamaLinearScalingRotaryEmbedding and LlamaDynamicNTKScalingRotaryEmbedding

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer(
            "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
        )
        self.register_buffer(
            "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
        )

    def forward(self, device, dtype, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype),
        )

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd_per_head, n_head, block_size, dropout) -> None:
        super().__init__()
        # query projections for all heads, but in a batch
        self.q_proj = nn.Linear(
            n_embd_per_head * n_head,
            n_embd_per_head * n_head,
            bias=False,
        )
        # key, value projections
        self.kv_proj = nn.Linear(
            n_embd_per_head * n_head,
            2 * n_embd_per_head * n_head,
            bias=False,
        )
        # output projection
        self.c_proj = nn.Linear(
            n_embd_per_head * n_head,
            n_embd_per_head * n_head,
            bias=False,
        )

        self.n_head = n_head
        self.n_embd_per_head = n_embd_per_head
        self.block_size = block_size
        self.dropout = dropout

        # self.rope_scaling = rope_scaling # so far on default RoPE implemented
        self.rope_scaling=None
        # self._rope_scaling_validation() # so far on default RoPE implemented

        self._init_rope()
        self.kv_cache = None

    def _init_rope(self):
        if self.rope_scaling is None:
            # print(self.n_embd_per_head, self.block_size)
            self.rotary_emb = LlamaRotaryEmbedding(
                self.n_embd_per_head, max_position_embeddings=self.block_size
            )
        # note: the following is not yet implemented 
        else:
            scaling_type = self.rope_scaling["type"]
            scaling_factor = self.rope_scaling["factor"]
            if scaling_type == "nope":
                self.rotary_emb = None
            elif scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.n_embd_per_head,
                    max_position_embeddings=self.block_size,
                    scaling_factor=scaling_factor,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.n_embd_per_head,
                    max_position_embeddings=self.block_size,
                    scaling_factor=scaling_factor,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    # def _rope_scaling_validation(self):
    #     """
    #     Validate the `rope_scaling` configuration.
    #     """
    #     if self.rope_scaling is None:
    #         return

    #     if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
    #         raise ValueError(
    #             "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
    #             f"got {self.rope_scaling}"
    #         )
    #     rope_scaling_type = self.rope_scaling.get("type", None)
    #     rope_scaling_factor = self.rope_scaling.get("factor", None)
    #     if rope_scaling_type is None or rope_scaling_type not in [
    #         "linear",
    #         "dynamic",
    #         "nope",
    #     ]:
    #         raise ValueError(
    #             f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
    #         )
    #     if rope_scaling_type in ["linear", "dynamic"]:
    #         if (
    #             rope_scaling_factor is None
    #             or not isinstance(rope_scaling_factor, float)
    #             or rope_scaling_factor < 1.0
    #         ):
    #             raise ValueError(
    #                 f"`rope_scaling`'s factor field must be an float >= 1, got {rope_scaling_factor}"
    #             )

    def forward(self, x: torch.Tensor, use_kv_cache: bool) -> torch.Tensor:
        # batch size, sequence length, embedding dimensionality (n_embd)

        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.q_proj(x)
        k, v = self.kv_proj(x).split(self.n_embd_per_head * self.n_head, dim=2)

        if use_kv_cache:
            # Optimized for single next prediction
            if self.kv_cache is not None:
                # Update cache
                k = torch.cat([self.kv_cache[0], k], dim=1)[:, 1:]
                v = torch.cat([self.kv_cache[1], v], dim=1)[:, 1:]
                self.kv_cache = k, v
            else:
                # Build cache
                self.kv_cache = k, v

        k = k.view(B, -1, self.n_head, self.n_embd_per_head).transpose(
            1, 2
        )  # (B, nh, T, hs)
        q = q.view(B, -1, self.n_head, self.n_embd_per_head).transpose(
            1, 2
        )  # (B, nh, T, hs)
        v = v.view(B, -1, self.n_head, self.n_embd_per_head).transpose(
            1, 2
        )  # (B, nh, T, hs)

        if self.rotary_emb is not None:
            cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=T)
            q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=None)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        #  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        #  att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        #  att = F.softmax(att, dim=-1)
        #  y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        # efficient attention using Flash Attention CUDA kernels
        # When using kv cache at inference, is_causal=False since decoder is causal, at each generation step we want
        # to avoid recalculating the same previous token attention

        # TODO: understand how causeal gets turned of in the encoder (we want to process the input squence at once)

        if use_kv_cache:
            y = F.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
            )
        else:
            y = F.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
            )
        
        # debug
        if not torch.isfinite(y).all():
            print("y is not finite")
            print(y)
            print(q)
            print(k)
            print(v)

        # re-assemble all head outputs side by side
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # output projection
        y = self.c_proj(y)

        return y

def find_multiple(n: int, k: int) -> int:
    if n % k == 0:
        return n
    return n + k - (n % k)

class MLP_block(nn.Module):
    def __init__(self, n_embd_per_head, n_head, min_multiple = 256, n_mlp_layers=2) -> None:
        super().__init__()
        hidden_dim = 4 * n_embd_per_head * n_head
        n_hidden = int(2 * hidden_dim / 3)
        self.n_mlp_layers = n_mlp_layers

        # TODO: ensure the hidden dim could also work on small data (might need to scale down the hidden dim)
        n_hidden = find_multiple(n_hidden, min_multiple)

        print("hidden dimension in MLP block: ", n_hidden)

        self.c_fc1 = nn.Linear(
            n_embd_per_head * n_head, n_hidden, bias=False
        )
        if n_mlp_layers ==2:
            self.c_fc2 = nn.Linear(
                n_embd_per_head * n_head, n_hidden, bias=False
            )
        
        self.c_proj = nn.Linear(
            n_hidden, n_embd_per_head * n_head, bias=False
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.n_mlp_layers == 2:
            x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
        else:
            x = F.silu(self.c_fc1(x))
        x = self.c_proj(x)
        return x

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.

    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
    """

    def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
        super().__init__()
        self.scale = nn.Parameter(torch.ones(size))
        self.eps = eps
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # NOTE: the original RMSNorm paper implementation is not equivalent
        # norm_x = x.norm(2, dim=self.dim, keepdim=True)
        # rms_x = norm_x * d_x ** (-1. / 2)
        # x_normed = x / (rms_x + self.eps)
        # keep RMSNorm in float32
        norm_x = x.to(torch.float32).pow(2).mean(dim=self.dim, keepdim=True)
        x_normed = x * torch.rsqrt(norm_x + self.eps)
        output = (self.scale * x_normed).type_as(x)
        return output

class Block(nn.Module):
    def __init__(self, n_embd_per_head, n_head, block_size, dropout, min_multiple = 256, n_mlp_layers=2) -> None:
        super().__init__()
        self.rms_1 = RMSNorm(n_embd_per_head * n_head)
        self.attn = CausalSelfAttention(n_embd_per_head, n_head, block_size, dropout)
        self.rms_2 = RMSNorm(n_embd_per_head * n_head)
        self.mlp = MLP_block(n_embd_per_head, n_head, min_multiple = min_multiple, n_mlp_layers=n_mlp_layers)

    def forward(self, x: torch.Tensor, use_kv_cache: bool) -> torch.Tensor:
        x = x + self.attn(self.rms_1(x), use_kv_cache)
        y = x + self.mlp(self.rms_2(x))
        return y

class LagLlama(nn.Module):
    def __init__(self,
                input_size,
                output_size,

                context_length = 16,
                max_context_length = 16, # check what context and max_context do
                n_layer = 1, # default LagLlama paper: 32
                n_head = 4, # default LagLlama paper: 32
                n_embd_per_head = 32, # default LagLlama paper: 128
                rope_scaling = None, # not yet implemetned (scaled version of rotary embeddings)
                # distr_output: Only needed for probabilistic forecasting
                # num_parallel_samples: int = 100, # Only needed for probabilistic forecasting
                # time_feat: True, # Already included in our multi-variate time series input
                min_multiple = 256,
                n_mlp_layers = 2,
            
                drop_prob=0.0,
                num_time_series_features=None,
                lag_window=None,
                relu_output=False):

        super().__init__()

        block_size = max_context_length
        self.context_length = context_length

        # did not implement optional time_features as all inputs are day-level.
        # Any other time-features (e.g., weekday) can be included in preprocessing
        # and added as part of the feature vector

        # self.num_parallel_samples = num_parallel_samples
        self.input_processor = GetTimeSeriesAndStaticFeatures(num_time_series_features,lag_window)

        self.input_size = input_size # size of input vector (features + demand) on one timestep without positional embeddings
        
        # print("input size: ", self.input_size)
        #! Note that no scaling is implemented in the model, needs to be addressed in pre-processing

        # self.distr_output = distr_output # TODO check later how to get rid of

        # changed from PropabilisticForecastingNetwork to sinlge output
        self.param_proj = nn.Linear(n_embd_per_head * n_head, output_size
        )

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Linear(
                    self.input_size, n_embd_per_head * n_head
                ),
                h=nn.ModuleList([Block(n_embd_per_head, n_head, block_size, drop_prob, min_multiple = min_multiple, n_mlp_layers=n_mlp_layers) for _ in range(n_layer)]),
                ln_f=RMSNorm(n_embd_per_head * n_head),
            )
        )
        self.y_cache = False  # used at time of inference when kv cached is used
        
        if relu_output:
            self.final_activation = nn.ReLU()  # output is non-negative
        else:
            self.final_activation = nn.Identity()


        def _init_weights(self, module: nn.Module) -> None:
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(
                    module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layer)
                )
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(
                    module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layer)
                )
                
        # def prepare_input(
        #     self,
        #     past_target: torch.Tensor,
        #     past_observed_values: torch.Tensor,
        #     past_time_feat: Optional[torch.Tensor] = None,
        #     future_time_feat: Optional[torch.Tensor] = None,
        #     future_target: Optional[torch.Tensor] = None,
        # ):
            
        #     #####################
        #     #! Took scaling out as input is already scaled.
        #     # TODO: check if scaling can still improve performance (e.g., due to local sacles)

        #     # scaled_past_target, loc, scale = self.scaler(
        #     #     past_target, past_observed_values
        #     # )  # Data is standardized (past_observed_values is passed as "weights" parameter) # (bsz, context_length+max(self.lags_seq)
        #     #####################

        #     # In the below code, instead of max(self.lags_seq), it was previously -self.context_length
        #     if future_target is not None:
        #         input = torch.cat(
        #             (
        #                 scaled_past_target[..., max(self.lags_seq) :],  # Just the context
        #                 (future_target[..., :-1] - loc)
        #                 / scale,  # Not sure about the -1 here. Maybe so since the last value isn't used in the model for prediction of any new values. also if the prediction length is 1, this doesn't really affect anything
        #             ),
        #             dim=-1,
        #         )  # Shape is (bsz, context_length+(pred_len-1))
        #     else:
        #         input = scaled_past_target[..., max(self.lags_seq) :]
        #     if (past_time_feat is not None) and (future_time_feat is not None):
        #         time_feat = (
        #             torch.cat(
        #                 (
        #                     past_time_feat[..., max(self.lags_seq) :, :],
        #                     future_time_feat[..., :-1, :],
        #                 ),
        #                 dim=1,
        #             )
        #             if future_time_feat is not None
        #             else past_time_feat[..., max(self.lags_seq) :, :]
        #         )

        #     prior_input = (
        #         past_target[..., : max(self.lags_seq)] - loc
        #     ) / scale  # This the history used to construct lags.  # bsz, max(self.lags_seq)

        #     lags = lagged_sequence_values(
        #         self.lags_seq, prior_input, input, dim=-1
        #     )  # Lags are added as an extra dim. Shape is (bsz, context_length+(pred_len-1), len(self.lags_seq))

        #     static_feat = torch.cat(
        #         (loc.abs().log1p(), scale.log()), dim=-1
        #     )  # (bsz, 2) (loc and scale are concatenated)
        #     expanded_static_feat = unsqueeze_expand(
        #         static_feat, dim=-2, size=lags.shape[-2]
        #     )  # (bsz, context_length+(pred_len-1), 2)
        #     # expanded_static_feat: (bsz, context_length+(pred_len-1), len(self.lags_seq) + 2); (bsz, 1); (bsz, 1)

        #     if past_time_feat is not None:
        #         return (
        #             torch.cat((lags, expanded_static_feat, time_feat), dim=-1),
        #             loc,
        #             scale,
        #         )
        #     else:
        #         return torch.cat((lags, expanded_static_feat), dim=-1), loc, scale

    def forward(    self,
                    x: torch.Tensor,
                    use_kv_cache: bool = False,) -> torch.Tensor:

        #! not needed for meta agent. If standard-agent is needed put this processing step into the specific class.
        # # if shape is 1 then expand
        # if len(x.shape)==1:
        #     x=x.unsqueeze(0)
        # x = self.input_processor(x) # no static data returned as no inventory


        # TODO: Confirm size of x is btc
        
        (B, T, C) = x.size()

        # past_target = None
        # past_observed_values = None
        # past_time_feat = None
        # future_time_feat = None
        # future_target = None

        # transformer_input, loc, scale = self.prepare_input(
        #     past_target=past_target,
        #     past_observed_values=None, # not scaling applied
        #     future_target=future_target,
        #     past_time_feat=past_time_feat,
        #     future_time_feat=future_time_feat,
        # )

        transformer_input = x
        #transformer_input = transformer_input.to(torch.float16) # solution to problem of getting NaN during training

        if use_kv_cache and self.y_cache:
            used_kv_cache = True
            # Only use the most recent one, rest is in cache
            transformer_input = transformer_input[:, -1:]
        else:
            used_kv_cache = False
        
        # print("before embedding:", x.shape)
        x = self.transformer.wte(
            transformer_input
        )  # token embeddings of shape (b, t, n_embd_per_head*n_head) # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head)
        # print("after embedding", x.shape)

        # sleep(2)

        embeddings = x
        if torch.isnan(x).any():
            print("nan in embeddings")
        if not torch.isfinite(x).all():
            print("infinite in embeddings")
            for i in x:
                if not torch.isfinite(i).all():
                    print("infinite in one row")
                    print(i)
        
        for block in self.transformer.h:
            x = block(x, use_kv_cache)
        x = self.transformer.ln_f(
            x
        )  # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head)
        if use_kv_cache:
            self.y_cache = True

        raw_output = x
        
        
        output = self.param_proj(
            x
        )  # (bsz, context_length+(pred_len-1)) ; (bsz, context_length+(pred_len-1))

        output_after_projection = output
        
        output = self.final_activation(output)

        # check if na in output
        if torch.isnan(output).any():
            print("nan in output")
            print("used_kv_cache: ", used_kv_cache)
    
            print(transformer_input)
            if torch.isnan(transformer_input).any():
                print("nan in transformer_input")
            else:
                print("no nan in input")
            print(raw_output)
            print(embeddings)
            print(output_after_projection)
            print(output)

            print(x)
            
        output = output[:, -1, :]
        # print(output.shape)
        return output
    
    def reset_cache(self) -> None:
        """
        Resets all cached key-values in attention.
        Has to be called after prediction loop in predictor
        """
        for block in self.transformer.h:
            block.y_cache = None
            block.attn.kv_cache = None

class LagLlamasgdAgent(SGDBase):

    def __init__(self,
                    input_size,
                    output_size,
                    cu,
                    co,
                    num_time_series_features,
                    lag_window,
                    agent_name = "DLNV_LagLlama",


                    final_activation="identity",

                    # TRansformer block params
                    context_length = None, # if context length is not provided, it is set to the lag_window
                    max_context_length = None,  # if max_context_length is not provided, it is set to the context_length
                                                # max context length used for RoPE - can be longer than lag_window during training
                    n_layer = 1, # default LagLlama paper: 32
                    n_head = 4, # default LagLlama paper: 32
                    n_embd_per_head = 32, # default LagLlama paper: 128
                    rope_scaling = None, # not yet implemented (scaled version of rotary embeddings)
                    min_multiple = 256,
                    n_mlp_layers = 2,
                    
                    # General params
                    drop_prob=0.0, 
                    batch_size=128,
                    learning_rate=0.01,
                    l2_reg=0.0,
                    device="cpu",

                    ):

        print("in init") 
        
        self.name=agent_name
        self.model_type="LagLlama"
        self.cu = cu
        self.co = co
        self.device = device

        if final_activation=="identity":
            final_activation = False
        elif final_activation=="relu":
            final_activation = True

        if context_length is None:
            context_length = lag_window
        if max_context_length is None:
            max_context_length = context_length

        config = {

            "input_size": input_size,
            "output_size": output_size,
            "num_time_series_features": num_time_series_features,

            "context_length": context_length,
            "max_context_length": max_context_length,
            "n_layer": n_layer,
            "n_head": n_head,
            "n_embd_per_head": n_embd_per_head,
            "rope_scaling": rope_scaling,
            "min_multiple": min_multiple,
            "n_mlp_layers": n_mlp_layers,

            "drop_prob": drop_prob,
            "num_time_series_features": num_time_series_features,
            "lag_window": lag_window,
            "relu_output": final_activation
        }

        
        self.batch_size=batch_size
        self.learning_rate=learning_rate

        self.policy=FakePolicy()
        self._postprocessors = list()
        self._preprocessors = list()
        self.train_directly=True
        self.train_mode = "epochs"
    
        if final_activation=="identity":
            self.final_activation = False
        elif final_activation=="relu":
            self.final_activation = True

        print("init parent")
        print("config: ", config)
        super().__init__(learning_rate = learning_rate, l2_reg=l2_reg, config=config)


    def fit_epoch(self, features_train, demand_train, mask=None):
        super().fit(features_train, demand_train, mask=mask, cu=self.cu, co=self.co, batch_size=self.batch_size, learning_rate=self.learning_rate, device=self.device)

    def draw_action(self, X):
        return super().predict(X)


class NewsvendorDataMeta(Dataset):

    def __init__(self, x_time_dependent, x_timeless, y, feature_map, lag_window, feature_size, lag_features):

        # Function to handle mask not built in.

        self.x_time_dependent = x_time_dependent
        self.x_timeless = x_timeless
        self.y = y
        self.feature_map = feature_map

        self.lag_window = lag_window
        self.feature_size = feature_size
        self.num_products = y.shape[1]

        print("lag lookback: ", lag_window)

        self.n_samples=(y.shape[0]-lag_window)*y.shape[1]
        self.index_mapping = dict()

        self.lag_features = lag_features
        
        index_used=0
        for i in range(y.shape[0]-lag_window):
            for j in range(y.shape[1]):
                self.index_mapping[index_used] = (i,j)
                index_used+=1

    def __getitem__(self, index):
        
        coordinates = self.index_mapping[index]

        time = coordinates[0]+self.lag_window
        product = coordinates[1]

        y = self.y[:, product]

        x = np.zeros((self.lag_window, self.feature_size))

        ####### ADD LAG DEMAND

        relevant_demand = np.zeros((self.lag_window, len(self.lag_features)))
        for i in range(len(self.lag_features)):
            lag = self.lag_features[i]
            start_index = time-self.lag_window-lag+1
            end_index = time-lag+1

            if start_index < 0:
                if end_index >0:
                    num_zeros = np.abs(start_index)
                    relevant_demand[:num_zeros, i] = 0
                    relevant_demand[num_zeros:, i] = y[:end_index].T
                else:
                    relevant_demand[:, i] = 0
            else:
                relevant_demand[:, i] = y[time-self.lag_window-lag+1:time-lag+1]
        x[:,0:len(self.lag_features)] = relevant_demand

        # relevant_demand = np.zeros((self.lag_window, len(self.lag_features)))
        # for i in range(len(self.lag_features)):
        #     lag = self.lag_features[i]
        #     start_index = time - self.lag_window - lag + 1
        #     end_index = time - lag + 1

        #     if start_index < 0:
        #         # If the start_index is negative, determine how many slots should be zero
        #         num_zeros = abs(start_index)
        #         # Adjust the start_index to 0 for valid slicing
        #         start_index = 0
        #         # First, fill the beginning part of the column with zeros (if needed)
        #         relevant_demand[:num_zeros, i] = 0
        #         # Then, fill the rest with data from 'y'
        #         relevant_demand[num_zeros:, i] = y[start_index:end_index]
        #     else:
        #         # When all the data is within the available range
        #         relevant_demand[:, i] = y[start_index:end_index]

        # x[:, 0:len(self.lag_features)] = relevant_demand

        ####### ADD TIMELESS FEATURES
        features_timeless = np.expand_dims(self.x_timeless[product], axis=0)
        features_timeless = np.repeat(features_timeless, self.lag_window, axis=0)     
        x[:,len(self.lag_features):len(self.lag_features)+features_timeless.shape[1]] = features_timeless  

        ####### ADD TIME DEPENDENT FEATURES
        relevant_features = self.x_time_dependent[(time-self.lag_window+1):time+1] # +1 because features of current period are visible to the agent.

        # get starting overarching features
        relevant_features_specific = relevant_features[:,:-self.num_products*2] #! TDOD make the 2 variable
        relevant_features_specific_mapped = relevant_features_specific[:, self.feature_map[product]]

        # get features unique per product
        relevant_features_product = relevant_features[:,-self.num_products*2:][:, [product, self.num_products+product]] #! TDOD make the 2 variable

        x[:,len(self.lag_features)+features_timeless.shape[1]:len(self.lag_features)+features_timeless.shape[1]+relevant_features_specific_mapped.shape[1]] = relevant_features_specific_mapped
        x[:,len(self.lag_features)+features_timeless.shape[1]+relevant_features_specific_mapped.shape[1]:len(self.lag_features)+features_timeless.shape[1]+relevant_features_specific_mapped.shape[1]+relevant_features_product.shape[1]] = relevant_features_product
        
        y_target = y[time]
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.float32)
        product = torch.tensor(product, dtype=torch.long)

        return x, y_target, product

    def __len__(self):
        return self.n_samples

class LagLlamasgdMetaAgent(SGDBase):

    def __init__(self,
                    feature_map,
                    input_size,
                    output_size,
                    cu,
                    co,
                    num_time_series_features,
                    lag_window,
                    lag_features = [1],
                    agent_name = "DLNV_LagLlama",

                    final_activation="identity",

                    # TRansformer block params
                    context_length = None, # if context length is not provided, it is set to the lag_window
                    max_context_length = None,  # if max_context_length is not provided, it is set to the context_length
                                                # max context length used for RoPE - can be longer than lag_window during training
                    n_layer = 1, # default LagLlama paper: 32
                    n_head = 4, # default LagLlama paper: 32
                    n_embd_per_head = 32, # default LagLlama paper: 128
                    rope_scaling = None, # not yet implemented (scaled version of rotary embeddings)
                    min_multiple = 256,
                    n_mlp_layers = 2,
                    
                    # General params
                    drop_prob=0.0, 
                    batch_size=128, # batch size for training
                    learning_rate=0.01,
                    l2_reg=0.0,
                    learning_rate_scheduler=None, # Other: "LinearWarmupWithDecay"
                    scheduler_params=None, # For LinearWarmupWithDecay warmup_steps=..., learning_rate will be interpreted as base learning rate.
                    device="cpu",

                    max_batch_size=2048, # relevant for inference only

                    weight_products = False,
                    ):

        print("in init") 
        
        self.feature_map = feature_map
        
        self.name=agent_name
        self.model_type="LagLlama"
        self.cu = cu
        self.co = co
        self.device = device

        print("in lag lama agent:", input_size)
        output_size = 1

        if final_activation=="identity":
            final_activation = False
        elif final_activation=="relu":
            final_activation = True

        if context_length is None:
            context_length = lag_window
        if max_context_length is None:
            max_context_length = context_length

        config = {

            "input_size": input_size,
            "output_size": output_size,
            "num_time_series_features": num_time_series_features,

            "context_length": context_length,
            "max_context_length": max_context_length,
            "n_layer": n_layer,
            "n_head": n_head,
            "n_embd_per_head": n_embd_per_head,
            "rope_scaling": rope_scaling,
            "min_multiple": min_multiple,
            "n_mlp_layers": n_mlp_layers,

            "drop_prob": drop_prob,
            "num_time_series_features": num_time_series_features,
            "lag_window": lag_window,
            "relu_output": final_activation
        }

        
        self.batch_size=batch_size
        self.learning_rate=learning_rate

        self.policy=FakePolicy()
        self._postprocessors = list()
        self._preprocessors = list()
        self.train_directly=True
        self.train_mode = "epochs"

        self.lag_window = lag_window
        self.input_size = input_size

        self.max_batch_size = max_batch_size
    
        if final_activation=="identity":
            self.final_activation = False
        elif final_activation=="relu":
            self.final_activation = True

        self.weight_products = weight_products

        self.lag_features = lag_features

        super().__init__(learning_rate = learning_rate, l2_reg=l2_reg, learning_rate_scheduler=learning_rate_scheduler, scheduler_params=scheduler_params, config=config)
    
    def fit_epoch(self, features_train, demand_train, mask=None):
        self.fit(features_train, demand_train, mask=mask, cu=self.cu, co=self.co, batch_size=self.batch_size, learning_rate=self.learning_rate, device=self.device)

    def fit(self, X_train, y_train, mask, cu, co, batch_size=64, learning_rate=0.01, device="cpu"):

        # self.model.half()

        torch.autograd.set_detect_anomaly(True)
        scaler = GradScaler()

        start_total_time = timer()
        start_preparation = timer()

        features_time_dependent = X_train[0]
        features_timeless = X_train[1]
        
        dataset_train=NewsvendorDataMeta(features_time_dependent, features_timeless, y_train, self.feature_map, self.lag_window, self.input_size, self.lag_features)

        self.model.to(device)
        
        train_loader=DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=1,) # make shuffle true

        self.model.train()

        total_loss = 0

        end_preparation = timer()

        start_time_loop = timer()
        time_to_train_model = 0

        for i, (output) in tqdm(enumerate(train_loader)):


            feat, labels, product = output

            start_time_training = timer()

            feat=feat.to(device)
            labels=labels.to(device)
            with autocast():
                outputs=self.model(feat.to(torch.float16))

                cu_selected = cu[product]
                co_selected = co[product]

                loss_per_product = self.pinball_loss(cu_selected, co_selected, labels, outputs)
            
                if self.weight_products:
                    # print(loss_per_product.shape)
                    # print(torch.sum(loss_per_product))
                    loss_per_product = loss_per_product * weights[product].unsqueeze(-1)
                    # print(loss_per_product.shape)
                    # print(torch.sum(loss_per_product))
                
                loss = torch.mean(loss_per_product)
            
                total_loss += loss.item()

                #backward
                self.optimizer.zero_grad()
                scaler.scale(loss).backward()
                #loss.backward()
                scaler.step(self.optimizer)
                scaler.update()
                #self.optimizer.step()
            
            if self.scheduler is not None:
                self.scheduler.step()

            # check if nan in loss tensor
            if torch.isnan(loss):
                raise ValueError("got nan")

            end_time_training = timer()
            time_to_train_model += end_time_training-start_time_training

        end_time_loop = timer()

        print("training loss: ", total_loss)   
        self.model.eval()
        self.model.to("cpu")

        end_total_time = timer()

        print("total time: ", end_total_time-start_total_time)
        print("preparation time: ", end_preparation-start_preparation)
        print("total time in loop: ", end_time_loop-start_time_loop)
        print("thereof time to train model: ", time_to_train_model)

        # self.model = self.model.float()

        return self.model

    def draw_action(self, X):
    
        # assert X does not contain nan values

        for i in X:
            if np.isnan(i).any():
                print("nan values in input")
                raise ValueError("NaN values in Input")
            break
        
        action = self.predict(X)

        if np.isnan(action).any():
            print("nan values in action")
            raise ValueError("NaN values in Output")

        return action
    
    # def predict(self, X):
    #     self.model.eval()
    #     self.model.to(self.device)

    #     if len(X.shape) == 1:
    #         X = np.expand_dims(X, axis=0)

    #     # Initialize a list to hold the X_i batches
    #     X_batches = []

    #     # Create batches
    #     for i in range(self.feature_map.shape[1]):
    #         X_i = X[:, self.feature_map[:, i].astype(bool)]
    #         X_batches.append(X_i)

    #     X_batch = np.concatenate([x for x in X_batches], axis=0)
    #     X_batch = torch.from_numpy(X_batch)
    #     X_batch = X_batch.float().to(self.device)

    #     with torch.no_grad():
    #         output = self.model(X_batch)
    #         output = output.cpu().numpy()

    #     # reduce output dimension

    #     output = output.squeeze(1)
        
    #     # print(outputs)

    #     # check if outputs need to got to cpu or handled by mushroomrl

    #     return output

    def predict(self, X):
        self.model.eval()
        self.model.to(self.device)

        X = torch.from_numpy(X).to(self.device)
        X = X.float()

        if X.size(0) > self.max_batch_size:
            batches = []
            for i in range(0, X.size(0), self.max_batch_size):
                batch = X[i:i+self.max_batch_size]
                batches.append(batch)
        else:
            batches = [X]
   
        outputs = []
        for X_batch in batches:
            with torch.no_grad():
                output_part = self.model(X_batch)
                outputs.append(output_part.cpu().numpy())
        
        # Concatenate the results from the smaller batches
        output = np.concatenate(outputs, axis=0)

        # Reduce output dimension
        output = output.squeeze()

        return output

    @staticmethod
    def pinball_loss(cu, co, demand, order_quantity):

        if len(demand.shape)==1:
            demand = demand.unsqueeze(1)
        
        assert demand.shape == order_quantity.shape

        cu = torch.tensor(cu, dtype=torch.float32)
        co = torch.tensor(co, dtype=torch.float32)

        cu = cu.unsqueeze(1)
        co = co.unsqueeze(1)

        cu = cu.to(demand.device)
        co = co.to(demand.device)

        underage_quantity = SGDBase.max_or_zero(demand-order_quantity)
        overage_quantity = SGDBase.max_or_zero(order_quantity-demand)

        assert cu.shape == underage_quantity.shape
        assert co.shape == overage_quantity.shape

        underage=cu*underage_quantity
        overage=co*overage_quantity

        loss=underage+overage
 
        return loss
