#!/usr/bin/env python

# Copyright (C) 2016 Miriam Cabero Mueller
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Creates a movie showing how a sampler evolves.

To determine the length of the movie, you can either specify a --frame-number
or a --frame-step. The former specifies the number of frames to use in the
movie.  In this case, the code will load samples stored in the given file,
thinned such that the number of frames is <= the frame number. The latter
specifies how many samples to skip between frames. In both cases, you can
provide a --start-index or --end-index to control the range of samples to
consider. These values, along with the --frame-step are specified in terms of
the index of samples on file. If the file was thinned (the file's thinned_by
attribute is > 1), then indices will correspond to the sampler's iteration *
the thinned_by interval. The iteration number will be printed on each frame of
the movie.
"""

#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

import argparse
import logging
import subprocess
import os
import glob
from multiprocessing import Pool

import numpy

import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot

import pycbc.results
from pycbc import conversions
from pycbc import __version__
from pycbc.inference import (option_utils, io)

from pycbc.results.scatter_histograms import (create_multidim_plot,
                                              get_scale_fac)
from pycbc.results.plot import (add_style_opt_to_parser, set_style_from_cli)


def integer_logspace(start, end, num):
    """Generates a list of integers that are spaced approximately uniformly
    in log10 space between `start` and `end`. This is done such that the
    length of the output array is guaranteed to have length equal to num.

    Parameters
    ----------
    start : int
        Integer to start with; must be >= 0.
    end : int
        Integer to end with; must be > start.
    num : int
        The number of integers to generate.

    Returns
    -------
    array
        The output array of integers.
    """
    start += 1
    end += 1
    out = numpy.zeros(num, dtype=int)
    x = numpy.round(numpy.logspace(numpy.log10(start), numpy.log10(end),
                       num=num)).astype(int) - 1
    dx = numpy.diff(x)
    start_idx = 0
    while (dx == 0).any():
        # collect the unique values up to the point that their
        # difference becomes > 1
        x = numpy.unique(x)
        dx = numpy.diff(x)
        stop_idx = numpy.where(dx > 1)[0][0]
        keep = x[:stop_idx]
        stop_idx += start_idx
        out[start_idx:stop_idx] = keep
        start_idx = stop_idx
        # regenerate, starting from the new starting point
        num -= len(keep)
        start = keep[-1] + 2
        x = numpy.round(numpy.logspace(numpy.log10(start), numpy.log10(end),
                           num=num)).astype(int) - 1
        dx = numpy.diff(x)
    out[start_idx:len(x)+start_idx] = x
    return out

# we won't add thinning arguments nor iteration, since this is determined by
# the frame number/step options
skip_args = ['thin-start', 'thin-interval', 'thin-end', 'iteration']
parser = io.ResultsArgumentParser(description=__doc__, skip_args=skip_args)
parser.add_argument("--version", action="version", version=__version__,
                    help="show version number and exit")
# make frame number and frame step mutually exclusive
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--frame-number", type=int,
                   help="Maximum number of frames for the movie.")
group.add_argument("--frame-step", type=int, 
                   help="Number of sample indices to skip between frames.")
parser.add_argument("--start-index", type=int, default=0,
                    help="The starting index of the samples to load. Must be "
                         "< the number of samples that are stored in the "
                         "file. Default is 0.")
parser.add_argument("--end-index", type=int, default=None,
                    help="The ending index of the samples to load. Must be < "
                         "the number of samples that are stored in the file. "
                         "Default is to load everything to the end.")
parser.add_argument("--log-steps", action="store_true", default=False,
                    help="If frame-number is specified, make the number of "
                         "samples between frames uniform in log10. This "
                         "provides more detail of the early iterations, when "
                         "the sampler is changing most rapidly. An error will "
                         "be raised if frame-number is not provided.")
parser.add_argument("--output-prefix", type=str, required=True,
                    help="Output path and prefix for the frame files "
                         "(without extension).")
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--dpi', type=int, default=200,
                    help='Set the dpi for each frame; default is 200')
parser.add_argument("--nprocesses", type=int, default=None,
                    help="Number of processes to use. If not given then "
                         "use maximum.")
parser.add_argument("--movie-file", type=str,
                    help="Path for creating the movie automatically after "
                         "generating all the frames. If a movie with the same "
                         "name already exists, it will remove it. Format: mp4. "
                         "Installation of FFMPEG required.")
parser.add_argument("--cleanup", action='store_true',
                    help="Delete all plots generated after creating the movie."
                         " Only works together with option make-movie.")
# add options for what plots to create
option_utils.add_plot_posterior_option_group(parser)
# add scatter and density configuration options
option_utils.add_scatter_option_group(parser) 
option_utils.add_density_option_group(parser)
add_style_opt_to_parser(parser)

opts = parser.parse_args()

# set mpl style
set_style_from_cli(opts)

pycbc.init_logging(opts.verbose)

if len(opts.input_file) > 1:
    raise ValueError("this program can only plot one file at a time")

# Get data
logging.info('Loading parameters')
fp, parameters, labels, _ = io.results_from_cli(opts, load_samples=False)

# get the total number of samples on disk
thinned_by = fp.thinned_by
nsamples = fp.niterations // thinned_by
start_index = opts.start_index
if start_index >= nsamples:
    raise ValueError("given start-index {} is >= the number of samples in the "
                     "input file {}".format(start_index, nsamples))
end_index = opts.end_index
if end_index is not None and end_index >= nsamples:
    raise ValueError("given end-index {} is >= the number of samples in the "
                     "input file {}".format(end_index, nsamples))

if opts.log_steps and not opts.frame_number:
    raise ValueError("log-steps requires a non-zero frame-number to be "
                     "provided; see help for details.")

if opts.frame_number:
    thinint = 1
else:  # frame step was provided
    thinint = opts.frame_step

# get the samples
samples = fp.samples_from_cli(opts, parameters, thin_start=start_index,
                              thin_interval=thinint, thin_end=end_index,
                              flatten=False)
if samples.ndim > 2:
    # multi-tempered samplers will return 3 dims, so flatten
    _, ii, jj = samples.shape
    samples = samples.reshape((ii, jj))

# pull out the samples we want
if opts.frame_number:
    if opts.log_steps:
        indices = integer_logspace(start_index, samples.shape[1]-1,
                                   opts.frame_number)
    else:
        indices = numpy.unique(numpy.linspace(start_index, samples.shape[1]-1,
                               num=opts.frame_number).astype(int))
    samples = samples[:, indices]
else:
    # set the index counter based on what was loaded
    indices = numpy.arange(samples.shape[1]) * thinint + start_index


# Get z-values
if opts.z_arg is not None:
    logging.info("Getting samples for colorbar")
    if opts.z_arg == 'snr':
        z_arg = 'loglikelihood'
    else:
        z_arg = opts.z_arg
    zsamples = fp.samples_from_cli(opts, z_arg, thin_start=start_index,
                                   thin_interval=thinint, thin_end=end_index,
                                   flatten=False)
    if opts.z_arg == 'snr':
        loglr = zsamples[z_arg] - zsamples.lognl
        zsamples[z_arg] = conversions.snr_from_loglr(loglr)
    zlbl = opts.z_arg_labels[opts.z_arg]
    if zsamples.ndim > 2:
        _, ii, jj = zsamples.shape
        zsamples = zsamples.reshape((ii, jj))
    if opts.frame_number:
        zsamples = zsamples[:, indices]
    zvals = zsamples[z_arg]
    show_colorbar = True
    # Set common min and max for colorbar in all plots
    if opts.vmin is None:
        vmin = zvals.min()
    else:
        vmin = opts.vmin
    if opts.vmax is None:
        vmax = zvals.max()
    else:
        vmax = opts.vmax
else:
    zvals = None
    zlbl = None
    vmin = vmax = None
    show_colorbar = False

fp.close()

# get injection values if desired
expected_parameters = {}
if opts.plot_injection_parameters:
    injections = io.injections_from_cli(opts)
    for p in parameters:
        # check that all of the injections are the same
        try:
            vals = injections[p]
        except NameError:
            # injection doesn't have this parameter, skip
            logging.warn("Could not find injection parameter {}".format(p))
            continue
        unique_vals = numpy.unique(vals)
        if unique_vals.size != 1:
            raise ValueError("More than one injection found! To use "
                "plot-injection-parameters, there must be a single unique "
                "injection in all input files. Use the expected-parameters "
                "option to specify an expected parameter instead.")
        # passed: use the value for the expected
        expected_parameters[p] = unique_vals[0]

# get expected parameter values from command line
expected_parameters.update(option_utils.expected_parameters_from_cli(opts))
expected_parameters_color = opts.expected_parameters_color

logging.info('Choosing common characteristics for all figures')
# Set common min and max for axis in all plots
mins, maxs = option_utils.plot_ranges_from_cli(opts)
# add any missing parameters
for p in parameters:
    if p not in mins:
        mins[p] = samples[p].min()
for p in parameters:
    if p not in maxs:
        maxs[p] = samples[p].max()

# set colors:
# the 1d marginal colors will just be the first color in the style's cycle
linecolor = list(matplotlib.rcParams['axes.prop_cycle'])[0]['color']
# make the hist color black or white, depending on if the
# dark background is used
if opts.mpl_style == 'dark_background':
    hist_color = 'white'
else:
    hist_color = 'black'
# make the default contour color white if plot density is on
if not opts.contour_color and opts.plot_density:
    contour_color = 'white'
# otherwise, make the default be the same as the hist color
elif not opts.contour_color:
    contour_color = hist_color
else:
    contour_color = opts.contour_color

# Make each figure
# for sorting purposes, we will need to zero-pad the sample number with the
# appriopriate number of 0's
max_iter_num = indices[-1]*thinned_by + 1

def _make_frame(frame):
    """Wrapper for making the plot in a pooled environment.
    """
    plotargs = samples[:,frame]
    if zvals is not None:
        z = zvals[:,frame]
    else:
        z = None
    iter_num = str(indices[frame]*thinned_by + 1)
    iter_num = iter_num.zfill(len(str(max_iter_num)))
    output = opts.output_prefix + '-{}.png'.format(iter_num)

    fig, axis_dict = create_multidim_plot(parameters, plotargs, labels=labels,
                        mins=mins, maxs=maxs,
                        plot_marginal=opts.plot_marginal,
                            line_color=linecolor,
                            hist_color=hist_color,
                        plot_scatter=opts.plot_scatter,
                            zvals=z, show_colorbar=show_colorbar,
                            cbar_label=zlbl, vmin=vmin, vmax=vmax,
                            scatter_cmap=opts.scatter_cmap,
                        plot_density=opts.plot_density,
                        plot_contours=opts.plot_contours,
                            density_cmap=opts.density_cmap,
                            contour_color=contour_color,
                            use_kombine=opts.use_kombine_kde,
                        expected_parameters=expected_parameters,
                        expected_parameters_color=expected_parameters_color)

    # Write sample number
    if show_colorbar:
        xtxt = 0.85
    else:
        xtxt = 0.9
    ytxt = 0.95
    scale_fac = get_scale_fac(fig)
    fontsize = 8*scale_fac
    pyplot.annotate('Iteration {}'.format(iter_num), xy=(xtxt, ytxt),
        xycoords='figure fraction', horizontalalignment='right',
        verticalalignment='top', fontsize=fontsize)

    fig.savefig(output, bbox_inches='tight', dpi=opts.dpi)
    pyplot.close()
    return fig.get_figheight()/fig.get_figwidth()

# create the pool
if opts.nprocesses is None or opts.nprocesses > 1:
    global make_frame
    make_frame = _make_frame
    pool = Pool(opts.nprocesses)
    mfunc = pool.map
else:
    make_frame = _make_frame
    mfunc = map

logging.info("Making frames")
aspect_ratio = list(mfunc(make_frame, range(len(indices))))[0]

if opts.movie_file:
    logging.info("Making movie")
    frame_files = opts.output_prefix + "*.png"
    # set the aspect ratio
    aspect_ratio = "1024x{}".format(int(1024*aspect_ratio))
    if os.path.isfile(opts.movie_file):
        os.remove(opts.movie_file)
    subprocess.call(["ffmpeg", "-pix_fmt", "yuv420p", "-s", aspect_ratio,
                     "-pattern_type", "glob", "-i",
                     frame_files, opts.movie_file])
    if opts.cleanup:
        logging.info("Removing frames")
        for frame in glob.glob(frame_files):
            os.remove(frame)

logging.info('Done')
