#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# =============================================================================
# Preamble
# =============================================================================

from __future__ import division

import sys
import glob
import os
import copy
import logging
from glue.ligolw import lsctables
import pycbc.version
import collections
import operator
import scipy
from pycbc.results.pygrb_plotting_utils import *
plt.switch_backend('Agg')

__author__  = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_coh_ifosnr"


# =============================================================================
# Main script starts here
# =============================================================================

description = 'Produces coherent SNR versus single IFO SNR plots.'
usage = __program__ + ' [--options]'
opts = pygrb_plot_opts_parser(usage=usage, description=description, version=__version__)

if opts.verbose:
    level = logging.INFO
else:
    level = logging.WARNING
logging.basicConfig(format="%(asctime)s:%(levelname)s : %(message)s",
                    level=level)

trig_file  = os.path.abspath(opts.trig_file)
inj_file   = None
if opts.inj_file:
    inj_file = os.path.abspath(opts.inj_file)
outfile       = opts.output_file
zoomedoutfile = opts.zoomed_output_file
if zoomedoutfile == None:
    msg = "Please specify a pathname with the zoomed-ouptut-file option"
    logging.error(msg)
    sys.exit()
veto_files = []
if opts.veto_directory:
    veto_string = ','.join([str(i) for i in range(2,opts.veto_category+1)])
    veto_files = glob.glob(opts.veto_directory +'/*CAT[%s]*.xml' %(veto_string))
ifo = opts.ifo
if ifo == None:
    msg = "Please specify an interferometer"
    logging.error(msg)
    sys.exit()

if opts.plot_title is None:
    opts.plot_title = '%s SNR vs Coherent SNR' % ifo
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: Background triggers\n" +
                         "Red crosses: Injections triggers\n" +
                         "Black line: Veto line\n" +
                         "Gray shaded region: vetoed area - The cut is " +
                         "applied only to the two most sensitive detectors, " +
                         "which can vary with mass and sky location.\n" +
                         "Green lines: the expected SNR for optimally " +
                         "oriented injections (mean, min, and max)\n" +
                         "Magenta lines: 2 sigma errors\n" +
                         "Cyan lines: 3 sigma errors")

logging.info("Imported and ready to go.")

# Set output directories
outdirs = [os.path.split(os.path.abspath(outfile))[0], \
           os.path.split(os.path.abspath(zoomedoutfile))[0]]
for outdir in outdirs:
    if not os.path.isdir(outdir):
        os.makedirs(outdir)

# Extract IFOs and vetoes
ifos = extract_ifos(trig_file)

# Extract IFOs and vetoes
vetoes = extract_vetoes(trig_file, ifos)

# Load triggers
trigs = load_triggers(trig_file, vetoes, ifos)

# Extract trigger data
trig_data = PygrbFilterOutput(trigs, ifos,
                              lsctables.MultiInspiralTable.loadcolumns,
                              "triggers", opts)

# Load injections
injs = None
if inj_file:
    injs = load_injections(inj_file, vetoes)

# Extract injection data
inj_data = PygrbFilterOutput(injs, ifos,
                             lsctables.MultiInspiralTable.loadcolumns,
                             "injections", opts)

# Generate plots
logging.info("Plotting...")

# Order the IFOs by sensitivity
ifo_senstvty = {}
for i_ifo in ifos:
    senstvty = trig_data.f_resp[i_ifo]*trig_data.sigma_mean[i_ifo]
    ifo_senstvty.update({i_ifo: senstvty})
ifo_senstvty = collections.OrderedDict(sorted(ifo_senstvty.items(), key=operator.itemgetter(1), reverse=True))
loudness_labels = ['Loudest', 'Second loudest', 'Third loudest']

# Determine the maximum coherent SNR value we are dealing with
x_max = axis_max_value(trig_data.snr, inj_data.snr, inj_file)
max_snr = None
if x_max < 50.:
    max_snr = 50.
