#!/usr/bin/env python
import h5py, copy, argparse, logging, numpy, numpy.random, multiprocessing
import shutil, uuid, os.path
from ligo.segments import infinity
from pycbc.events import veto, coinc, stat, ranking
import pycbc.version
from numpy.random import seed, shuffle
from pycbc.io.hdf import ReadByTemplate
from pycbc.types.optparse import MultiDetOptionAction

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action="count")
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument("--veto-files", nargs='*', action='append', default=[],
                    help="Optional veto file. Triggers within veto segments "
                         "contained in the file are ignored")
parser.add_argument("--segment-name", nargs='*', action='append', default=[],
                    help="Optional, name of veto segment in veto file")
parser.add_argument("--gating-veto-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to be vetoed before and after the central time "
                         "of each gate. Given as detector-values pairs, e.g. "
                         "H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument("--trigger-files", nargs='*', action='append', default=[],
                    help="Files containing single-detector triggers")
parser.add_argument("--template-bank", required=True,
                    help="Template bank file in HDF format")
parser.add_argument("--pivot-ifo", required=True,
                    help="Add the ifo to use as the pivot for multi "
                         "detector coincidence")
parser.add_argument("--fixed-ifo", required=True,
                    help="Add the ifo to use as the fixed ifo for "
                         "multi detector coincidence")
# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument("--use-maxalpha", action="store_true")
parser.add_argument("--coinc-threshold", type=float, default=0.0,
                    help="Seconds to add to time-of-flight coincidence window")
parser.add_argument("--timeslide-interval", type=float,
                    help="Interval between timeslides in seconds. Timeslides are"
                         " disabled if the option is omitted.")
parser.add_argument("--loudest-keep-values", type=str, nargs='*',
                    default=['6:1'],
                    help="Apply successive multiplicative levels of decimation"
                         " to coincs with stat value below the given thresholds"
                         " Ex. 9:10 8.5:30 8:30 7.5:30. Default: no decimation")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. Format"
                         " PART/NUM_PARTS")
parser.add_argument("--randomize-template-order", action="store_true",
                    help="Random shuffle templates with fixed seed "
                         "before selecting range to analyze")
parser.add_argument("--cluster-window", type=float,
                    help="Optional, window size in seconds to cluster "
                         "coincidences over the bank")
parser.add_argument("--output-file",
                    help="File to store the coincident triggers")
parser.add_argument("--batch-singles", default=5000, type=int,
                    help="Number of single triggers to process at once")
parser.add_argument('--nprocesses', type=int, default=1,
                    help="Number of processes to use")
parser.add_argument('--stage-input', action='store_true',
                    help="Stage input files through to speed up"
                         "access by multiple processes")
parser.add_argument('--stage-input-dir', type=str, default='/dev/shm',
                    help="Directory to stage input files")
stat.insert_statistic_option_group(parser)
args = parser.parse_args()

# flatten the list of lists of filenames to a single list (may be empty)
args.segment_name = sum(args.segment_name, [])
args.veto_files = sum(args.veto_files, [])
args.trigger_files = sum(args.trigger_files, [])

if args.verbose:
    logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)


def parse_template_range(num_templates, rangestr):
    part = int(rangestr.split('/')[0])
    pieces = int(rangestr.split('/')[1])
    tmin = int(num_templates / float(pieces) * part)
    tmax = int(num_templates / float(pieces) * (part+1))
    return tmin, tmax

logging.info('Starting...')

num_templates = len(h5py.File(args.template_bank, "r")['template_hash'])
tmin, tmax = parse_template_range(num_templates, args.template_fraction_range)
logging.info('Analyzing template %s - %s' % (tmin, tmax-1))

class MultiifoTrigs(object):
    """store trigger info in parallel with ifo name and shift vector"""
    def __init__(self):
        self.ifos = []
        self.to_shift = []
        self.singles = []

trigs = MultiifoTrigs()
cleanup_files = []
for i in range(len(args.trigger_files)):
    if args.stage_input:
        dest = os.path.join(args.stage_input_dir, str(uuid.uuid4()) + '.hdf')
        logging.info("Moving %s to shared memory as %s",
                     args.trigger_files[i], dest)
        shutil.copyfile(args.trigger_files[i], dest)
        cleanup_files.append(dest)
    else:
        dest = args.trigger_files[i]

    logging.info('Opening trigger file %s: %s' % (i, dest))
    reader = ReadByTemplate(dest,
                            args.template_bank,
                            args.segment_name,
                            args.veto_files,
                            args.gating_veto_windows)
    ifo = reader.ifo
    trigs.ifos.append(ifo)
    # time shift is subtracted from pivot ifo time
    trigs.to_shift.append(-1 if ifo == args.pivot_ifo else 0)
    logging.info('Applying time shift multiple %i to ifo %s' %
                 (trigs.to_shift[-1], trigs.ifos[-1]))
    trigs.singles.append(reader)

