#!/usr/bin/python
# Copyright (C) 2015-2019 Christopher M. Biwer, Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import h5py
import numpy
import sys
import matplotlib as mpl; mpl.use('Agg')
import pylab
import pycbc.results
import pycbc.version
from pycbc import conversions
from scipy.stats import norm, poisson

parser = argparse.ArgumentParser(usage='pycbc_ifar_catalog [--options]',
                    description='Plots cumulative IFAR vs count for'
                          ' coincident foreground triggers')
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-files', nargs='+',
                    help='Path to coincident trigger HDF file(s)')
parser.add_argument('--output-file', required=True,
                    help='Path to output plot')
parser.add_argument('--truncate-threshold', type=float,
                    help='Truncate plot above IFAR threshold (units: years)')
parser.add_argument('--remove-threshold', type=float,
                    help='Remove detected events with IFAR above threshold'
                         ' (units years)')
parser.add_argument('--open-box', action='store_true',
                    help='Are we putting open box results onto the plot? '
                         'Default=False.')
parser.add_argument('--use-tex', action='store_true',
                    help="Render using latex")
parser.add_argument('--use-hierarchical-level', type=int, default=None,
                    help='Indicate which inclusive background and FARs of '
                         'foreground triggers to plot if there were any '
                         'hierarchical removals done. Choosing None plots '
                         'the inclusive backgrounds prior to any '
                         'hierarchical removals with the updated FARs for '
                         'foreground triggers after hierarchical removal(s). '
                         'Choosing 0 means plotting inclusive background '
                         'from prior to any hierarchical removals with FARs '
                         'for foreground triggers prior to hierarchical '
                         'removal. Choosing 1 means plotting the inclusive '
                         'background after doing 1 hierarchical removal, and '
                         'includes updated FARs from after 1 hierarchical '
                         'removal. [default=plot after application of all '
                         'hierarchical removal iterations]')
parser.add_argument('--use-exclusive-ifar', action='store_true',
                    help='Whether to use exclusive background for '
                         'IFAR calculations. Default=False.')
opts = parser.parse_args()

if opts.use_tex:
    pylab.rc('text', usetex=True)
    pylab.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})

trigf = [h5py.File(f, 'r') for f in opts.trigger_files]

# Parse which inclusive background to use for the plotting
h_inc_back_num = opts.use_hierarchical_level

if opts.use_exclusive_ifar:
    ifar_str = 'ifar_exc'
else:
    ifar_str = 'ifar'

try:
    h_iterations = max([f.attrs['hierarchical_removal_iterations']
                        for f in trigf])
except KeyError:
    h_iterations = 0

if h_inc_back_num is None:
    print('Using {} hierarchical removal iterations'.format(h_iterations))
    h_inc_back_num = h_iterations

if h_inc_back_num > h_iterations:
    # Produce a null plot saying no hierarchical removals can be plotted
    import sys
    fig = pylab.figure()
    ax = fig.add_subplot(111)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    output_message = "No more foreground events louder than all background\n" \
                     "at this removal level.\nAttempted to show " + \
                     str(h_inc_back_num) + " removal(s),\n" \
                     "but only " + str(h_iterations) + " removal(s) done."

    ax.text(0.5, 0.5, output_message, horizontalalignment='center',
            verticalalignment='center')

    pycbc.results.save_fig_with_metadata(fig, opts.output_file,
       title='Cumulative Number vs. IFAR',
       caption=output_message,
       cmd=' '.join(sys.argv))

    # Exit the code successfully and bypass the rest of the plotting code.
    sys.exit(0)

# get foreground IFAR values and cumulative number for each IFAR value
if h_inc_back_num >= 0 and h_iterations is not None and h_iterations != 0:
    fore_ifar = numpy.array([])
    supp_fore_ifar = numpy.array([])
    for f in trigf:
        ifar = f['foreground_h%s/%s' % (h_inc_back_num, ifar_str)][:]
        # Get ifar of hierarchically removed foreground triggers
        supp_ifar = f['foreground/' + ifar_str][:]
        # increment arrays to be plotted
        fore_ifar = numpy.append(fore_ifar, ifar)
        supp_fore_ifar = numpy.append(supp_fore_ifar, supp_ifar)

    # Use to plot hierarchically removed foreground triggers
    # on the plot and correct the cumulative number at IFAR level
    hrm_sorted = sorted(supp_fore_ifar)
    idx_start = len(supp_fore_ifar) - h_inc_back_num

    h_rm_ifar = hrm_sorted[idx_start:]
    h_rm_cumnum = numpy.arange(len(h_rm_ifar), 0, -1)
else:
    fore_ifar = numpy.array([])
    for f in trigf:
        fore_ifar = numpy.append(fore_ifar, f['foreground/' + ifar_str][:])

if opts.remove_threshold is not None and opts.truncate_threshold is not None:
    raise RuntimeError("Can't both remove and truncate foreground events!")

if opts.remove_threshold is not None:
    fore_ifar = fore_ifar[fore_ifar < opts.remove_threshold]