else:
    max_snr = x_max

# Determine the maximum auto veto value we are dealing with
y_max = axis_max_value(trig_data.ifo_snr[ifo], inj_data.ifo_snr[ifo], inj_file)

zoom_snr = numpy.arange(0.01, max_snr, 0.01)

# Setup the plots
pygrb_shared_plot_setups()
fig_path_list = [outfile, zoomedoutfile]
x_label = "Coherent SNR"
y_label = "%s sngl SNR" % ifo
fig_path = fig_path_list[0]
fig_name = os.path.split(os.path.abspath(fig_path))[1]
logging.info(" * %s (%s vs %s)...", fig_name, x_label, y_label)
fig = plt.figure()
ax  = fig.gca()
# Plot trigger data
ax.plot(trig_data.snr, trig_data.ifo_snr[ifo], 'bx')
ax.grid()
# Plot injection data
if inj_file:
    ax.plot(inj_data.snr, inj_data.ifo_snr[ifo], 'r+')
# Sigma-mean, min, max
y_data = [trig_data.sigma_mean[ifo], trig_data.sigma_min[ifo], trig_data.sigma_max[ifo]]
# Calculate: zoom-snr * sqrt(response * sigma-mean, min, max)
y_data = map(lambda x: zoom_snr*(trig_data.f_resp[ifo]*x)**0.5 , y_data)
for el in y_data:
    ax.plot(zoom_snr, el, 'g-')
# ncx2: non-central chi-squared; ppf: percent point function
# Plot these for sigma_min and sigma_max
# TODO: If these are meant to be 2 and 3 sigma (not 1 and 2 sigma, as the old
# post-processing caption used to say), the 0.0455 should be 0.02275 so that
# 1-(0.0455*0.5)*2 = 0.9545 (2 sigma) and 1-(0.0027*0.5)*2 = 0.9973 (3 sigma)
ax.plot(zoom_snr, scipy.stats.ncx2.ppf(0.0455*0.5, 2, y_data[1]*y_data[1])**0.5, 'm-')
ax.plot(zoom_snr, scipy.stats.ncx2.ppf(1.-0.0455*0.5, 2, y_data[2]*y_data[2])**0.5, 'm-')
ax.plot(zoom_snr, scipy.stats.ncx2.ppf(0.00135*0.5, 2, y_data[1]*y_data[1])**0.5, 'c-')
ax.plot(zoom_snr, scipy.stats.ncx2.ppf(1.-0.00135*0.5, 2, y_data[2]*y_data[2])**0.5, 'c-')
# Non-zoomed plot
ax.plot([0,max_snr], [4,4], 'k-')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_xlim([0,1.1*x_max])
ax.set_ylim([0,1.1*y_max])
# Veto applies to the two most sensitive IFOs, so shade them
loudness_index = ifo_senstvty.keys().index(ifo)
if loudness_index < 2:
    limy = ax.get_ylim()[0]
    polyx = [0,max_snr]
    polyy = [4,4]
    polyx.extend([max_snr,0])
    polyy.extend([limy, limy])
    ax.fill(polyx, polyy, color = '#dddddd')
opts.plot_title = opts.plot_title + " (%s SNR)" % loudness_labels[loudness_index]
# Save non-zoomed plot
save_fig_with_metadata(fig, fig_path, cmd=' '.join(sys.argv),
                       title=opts.plot_title, caption=opts.plot_caption)
# Save zoomed plot
fig_path = fig_path_list[1]
fig_name = os.path.split(os.path.abspath(fig_path))[1]
logging.info(" * %s (%s vs %s)...", fig_name, x_label, y_label)
ax.set_xlim([6,50])
ax.set_ylim([0,20])
save_fig_with_metadata(fig, fig_path, cmd=' '.join(sys.argv),
                       title=opts.plot_title, caption=opts.plot_caption)
plt.close()
