"""Trialmap of DAQ-HDF files

Electrophysiological recording sessions typically consist of trials. In
different trials, experimental conditions repeat or have other important
similarities. Trialmap information in DAQ-HDF files characterizes each
trial and makes it possible to determine which parts of signal or spike
data corresponds to a particular trial.

The information contained in DAQ-HDF trialmap is:

-   Trial numbers as generated by the stimulation PC;
-   Stimulus numbers (encoded type of stimulation);
-   Outcome number (encoded behavioral data, such as successful or
    unsuccessful performance of the experiment subject in each trial)
-   Timestamps for the start and the end of each trial;

Trialmap is a dataset in the root group of a DAQ-HDF file:

`TRIALMAP` (`struct` array\[T\]):

| Offset    | Name          |   Type    |
|-----|-----------|-------|
| 0   | TrialNo   | `int32` |
| 4   | StimNo    | `int32` |
| 8   | Outcome   | `int32` |
| 12  | StartTime | `int64` |
| 20  | EndTime   | `int64` |

Here, `T` is the total number of trials in the file.

- **TrialNo** is a sequence number generated by the stimulation program and
then transferred to the recording program and stored in the data
acquisition file. This number can be used to combine trialmap
information with whatever other information about the trials from other
sources of data than the DAQ-HDF file itself. If DAQ-HDF file and the
trialmap were created in such a way that the above mentioned
considerations are not applicable, this structure member may be ignored
altogether or, better, filled with a sequence of ascending numbers.

- **StimNo** is so-called Stimulus Number. Trials which have the same stimulus
numbers can be usually grouped together for analysis. So, StimNo
contains some encoded information about the type of trial and, possibly,
some other conditions. Note: The Stimulus Number in VStim has been renamed
to Trial Type Number. For backwards compatibility, the DH5 files still contain
StimNo.

- **Outcome** – behavioral data. Outcome code specifies the type of behavior
observed and discriminated from the experimental subject. Typically,
Outcome member specifies whether the subject performed his task during a
trial successfully or not, and if not, what particular kind of error was
made by him.

- **StartTime** and **EndTime** are timestamps, in nanoseconds, for the beginning
and ending of each trial. All timestamps throughout a DAQ-HDF file have
the same base value, so timestamps from `CONT` and `SPIKE` blocks as well as
the `TRIALMAP`, are comparable with each other. It is typically needed,
based on the timestamps from the TRIALMAP, to determine location of the
corresponding piece of signal  within `CONT` or `SPIKE` blocks.

"""

from enum import IntEnum
import logging
import h5py
from dh5io.errors import DH5Error
import numpy
from dhspec.trialmap import TRIALMAP_DATASET_DTYPE, TRIALMAP_DATASET_NAME

logger = logging.getLogger(__name__)


class TrialOutcome(IntEnum):
    NotStarted = 0
    Hit = 1
    WrongResponse = 2
    EarlyHit = 3
    EarlyWrongResponse = 4
    Early = 5
    Late = 6
    EyeErr = 7
    InexpectedStartSignal = 8
    WrongStartSignal = 9


def add_trialmap_to_file(
    file: h5py.File, trialmap: numpy.recarray, replace=True
) -> None:
    if not trialmap.dtype == TRIALMAP_DATASET_DTYPE:
        raise DH5Error(
            f"Invalid trialmap dtype: {trialmap.dtype}. Expected {TRIALMAP_DATASET_DTYPE}"
        )
    if TRIALMAP_DATASET_NAME in file:
        if not replace:
            raise DH5Error(f"TRIALMAP dataset already exists in file {file.filename}")
        del file[TRIALMAP_DATASET_NAME]
        logger.debug(f"Replacing existing TRIALMAP dataset in file {file.filename}")
    file.create_dataset(TRIALMAP_DATASET_NAME, data=trialmap)


