#!/usr/bin/env python
# Copyright (C) 2020 Josh Willis
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

from glob import glob
from os import path
import argparse
import numpy as np
from h5py import File
from pycbc.events import ranking
import pycbc
import logging

def parse_injection_path(injname, basedir):
    dirpath = path.expandvars(basedir)
    dirpath = path.expanduser(dirpath)
    dirpath = path.normpath(dirpath)
    if not path.exists(dirpath):
       raise RuntimeError("Directory {0} does not exist".format(basedir))
    if not path.isdir(dirpath):
       raise RuntimeError("Path {0} does not specify a"
                          " directory".format(basedir))
    bankpath = path.join(dirpath, "bank")
    if not path.exists(bankpath):
        raise RuntimeError("There is no bank sub-directory of"
                           " {0}".format(basedir))
    if not path.isdir(bankpath):
        raise RuntimeError("Path {0} does not specify a"
                           " directory".format(bankpath))
    # The following is really not that robust
    bank_files = glob(bankpath+"/*-BANK2HDF-*.hdf")
    if len(bank_files) != 1:
        raise RuntimeError("There is not exactly one complete HDF bank file in"
                           " the path {0}".format(bankpath))
    bank_file = bank_files[0]
    dirpath = path.join(dirpath, "{0}_INJ_coinc".format(injname))
    if not path.exists(dirpath):
        raise RuntimeError("There is no sub-directory {0}_INJ_coinc in"
                           " {1}".format(injname, basedir))
    if not path.isdir(dirpath):
        raise RuntimeError("Path {0}/{1}_INJ_coinc is not a"
                           " directory".format(basedir, injname))
    injfile = glob(dirpath+"/*HDFINJFIND*")
    if len(injfile) == 0:
        raise RuntimeError("No found-injections file in directory"
                           " {0}".format(dirpath))
    if len(injfile) > 1:
        raise RuntimeError("More than one found-injections file in directory"
                           " {0}".format(dirpath))
    injfile = injfile[0]
    return dirpath, injfile, bank_file

# Below are the parameters we use in hashing---for injections especially, we may
# need to watch out for additional meaningful parameters. Note, it does not work
# to simply take all of the keys present in the injection data group of an
# HDFINJFIND file---some of them can contain tiny numerical differences that
# spoil their use in a hash.
hash_inj_params = ['coa_phase', 'distance', 'end_time', 'inclination',
                   'latitude', 'longitude', 'mass1', 'mass2',
                   'polarization', 'spin1x', 'spin1y', 'spin1z',
                   'spin2x', 'spin2y', 'spin2z']

class Injections(object):
    def __init__(self, inj_group):
        self.injgroup = inj_group
        self.inj_params = self.injgroup.keys()
        self.ninj = len(self.injgroup[self.inj_params[0]][:])
        self.inj_hashes = np.array([hash(v) for v in
                                    zip(*[self.injgroup[p] for p \
                                          in hash_inj_params])])
        if not np.all(self.inj_hashes):
            raise RuntimeError("Not all injection hashes were finite")

    def __eq__(self, other):
        if not isinstance(other, Injections):
            raise ValueError("{0} cannot be compared to Injections instance; is"
                             " not itself an instance".format(other))
        if self.ninj == other.ninj:
            if set(self.inj_params) == set(other.inj_params):
                return (self.inj_hashes == other.inj_hashes).all()
            else:
                return False
        else:
            return False

