#!/usr/bin/python

# Copyright 2016 Thomas Dent, Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys, h5py, argparse, logging, pycbc.version, numpy
from pycbc.events import triggers

parser = argparse.ArgumentParser(usage="",
    description="Smooth (regress) the dependence of coefficients describing "
                "single-ifo background trigger distributions on a template "
                "parameter, to suppress random noise in the resulting "
                "background model.")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--template-fit-file",
                    help="hdf5 file containing fit coefficients for each"
                         " individual template. Required")
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters. Required "
                         "unless reading param from template fit file")
parser.add_argument("--output", required=True,
                    help="Location for output file containing smoothed fit "
                         "coefficients.  Required")
parser.add_argument("--use-template-fit-param", action="store_true",
                    help="Use parameter values stored in the template fit file"
                         "as template_param for smoothing.", default=False)
parser.add_argument("--fit-param", nargs='+',
                    help="Parameter(s) over which to regress the background "
                         "fit coefficients. Required. Either read from "
                         "template fit file or choose from mchirp, mtotal, "
                         "chi_eff, eta, tau_0, tau_3, template_duration, "
                         "a frequency cutoff in pnutils or a frequency function"
                         "in LALSimulation. To regress the background over "
                         "multiple parameters, provide them as a list.")
parser.add_argument("--approximant", default="SEOBNRv4",
                    help="Approximant for template duration. Default SEOBNRv4")
parser.add_argument("--f-lower", type=float, default=0.,
                    help="Starting frequency for calculating template "
                         "duration, if not reading from the template fit file")
parser.add_argument("--min-duration", type=float, default=0.,
                    help="Fudge factor for templates with tiny or negative "
                         "values of template_duration: add to duration values"
                         " before fitting. Units seconds.")
parser.add_argument("--log-param", nargs='+',
                    help="Take the log of the fit param before smoothing.")
parser.add_argument("--smoothing-width", type=float, nargs='+', required=True,
                    help="Distance in the space of fit param values (or the "
                         "logs of them) to smooth over. Required. "
                         "This must be a list corresponding to the smoothing "
                         "parameters.")
args = parser.parse_args()

assert len(args.log_param) == len(args.fit_param) == len(args.smoothing_width)

pycbc.init_logging(args.verbose)

fits = h5py.File(args.template_fit_file, 'r')

# get the ifo from the template-level fit
ifo = fits.attrs['ifo']

# get template id and template parameter values
tid = fits['template_id'][:]

logging.info('Calculating template parameter values')
bank = h5py.File(args.bank_file, 'r')
m1, m2, s1z, s2z = triggers.get_mass_spin(bank, tid)

parvals = []

for param, slog in zip(args.fit_param, args.log_param):
    data = triggers.get_param(param, args, m1, m2, s1z, s2z)
    if slog in ['false', 'False', 'FALSE']:
        logging.info('Using param: %s', param)
        parvals.append(data)
    elif slog in ['true', 'True', 'TRUE']:
        logging.info('Using log param: %s', param)
        parvals.append(numpy.log(data))
    else:
        raise ValueError("invalid log param argument, use 'true', or 'false'")

if 'count_in_template' in fits.keys(): # older files may not have this dataset
    tcount = True
else:
    tcount = False

# for an exponential fit 1/alpha is linear in the trigger statistic values
# so calculating weighted sums or averages of 1/alpha is appropriate
nabove = fits['count_above_thresh'][:]
nabove = nabove / numpy.mean(nabove)
if tcount: ntotal = fits['count_in_template'][:]

invalpha = 1. / fits['fit_coeff'][:]
invalphan = invalpha * nabove

def dist(i1, i2):
    dsq = 0
    for v, s in zip(parvals, args.smoothing_width):
        dsq += (v[i2] - v[i1]) ** 2.0 / s ** 2.0
    return dsq ** 0.5

nabove_smoothed = []
if tcount: ntotal_smoothed = []
alpha_smoothed = []
rang = numpy.arange(0, len(nabove))

logging.info("Smoothing ...")
# Handle the one-dimensional case of one dimension separately as it is easier to
# optimize computational performance.
if len(parvals) == 1:
    sort = parvals[0].argsort()
    parvals[0] = parvals[0][sort]
    ntotal = ntotal[sort]
    nabove = nabove[sort]
    invalphan = invalphan[sort]

    # For each template, find the range of nearby templates which fall within
    # the chosen window.
    left = numpy.searchsorted(parvals[0], parvals[0] - args.smoothing_width[0])
    right = numpy.searchsorted(parvals[0], parvals[0] + args.smoothing_width[0]) - 1

    # Precompute the sums so we can quickly look up differences between
    # templates
    ntsum = ntotal.cumsum()
    nasum = nabove.cumsum()
    invsum = invalphan.cumsum()
    num = right - left

    if tcount: ntotal_smoothed = (ntsum[right] - ntsum[left]) / num
    nabove_smoothed = (nasum[right] - nasum[left]) / num
    invmean = (invsum[right] - invsum[left]) / num
    alpha_smoothed = nabove_smoothed / invmean

else:
    for i in range(len(nabove)):
        dsq = dist(i, rang)
        l = (dsq < 1)
        if tcount: ntotal_smoothed.append(ntotal[l].mean())
        nabove_smoothed.append(nabove[l].mean())
        alpha_smoothed.append(nabove_smoothed[i] / invalphan[l].mean())

logging.info("Writing output")
outfile = h5py.File(args.output, 'w')
outfile['template_id'] = tid
outfile['count_above_thresh'] = nabove_smoothed
outfile['fit_coeff'] = alpha_smoothed
try:
    outfile['median_sigma'] = fits['median_sigma'][:]
except KeyError:
    logging.info('Median_sigma dataset not present in input file')
if tcount: outfile['count_in_template'] = ntotal_smoothed

for param, vals, slog in zip(args.fit_param, parvals, args.log_param):
    if slog in ['false', 'False', 'FALSE']:
        outfile[param] = vals
    elif slog in ['true', 'True', 'TRUE']:
        outfile[param] = numpy.exp(vals)

# add metadata, some is inherited from template level fit
outfile.attrs['ifo'] = ifo
outfile.attrs['stat_threshold'] = fits.attrs['stat_threshold']
if 'analysis_time' in fits.attrs:
    outfile.attrs['analysis_time'] = fits.attrs['analysis_time']

# add a magic file attribute so that coinc_findtrigs can parse it
outfile.attrs['stat'] = ifo + '-fit_coeffs'
logging.info('Done!')
