#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# =============================================================================
# Preamble
# =============================================================================

from __future__ import division

import sys
import glob
import os
import copy
import logging
from glue.ligolw import lsctables
import pycbc.version
from pycbc.results.pygrb_plotting_utils import *
plt.switch_backend('Agg')

__author__  = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_signal_consistency"


# =============================================================================
# Main script starts here
# =============================================================================

description = 'Produces signal consistency plots of the form SNR versus veto.'
usage = __program__ + ' [--options]'
opts = pygrb_plot_opts_parser(usage=usage, description=description, version=__version__)

if opts.verbose:
    level = logging.INFO
else:
    level = logging.WARNING
logging.basicConfig(format="%(asctime)s:%(levelname)s : %(message)s",
                    level=level)

trig_file = os.path.abspath(opts.trig_file)
inj_file = None
if opts.inj_file:
    inj_file = os.path.abspath(opts.inj_file)
outfile = opts.output_file
zoomedoutfile = opts.zoomed_output_file
if zoomedoutfile is None:
    msg = "Please specify a pathname with the zoomed-ouptut-file option"
    logging.error(msg)
    sys.exit()
veto_files = []
if opts.veto_directory:
    veto_string = ','.join([str(i) for i in range(2,opts.veto_category+1)])
    veto_files = glob.glob(opts.veto_directory +'/*CAT[%s]*.xml' %(veto_string))
# If an ifo is given, the veto is intended as a single IFO quantity
ifo = opts.ifo
veto_type = opts.variable
if veto_type is None:
    msg = "Please specify the chi-square veto to be plotted"
    logging.error(msg)
    sys.exit()
# If this is false, coherent SNR is used
use_sngl_ifo_snr = opts.use_sngl_ifo_snr
# Prepare plot title and caption
veto_labels = {'standard': "Chi Square Veto",
               'bank': "Bank Veto",
               'auto': "Auto Veto"}
if opts.plot_title is None:
    opts.plot_title = veto_labels[veto_type]
    if ifo:
        opts.plot_title = "%s %s" %(ifo, opts.plot_title)
    if use_sngl_ifo_snr:
        opts.plot_title = "%s vs %s SNR" %(opts.plot_title, ifo)
    else:
        opts.plot_title = "%s vs Coherent SNR" % opts.plot_title
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: Background triggers\n" +
                         "Red crosses: Injections triggers\n" +
                         "Black line: Veto line\n" +
                         "Gray shaded region: Vetoed area\n" +
                         "Yellow lines: Contours of new SNR")

logging.info("Imported and ready to go.")

# Set output directories
outdirs = [os.path.split(os.path.abspath(outfile))[0], \
           os.path.split(os.path.abspath(zoomedoutfile))[0]]
for outdir in outdirs:
    if not os.path.isdir(outdir):
        os.makedirs(outdir)

# Extract IFOs and vetoes
ifos = extract_ifos(trig_file)

# Extract IFOs and vetoes
vetoes = extract_vetoes(trig_file, ifos)

# Load triggers
trigs = load_triggers(trig_file, vetoes, ifos)

# Extract trigger data
trig_data = PygrbFilterOutput(trigs, ifos,
                              lsctables.MultiInspiralTable.loadcolumns,
                              "triggers", opts)

# Load injections
injs = None
if inj_file:
    injs = load_injections(inj_file, vetoes)

# Extract injection data
inj_data = PygrbFilterOutput(injs, ifos,
                             lsctables.MultiInspiralTable.loadcolumns,
                             "injections", opts)

# Generate plots
logging.info("Plotting...")

# Determine x-axis values of triggers and injections
# Default is coherent SNR
x_label = "Coherent SNR"
trig_snr = trig_data.snr
inj_snr = None
if inj_file:
    inj_snr = inj_data.snr
# Case with single SNR
if use_sngl_ifo_snr:
    x_label = "%s SNR" % ifo
    trig_snr = trig_data.ifo_snr[ifo]
    if inj_file:
        inj_snr = inj_data.ifo_snr[ifo]
# Coherent SNR requested with standard chi-square
elif veto_type == 'standard':
    trig_snr = trig_data.snr[trig_data.reweighted_snr!=0]
    if inj_file:
        inj_snr = inj_data.snr[inj_data.reweighted_snr!=0]
# Sanity check
if trig_snr is None and inj_snr is None:
    logging.warn("No data to be plotted on the x-axis was found")
    sys.exit()

# Determine the minumum and maximum SNR value we are dealing with
x_min = opts.sngl_snr_threshold
x_max = axis_max_value(trig_snr, inj_snr, inj_file)


# Determine y-axis values of triggers and injections
y_label = veto_labels[veto_type]
if ifo:
    y_label = "%s Single %s" % (ifo, y_label)
trig_veto_data = None
sngl_ifo_veto_flag = ''
trig_veto_dict = {'standard': trig_data.chi_square[trig_data.reweighted_snr!=0],
                  'bank': trig_data.bank_veto,
                  'auto': trig_data.auto_veto}
if ifo:
    sngl_ifo_veto_flag = '1'
    trig_veto_dict.update({'standard1': trig_data.ifo_stan_cs[ifo],
                           'bank1': trig_data.ifo_bank_cs[ifo],
                           'auto1': trig_data.ifo_auto_cs[ifo]})
trig_veto_data = trig_veto_dict[veto_type+sngl_ifo_veto_flag]
inj_veto_data = None
if inj_file:
    inj_veto_dict = {'standard': inj_data.chi_square[inj_data.reweighted_snr!=0],
                      'bank': inj_data.bank_veto,
                      'auto': inj_data.auto_veto}
    if ifo:
        inj_veto_dict.update({'standard1': inj_data.ifo_stan_cs[ifo],
                              'bank1': inj_data.ifo_bank_cs[ifo],
                              'auto1': inj_data.ifo_auto_cs[ifo]})
    inj_veto_data = inj_veto_dict[veto_type+sngl_ifo_veto_flag]
# Sanity check
if trig_veto_data is None and inj_veto_data is None:
    logging.warn("No data to be plotted on the y-axis was found")
    sys.exit()

# Determine the maximum bank veto value we are dealing with
y_max = axis_max_value(trig_veto_data, inj_veto_data, inj_file)

# Determine contours for plots
bank_conts, auto_conts, chi_conts, null_cont, snr_vals, cont_value, colors = \
    calculate_contours(trigs, opts)
conts = None
cont_dict = {'standard': chi_conts,
             'bank': bank_conts,
             'auto': auto_conts}
if ifo is None:
    conts = cont_dict[veto_type]

# Produce a non-zoomed and a zoomed veto vs. SNR plot
fig_path_list = [outfile, zoomedoutfile]
xlims_list = [[x_min,50], [x_min,1.1*x_max]]
ylims_list = [[1,20000], [1,10*y_max]]
pygrb_shared_plot_setups()
for fig_path, xlims, ylims in zip(fig_path_list, xlims_list, ylims_list):
    pygrb_plotter(trig_snr, trig_veto_data, inj_snr, inj_veto_data,
                  inj_file, x_label, y_label, fig_path,
                  snr_vals=snr_vals, conts=conts, colors=colors,
                  shade_cont_value=cont_value, vert_spike=True,
                  xlims=xlims, ylims=ylims, cmd=' '.join(sys.argv),
                  plot_title=opts.plot_title, plot_caption=opts.plot_caption)
