#!/usr/bin/env python

# Copyright (C) 2017 Vaibhav Tiwari

# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
This script reads the rates priors and posteriors and plots them
"""

__author__ = "Vaibhav Tiwari"
__email__ = "vaibhav.tiwari@ligo.org"
__version__ = "0.0"
__date__ = "31.10.2017"

import logging
import argparse
import pycbc.version

import h5py
import pylab
import numpy as np
import scipy.stats as ss
from matplotlib import cm
from pycbc.population import rates_functions as rf
from numpy import logaddexp, log, newaxis, expm1

# Parse command line
parser = argparse.ArgumentParser(description = 
                                      "Plot rate prior and posteriors.")
parser.add_argument('--prior-samples', required = True, 
                              help="File containing rate prior samples")
parser.add_argument('--posterior-samples', nargs = '+', required = True, 
             help="Files(s) containing samples for the rate posterior.")
parser.add_argument('--population-models', nargs = '+', required=True, 
        help="Models to which posteriors belong ('lnm', 'imf', 'bns').")
parser.add_argument('--output-rates', dest='rate_file', required=True,
                   help="File saving all the rate posteriors into one.")
parser.add_argument('--rates-figure', required=True,
                 help="Name of file to draw rates prior and posterior.")
parser.add_argument('--pastro-figure', required=True,
                            help="Name of file to save p_astro figure.")
parser.add_argument('--plot-labels', nargs = '+',
              required = True, help="Labels for the population models.")

opts = parser.parse_args()
assert len(opts.posterior_samples) == len(opts.population_models), \
              "Unequal number of posterior files and population models!"

#Save rate posteriors in one file
with h5py.File(opts.rate_file, 'w') as out:
    for fname, model in zip(opts.posterior_samples, opts.population_models):

        f = h5py.File(fname, "r")
        Rf = f[model + '/Rf'][:]
        Lf = f[model + '/Lf'][:]

        pl = out.create_group(model)
        pl.create_dataset('Rf', data=Rf, compression='gzip')
        pl.create_dataset('Lf', data=Lf, compression='gzip')

        f.close()
    
# Make prior/posterior plot -- estimate p_astro
p_astro = []
pylab.figure()
color=iter(cm.rainbow(np.linspace(0, 1, len(opts.population_models))))
mods = zip(opts.posterior_samples, opts.population_models, opts.plot_labels)

f = h5py.File(opts.prior_samples, "r")
for fpost, model, lbl in mods:

    c = next(color)

    fpo = h5py.File(fpost, "r")
    Rfpr, Rfpo = f[model + '/Rf'][:], fpo[model + '/Rf'][:]
    prior_alpha, prior_mu, prior_sigma = rf.fit(Rfpr)
    post_alpha, post_mu, post_sigma = rf.fit(Rfpo)

    log_R = np.log(Rfpr)
    xs = np.linspace(min(log_R), max(log_R), 200)
    pylab.plot(np.exp(xs), ss.skewnorm.pdf(xs, prior_alpha, prior_mu, 
            prior_sigma), '--', label = lbl + ' Prior', color = c)
    pylab.plot(np.exp(xs), ss.skewnorm.pdf(xs, post_alpha, post_mu, 
               post_sigma), label = lbl + ' Posterior', color = c)

    Lfpo, Lbpo = fpo[model + '/Lf'][:], fpo[model + '/Lb'][:]
    log_fg_ratios = fpo['data/log_fg_bg_ratio'][:]    

    log_pastros = logaddexp.reduce(log(Lfpo[:, newaxis]) +\
                 log_fg_ratios[newaxis,:] - logaddexp(log(Lfpo[:, newaxis]) +\
                 log_fg_ratios[newaxis,:], log(Lbpo[:, newaxis])), axis=0) -\
                 log(Lfpo.shape[0])
    p_astro.append(1 + expm1(np.sort(log_pastros)[::-1]))

    fpo.close()
f.close()

pylab.xscale('log')
pylab.xlabel(r'$R$ ($\mathrm{Gpc}^{-3} \, \mathrm{yr}^{-1}$)')
pylab.ylabel(r'$RP(R)$')
pylab.legend(loc='best')
pylab.savefig(opts.rates_figure)

pylab.figure()
color=iter(cm.rainbow(np.linspace(0, 1, len(opts.population_models))))
for pas, lbl in zip(p_astro, opts.plot_labels):   
    c = next(color)
    pylab.plot(log_fg_ratios, 1 - pas, '.', label = lbl, color = c)

pylab.xlabel(r'$\log p(x\mid f)/p(x\mid b)$')
pylab.ylabel(r'$1-p_\mathrm{astro}$')
pylab.legend(loc='best')
pylab.yscale('log')
pylab.savefig(opts.pastro_figure)
print(p_astro)
