from typing import NamedTuple, Optional, cast

import torch
from torch.utils.data import Dataset as TorchDataset

from replay.data.nn import (
    MutableTensorMap,
    SequentialDataset,
    TensorMap,
    TorchSequentialDataset,
    TorchSequentialValidationDataset,
)


class SasRecTrainingBatch(NamedTuple):
    """
    Batch of data for training.
    Generated by `SasRecTrainingDataset`.
    """

    query_id: torch.LongTensor
    padding_mask: torch.BoolTensor
    features: TensorMap
    labels: torch.LongTensor
    labels_padding_mask: torch.BoolTensor


class SasRecTrainingDataset(TorchDataset):
    """
    Dataset that generates samples to train SasRec-like model
    """

    def __init__(
        self,
        sequential: SequentialDataset,
        max_sequence_length: int,
        sequence_shift: int = 1,
        sliding_window_step: Optional[None] = None,
        padding_value: Optional[int] = None,
        label_feature_name: Optional[str] = None,
    ) -> None:
        """
        :param sequential: Sequential dataset with training data.
        :param max_sequence_length: Max length of sequence.
        :param sequence_shift: Shift of sequence to predict.
        :param sliding_window_step: A sliding window step.
            If not ``None`` provides iteration over sequences with window.
            Default: ``None``.
        :param padding_value: Value for padding a sequence to match the `max_sequence_length`.
            Default: ``0``.
        :param label_feature_name: Name of label feature in provided dataset.
            If ``None`` set an item_id_feature name from sequential dataset.
            Default: ``None``.
        """
        super().__init__()
        if label_feature_name:
            if label_feature_name not in sequential.schema:
                msg = "Label feature name not found in provided schema"
                raise ValueError(msg)

            if not sequential.schema[label_feature_name].is_cat:
                msg = "Label feature must be categorical"
                raise ValueError(msg)

            if not sequential.schema[label_feature_name].is_seq:
                msg = "Label feature must be sequential"
                raise ValueError(msg)

        self._sequence_shift = sequence_shift
        self._max_sequence_length = max_sequence_length + sequence_shift
        self._label_feature_name = label_feature_name or sequential.schema.item_id_feature_name
        self._schema = sequential.schema

        self._inner = TorchSequentialDataset(
            sequential=sequential,
            max_sequence_length=self._max_sequence_length,
            sliding_window_step=sliding_window_step,
            padding_value=padding_value,
        )

    def __len__(self) -> int:
        return len(self._inner)

    def __getitem__(self, index: int) -> SasRecTrainingBatch:
        query_id, padding_mask, features = self._inner[index]

        assert self._label_feature_name
        labels = features[self._label_feature_name][self._sequence_shift :]
        labels_padding_mask = padding_mask[self._sequence_shift :]

        output_features: MutableTensorMap = {}
        for feature_name in self._schema:
            feature = features[feature_name]
            if self._schema[feature_name].is_seq:
                feature = feature[: -self._sequence_shift]
            output_features[feature_name] = feature

        output_features_padding_mask = padding_mask[: -self._sequence_shift]

        return SasRecTrainingBatch(
            query_id=query_id,
            features=output_features,
            padding_mask=cast(torch.BoolTensor, output_features_padding_mask),
            labels=cast(torch.LongTensor, labels),
            labels_padding_mask=cast(torch.BoolTensor, labels_padding_mask),
        )


class SasRecPredictionBatch(NamedTuple):
    """
    Batch of data for model inference.
    Generated by `SasRecPredictionDataset`.
    """

    query_id: torch.LongTensor
    padding_mask: torch.BoolTensor
    features: TensorMap


class SasRecPredictionDataset(TorchDataset):
    """
    Dataset that generates samples to infer SasRec-like model
    """

    def __init__(
        self,
        sequential: SequentialDataset,
        max_sequence_length: int,
        padding_value: Optional[int] = None,
    ) -> None:
        """
        :param sequential: Sequential dataset with data to make predictions at.
        :param max_sequence_length: Max length of sequence.
        :param padding_value: Value for padding a sequence to match the `max_sequence_length`.
            Default: ``0``.
        """
        self._inner = TorchSequentialDataset(
            sequential=sequential,
            max_sequence_length=max_sequence_length,
            padding_value=padding_value,
        )

    def __len__(self) -> int:
        return len(self._inner)

    def __getitem__(self, index: int) -> SasRecPredictionBatch:
        query_id, padding_mask, features = self._inner[index]
        return SasRecPredictionBatch(
            query_id=query_id,
            padding_mask=padding_mask,
            features=features,
        )


class SasRecValidationBatch(NamedTuple):
    """
    Batch of data for validation.
    Generated by `SasRecValidationDataset`.
    """

    query_id: torch.LongTensor
    padding_mask: torch.BoolTensor
    features: TensorMap
    ground_truth: torch.LongTensor
    train: torch.LongTensor


class SasRecValidationDataset(TorchDataset):
    """
    Dataset that generates samples to infer and validate SasRec-like model
    """

    def __init__(
        self,
        sequential: SequentialDataset,
        ground_truth: SequentialDataset,
        train: SequentialDataset,
        max_sequence_length: int,
        padding_value: Optional[int] = None,
        label_feature_name: Optional[str] = None,
    ):
        """
        :param sequential: Sequential dataset with data to make predictions at.
        :param ground_truth: Sequential dataset with ground truth predictions.
        :param train: Sequential dataset with training data.
        :param max_sequence_length: Max length of sequence.
        :param padding_value: Value for padding a sequence to match the `max_sequence_length`.
            Default: ``0``.
        :param label_feature_name: Name of label feature in provided dataset.
            If ``None`` set an item_id_feature name from sequential dataset.
            Default: ``None``.
        """
        self._inner = TorchSequentialValidationDataset(
            sequential=sequential,
            ground_truth=ground_truth,
            train=train,
            max_sequence_length=max_sequence_length,
            padding_value=padding_value,
            label_feature_name=label_feature_name,
        )

    def __len__(self) -> int:
        return len(self._inner)

    def __getitem__(self, index: int) -> SasRecValidationBatch:
        query_id, padding_mask, features, ground_truth, train = self._inner[index]
        return SasRecValidationBatch(
            query_id=query_id,
            padding_mask=padding_mask,
            features=features,
            ground_truth=ground_truth,
            train=train,
        )
