#! /usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import os
import argparse
import h5py
import logging
import numpy
import sys
import pycbc
from pycbc import results
from pycbc.inference.option_utils import ParseLabelArg
from pycbc.inference.io import ResultsArgumentParser, results_from_cli

parser = ResultsArgumentParser(
    description="Makes a table of posterior results.")
parser.add_argument("--verbose", action="store_true", default=False,
    help="Print logging info.")
parser.add_argument("--output-file", type=str, required=True,
    help="Path to output plot.")
parser.add_argument("--print-metadata", nargs="+", metavar="PARAM[:LABEL]",
                    default=[],
                    action=ParseLabelArg,
                    help="Add metadata information after the parameter table. "
                         "Any parameter stored in the any of the file's attrs "
                         "may be printed. To specify an attribute in a "
                         "sub-group, prepend the parameter with the group  "
                         "name in the file. A label may be provided for the "
                         "html output; otherwise, the parameter name will be "
                         "used. If nothing provided, will just print the "
                         "number of posterior samples that were used (this is "
                         "always printed).")
parser.add_argument("--percentiles", type=float, nargs=3,
                    default=[5, 50, 95],
                    help="Percentiles to calculate. Must provide 3 values. "
                         "Default is 5 50 95 (the median with 90%% credible "
                         "interval).")

# parse the command line
opts = parser.parse_args()

pycbc.init_logging(opts.verbose)

ninputs = len(opts.input_file)
if ninputs > 1:
    raise ValueError("This program can only handle one file at a time")

# make sure the loglikelihood and logprior are included in the parameters
# so that we can get the maxL and MAP values
requested_params = [p for p in opts.parameters]
if 'loglikelihood' not in requested_params:
    opts.parameters.append('loglikelihood')
if 'logprior' not in requested_params:
    opts.parameters.append('logprior')

# load the results
fp, parameters, labels, samples = results_from_cli(opts)
# replace with the requested parameters
parameters = requested_params

# get the index of the map and maxl
if 'loglikelihood' in samples.fieldnames:
    maxlidx = samples['loglikelihood'].argmax()
    if 'logprior' in samples.fieldnames:
        mapidx = samples['loglikelihood+logprior'].argmax()
    else:
        # the sampler didn't save the logprior
        mapidx = None
else:
    # the sampler didn't save the loglikelihood
    maxlidx = None

# make sure the percentiles are sorted
opts.percentiles.sort()

table = []
for param in parameters:
    row = [labels[param]]
    # calculate the score at a given percentile
    x = samples[param]
    percentiles = numpy.array([numpy.percentile(x, q)
                              for q in sorted(opts.percentiles)])
    values_min, values_med, values_max = percentiles
    negerror = values_med - values_min
    poserror = values_max - values_med
    fmt = '${0}$'.format(results.format_value(
        values_med, negerror, plus_error=poserror, use_scientific_notation=5))
    row.append(fmt)
    # get the maxl and map values
    mapval = x[mapidx]
    maxlval = x[maxlidx]
    # we'll use the error from the credible interval to determine how many
    # significant figures to print for the map and maxl values
    error = min(negerror, poserror)
    for idx in [mapidx, maxlidx]:
        if idx is not None:
            val = x[idx]
            fmt = '${0}$'.format(results.format_value(
                val, error, use_scientific_notation=5, include_error=False))
        else:
            # sampler didn't provide a loglikelihood or logprior, so just
            # enter nothing
            fmt = '--'
        row.append(fmt)
    # add to the table
    table.append(row)

# create table header
interval = opts.percentiles[-1]-opts.percentiles[0]
headers = ["Parameter",
           "{0:d}% Credible Interval".format(int(interval)),
           "Maximum Posterior",
           "Maximum Likelihood"
           ]

# add mathjax header to display latex
html = results.mathjax_html_header() + '\n%s'%(
    str(results.static_table(table, headers) ))

# add extra metadata
mdatatmplt = '<h4><b>{}:</b> {}</h4>'
def formatattr(fp, attr):
    group = os.path.dirname(attr)
    if not group.startswith('/'):
        group = '/' + group
    attr = os.path.basename(attr)
    val = fp[group].attrs[attr]
    if isinstance(val, numpy.ndarray):
        val = ', '.join(['{}'.format(x) for x in val])
    return val
metadata = [mdatatmplt.format(opts.print_metadata_labels[p],
                              formatattr(fp, p))
            for p in opts.print_metadata]
# add the number of posterior samples
metadata.append(mdatatmplt.format('Number of posterior samples', samples.size))

html += '\n'.join(metadata) + '<br />\n<br />\n'

# save HTML table
results.save_fig_with_metadata(
    html, opts.output_file, {},
    cmd=" ".join(sys.argv),
    title="Parameter Estimates",
    caption="Summary of parameter estimates. The {0:d}-percent credible "
            "interval is the {1:d}th +/- {2:d}th/{3:d}th percentiles. "
            "The maximum posterior parameters are the parameters "
            "with the largest likelihood * prior. The maximum "
            "likelihood parameters are the parameters with the "
            "largest likelihood."
            .format(int(interval), int(opts.percentiles[1]),
                    int(opts.percentiles[2]), int(opts.percentiles[0])))
