#!/usr/bin/env python

# Copyright 2016 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys, h5py
import argparse, logging

import copy, numpy as np

from pycbc import io, events
from pycbc.events import triggers, trigger_fits as trstats
from pycbc.events import stat as statsmod
from pycbc.types.optparse import MultiDetOptionAction
import pycbc.version

#### DEFINITIONS AND FUNCTIONS ####

def get_stat(statname, trigs, threshold):
    # For now this is using the single detector ranking. If we want, this
    # could use the Stat classes in stat.py using similar code as in hdf/io.py
    # This requires additional options, so only change this if it's useful!
    chunk_size = 2**23
    stat = []
    select = []
    size = len(trigs['end_time'])
    s = 0
    while s < size:
        e = s + chunk_size if (s + chunk_size) <= size else size

        # read and format chunk of data so it can be read by key
        # as the stat classes expect.
        chunk = {k: trigs[k][s:e] for k in trigs if len(trigs[k]) == size}
        
        rank_method = statsmod.get_statistic_from_opts(args, [args.ifo])
        chunk_stat = rank_method.get_sngl_ranking(chunk)

        above = chunk_stat >= threshold
        stat.append(chunk_stat[above])
        select.append(above)
        s += chunk_size

    # Return boolean area that selects the triggers above threshold
    # along with the stat values above threshold
    return np.concatenate(select), np.concatenate(stat)

#### MAIN ####

parser = argparse.ArgumentParser(usage="",
    description="Perform maximum-likelihood fits of single inspiral trigger"
                " distributions to various functions")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trigger-file",
                    help="Input hdf5 file containing single triggers. "
                    "Required")
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters. Required")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. "
                    "Format is PART/NUM_PARTS")
parser.add_argument("--veto-file", nargs='*', default=[], action='append',
                    help="File(s) in .xml format with veto segments to apply "
                    "to triggers before fitting")
parser.add_argument("--veto-segment-name", nargs='*', default=[], action='append',
                    help="Name(s) of veto segments to apply. Optional, if not "
                    "given all segments for a given ifo will be used")
parser.add_argument("--gating-veto-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to be vetoed before and after the central time "
                         "of each gate. Given as detector-values pairs, e.g. "
                         "H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument("--output", required=True,
                    help="Location for output file containing fit coefficients"
                    ". Required")
parser.add_argument("--ifo", required=True,
                    help="Ifo producing triggers to be fitted. Required")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--stat-threshold", type=float,
                    help="Only fit triggers with statistic value above this "
                    "threshold.  Required.  Typically 6-6.5")
parser.add_argument("--save-trig-param",
                    help="For each template, save a parameter value read from "
                    "its trigger(s). Ex. template_duration")
parser.add_argument("--prune-param",
                    help="Parameter to define bins for 'pruning' loud triggers"
                    " to make the fit insensitive to signals and outliers. "
                    "Choose from mchirp, mtotal, template_duration or a named "
                    "frequency cutoff in pnutils or a frequency function in "
                    "LALSimulation")
parser.add_argument("--prune-bins", type=int,
                    help="Number of bins to divide bank into when pruning")
parser.add_argument("--prune-number", type=int,
                    help="Number of loudest events to prune in each bin")
parser.add_argument("--log-prune-param", action='store_true',
                    help="Bin in the log of prune-param")
parser.add_argument("--f-lower", default=-1.,
                    help="Starting frequency for calculating template "
                    "duration, required if this is the prune parameter")
# FIXME : support using the trigger file duration as prune parameter?
# FIXME : have choice of SEOBNRv2 or PhenD duration formula ?
parser.add_argument("--min-duration", default=0.,
                    help="Fudge factor for templates with tiny or negative "
                    "values of template_duration: add to duration values "
                    "before pruning. Units seconds")
parser.add_argument("--approximant", default="SEOBNRv4",
                    help="Approximant for template duration. Default SEOBNRv4")

statsmod.insert_statistic_option_group(parser,
    default_ranking_statistic='single_ranking_only')
args = parser.parse_args()

args.veto_segment_name = sum(args.veto_segment_name, [])
args.veto_file = sum(args.veto_file, [])

if len(args.veto_segment_name) != len(args.veto_file):
    raise RuntimeError("Number of veto files much match veto file names")

if (args.prune_param or args.prune_bins or args.prune_number) and not \
   (args.prune_param and args.prune_bins and args.prune_number):
    raise RuntimeError("To prune, need to specify param, number of bins and "
                       "nonzero number to prune in each bin!")

if args.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

