#!/usr/bin/env python
import argparse, logging, numpy as np
from ligo.segments import infinity
from pycbc.events import veto, coinc, stat
import pycbc.conversions as conv
import pycbc.version
from pycbc import io
from pycbc.events import trigger_fits as trfits
from pycbc import init_logging

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action='count')
parser.add_argument("--version", action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("--veto-files", nargs='+',
                    help="Optional veto file. Triggers within veto segments "
                         "contained in the file are ignored. Required if "
                         "--segment-names given.")
parser.add_argument("--segment-names", nargs='+',
                    help="Optional, name of veto segment in veto file. "
                         "Required if --veto-files given.")
parser.add_argument("--trigger-file",type=str,
                    help="File containing single-detector triggers")
parser.add_argument("--template-bank", required=True,
                    help="Template bank file in HDF format")
# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument('--trigger-snr-cut',  type=float,
                    help='Only consider triggers above the given SNR.')
parser.add_argument('--cluster-window', type=float,
                    help='Window (seconds) during which to keep the trigger '
                         'with the loudest statistic value. '
                         'Default=do not cluster')
parser.add_argument('--reduced-chisq-cut', type=float,
                    help='Only consider triggers below given reduced '
                         'chisquared.')
parser.add_argument("--output-file",
                    help="File to store the candidate triggers")
stat.insert_statistic_option_group(parser)
args = parser.parse_args()

if (args.veto_files and not args.segment_names) or \
    (args.segment_names and not args.veto_files):
    raise RuntimeError('--veto-files and --segment-names are mutually required')

if not len(args.veto_files) == len(args.segment_names):
    raise RuntimeError('--segment-names are required for each --veto-files')

init_logging(args.verbose)

logging.info('Opening trigger file: %s', args.trigger_file)
trigf = io.HFile(args.trigger_file, 'r')
ifo = trigf.keys()[0]

starts = trigf[ifo + '/search/start_time'][:]
ends = trigf[ifo + '/search/end_time'][:]
segments = veto.start_end_to_segments(starts, ends)

n_tot_trigs = trigf[ifo + '/snr'].size
logging.info("%d triggers in file", n_tot_trigs)

if args.trigger_snr_cut:
    keep_idx = np.flatnonzero(trigf[ifo + '/snr'][:] >= args.trigger_snr_cut)
    snr_cut_f = ("%f" % args.trigger_snr_cut).rstrip("0").rstrip(".")
    logging.info("Cutting %d triggers with SNR < %s (%.2f%%)",
                 n_tot_trigs - keep_idx.size, snr_cut_f,
                 float(n_tot_trigs - keep_idx.size) / n_tot_trigs * 100)
if args.reduced_chisq_cut:
    n_skp_trigs = float(keep_idx.size)
    chisq = trigf[ifo + '/chisq'][:][keep_idx]
    chisq_dof = trigf[ifo + '/chisq_dof'][:][keep_idx]
    reduced_chisq = chisq / (2 * chisq_dof - 2)
    chisq_keep_idx = np.flatnonzero(reduced_chisq <= args.reduced_chisq_cut)
    chisq_cut_f = ("%f" % args.reduced_chisq_cut).rstrip("0").rstrip(".")
    logging.info("Cutting %d triggers with \chi^2 > %.f (%.2f%%)",
                 n_skp_trigs - chisq_keep_idx.size, chisq_cut_f,
                 float(n_skp_trigs - chisq_keep_idx.size) / n_skp_trigs * 100)
    # Select chisq-cut-kept idx from keep_idx
    keep_idx = keep_idx[chisq_keep_idx]

if args.veto_files:
    for veto_file, segment_name in zip(args.veto_files, args.segment_names):
        logging.info("Getting vetoed indices from file %s", veto_file)
        end_time = trigf[ifo + '/end_time'][:][keep_idx]
        veto_keep_idx, _ = veto.indices_outside_segments(end_time,
                                                         [veto_file],
                                                         segment_name=segment_name,
                                                         ifo=ifo)
        logging.info("Cutting %d triggers in vetoed segments (%.2f%%)",
                     keep_idx.size - veto_keep_idx.size,
                     float(keep_idx.size - veto_keep_idx.size) / float(keep_idx.size) * 100.)
        # Select unvetoed idx from keep_idx
        keep_idx = keep_idx[veto_keep_idx]
        veto_segs = veto.select_segments_by_definer(veto_file, ifo=ifo,
                                                    segment_name=segment_name)
        fg_segs = segments - veto_segs
else:
    fg_segs = segments

if not len(keep_idx):
    raise RuntimeError("All triggers removed by vetoes or cuts")

logging.info("Loading %d triggers", len(keep_idx))

data_init = {}
all_dsets = ['sigmasq', 'chisq', 'chisq_dof', 'coa_phase', 'end_time',
             'snr', 'template_id', 'sg_chisq']

for ds in all_dsets:
    data_init[ds] = trigf[ifo + "/" + ds][:][keep_idx]
data_init['trigger_id'] = np.arange(trigf[ifo + '/snr'].size)[keep_idx]

logging.info("Putting data into DictArray")
trigs = io.DictArray(data=data_init)
trigf.close()

logging.info('Setting up ranking method')
# Stat class instance to calculate the ranking statistic
extra_kwargs = {}
for inputstr in args.statistic_keywords:
    try:
        key, value = inputstr.split(':')
        extra_kwargs[key] = value
    except ValueError:
        err_txt = "--statistic-keywords must take input in the " \
                  "form KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ... " \
                  "Received {}".format(args.statistic_keywords)
        raise ValueError(err_txt)

rank_method = stat.get_statistic_from_opts(args, [ifo])

logging.info("Computing single-detector statistic")
stat = rank_method.rank_stat_single((ifo, trigs.data))

logging.info("Clustering")
if args.cluster_window:
    cid = coinc.cluster_over_time(stat, trigs.data['end_time'],
                                  args.cluster_window)
    trigs = trigs.select(cid)
    stat = stat[cid]
    logging.info("%d triggers after clustering", stat.size)

fg_time = abs(fg_segs)

data = {"stat": stat,
        "decimation_factor": np.ones_like(stat),
        "timeslide_id": np.zeros_like(stat),
        "template_id": trigs.data['template_id'],
        "%s/time" % ifo : trigs.data['end_time'],
        "%s/trigger_id" % ifo: trigs.data['trigger_id']}

logging.info("saving triggers")
f = io.HFile(args.output_file, 'w')
for key in data:
    f.create_dataset(key, data=data[key],
                     compression="gzip",
                     compression_opts=9,
                     shuffle=True)

# Store segments
f['segments/%s/start' % ifo], f['segments/%s/end' % ifo] = \
    veto.segments_to_start_end(fg_segs)
f.attrs['foreground_time'] = fg_time
f.attrs['background_time'] = fg_time
f.attrs['num_of_ifos'] = 1
f.attrs['pivot'] = ifo
f.attrs['fixed'] = ifo
f.attrs['ifos'] = ifo

# Do hierarchical removal
# h_iterations = 0
# if args.max_hierarchical_removal != 0:

logging.info("Done")
