#!/usr/bin/env python

# Copyright (C) 2020 Collin Capano
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Makes a plot of MCMC parameters saved over checkpoint history.."""

import sys
import logging
import argparse

import numpy
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot

import pycbc
from pycbc.inference import io
from pycbc.results import metadata

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--input-file', type=str, required=True,
                    help='Path to the input HDF file.')
parser.add_argument('--output-file', type=str, required=True,
                    help='Name of the output plot.')
parser.add_argument('-t', '--plot-checkpoint-dt', action='store_true',
                    help='Plot the wall-clock time between checkpoints.')
parser.add_argument('-a', '--plot-act', action='store_true',
                    help='Plot ACT vs checkpoint iteration.')
parser.add_argument('-n', '--plot-effective-nsamples', action='store_true',
                    help='Plot the number of effective samples versus '
                         'checkpoint iteration.')
parser.add_argument('-b', '--plot-nchains-burned-in', action='store_true',
                    help='Plot the number of chains that were burned in '
                         'versus checkpoint iteration. Note that for '
                         'ensemble samplers, this will be all or nothing.')
opts = parser.parse_args()

pycbc.init_logging(True)

nplots = sum([opts.plot_act, opts.plot_effective_nsamples,
              opts.plot_nchains_burned_in, opts.plot_checkpoint_dt])
if nplots == 0:
    raise ValueError("nothing to plot")

# load the data
logging.info('Loading data')

fp = io.loadfile(opts.input_file, 'r')
history = fp['sampler_info/checkpoint_history']
iterations = history['niterations'][()]

if opts.plot_checkpoint_dt:
    checkpoint_dt = history['checkpoint_dt'][()]

if opts.plot_act:
    try:
        raw_acts = history['act'][()]
    except KeyError:
        raise ValueError("ACT history was not saved")
    if raw_acts.ndim == 2:
        # separate acts for each chain, calculate mean (of finite at each point)
        acts = numpy.full(raw_acts.shape[1], numpy.inf)
        for ii in range(raw_acts.shape[1]):
            aa = raw_acts[:, ii]
            aa = aa[numpy.isfinite(aa)]
            if aa.size > 0:
                acts[ii] = aa.mean()
    else:
        acts = raw_acts

if opts.plot_effective_nsamples:
    nsamples = history['effective_nsamples'][()]

if opts.plot_nchains_burned_in:
    try:
        burn_in_iter = history['burn_in_iteration'][()]
    except KeyError:
        raise ValueError("Burn-in history not saved")
    nchains_burned_in = numpy.zeros(iterations.size, dtype=int)
    nchains = fp.nchains
    for ii in range(nchains_burned_in.size):
        if burn_in_iter.ndim == 1:
            # ensemble sampler; all or none
            nchains_burned_in[ii] = nchains*(burn_in_iter[ii] > 0)
        else:
            nchains_burned_in[ii] = (burn_in_iter[:, ii] > 0).sum()
fp.close()

# plot
logging.info("Plotting")
fig, axes = pyplot.subplots(nrows=nplots, figsize=(8, 3*nplots))
if nplots == 1:
    axes = [axes]
pi = -1
dx = iterations.max() - iterations.min()
if dx == 0:
    dx = 1
xmin = iterations.min() - 0.025*dx
xmax = iterations.max() + 0.025*dx

caption = ("Status of the sampler at each checkpoint iteration (indicated by "
           "the x markers). Shown are:")

if opts.plot_checkpoint_dt:
    pi += 1
    ax = axes[pi]
    ax.plot(iterations, checkpoint_dt/60., lw=2, marker='x')
    ax.set_ylabel('wallclock dt (m)')
    ax.set_xlim(xmin, xmax)
    caption += (" ({}) the amount of wall-clock time between checkpoints "
                "(in minutes);".format(pi+1))

if opts.plot_act:
    pi += 1
    ax = axes[pi]
    ax.plot(iterations, acts, lw=2, marker='x', zorder=1)
    if raw_acts.ndim == 2:
        # plot each of the chains separately
        for ii in range(raw_acts.shape[0]):
            ax.plot(iterations, raw_acts[ii, :], lw=1, color='C1',
                    alpha=0.3,
                    zorder=0)
    ax.set_ylabel('ACT')
    ax.set_xlim(xmin, xmax)
    caption += (" ({}) the autocorrelation time (ACT) as computed from the "
                "estimated burn-in iteration to the end of the chain(s) at "
                "the checkpoint;".format(pi+1))

if opts.plot_effective_nsamples:
    pi += 1
    ax = axes[pi]
    ax.plot(iterations, nsamples, lw=2, marker='x', zorder=1)
    ax.set_ylabel(r'eff. N samples')
    ax.set_xlim(xmin, xmax)
    caption += (" ({}) the estimateed effective number of samples post "
                "burn-in at that checkpoint;".format(pi+1))

if opts.plot_nchains_burned_in:
    pi += 1
    ax = axes[pi]
    ax.plot(iterations, nchains_burned_in, lw=2, marker='x', zorder=1)
    # put a horizontal line at the total number of chains
    ax.axhline(nchains, ls='--', color='C3', zorder=0)
    ax.set_ylabel(r'N chains burned in')
    ax.set_xlim(xmin, xmax)
    caption += (" ({}) the number of chains that are burned in at the  "
                "checkpoint (for ensemble samplers, this will either be all "
                "or nothing), and the total number of chains (red dashed "
                "line);".format(pi+1))

ax.set_xlabel('iteration')
# replace the traling semicolon with a period
caption = caption[:-1] + "."

# common settings
for ii, ax in enumerate(axes):
    ax.grid(ls=':', zorder=-1)
    # turn off x ticks for all but the bottom
    if ii < len(axes) - 1:
        ax.set_xticklabels([])

# save
metadata.save_fig_with_metadata(
    fig, opts.output_file,
    cmd=" ".join(sys.argv),
    title="MCMC history",
    caption=caption,
    fig_kwds={'bbox_inches': 'tight'})
