# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/agents/benchmark_agents/04_Gradient_boosting_NV_agents.ipynb.

# %% auto 0
__all__ = ['FakePolicy', 'XGBAgent', 'XGBMetaAgent']

# %% ../../../nbs/agents/benchmark_agents/04_Gradient_boosting_NV_agents.ipynb 4
# General libraries:
import numpy as np
from scipy.stats import norm
from tqdm import trange, tqdm
from time import sleep

import xgboost as xgb

from ..processors.processors import GetTimeSeriesAndStaticFeatures

# Mushroom libraries
from mushroom_rl.core import Agent

from timeit import default_timer as timer

from .nv_agents_ERM import NewsvendorDataMeta
from torch.utils.data import DataLoader

# %% ../../../nbs/agents/benchmark_agents/04_Gradient_boosting_NV_agents.ipynb 6
class FakePolicy():
    def reset():
        pass

class XGBAgent(Agent):
    def __init__(self,

                 ### Pinball loss params
                 cu=None,
                 co=None,

                 ### XGB params
                 eta=0.3,
                 gamma=0,
                 max_depth=6,
                 min_child_weight=1,
                 max_delta_step=0,
                 subsample=1,
                 sampling_method="uniform",
                 colsample_bytree=1,
                 colsample_bylevel=1,
                 colsample_bynode=1,
                 lambda_=1,
                 alpha=0,
                 tree_method="auto",
                 scale_pos_weight=1,
                 # updater will always use default
                 refresh_leaf=1,
                 # process type will always use default
                 grow_policy="depthwise",
                 max_leaves=0,
                 max_bin=256,
                 num_parallel_tree=1,
                 multi_strategy="one_output_per_tree",
                 max_cached_hist_node=65536,
 
                 ### General params
                 nthread=1,
                 device="cpu",
                 agent_name="XGB_quantile",
                 ):

        self.cu = cu
        self.co = co

        self.model = xgb.XGBRegressor(
            objective = "reg:quantileerror",
            quantile_alpha = cu / (cu + co),

            eta=eta,
            gamma=gamma,
            max_depth=max_depth,
            min_child_weight=min_child_weight,
            max_delta_step=max_delta_step,
            subsample=subsample,
            sampling_method=sampling_method,
            colsample_bytree=colsample_bytree,
            colsample_bylevel=colsample_bylevel,
            colsample_bynode=colsample_bynode,
            lambda_=lambda_,
            alpha=alpha,
            tree_method=tree_method,
            scale_pos_weight=scale_pos_weight,
            # updater will always use default
            refresh_leaf=refresh_leaf,
            # process type will always use default
            grow_policy=grow_policy,
            max_leaves=max_leaves,
            max_bin=max_bin,
            num_parallel_tree=num_parallel_tree,
            multi_strategy=multi_strategy,
            max_cached_hist_node=max_cached_hist_node,
            
            nthread=nthread,
            device=device
        )

        self.train_directly=True
        self.train_mode = "direct"
        self.policy = FakePolicy

        self._postprocessors=list()
        self._preprocessors=list() 

        self.name = agent_name
        self.fitted=False
        
    def _get_fitted_model(self, X, y, mask=None):

        self.model.fit(X, y)
        self.n_features_ = X.shape[1]
    

    def fit(self, features, demand, mask=None):

        X=features
        y=demand

        # X = np.random.normal(0, 1, (1, X.shape[1]))
        # y = np.random.normal(0, 1, (y.shape))

        # X = np.random.rand(100, 5)
        # y = np.random.rand(100)

        # print unique values of y

        if y.ndim == 1:
            y = np.reshape(y, (-1, 1))

        self._get_fitted_model(X, y, mask)

        self.fitted=True

        return


    def draw_action(self, X):   

        if self.fitted:  

            if X.ndim == 1:
                X = np.reshape(X, (-1, self.n_features_))

            pred = self.model.predict(X)

        else:
            pred = np.random.rand(1)  
        
        return pred