class RunInjectionResults(object):
    def __init__(self, injname, basedir,
                 single_detector_statistic='newsnr',
                 nmissed=10, ifar_threshold=None):
        self.injname = injname
        # First, get the necessary files and load them up
        dirpath, injfile, bankfile = parse_injection_path(injname, basedir)
        self.dirpath = dirpath
        self.nmissed = nmissed
        self.ifar_threshold = ifar_threshold
        logging.info("Reading injection directory {0}".format(self.dirpath))
        self.injfile = File(injfile, "r")
        self.bankfile = File(bankfile, "r")
        self.inj_dets = [self.injfile.attrs['detector_1'],
                         self.injfile.attrs['detector_2']]
        self.trigger_merge_dict = {}
        for ifo in self.inj_dets:
            self.trigger_merge_dict.update({ifo: {}})
            merge_file_list = glob(self.dirpath+
                                   "/{0}-HDF_TRIGGER_MERGE*".format(ifo))
            if len(merge_file_list) != 1:
                raise RuntimeError("There is not exactly one trigger file for"
                                   " IFO {0} in directory {1}".format(
                                   ifo, self.dirpath))
            self.trigger_merge_dict[ifo].update({'fileptr':
                                                 File(merge_file_list[0], "r")})
            dgroup = self.trigger_merge_dict[ifo]['fileptr'][ifo]
            self.trigger_merge_dict[ifo].update({'data': dgroup})
        # Next, load the information about the injections themselves.
        self.injgroup = self.injfile['injections']
        self.injections = Injections(self.injgroup)
        # Now, read the information from the found injection file. The trigger
        # merge files are typically quite large; we do not automatically load
        # them into memory
        self.found = {}
        self.found_after_vetoes = {}
        self.missed = {}
        for ifo in self.inj_dets:
            self.found.update({ifo: {}})
            self.found_after_vetoes.update({ifo: {}})
        for param in ['fap', 'fap_exc', 'ifar', 'ifar_exc',
                      'injection_index', 'stat']:
            self.found.update({param: self.injfile['found'][param][:]})
            self.found_after_vetoes.update(
                {param: self.injfile['found_after_vetoes'][param][:]})
        template_hashes = self.bankfile['template_hash'][:]
        for found_class in ['found', 'found_after_vetoes']:
            fdict = getattr(self, found_class)
            idx = self.injfile[found_class]['template_id'][:]
            fdict.update({'template_hash': template_hashes[idx]})
            # The ordering of detectors is arbitrary in the found
            # injection file, so we disentangle it and only record
            # via the IFO name
            d1 = self.injfile.attrs['detector_1']
            d2 = self.injfile.attrs['detector_2']
            fdict[d1].update({'trigger_time':
                              self.injfile[found_class]['time1'][:]})
            fdict[d1].update({'trigger_id':
                              self.injfile[found_class]['trigger_id1'][:]})
            fdict[d2].update({'trigger_time':
                              self.injfile[found_class]['time2'][:]})
            fdict[d2].update({'trigger_id':
                              self.injfile[found_class]['trigger_id2'][:]})
        for param in ['after_vetoes', 'all', 'within_analysis']:
            self.missed.update({param: self.injfile['missed'][param][:]})
        # Record what the single detector statistic is
        self.single_detector_statistic = single_detector_statistic
        # Next, read information from the two trigger merge files
        logging.info("Reading single detector triggers from {0}; finding"
                     "minimum of {1}".format(
                         self.dirpath, self.single_detector_statistic))
        for fdict in [self.found, self.found_after_vetoes]:
            for ifo in self.inj_dets:
                for param in ['chisq', 'chisq_dof', 'snr', 'sg_chisq',
                              'sigmasq', 'coa_phase', 'end_time']:
                    if param in self.trigger_merge_dict[ifo]['data'].keys():
                        idx = fdict[ifo]['trigger_id']
                        vals = \
                            self.trigger_merge_dict[ifo]['data'][param][:][idx]
                        # A hack:
                        if param == 'chisq_dof':
                            vals = vals / 2 + 1
                        fdict[ifo].update({param: vals})
                # Initialize statclass with an empty file list
                curr_rank = ranking.get_sngls_ranking_from_trigs(
                    fdict[ifo],
                    single_detector_statistic
                )
                fdict[ifo].update({single_detector_statistic:
                                   curr_rank})
            statname = self.single_detector_statistic
            fdict.update({'minimum_single_detector_statistic': 
                          np.minimum(fdict[self.inj_dets[0]][statname],
                                     fdict[self.inj_dets[1]][statname])})

    def close(self):
        self.injfile.close()
        self.bankfile.close()
        for ifo in self.inj_dets:
            self.trigger_merge_dict[ifo]['fileptr'].close()

    def __str__(self):
        return "RunInjectionResults instance for injection set {0} from" \
            " directory {1}".format(self.injname, self.dirpath)

    def __eq__(self, other):
        if not isinstance(other, RunInjectionResults):
            raise ValueError("{0} cannot be compared to RunInjectionResults"
                             " instance; is not itself an instance".format(
                                 other))
        return ((set(self.inj_dets) == set(other.inj_dets)) and
                (self.injections == other.injections))

    def write_to_hdf(self, outpath):
        outfp = File(outpath, "w")
        outfp.attrs['injection_set'] = self.injname
        outfp.attrs['original_directory'] = self.dirpath
        outfp.attrs['single_detector_statistic'] = \
                                                self.single_detector_statistic
        outfp.attrs['detector_1'] = self.inj_dets[0]
        outfp.attrs['detector_2'] = self.inj_dets[1]
        self.injfile.copy('injections', outfp)
        missed_group = outfp.create_group('missed')
        for col in self.missed.keys():
            missed_group.create_dataset(col, data=self.missed[col])
        for fname in ['found', 'found_after_vetoes']:
            fgroup = outfp.create_group(fname)
            fdict = getattr(self, fname)
            keys_no_ifos = fdict.keys()
            for ifo in self.inj_dets:
                keys_no_ifos.remove(ifo)
            for col in keys_no_ifos:
                fgroup.create_dataset(col, data=fdict[col])
            for ifo in self.inj_dets:
                igroup = fgroup.create_group(ifo)
                for col in fdict[ifo].keys():
                    igroup.create_dataset(col, data=fdict[ifo][col])
        outfp.close()


