#!/usr/bin/python

"""
This script compares the sensitivities (VTs) of two searches having consistent
sets of injections. It reads two HDF files produced by pycbc_page_sensitivity's
--hdf-out option, and plots the ratios of their VTs at various IFARs.
"""

import sys
import argparse
import h5py
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from pycbc.results import save_fig_with_metadata


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--vt-file-one', required=True,
                    help='HDF file containing VT curves, '
                         'first set of data for comparison')
parser.add_argument('--vt-file-two', required=True,
                    help='HDF file containing VT curves, '
                         'second set of data for comparison')
parser.add_argument('--desc-one',  required=True,
                    help='Descriptor tag for first set of data '
                         '(short, for use in subscript)')
parser.add_argument('--desc-two', type=str, required=True,
                    help='Descriptor tag for second set of data '
                         '(short, for use in subscript)')
parser.add_argument('--outfile', type=str, required=True,
                    help='Output file to save to')
parser.add_argument('--ifars', type=float, required=True, nargs='+',
                    help='IFAR values to plot VT ratio for. Note that the '
                         'plotted values will be the closest values available '
                         'from the VT files')
parser.add_argument('--log-x', action='store_true',
                    help='Use logarithmic x-axis')
parser.add_argument('--log-y', action='store_true',
                    help='Use logarithmic y-axis')
args = parser.parse_args()

# Load in the two datasets
ffirst_set = h5py.File(args.vt_file_one, 'r')
fsecond_set = h5py.File(args.vt_file_two, 'r')

# Find the index closest to the given IFAR value
idxs = [np.argmin(np.abs(ffirst_set['xvals'][:] - ifv)) for ifv in args.ifars]

keys = ffirst_set['data'].keys()
# sanitise the input so that the files have the same binning parameter and bins
if keys != fsecond_set['data'].keys():
    parser.error('keys do not match for the given input files - '
                 '{} vs {}'.format(keys, fsecond_set['data'].keys()))

# make the plot pretty
plt.rc('axes.formatter', limits=[-3, 4])
plt.rc('figure', dpi=300)
fig_mi = plt.figure(figsize=(10, 4))
ax_mi = fig_mi.gca()

ax_mi.grid(True, zorder=1)

# read in labels for the different plotting points
labels = ['$ ' + label.split('\\in')[-1] for label in ffirst_set['data'].keys()]

# read in the splitting parameter name from the first data set
x_param = r'$' + tuple(ffirst_set['data'].keys())[0].split('\\in')[0].strip('$').strip() + r'$'

# read in the positions from the labels
xpos = np.array([float(l.split('[')[1].split(',')[0]) for l in labels])

# offset different ifars by 1/20th of the mean distance between parameters
try: 
    if args.log_x:
        xpos_logdiffmean = np.diff(np.log(xpos)).mean()
        xpos_add_dx = 0.05 * np.ones_like(xpos) * xpos_logdiffmean
    else:
        xpos_diffmean = np.diff(xpos).mean()
        xpos_add_dx = 0.05 * np.ones_like(xpos) * xpos_diffmean
except IndexError:
    #If there's only one value of xpos, then diff doesnt work
    xpos_add_dx = 0.05

# set the x ticks to be the positions given in the labels
plt.xticks(xpos, labels, rotation='horizontal')

colors = ['#7b85d4', '#f37738', '#83c995', '#d7369e', '#c4c9d8', '#859795']

# loop through each IFAR and plot the VT ratio with error bars
for count, idv in enumerate(idxs):
    data1 = np.array([ffirst_set['data'][key][idv] for key in keys])
    data2 = np.array([fsecond_set['data'][key][idv] for key in keys])

    errhigh1 = np.array([ffirst_set['errorhigh'][key][idv] for key in keys])
    errlow1 = np.array([ffirst_set['errorlow'][key][idv] for key in keys])

    errhigh2 = np.array([fsecond_set['errorhigh'][key][idv] for key in keys])
    errlow2 = np.array([fsecond_set['errorlow'][key][idv] for key in keys])

    ys = np.divide(data1, data2)
    yerr_errlow = np.multiply(np.sqrt(np.divide(errlow1, data1)**2 +
                np.divide(errlow2, data2)**2), ys)
    yerr_errhigh = np.multiply(np.sqrt(np.divide(errhigh1, data1)**2 +
                np.divide(errhigh2, data2)**2), ys)
    if args.log_x:
        xvals = np.exp(np.log(xpos) +
                       xpos_add_dx * (count -
                                      float(len(args.ifars) - 1) / 2.))
    else:
        xvals = xpos + xpos_add_dx * (count -
                                      float(len(args.ifars) - 1) / 2.)
    ax_mi.errorbar(xvals, ys,
        yerr=[yerr_errlow, yerr_errhigh], fmt='o', markersize=7, linewidth=5,
        label='IFAR = %d yr' % ffirst_set['xvals'][idv], capsize=5,
        capthick=2, mec='k', color=colors[count % len(colors)])

if args.log_x:
    plt.xscale('log')
if args.log_y:
    plt.yscale('log')
plt.xticks(xpos, labels, rotation='horizontal')

# get the limit of the x axes, and draw a black line in order to highlight
# equal comparison
xlimits = plt.xlim()
plt.plot(xlimits, [1, 1], 'k', lw=3, zorder=0)
plt.xlim(xlimits) # reassert the x limits so that the plot doesn't expand

ax_mi.legend(bbox_to_anchor=(0.5, 1.01), ncol=len(args.ifars),
             loc='lower center')
ax_mi.get_legend().get_title().set_fontsize('14')
ax_mi.get_legend().get_frame().set_alpha(0.7)
ax_mi.set_xlabel(x_param)
ax_mi.set_ylabel(r'$\frac{VT(\mathrm{' + args.desc_one +'})}{VT(\mathrm{' +
                 args.desc_two +'})}$')
plt.tight_layout()

title = 'VT sensitivity comparison between {} and {}'.format(
        args.vt_file_one, args.vt_file_two)

# write out to file
save_fig_with_metadata(fig_mi, args.outfile, cmd=' '.join(sys.argv),
                       title=title)

plt.close()
