from autoarray.plot.abstract_plotters import AbstractPlotter
from autoarray.plot.visuals.one_d import Visuals1D
from autoarray.plot.visuals.two_d import Visuals2D
from autoarray.plot.mat_plot.one_d import MatPlot1D
from autoarray.plot.mat_plot.two_d import MatPlot2D
from autoarray.plot.auto_labels import AutoLabels
from autoarray.dataset.interferometer.dataset import Interferometer
from autoarray.structures.grids.irregular_2d import Grid2DIrregular


class InterferometerPlotter(AbstractPlotter):
    def __init__(
        self,
        dataset: Interferometer,
        mat_plot_1d: MatPlot1D = None,
        visuals_1d: Visuals1D = None,
        mat_plot_2d: MatPlot2D = None,
        visuals_2d: Visuals2D = None,
    ):
        """
        Plots the attributes of `Interferometer` objects using the matplotlib methods `plot()`, `scatter()` and
        `imshow()` and other matplotlib functions which customize the plot's appearance.

        The `mat_plot_1d` and `mat_plot_2d` attributes wrap matplotlib function calls to make the figure. By default,
        the settings passed to every matplotlib function called are those specified in
        the `config/visualize/mat_wrap/*.ini` files, but a user can manually input values into `MatPlot2d` to
        customize the figure's appearance.

        Overlaid on the figure are visuals, contained in the `Visuals1D` and `Visuals2D` objects. Attributes may be
        extracted from the `LightProfile` and plotted via the visuals object.

        Parameters
        ----------
        dataset
            The interferometer dataset the plotter plots.
        mat_plot_1d
            Contains objects which wrap the matplotlib function calls that make 1D plots.
        visuals_1d
            Contains 1D visuals that can be overlaid on 1D plots.
        mat_plot_2d
            Contains objects which wrap the matplotlib function calls that make 2D plots.
        visuals_2d
            Contains 2D visuals that can be overlaid on 2D plots.
        """
        self.dataset = dataset

        super().__init__(
            mat_plot_1d=mat_plot_1d,
            visuals_1d=visuals_1d,
            mat_plot_2d=mat_plot_2d,
            visuals_2d=visuals_2d,
        )

    @property
    def interferometer(self):
        return self.dataset

    def figures_2d(
        self,
        data: bool = False,
        noise_map: bool = False,
        u_wavelengths: bool = False,
        v_wavelengths: bool = False,
        uv_wavelengths: bool = False,
        amplitudes_vs_uv_distances: bool = False,
        phases_vs_uv_distances: bool = False,
        dirty_image: bool = False,
        dirty_noise_map: bool = False,
        dirty_signal_to_noise_map: bool = False,
    ):
        """
        Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D.

        The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type
        bool of the function, which if switched to `True` means that it is plotted.

        Parameters
        ----------
        data
            Whether to make a 2D plot (via `scatter`) of the visibility data.
        noise_map
            Whether to make a 2D plot (via `scatter`) of the noise-map.
        u_wavelengths
            Whether to make a 1D plot (via `plot`) of the u-wavelengths.
        v_wavelengths
            Whether to make a 1D plot (via `plot`) of the v-wavelengths.
        amplitudes_vs_uv_distances
            Whether to make a 1D plot (via `plot`) of the amplitudes versis the uv distances.
        phases_vs_uv_distances
            Whether to make a 1D plot (via `plot`) of the phases versis the uv distances.
        dirty_image
            Whether to make a 2D plot (via `imshow`) of the dirty image.
        dirty_noise_map
            Whether to make a 2D plot (via `imshow`) of the dirty noise map.
        dirty_signal_to_noise_map
            Whether to make a 2D plot (via `imshow`) of the dirty signal-to-noise map.
        """

        if data:
            self.mat_plot_2d.plot_grid(
                grid=self.dataset.data.in_grid,
                visuals_2d=self.visuals_2d,
                auto_labels=AutoLabels(title="Visibilities", filename="data"),
            )

        if noise_map:
            self.mat_plot_2d.plot_grid(
                grid=self.dataset.data.in_grid,
                visuals_2d=self.visuals_2d,
                color_array=self.dataset.noise_map.real,
                auto_labels=AutoLabels(title="Noise-Map", filename="noise_map"),
            )

        if u_wavelengths:
            self.mat_plot_1d.plot_yx(
                y=self.dataset.uv_wavelengths[:, 0],
                x=None,
                visuals_1d=self.visuals_1d,
                auto_labels=AutoLabels(
                    title="U-Wavelengths",
                    filename="u_wavelengths",
                    ylabel="Wavelengths",
                ),
                plot_axis_type_override="linear",
            )

        if v_wavelengths:
            self.mat_plot_1d.plot_yx(
                y=self.dataset.uv_wavelengths[:, 1],
                x=None,
                visuals_1d=self.visuals_1d,
                auto_labels=AutoLabels(
                    title="V-Wavelengths",
                    filename="v_wavelengths",
                    ylabel="Wavelengths",
                ),
                plot_axis_type_override="linear",
            )

        if uv_wavelengths:
            self.mat_plot_2d.plot_grid(
                grid=Grid2DIrregular.from_yx_1d(
                    y=self.dataset.uv_wavelengths[:, 1] / 10**3.0,
                    x=self.dataset.uv_wavelengths[:, 0] / 10**3.0,
                ),
                visuals_2d=self.visuals_2d,
                auto_labels=AutoLabels(
                    title="UV-Wavelengths", filename="uv_wavelengths"
                ),
            )

        if amplitudes_vs_uv_distances:
            self.mat_plot_1d.plot_yx(
                y=self.dataset.amplitudes,
                x=self.dataset.uv_distances / 10**3.0,
                visuals_1d=self.visuals_1d,
                auto_labels=AutoLabels(
                    title="Amplitudes vs UV-distances",
                    filename="amplitudes_vs_uv_distances",
                    yunit="Jy",
                    xunit="k$\lambda$",
                ),
                plot_axis_type_override="scatter",
            )

        if phases_vs_uv_distances:
            self.mat_plot_1d.plot_yx(
                y=self.dataset.phases,
                x=self.dataset.uv_distances / 10**3.0,
                visuals_1d=self.visuals_1d,
                auto_labels=AutoLabels(
                    title="Phases vs UV-distances",
                    filename="phases_vs_uv_distances",
                    yunit="deg",
                    xunit="k$\lambda$",
                ),
                plot_axis_type_override="scatter",
            )

        if dirty_image:
            self.mat_plot_2d.plot_array(
                array=self.dataset.dirty_image,
                visuals_2d=self.visuals_2d,
                auto_labels=AutoLabels(title="Dirty Image", filename="dirty_image"),
            )

        if dirty_noise_map:
            self.mat_plot_2d.plot_array(
                array=self.dataset.dirty_noise_map,
                visuals_2d=self.visuals_2d,
                auto_labels=AutoLabels(
                    title="Dirty Noise Map", filename="dirty_noise_map"
                ),
            )

        if dirty_signal_to_noise_map:
            self.mat_plot_2d.plot_array(
                array=self.dataset.dirty_signal_to_noise_map,
                visuals_2d=self.visuals_2d,
                auto_labels=AutoLabels(
                    title="Dirty Signal-To-Noise Map",
                    filename="dirty_signal_to_noise_map",
                ),
            )

    def subplot(
        self,
        data: bool = False,
        noise_map: bool = False,
        u_wavelengths: bool = False,
        v_wavelengths: bool = False,
        uv_wavelengths: bool = False,
        amplitudes_vs_uv_distances: bool = False,
        phases_vs_uv_distances: bool = False,
        dirty_image: bool = False,
        dirty_noise_map: bool = False,
        dirty_signal_to_noise_map: bool = False,
        auto_filename: str = "subplot_dataset",
    ):
        """
        Plots the individual attributes of the plotter's `Interferometer` object in 1D and 2D on a subplot.

        The API is such that every plottable attribute of the `Interferometer` object is an input parameter of type
        bool of the function, which if switched to `True` means that it is included on the subplot.

        Parameters
        ----------
        data
            Whether to include a 2D plot (via `scatter`) of the visibility data.
        noise_map
            Whether to include a 2D plot (via `scatter`) of the noise-map.
        u_wavelengths
            Whether to include a 1D plot (via `plot`) of the u-wavelengths.
        v_wavelengths
            Whether to include a 1D plot (via `plot`) of the v-wavelengths.
        amplitudes_vs_uv_distances
            Whether to include a 1D plot (via `plot`) of the amplitudes versis the uv distances.
        phases_vs_uv_distances
            Whether to include a 1D plot (via `plot`) of the phases versis the uv distances.
        dirty_image
            Whether to include a 2D plot (via `imshow`) of the dirty image.
        dirty_noise_map
            Whether to include a 2D plot (via `imshow`) of the dirty noise map.
        dirty_signal_to_noise_map
            Whether to include a 2D plot (via `imshow`) of the dirty signal-to-noise map.
        """
        self._subplot_custom_plot(
            data=data,
            noise_map=noise_map,
            u_wavelengths=u_wavelengths,
            v_wavelengths=v_wavelengths,
            uv_wavelengths=uv_wavelengths,
            amplitudes_vs_uv_distances=amplitudes_vs_uv_distances,
            phases_vs_uv_distances=phases_vs_uv_distances,
            dirty_image=dirty_image,
            dirty_noise_map=dirty_noise_map,
            dirty_signal_to_noise_map=dirty_signal_to_noise_map,
            auto_labels=AutoLabels(filename=auto_filename),
        )

    def subplot_dataset(self):
        """
        Standard subplot of the attributes of the plotter's `Interferometer` object.
        """
        return self.subplot(
            data=True,
            uv_wavelengths=True,
            amplitudes_vs_uv_distances=True,
            phases_vs_uv_distances=True,
            dirty_image=True,
            dirty_signal_to_noise_map=True,
            auto_filename="subplot_dataset",
        )

    def subplot_dirty_images(self):
        """
        Standard subplot of the dirty attributes of the plotter's `Interferometer` object.
        """
        return self.subplot(
            dirty_image=True,
            dirty_noise_map=True,
            dirty_signal_to_noise_map=True,
            auto_filename="subplot_dirty_images",
        )