def populate_injection_group(fgroup, fidx, fdict, ifos):
    keys = fdict.keys()
    for ifo in ifos:
        keys.remove(ifo)
    for key in keys:
        fgroup.create_dataset(key, data=fdict[key][fidx])
    for ifo in ifos:
        igroup = fgroup.create_group(ifo)
        for key in fdict[ifo].keys():
            igroup.create_dataset(key, data=fdict[ifo][key][fidx])

def compare_found_injs(reference_run, comparison_run, outfp):
    ifos = reference_run.inj_dets
    for found_class in ['found', 'found_after_vetoes']:
        logging.info("Comparing found injections for class"
                     " '{0}'".format(found_class))
        fgroup = outfp.create_group(found_class)
        ref_found = getattr(reference_run, found_class)
        com_found = getattr(comparison_run, found_class)
        ref_injs = ref_found['injection_index']
        com_injs = com_found['injection_index']
        # Note that intersect1d returns sorted, unique elements in both arrays
        # Because we set 'return_indices' to True, we also get the index into
        # each array of these common elements
        logging.info("Calculating indices of found/missed injections between"
                     " reference and comparison runs")
        both, ref_both_idx, com_both_idx = np.intersect1d(ref_injs,
                                                          com_injs,
                                                          return_indices=True)
        # Now we want to find the indices of those injections found in one
        # run but not the other
        ref_only_idx = np.isin(ref_injs, com_injs, invert=True)
        com_only_idx = np.isin(com_injs, ref_injs, invert=True)
        # We now have all of our indices, so create the hierarchy of data groups
        ref_both_group = fgroup.create_group('found_in_both/reference')
        com_both_group = fgroup.create_group('found_in_both/comparison')
        ref_only_group = fgroup.create_group('found_reference_only')
        com_only_group = fgroup.create_group('found_comparison_only')
        # Now fill these groups with the various data they should contain
        logging.info("Writing injections found in both, reference run")
        populate_injection_group(ref_both_group, ref_both_idx, ref_found, ifos)
        logging.info("Writing injections found only in reference run")
        populate_injection_group(ref_only_group, ref_only_idx, ref_found, ifos)
        logging.info("Writing injections found in both, comparison run")
        populate_injection_group(com_both_group, com_both_idx, com_found, ifos)
        logging.info("Writing injections found only in comparison run")
        populate_injection_group(com_only_group, com_only_idx, com_found, ifos)

    return

