#!/usr/bin/env python
""" Make a table of found injection information
"""
import argparse, h5py, numpy as np, pycbc.results, pycbc.detector, sys
from pycbc.types import MultiDetOptionAction
import pycbc.pnutils, pycbc.events
import pycbc.version
from itertools import combinations


parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument('--injection-file', help='HDF File containing the matched injections')
parser.add_argument('--single-trigger-files', nargs='*', help="HDF format single detector trigger files", action=MultiDetOptionAction)
parser.add_argument('--verbose', action='count')
parser.add_argument('--show-missed', action='store_true')
parser.add_argument('--output-file')
args = parser.parse_args()

f = h5py.File(args.injection_file,'r')
inj = f['injections']
found_cols, found_names, found_formats = [], [], []

ifos = f.attrs['ifos'].split(' ')

if args.show_missed:
    title = "Missed Injections"
    idx = f['missed/after_vetoes'][:]
else:
    title = "Found Injections"
    found = f['found_after_vetoes']
    idx = found['injection_index'][:]
    detectors = f.attrs['ifos'].split(' ')
    keys = f['found_after_vetoes'].keys()
    detectors_used = []
    found = f['found_after_vetoes']
    for det in detectors:
        if(det in keys):
            detectors_used.append(det)
    det_two_combo= np.array(list(combinations(detectors_used,2)))
    tdiff = []
    tdiff_str = []
    tdiff_format =[]
    for i in range(len(det_two_combo)):
        time_1 = np.array(found[det_two_combo[i,0]+'/time'][:])
        time_2 = np.array(found[det_two_combo[i,1]+'/time'][:])
        tdiff_vals = (time_1 - time_2) * 1000
        tdiff_vals[np.logical_or(time_1 < 0, time_2 < 0)] = np.nan
        tdiff_1 = ['%.2f' % td if not np.isnan(td) else ' ' for td in tdiff_vals]
        tdiff.append(tdiff_1)
        tdiff_head= '%s - %s time (ms)' % (det_two_combo[i,0], det_two_combo[i,1])
        tdiff_str.append(tdiff_head)
        tdiff_format.append('##.##')
    ids = {detector:found[detector+'/trigger_id'][:] for detector in detectors_used}

    found_cols = [found['stat'], found['ifar_exc']] + tdiff
    found_names = ['Ranking Stat.', 'Exc. IFAR'] + tdiff_str
    found_formats =  ['##.##', '##.##'] + tdiff_format


    if args.single_trigger_files:
        for ifo in args.single_trigger_files:
            f = h5py.File(args.single_trigger_files[ifo], 'r')[ifo]
            ids_ifo = np.array(ids[ifo])
            ids_na = np.argwhere(ids_ifo == -1)
            snr_vals = f['snr'][:][ids_ifo]
            snr_vals[ids_ifo == -1] = np.nan
            chisq_vals = f['chisq'][:][ids_ifo] / (2 * f['chisq_dof'][:][ids_ifo] - 2)
            chisq_vals[ids_ifo == -1] = np.nan
            newsnr_vals = pycbc.events.ranking.newsnr(snr_vals, chisq_vals)
            snr = ['%.2f' % s if not np.isnan(s) else ' ' for s in snr_vals]
            chisq = ['%.2f' % c if not np.isnan(c) else ' ' for c in chisq_vals]
            newsnr = ['%.2f' % s if not np.isnan(s) else ' ' for s in newsnr_vals]

            found_names += [ifo + " SNR", ifo + " CHISQ", ifo + " NewSNR"]
            found_cols += [snr, chisq, newsnr]
            found_formats += ['##.##', '##.##', '##.##']

eff_dist = {'eff_dist_%s' % i[0].lower() : 'Eff Dist (%s)' % i for i in ifos}

keys = inj.keys()
eff_dist_str = []
eff_distance = []
eff_dist_format = []
for dist in eff_dist :
    ifo = ('%s1' % dist.split('_')[-1]).upper()
    d = pycbc.detector.Detector(ifo)
    edist = d.effective_distance(
                 inj['distance'][:][idx],
                 inj['ra'][:][idx],
                 inj['dec'][:][idx],
                 inj['polarization'][:][idx],
                 inj['tc'][:][idx],
                 inj['inclination'][:][idx])
    eff_distance.append(edist)
    eff_dist_str.append(eff_dist[dist])
    eff_dist_format.append('##.##')

dec_dist = np.max(eff_distance, 0)
m1, m2 = inj['mass1'][:][idx], inj['mass2'][:][idx]
mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
dec_chirp_dist = pycbc.pnutils.chirp_distance(dec_dist, mchirp)

columns = [dec_chirp_dist, inj['tc'][:][idx], m1, m2, mchirp, eta,
           inj['spin1x'][:][idx], inj['spin1y'][:][idx], inj['spin1z'][:][idx],
           inj['spin2x'][:][idx], inj['spin2y'][:][idx], inj['spin2z'][:][idx],
           inj['distance'][:][idx]] + eff_distance + found_cols

names = ['DChirp Dist', 'Inj Time', 'Mass1', 'Mass2', 'Mchirp', 'Eta',
         's1x', 's1y', 's1z',
         's2x', 's2y', 's2z',
         'Dist']  + eff_dist_str + found_names

format_strings = ['##.##', '##.##', '##.##', '##.##', '##.##', '##.##',
                  '##.##', '##.##', '##.##',
                  '##.##', '##.##', '##.##',
                  '##.##'] + eff_dist_format +  found_formats
columns = [np.array(col) for col in columns]
html_table = pycbc.results.html_table(columns, names,
                                 format_strings=format_strings,
                                 page_size=20)

kwds = { 'title' : title,
        'caption' : "A table of %s and their coincident statistic information." % title.lower(),
        'cmd' :' '.join(sys.argv), }
pycbc.results.save_fig_with_metadata(str(html_table), args.output_file, **kwds)
