import logging
import math
import os
import shutil

import matplotlib.pyplot as plt
import numpy as np
import pymultinest
import scipy.optimize

from autofit import conf
from autofit import exc
from autofit.core import link
from autofit.core import model_mapper as mm

logging.basicConfig()
logger = logging.getLogger(__name__)

SIMPLEX_TUPLE_WIDTH = 0.1


class Analysis(object):
    def fit(self, instance):
        raise NotImplementedError()

    def visualize(self, instance, suffix, during_analysis):
        raise NotImplementedError()

    def log(self, instance):
        raise NotImplementedError()


class Result(object):

    def __init__(self, constant, figure_of_merit, variable=None):
        """
        The result of an optimization.

        Parameters
        ----------
        constant: mm.ModelInstance
            An instance object comprising the class instances that gave the optimal fit
        figure_of_merit: float
            A value indicating the figure of merit given by the optimal fit
        variable: mm.ModelMapper
            An object comprising priors determined by this stage of the lensing
        """
        self.constant = constant
        self.figure_of_merit = figure_of_merit
        self.variable = variable

    def __str__(self):
        return "Analysis Result:\n{}".format(
            "\n".join(["{}: {}".format(key, value) for key, value in self.__dict__.items()]))


class IntervalCounter(object):
    def __init__(self, interval):
        self.count = 0
        self.interval = interval

    def __call__(self):
        if self.interval == -1:
            return False
        self.count += 1
        return self.count % self.interval == 0


class NonLinearOptimizer(object):

    def __init__(self, model_mapper=None, name=None):
        """Abstract base class for non-linear optimizers.

        This class sets up the file structure for the non-linear optimizer nlo, which are standardized across all \
        non-linear optimizers.

        Parameters
        ------------

        """
        self.named_config = conf.instance.non_linear

        name = name or "phase"

        self.phase_path = "{}/{}/".format(conf.instance.output_path, name)
        self.opt_path = "{}/{}/optimizer".format(conf.instance.output_path, name)

        sym_path = "{}/{}/optimizer".format(conf.instance.output_path, name)
        self.backup_path = "{}/{}/optimizer_backup".format(conf.instance.output_path, name)

        try:
            os.makedirs("/".join(sym_path.split("/")[:-1]))
        except FileExistsError:
            pass

        self.path = link.make_linked_folder(sym_path)

        self.variable = model_mapper or mm.ModelMapper()
        self.constant = mm.ModelInstance()

        self.label_config = conf.instance.label

        self.file_param_names = "{}/{}".format(self.opt_path, 'multinest.paramnames')
        self.file_model_info = "{}/{}".format(self.phase_path, 'model.info')

    def backup(self):
        try:
            shutil.rmtree(self.backup_path)
        except FileNotFoundError:
            pass
        shutil.copytree(self.opt_path, self.backup_path)

    def config(self, attribute_name, attribute_type=str):
        """
        Get a config field from this optimizer's section in non_linear.ini by a key and value type.

        Parameters
        ----------
        attribute_name: str
            The analysis_path of the field
        attribute_type: type
            The type of the value

        Returns
        -------
        attribute
            An attribute for the key with the specified type.
        """
        return self.named_config.get(self.__class__.__name__, attribute_name, attribute_type)

    def save_model_info(self):
        if not os.path.exists(self.path):
            os.makedirs(self.path)  # Create results folder if doesnt exist

        self.create_paramnames_file()
        if not os.path.isfile(self.file_model_info):
            with open(self.file_model_info, 'w') as file:
                file.write(self.variable.info)
            file.close()

    def fit(self, analysis):
        raise NotImplementedError("Fitness function must be overridden by non linear optimizers")

    @property
    def param_names(self):
        """The param_names vector is a list each parameter's analysis_path, and is used for *GetDist* visualization.

        The parameter names are determined from the class instance names of the model_mapper. Latex tags are \
        properties of each model class."""

        paramnames_names = []

        prior_prior_model_name_dict = self.variable.prior_prior_model_name_dict

        for prior_name, prior in self.variable.prior_tuples_ordered_by_id:
            paramnames_names.append(prior_prior_model_name_dict[prior] + '_' + prior_name)

        return paramnames_names

    @property
    def constant_names(self):
        constant_names = []

        constant_prior_model_name_dict = self.variable.constant_prior_model_name_dict

        for constant_name, constant in self.variable.constant_tuples_ordered_by_id:
            constant_names.append(constant_prior_model_name_dict[constant] + '_' + constant_name)

        return constant_names

    @property
    def param_labels(self):
        """The param_names vector is a list each parameter's analysis_path, and is used for *GetDist* visualization.

        The parameter names are determined from the class instance names of the model_mapper. Latex tags are
        properties of each model class."""

        paramnames_labels = []
        prior_class_dict = self.variable.prior_class_dict
        prior_prior_model_dict = self.variable.prior_prior_model_dict

        for prior_name, prior in self.variable.prior_tuples_ordered_by_id:
            param_string = self.label_config.label(prior_name)
            prior_model = prior_prior_model_dict[prior]
            cls = prior_class_dict[prior]
            cls_string = "{}{}".format(self.label_config.subscript(cls), prior_model.component_number + 1)
            param_label = "{}_{{\\mathrm{{{}}}}}".format(param_string, cls_string)
            paramnames_labels.append(param_label)

        return paramnames_labels

    def create_paramnames_file(self):
        """The param_names file lists every parameter's analysis_path and Latex tag, and is used for *GetDist*
        visualization.

        The parameter names are determined from the class instance names of the model_mapper. Latex tags are
        properties of each model class."""
        paramnames_names = self.param_names
        paramnames_labels = self.param_labels
        with open(self.file_param_names, 'w') as paramnames:
            for i in range(self.variable.prior_count):
                line = paramnames_names[i]
                line += ' ' * (70 - len(line)) + paramnames_labels[i]
                paramnames.write(line + '\n')