logging.info('Fitting above threshold %f' % args.stat_threshold)

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')
logging.info('Opening template file: %s' % args.bank_file)
templatef = h5py.File(args.bank_file, 'r')

logging.info('Counting number of triggers in each template')
# template boundaries dataset is in order of template_id
tb = trigf[args.ifo+'/template_boundaries'][:]
tid = np.arange(len(tb))
# template boundary values ascend in the same order as template hash
# hence sort by hash
hash_sort = np.argsort(templatef['template_hash'][:])
tb_hashorder = tb[hash_sort]
# reorder template IDs in parallel to the boundary values
tid_hashorder = tid[hash_sort]

# Calculate the differences between the boundary indices to get the
# number in each template
# adding on total number at the end to get number in the last template
total_number = len(trigf[args.ifo + '/template_id'])
count_in_template_hashorder = np.diff(np.append(tb_hashorder, total_number))
# re-reorder values from hash order to tid order
tid_sort = np.argsort(tid_hashorder)
count_in_template = count_in_template_hashorder[tid_sort]

# get the stat values
logging.info('Calculating stat values')
abovethresh, stat = get_stat(args, trigf[args.ifo], args.stat_threshold)

tid = trigf[args.ifo + '/template_id'][abovethresh]
time = trigf[args.ifo + '/end_time'][abovethresh]
if args.save_trig_param:
    tparam = trigf[args.ifo + '/' + args.save_trig_param][abovethresh]
logging.info('%i trigs left after thresholding' % len(stat))

# Calculate total time being analysed from segments
# use set() to eliminate duplicates
segment_starts = sorted(set(trigf['{}/search/start_time'.format(args.ifo)][:]))
segment_ends = sorted(set(trigf['{}/search/end_time'.format(args.ifo)][:]))
all_segments = events.veto.start_end_to_segments(segment_starts, segment_ends)

# now do vetoing
for veto_file, veto_segment_name in zip(args.veto_file, args.veto_segment_name):
    retain, junk = events.veto.indices_outside_segments(time, [veto_file],
                                 ifo=args.ifo, segment_name=veto_segment_name)
    all_segments -= events.veto.select_segments_by_definer(veto_file,
                                                           ifo=args.ifo,
                                                           segment_name=veto_segment_name)
    stat = stat[retain]
    tid = tid[retain]
    time = time[retain]
    if args.save_trig_param:
        tparam = tparam[retain]
    logging.info('%i trigs left after vetoing with %s' %
                                                       (len(stat), veto_file))
#now include gating veto
if args.gating_veto_windows:
    gating_veto = args.gating_veto_windows[args.ifo].split(',')
    gveto_before = float(gating_veto[0])
    gveto_after = float(gating_veto[1])
    if gveto_before > 0 or gveto_after < 0:
        raise ValueError("Gating veto window values must be negative before "
                         "gates and positive after gates.")
    if not (gveto_before == 0 and gveto_after == 0):
        gate_times = np.unique(trigf[args.ifo + '/gating/auto/time'][:])
        gveto_segs = events.veto.start_end_to_segments(gate_times + gveto_before,
                                                       gate_times + gveto_after).coalesce()
        all_segments -= gveto_segs
        gveto_retain_idx = events.veto.indices_outside_times(time,
                                                             gate_times + gveto_before,
                                                             gate_times + gveto_after)
        stat = stat[gveto_retain_idx]
        tid = tid[gveto_retain_idx]
        time = time[gveto_retain_idx]
        if args.save_trig_param:
            tparam = tparam[gveto_retain_idx]
        logging.info('%i trigs left after vetoing triggers near gates' % len(stat))

total_time = abs(all_segments)

