#!/usr/bin/python
''' Make plots of recovered injection parameters
'''
import h5py, numpy, logging, argparse, sys, matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plot

import pycbc.version, pycbc.detector
from pycbc import pnutils, results
from pycbc.events import triggers

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action="count")
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("--injection-file", required=True,
                    help="hdf injection file containing found injections. "
                         "Required")
parser.add_argument("--bank-file",
                    help="hdf bank file containing template parameters. "
                         "Required for some parameters, ex. chirp mass")
parser.add_argument("--trigger-file",
                    help="hdf trigger merge file containing single triggers. "
                         "Required for some parameters, ex. end_time")
parser.add_argument("--ifo", required=True,
                    help="Detector where injections are recovered. Required")
parser.add_argument("--min-ifar", default=0,
                    help="Minimum IFAR for which to plot injections. Units:years")
parser.add_argument("--error-param", required=True,
                    help="Property to be compared between injected and recovered "
                         "event, ex. 'mchirp'. Required")
parser.add_argument("--x-param",
                    help="Injected parameter to plot on x-axis. Required if "
                         "--plot-type is err_v_param or fracerr_v_param")
parser.add_argument("--log-x", action="store_true",
                    help="Use a logarithmic x-axis.")
parser.add_argument("--log-y", action="store_true",
                    help="Use a logarithmic y-axis.")
parser.add_argument("--plot-type", choices=["scatter", "err", "fracerr",
                    "err_v_param", "fracerr_v_param", "errhist",
                    "fracerrhist", "ratio", "ratio_v_param"],
                    default="scatter",
                    help="Type of plot to show accuracy of recovery. "
                         "Default='scatter'")
parser.add_argument("--gradient-far", action="store_true",
                    help="Show FAR of found injections as a color gradient")
parser.add_argument("--f-lower", type=int,
                    help="lower frequency value for e.g. calculation of "
                         "template duration")
parser.add_argument("--output-file", required=True)
args = parser.parse_args()

if args.x_param is not None and "v_param" not in args.plot_type:
    raise RuntimeError("Can't use an --x-param with this plot type!")
if "v_param" in args.plot_type and args.x_param is None:
    raise RuntimeError("Need an --x-param to plot errors against!")

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format="%(asctime)s : %(message)s", level=log_level)

logging.info("Reading data...")
injs = h5py.File(args.injection_file, "r")
bank = h5py.File(args.bank_file, "r") if args.bank_file else None
trig = h5py.File(args.trigger_file, "r") if args.trigger_file else None

# Check the detector
dets = injs.attrs["ifos"].split(" ")
if args.ifo not in dets:
    raise RuntimeError("Can't find detector %s in injection file!" % args.ifo)
site = args.ifo[0].lower()

# Determine which injections are found
found = injs["found_after_vetoes/injection_index"][:]
ifar_found = numpy.array(injs["found_after_vetoes/ifar"][:])
# above_min_ifar is a Boolean slice to apply to arrays containing indices of
# found injections
above_min_ifar = ifar_found > float(args.min_ifar)
found = found[above_min_ifar]
ifar_found = ifar_found[above_min_ifar]

# Get the injected param values
# need to hack end time and eff dist as different between detectors
param = args.error_param
if args.error_param == "end_time":
    param = "end_time_"+site
if args.error_param == "effective_distance":
    param = "eff_dist_"+site
# calculate the parameter for all injections, then select those found and 
# above the IFAR threshold
inj_params = triggers.get_inj_param(injs, param, args.ifo, args)[found]

# Get the injected x-axis values
if args.x_param:
    xparam = args.x_param
    if args.x_param == "end_time":
        xparam = "end_time_"+site
    if args.x_param == "effective_distance":
        xparam = "eff_dist_"+site
    inj_xparams = triggers.get_inj_param(injs, xparam, args.ifo, args)[found]
# need a string value when defining axis labels even if not used
args.x_param = args.x_param or "badger"

# Get the recovered values
# for this, unhack the parameter name as triggers are stored by ifo
param = args.error_param

# calculate parameters for all recovered injection triggers
rec_params, found_in_ifo = triggers.get_found_param(injs, bank, trig, param,
                                                    args.ifo, args)
# select values above IFAR threshold
rec_params = rec_params[above_min_ifar]
found_in_ifo = found_in_ifo[above_min_ifar]

# restrict plotted values if some inj are not found in given ifo
if not all(found_in_ifo):
    rec_params = rec_params[found_in_ifo]
    inj_params = inj_params[found_in_ifo]
    inj_xparams = inj_xparams[found_in_ifo]
    ifar_found = ifar_found[found_in_ifo]
    logging.info("Removing %i inj not found in %s" %
                 ((found_in_ifo == False).sum(), args.ifo))

# calculate needed values
if "err" in args.plot_type:
    diff = rec_params - inj_params