def get_trialmap_from_file(file: h5py.File) -> numpy.recarray | None:
    trialmap_dataset = file.get(TRIALMAP_DATASET_NAME)
    if trialmap_dataset is None:
        return None
    else:
        return numpy.rec.array(
            numpy.array(trialmap_dataset, dtype=TRIALMAP_DATASET_DTYPE)
        )


def validate_trialmap(file: h5py.File):
    # check for TRIALMAP dataset
    if TRIALMAP_DATASET_NAME not in file:
        logger.warning(f"TRIALMAP dataset not found in file {file.filename}")
        return
    validate_trialmap_dataset(file[TRIALMAP_DATASET_NAME])


def validate_trialmap_dataset(trialmap: h5py.Dataset) -> None:
    # trialmap must be a compound dataset with fields 'TrialNo', 'StimNo', 'Outcome', 'StartTime', 'EndTime'
    if (
        not isinstance(trialmap, h5py.Dataset)
        or trialmap.dtype != TRIALMAP_DATASET_DTYPE
    ):
        raise DH5Error(
            f"TRIALMAP dataset is not a named data type with fields 'TrialNo', 'StimNo', 'Outcome', 'StartTime', 'EndTime': {trialmap.dtype}"
        )


class Trialmap:
    """Trialmap class for DAQ-HDF files.
    Provides access to trialmap data in a structured way.
    Attributes:
        recarray (numpy.recarray): The trialmap data as a structured array.
    Properties:
        trial_numbers (numpy.ndarray): Array of trial numbers.
        trial_type_numbers (numpy.ndarray): Array of trial type numbers.
        trial_outcomes_integer (numpy.ndarray): Array of trial outcomes as integers.
        trial_outcomes_as_enum (list[TrialOutcome]): List of trial outcomes as TrialOutcome enums.
        start_time_nanoseconds (numpy.ndarray): Start time of each trial in nanoseconds.
        start_time_float_seconds (numpy.ndarray): Start time of each trial in seconds as float.
        end_time_nanoseconds (numpy.ndarray): End time of each trial in nanoseconds.
        end_time_float_seconds (numpy.ndarray): End time of each trial in seconds as float.

    """

    recarray: numpy.recarray

    def __init__(self, trialmap: numpy.recarray):
        if trialmap.dtype != TRIALMAP_DATASET_DTYPE:
            raise DH5Error(
                f"Invalid trialmap dtype: {trialmap.dtype}. Expected {TRIALMAP_DATASET_DTYPE}"
            )
        self.recarray = trialmap

    def __len__(self):
        return len(self.recarray)

    def __str__(self):
        return f"""Trialmap with {len(self)} trials
    TrialNo, TrialTypeNo, Outcome, StartTmeNS, EndTimeNs
{self.recarray}"""

    @property
    def trial_type_numbers(self) -> numpy.ndarray:
        """Return trial type numbers"""
        return self.recarray.StimNo

    @property
    def trial_numbers(self) -> numpy.ndarray:
        """Trial indices"""
        return self.recarray.TrialNo

    @property
    def trial_outcomes_integer(self) -> numpy.ndarray:
        """Trial outcomes as integers"""
        return self.recarray.Outcome

    @property
    def trial_outcomes_as_enum(self) -> list[TrialOutcome]:
        """Return trial outcomes as TrialOutcome enum"""
        return [TrialOutcome(outcome) for outcome in self.recarray.Outcome]

    @property
    def start_time_nanoseconds(self) -> numpy.ndarray:
        """Start time of each trial in nanoseconds"""
        return self.recarray.StartTime

    @property
    def start_time_float_seconds(self) -> numpy.ndarray:
        """Start time of each trial in seconds as float"""
        return self.recarray.StartTime.astype(numpy.float64) / 1e9

    @property
    def end_time_nanoseconds(self) -> numpy.ndarray:
        """End time of each trial in nanoseconds"""
        return self.recarray.EndTime

    @property
    def end_time_float_seconds(self) -> numpy.ndarray:
        """End time of each trial in seconds as float"""
        return self.recarray.EndTime.astype(numpy.float64) / 1e9
