#!/usr/bin/env python

# Copyright (C) 2016 Miriam Cabero Mueller, Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

import argparse
import itertools
import logging
import sys

import numpy

import matplotlib
from matplotlib import (patches, use)

import pycbc
import pycbc.version
from pycbc.results.plot import (add_style_opt_to_parser, set_style_from_cli)
from pycbc.results import metadata
from pycbc.io import FieldArray
from pycbc import __version__
from pycbc import conversions
from pycbc.workflow import WorkflowConfigParser
from pycbc.inference import (option_utils, io)

from pycbc.results.scatter_histograms import create_multidim_plot

use('agg')

# add options to command line
parser = io.ResultsArgumentParser()
# program-specific
parser.add_argument("--version", action="version", version=__version__,
                    help="Prints version information.")
parser.add_argument("--output-file", type=str, required=True,
                    help="Output plot path.")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Be verbose")
parser.add_argument("--plot-prior", nargs="+", type=str,
                    help="Plot the prior on the 1D marginal plots using the "
                         "given config file(s).")
parser.add_argument("--prior-nsamples", type=int, default=10000,
                    help="The number of samples to use for plotting the "
                         "prior. Default is 10000.")
# add options for what plots to create
option_utils.add_plot_posterior_option_group(parser)
# scatter configuration
option_utils.add_scatter_option_group(parser)
# density configuration
option_utils.add_density_option_group(parser)
parser.add_argument('--dpi', type=int, default=200,
                    help="Set the DPI of the plot. Default is 200.")
# style option
add_style_opt_to_parser(parser)

# parse command line
opts = parser.parse_args()

# set mpl style
set_style_from_cli(opts)

# set logging
pycbc.init_logging(opts.verbose)

# load the samples
fps, parameters, labels, samples = io.results_from_cli(opts)

# typecast to list so the input files can be iterated over
fps = fps if isinstance(fps, list) else [fps]
samples = samples if isinstance(samples, list) else [samples]

# if a z-arg is specified, load samples for it
if opts.z_arg is not None:
    logging.info("Getting samples for colorbar")
    z_arg = 'loglikelihood' if opts.z_arg == 'snr' else opts.z_arg
    zlbl = opts.z_arg_labels[opts.z_arg]
    zvals = []
    for fp in fps:
        zsamples = fp.samples_from_cli(opts, parameters=z_arg)
        if opts.z_arg == 'snr':
            loglr = zsamples[z_arg] - zsamples.lognl
            zsamples[z_arg] = conversions.snr_from_loglr(loglr)
        zvals.append(zsamples[z_arg])
else:
    zvals = None
    zlbl = None

# if no plotting options selected, then the default options are based
# on the number of parameters
plot_options = [opts.plot_marginal, opts.plot_scatter, opts.plot_density]
if not numpy.any(plot_options):
    if len(parameters) == 1:
        opts.plot_marginal = True
    else:
        opts.plot_scatter = True
        # FIXME: right now if there are two parameters it wants
        # both plot_scatter and plot_marginal. One should have the option
        # of give only plot_scatter and that should be the default for
        # two or more parameters
        opts.plot_marginal = True

if opts.plot_prior is not None:
    # check that we're plotting 1D marginals
    if not opts.plot_marginal:
        raise ValueError("prior may only be plotted on 1D marginal plot; "
                         "either turn on --plot-marginal, or turn off "
                         "--plot-prior")
    logging.info("Loading prior")
    cp = WorkflowConfigParser(opts.plot_prior)
    prior = option_utils.prior_from_config(cp)
    logging.info("Drawing samples from prior")
    prior_samples = prior.rvs(opts.prior_nsamples)
    # we'll just use the first file for metadata
    fp = fps[0]
    # add the static params
    for param in fp.attrs['static_params']:
        setattr(prior_samples, param, fp.attrs[param])
    # remap any parameters
    if 'remapped_params' in fp.attrs:
        remapped_params = {}
        for func, param in fp.attrs['remapped_params']:
            try:
                remapped_params[param] = prior_samples[func]
            except (NameError, TypeError, AttributeError):
                continue
        prior_samples = FieldArray.from_kwargs(**remapped_params)
        for param in fp.attrs['static_params']:
            setattr(prior_samples, param, fp.attrs[param])

# get minimum and maximum ranges for each parameter from command line
mins, maxs = option_utils.plot_ranges_from_cli(opts)

# add any missing parameters
for p in parameters:
    if p not in mins:
        mins[p] = numpy.array([s[p].min() for s in samples]).min()
    if p not in maxs:
        maxs[p] = numpy.array([s[p].max() for s in samples]).max()