sortifar = fore_ifar.argsort()
fore_ifar = fore_ifar[sortifar]
fore_cumnum = numpy.arange(len(fore_ifar), 0, -1)

# get expected (noise-only) foreground IFAR values and cumulative number for
# each IFAR value
expected_ifar = numpy.logspace(-4., numpy.log10(opts.truncate_threshold or 10000),
                               num=1000, base=10.0)
fg_time = 0
for f in trigf:
    fg_time += f.attrs['foreground_time']
expected_cumnum = conversions.sec_to_year(fg_time) / expected_ifar

# make figure
fig = pylab.figure(1)

# plot the expected background
pylab.loglog(expected_ifar, expected_cumnum, linestyle='--', linewidth=1,
             color='black', label='Expected Background')

# plot the counting error
lower = {}
upper = {}
for sigma in [1, 2, 3, 4]:
    # one-tailed significance
    eps = norm.sf(sigma, 0, 1)
    upper[sigma] = []
    lower[sigma] = []
    for mu in expected_cumnum[::-1]:
        # initialize at 'small' value
        nlo = max(0, int(mu - 5 * mu**0.5))
        while True:
            # cdf is sum of probabilities *including* nlo
            cdf = poisson.cdf(nlo, mu)
            if cdf > eps:
                break
            # count up
            nlo += 1
        # if value is not in tail then back off by one
        # nlo - 1 is largest value such that it is still in eps tail
        # thus put boundary at nlo - 0.5
        lower[sigma].append(nlo - 0.5)
        # initialize at 'large' value
        nup = max(10, int(mu + 5 * mu**0.5))
        while True:
            # sf is sum of probabilities *above* nup, ie complement of cdf
            surviv = poisson.sf(nup, mu)
            if surviv > eps:
                break
            # count down
            nup -= 1
        # here sum of probabilities from nup + 1 upwards is above eps
        # thus smallest value in eps tail is nup + 2
        # thus put boundary at nup + 1.5
        upper[sigma].append(nup + 1.5)

plotifar = expected_ifar[::-1]
pylab.fill_between(plotifar, lower[4], lower[3], facecolor='k', alpha=0.15,
                   label=r'$<4\sigma$')
pylab.fill_between(plotifar, lower[3], lower[2], facecolor='k', alpha=0.3,
                   label=r'$<3\sigma$')
pylab.fill_between(plotifar, lower[2], lower[1], facecolor='k', alpha=0.45,
                   label=r'$<2\sigma$')
pylab.fill_between(plotifar, lower[1], upper[1], facecolor='k', alpha=0.6,
                   label=r'$<1\sigma$')
pylab.fill_between(plotifar, upper[1], upper[2], facecolor='k', alpha=0.45)
pylab.fill_between(plotifar, upper[2], upper[3], facecolor='k', alpha=0.3)
pylab.fill_between(plotifar, upper[3], upper[4], facecolor='k', alpha=0.15)

# plot the foreground triggers
if opts.open_box:
    if opts.truncate_threshold:
        over_trunc = fore_ifar > opts.truncate_threshold
        fore_ifar[over_trunc] = numpy.ones(over_trunc.sum()) * \
                                                        opts.truncate_threshold
        for i in fore_cumnum[over_trunc]:
            pylab.arrow(opts.truncate_threshold, i, opts.truncate_threshold, 0,
                        head_width=0.1 * i, head_length=0.4 * \
                        opts.truncate_threshold, ec='b', fc='b')
    pylab.loglog(fore_ifar, fore_cumnum, linestyle='None', color='blue',
                 marker='^', ms=6, label='Foreground')
    max_ifar = max(fore_ifar)

    if h_inc_back_num > 0:
        max_ifar = max(max_ifar, max(h_rm_ifar))
        pylab.loglog(h_rm_ifar, h_rm_cumnum, linestyle='None', color='#b66dff',
                     marker='v', label='Hierarchically Removed Foreground')

# format plot
if opts.open_box:
    # scale the plot around the foreground
    pylab.ylim(0.7, 0.05 * len(fore_cumnum))
    pylab.xlim(30 * min(fore_ifar), 3 * max_ifar)
else:
    pylab.ylim(0.7, max(expected_cumnum))
pylab.grid()
pylab.legend(loc='upper right', fontsize=13)
pylab.ylabel('Cumulative Number', size='large')
ifar_label = 'Inverse False Alarm Rate (yr)'
if opts.use_exclusive_ifar:
    ifar_label = 'Exclusive ' + ifar_label
pylab.xlabel(ifar_label, size='large')

# save
caption = 'This is a cumulative histogram of triggers. The blue triangles ' \
          + 'represent coincident foreground triggers. The dashed line ' \
          + 'represents the expected background given the analysis time. ' \
          + 'The shaded regions represent counting errors.'
pycbc.results.save_fig_with_metadata(fig, opts.output_file,
     title='Cumulative Number vs. IFAR',
     caption=caption,
     cmd=' '.join(sys.argv))

