#!/usr/bin/env python

"""
This executable is used to get the output of pycbc_multiifo_sngls_findtrigs
and separate the triggers into foreground - times when coincs are not
possible - and background - times when coincs are possible. It also performs
foreground removal, and creates background_exc which is a background that
contains no triggers within a certain window of those which form foreground
coincidences.

"""

import pycbc, pycbc.io, copy
import argparse, logging, numpy as np
from glue.ligolw import ligolw, table, lsctables, utils as ligolw_utils
from pycbc import conversions as conv
from pycbc.events import veto, stat, ranking, coinc, single as sngl
from ligo.segments import segment, segmentlist
import pycbc.version
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
from scipy.stats import gaussian_kde as gk

# dummy class for loading LIGOLW files
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
    pass
lsctables.use_in(LIGOLWContentHandler)

d_power = {
    'log': 3.,
    'uniform': 2.,
    'distancesquared': 1.,
    'volume': 0.
}

mchirp_power = {
    'log': 0.,
    'uniform': 5. / 6.,
    'distancesquared': 5. / 3.,
    'volume': 15. / 6.
}

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action='count')
parser.add_argument("--version", action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("--single-statmap-files", nargs='+', required=True,
                    help="Single statmap files for which p_astro is "
                         "calculated.")

# Files to help remove foreground events
parser.add_argument('--coinc-statmap-files', nargs='+', required=True,
                    help="Coincident statmap files, containing coincident "
                         "events to be removed from the background.")
parser.add_argument('--coinc-veto-ifar-threshold', type=float, default=1.0,
                    help="Censor triggers around coincidences with "
                         "IFAR (years) above the threshold [default=1 yr")
parser.add_argument('--coinc-veto-window', type=float, default=0.1,
                    help="Time around each coincident event above threshold "
                         "to window out. Default = 0.1s")
parser.add_argument("--remove-n-loudest", type=int,
                    help="If given, will remove this number of triggers from "
                         "the background after applying foreground censor. "
                         "This helps prevent signal contamination, but too "
                         "many removals can adversely affect background "
                         "estimate.")

# Arguments for dealing with the injection trigger files for signal population
parser.add_argument('--inj-single-statmap-files', nargs='+', required=True,
                    help="Single statmap files for the injections. "
                         "Must be in same order as --inj-files")
parser.add_argument('--inj-files', nargs='+', required=True,
                    help="File which define injections. Must be in same "
                         "order as --inj-single-trigger-files")
parser.add_argument('--injection-window', type=float, default=1,
                    help="Window for how close a trigger is to an "
                         "injection to be considered to be associated "
                         "with it (seconds). Default=1.")

# Arguments to decide the weighting of the injection distribution
parser.add_argument('--distance-param', choices=['distance', 'chirp_distance'],
                    help="Parameter used to calculate injection distribution "
                         "for weighting. Default='distance'")
parser.add_argument('--distribution', default='uniform',
                    choices=['log', 'uniform', 'distancesquared', 'volume'],
                    help="Form of distribution over --distance-param. "
                         "Default='uniform'")
parser.add_argument('--expected-signal-rate', type=float, required=True,
                    help="Expected rate of signals (per year) with stat "
                         "value above --stat-threshold for use in p_astro "
                         "calcation.")
parser.add_argument('--signal-ifar-threshold', type=float, default=1,
                    help="Coincident IFAR threshold to consider an injection "
                         "as 'found' (years). Default=1")

# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument('--pastro-method', required=True,
                    choices=['truncated_shelf', 'callister', 'background_kde'],
                    help="Which method to use for calculating p_astro. ")
parser.add_argument("--bg-distribution-limit", type=float, default=8,
                    help="The point at which the noise model will change from "
                         "the background distribution to --pastro-method")

parser.add_argument('--plot-distribution',
                    help="If given, will plot the distribution of rate "
                         "densities and p_astro given ranking statistic.")
parser.add_argument('--output-file', required=True,
                    help="name of output file")
args = parser.parse_args()

d_power = {
    'log': 3.,
    'uniform': 2.,
    'distancesquared': 1.,
    'volume': 0.
}[args.distribution]

mchirp_power = {
    'log': 0.,
    'uniform': 5. / 6.,
    'distancesquared': 5. / 3.,
    'volume': 15. / 6.
}[args.distribution]


