#!/usr/bin/env python
"""
Finds the loudest snr or newsnr event within a given time window, parses the 
parameters of the template, and writes them to an hdf and/or stdout.
"""

import lal
import numpy as np
import h5py
import argparse
import logging
import pycbc.events
from pycbc.pnutils import mass1_mass2_to_mchirp_eta

logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--single-ifo-trigs', type=str, required=True,
        help='HDF file containing single IFO CBC triggers')
parser.add_argument('--tmpltbank-file', type=str, required=True,
        help='HDF file containing template information for CBC search')
parser.add_argument('--ifo', type=str, required=True,
        help='IFO, L1 or H1')
parser.add_argument('--central-time', type=float, required=True,
        help='Central time over which to search')
parser.add_argument('--window', type=float, required=False, default=8.0,
        help='Time window over which to search for loudest trigger')
parser.add_argument('--ranking-statistic', type=str, required=False, default='newsnr',
        choices=['snr','newsnr'], help='Ranking statistic to use when searching for loudest events')
parser.add_argument('--output-file', type=str, required=False,
        help='Output hdf file to write parameters')
parser.add_argument('--print-params', action='store_true', required=False,
        help='Toggle printing parameters to stdout')
args = parser.parse_args()

logging.info('Reading in HDF files')
trigs = h5py.File(args.single_ifo_trigs,'r')
template_file = h5py.File(args.tmpltbank_file,'r')

if args.output_file:
    outfile = h5py.File(args.output_file,'w')

t_low = args.central_time - args.window
t_high = args.central_time + args.window

times = trigs[args.ifo]['end_time'][:]
mask = (times > t_low) & (times < t_high)

# generate vectors of sigmasq, sigma, snr, end_time, coalescence phase, and template_id
snr = trigs[args.ifo]['snr'][mask]
chisq = trigs[args.ifo]['chisq'][mask]
chisq_dof = trigs[args.ifo]['chisq_dof'][mask]
reduced_chisq = chisq/(2*chisq_dof - 2)
newsnr = pycbc.events.ranking.newsnr(snr,reduced_chisq)
template_ids = trigs[args.ifo]['template_id'][mask]
end_times = trigs[args.ifo]['end_time'][mask]

if args.ranking_statistic == 'snr':
    idx = np.argmax(snr)
else:
    idx = np.argmax(newsnr)

tid = template_ids[idx]
cbc_end_time = end_times[idx]
m1 = template_file['mass1'][tid]
m2 = template_file['mass2'][tid]
s1z = template_file['spin1z'][tid]
s2z = template_file['spin2z'][tid]
mchirp, eta = mass1_mass2_to_mchirp_eta(m1, m2)


data = {'%s/snr' % args.ifo : [snr[idx]]}
data['%s/chisq' % args.ifo] = [chisq[idx]]
data['%s/newsnr' % args.ifo] = [newsnr[idx]]
data['%s/template_id' % args.ifo] = [tid]
data['%s/end_time' % args.ifo] = [cbc_end_time]
data['template/mass1'] = [m1]
data['template/mass2'] = [m2]
data['template/mchirp'] = [mchirp]
data['template/eta'] = [eta]
data['template/spin1z'] = [s1z]
data['template/spin2z'] = [s2z]

if args.output_file:
    for key in data.keys():
        outfile.create_dataset(key,data=data[key])

if args.print_params:
    for key in data.keys():
        print(key.split('/')[1], str(data[key]))
