"""Data logging, management, and reward calculation."""

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Optional

from deprecated import deprecated

from bsk_rl.utils.functional import Resetable

if TYPE_CHECKING:  # pragma: no cover
    from bsk_rl.sats import Satellite
    from bsk_rl.scene import Scenario


logger = logging.getLogger(__name__)

LogStateType = Any


class Data(ABC):
    """Base class for units of satellite data.

    Only needs to implement the ``__add__`` method, which is used to combine two units
    of data. This is used when adding new data from actions or communication to the
    data store.
    """

    @abstractmethod  # pragma: no cover
    def __add__(self, other: "Data") -> "Data":
        """Define the combination of two units of data."""
        pass

    def __copy__(self) -> "Data":
        """Create a shallow copy of the data."""
        return self.__class__() + self


class DataStore(ABC):
    """Base class for satellite data logging."""

    data_type: type[Data]  # Define the unit of data used by the DataStore

    def __init__(
        self, satellite: "Satellite", initial_data: Optional[Data] = None
    ) -> None:
        """Base class for satellite data logging.

        One DataStore is created for each satellite in the scenario each time the
        scenario is reset. The DataStore is responsible for generating data from the
        satellite's environment and actions by comparing the current and previous-step
        state from :class:`~DataStore.get_log_state` and returning a unit of data with
        :class:`~DataStore.compare_log_states`. These two methods must be implemented
        by subclasses.

        Args:
            satellite: Satellite which data is being stored for.
            initial_data: Initial data to start the store with. Usually comes from
                :class:`~bsk_rl.data.GlobalReward.initial_data`.
        """
        self.satellite = satellite
        self.staged_data = []

        if initial_data:
            self.data = initial_data
        else:
            self.data = self.data_type()
        self.new_data = self.data_type()

    def get_log_state(self) -> LogStateType:
        """Pull information used in determining current data contribution."""
        pass

    @abstractmethod  # pragma: no cover
    def compare_log_states(
        self, old_state: LogStateType, new_state: LogStateType
    ) -> "Data":
        """Generate a unit of data based on previous step and current step logs.

        Args:
            old_state: A previous result of :class:`~DataStore.get_log_state`.
            new_state: A newer result of :class:`~DataStore.get_log_state`.

        Returns:
            Data: New data generated by the satellite.
        """
        pass

    def update_from_logs(self) -> "Data":
        """Update the data store based on collected information.

        Returns:
            New data from the previous step.
        """
        if not hasattr(self, "log_state"):
            self.log_state = self.get_log_state()
            return self.data_type()
        old_log_state = self.log_state
        self.log_state = self.get_log_state()
        new_data = self.compare_log_states(old_log_state, self.log_state)
        self.data += new_data
        self.new_data = new_data
        return new_data

    def stage_communicated_data(self, external_data: "Data") -> None:
        """Prepare data to be added from another source, but don't add it yet.

        Works with :class:`~DataStore.update_with_communicated_data` to add data from
        other satellites without erroneously propagating it through other satellites.

        Args:
            external_data: Data from another satellite to be added
        """
        self.staged_data.append(external_data)

    def update_with_communicated_data(self) -> None:
        """Update the data store from staged data."""
        for staged in self.staged_data:
            self.data += staged
        self.staged_data = []


class GlobalReward(ABC, Resetable):
    """Base class for simulation-wide data management."""

    data_store_type: type[DataStore]  # type of DataStore managed by the GlobalReward

    @classmethod
    @property
    @deprecated(reason="datastore_type is deprecated, use data_store_type instead")
    def datastore_type(cls) -> type[DataStore]:
        """:meta private: Deprecated alias for data_store_type."""
        return cls.data_store_type

    @deprecated(reason="datastore_type is deprecated, use data_store_type instead")
    def set_data_type_deprecated(self) -> None:
        """:meta private: Deprecated alias for data_store_type."""
        self.data_type = self.datastore_type.data_type
        self.data_store_type = self.datastore_type

    def __init__(self) -> None:
        """Base class for simulation-wide data management and rewarding.

        The method :class:`calculate_reward` must be overridden by subclasses. Other
        methods may be extended as necessary for housekeeping.
        """
        self.scenario: "Scenario"
        try:
            self.data_type = self.data_store_type.data_type
        except AttributeError:
            self.set_data_type_deprecated()
        self.data_store_kwargs = {}

    def link_scenario(self, scenario: "Scenario") -> None:
        """Link the data manager to the scenario.

        Args:
            scenario: The scenario that the data manager is being used with.
        """
        self.scenario = scenario

    def reset_overwrite_previous(self) -> None:
        """Overwrite attributes from previous episode."""
        self.data = self.data_type()
        self.cum_reward = {}

    def initial_data(self, satellite: "Satellite") -> "Data":
        """Furnish the :class:`~bsk_rl.data.base.DataStore` with initial data."""
        return self.data_type()

    def create_data_store(self, satellite: "Satellite") -> None:
        """Create a data store for a satellite.

        Args:
            satellite: Satellite to create a data store for.
        """
        satellite.data_store = self.data_store_type(
            satellite,
            initial_data=self.initial_data(satellite),
            **self.data_store_kwargs,
        )
        self.cum_reward[satellite.name] = 0.0

    @abstractmethod  # pragma: no cover
    def calculate_reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]:
        """Calculate step reward based on all satellite data from a step.

        Returns a dictionary of rewards for each satellite based on the new data
        generated by each satellite during the previous step, in the form:

        .. code-block:: python

            {"sat-1_id": 0.23, "sat-2_id": 0.0, ...}


        Args:
            new_data_dict: A dictionary of new data generated by each satellite, in the
                form:

                .. code-block:: python

                    {"sat-1_id": data1, "sat-2_id": data2, ...}
        """
        pass

    def reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]:
        """Call :class:`calculate_reward` and log cumulative reward."""
        reward = self.calculate_reward(new_data_dict)
        for satellite_id, sat_reward in reward.items():
            self.cum_reward[satellite_id] += sat_reward

        for new_data in new_data_dict.values():
            self.data += new_data

        nonzero_reward = {k: v for k, v in reward.items() if v != 0}
        logger.info(f"Total reward: {nonzero_reward}")
        return reward

    def is_truncated(self, satellite) -> bool:
        """Check if the episode should be truncated for a satellite."""
        return False

    def is_terminated(self, satellite) -> bool:
        """Check if the episode should be terminated for a satellite."""
        return False


__doc_title__ = "Base Data"
__all__ = ["GlobalReward", "DataStore", "Data"]
