#!/usr/bin/env python

"Plot gated time segments from inspiral HDF5 files."

import argparse
import logging
import h5py
import numpy as np
import matplotlib
matplotlib.use('agg')
import pylab as pl
from matplotlib.patches import Rectangle
import mpld3
import mpld3.plugins
from pycbc.results.color import ifo_color
from pycbc.results.mpld3_utils import MPLSlide
import pycbc.version


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--input-file', nargs='+', required=True,
                    help='Single-detector inspiral HDF5 files to take gating '
                         'data from.')
parser.add_argument('--output-file', required=True,
                    help='Destination file for the plot.')
args = parser.parse_args()

log_fmt = '%(asctime)s %(message)s'
log_date_fmt = '%Y-%m-%d %H:%M:%S'
logging.basicConfig(level=logging.INFO, format=log_fmt, datefmt=log_date_fmt)

gate_data = {}
have_gates = False
for fn in args.input_file:
    logging.info('Reading gates from %s', fn)
    f = h5py.File(fn, 'r')
    ifo = tuple(f.keys())[0]
    for gate_type in ['file', 'auto']:
        try:
            gate_time = f[ifo + '/gating/' + gate_type + '/time'][:]
            gate_width = f[ifo + '/gating/' + gate_type + '/width'][:]
            gate_pad = f[ifo + '/gating/' + gate_type + '/pad'][:]
        except KeyError:
            continue
        key = (ifo, gate_type)
        if not key in gate_data:
            gate_data[key] = []
        gate_data[key] += [g for g in zip(gate_time, gate_width, gate_pad)]
        have_gates = have_gates or (len(gate_data[key]) > 0)

mpld3.plugins.DEFAULT_PLUGINS = []
fig = pl.figure(figsize=(10, 5))
ax = fig.gca()
ax.set_yticks([])

if have_gates:
    t_min = t_max = None
    total_time = {}
    total_time_pad = {}
    for i, (ifo, gate_type) in enumerate(sorted(gate_data.keys())):
        logging.info('Plotting %s gates for %s', gate_type, ifo)
        gd = np.array(sorted(frozenset(gate_data[(ifo, gate_type)])))
        if len(gd) > 0:
            for s, e in zip(gd[:,0] - gd[:,1] - gd[:,2],
                            gd[:,0] + gd[:,1] + gd[:,2]):
                patch = Rectangle((s, i), e - s, 0.7, color=ifo_color(ifo),
                                  alpha=0.5)
                ax.add_patch(patch)
                if t_min is None or s < t_min:
                    t_min = s
                if t_max is None or e > t_max:
                    t_max = e
            total_time[(ifo, gate_type)] = sum(gd[:,1]) * 2
            total_time_pad[(ifo, gate_type)] = sum(gd[:,1] + gd[:,2]) * 2
        else:
            total_time[(ifo, gate_type)] = 0.
            total_time_pad[(ifo, gate_type)] = 0.

    for i, (ifo, gate_type) in enumerate(sorted(gate_data.keys())):
        label = '%s %s gates: %.1f s (zeroes), %.1f s (zeroes + pad)' \
                % (ifo, gate_type, total_time[(ifo, gate_type)],
                   total_time_pad[(ifo, gate_type)])
        ax.text(t_min, i + 0.75, label)

    ax.set_xlim(t_min, t_max)
    ax.set_ylim(0, len(gate_data.keys()))
    ax.set_xlabel('GPS Time (s)')
else:
    ax.text(0.5, 0.5, 'No gating data to plot', horizontalalignment='center',
            verticalalignment='center')
    ax.set_xticks([])
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fontsize=14, fmt='10.1f'))
mpld3.plugins.connect(fig, mpld3.plugins.BoxZoom())
mpld3.plugins.connect(fig, mpld3.plugins.Zoom())
mpld3.plugins.connect(fig, mpld3.plugins.Reset())
mpld3.save_html(fig, open(args.output_file, 'w'))
