import numpy as np
import pandas as pd

from autogluon.timeseries.models.local.abstract_local_model import AbstractLocalModel
from autogluon.timeseries.utils.datetime import get_time_features_for_frequency
from autogluon.timeseries.utils.forecast import get_forecast_horizon_index_single_time_series


class NPTSModel(AbstractLocalModel):
    """Non-Parametric Time Series Forecaster.

    This models is especially well suited for forecasting sparse or intermittent time series with many zero values.

    Based on `gluonts.model.npts.NPTSPredictor <https://ts.gluon.ai/stable/api/gluonts/gluonts.model.npts.html>`_.
    See GluonTS documentation for more information about the model.

    Other Parameters
    ----------------
    kernel_type : {"exponential", "uniform"}, default = "exponential"
        Kernel used by the model.
    exp_kernel_weights : float, default = 1.0
        Scaling factor used in the exponential kernel.
    use_seasonal_model : bool, default = True
        Whether to use the seasonal variant of the model.
    num_samples : int, default = 100
        Number of samples generated by the forecast.
    num_default_time_features : int, default = 1
        Number of time features used by seasonal model.
    n_jobs : int or float, default = joblib.cpu_count(only_physical_cores=True)
        Number of CPU cores used to fit the models in parallel.
        When set to a float between 0.0 and 1.0, that fraction of available CPU cores is used.
        When set to a positive integer, that many cores are used.
        When set to -1, all CPU cores are used.
    max_ts_length : Optional[int], default = 2500
        If not None, only the last ``max_ts_length`` time steps of each time series will be used to train the model.
        This significantly speeds up fitting and usually leads to no change in accuracy.
    """

    ag_priority = 80
    allowed_local_model_args = [
        "kernel_type",
        "exp_kernel_weights",
        "use_seasonal_model",
        "num_samples",
        "num_default_time_features",
        "seasonal_period",
    ]

    def _update_local_model_args(self, local_model_args: dict) -> dict:
        local_model_args = super()._update_local_model_args(local_model_args)
        local_model_args.setdefault("num_samples", 100)
        local_model_args.setdefault("num_default_time_features", 1)
        return local_model_args

    def _predict_with_local_model(
        self,
        time_series: pd.Series,
        local_model_args: dict,
    ) -> pd.DataFrame:
        from gluonts.model.npts import NPTSPredictor

        local_model_args.pop("seasonal_period")
        num_samples = local_model_args.pop("num_samples")
        num_default_time_features = local_model_args.pop("num_default_time_features")

        ts = time_series.copy(deep=False)
        # We generate time features outside NPTSPredictor since GluonTS does not support all pandas frequencies
        future_index = get_forecast_horizon_index_single_time_series(
            ts.index, freq=self.freq, prediction_length=self.prediction_length
        )
        past_and_future_index = ts.index.union(future_index)
        time_features = get_time_features_for_frequency(self.freq)[:num_default_time_features]
        if len(time_features) == 0:
            local_model_args["use_seasonal_model"] = False
            custom_features = None
        else:
            custom_features = np.vstack([feat(past_and_future_index) for feat in time_features])

        # We pass dummy frequency to GluonTS because it does not support all pandas frequencies
        dummy_freq = "S"
        ts.index = ts.index.to_period(freq=dummy_freq)
        predictor = NPTSPredictor(
            prediction_length=self.prediction_length,
            use_default_time_features=False,
            **local_model_args,
        )
        forecast = predictor.predict_time_series(ts, num_samples=num_samples, custom_features=custom_features)
        forecast_dict = {"mean": forecast.mean}
        for q in self.quantile_levels:
            forecast_dict[str(q)] = forecast.quantile(q)
        return pd.DataFrame(forecast_dict)

    def _more_tags(self) -> dict:
        return {"allow_nan": True}
