#!/usr/bin/env python

# Copyright (C) 2017 Vaibhav Tiwari

# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
This script estimates black hole merger rates for astro-physical models:
1) Uniform in log of the compoment mass
2) Power-law distribution of the masses
3) Binary Neutron Star with uniform in masses specified in a range.

It performs weighted MonteCarlo integration to calculate the sensitive volume.
"""

__author__ = "Vaibhav Tiwari"
__email__ = "vaibhav.tiwari@ligo.org"
__version__ = "0.0"
__date__ = "31.10.2017"

import logging
import argparse
import pycbc.version

import h5py
import numpy as np
import scipy.stats as ss
from pycbc.population import scale_injections as si
from pycbc.population import rates_functions as rf
from numpy import logaddexp, log, newaxis, expm1


# Parse command line
parser = argparse.ArgumentParser(description = "Estimate rates for flat in "
                                     "log and power-law mass distribution.")
parser.add_argument('--sim-files', nargs='+', required=True, help="List of "
                        "simulation files to estimate the sensitive volume")
parser.add_argument('--m_dist', nargs = '+', required=True, help="Specify "
                                "the mass distribution for the simulations")
parser.add_argument('--s_dist', nargs = '+',required=True, help="Specify "
                                "the spin distribution for the simulations")
parser.add_argument('--d_dist', nargs = '+', required=True, help="Specify "
                            "the distance distribution for the simulations")
parser.add_argument('--bank-file', required=True, help="File containing "
                                        "template bank used in the search.")
parser.add_argument('--statmap-file', required=True, help="File containing "
                                                      "trigger information")
parser.add_argument('--prior-samples', required=True, help="File storing "
 "samples of prior for the analysis - posterior from the previous analysis")
parser.add_argument('--output-file', required=True, help="Name of the "
                                  "output file for saving rate posteriors.")
parser.add_argument('--thr-var', required=False, default = 'stat',
                                  help="Variable used to define threshold.")
parser.add_argument('--thr-val', type = float, required=False, default = 8.0, 
                                                 help="Value of threshold.")
parser.add_argument('--population-model', required = True,
                       help = 'Population model defined in rates_functions')
parser.add_argument('--min-mass', type = float, required=True,
                                 help="Minimum mass of the compact object.")
parser.add_argument('--max-mass', type = float, required=True,
                                 help="Maximum mass of the compact object.")
parser.add_argument('--max-mtotal', type = float, required=True,
                                   help="Maximum total mass of the binary.")
parser.add_argument('--min-tmplt-mchirp', type = float, required=True,
                     help="Minimum chirp mass of the template considered "
                                              "for trigger identification.")
parser.add_argument('--max-tmplt-mchirp', type = float, required=True,
                     help="Maximum chirp mass in the template considered "
                                              "for trigger identification.")
parser.add_argument('--calibration-error', dest='cal_err', type=float,
                 required=False, default = 3.0, help="Percentage calibration"
                                      " errors in measurement of distance.")

opts = parser.parse_args()
path = opts.output_file

assert opts.min_tmplt_mchirp < opts.max_tmplt_mchirp, \
             "Minimum chirp mass should be less than the maximum chirp mass"

# Read the simulation files
injections = si.read_injections(opts.sim_files,
          opts.m_dist, opts.s_dist, opts.d_dist)

# Read the chirp-mass samples -- Imported from rates_function
if opts.population_model == 'imf':
    mchirp_sampler = rf.mchirp_sampler_imf
    prob = rf.prob_imf
elif opts.population_model == 'lnm':
    mchirp_sampler = rf.mchirp_sampler_lnm
    prob = rf.prob_lnm
elif opts.population_model == 'bns':
    mchirp_sampler = rf.mchirp_sampler_flat
    prob = rf.prob_flat

# Estimate the rates and make supporting plots
vt = si.estimate_vt(injections, mchirp_sampler, prob,
                              thr_var = opts.thr_var,
                              thr_val = opts.thr_val,
                              min_mass = opts.min_mass,
                              max_mass = opts.max_mass,
                              max_mtotal = opts.max_mtotal)

vol_time, sig_vt = vt['VT'], vt['VT_err']
inj_falloff = vt['thr_falloff']

# Include the calibration uncertainity
vol, vol_err, cal_err = vol_time, sig_vt, opts.cal_err
sigma_w_cal_uncrt = np.sqrt((3*cal_err/100.)**2 + (vol_err/vol)**2)

#Sabe background data and coincidences
all_bkg, coincs = rf.save_bkg_falloff(opts.statmap_file, opts.bank_file,
                  path, opts.thr_val, opts.min_tmplt_mchirp,
                  opts.max_tmplt_mchirp)

#Load background data and coincidences/ make some plots
bg_l, bg_h, bg_counts = all_bkg
bg_bins = np.append(bg_l, bg_h[-1])

#fg_stats = np.concatenate([inj_falloff[dist] for dist in distrs])
fg_stats = inj_falloff[inj_falloff > opts.thr_val]
fg_bins = np.logspace(np.log10(opts.thr_val), np.log10(np.max(fg_stats)), 101)

log_fg_ratios = rf.log_rho_fgmc(coincs, fg_stats, fg_bins)
log_fg_ratios -= rf.log_rho_bg(coincs, bg_bins, bg_counts)

#Load prior samples and fit a skew-log-normal to it
with h5py.File(opts.prior_samples, "r") as f:
    R = np.array(f[opts.population_model+'/Rf'])

alpha, mu, sigma = rf.fit(R)

#Estimate rates
rate_samples = {}
log_R = log(R)
mu_log_vt = np.log(vol_time/1e9)
sigma_log_vt = sigma_w_cal_uncrt
Rf_samp = rf.skew_lognormal_samples(alpha, mu, sigma, min(log_R), max(log_R))
rate_samples['Rf'], rate_samples['Lf'], rate_samples['Lb'] = \
rf.fgmc(log_fg_ratios, mu_log_vt, sigma_log_vt, Rf_samp, max(fg_stats))

rate_post = rate_samples['Rf']
r50, r95, r05 = np.percentile(rate_post, [50, 95, 5])

#Save rate posteriors
with h5py.File(opts.output_file, 'w') as out:

    pl = out.create_group(opts.population_model)
    pl.create_dataset('Lf', data=rate_samples['Lf'], compression='gzip')
    pl.create_dataset('Lb', data=rate_samples['Lb'], compression='gzip')
    pl.create_dataset('Rf', data=rate_samples['Rf'], compression='gzip')

    d = out.create_group('data')
    d.create_dataset('log_fg_bg_ratio', data=log_fg_ratios, compression='gzip')
    d.create_dataset('newsnr', data=coincs, compression='gzip')
