#!/usr/bin/env python
"""
This program combines together a set of STATMAP files from disjoint times.
The resulting file would contain triggers from the full set of input files
"""
import numpy
import argparse
import logging
import h5py

import pycbc
import pycbc.version

def com(f, files, group):
    """ Combine the same column from multiple files into another file f"""
    f[group] = numpy.concatenate\
        ([fi[group][:] if group in fi else numpy.array([], dtype=numpy.uint32)\
         for fi in files])

def com_with_detector_key(f, files, group):
    """
    Combine data from multiple files where the group is dependent on which of
    the two ifos the data belongs to, e.g. time1, time2 or trigger_id1,
    trigger_id2. This function checks which group should be used
    for each file, as each file might have a different convention, and combines
    according to the attributes set in f.
    """
    # We have to be careful here, because in these cases the group, for
    # example time1 or time2 encodes a detector name, stored in the attributes.
    # It may not be the *same* detector for each file.

    # What detector are we dealing with
    if group.endswith('1'):
        ifo_name = f.attrs['detector_1']
    elif group.endswith('2'):
        ifo_name = f.attrs['detector_2']
    else:
        raise ValueError("Group name must end in 1 or 2, got %s" % group)

    # We defined the format using the first file, so we can just add that
    data_for_catting = [files[0][group][:]]
    # For the remaining files we must check the detector name
    for nfp in files[1:]:
        # What detector do we need in this file
        if nfp.attrs['detector_1'] == ifo_name:
            new_det_num = '1'
        elif nfp.attrs['detector_2'] == ifo_name:
            new_det_num = '2'
        else:
            raise ValueError("Cannot find detector %s in input file" % ifo_name)
        new_group = group[:-1] + new_det_num

        data_for_catting.append(nfp[new_group][:])
    f[group] = numpy.concatenate(data_for_catting)
    

parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--statmap-files', nargs='+',
                    help="List of coinc files to be redistributed")
parser.add_argument('--output-file', help="name of output file")
args = parser.parse_args()

pycbc.init_logging(args.verbose)

files = [h5py.File(n) for n in args.statmap_files]

# Start setting some of the attributes
f = h5py.File(args.output_file, "w")
# It's not guaranteed that all files will follow this, so be careful later!
f.attrs['detector_1'] = files[0].attrs['detector_1']
f.attrs['detector_2'] = files[0].attrs['detector_2']

f.attrs['background_time'] = \
    sum([cfp.attrs['background_time'] for cfp in files])
f.attrs['foreground_time'] = \
    sum([cfp.attrs['foreground_time'] for cfp in files])
f.attrs['background_time_exc'] = \
    sum([cfp.attrs['background_time_exc'] for cfp in files])
f.attrs['foreground_time_exc'] = \
    sum([cfp.attrs['foreground_time_exc'] for cfp in files])

# Combine segments
for key in files[0]['segments'].keys():
    com(f, files, 'segments/%s/start' % key)
    com(f, files, 'segments/%s/end' % key)

# copy over all the columns in the foreground group. A few special cases here
for fg_bg_key in ['foreground', 'background', 'background_exc']:
    for key in files[0][fg_bg_key].keys():
        if key not in ['time1', 'time2', 'trigger_id1', 'trigger_id2',
                       'fap', 'fap_exc']:
            com(f, files, '%s/%s' % (fg_bg_key,key))
        elif key in ['time1', 'time2', 'trigger_id1', 'trigger_id2']:
            # Check if all files use the same detector convention
            com_with_detector_key(f, files, '%s/%s' % (fg_bg_key,key))
        else:
            # Do not store FAP numbers ... Could be recalculated.
            continue
    
f.close()
