"""Contour tracking on image series."""

# Misc. package imports
from numpy import nan as NaN
import numpy as np
import matplotlib.pyplot as plt
import h5py

# Local imports
from .analysis_base import Analysis
from .formatters import AnalysisPandasFormatterBase, AnalysisFormatterBase
from .formatters import MultiFormatterBase
from .results import ResultsBase
from ..fileio import FileIO
from ..parameters.analysis import ContourSelection, Threshold
from ..contours import Contour, ContourProperties, ContourCoordinates
from ..contours import ContourFinder
from ..viewers import AnalysisViewer

# ============================ Results formatting ============================


# Here we combine two formatters using a multiformatter
# one is for contour properties (centroid, area etc.)
# the other one is for raw contour data (coordinates)


def check_active(method):
    """Decorator to suppress methods in case making table is not requires"""

    def wrapper(self, *args, **kwargs):
        if not self.active:
            return
        return method(self, *args, **kwargs)

    return wrapper


class ContourTrackingTableFormatter(AnalysisPandasFormatterBase):
    """Format contour properties (area, position etc.) results as a table

    (pandas dataframe)
    """

    Properties = ContourProperties

    def __init__(self, *args, **kwargs):
        """Add active in order to deactivate when pandas table not needed"""
        super().__init__(*args, **kwargs)
        self.active = True

    @staticmethod
    def property_name(kind, k):
        """How to move from property data to column name"""
        return f'{kind}_{k}'

    @property
    def results_dataframe(self):
        """Redefined here because pandas dataframe is not directly in results.data"""
        return self.analysis.results.data['table']

    def _column_names(self):
        """Prepare structure(s) that will hold the analyzed data."""
        return [
            self.property_name(name, k)
            for k in range(self.analysis.n_contours)
            for name in self.Properties.table_columns
        ]

    @check_active
    def _data_to_results_row(self, data):
        """Generate iterable of data that fits in the defined columns."""
        all_info = []

        for contour in data['contours']:

            if contour is None:    # no contour detected.
                info = (NaN,) * len(self.Properties.table_columns)
            else:
                ppties = contour.properties
                info = (*ppties.centroid, ppties.area, ppties.perimeter)

            all_info.append(info)

        return sum(all_info, start=())  # "Flatten" list of tuples

    @check_active
    def _results_row_to_data(self, row):
        """Go from row of data to properties data

        (not full data dict, combined with the multiformatter)
        """
        properties_data = []

        for k in range(self.analysis.n_contours):

            lim1 = self.property_name(self.Properties.table_columns[0], k)
            lim2 = self.property_name(self.Properties.table_columns[-1], k)

            ppties = self.Properties.from_table_row(row=row.loc[lim1:lim2])
            properties_data.append(ppties)

        return properties_data


class ContourTrackingFullFormatter(AnalysisFormatterBase):
    """Format full contour coordinates (x, y data) results"""

    def _prepare_data_storage(self):
        """Prepare structure(s) that will hold the analyzed data"""
        self.data = {}

    def _store_data(self, data):
        """How to store data generated by analysis on a single image.

        Input
        -----
        data : dict
            Dictionary of data, output of Analysis.analyze()
        """
        num = data['num']
        self.data[num] = data['contours']

    def _to_results_data(self):
        """Return partial data that will be combined with the multiformatter"""
        return self.data

    def _to_metadata(self):
        """Get analysis metadata excluding paths and transforms"""
        return {'contour selection': self.analysis.contour_selection.data}

    def _regenerate_analysis_data(self, num):
        """How to go back to raw data (as spit out by the analysis methods
        during analysis) from data saved in results or files."""
        try:
            contour_data = self.analysis.results.data['contours']
        except KeyError:
            return

        try:
            contours = contour_data[num]
        except KeyError:  # this particular num not analyzed -> return None
            return

        return contours


