#!/usr/bin/python
"""Associate coincident triggers with injections listed in one or more LIGOLW
files.
"""

import argparse, h5py, logging, types, numpy, os.path
from glue.ligolw import ligolw, table, lsctables, utils as ligolw_utils
from ligo import segments
from pycbc import events
from pycbc.events import indices_within_segments
from pycbc.types import MultiDetOptionAction
from pycbc.inject import CBCHDFInjectionSet
import pycbc.version


# dummy class needed for loading LIGOLW files
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
    pass
lsctables.use_in(LIGOLWContentHandler)

def hdf_append(f, key, value):
    if key in f:
        tmp = numpy.concatenate([f[key][:], value])
        del f[key]
        f[key] = tmp
    else:
        f[key] = value

h5py.File.append = types.MethodType(hdf_append, h5py.File)

def keep_ind(times, start, end):
    """ Return the list of indices within the list of start and end times
    """
    time_sorting = times.argsort()
    times = times[time_sorting]
    indices = numpy.array([], dtype=numpy.uint32)
    leftidx = numpy.searchsorted(times, start, side='left')
    rightidx = numpy.searchsorted(times, end, side='right')

    for li, ri in zip(leftidx, rightidx):
        seg_indices = numpy.arange(li, ri, 1).astype(numpy.uint32)
        indices=numpy.union1d(seg_indices, indices)
    return time_sorting[indices]

def xml_to_hdf(table, hdf_file, hdf_key, columns):
    """ Save xml columns as hdf columns, only float32 supported atm.
    """
    for col in columns:
        # Look for key in the form 'a:b' where we want to 
        # read data from column a but store it as named output 'b'
        # This is used to keep the output consistent where the input may not
        # depending on the injection file format (xml vs hdf)
        if ':' in col:
            col, new_col = col.split(':')
        else:
            new_col = col
        key = os.path.join(hdf_key, new_col)
        hdf_append(hdf_file, key, numpy.array(table.get_column(col),
                                        dtype=numpy.float32))

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-files', nargs='+', required=True)
parser.add_argument('--injection-files', nargs='+', required=True)
parser.add_argument('--veto-file')
parser.add_argument('--segment-name', default=None,
                    help='Name of segment list to use for vetoes. Optional')
parser.add_argument('--injection-window', type=float, required=True)
parser.add_argument('--min-required-ifos', type=int, default=2,
                    help='Minimum number of IFOs required to be observing and'
                    ' not vetoed for an injection to be counted. Default 2')
parser.add_argument('--optimal-snr-column', nargs='+',
                    action=MultiDetOptionAction, metavar='DETECTOR:COLUMN',
                    help='Names of the sim_inspiral columns containing the'
                    ' optimal SNRs.')
parser.add_argument('--redshift-column', default=None,
                    help='Name of sim_inspiral column containing redshift. '
                    'Optional')
parser.add_argument('--verbose', action='count')
parser.add_argument('--output-file', required=True)
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

fo = h5py.File(args.output_file, 'w')

