#!/usr/bin/python
""" Plot the rate of triggers across the template bank
"""
import matplotlib
matplotlib.use('Agg')
import numpy, argparse, h5py, os, pylab, pycbc.pnutils

parser = argparse.ArgumentParser()
parser.add_argument('--trigger-files', nargs='+')
parser.add_argument('--bank-file')
parser.add_argument('--output-file')
parser.add_argument('--chisq-bins')
args = parser.parse_args()

bf = h5py.File(args.bank_file)
m1 = bf['mass1'][:]
m2 = bf['mass2'][:]

mc, et = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)

num_templates = len(m1) 
template_num = numpy.zeros(num_templates)
max_snr = numpy.zeros(num_templates)

num_trigs = 0
snrs = []
chisqs = []

for trig_filename in args.trigger_files:
    f = h5py.File(trig_filename)
    tid = f['template_id'][:]
    
    # Count triggers produced by each template
    tsort = tid.argsort()
    tid = tid[tsort]
    snr = f['snr'][:][tsort]
    chisq = f['chisq'][:][tsort]
    
    u = numpy.unique(tid)
    l = numpy.searchsorted(tid, u, side='left')
    r = numpy.searchsorted(tid, u, side='right')
    n = r - l
    template_num[u] += n
    
    snrs += [snr]
    chisqs += [chisq]
    
    num_trigs += len(tsort)

chisq = numpy.concatenate(chisqs) / (float(args.chisq_bins) * 2 - 2)
snr = numpy.concatenate(snrs)

pylab.figure()
pylab.scatter(snr[0:1000000], chisq[0:1000000])
pylab.xlim(6, 8)
pylab.ylim(.8, 3)
pylab.savefig('snrchi.png')
   
pylab.figure() 
pylab.scatter(et, m1+m2, c=template_num, s=template_num/template_num.max()*50)
pylab.ylabel('Total Mass')
pylab.xlabel('Eta')
pylab.colorbar()
pylab.savefig(args.output_file)