# do pruning (removal of trigs at N loudest times defined over param bins)
if args.prune_param:
    logging.info('Getting min and max param values')
    pars = triggers.get_param(args.prune_param, args,
                     templatef['mass1'][:], templatef['mass2'][:],
                     templatef['spin1z'][:], templatef['spin2z'][:])
    minpar = min(pars)
    maxpar = max(pars)
    del pars
    logging.info('%f %f' % (minpar, maxpar))

    # hard-coded time window of 0.1s
    args.prune_window = 0.1
    # initialize bin storage
    prunedtimes = {}
    for i in range(args.prune_bins):
        prunedtimes[i] = []

    # keep a record of the triggers if all successive loudest events were to
    # be pruned
    statpruneall = copy.deepcopy(stat)
    tidpruneall = copy.deepcopy(tid)
    timepruneall = copy.deepcopy(time)

    # many trials may be required to prune in 'quieter' bins
    for j in range(1000):
        # are all the bins full already?
        numpruned = sum([len(prunedtimes[i]) for i in range(args.prune_bins)])
        if numpruned == args.prune_bins * args.prune_number:
            logging.info('Finished pruning!')
            break
        if numpruned > args.prune_bins * args.prune_number:
            logging.error('Uh-oh, we pruned too many things .. %i, to be '
                          'precise' % numpruned)
            raise RuntimeError
        loudest = np.argmax(statpruneall)
        lstat = statpruneall[loudest]
        ltid = tidpruneall[loudest]
        ltime = timepruneall[loudest]
        m1, m2, s1z, s2z = triggers.get_mass_spin(templatef, ltid)
        lbin = trstats.which_bin(triggers.get_param(args.prune_param, args,
                                                    m1, m2, s1z, s2z),
                                 minpar, maxpar, args.prune_bins,
                                                      log=args.log_prune_param)
        # is the bin where the loudest trigger lives full already?
        if len(prunedtimes[lbin]) == args.prune_number:
            logging.info('%i - Bin %i full, not pruning event with stat %f at '
                         'time %.3f' % (j, lbin, lstat, ltime))
            # prune the reference trigger array
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            del retain
            continue
        else:
            logging.info('Pruning event with stat %f at time %.3f in bin %i' %
                         (lstat, ltime, lbin))
            # now do the pruning
            retain = abs(time - ltime) > args.prune_window
            logging.info('%i trigs before pruning' % len(stat))
            stat = stat[retain]
            logging.info('%i trigs remain' % len(stat))
            tid = tid[retain]
            time = time[retain]
            if args.save_trig_param:
                tparam = tparam[retain]
            # also for the reference trig arrays
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            # record the time
            prunedtimes[lbin].append(ltime)
            del retain
    del statpruneall
    del tidpruneall
    del timepruneall
    logging.info('%i trigs remain after pruning loop' % len(stat))

# parse template range
num_templates = len(templatef['template_hash'])
rangestr = args.template_fraction_range
part = int(rangestr.split('/')[0])
pieces = int(rangestr.split('/')[1])
tmin = int(num_templates / float(pieces) * part)
tmax = int(num_templates / float(pieces) * (part + 1))
trange = range(tmin, tmax)

# initialize result storage
tids = []
counts_total = []
counts_above = []
fits = []
tpars = []

tsort = tid.argsort()
tid = tid[tsort]
stat = stat[tsort]

# Get range of tid values which are the same
left = np.searchsorted(tid, trange, side='left')
right = np.searchsorted(tid, trange, side='right')

logging.info("Fitting ...")
for j, tnum in enumerate(trange):
    stat_in_template = stat[left[j]:right[j]]
    count_above = len(stat_in_template)
    count_total = count_in_template[tnum]
    if count_above == 0:
        # default/sentinel value to indicate no data, shouldn't hurt if 1/alpha is averaged
        alpha = -100.
    else:
        alpha, sig_alpha = trstats.fit_above_thresh(
                      args.fit_function, stat_in_template, args.stat_threshold)
    tids.append(tnum)
    counts_above.append(count_above)
    counts_total.append(count_total)
    fits.append(alpha)
    if args.save_trig_param:
        # save the param value of the first trig in this template
        param_in_template = tparam[tid == tnum]
        tpars.append(param_in_template[0])
    if (tnum % 1000 == 0): logging.info('Fitted template %i / %i' %
                                                    (tnum - tmin, tmax - tmin))

logging.info("Calculating median sigma for each template")
sigma_regions = trigf[args.ifo + '/sigmasq_template'][:]
median_sigma = []
for reg in sigma_regions:
    strigs = trigf[args.ifo + '/sigmasq'][reg]
    median_sigma.append(np.median(strigs) ** 0.5)

outfile = h5py.File(args.output, 'w')
outfile.create_dataset("template_id", data=trange)
outfile.create_dataset("count_above_thresh", data=counts_above)
outfile.create_dataset("fit_coeff", data=fits)
if args.save_trig_param:
    outfile.create_dataset("template_param", data=tpars)
outfile.create_dataset("count_in_template", data=counts_total)
outfile.create_dataset("median_sigma", data=median_sigma)
# add some metadata
outfile.attrs.create("ifo", data=args.ifo.encode())
outfile.attrs.create("fit_function", data=args.fit_function.encode())
outfile.attrs.create("sngl_stat", data=args.sngl_ranking)
if args.save_trig_param:
    outfile.attrs.create("save_trig_param", data=args.save_trig_param.encode())
outfile.attrs.create("stat_threshold", data=args.stat_threshold)
outfile.attrs.create("analysis_time", data=total_time)

outfile.close()
logging.info('Done!')