# Coinc_segs contains only segments where all ifos are analyzed
coinc_segs = veto.start_end_to_segments([-infinity()], [infinity()])
for i, sngl in zip(trigs.ifos, trigs.singles):
    coinc_segs = (coinc_segs & sngl.segs)
for sngl in trigs.singles:
    sngl.segs = coinc_segs
    sngl.valid = veto.segments_to_start_end(sngl.segs)

# Stat class instance to calculate the coinc ranking statistic
rank_method = stat.get_statistic_from_opts(args, trigs.ifos)

# Sanity check, time slide interval should be larger than twice the
# Earth crossing time, which is approximately 0.085 seconds.
TWOEARTH = 0.085
if args.timeslide_interval is not None and args.timeslide_interval <= TWOEARTH:
    raise parser.error("The time slide interval should be larger "
                       "than twice the Earth crossing time.")

# slide = 0 means don't do timeslides
if args.timeslide_interval is None:
    args.timeslide_interval = 0

if args.randomize_template_order:
    seed(0)
    template_ids = numpy.arange(0, num_templates)
    shuffle(template_ids)
    template_ids = template_ids[tmin:tmax]
else:
    template_ids = range(tmin, tmax)

# 'data' will store output of coinc finding
# in addition to these lists of coinc info, will also store trigger times and
# ids in each ifo
data = {'stat': [], 'decimation_factor': [], 'timeslide_id': [], 'template_id': []}
for ifo in trigs.ifos:
    data['%s/time' % ifo] = []
    data['%s/trigger_id' % ifo] = []

total_factors = [1]
threshes = [numpy.inf]
for decstr in args.loudest_keep_values:
    thresh, factor = decstr.split(':')
    if int(factor) == 1:
        continue
    threshes.append(float(thresh))
    # throws an error if 'factor' is not the string representation
    # of an integer
    total_factors.append(total_factors[-1] * int(factor))