if "fracerr" in args.plot_type:
    reldiff = diff / inj_params
if "ratio" in args.plot_type:
    errrat = rec_params / inj_params

x_dict = {
  "scatter" : "inj_params",
  "err" : "inj_params",
  "fracerr" : "inj_params",
  "ratio" : "inj_params",
  "err_v_param" : "inj_xparams",
  "fracerr_v_param" : "inj_xparams",
  "ratio_v_param" : "inj_xparams",
  "errhist" : "diff",
  "fracerrhist" : "reldiff"
}
xlabels = {
  "scatter" : "Injected "+args.error_param,
  "err" : "Injected "+args.error_param,
  "fracerr" : "Injected "+args.error_param,
  "ratio" : "Injected "+args.error_param,
  "err_v_param" : "Injected "+args.x_param,
  "fracerr_v_param" : "Injected "+args.x_param,
  "ratio_v_param" : "Injected "+args.x_param,
  "errhist" : "Error (rec-inj) in "+args.error_param,
  "fracerrhist" : "Fractional error (rec-inj)/inj in "+args.error_param
}

y_dict = {
  "scatter" : "rec_params",
  "err" : "diff",
  "fracerr" : "reldiff",
  "ratio" : "errrat",
  "err_v_param" : "diff",
  "fracerr_v_param" : "reldiff",
  "ratio_v_param" : "errrat",
  "errhist" : "None",
  "fracerrhist" : "None"
}
ylabels = {
  "scatter" : "Recovered "+args.error_param,
  "err" : "Error (rec-inj) in "+args.error_param,
  "fracerr" : "Fractional error (rec-inj)/inj in "+args.error_param,
  "ratio" : "Ratio rec/inj in "+args.error_param,
  "err_v_param" : "Error (rec-inj) in "+args.error_param,
  "fracerr_v_param" : "Fractional error (rec-inj)/inj in "+args.error_param,
  "ratio_v_param" : "Ratio rec/inj in "+args.error_param,
  "errhist" : "Number of injections",
  "fracerrhist" : "Number of injections"
}

xvals = eval(x_dict[args.plot_type])
yvals = eval(y_dict[args.plot_type])

logging.info("Plotting %i found inj ..." % len(xvals))
fig = plot.figure()

if "hist" not in args.plot_type:
    if not args.gradient_far:
        # make a scatter with crude colour coding
        color = numpy.zeros(len(found))
        ten = numpy.where(ifar_found > 10)[0]
        hundred = numpy.where(ifar_found > 100)[0]
        thousand = numpy.where(ifar_found > 1000)[0]
        color[hundred] = 0.5
        color[thousand] = 1.0
        caption = ("Found injections: blue circles are found with IFAR<100yr, "
                   "green are IFAR<1000yr, red are IFAR >=1000yr")
        points = plot.scatter(xvals, yvals, c=color, linewidth=0, vmin=0,
                              vmax=1, s=12, marker="o", label="found",
                              alpha=0.6)

    else:
        # make a pretty rainbow coloured plot
        color = 1.0 / ifar_found
        # sort so quiet found is on top
        csort = color.argsort()
        xvals = xvals[csort]
        yvals = yvals[csort]
        color = color[csort]
        caption = ("Found injections: color indicates the estimated false "
                   "alarm rate")
        points = plot.scatter(xvals, yvals, c=color, linewidth=0,
                              norm=matplotlib.colors.LogNorm(),
                              s=16, marker="o", label="found")
        plot.subplots_adjust(right=0.99)
        try:
            c = plot.colorbar()
            c.set_label("False Alarm Rate (yr$^{-1}$)")
        except TypeError:
            # Can't make colorbar if no quiet found injections
            if len(fvals):
                raise
else:
    # it's a histogram!
    plot.hist(xvals, bins=int(len(xvals)**0.5), histtype="step", label="found")
    caption = ("Found injections")

ax = plot.gca()
if args.log_x:
    ax.set_xscale("log")
if args.log_y:
    ax.set_yscale("log")
plot.xlabel(xlabels[args.plot_type], size='large')
plot.ylabel(ylabels[args.plot_type], size='large')
plot.grid(True)

fig_kwds = {}
if ".png" in args.output_file:
    fig_kwds["dpi"] = 150

if (".html" in args.output_file):
    plot.subplots_adjust(left=0.1, right=0.8, top=0.9, bottom=0.1)
    import mpld3, mpld3.plugins, mpld3.utils
    mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fmt=".5g"))
    legend =  mpld3.plugins.InteractiveLegendPlugin([points], ["found"],
                                                    alpha_unsel=0.1)
    mpld3.plugins.connect(fig, legend)

results.save_fig_with_metadata(fig, args.output_file,
                     fig_kwds=fig_kwds,
                     title="Injection %s recovery" % (args.error_param),
                     cmd=" ".join(sys.argv),
                     caption=caption)
logging.info("Done!")
