#!/usr/bin/env python
"""
Compute source probabilities using mchirp estimation method for all events
in a chunk with an IFAR above certain threshold.
"""
import h5py
import json
import tqdm
import argparse
import logging
import numpy as np
import pycbc
from pycbc.io import hdf
from pycbc.pnutils import mass1_mass2_to_mchirp_eta
from pycbc import mchirp_area

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--trigger-file', required=True)
parser.add_argument('--bank-file', required=True)
parser.add_argument('--single-detector-triggers', nargs='+', required=True)
parser.add_argument('--search-tag', required=True,
                    help='String to add to the output file names '
                         'identifying the search. eg: PYCBC_AllSky, '
                         'PYCBC_HighMass')
parser.add_argument('--ifar-threshold', type=float, default=None,
                    help='Select only candidate events with IFAR '
                         'above threshold.')
parser.add_argument('--include-mass-gap', action='store_true',
                    help='Option to include the Mass Gap region.')
parser.add_argument('--verbose', action='count')
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
mchirp_area.insert_args(parser)
args = parser.parse_args()

mc_area_args = mchirp_area.from_cli(args)

if not args.include_mass_gap:
    mass_bdary = mc_area_args['mass_bdary']
    assert mass_bdary['ns_max'] == mass_bdary['gap_max'], \
           'NS/Mass Gap boundaries should be the same.'

pycbc.init_logging(args.verbose)

TRIGGER_FILE = args.trigger_file.split('/')[-1]
GPS_START_TIME = TRIGGER_FILE.split('-')[2]
GPS_START_TIME_NS = TRIGGER_FILE.split('-')[3].split('.')[0]

logging.info('Using files: %s, %s, %s', TRIGGER_FILE,
             args.bank_file.split('/')[-1],
             (',').join([name.split('/')[-1] for name in
                         args.single_detector_triggers]))

dir_path = 'source_probability_results/CHUNK_%s-%s/' % (GPS_START_TIME,
                                                        GPS_START_TIME_NS)
pycbc.makedir(dir_path)
logging.info('Saving results in %s', dir_path)

fortrigs = hdf.ForegroundTriggers(args.trigger_file, args.bank_file,
                                  sngl_files=args.single_detector_triggers)

ifar = fortrigs.get_coincfile_array('ifar')
N_original = len(ifar)

if args.ifar_threshold:
    idx = ifar > args.ifar_threshold
    ifar = ifar[idx]
    logging.info('%i triggers out of %i with IFAR > %s' %
                 (len(ifar), N_original, str(args.ifar_threshold)))
else:
    idx = np.full(N_original, True)

mass1 = fortrigs.get_bankfile_array('mass1')[idx]
mass2 = fortrigs.get_bankfile_array('mass2')[idx]
mchirp,_ = mass1_mass2_to_mchirp_eta(mass1, mass2)
end_time = fortrigs.get_end_time()[idx]
sngl_snr = fortrigs.get_snglfile_array_dict('snr')
sngl_sigmasq = fortrigs.get_snglfile_array_dict('sigmasq')

for event in tqdm.trange(len(ifar)):
    # sngl_snr[ifo] contains 2 arrays: first contains SNR values and
    # second (True/False) boolean values. If boolean is False, the ifo was
    # not present in the coincidence
    ifos_event = [ifo for ifo in sngl_snr.keys() if sngl_snr[ifo][1][event]]
    snrs_event = [sngl_snr[ifo][0][event] for ifo in ifos_event]
    coinc_snr = sum([snr**2 for snr in snrs_event]) ** 0.5
    min_eff_dist = min([(sngl_sigmasq[ifo][0][event])**0.5
                       / sngl_snr[ifo][0][event] for ifo in ifos_event])
    probs = mchirp_area.calc_probabilities(mchirp[event], coinc_snr,
                                           min_eff_dist, mc_area_args)
    if not args.include_mass_gap:
        probs.pop("Mass Gap")
    ifo_names = ''.join(sorted(ifos_event))
    out_name = dir_path + '%s-%s-%i-1.json' % (ifo_names, args.search_tag,
                                               int(end_time[event]))
    with open(out_name, 'w') as outfile:
        json.dump(probs, outfile)