# Gather the coincs from a single template
def process_template(tnum):
    local_data = copy.deepcopy(data)
    times_full = {}
    sds_full = {}
    tids_full = {}
    logging.info('Obtaining trigs for template %i ..' % (tnum))
    for i, sngl in zip(trigs.ifos, trigs.singles):
        tids_full[i] = sngl.set_template(tnum)
        times_full[i] = sngl['end_time']
        logging.info('%s:%s' % (i, len(tids_full[i])))

        # get single-detector statistic
        sds_full[i] = rank_method.single(sngl)

    mintrigs = min([len(ti) for ti in tids_full.values()])
    if mintrigs == 0:
        logging.info('No triggers in at least one ifo for template %i, '
                     'skipping' % tnum)
        return local_data

    if type(rank_method.single_dtype) == list:
        entries = [e[0] for e in rank_method.single_dtype]
        if 'snr' in entries:
            min_snr = min([numpy.min(sds_full[i]['snr']) for i in trigs.ifos])
        else:
            min_snr = None
        if 'sigmasq'in entries:
            max_sigmasq = min([numpy.max(sds_full[i]['sigmasq']) for i in trigs.ifos])
        else:
            max_sigmasq = None
    else:
        min_snr = None
        max_sigmasq = None

    # Test whether sds_full contains single arrays or record arrays
    # this depends on the stat being used
    try:
        pivot_stat = sds_full[args.pivot_ifo]['snglstat'].copy()
        fixed_stat = sds_full[args.fixed_ifo]['snglstat'].copy()
    except IndexError:
        pivot_stat = sds_full[args.pivot_ifo].copy()
        fixed_stat = sds_full[args.fixed_ifo].copy()

    pivot_sort = numpy.argsort(pivot_stat)
    fixed_sort = numpy.argsort(fixed_stat)

    if not rank_method.single_increasing:
        pivot_sort = pivot_sort[::-1]
        fixed_sort = fixed_sort[::-1]

    # Loop over the single triggers and calculate the coincs they can form
    start0 = 0
    while start0 < len(sds_full[args.fixed_ifo]):
        start1 = 0

        end0 = start0 + args.batch_singles
        if end0 > len(sds_full[args.fixed_ifo]):
            end0 = len(sds_full[args.fixed_ifo])

        fixed_idxs = fixed_sort[start0:end0]

        times = times_full.copy()
        times[args.fixed_ifo] = times_full[args.fixed_ifo][fixed_idxs]

        sds = sds_full.copy()
        sds[args.fixed_ifo] = sds_full[args.fixed_ifo][fixed_idxs]

        tids = tids_full.copy()
        tids[args.fixed_ifo] = tids_full[args.fixed_ifo][fixed_idxs]

        # if decimation is being used precalculate the fixed network coincidences
        if len(threshes) > 1:
            fixed_ifos = [i for i in trigs.ifos if i != args.pivot_ifo]
            fixed_times = {i: times[i] for i in fixed_ifos}

            if len(fixed_ifos) > 1:
                fixed_ids, fixed_slide = coinc.time_multi_coincidence(fixed_times, 0.,
                                                                      args.coinc_threshold,
                                                                      fixed_ifos[0],
                                                                      fixed_ifos[1])
                if len(fixed_slide) == 0:
                    start0 += args.batch_singles
                    continue

                # list in ifo order of remaining trigger data
                fixed_single_info = [(i, sds[i][fixed_ids[i]]) for i in fixed_ifos]
                fixed_times = {i: fixed_times[i][fixed_ids[i]] for i in fixed_ifos}

            else:
                det = fixed_ifos[0]
                fixed_ids = {det: numpy.arange(len(times[det]))}
                fixed_single_info = [(i, sds[i]) for i in fixed_ifos]
                fixed_times = {i: times[i] for i in fixed_ifos}

        pivot_lims = {}
        pivot_lower = {}
        for kidx in range(1, len(threshes)):
            # For each trigger in the fixed network calculate the limit int
            # the pivot detector to pass the current decimation threshold
            pivot_lims[kidx] = rank_method.coinc_lim_for_thresh(
                fixed_single_info, threshes[kidx],
                args.pivot_ifo,
                time_addition=args.coinc_threshold,
                min_snr=min_snr,
                max_sigmasq=max_sigmasq
            )
            if not rank_method.single_increasing:
                pivot_lims[kidx] *= -1.
            # subtract small amount to account for errors due to rounding
            pivot_lims[kidx] -= 1e-6
            # Get the minimum statistic required for all triggers at the
            # current decimation threshold
            pivot_lower[kidx] = pivot_lims[kidx].min()

        while start1 < len(sds_full[args.pivot_ifo]):
            end1 = start1 + args.batch_singles
            if end1 > len(sds_full[args.pivot_ifo]):
                end1 = len(sds_full[args.pivot_ifo])

            pivot_idxs = pivot_sort[start1:end1]

            times[args.pivot_ifo] = times_full[args.pivot_ifo][pivot_idxs]

            sds[args.pivot_ifo] = sds_full[args.pivot_ifo][pivot_idxs]

            tids[args.pivot_ifo] = tids_full[args.pivot_ifo][pivot_idxs]

            # Do time coincidence for slides that will be kept after the last decimation
            ids, slide = coinc.time_multi_coincidence(times,
                                                      args.timeslide_interval*total_factors[-1],
                                                      args.coinc_threshold,
                                                      args.pivot_ifo,
                                                      args.fixed_ifo)
            slide *= total_factors[-1]

            single_info = [(i, sds[i][ids[i]]) for i in trigs.ifos]
            cstat = rank_method.rank_stat_coinc(
                single_info, slide, args.timeslide_interval,
                to_shift=trigs.to_shift,
                time_addition=args.coinc_threshold
            )

            #index values of the zerolag triggers
            fi = numpy.where(slide == 0)[0]

            #index values of the background triggers
            bi = numpy.where(slide != 0)[0]
            bl = bi[cstat[bi] < threshes[-1]]

            ti = numpy.concatenate([fi, bl])

            ids = {i: t[ti] for i, t in ids.items()}

            cstat = cstat[ti]
            slide = slide[ti]
            dec = numpy.concatenate([numpy.ones(len(fi)), numpy.repeat(total_factors[-1], len(bl))])

            try:
                pivot_stat = sds[args.pivot_ifo]['snglstat'].copy()
            except IndexError:
                pivot_stat = sds[args.pivot_ifo].copy()

            if not rank_method.single_increasing:
                pivot_stat *= -1.

            # Starting from the largest decimation threshold, find the first decimation step
            # where the loudest single detector trigger in pivot can pass the decimation threshold
            # with any trigger in the fixed network
            tidx = len(threshes)
            for i in range(1, len(threshes)):
                if pivot_stat[-1] >= pivot_lower[kidx]:
                    tidx = i
                    break

            # loop through decimation steps starting from the first step where passing
            # the threshold is possible
            for kidx in range(tidx, len(threshes)):

                # Remove triggers in pivot that cannot form coincidences above
                # the current decimation threshold
                pivot_cut = numpy.searchsorted(pivot_stat, pivot_lower[kidx])

                pivot_s = pivot_stat[pivot_cut:]

                test_times = fixed_times.copy()
                test_times[args.pivot_ifo] = times[args.pivot_ifo][pivot_cut:]

                # Do time coincidence for the current decimation factor
                set_ids, set_slide = coinc.time_multi_coincidence(test_times,
                                                                  args.timeslide_interval*total_factors[kidx - 1],
                                                                  args.coinc_threshold,
                                                                  args.pivot_ifo,
                                                                  args.fixed_ifo)
                set_slide *= total_factors[kidx - 1]

                # Remove foreground triggers
                sets = set_slide != 0

                # Each precalculated coincidence from the fixed network should
                # be treated like a single trigger, only keep coincidences
                # where all triggers from the fixed network coincidence are still
                # together
                for i in range(1, len(fixed_ifos)):
                    sets *= set_ids[fixed_ifos[i-1]] == set_ids[fixed_ifos[i]]

                set_ids = {i: set_ids[i][sets] for i in trigs.ifos}
                set_slide = set_slide[sets]

                # Only keep coincidences where pivot has a single stat above the threshold
                # calculated earlier
                above = pivot_s[set_ids[args.pivot_ifo]] >= pivot_lims[kidx][set_ids[fixed_ifos[0]]]

                test_ids = {i: fixed_ids[i][set_ids[i][above]] for i in fixed_ifos}
                test_ids[args.pivot_ifo] = set_ids[args.pivot_ifo][above] + pivot_cut
                set_slide = set_slide[above]
                test_single_info = [(i, sds[i][test_ids[i]]) for i in trigs.ifos]

                test_cstat = rank_method.rank_stat_coinc(
                    test_single_info, set_slide, args.timeslide_interval,
                    to_shift=trigs.to_shift,
                    time_addition=args.coinc_threshold)

                # Keep triggers in the current decimation range
                test_bi = numpy.where(test_cstat >= threshes[kidx])[0]
                test_bi = test_bi[test_cstat[test_bi] < threshes[kidx - 1]]

                for i in trigs.ifos:
                    ids[i] = numpy.concatenate([ids[i], test_ids[i][test_bi]])
                cstat = numpy.concatenate([cstat, test_cstat[test_bi]])
                slide = numpy.concatenate([slide, set_slide[test_bi]])
                dec = numpy.concatenate([dec, numpy.repeat(total_factors[kidx - 1], len(test_bi))])

            # temporary storage for decimated trigger ids
            for ifo in trigs.ifos:
                addtime = times[ifo][ids[ifo]]
                addtriggerid = tids[ifo][ids[ifo]]
                local_data['%s/time' % ifo] += [addtime]
                local_data['%s/trigger_id' % ifo] += [addtriggerid]
            local_data['stat'] += [cstat]
            local_data['decimation_factor'] += [dec]
            local_data['timeslide_id'] += [slide]
            local_data['template_id'] += [numpy.repeat(tnum, len(slide))]
            start1 += args.batch_singles
        start0 += args.batch_singles
    return local_data

