from abc import ABC

import numpy as np

from bregman.base import ETA_COORDS, THETA_COORDS, Coords, DualCoords, Point
from bregman.manifold.connection import FlatConnection
from bregman.manifold.coordinate import Atlas
from bregman.manifold.generator import Generator


class EtaGeneratorNotAssigned(Exception):
    r"""Exception for when :math:`\eta` Bregman generator not assigned."""

    pass


class BregmanManifold(ABC):
    r"""Base class for Bregman manifold. Constructs basic components of a Bregman manifold.

    Parameters:
        dimension: Dimension of canonical parameterizations (:math:`\theta`-or :math:`\eta`-coordinates).
        theta_generator: Primal generator for :math:`\theta`-coordinates.
        eta_generator: Dual generator for :math:`\eta`-coordinates.
        theta_connection: Connection object generated by theta_generator.
        eta_connection: Connection object generated by eta_generator.
        atlas: Atlas object used to manage different coordinates types of the manifold.
    """

    def __init__(
        self,
        theta_generator: Generator,
        eta_generator: Generator | None,
        dimension: int,
    ) -> None:
        r"""Initializes Bregman manifold from Bregman generators.

        Args:
            theta_generator: Primal generator for :math:`\theta`-coordinates.
            eta_generator: Dual generator for :math:`\eta`-coordinates. Optional.
            dimension: Dimension of canonical parameterizations (:math:`\theta`-or :math:`\eta`-coordinates).
        """
        super().__init__()

        self.dimension = dimension

        # Generators
        self.theta_generator = theta_generator
        self.eta_generator = eta_generator

        # Connections
        self.theta_connection = FlatConnection(THETA_COORDS, theta_generator)

        self.eta_connection = None
        if eta_generator is not None:
            self.eta_connection = FlatConnection(ETA_COORDS, eta_generator)

        # Atlas to change coordinates
        self.atlas = Atlas(dimension)
        self.atlas.add_coords(THETA_COORDS)
        if eta_generator is not None:
            self.atlas.add_coords(ETA_COORDS)
            self.atlas.add_transition(THETA_COORDS, ETA_COORDS, self._theta_to_eta)
            self.atlas.add_transition(ETA_COORDS, THETA_COORDS, self._eta_to_theta)

    def convert_coord(self, target_coords: Coords, point: Point) -> Point:
        r"""Converts coordinates of Point objects.

        Args:
            target_coords: Coords which one wants to convert point to.
            point: Point object being converted.

        Returns:
            point converted to target_coords coordinates based on manifold.
        """
        return self.atlas(target_coords, point)

    def bregman_generator(self, dcoords: DualCoords) -> Generator:
        r"""Get Bregman generator of a specific :math:`\theta`-or :math:`\eta`-coordinate.

        Args:
            dcoords: DualCoords specifying :math:`\theta`-or :math:`\eta`-coordinates.

        Raises:
            EtaGeneratorNotAssigned: If dcoords = DualCoords.ETA and
            self.self.eta_generator is not defined.

        Returns:
            Generator corresponding to specified dual coordinates.
        """
        generator = (
            self.theta_generator if dcoords == DualCoords.THETA else self.eta_generator
        )

        if generator is None:
            raise EtaGeneratorNotAssigned()

        return generator

    def bregman_connection(self, dcoords: DualCoords) -> FlatConnection:
        r"""Get Bregman connection of a specific :math:`\theta`-or :math:`\eta`-coordinate.

        Args:
            dcoords: DualCoords specifying :math:`\theta`-or :math:`\eta`-coordinates.

        Raises:
            EtaGeneratorNotAssigned: If dcoords = DualCoords.ETA and
            self.self.eta_generator is not defined.

        Returns:
            FlatConnection corresponding to specified dual coordinates.
        """
        connection = (
            self.theta_connection
            if dcoords == DualCoords.THETA
            else self.eta_connection
        )

        if connection is None:
            raise EtaGeneratorNotAssigned()

        return connection

    def _theta_to_eta(self, theta: np.ndarray) -> np.ndarray:
        r"""Internal method to convert data from :math:`\theta` to :math:`\eta`
        coordinates.

        Args:
            theta: :math:`\theta`-coordinate data.

        Returns:
            Data in :math:`\theta`-coordinates converted to the :math:`\eta`-coordinates.
        """
        return self.theta_generator.grad(theta)

    def _eta_to_theta(self, eta: np.ndarray) -> np.ndarray:
        r"""Internal method to convert data from :math:`\eta` to :math:`\theta`
        coordinates.

        Args:
            eta: :math:`\eta`-coordinate data.

        Returns:
            Data in :math:`\eta`-coordinates converted to the :math:`\theta`-coordinates.
        """

        if self.eta_generator is None:
            raise EtaGeneratorNotAssigned()

        return self.eta_generator.grad(eta)