# get injection values if desired
expected_parameters = {}
if opts.plot_injection_parameters:
    injections = io.injections_from_cli(opts)

    if opts.pick_injection_by_time:
        if 'tc' not in injections:
            raise ValueError("Couldn't determine injection time, tried tc")

        inj_time = injections['tc']

        if 'tc' in samples[0]:
            pos_time = samples[0]['tc'].mean()
        elif 'trigger_time' in fps[0].attrs:
            pos_time = fps[0].attrs['trigger_time']
        elif 'tc_ref' in fps[0].attrs:
            pos_time = fps[0].attrs['tc_ref']
        else:
            raise ValueError("Couldn't find posterior time, "
                             "tried tc, tc_ref, and trigger_time attribute")
        pick = abs(inj_time - pos_time).argmin()

    for p in parameters:
        try:
            vals = injections[p]
        except (NameError, TypeError, AttributeError):
            # injection doesn't have this parameter, skip
            logging.warn("Could not find injection parameter {}".format(p))
            continue

        if opts.pick_injection_by_time:
            expected_parameters[p] = injections[p][pick]
        else:
            # check that all of the injections are the same
            unique_vals = numpy.unique(vals)
            if unique_vals.size != 1:
                raise ValueError("More than one injection found! To use "
                    "plot-injection-parameters, there must be a single unique "
                    "injection in all input files. Use the expected-parameters"
                    " option to specify an expected parameter instead.")

            # passed: use the value for the expected
            expected_parameters[p] = unique_vals[0]

# close the files, we don't need them anymore
for fp in fps:
    fp.close()

# get expected parameter values from command line
expected_parameters.update(option_utils.expected_parameters_from_cli(opts))

# get the color cycle to use
color_cycle = [c['color'] for c in matplotlib.rcParams['axes.prop_cycle']]
colors = itertools.cycle(color_cycle)

# plot each input file
logging.info("Plotting")
hist_colors = []
for (i, s) in enumerate(samples):

    # on first iteration create figure otherwise update old figure
    if i == 0:
        fig = None
        axis_dict = None

    # get a default line color; this is used for the 1D marginal lines
    linecolor = next(colors)
    # set different colors depending on if one or more files was provided
    if len(opts.input_file) == 1:
        # make the hist color black or white, depending on if dark background
        # is used
        if opts.mpl_style == 'dark_background':
            hist_color = 'white'
        else:
            hist_color = 'black'
        # make histograms filled if only one input file to plot
        fill_color = 'gray'
        # make the default contour color white if plot density is on
        if not opts.contour_color and opts.plot_density:
            contour_color = 'white'
        # otherwise, make the default be the same as the hist color
        elif not opts.contour_color:
            contour_color = hist_color
        else:
            contour_color = opts.contour_color
    else:
        # don't fill in the histograms
        fill_color = None
        # make the contour and hist colors the same as the 1D marginal lines
        contour_color = hist_color = linecolor

    # save the hist color for the legend, in the case of multiple files
    hist_colors.append(hist_color)

    # plot
    fig, axis_dict = create_multidim_plot(
                    parameters, s, labels=labels, fig=fig, axis_dict=axis_dict,
                    plot_marginal=opts.plot_marginal,
                    marginal_percentiles=opts.marginal_percentiles,
                    plot_scatter=opts.plot_scatter,
                    zvals=zvals[i] if zvals is not None else None,
                    show_colorbar=opts.z_arg is not None,
                    cbar_label=zlbl,
                    vmin=opts.vmin, vmax=opts.vmax,
                    scatter_cmap=opts.scatter_cmap,
                    plot_density=opts.plot_density,
                    plot_contours=opts.plot_contours,
                    contour_percentiles=opts.contour_percentiles,
                    density_cmap=opts.density_cmap,
                    contour_color=contour_color,
                    hist_color=hist_color,
                    line_color=linecolor,
                    fill_color=fill_color,
                    use_kombine=opts.use_kombine_kde,
                    mins=mins, maxs=maxs,
                    expected_parameters=expected_parameters,
                    expected_parameters_color=opts.expected_parameters_color)

# plot the prior
if opts.plot_prior:
    if len(opts.input_file) > 1:
        hist_color = next(colors)
    fig, axis_dict = create_multidim_plot(
        parameters, prior_samples, fig=fig, axis_dict=axis_dict,
        labels=labels, plot_marginal=True, marginal_percentiles=[],
        plot_scatter=False, plot_density=False, plot_contours=False,
        fill_color=None,
        marginal_title=False, marginal_linestyle=':',
        hist_color=hist_color,
        mins=mins, maxs=maxs)

# add legend to upper right for input files
if len(opts.input_file) > 1:
    handles = []
    labels = []
    for color, fn in zip(hist_colors, opts.input_file):
        label = opts.input_file_labels[fn]
        handles.append(patches.Patch(color=color, label=label))
        labels.append(label)
    fig.legend(loc="upper right", handles=handles,
               labels=labels)

# set DPI
fig.set_dpi(opts.dpi)

# save
metadata.save_fig_with_metadata(
                 fig, opts.output_file,
                 cmd=" ".join(sys.argv),
                 title="Posteriors",
                 caption="Posterior probability density functions.",
                 fig_kwds={'bbox_inches': 'tight'})

# finish
logging.info("Done")