pycbc.init_logging(args.verbose)

sngl_ifo = pycbc.io.HFile(args.single_statmap_files[0], 'r').attrs['ifos']
groups = ['decimation_factor', 'stat', 'template_id', 'timeslide_id',
          sngl_ifo + '/time', sngl_ifo + '/trigger_id']
empty_darray = {g: np.array([]) for g in groups}
sngl_trigs = pycbc.io.DictArray(data=empty_darray)

logging.info('Getting single-detector triggers and segments')
sngl_detector_segs = segmentlist([segment(0, 0)])
for smap_file in args.single_statmap_files:
    with pycbc.io.HFile(smap_file, 'r') as f_in:
        sngl_starts = f_in['segments/' + sngl_ifo + '/start'][:]
        sngl_ends = f_in['segments/' + sngl_ifo + '/end'][:]
        sngl_trigs += pycbc.io.DictArray(data={g: f_in[g][:] for g in groups})
    sngl_detector_segs += veto.start_end_to_segments(sngl_starts, sngl_ends)

fg_time = abs(sngl_detector_segs)
coinc_segs = segmentlist([segment(0, 0)])
fgveto_segs = segmentlist([segment(0, 0)])
logging.info('Getting coinc segments and foreground vetoes')
for cfilename in args.coinc_statmap_files:
    with pycbc.io.HFile(cfilename, 'r') as cfile:
        # Coinc segments are different depending on statmap file
        if 'coinc' in cfile['segments']:
            c_starts = cfile['segments/coinc/start'][:]
            c_ends = cfile['segments/coinc/end'][:]
        else:
            ctype = cfile.attrs['ifos'].replace(' ','')
            c_starts = cfile['segments/' + ctype + '/start'][:]
            c_ends = cfile['segments/' + ctype + '/end'][:]
        # Test if the ifo is actually in this statmap file
        if 'ifos' in cfile.attrs.keys():
            ifolist = cfile.attrs['ifos'].split(' ')
            time_key = sngl_ifo + '/time'
        else:
            ifolist = [cfile.attrs['detector_1'],
                       cfile.attrs['detector_2']]
            if sngl_ifo == cfile.attrs['detector_1']:
                time_key = 'time1'
            else:
                time_key = 'time2'

        if sngl_ifo not in ifolist:
            logging.warning("IFO %s is not in file %s", sngl_ifo, cfilename)
            continue

        # Find the coinc events above a threshold IFAR
        cifar = cfile['foreground/ifar'][:]
        c_above = cifar > args.coinc_veto_ifar_threshold
        ctime = cfile['foreground'][time_key][:][c_above]
        c_fgv_starts = ctime - args.coinc_veto_window
        c_fgv_ends = ctime + args.coinc_veto_window
    coinc_segs = coinc_segs + veto.start_end_to_segments(c_starts, c_ends)
    fgveto_segs = fgveto_segs + veto.start_end_to_segments(c_fgv_starts,
                                                           c_fgv_ends)
    coinc_segs.coalesce()
    fgveto_segs.coalesce()


logging.info("Removing triggers in single-detector time from background.")
bg_bool = np.array([t in coinc_segs
                    for t in sngl_trigs.data[sngl_ifo + '/time']])
bg_idx = np.flatnonzero(bg_bool)
bg_trigs = sngl_trigs.select(bg_idx)

logging.info("Removing triggers in foreground vetoed time from background.")
fg_veto_bool_bg = np.array([t not in fgveto_segs
                         for t in bg_trigs.data[sngl_ifo + '/time']])
bg_trigs = bg_trigs.select(np.flatnonzero(fg_veto_bool_bg))
bg_time = abs(coinc_segs - fgveto_segs)

if args.remove_n_loudest:
    logging.info("Removing %d loudest remaining triggers from background",
                 args.remove_n_loudest)
    loudest_idx = np.argsort(bg_trigs.data['stat'])[-args.remove_n_loudest:]
    bg_trigs = bg_trigs.remove(loudest_idx)

logging.info("Getting injected signal triggers and information.")