class ContourTrackingFormatter(MultiFormatterBase):
    """Combine property formatter and coordinates formatter"""

    Formatters = [
        ContourTrackingTableFormatter,
        ContourTrackingFullFormatter,
    ]

    def _combine_results_data(self, individual_data):
        """How to combine individual data obtained from _to_results_data()

        individual_data is a list of data produced by each
        _to_results_data() methods of individual formatters

        Returns data that will be stored in results.data
        """
        properties_table, contour_data = individual_data

        table_dict = {'table': properties_table} if self.analysis.make_table else {}
        contour_dict = {'contours': contour_data}

        return {**table_dict, **contour_dict}

    def _combine_regenerated_data(self, individual_regenerated_data):
        """How to combine individual data obtained from _to_results_data()

        [OPTIONAL]

        individual_data is a list of data produced by each
        _regenerate_analysis_data() methods of individual formatters

        Returns regenrated data that will sent to viewers etc.
        """
        _, contours = individual_regenerated_data
        if contours is None:
            return {}
        return {'contours': contours}


# ============================= Results classes ==============================


class ContourTrackingResults(ResultsBase):

    # define in subclass (e.g. 'Img_GreyLevel')
    # Note that the program will add extensions depending on context
    # (data or metadata).
    default_filename = 'Img_ContourTracking'

    # Define type of data (e.g. data / metadata and corresponding extensions)
    # Possible to change in subclasses.
    extensions = {
        'data': ('.tsv', '.hdf5'),
        'metadata': ('.json',),
    }

    # What to add to the default filename or specified filename
    # needs to be same length as extensions above.
    # useful if two extensions are the same, to distinguish filenames
    filename_adds = {
        'data': ('', ''),
        'metadata': ('',),
    }

    # Corresponding loading and saving methods, possibility to put several
    # in order to save data to various files or different formats.
    # Must be same length as extensions above.
    load_methods = {
        'data': ('_load_data_tsv', '_load_data_hdf5'),
        'metadata': ('_load_metadata',),
    }

    # idem for save methods
    save_methods = {
        'data': ('_save_data_tsv', '_save_data_hdf5'),
        'metadata': ('_save_metadata',),
    }

    @property
    def n_contours(self):
        return len(self.metadata['contour selection']['properties'])

    # Loading data -----------------------------------------------------------

    def _load_data_tsv(self, filepath):
        """Load data from tsv file (only properties, no coordinates)"""
        if not filepath.exists():
            return

        return FileIO.from_tsv(filepath=filepath)

    def _load_data_hdf5(self, filepath):
        """Load data from hdf5 file (both properties and coordinates)

        Data structure is
        contour tracking/num/17/contour/0/...
        """

        contours = {}

        with h5py.File(filepath, 'r') as f:

            basegrp = f['contour tracking']
            n_contours = basegrp.attrs['number of contours']

            for img_group in basegrp['num'].values():

                ctrgrp = img_group['contour']
                num = ctrgrp.attrs['num']
                contours[num] = []

                for k in range(n_contours):

                    try:
                        ctr_group = ctrgrp[f'{k}']
                    except KeyError:
                        contour = None
                    else:
                        contour = Contour(
                            coordinates=ContourCoordinates.from_hdf5_group(ctr_group),
                            properties=ContourProperties.from_hdf5_group(ctr_group),
                        )
                    finally:
                        contours[num].append(contour)

        return contours

    def loaded_data_to_data(self, loaded_data):
        """How to go from the results of load_data into self.data

        Possibility to subclass, by default assumes just one
        data returned that goes directly into self.data
        """
        data_tsv, data_hdf5 = loaded_data

        coords_dict = {'contours': data_hdf5}
        ppties_dict = {'table': data_tsv} if data_tsv is not None else {}

        return {**ppties_dict, **coords_dict}

    # Saving data ------------------------------------------------------------

    def _save_data_tsv(self, data, filepath):

        try:
            table = data['table']
        except KeyError:
            return

        FileIO.to_tsv(data=table, filepath=filepath)

    def _save_data_hdf5(self, data, filepath):

        try:
            contour_data = data['contours']
        except KeyError:
            return

        with h5py.File(filepath, 'w') as f:

            basegrp = f.create_group('contour tracking')
            basegrp.attrs['number of contours'] = self.n_contours
            numgrp = basegrp.create_group('num', track_order=True)

            for num, contours in contour_data.items():

                ctrgrp = numgrp.create_group(f'{num}/contour', track_order=True)
                # The line below is useful when re-loading the data
                ctrgrp.attrs['num'] = num

                for k, contour in enumerate(contours):

                    if contour is None:  # HDF5 cannot store None data
                        continue

                    group = ctrgrp.create_group(f'{k}', track_order=True)
                    group.attrs['num'] = num
                    group.attrs['contour id'] = k

                    if contour.coordinates is not None:
                        contour.coordinates.to_hdf5_group(group)

                    if contour.properties is not None:
                        contour.properties.to_hdf5_group(group)

    # --------------------- Loading and saving metadata ----------------------

    def _load_metadata(self, filepath):
        """Return analysis metadata from file as a dictionary.

        Parameters
        ----------
        filepath : pathlib.Path object
            file to load the metadata from

        Returns
        -------
        dict
            metadata
        """
        return FileIO.from_json(filepath)

    def _save_metadata(self, metadata, filepath):
        """Write metadata to file

        Parameters
        ----------
        metadata : dict
            Metadata as a dictionary

        filepath : pathlib.Path object
            file to load the metadata from

        Returns
        -------
        None
        """
        return FileIO.to_json_with_gitinfo(data=metadata, filepath=filepath)

    # ============== Methods specific to ContourTrackingResults ==============

    @property
    def table(self):
        """Returns the table of properties data, if exists"""
        return self.data.get('table')

    def get_contour(self, k=0, num=0):
        """Return specific contour number k, on image num"""
        return self.data['contours'][num][k]


