#!/bin/env python
import h5py, argparse, numpy, pycbc.events, logging, pycbc.events, pycbc.io
import pycbc.version

parser = argparse.ArgumentParser()
parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--coinc-files', nargs='+',
                    help="List of coinc files to be redistributed")
parser.add_argument('--background-bins', nargs='+',
                    help="Ordered list of mass bin upper boundaries. "
                         "An ordered list of type-boundary pairs, applied sequentially."
                         "Must provide a name (can be any unique string for tagging "
                         "purposes), the parameter to bin on, and the membership "
                         "condition via 'lt' / 'gt' operators. "
                         "Ex. name1:component:lt2 name2:total:lt15 name3:SEOBNRv2Peak:gt1000")
parser.add_argument('--f-lower',
                    help="Lower frequency cutoff for evaluating template duration. Should"
                         " be equal to the lower cutoff used in inspiral jobs")
parser.add_argument('--bank-file',
                    help="hdf format template bank file")
parser.add_argument('--output-files', nargs='+',
                    help="list of output file names, one for each mass bin")
args = parser.parse_args()

if 'duration' in args.background_bins and not args.f_lower:
    raise RuntimeError("Can't bin on template duration without f-lower!")

pycbc.init_logging(args.verbose)

if len(args.output_files) != len(args.background_bins):
    raise ValueError("Number of mass bins and output files does not match")

f = h5py.File(args.bank_file)
data = {'mass1':f['mass1'][:], 'mass2':f['mass2'][:],
        'spin1z':f['spin1z'][:], 'spin2z':f['spin2z'][:]}
if args.f_lower:
    data['f_lower'] = float(args.f_lower)

locs_dict = pycbc.events.background_bin_from_string(args.background_bins, data)
names = [b.split(':')[0] for b in args.background_bins]

d = pycbc.io.StatmapData(files=args.coinc_files)
logging.info('%s coinc triggers' % len(d))
for name, outname in zip(names, args.output_files):
    # select the coincs from only this bin and save to a single combined file
    locs = locs_dict[name]
    e = d.select(numpy.in1d(d.template_id, locs))
    logging.info('%s coincs in mass bin: %s' % (len(e), name))
    e.save(outname)
    f = h5py.File(outname)
    f.attrs['name'] = name