signal_stat = np.array([])
signal_weights = np.array([])
# This method requires injection definition files and injection single statmap files to be
# in the same order - would prefer to not have to do this
for inj_filename, inj_trigger_filename in zip(args.inj_files,
                                              args.inj_single_statmap_files):

    logging.info('Reading injection statistic file')
    with pycbc.io.HFile(inj_trigger_filename, 'r') as inj_trig_f:
        inj_trig_id = inj_trig_f[sngl_ifo]['trigger_id'][:]
        inj_trig_time = inj_trig_f[sngl_ifo]['time'][:]
        inj_stat = inj_trig_f['stat'][:]

    time_sort = inj_trig_time.argsort()

    logging.info('Reading injection file')
    indoc = ligolw_utils.load_filename(inj_filename, False,
                                       contenthandler=LIGOLWContentHandler)
    sim_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName)
    inj_time = np.array(sim_table.get_column('geocent_end_time') +
                        1e-9 * sim_table.get_column('geocent_end_time_ns'),
                        dtype=np.float64)

    left = np.searchsorted(inj_trig_time[time_sort],
                           inj_time - args.injection_window, side='left')
    right = np.searchsorted(inj_trig_time[time_sort],
                            inj_time + args.injection_window, side='right')
    found = np.flatnonzero((right - left) == 1)

    found_d = np.array(sim_table.get_column('distance'),
                       dtype=np.float32)[found]
    found_m1 = np.array(sim_table.get_column('mass1'),
                        dtype=np.float32)[found]
    found_m2 = np.array(sim_table.get_column('mass2'),
                        dtype=np.float32)[found]

    signal_stat = np.append(signal_stat, inj_stat[left[found]])

    found_mchirp = conv.mchirp_from_mass1_mass2(found_m1, found_m2)
    if args.distance_param == 'chirp_distance':
        found_d = conv.chirp_distance(found_d, found_mchirp)

    weights = found_mchirp ** mchirp_power * found_d ** d_power
    signal_weights = np.append(signal_weights, weights)


bg_stat = bg_trigs.data['stat']
fg_stat = sngl_trigs.data['stat']

n_sig_exp = args.expected_signal_rate * conv.sec_to_year(fg_time)
logging.info("%.3f signals are expected in this amount of data", n_sig_exp)

max_bg = bg_stat.max()
min_bg = bg_stat.min()

logging.info('Getting Gaussian kernel density estimates of signal '
             'distribution')
sig_kern = gk(signal_stat, weights=signal_weights, bw_method=1)
sig_norm = sig_kern.integrate_box_1d(max_bg, np.inf)
sig_dens = sig_kern(fg_stat) / sig_norm

logging.info('Getting Gaussian kernel density estimates of background '
             'distribution')
bg_kern = gk(bg_stat, bw_method=1)
bg_rate_dens = bg_kern(fg_stat) * bg_stat.size

# Find the peak of the background distribution:
# anything with statistic below this gets p_astro zero, this prevents
# problems with the background distribution reducing as triggers are
# clustered away
bg_dens_max_idx = bg_rate_dens.argmax()
bg_dens_max_stat = fg_stat[bg_dens_max_idx]
fg_stat_below_peak = fg_stat < bg_dens_max_stat

logging.info('Getting noise distribution of %s method.', args.pastro_method)
# If the foreground statistic value is quieter than X
# then use background density for noise
pastro_noise_valid = fg_stat >= args.bg_distribution_limit
noise_model = copy.deepcopy(bg_rate_dens)

if args.pastro_method == 'callister':
    noise_model[pastro_noise_valid] = sig_dens[pastro_noise_valid]
elif args.pastro_method == 'truncated_shelf':
    fgs = fg_stat[pastro_noise_valid]
    stat_below = np.array([max(bg_stat[bg_stat < fs])
                           if fs > min_bg else -np.inf for fs in fgs])
    ts_norm = np.array([sig_kern.integrate_box_1d(sbel, fs)
                        for sbel, fs in zip(stat_below, fgs)])
    noise_model[pastro_noise_valid] = sig_kern(fgs) / ts_norm

logging.info("Calculating pastro for foreground events")
p_astro = n_sig_exp * sig_dens / (noise_model + n_sig_exp * sig_dens)
p_astro[fg_stat_below_peak] = 0