if args.nprocesses == 1:
    ldatas = list(map(process_template, list(template_ids)))
else:
    p = multiprocessing.Pool(args.nprocesses)
    ldatas = p.map(process_template, list(template_ids))

logging.info('merging data from the templates')
for ldata in ldatas:
    for key in data:
        data[key] += ldata[key]

if len(data['stat']) > 0:
    for key in data:
        data[key] = numpy.concatenate(data[key])

if args.cluster_window and len(data['stat']) > 0:
    timestring0 = '%s/time' % args.pivot_ifo
    timestring1 = '%s/time' % args.fixed_ifo
    cid = coinc.cluster_coincs(data['stat'], data[timestring0], data[timestring1],
                               data['timeslide_id'], args.timeslide_interval,
                               args.cluster_window)

logging.info('saving coincident triggers')
f = h5py.File(args.output_file, 'w')
if len(data['stat']) > 0:
    for key in data:
        var = data[key][cid] if args.cluster_window else data[key]
        f.create_dataset(key, data=var,
                         compression='gzip',
                         compression_opts=9,
                         shuffle=True)

# Store coinc segments keyed by detector combination
key = ''.join(sorted(trigs.ifos))
f['segments/%s/start' % key], f['segments/%s/end' % key] = trigs.singles[0].valid

f.attrs['timeslide_interval'] = args.timeslide_interval
f.attrs['coinc_time'] = abs(coinc_segs)
f.attrs['num_of_ifos'] = len(args.trigger_files)
f.attrs['pivot'] = args.pivot_ifo
f.attrs['fixed'] = args.fixed_ifo
for i, sngl in zip(trigs.ifos, trigs.singles):
    f.attrs['%s_foreground_time' % i] = abs(sngl.segs)
f.attrs['ifos'] = ' '.join(sorted(trigs.ifos))

# What does this code actually calculate?
if args.timeslide_interval:
    maxtrigs = max([abs(sngl.segs) for sngl in trigs.singles])
    nslides = int(maxtrigs / args.timeslide_interval)
else:
    nslides = 0
f.attrs['num_slides'] = nslides

for fname in cleanup_files:
    logging.info('cleaning up %s', fname)
    os.remove(fname)

logging.info('Done')
