#!/usr/bin/env python

import sys
import numpy, h5py, argparse, matplotlib
from matplotlib import colors
matplotlib.use('Agg')
import pylab, pycbc.results
from pycbc.io import get_chisq_from_file_choice, chisq_choices
from pycbc import conversions
from pycbc.detector import Detector
import pycbc.version

def snr_from_chisq(chisq, newsnr, q=6.):
    snr = numpy.zeros(len(chisq)) + float(newsnr)
    ind = numpy.where(chisq > 1.)[0]
    snr[ind] = float(newsnr) / ( 0.5 * (1. + chisq[ind] ** (q/2.)) ) ** (-1./q)
    return snr

parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--found-injection-file', required=True,
                    help='HDF format found injection file. Required')
parser.add_argument('--single-injection-file', required=True,
                    help='Single detector trigger files from the injection set'
                         ': one file per ifo')
parser.add_argument('--coinc-statistic-file', required=True,
                    help='HDF format statistic file. Required')
parser.add_argument('--single-trigger-file', required=True,
                    help='Single detector trigger files from the zero lag run.'
                         ' Required')
parser.add_argument('--newsnr-contours', nargs='*', default=[],
                    help="List of newsnr values to draw contours. Optional")
parser.add_argument('--background-front', action='store_true', default=False,
                    help='If set, plot background on top of injections rather '
                         'than vice versa')
parser.add_argument('--colorbar-choice', choices=('effective_spin', 'mchirp',
                    'eta', 'effective_distance', 'mtotal', 'optimal_snr',
                    'redshift'), default='effective_distance',
                    help='Parameter to use for the colorbar. '
                         'Default=effective_distance')
parser.add_argument('--chisq-choice', choices=chisq_choices,
                    default='traditional',
                    help='Which chisquared to plot. Default=traditional')
parser.add_argument('--output-file', required=True)
args = parser.parse_args()

# Add the background triggers
f = h5py.File(args.coinc_statistic_file, 'r')
b_tids = {}
ifos = f.attrs['ifos'].split(' ')
for tmpifo in ifos:
    b_tids[tmpifo] = f['background_exc/{}/trigger_id'.format(tmpifo)]

f = h5py.File(args.single_trigger_file, 'r')
ifo = tuple(f.keys())[0]
f = f[ifo]
tid = b_tids[ifo][:]
bkg_snr = f['snr'][:][tid]

bkg_chisq = get_chisq_from_file_choice(f, args.chisq_choice)[tid]

# don't plot if chisq is not calculated
bkg_pos = bkg_chisq > 0
bkg_snr = bkg_snr[bkg_pos]
bkg_chisq = bkg_chisq[bkg_pos]
fig = pylab.figure()
pylab.scatter(bkg_snr, bkg_chisq, marker='o', color='black',
              linewidth=0, s=4, label='Background', alpha=0.6,
              zorder=args.background_front)

# Add the found injection points
f = h5py.File(args.found_injection_file, 'r')
inj_tids = {}
for tmpifo in ifos:
    inj_tids[tmpifo] = f['found_after_vetoes/{}/trigger_id'.format(tmpifo)]

eff_dists = Detector(ifo).effective_distance(f['injections/distance'][:],
                                             f['injections/ra'][:],
                                             f['injections/dec'][:],
                                             f['injections/polarization'][:],
                                             f['injections/tc'][:],
                                             f['injections/inclination'][:])

inj_idx = f['found_after_vetoes/injection_index'][:]
eff_dist = eff_dists[inj_idx]
m1, m2 = f['injections/mass1'][:][inj_idx], f['injections/mass2'][:][inj_idx]
s1, s2 = f['injections/spin1z'][:][inj_idx], f['injections/spin2z'][:][inj_idx]
mchirp = conversions.mchirp_from_mass1_mass2(m1, m2)
eta = conversions.eta_from_mass1_mass2(m1, m2)
weighted_spin = conversions.chi_eff(m1, m2, s1, s2)
redshift = f['injections/redshift'][:][inj_idx] if \
    args.colorbar_choice == 'redshift' else None