class AbstractFitness(object):
    def __init__(self, nlo, analysis, instance_from_physical_vector, constant):
        self.nlo = nlo
        self.result = None
        self.instance_from_physical_vector = instance_from_physical_vector
        self.constant = constant
        self.max_likelihood = -np.inf
        self.analysis = analysis
        visualise_interval = conf.instance.general.get('output', 'visualise_interval', int)
        log_interval = conf.instance.general.get('output', 'log_interval', int)
        backup_interval = conf.instance.general.get('output', 'backup_interval', int)

        self.should_log = IntervalCounter(log_interval)
        self.should_visualise = IntervalCounter(visualise_interval)
        self.should_backup = IntervalCounter(backup_interval)

    def fit_instance(self, instance):
        instance += self.constant

        likelihood = self.analysis.fit(instance)

        if likelihood > self.max_likelihood:
            self.max_likelihood = likelihood
            self.result = Result(instance, likelihood)

            if self.should_visualise():
                self.analysis.visualize(instance, suffix=None, during_analysis=True)

        if self.should_log():
            self.analysis.log(instance)
        if self.should_backup():
            self.nlo.backup()

        return likelihood


class DownhillSimplex(NonLinearOptimizer):

    def __init__(self, model_mapper=None, fmin=scipy.optimize.fmin, name=None):
        super(DownhillSimplex, self).__init__(model_mapper=model_mapper, name=name)

        self.xtol = self.config("xtol", float)
        self.ftol = self.config("ftol", float)
        self.maxiter = self.config("maxiter", int)
        self.maxfun = self.config("maxfun", int)

        self.full_output = self.config("full_output", int)
        self.disp = self.config("disp", int)
        self.retall = self.config("retall", int)

        self.fmin = fmin

        logger.debug("Creating DownhillSimplex NLO")

    def fit(self, analysis):
        initial_vector = self.variable.physical_values_from_prior_medians

        class Fitness(AbstractFitness):
            def __init__(self, nlo, instance_from_physical_vector, constant):
                super().__init__(nlo, analysis, instance_from_physical_vector, constant)

            def __call__(self, vector):
                try:
                    instance = self.instance_from_physical_vector(vector)
                    likelihood = self.fit_instance(instance)
                except exc.FitException:
                    likelihood = -np.inf
                return -2 * likelihood

        fitness_function = Fitness(self, self.variable.instance_from_physical_vector, self.constant)

        logger.info("Running DownhillSimplex...")
        output = self.fmin(fitness_function, x0=initial_vector)
        logger.info("DownhillSimplex complete")
        res = fitness_function.result

        # Create a set of Gaussian priors from this result and associate them with the result object.
        res.variable = self.variable.mapper_from_gaussian_means(output)

        analysis.visualize(instance=self.constant, suffix=None, during_analysis=False)

        return res