if args.plot_distribution:
    logging.info("Making distribution plot")
    max_bin_c =  min(signal_stat.max(), bg_stat.max() * 4)
    stat_bins = np.linspace(min_bg, max_bin_c, 201)
    stat_bin_c = (stat_bins[1:] + stat_bins[:-1]) / 2

    fig = plt.figure(figsize=[6, 6])
    ax0 = fig.add_subplot(211)
    ax1 = fig.add_subplot(212)

    # Plot background trigger rate density
    bg_rate_dens_plt = bg_kern(stat_bin_c) * bg_stat.size
    ax0.plot(stat_bin_c, bg_rate_dens_plt, linestyle="--",
             label="Background KDE")

    # Plot signal trigger rate density (from injections)
    sig_rate_dens_plt = sig_kern(stat_bin_c) / sig_norm
    ax0.plot(stat_bin_c, sig_rate_dens_plt, linestyle=":",
             label="Signals KDE")

    # Calculate noise model according to chosen method
    if args.pastro_method == 'callister':
        noise_mod_plt = sig_rate_dens_plt
    elif args.pastro_method == 'truncated_shelf':
        stat_below_plt = [max(bg_stat[sbc >= bg_stat])
                          if sbc >= min_bg else -np.inf for sbc in stat_bin_c]
        ts_norm = np.array([sig_kern.integrate_box_1d(sbel, sbc)
                           for sbel, sbc in zip(stat_below_plt, stat_bin_c)])
        noise_mod_plt = sig_kern(stat_bin_c) / ts_norm
    elif args.pastro_method == 'background_kde':
        noise_mod_plt = bg_rate_dens_plt

    # triggers below the chosen distribution limit use the background
    # distribution for noise model
    past_noise_idx_plt = stat_bin_c < args.bg_distribution_limit
    noise_mod_plt[past_noise_idx_plt] = bg_rate_dens_plt[past_noise_idx_plt]
    ax0.plot(stat_bin_c, noise_mod_plt, linestyle='-', c='k',
             label=args.pastro_method)

    ax0.semilogy()
    ax0.grid()
    # Set axis limits according to the highest point of the background
    # distribution
    ax0.set_xlim([bg_dens_max_stat, max_bin_c])
    ax0.set_ylim([sig_kern(max_bin_c) / sig_norm / 1.5,
                  bg_rate_dens[bg_dens_max_idx] * 1.5])
    ax0.legend()
    ax0.set_xlabel("Ranking Statistic")
    ax0.set_ylabel("Rate Density")

    # Calculate p_astro for the plotted statistic values
    p_astro_plt = \
        n_sig_exp * sig_rate_dens_plt / (noise_mod_plt + n_sig_exp * sig_rate_dens_plt)

    ax1.plot(stat_bin_c, p_astro_plt, c='k', label="p_astro distribution")
    ax1.scatter(fg_stat, p_astro, c='k', marker='x', label='foreground events')
    ax1.grid()
    ax1.legend()
    ax1.set_xlim([bg_dens_max_stat, max_bin_c])
    ax1.set_ylim([0, 1])
    ax1.set_xlabel("Ranking Statistic")
    ax1.set_ylabel("p_astro (%s)" % args.pastro_method)

    fig.savefig(args.plot_distribution)

f_out = pycbc.io.HFile(args.output_file, 'w')
for k in sngl_trigs.data:
    f_out.create_dataset('foreground/' + k,
                         data=sngl_trigs.data[k],
                         compression='gzip',
                         compression_opts=9,
                         shuffle=True)
    f_out.create_dataset('background/' + k,
                         data=bg_trigs.data[k],
                         compression='gzip',
                         compression_opts=9,
                         shuffle=True)

f_out.create_dataset('foreground/p_astro',
                     data=p_astro,
                     compression='gzip',
                     compression_opts=9,
                     shuffle=True)

# Store segments
f_out['segments/%s/start' % sngl_ifo], f_out['segments/%s/end' % sngl_ifo] = \
    veto.segments_to_start_end(sngl_detector_segs)
f_out.attrs['foreground_time'] = fg_time
f_out.attrs['background_time'] = bg_time
f_out.attrs['num_of_ifos'] = 1
f_out.attrs['pivot'] = sngl_ifo
f_out.attrs['fixed'] = sngl_ifo
f_out.attrs['ifos'] = sngl_ifo
f_out.attrs['pastro_method'] = args.pastro_method
logging.info('Done')
