#!/usr/bin/env python
""" Validate and generate diagnostic plots for a inference file using the
test posterior model.
"""
import sys
import numpy
import argparse
from matplotlib import use; use('Agg')
import pylab
from pycbc.inference.option_utils import prior_from_config
from pycbc.inference import models, io
from scipy.stats import gaussian_kde, ks_2samp
from pycbc.io import FieldArray
numpy.random.seed(0)

parser = argparse.ArgumentParser()
parser.add_argument('--input-file', help='inference posterior file')
parser.add_argument('--output-file', help='diagnostic plot')
parser.add_argument('--p-value-threshold', help='minimum ks test p-value',
                    type=float)
parser.add_argument('--ind-samples', help='use only this number of samples',
                    default=1000, type=int)
args = parser.parse_args()

size = int(1e6)
d1 = io.loadfile(args.input_file, 'r')

#We directly recreate the model and prior from the stored
#config to ensure the same configuration
config = d1.read_config_file()

prior = prior_from_config(config)
model = models.read_from_config(config)

# Draw reference samples directly from the kde
draw = model.kde.resample(size=size)
data = {v: draw[i, :] for i, v in enumerate(model.variable_params)}
ref = FieldArray.from_kwargs(**data)

# apply the prior bounds to ensure kde leakage is not a concern
for dist in prior.distributions:
    param = dist._params[0]
    bound = dist._bounds[param]
    ref = ref[(bound.min < ref[param]) & (ref[param] < bound.max)]

nparam = len(model.variable_params)
fig, axs = pylab.subplots(1, nparam, figsize=[6*nparam, 4], dpi=100)

result = d1.read_samples(model.variable_params)
failed = False
for param, ax in zip(model.variable_params, axs):
    rpart = numpy.random.choice(result[param], replace=False,
                                size=args.ind_samples)
    kv, pvalue = ks_2samp(ref[param], rpart)
    print("{}, p-value={:.3f}".format(param, pvalue))

    pylab.sca(ax)
    pylab.hist(ref[param], density=True, bins=30, label='reference')
    pylab.hist(result[param], density=True, bins=30, alpha=0.5, label='sampler')
    pylab.title('KS p-value = {:.4f}'.format(pvalue))
    pylab.xlabel(param)
    pylab.legend()
    ax.get_yaxis().set_visible(False)

    if pvalue < args.p_value_threshold:
        failed = True

pylab.savefig(args.output_file)
sys.exit(failed)