# %% ../../../nbs/agents/benchmark_agents/04_Gradient_boosting_NV_agents.ipynb 7
class XGBMetaAgent(Agent):
    def __init__(self,
                 feature_map,
                 input_size,
                 lag_features,

                 ### Pinball loss params
                 cu,
                 co,
                 num_time_series_features,
                 lag_window,
                 agent_name="XGBMeta_quantile",

                 ### XGB params
                 eta=0.3,
                 gamma=0,
                 max_depth=6,
                 min_child_weight=1,
                 max_delta_step=0,
                 subsample=1,
                 sampling_method="uniform",
                 colsample_bytree=1,
                 colsample_bylevel=1,
                 colsample_bynode=1,
                 lambda_=1,
                 alpha=0,
                 tree_method="auto",
                 scale_pos_weight=1,
                 # updater will always use default
                 refresh_leaf=1,
                 # process type will always use default
                 grow_policy="depthwise",
                 max_leaves=0,
                 max_bin=256,
                 num_parallel_tree=1,
                 multi_strategy="one_output_per_tree",
                 max_cached_hist_node=65536,
 
                 ### General params
                 nthread=1,
                 device="cpu",
                 ):

        self.cu = cu
        self.co = co

        self.feature_map = feature_map
        self.num_time_series_features = num_time_series_features
        self.lag_window = lag_window
        self.lag_features = lag_features
        self.input_size = input_size  

        self.model = xgb.XGBRegressor(
            objective = "reg:quantileerror",
            quantile_alpha = cu / (cu + co),

            eta=eta,
            gamma=gamma,
            max_depth=max_depth,
            min_child_weight=min_child_weight,
            max_delta_step=max_delta_step,
            subsample=subsample,
            sampling_method=sampling_method,
            colsample_bytree=colsample_bytree,
            colsample_bylevel=colsample_bylevel,
            colsample_bynode=colsample_bynode,
            lambda_=lambda_,
            alpha=alpha,
            tree_method=tree_method,
            scale_pos_weight=scale_pos_weight,
            # updater will always use default
            refresh_leaf=refresh_leaf,
            # process type will always use default
            grow_policy=grow_policy,
            max_leaves=max_leaves,
            max_bin=max_bin,
            num_parallel_tree=num_parallel_tree,
            multi_strategy=multi_strategy,
            max_cached_hist_node=max_cached_hist_node,
            
            nthread=nthread,
            device=device
        )

        self.train_directly=True
        self.train_mode = "direct"
        self.policy = FakePolicy

        self._postprocessors=list()
        self._preprocessors=list() 

        self.name = agent_name
        self.fitted=False
        
    def _get_fitted_model(self, X, y, mask=None):

        self.model.fit(X, y)
        self.n_features_ = X.shape[1]
    

    def fit(self, features, demand, mask=None):

        print("separating features")

        features_time_dependent = features[0]
        features_timeless = features[1]

        print("setting up dataset")

        dataset_train = NewsvendorDataMeta(features_time_dependent, features_timeless, demand, self.feature_map, self.lag_window, self.input_size, self.lag_features)


        batch_size = int(len(dataset_train)*0.01)
        print("batch_size:", batch_size)
        train_loader=DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=1,) # make shuffle true

        print("loading data")
        start_time = timer()

        for i, batch in enumerate(train_loader): # will only be one batch
            X = batch[0].numpy()
            X = np.reshape(X, (X.shape[0], -1))
            y = batch[1].numpy()
            product = batch[2]

            if i >=0:
                break

        end_time = timer()

        print("loaded data in seconds:", end_time-start_time)
        
        print(X.shape)
        print(y.shape)

        # print size of X and y in GB:
        print("size X in GB:", X.nbytes / 1e9)
        print("size y in GB:", y.nbytes / 1e9)

        if y.ndim == 1:
            y = np.reshape(y, (-1, 1))

        start_time = timer()
        self._get_fitted_model(X, y, mask)
        end_time = timer()
        print("fitted model in seconds:", end_time-start_time)

        self.fitted=True

        return


    def draw_action(self, X):   

        if self.fitted:  

            if X.ndim == 1:
                X = np.reshape(X, (-1, self.n_features_))
            
            if len(X.shape) == 3:
                X = np.reshape(X, (X.shape[0], -1))

            pred = self.model.predict(X)

        else:
            pred = np.random.rand(1)  
        
        return pred
