import numpy as np
from typing import Union

from autoarray.plot.abstract_plotters import AbstractPlotter
from autoarray.plot.visuals.two_d import Visuals2D
from autoarray.plot.mat_plot.two_d import MatPlot2D
from autoarray.plot.auto_labels import AutoLabels
from autoarray.structures.arrays.uniform_2d import Array2D
from autoarray.inversion.pixelization.mappers.rectangular import (
    MapperRectangular,
)

import logging

logger = logging.getLogger(__name__)


class MapperPlotter(AbstractPlotter):
    def __init__(
        self,
        mapper: MapperRectangular,
        mat_plot_2d: MatPlot2D = None,
        visuals_2d: Visuals2D = None,
    ):
        """
        Plots the attributes of `Mapper` objects using the matplotlib method `imshow()` and many other matplotlib
        functions which customize the plot's appearance.

        The `mat_plot_2d` attribute wraps matplotlib function calls to make the figure. By default, the settings
        passed to every matplotlib function called are those specified in the `config/visualize/mat_wrap/*.ini` files,
        but a user can manually input values into `MatPlot2d` to customize the figure's appearance.

        Overlaid on the figure are visuals, contained in the `Visuals2D` object. Attributes may be extracted from
        the `Mapper` and plotted via the visuals object.

        Parameters
        ----------
        mapper
            The mapper the plotter plots.
        mat_plot_2d
            Contains objects which wrap the matplotlib function calls that make 2D plots.
        visuals_2d
            Contains 2D visuals that can be overlaid on 2D plots.
        """
        super().__init__(visuals_2d=visuals_2d, mat_plot_2d=mat_plot_2d)

        self.mapper = mapper

    def figure_2d(
        self, interpolate_to_uniform: bool = False, solution_vector: bool = None
    ):
        """
        Plots the plotter's `Mapper` object in 2D.

        Parameters
        ----------
        interpolate_to_uniform
            By default, the mesh's reconstruction is interpolated to a uniform 2D array for plotting. If the
            reconstruction can be plotted in an alternative format (e.g. using Voronoi pixels for a Voronoi mesh)
            settings `interpolate_to_uniform=False` plots the reconstruction using this.
        solution_vector
            A vector of values which can culor the pixels of the mapper's source pixels.
        """
        self.mat_plot_2d.plot_mapper(
            mapper=self.mapper,
            visuals_2d=self.visuals_2d,
            interpolate_to_uniform=interpolate_to_uniform,
            pixel_values=solution_vector,
            auto_labels=AutoLabels(
                title="Pixelization Mesh (Source-Plane)", filename="mapper"
            ),
        )

    def figure_2d_image(self, image):

        self.mat_plot_2d.plot_array(
            array=image,
            visuals_2d=self.visuals_2d,
            grid_indexes=self.mapper.mapper_grids.image_plane_data_grid.over_sampled,
            auto_labels=AutoLabels(
                title="Image (Image-Plane)", filename="mapper_image"
            ),
        )

    def subplot_image_and_mapper(
        self, image: Array2D, interpolate_to_uniform: bool = False
    ):
        """
        Make a subplot of an input image and the `Mapper`'s source-plane reconstruction.

        This function can include colored points that mark the mappings between the image pixels and their
        corresponding locations in the `Mapper` source-plane and reconstruction. This therefore visually illustrates
        the mapping process.

        Parameters
        ----------
        interpolate_to_uniform
            By default, the mesh's reconstruction is interpolated to a uniform 2D array for plotting. If the
            reconstruction can be plotted in an alternative format (e.g. using Voronoi pixels for a Voronoi mesh)
            settings `interpolate_to_uniform=False` plots the reconstruction using this.
        image
            The image which is plotted on the subplot.
        """
        self.open_subplot_figure(number_subplots=2)

        self.figure_2d_image(image=image)
        self.figure_2d(interpolate_to_uniform=interpolate_to_uniform)

        self.mat_plot_2d.output.subplot_to_figure(
            auto_filename="subplot_image_and_mapper"
        )
        self.close_subplot_figure()

    def plot_source_from(
        self,
        pixel_values: np.ndarray,
        zoom_to_brightest: bool = True,
        interpolate_to_uniform: bool = False,
        auto_labels: AutoLabels = AutoLabels(),
    ):
        """
        Plot the source of the `Mapper` where the coloring is specified by an input set of values.

        Parameters
        ----------
        pixel_values
            The values of the mapper's source pixels used for coloring the figure.
        zoom_to_brightest
            For images not in the image-plane (e.g. the `plane_image`), whether to automatically zoom the plot to
            the brightest regions of the galaxies being plotted as opposed to the full extent of the grid.
        interpolate_to_uniform
            If `True`, the mapper's reconstruction is interpolated to a uniform grid before plotting, for example
            meaning that an irregular Delaunay grid can be plotted as a uniform grid.
        auto_labels
            The labels given to the figure.
        """
        try:
            self.mat_plot_2d.plot_mapper(
                mapper=self.mapper,
                visuals_2d=self.visuals_2d,
                auto_labels=auto_labels,
                pixel_values=pixel_values,
                zoom_to_brightest=zoom_to_brightest,
                interpolate_to_uniform=interpolate_to_uniform,
            )
        except ValueError:
            logger.info(
                "Could not plot the source-plane via the Mapper because of a ValueError."
            )