# ======================= Plotting / Animation classes =======================


class ContourTrackingViewer(AnalysisViewer):

    def _plot_contours_and_centroids(self, data):

        self.centroid_pts = []
        self.contour_lines = []

        for _ in range(self.analysis.n_contours):
            centroid_pt, = self.ax_img.plot([], [], '+')
            contour_line, = self.ax_img.plot([], [], '-', c=centroid_pt.get_color())
            self.centroid_pts.append(centroid_pt)
            self.contour_lines.append(contour_line)

        self._update_contours_and_centroids(data)

    def _hide_all(self):
        for pt, line in zip(self.centroid_pts, self.contour_lines):
            pt.set_visible(False)
            line.set_visible(False)

    def _update_contours_and_centroids(self, data):

        try:
            contours = data['contours']
        except KeyError:  # no contour data in incoming data
            self._hide_all()
            return

        for contour, pt, line in zip(
            contours,
            self.centroid_pts,
            self.contour_lines,
        ):
            if contour is None:
                continue

            if contour.properties is None:
                pt.set_visible(False)
            else:
                pt.set_visible(True)
                xc, yc = contour.properties.centroid
                pt.set_data((xc,), (yc,))

            if contour.coordinates is None:
                line.set_visible(False)
            else:
                line.set_visible(True)
                line.set_data(contour.coordinates.x, contour.coordinates.y)

    # ---------------- Methods subclassed from AnalysisViewer ----------------

    def _create_figure(self):
        self.fig, self.ax_img = plt.subplots()
        self.ax_img.axis('off')
        self.axs = self.ax_img,

    def _first_plot(self, data):
        """What to do the first time data arrives on the plot."""
        self._create_image(data)
        self.fig.tight_layout()
        self._plot_contours_and_centroids(data)
        self.updated_artists = self.centroid_pts + self.contour_lines + [self.imshow]

    def _update_plot(self, data):
        """What to do upon iterations of the plot after the first time."""
        self._update_image(data)
        self._update_contours_and_centroids(data)


# ----------------------------------------------------------------------------
# =========================== Main ANALYSIS class ============================
# ----------------------------------------------------------------------------