# choices to color the found injections
coloring = {'effective_distance': (eff_dist, "Effective Distance (Mpc)",
                                                             colors.LogNorm()),
            'mchirp': (mchirp, "Chirp Mass", colors.LogNorm()),
            'eta': (eta, "Symmetric Mass Ratio", colors.LogNorm()),
            'effective_spin': (weighted_spin, "Weighted Aligned Spin", None),
            'mtotal': (m1 + m2, "Total Mass", colors.LogNorm()),
            'redshift': (redshift, "Redshift", None)
            }

if 'optimal_snr_{}'.format(ifo) in f['injections']:
    opt_snr_str = 'injections/optimal_snr_{}'.format(ifo)
    opt_snr = f[opt_snr_str][:][inj_idx]
    coloring['optimal_snr'] = (opt_snr, 'Optimal SNR', colors.LogNorm())

f = h5py.File(args.single_injection_file, 'r')[ifo]
tid = inj_tids[ifo][:]
if len(tid):
    inj_snr = f['snr'][:][tid]
    inj_chisq = get_chisq_from_file_choice(f, args.chisq_choice)[tid]
else:
    inj_snr = numpy.array([])
    inj_chisq = numpy.array([])

inj_pos = inj_chisq > 0
# Catch not enough found injections case
if len(coloring[args.colorbar_choice][0]) == 0:
    coloring[args.colorbar_choice] = (None, None, None)
else:  # Only plot positive chisq
    pylab.scatter(inj_snr[inj_pos], inj_chisq[inj_pos],
                  c=coloring[args.colorbar_choice][0][inj_pos],
                  norm=coloring[args.colorbar_choice][2], s=20,
                  marker='^', linewidth=0, label="Injections",
                  zorder=(not args.background_front))

try:
    r = numpy.logspace(numpy.log(min(bkg_chisq.min(), inj_chisq[inj_pos].min())
                                 * 0.9),
                   numpy.log(max(bkg_chisq.max(), inj_chisq.max()) * 1.1), 200)
except ValueError:
    # Allow code to continue in the absence of injection triggers
    r = numpy.logspace(numpy.log(bkg_chisq.min() * 0.9),
                       numpy.log(bkg_chisq.max() * 1.1), 200)

if args.newsnr_contours:
    for cval in args.newsnr_contours:
        snrv = snr_from_chisq(r, cval)
        pylab.plot(snrv, r, '--', color='grey', linewidth=1)

ax = pylab.gca()
ax.set_xscale('log')
ax.set_yscale('log')

try:
    cb = pylab.colorbar()
    cb.set_label(coloring[args.colorbar_choice][1], size='large')
except (TypeError, ZeroDivisionError):
    # Catch case of no injection triggers
    if len(inj_chisq):
        raise

pylab.title('%s Coincident Triggers' % ifo, size='large')
pylab.xlabel('SNR', size='large')
pylab.ylabel('Reduced $\chi^2$', size='large')
try:
    pylab.xlim(min(inj_snr.min(), bkg_snr.min()) * 0.99,
               max(inj_snr.max(), bkg_snr.max()) * 1.4)
    pylab.ylim(min(bkg_chisq.min(), inj_chisq[inj_pos].min()) * 0.7,
               max(bkg_chisq.max(), inj_chisq.max()) * 1.4)
except ValueError:
    # Raised if no injection triggers
    pass
pylab.legend(loc='lower right', prop={'size': 10})
pylab.grid(which='major', ls='solid', alpha=0.7, linewidth=.5)
pylab.grid(which='minor', ls='solid', alpha=0.7, linewidth=.1)

title = '%s %s chisq vs SNR. %s background with injections %s' \
        % (ifo.upper(), args.chisq_choice, ''.join(ifos).upper(),
           'behind' if args.background_front else 'ontop')
caption = """Distribution of SNR and %s chi-squared veto for single detector
triggers. Black points are %s background triggers. Triangles are injection
triggers colored by %s of the injection. Dashed lines show contours of
constant NewSNR.""" % (args.chisq_choice, ''.join(ifos).upper(),
                       coloring[args.colorbar_choice][1])
pycbc.results.save_fig_with_metadata(fig,
                                     args.output_file,
                                     title=title,
                                     caption=caption,
                                     cmd=' '.join(sys.argv),
                                     fig_kwds={'dpi':200})
