#!/bin/env python
import h5py, pycbc.types, numpy, argparse, pycbc.results, sys
import matplotlib; matplotlib.use('Agg')
import pylab

parser=argparse.ArgumentParser()
parser.add_argument('--single-template-file',
                    help="HDF output of pycbc_single_template")
parser.add_argument('--central-time', type=float,
                    help="Time to center the plot, optional",
                    default=0)
parser.add_argument('--window', type=float,
                    help="Time around central time to plot, optional. "
                         "Used with the central time option")
parser.add_argument('--plot-type', choices=['snr', 'chisq'], default='chisq')
parser.add_argument('--output-file')
args = parser.parse_args()

f = h5py.File(args.single_template_file, 'r')
y = f['chisq_boundaries'][:]
fig = pylab.figure()
ax = pylab.gca()

snr = pycbc.types.load_timeseries(args.single_template_file, group='snr')
chisq = pycbc.types.load_timeseries(args.single_template_file, group='chisq')

s = None
for i in range(len(f['chisq_bins'].keys())):
    ts = pycbc.types.load_timeseries(args.single_template_file, group='chisq_bins/%s' % i)

    delta_t = ts.delta_t
    st = ts.sample_times.numpy()   
    
    
    if args.plot_type == 'chisq':
        ts = ts.squared_norm() - snr.squared_norm() / (len(y) -1)
    elif args.plot_type == 'snr':
        ts = ts.squared_norm()
        
        
    ts = ts.numpy()
    x = numpy.append(st - delta_t / 2.0, [st[-1] + delta_t / 2.0]) - args.central_time
    l = y[i:i+2]
    ax.pcolorfast(x, numpy.array([i, i+1]), ts.reshape(1, len(ts)))

pylab.ylabel('Frequency (Hz)')

xlabel = 'Time (s)'
if args.central_time:
    xlabel += ' - %.2f' % args.central_time
    if args.window:
        pylab.xlim(xmin=-args.window, xmax=args.window)

pylab.xlabel(xlabel)
c = pylab.colorbar(ax.get_children()[2], ax=ax)

if args.plot_type == 'chisq':
    c.set_label("$\\rho_l^2 - \\rho^2/p$")
else:
    c.set_label("$\\rho_l^2$")

# Set the frequency label
fig.canvas.draw()
labels = [item.get_text() for item in ax.get_yticklabels()]
ax.set_yticklabels(['%.0f' % y[int(label)] for label in labels])

pycbc.results.save_fig_with_metadata(fig, args.output_file,
                cmd=' '.join(sys.argv),
                title="chisq timeseries for each bin",
                caption="Plot of the time series of each chisq bin")
                