class ContourTracking(Analysis):
    """Class to track contours on image series.

    Class attributes
    ----------------
    Viewer : class
        (subclass of AnalysisViewer)
        Viewer class/subclasses that is used to display and inspect
        analysis data (is used by ViewerTools)

    Formatter: class
        (subclass of Formatter)
        class used to format results spit out by the raw analysis into
        something storable/saveable by the Results class.

    Results : class
        (subclass of Results)
        Results class/subclasses that is used to store, save and load
        analysis data and metadata.
    """
    Viewer = ContourTrackingViewer
    Formatter = ContourTrackingFormatter
    Results = ContourTrackingResults
    Finder = ContourFinder

    # If results are independent (results from one num do not depend from
    # analysis on other nums), one do not need to re-do the analysis when
    # asking for the same num twice, and parallel computing is possible
    independent_results = False

    def __init__(
        self,
        img_series,
        savepath=None,
        make_table=True,
        tolerance_displacement=None,
        tolerance_area=None,
    ):
        """Analysis of iso-grey-level contours and their evolution in series.

        Parameters
        ----------
        img_series : ImgSeries or ImgStack object
            image series on which the analysis will be run

        savepath : str or Path object
            folder in which to save analysis data & metadata
                    (if not specified, the img_series savepath is used)

        make_table : bool
            if True (default), create table (pandas dataframe) with the
            contour properties as a function of the image number;
            the table does not contain contour coordinates data

        tolerance_displacement : float
            if None (default), no restriction on displacements
            if value = d > 0, do not consider displacements more than d pixels

        tolerance_area : float
            if None (default), no restriction on area variations of contours
            if value = x > 0, do not consider relative variation in area of
            more than x.
        """
        super().__init__(img_series=img_series, savepath=savepath)

        self.make_table = make_table
        if not make_table:
            self.formatters[0].active = False

        # empty contour param object, needs to be filled with contours.define()
        # or contours.load() prior to starting analysis with self.run()
        self.contour_selection = ContourSelection(self)
        self.threshold = Threshold(self)

        # Tolerance in displacement and areas to match contours
        self.contour_finder = self.Finder(
            tolerance_displacement=tolerance_displacement,
            tolerance_area=tolerance_area,
        )

    @property
    def n_contours(self):
        return len(self.contour_selection.properties)

    # ------------------- Subclassed methods from Analysis -------------------

    def _init_analysis(self):
        """Check everything OK before starting analysis & initialize params."""

        if self.contour_selection.is_empty:
            msg = (
                "Contours not defined yet. Use self.contour_selection.define(), or "
                "self.contour_selection.load() if contours have been previously saved."
            )
            raise AttributeError(msg)

        self.previous_contour_ppties = self.contour_selection.properties.copy()

    def _analyze(self, img):
        """Find contours at level in file i closest to the reference positions.

        Parameters
        ----------
        img : array_like
            image array to be analyzed (e.g. numpy array).

        Returns
        -------
        list of tuples
            [(x1, y1, p1, a1), (y2, y2, p2, a1), ..., (xn, yn, pn, an)]
            where n is the number of contours followed and
            (x, y), p, a is position, perimeter, area
        """
        data = {'contours': []}
        new_contour_ppties = {}

        contours = self.contour_finder.find_contours(
            img=img,
            level=self.threshold.value,
        )

        # The loop is on the number of followed contours, basically
        for name, prev_contour_ppties in self.previous_contour_ppties.items():

            # contour can be None if no matching contour found
            # (if matching contour, contour is a Contour object)
            contour = self.contour_finder.match(contours, prev_contour_ppties)
            data['contours'].append(contour)

            # If contour not found, keep same ref. ppties as before
            if contour is None:
                new_contour_ppties[name] = prev_contour_ppties
            # If found, update.
            else:
                new_contour_ppties[name] = contour.properties

        self.previous_contour_ppties = new_contour_ppties

        return data

    # ------------------ Redefinitions of Analysis methods -------------------

    def regenerate(self, filename=None):
        """Load saved data, metadata and regenerate objects from them.

        Is used to reset the system in a state similar to the end of the
        analysis that was made before saving the results.

        Parameters
        ----------
        filename : str
            name of the analysis results file (if None, use default)

        Notes
        -----
            More or less equivalent to:
            >>> analysis.results.load(filename=filename)
            >>> image_series.load_transforms()
            (except that transforms are loaded from the metadata file of the
            analysis, not from a file generated by
            image_series.save_transforms())
        """

        # Load data
        super().regenerate(filename=filename)

        # regenerate internal threshold / contours object
        self.contour_selection.load(filename=filename)

        # at the moment, this is already done by contours.load(), but I'm
        # putting this there to be sure in case contours are modified to not
        # include threshold level information
        self.threshold.load(filename=filename)