def compare_missed_injs(reference_run, comparison_run, outfp):
    logging.info("Comparing missed injections after vetoes")
    ifos = reference_run.inj_dets
    fgroup = outfp.create_group('missed_after_vetoes')
    ref_missed = reference_run.missed['after_vetoes']
    com_missed = comparison_run.missed['after_vetoes']
    ref_found_injs = reference_run.found_after_vetoes['injection_index']
    com_found_injs = comparison_run.found_after_vetoes['injection_index']
    if reference_run.ifar_threshold is not None:
        ifars = reference_run.found_after_vetoes['ifar']
        idx = ifars < reference_run.ifar_threshold
        ref_missed = np.append(ref_missed, ref_found_injs[idx])
        ref_above = reference_run.found_after_vetoes['injection_index'][~idx]
    else:
        ref_above = reference_run.found_after_vetoes['injection_index']
    if comparison_run.ifar_threshold is not None:
        ifars = comparison_run.found_after_vetoes['ifar']
        idx = ifars < comparison_run.ifar_threshold
        com_missed = np.append(com_missed, com_found_injs[idx])
        com_above = comparison_run.found_after_vetoes['injection_index'][~idx]
    else:
        com_above = comparison_run.found_after_vetoes['injection_index']
    # Now find the indices of the required loudest missed injections
    # Note that we are assuming that the injection table contains
    # optimal SNR columns, which may not always be true.
    logging.info("Finding {0} loudest missed injections in"
                 " reference run".format(reference_run.nmissed))
    o1 = reference_run.injgroup['optimal_snr_1'][:][ref_missed]
    o2 = reference_run.injgroup['optimal_snr_2'][:][ref_missed]
    ref_dec_snr = np.minimum(o1, o2)
    ref_sort = ref_dec_snr.argsort()
    ref_sort = ref_sort[::-1]
    ref_missed_n = ref_missed[ref_sort][0:reference_run.nmissed]
    logging.info("Finding {0} loudest missed injections in"
                 " comparison run".format(comparison_run.nmissed))
    o1 = comparison_run.injgroup['optimal_snr_1'][:][com_missed]
    o2 = comparison_run.injgroup['optimal_snr_2'][:][com_missed]
    com_dec_snr = np.minimum(o1, o2)
    com_sort = com_dec_snr.argsort()
    com_sort = com_sort[::-1]
    com_missed_n = com_missed[com_sort][0:comparison_run.nmissed]
    # Now see which missed injections in one run were found in the other
    logging.info("Finding loud injections missed only in reference run")
    ref_missed_only_idx = np.isin(ref_missed_n, com_above)
    ref_missed_only = ref_missed_n[ref_missed_only_idx]
    ref_missed_only_locs = np.where(ref_missed_only_idx)[0]
    com_ii = comparison_run.found_after_vetoes['injection_index']
    ref_missed_com_idx = np.isin(com_ii, ref_missed_only)
    logging.info("Finding loud injections missed only in comparison run")
    com_missed_only_idx = np.isin(com_missed_n, ref_above)
    com_missed_only = com_missed_n[com_missed_only_idx]
    com_missed_only_locs = np.where(com_missed_only_idx)[0]
    ref_ii = reference_run.found_after_vetoes['injection_index']
    com_missed_ref_idx = np.isin(ref_ii, com_missed_only)
    # We now have all of our indices, so create the hierarchy of data groups
    ref_only_group = outfp.create_group(
        "missed_after_vetoes/missed_only_reference")
    com_only_group = outfp.create_group(
        "missed_after_vetoes/missed_only_comparison")
    # Now fill these groups with the various data they should contain
    logging.info("Writing loud injections missed only in reference run")
    ref_only_group.create_dataset("injection_index", data=ref_missed_only)
    ref_only_group.create_dataset("loudest_rank", data=ref_missed_only_locs)
    ref_missed_com_group = ref_only_group.create_group("comparison")
    populate_injection_group(ref_missed_com_group, ref_missed_com_idx,
                             comparison_run.found_after_vetoes, ifos)
    logging.info("Writing loud injections missed only in comparison run")
    com_only_group.create_dataset("injection_index", data=com_missed_only)
    com_only_group.create_dataset("loudest_rank", data=com_missed_only_locs)
    com_missed_ref_group = com_only_group.create_group("reference")
    populate_injection_group(com_missed_ref_group, com_missed_ref_idx,
                             reference_run.found_after_vetoes, ifos)

    return


parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(usage="",
    description="Detailed comparison of a specified injection set between two"
                                 " PyCBC runs")
parser.add_argument("--injection-label", type=str, required=True,
                    help="Label of injection set")
parser.add_argument("--reference-dir", type=str, required=True,
                    help="Directory containing reference run of this injection"
                    " set. This should contain a 'bank/' sub-directory, as well"
                    " as a sub-directory '<injection_label>_INJ_coinc'")
parser.add_argument("--comparison-dir", type=str, required=True,
                    help="Directory containing comparison run of this"
                    " injection set.  This should contain a 'bank/' sub-"
                    "directory, as well as a sub-directory"
                    " '<injection_label>_INJ_coinc'")
parser.add_argument("--output-file", type=str, required=True,
                    help="Name of HDF output file in which to store results")
parser.add_argument("--number-missed", type=int,
                    default=10, required=True,
                    help="Number of the loudest missed injections to compare"
                    " between runs")
parser.add_argument('--ifar-threshold', type=float, default=None,
                    help="If given, also followup injections with ifar smaller "
                         "than this threshold.")
parser.add_argument("--single-detector-statistic", type=str, default='newsnr',
                    choices=ranking.sngls_ranking_function_dict.keys(),
                    help="Which single-detector statistic to calculate for"
                    " found injections")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print extra debugging information")

args = parser.parse_args()
pycbc.init_logging(args.verbose)

ref_found_injs = RunInjectionResults(args.injection_label, args.reference_dir,
                                     args.single_detector_statistic,
                                     args.number_missed, args.ifar_threshold)
com_found_injs = RunInjectionResults(args.injection_label, args.comparison_dir,
                                     args.single_detector_statistic,
                                     args.number_missed, args.ifar_threshold)

logging.info("Comparing injection parameters between reference and comparison"
             " runs")
same_inj = (ref_found_injs == com_found_injs)
if not same_inj:
    raise RuntimeError("Reference and comparison runs did not perform the same"
                       " injections")
else:
    outfp = File(args.output_file, "w")
    ref_found_injs.injfile.copy('injections', outfp)

outfp.attrs['injection_label'] = args.injection_label
outfp.attrs['reference_dir'] = args.reference_dir
outfp.attrs['comparison_dir'] = args.comparison_dir
outfp.attrs['single_detector_statistic'] = args.single_detector_statistic
outfp.attrs['detector_1'] = ref_found_injs.inj_dets[0]
outfp.attrs['detector_2'] = ref_found_injs.inj_dets[1]
outfp.attrs['number_missed'] = args.number_missed
outfp.attrs['ifar_threshold'] = args.ifar_threshold

compare_found_injs(ref_found_injs, com_found_injs, outfp)
compare_missed_injs(ref_found_injs, com_found_injs, outfp)

outfp.close()

logging.info("Finished")