injection_index = 0
for trigger_file, injection_file in zip(args.trigger_files,
                                        args.injection_files):
    logging.info('Read in the coinc data: %s' % trigger_file)
    f = h5py.File(trigger_file, 'r')

    # Get list of groups which contain subgroup 'time'
    # - these will be the IFOs
    ifo_list = [key for key in f['foreground']
                if 'time' in f['foreground/%s/' % key]]
    assert len(ifo_list) > 1
    # Check required ifos option
    if len(ifo_list) < args.min_required_ifos:
        raise RuntimeError('min-required-ifos (%s) must be <= number of ifos'
                           ' being searched (%s)' %
                           (args.min_required_ifos, len(ifo_list)))
    fo.attrs['ifos'] = ' '.join(sorted(ifo_list))

    template_id = f['foreground/template_id'][:]
    stat = f['foreground/stat'][:]
    ifar_exc = f['foreground/ifar_exc'][:]
    fap_exc = f['foreground/fap_exc'][:]
    try:
        ifar = f['foreground/ifar'][:]
        fap = f['foreground/fap'][:]
    except KeyError:
        logging.info('No inclusive ifar/fap. Proceeding anyway')
        ifar = None
        fap = None
    # using multi-ifo-style trigger file input
    ifo_times = ()
    time_dict = {}
    trig_dict = {}
    for ifo in ifo_list:
         ifo_times += (f['foreground/%s/time' % ifo][:],)
         time_dict[ifo] = f['foreground/%s/time' % ifo][:]
         trig_dict[ifo] = f['foreground/%s/trigger_id' % ifo][:]
    time = numpy.array([events.mean_if_greater_than_zero(vals)[0]
                                               for vals in zip(*ifo_times)])
    # We will discard injections which cannot be associated with a
    # coincident event, thus combine segments over all combinations
    # of coincident detectors to determine which times to keep
    any_seg = segments.segmentlist([])
    for key in f['segments']:
        if key == 'foreground':
            continue
        else:
            starts = f['/segments/%s/start' % key][:]
            ends = f['/segments/%s/end' % key][:]
            any_seg += events.start_end_to_segments(starts, ends)
    ana_start, ana_end = events.segments_to_start_end(any_seg)

    time_sorting = time.argsort()

    logging.info('Read in the injection file')
    if '.xml' in injection_file or '.xml.gz' in injection_file:
        indoc = ligolw_utils.load_filename(injection_file, False,
                                           contenthandler=LIGOLWContentHandler)
        sim_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName)
        inj_time = numpy.array(sim_table.get_column('geocent_end_time') +
                               1e-9 * sim_table.get_column('geocent_end_time_ns'),
                               dtype=numpy.float64)
    else:
        inj_file = CBCHDFInjectionSet(injection_file)
        inj_data = inj_file.table
        inj_time = inj_data['tc'][:]

    logging.info('Determined the found injections by time')
    left = numpy.searchsorted(time[time_sorting],
                              inj_time - args.injection_window, side='left')
    right = numpy.searchsorted(time[time_sorting],
                               inj_time + args.injection_window, side='right')
    found = numpy.where((right-left) == 1)[0]
    missed = numpy.where((right-left) == 0)[0]
    ambiguous = numpy.where((right-left) > 1)[0]
    missed = numpy.concatenate([missed, ambiguous])
    logging.info('Found: %s, Missed: %s Ambiguous: %s'
                 % (len(found), len(missed), len(ambiguous)))

    if len(ambiguous) > 0:
        logging.warn('More than one coinc trigger found associated '
                     'with injection')
        am = numpy.arange(0, len(inj_time), 1)[left[ambiguous]]
        bm = numpy.arange(0, len(inj_time), 1)[right[ambiguous]]

    logging.info('Removing injections outside of analyzed time')
    ki = keep_ind(inj_time, ana_start, ana_end)
    found_within_time = numpy.intersect1d(ki, found)
    missed_within_time = numpy.intersect1d(ki, missed)
    logging.info('Found: %s, Missed: %s' %
                 (len(found_within_time), len(missed_within_time)))

    logging.info('Removing injections in vetoed time')

    # Put individual detector vetoes into list of all vetoed indices, if an
    # injection is vetoed in N ifos then its index will appear N times
    vetoid = numpy.array([])
    veto_dict = {}
    for ifo in ifo_list:
        vi, _ = indices_within_segments(inj_time, [args.veto_file], ifo=ifo,
                                        segment_name=args.segment_name)
        vetoid = numpy.append(vetoid, vi)
        veto_dict[ifo] = vi

    # Find out which of these indices are unique, and how many times each
    # vetoed index occurs - this is how many detectors it is vetoed in
    vetoid_unique, count_vetoid = numpy.unique(vetoid, return_counts=True)

    # remove injections where the number of unvetoed ifos is less than the
    # minimum specified by the user
    vetoed_all = vetoid_unique[(len(ifo_list) - count_vetoid)
                               < args.min_required_ifos]
    found_after_vetoes = numpy.array([i for i in found_within_time
                                      if i not in vetoed_all])
    missed_after_vetoes = numpy.array([i for i in missed_within_time
                                      if i not in vetoed_all]).astype(int)

    logging.info('Found: %s, Missed: %s' %
                 (len(found_after_vetoes), len(missed_after_vetoes)))

    found_fore = numpy.arange(0, len(stat), 1)[left[found]]
    found_fore_v = numpy.arange(0, len(stat), 1)[left[found_after_vetoes]]

    logging.info('Saving injection information')
    if '.xml' in injection_file or '.xml.gz' in injection_file:
        ninj = len(sim_table)
        columns = ['mass1', 'mass2', 'spin1x', 'spin1y',
                   'spin1z', 'spin2x', 'spin2y', 'spin2z',
                   'inclination', 'polarization', 'coa_phase',
                   'latitude:dec', 'longitude:ra', 'distance']
        xml_to_hdf(sim_table, fo, 'injections', columns)
        hdf_append(fo, 'injections/tc', inj_time)

        # pick up optimal SNRs
        for ifo, column in args.optimal_snr_column.items():
            # As a single detector being vetoed won't veto all combinations,
            # need to set optimal_snr of a vetoed ifo to zero in order
            # to later calculate decisive optimal snr
            optimal_snr_all = numpy.array(sim_table.get_column(column))
            optimal_snr_all[veto_dict[ifo]] = 0
            hdf_append(fo, 'injections/optimal_snr_%s' % ifo,
                       optimal_snr_all)

        # pick up redshift
        if args.redshift_column:
            hdf_append(fo, 'injections/redshift',
                       sim_table.get_column(args.redshift_column))
    else:
        # hdf injection format
        ninj = len(inj_data)

        # fill these columns if not provided (rest should be there
        # if we go this far) as they are commonly used for plots
        for k in ['spin1x', 'spin1y', 'spin1z', 'spin2x', 'spin2y', 'spin2z']:
            if k not in inj_data.dtype.names:
                inj_data = inj_data.add_fields(numpy.zeros(ninj), k)

        for k in inj_data.dtype.names:
            data = inj_data[k][:]

            # set optimal snr to zero for detectors that were vetoed
            if 'optimal_snr' in k:
                ifo = k.split('_')[-1]
                data[veto_dict[ifo]] = 0

            if data.dtype.char == 'U':
                data = data.astype('S')

            hdf_append(fo, 'injections/{}'.format(k), data)

    # copy over common search info
    if 'foreground_time' in f.attrs.keys():
        fo.attrs['foreground_time'] = f.attrs['foreground_time']
    if 'foreground_time_exc' in f.attrs.keys():
        fo.attrs['foreground_time_exc'] = f.attrs['foreground_time_exc']

    for key in f['segments'].keys():
        if 'foreground' in key or 'coinc' in key:
            continue
        if key not in fo:
            fo.create_group(key)
        if key in f:
            fkey = f[key]
        else:
            fkey = f
        fo[key].attrs['pivot'] = fkey.attrs['pivot']
        fo[key].attrs['fixed'] = fkey.attrs['fixed']
        fo[key].attrs['foreground_time'] = fkey.attrs['foreground_time']
        fo[key].attrs['foreground_time_exc'] = fkey.attrs['foreground_time_exc']

    hdf_append(fo, 'missed/all', missed + injection_index)
    hdf_append(fo, 'missed/within_analysis', missed_within_time + injection_index)
    hdf_append(fo, 'missed/after_vetoes', missed_after_vetoes + injection_index)
    hdf_append(fo, 'found/template_id', template_id[time_sorting][found_fore])
    hdf_append(fo, 'found/injection_index', found + injection_index)
    hdf_append(fo, 'found/stat', stat[time_sorting][found_fore])
    hdf_append(fo, 'found/ifar_exc', ifar_exc[time_sorting][found_fore])
    hdf_append(fo, 'found/fap_exc', ifar_exc[time_sorting][found_fore])
    if ifar is not None:
        hdf_append(fo, 'found/ifar', ifar[time_sorting][found_fore])
        hdf_append(fo, 'found/fap', fap[time_sorting][found_fore])
    hdf_append(fo, 'found_after_vetoes/template_id',
               template_id[time_sorting][found_fore_v])
    hdf_append(fo, 'found_after_vetoes/injection_index',
               found_after_vetoes + injection_index)
    hdf_append(fo, 'found_after_vetoes/stat', stat[time_sorting][found_fore_v])
    hdf_append(fo, 'found_after_vetoes/ifar_exc', ifar_exc[time_sorting][found_fore_v])
    hdf_append(fo, 'found_after_vetoes/fap_exc', fap_exc[time_sorting][found_fore_v])
    if ifar is not None:
        hdf_append(fo, 'found_after_vetoes/ifar', ifar[time_sorting][found_fore_v])
        hdf_append(fo, 'found_after_vetoes/fap', fap[time_sorting][found_fore_v])
    for ifo in ifo_list:
        hdf_append(fo, 'found/%s/time' % ifo,
                   time_dict[ifo][time_sorting][found_fore])
        hdf_append(fo, 'found/%s/trigger_id' % ifo,
                   trig_dict[ifo][time_sorting][found_fore])
        hdf_append(fo, 'found_after_vetoes/%s/time' % ifo,
                   time_dict[ifo][time_sorting][found_fore_v])
        hdf_append(fo, 'found_after_vetoes/%s/trigger_id' % ifo,
                   trig_dict[ifo][time_sorting][found_fore_v])

    injection_index += ninj
