#!/usr/bin/env python
# Copyright (C) 2020 Josh Willis
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


import h5py
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import argparse

parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description="Plot histograms of ratio of IFAR"
                                 " and ranking statistic of injections between"
                                 " two comparable runs")
parser.add_argument('--combined-comparison-file', required=True,
                    help="HDF file holding output from"
                    " 'pycbc_combine_injection_comparisons'")
parser.add_argument('--outfile', type=str, required=True,
                    help='Output file to save to')
parser.add_argument('--plot-title', type=str, required=True,
                    help='(Possibly) quoted string to be title of plot')
parser.add_argument('--found-category', type=str, required=True,
                    choices=['found', 'found_after_vetoes'],
                    help='Which class of found injections to compare')
parser.add_argument('--nbins', type=int, default=10,
                    help='Number of bins to use for template duration (x-axis)')
parser.add_argument('--log-y', action='store_true', default=False,
                    help='Use logarithmic y-axis')
args = parser.parse_args()

# Load in the two datasets
f = h5py.File(args.combined_comparison_file)

ifar_ratio = f[args.found_category]['found_in_both']['ifar']['ratio'][:]
stat_ratio = f[args.found_category]['found_in_both']['stat']['ratio'][:]

nbins = args.nbins

fig, (ax_ifar, ax_stat) = plt.subplots(1, 2, sharey=True)

stitle = fig.suptitle(args.plot_title)

_, bins = np.histogram(np.log10(ifar_ratio), bins=nbins)
ax_ifar.hist(ifar_ratio, bins=10.0**bins)
ax_ifar.set_xscale('log')
ax_ifar.set(xlabel='IFAR Ratio')

ax_stat.hist(stat_ratio, bins=nbins)
ax_stat.set(xlabel='Ranking Statistic Ratio')

ax_ifar.set(ylabel='Count')

for ax in [ax_ifar, ax_stat]:
    if args.log_y:
        ax.set_yscale('log')

fig.savefig(args.outfile)
plt.close()
