#! /usr/bin/env python

# Copyright (C) 2020 Collin D. Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Performs a PP test on a collection of parameters, writing results out to
an html table.
"""

import argparse
import logging
import numpy
import sys
from scipy import stats
import pycbc
from pycbc import results
from pycbc.inference.option_utils import add_injsamples_map_opt
from pycbc.inference.io import (ResultsArgumentParser, results_from_cli,
                                injections_from_cli)

parser = ResultsArgumentParser(description=__doc__)
parser.add_argument("--verbose", action="store_true", default=False,
    help="Print logging info.")
parser.add_argument("--output-file", type=str, required=True,
    help="Path to output plot.")
add_injsamples_map_opt(parser)

# parse the command line
opts = parser.parse_args()

pycbc.init_logging(opts.verbose)

# read results
logging.info('Loading parameters')
_, parameters, labels, samples = results_from_cli(opts)

# typecast to list for iteration
samples = [samples] if not isinstance(samples, list) else samples

logging.info("Getting percentiles of injections")
measured_percentiles = {}
ninjections = len(opts.input_file)
for input_file, input_samples in zip(opts.input_file, samples):
    # load the injections
    opts.input_file = input_file
    inj_parameters = injections_from_cli(opts)

    for p in parameters:
        inj_val = inj_parameters[p]
        sample_vals = input_samples[p]
        measured = stats.percentileofscore(sample_vals, inj_val, kind='weak')
        try:
            measured_percentiles[p].append(measured)
        except KeyError:
            measured_percentiles[p] = []
            measured_percentiles[p].append(measured)

logging.info("Performing percentile-percentile test")
p_values = numpy.zeros(len(parameters))
table = []
fmt = '{:.3f}'
for ii, param in enumerate(parameters):
    row = [labels[param]]
    meas = numpy.array(measured_percentiles[param])
    meas.sort()
    expected = numpy.array([stats.percentileofscore(meas, x, kind='weak')
                            for x in meas])
    # perform ks test
    ks, p = stats.kstest(meas/100., 'uniform')
    p_values[ii] = p
    row += [fmt.format(ks), fmt.format(p)]
    table.append(row)

# do the p-value of p-values test if more than one parameter provided
if len(parameters) > 1:
    summary_ks, summary_p = stats.kstest(p_values, 'uniform')
    table.append(['<b>p-values</b>',
                  '<b>'+fmt.format(summary_ks)+'</b>',
                  '<b>'+fmt.format(summary_p)+'</b>'])
    # for the caption
    extra_caption = (
           ' If all parameters '
           'satisfy the percentile-percentile test, then the distribution of '
           'p-values should be uniform. The p-value of p-values is obtained '
           'by applying the KS test to the distribution of p-values. The '
           'smaller this value, the lower the probability of obtaining the '
           'observed distribution of parameters assuming that the analysis '
           'did satsify the expectation that $X$ percent of injections fall '
           'in the $X$ percent credible interval.')
else:
    extra_caption = ''

# create table header
headers = ["Parameter", "$D_{{KS}}$", "p-value"]

# add mathjax header to display latex
html = results.mathjax_html_header() + '\n%s'%(
    str(results.static_table(table, headers) ))

# add the number of injections that were used
mdatatmplt = '<h4><b>{}:</b> {}</h4>'
metadata = [mdatatmplt.format('Number of injections', ninjections)]

html += '\n'.join(metadata) + '<br />\n<br />\n'

# save HTML table
results.save_fig_with_metadata(
    html, opts.output_file, {},
    cmd=" ".join(sys.argv),
    title="Percentile-Percentile Test",
    caption='Percentile-percentile test results. The value of the KS '
           'statistic $D_{KS}$ gives the maximum distance '
           'between the the measured CDF and the '
           'expected (uniform) CDF. The associated two-tailed p-value gives '
           'the probability of getting a maximum distance larger than the '
           'observed $D_{KS}$ '
           'assuming that the measured CDF is the same as the expected. '
           'In other words, the larger (smaller) the p-value ($D_{KS}$), the '
           'more likely the measured distribution is the same as the '
           'expected.'+extra_caption)