class MultiNest(NonLinearOptimizer):

    def __init__(self, model_mapper=None, sigma_limit=3, run=pymultinest.run, name=None):
        """Class to setup and run a MultiNest lensing and output the MultiNest nlo.

        This interfaces with an input model_mapper, which is used for setting up the individual model instances that \
        are passed to each iteration of MultiNest.
        """

        super(MultiNest, self).__init__(model_mapper=model_mapper, name=name)

        self.file_summary = "{}/{}".format(self.path, 'multinestsummary.txt')
        self.file_weighted_samples = "{}/{}".format(self.path, 'multinest.txt')
        self.file_results = "{}/{}".format(self.phase_path, 'model.results')
        self._weighted_sample_model = None
        self.sigma_limit = sigma_limit

        self.importance_nested_sampling = self.config('importance_nested_sampling', bool)
        self.multimodal = self.config('multimodal', bool)
        self.const_efficiency_mode = self.config('const_efficiency_mode', bool)
        self.n_live_points = self.config('n_live_points', int)
        self.evidence_tolerance = self.config('evidence_tolerance', float)
        self.sampling_efficiency = self.config('sampling_efficiency', float)
        self.n_iter_before_update = self.config('n_iter_before_update', int)
        self.null_log_evidence = self.config('null_log_evidence', float)
        self.max_modes = self.config('max_modes', int)
        self.mode_tolerance = self.config('mode_tolerance', float)
        self.outputfiles_basename = self.config('outputfiles_basename', str)
        self.seed = self.config('seed', int)
        self.verbose = self.config('verbose', bool)
        self.resume = self.config('resume', bool)
        self.context = self.config('context', int)
        self.write_output = self.config('write_output', bool)
        self.log_zero = self.config('log_zero', float)
        self.max_iter = self.config('max_iter', int)
        self.init_MPI = self.config('init_MPI', bool)
        self.run = run

        logger.debug("Creating MultiNest NLO")

    @property
    def pdf(self):
        import getdist
        return getdist.mcsamples.loadMCSamples(self.opt_path + '/multinest')

    def fit(self, analysis):
        self.save_model_info()

        class Fitness(AbstractFitness):

            # noinspection PyShadowingNames
            def __init__(self, nlo, instance_from_physical_vector, constant, output_results):
                super().__init__(nlo, analysis, instance_from_physical_vector, constant)
                self.output_results = output_results
                self.accepted_samples = 0
                self.number_of_accepted_samples_between_output = conf.instance.general.get(
                    "output",
                    "number_of_accepted_samples_between_output",
                    int)

            def __call__(self, cube, ndim, nparams, lnew):
                try:
                    instance = self.instance_from_physical_vector(cube)
                    likelihood = self.fit_instance(instance)
                except exc.FitException:
                    likelihood = -np.inf

                if likelihood > self.max_likelihood:

                    self.accepted_samples += 1

                    if self.accepted_samples == self.number_of_accepted_samples_between_output:
                        self.accepted_samples = 0
                        self.output_results(during_analysis=True)

                return likelihood

        # noinspection PyUnusedLocal
        def prior(cube, ndim, nparams):

            phys_cube = self.variable.physical_vector_from_hypercube_vector(hypercube_vector=cube)

            for i in range(self.variable.prior_count):
                cube[i] = phys_cube[i]

            return cube

        fitness_function = Fitness(self, self.variable.instance_from_physical_vector, self.constant,
                                   self.output_results)

        logger.info("Running MultiNest...")
        self.run(fitness_function.__call__,
                 prior,
                 self.variable.prior_count,
                 outputfiles_basename="{}/multinest".format(self.path),
                 n_live_points=self.n_live_points,
                 const_efficiency_mode=self.const_efficiency_mode,
                 importance_nested_sampling=self.importance_nested_sampling,
                 evidence_tolerance=self.evidence_tolerance,
                 sampling_efficiency=self.sampling_efficiency,
                 null_log_evidence=self.null_log_evidence,
                 n_iter_before_update=self.n_iter_before_update,
                 multimodal=self.multimodal,
                 max_modes=self.max_modes,
                 mode_tolerance=self.mode_tolerance,
                 seed=self.seed,
                 verbose=self.verbose,
                 resume=self.resume,
                 context=self.context,
                 write_output=self.write_output,
                 log_zero=self.log_zero,
                 max_iter=self.max_iter,
                 init_MPI=self.init_MPI)
        logger.info("MultiNest complete")

        self.output_results(during_analysis=False)
        self.output_pdf_plots()

        constant = self.most_likely_instance_from_summary()
        constant += self.constant
        variable = self.variable.mapper_from_gaussian_tuples(
            tuples=self.gaussian_priors_at_sigma_limit(self.sigma_limit))

        analysis.visualize(instance=constant, suffix=None, during_analysis=False)

        return Result(constant=constant, figure_of_merit=self.max_likelihood_from_summary(), variable=variable)

    def open_summary_file(self):

        summary = open(self.file_summary)
        summary.seek(1)

        return summary

    def read_vector_from_summary(self, number_entries, offset):

        summary = self.open_summary_file()

        summary.seek(1)
        summary.read(2 + offset * self.variable.prior_count)
        vector = []
        for param in range(number_entries):
            vector.append(float(summary.read(28)))

        summary.close()

        return vector

    def most_probable_from_summary(self):
        """
        Read the most probable or most likely model values from the 'obj_summary.txt' file which nlo from a \
        multinest lensing.

        This file stores the parameters of the most probable model in the first half of entries and the most likely
        model in the second half of entries. The offset parameter is used to start at the desired model.

        """
        return self.read_vector_from_summary(number_entries=self.variable.prior_count, offset=0)

    def most_likely_from_summary(self):
        """
        Read the most probable or most likely model values from the 'obj_summary.txt' file which nlo from a \
        multinest lensing.

        This file stores the parameters of the most probable model in the first half of entries and the most likely
        model in the second half of entries. The offset parameter is used to start at the desired model.
        """
        return self.read_vector_from_summary(number_entries=self.variable.prior_count, offset=56)

    def max_likelihood_from_summary(self):
        return self.read_vector_from_summary(number_entries=2, offset=112)[0]

    def max_log_likelihood_from_summary(self):
        return self.read_vector_from_summary(number_entries=2, offset=112)[1]

    def most_probable_instance_from_summary(self):
        most_probable = self.most_probable_from_summary()
        return self.variable.instance_from_physical_vector(most_probable)

    def most_likely_instance_from_summary(self):
        most_likely = self.most_likely_from_summary()
        return self.variable.instance_from_physical_vector(most_likely)

    def gaussian_priors_at_sigma_limit(self, sigma_limit):
        """Compute the Gaussian Priors these results should be initialzed with in the next phase, by taking their \
        most probable values (e.g the means of their PDF) and computing the error at an input sigma_limit.

        Parameters
        -----------
        sigma_limit : float
            The sigma limit within which the PDF is used to estimate errors (e.g. sigma_limit = 1.0 uses 0.6826 of the \
            PDF).
        """

        means = self.most_probable_from_summary()
        uppers = self.model_at_upper_sigma_limit(sigma_limit)
        lowers = self.model_at_lower_sigma_limit(sigma_limit)

        # noinspection PyArgumentList
        sigmas = list(map(lambda mean, upper, lower: max([upper - mean, mean - lower]), means, uppers, lowers))

        return list(map(lambda mean, sigma: (mean, sigma), means, sigmas))

    def model_at_sigma_limit(self, sigma_limit):
        limit = math.erf(0.5 * sigma_limit * math.sqrt(2))
        densities_1d = list(map(lambda p: self.pdf.get1DDensity(p), self.pdf.getParamNames().names))
        return list(map(lambda p: p.getLimits(limit), densities_1d))

    def model_at_upper_sigma_limit(self, sigma_limit):
        """Setup 1D vectors of the upper and lower limits of the multinest nlo.

        These are generated at an input limfrac, which gives the percentage of 1d posterior weighted samples within \
        each parameter estimate

        Parameters
        -----------
        sigma_limit : float
            The sigma limit within which the PDF is used to estimate errors (e.g. sigma_limit = 1.0 uses 0.6826 of the \
            PDF).
        """
        return list(map(lambda param: param[1], self.model_at_sigma_limit(sigma_limit)))

    def model_at_lower_sigma_limit(self, sigma_limit):
        """Setup 1D vectors of the upper and lower limits of the multinest nlo.

        These are generated at an input limfrac, which gives the percentage of 1d posterior weighted samples within \
        each parameter estimate

        Parameters
        -----------
        sigma_limit : float
            The sigma limit within which the PDF is used to estimate errors (e.g. sigma_limit = 1.0 uses 0.6826 of the \
            PDF).
        """
        return list(map(lambda param: param[0], self.model_at_sigma_limit(sigma_limit)))

    def model_errors_at_sigma_limit(self, sigma_limit):
        uppers = self.model_at_upper_sigma_limit(sigma_limit=sigma_limit)
        lowers = self.model_at_lower_sigma_limit(sigma_limit=sigma_limit)
        return list(map(lambda upper, lower: upper - lower, uppers, lowers))

    def weighted_sample_instance_from_weighted_samples(self, index):
        """Setup a model instance of a weighted sample, including its weight and likelihood.

        Parameters
        -----------
        index : int
            The index of the weighted sample to return.
        """
        model, weight, likelihood = self.weighted_sample_model_from_weighted_samples(index)

        self._weighted_sample_model = model

        return self.variable.instance_from_physical_vector(model), weight, likelihood

    def weighted_sample_model_from_weighted_samples(self, index):
        """From a weighted sample return the model, weight and likelihood hood.

        NOTE: GetDist reads the log likelihood from the weighted_sample.txt file (column 2), which are defined as \
        -2.0*likelihood. This routine converts these back to likelihood.

        Parameters
        -----------
        index : int
            The index of the weighted sample to return.
        """
        return list(self.pdf.samples[index]), self.pdf.weights[index], -0.5 * self.pdf.loglikes[index]

    def output_pdf_plots(self):

        import getdist.plots
        pdf_plot = getdist.plots.GetDistPlotter()

        plot_pdf_1d_params = conf.instance.general.get('output', 'plot_pdf_1d_params', bool)

        if plot_pdf_1d_params:

            for param_name in self.param_names:
                pdf_plot.plot_1d(roots=self.pdf, param=param_name)
                pdf_plot.export(fname=self.phase_path + 'image/pdf_' + param_name + '_1D.png')

        plt.close()

        plot_pdf_triangle = conf.instance.general.get('output', 'plot_pdf_triangle', bool)

        if plot_pdf_triangle:

            try:
                pdf_plot.triangle_plot(roots=self.pdf)
                pdf_plot.export(fname=self.phase_path + 'image/pdf_triangle.png')
            except Exception as e:
                print(type(e))
                print('The PDF triangle of this non-linear search could not be plotted. This is most likely due to a '
                      'lack of smoothness in the sampling of parameter space. Sampler further by decreasing the '
                      'parameter evidence_tolerance.')

        plt.close()

    def output_results(self, during_analysis=False):

        if os.path.isfile(self.file_summary):

            with open(self.file_results, 'w') as results:

                max_likelihood = self.max_likelihood_from_summary()

                results.write('Most likely model, Likelihood = ' + str(max_likelihood) + '\n')
                results.write('\n')

                most_likely = self.most_likely_from_summary()

                if len(most_likely) != self.variable.prior_count:
                    raise exc.MultiNestException('MultiNest and GetDist have counted a different number of parameters.'
                                                 'See github issue https://github.com/Jammy2211/PyAutoLens/issues/49')

                for i in range(self.variable.prior_count):
                    line = self.param_names[i]
                    line += ' ' * (60 - len(line)) + str(most_likely[i])
                    results.write(line + '\n')

                if during_analysis is False:

                    most_probable = self.most_probable_from_summary()

                    lower_limit = self.model_at_lower_sigma_limit(sigma_limit=3.0)
                    upper_limit = self.model_at_upper_sigma_limit(sigma_limit=3.0)

                    results.write('\n')
                    results.write('Most probable model (3 sigma limits)' + '\n')
                    results.write('\n')

                    for i in range(self.variable.prior_count):
                        line = self.param_names[i]
                        line += ' ' * (60 - len(line)) + str(most_probable[i]) + ' (' + str(lower_limit[i]) + ', ' + \
                                str(upper_limit[i]) + ')'
                        results.write(line + '\n')

                    lower_limit = self.model_at_lower_sigma_limit(sigma_limit=1.0)
                    upper_limit = self.model_at_upper_sigma_limit(sigma_limit=1.0)

                    results.write('\n')
                    results.write('Most probable model (1 sigma limits)' + '\n')
                    results.write('\n')

                    for i in range(self.variable.prior_count):
                        line = self.param_names[i]
                        line += ' ' * (60 - len(line)) + str(most_probable[i]) + ' (' + str(lower_limit[i]) + ', ' + \
                                str(upper_limit[i]) + ')'
                        results.write(line + '\n')

                results.write('\n')
                results.write('Constants' + '\n')
                results.write('\n')

                constant_names = self.constant_names
                constants = self.variable.constant_tuples_ordered_by_id

                for i in range(self.variable.constant_count):
                    line = constant_names[i]
                    line += ' ' * (60 - len(line)) + str(constants[i][1].value)
