#!/usr/bin/env python

"""Fast and simple sky localization for a specific compact binary merger event.
Runs a single-template matched filter on strain data from a number of detectors
and calls BAYESTAR to produce a sky localization from the resulting set of SNR
time series."""

import os
import argparse
import logging
import subprocess
import tempfile
import shutil
import numpy as np
import h5py
import pycbc
from pycbc.filter import sigmasq
from pycbc.io import live, WaveformArray
from pycbc.types import (TimeSeries, FrequencySeries, load_timeseries,
                         load_frequencyseries, MultiDetMultiColonOptionAction,
                         MultiDetOptionAction, MultiDetOptionAppendAction)
from pycbc.pnutils import nearest_larger_binary_number
from pycbc.waveform.spa_tmplt import spa_length_in_time
from pycbc import frame, waveform, dq
from pycbc.psd import interpolate
from ligo.gracedb.rest import GraceDb


def default_frame_type(time, ifo):
    """Sensible defaults for frame types based on interferometer and time.
    """
    if time < 1137254517:
        # O1
        if ifo in ['H1', 'L1']:
            return ifo + '_HOFT_C02'
    elif time >= 1164556717 and time < 1235433618:
        # O2
        if ifo == 'V1':
            return 'V1O2Repro2A'
        elif ifo in ['H1', 'L1']:
            return ifo + '_CLEANED_HOFT_C02'
    elif time >= 1235433618:
        # O3
        if ifo == 'V1':
            return 'V1Online'
        elif ifo in ['H1', 'L1']:
            return ifo + '_HOFT_CLEAN_SUB60HZ_C01'
    raise ValueError('Detector {} not supported at time {}'.format(ifo, time))

def default_channel_name(time, ifo):
    """Sensible defaults for channel name based on interferometer and time.
    """
    if time < 1137254517:
        # O1
        if ifo in ['H1', 'L1']:
            return ifo + ':DCS-CALIB_STRAIN_C02'
    elif time > 1164556717 and time < 1235433618:
        # O2
        if ifo == 'V1':
            return ifo + ':Hrec_hoft_V1O2Repro2A_16384Hz'
        elif ifo in ['H1', 'L1']:
            return ifo + ':DCH-CLEAN_STRAIN_C02'
    elif time >= 1235433618:
        # O3
        if ifo == 'V1':
            return ifo + ':Hrec_hoft_16384Hz'
        elif ifo in ['H1', 'L1']:
            return ifo + ':DCS-CALIB_STRAIN_CLEAN_SUB60HZ_C01'
    raise ValueError('Detector {} not supported at time {}'.format(ifo, time))

def main(trig_time, mass1, mass2, spin1z, spin2z, f_low, f_upper, sample_rate,
         ifar, ifos, thresh_SNR, ligolw_skymap_output='.',
         ligolw_event_output=None, snr_series_duration=0.1465,
         frame_types=None, channel_names=None,
         segment_source=None, segment_server=None,
         gracedb_server=None, test_event=True,
         custom_frame_files=None, approximant=None, detector_state=None,
         veto_definer=None, injection_file=None, fake_strain=None,
         fake_strain_seed=None, rescale_loglikelihood=None):

    if not test_event and not gracedb_server:
        raise RuntimeError('a GraceDB URL must be specified if not a test event.')

    tmpdir = tempfile.mkdtemp()

    if len(trig_time) == 1 and ':' not in trig_time[0]:
        # single approximate time given
        mean_trig_time = float(trig_time[0])
        trig_time_mode = 'approximate'
    else:
        # precise per-detector times given
        trig_time = {kv.split(':')[0]: float(kv.split(':')[1])
                     for kv in trig_time}
        mean_trig_time = np.mean([trig_time[ifo] for ifo in trig_time])
        trig_time_mode = 'exact'

    if frame_types is None:
        frame_types = {}
    if channel_names is None:
        channel_names = {}
    for ifo in ifos:
        if ifo not in frame_types:
            frame_types[ifo] = default_frame_type(mean_trig_time, ifo)
        if ifo not in channel_names:
            if fake_strain[ifo] is not None:
                channel_names[ifo] = ifo + ':FAKE_DATA'
            else:
                channel_names[ifo] = default_channel_name(mean_trig_time, ifo)

    # resulting files will be tagged with this string
    file_name_tag = '{:.0f}'.format(mean_trig_time)

    # parameters to fit a single-template inspiral job nicely
    # around the trigger time, without requiring too much data

    pad_data = 8

    # Padding set by 16 * 2 for psd and buffer for other filtering
    pad = 40
    template_duration = spa_length_in_time(mass1=mass1, mass2=mass2,
                                           f_lower=f_low, phase_order=-1)
    segment_length = int(nearest_larger_binary_number(template_duration + pad))
    # set minimum so there is enough for a psd estimate
    if segment_length < 128:
        segment_length = 128
    logging.info('Using segment length: %s', segment_length)

    gps_end_time = int(mean_trig_time + pad / 2)
    gps_start_time = gps_end_time - segment_length
    logging.info("Using data: %s-%s", gps_start_time, gps_end_time)

    # if requested, remove detectors with missing/vetoed data
    # over the required segment
    required_duration = (gps_end_time + pad_data) - (gps_start_time - pad_data)
    unavailable_ifos = set()
    for ifo in ifos:
        if detector_state is None or ifo not in detector_state:
            continue
        on_segs = dq.query_str(ifo, detector_state[ifo],
                               gps_start_time - pad_data,
                               gps_end_time + pad_data,
                               veto_definer=veto_definer,
                               source=segment_source,
                               server=segment_server)
        if abs(on_segs) < required_duration:
            logging.info('Excluding %s due to missing or vetoed data', ifo)
            unavailable_ifos.add(ifo)
    ifos = sorted(set(ifos) - unavailable_ifos)
    if not ifos:
        raise RuntimeError('All detectors have been excluded due to '
                           'missing or vetoed data')

    highpass_frequency = int(f_low * 0.7)
    logging.info("Setting highpass: %s Hz", highpass_frequency)

    procs = []
    st_psd_paths = {}
    st_out_paths = {}
    for ifo in ifos:
        # compose the command line for the single-template process
        st_psd_paths[ifo] = os.path.join(
                tmpdir, 'PSD_{}_{}.txt'.format(file_name_tag, ifo))
        st_out_paths[ifo] = os.path.join(
                tmpdir, 'SNRTS_{}_{}.hdf'.format(file_name_tag, ifo))

        command = ["pycbc_single_template",
        "--verbose",
        "--segment-length", str(segment_length),
        "--segment-start-pad", "0",
        "--segment-end-pad", "0",
        "--psd-estimation", "median",
        "--psd-segment-length", "16",
        "--psd-segment-stride", "8",
        "--psd-inverse-length", "16",
        "--order", "-1",
        "--taper-data", "1",
        "--allow-zero-padding",
        "--autogating-threshold", "100",
        "--autogating-cluster", "0.5",
        "--autogating-width", "0.25",
        "--autogating-taper", "0.25",
        "--autogating-pad", "16",
        "--strain-high-pass", str(highpass_frequency),
        "--pad-data", str(pad_data),
        "--chisq-bins", '0.72*get_freq("fSEOBNRv4Peak",params.mass1,params.mass2,params.spin1z,params.spin2z)**0.7',
        "--sample-rate", str(sample_rate),
        "--mass1", str(mass1),
        "--mass2", str(mass2),
        "--spin1z", str(spin1z),
        "--spin2z", str(spin2z),
        "--low-frequency-cutoff", str(f_low),
        "--gps-start-time", str(gps_start_time),
        "--gps-end-time", str(gps_end_time),
        "--trigger-time", '{:.6f}'.format(trig_time[ifo] if ifo in trig_time
                                          else mean_trig_time),
        "--window", '1',
        "--channel-name", channel_names[ifo],
        "--psd-output", st_psd_paths[ifo],
        "--output-file", st_out_paths[ifo]]

        command.append("--approximant")
        for apx in approximant:
            command.append(apx)
        if f_upper:
            command.append("--high-frequency-cutoff")
            command.append(str(f_upper))

        if injection_file:
            command.append("--injection-file")
            command.append(injection_file)

        if fake_strain[ifo] is not None:
            # use fake strain for this ifo
            command.append("--fake-strain")
            command.append(fake_strain[ifo])
            if fake_strain_seed[ifo] is not None:
                command.append("--fake-strain-seed")
                command.append(str(fake_strain_seed[ifo]))
        elif custom_frame_files is None or ifo not in custom_frame_files:
            # use default guesses for this ifo
            command.append("--frame-type")
            command.append(frame_types[ifo])
        else:
            # use custom frame files for this ifo
            logging.info("Check if the segment in the custom frame file is safe")
            fr_start_times = []
            fr_end_times = []
            for custom_frame in custom_frame_files[ifo]:
                try:
                    frame_data = frame.read_frame(custom_frame, channel_names[ifo])
                except RuntimeError:
                    msg = 'Channel name in {} is not {}'.format(
                            custom_frame, channel_names[ifo])
                    raise RuntimeError(msg)
                fr_start_times.append(frame_data.start_time)
                fr_end_times.append(frame_data.end_time)
            if gps_start_time < min(fr_start_times):
                msg = 'Start time of {} must be before the required start time {}'
                msg = msg.format(min(fr_start_times), gps_start_time)
                raise RuntimeError(msg)
            if max(fr_end_times) < gps_end_time:
                msg = 'End time of {} must be after the required end time {}'
                msg = msg.format(max(fr_end_times), gps_end_time)
                raise RuntimeError(msg)
            if mean_trig_time < min(fr_start_times) \
                    or max(fr_end_times) < mean_trig_time:
                msg = 'Trigger time must be within your frame file(s)'
                raise RuntimeError(msg)

            command.append("--frame-files")
            for custom_frame in custom_frame_files[ifo]:
                command.append(custom_frame)

        # create and open a file to record the
        # single-template process' stdout and stderr
        log_path = os.path.join(
                tmpdir,
                'pycbc_single_template_{}_{}_log.txt'.format(file_name_tag, ifo))
        log_file = open(log_path, 'w')
        log_file.write(' '.join(command) + '\n\n')
        log_file.flush()

        # start the single-template process
        proc = subprocess.Popen(command, stdout=log_file, stderr=log_file)
        procs.append((ifo, proc, log_path, log_file))

    logging.info('Waiting for pycbc_single_template to complete')

    snr_errors = False
    for ifo, proc, log_path, log_file in procs:
        proc.wait()
        if proc.returncode != 0:
            logging.error('%s pycbc_single_template failed, see %s',
                          ifo, log_path)
            snr_errors = True
        log_file.close()
    if snr_errors:
        raise RuntimeError('one or more pycbc_single_template failed, '
                           'please see the messages above for details.')

    logging.info('Gathering info from pycbc_single_template results')

    coinc_results = {}
    subthreshold_ifos = []
    network_snr2 = 0.
    for ifo in ifos:
        snr_series = load_timeseries(st_out_paths[ifo], group='snr')
        chisq_series = load_timeseries(st_out_paths[ifo], group='chisq')
        psd_series = load_frequencyseries(st_psd_paths[ifo])
        psd_series *= pycbc.DYN_RANGE_FAC ** 2.0
        psd_series = psd_series.astype(np.float32)
        template_series = load_frequencyseries(st_out_paths[ifo],
                                               group='template')

        key = 'foreground/{}/'.format(ifo)
        coinc_results[key + 'mass1'] = mass1
        coinc_results[key + 'mass2'] = mass2
        coinc_results[key + 'spin1z'] = spin1z
        coinc_results[key + 'spin2z'] = spin2z
        coinc_results[key + 'f_lower'] = f_low
        if f_upper:
            coinc_results[key + 'f_final'] = f_upper
        # required by SingleCoincForGraceDB
        coinc_results[key + 'template_id'] = -1
        coinc_results[key + 'snr_series'] = snr_series
        coinc_results[key + 'psd_series'] = psd_series
        coinc_results[key + 'sigmasq'] = sigmasq(template_series, psd_series)

        if trig_time_mode == 'approximate':
            # find the absolute max SNR peak in the entire SNR time
            # series; it will be the trigger time in this detector
            # if above threshold
            # FIXME there is no guarantee that the light travel times
            # between the peaks in different detectors will be physical.
            # Any good ideas for how to enforce this in a simple way?
            snr_series_peak = np.argmax(abs(snr_series))
            snr = abs(snr_series[snr_series_peak])
            logging.info('%s peak SNR: %.2f', ifo, snr)
            if snr < thresh_SNR:
                subthreshold_ifos.append(ifo)
                continue
        elif trig_time_mode == 'exact':
            if ifo not in trig_time:
                # treat this detector as subthreshold
                subthreshold_ifos.append(ifo)
                continue
            snr_series_peak = int((trig_time[ifo] - snr_series.start_time) \
                    / snr_series.delta_t)
            assert snr_series_peak >= 0
            assert snr_series_peak < len(snr_series)
            snr = abs(snr_series[snr_series_peak])
            logging.info('%s SNR at given time: %.2f', ifo, snr)

        coinc_results[key + 'end_time'] = \
                float(snr_series.sample_times[snr_series_peak])
        coinc_results[key + 'snr'] = snr
        coinc_results[key + 'coa_phase'] = np.angle(snr_series[snr_series_peak])
        coinc_results[key + 'chisq'] = chisq_series[snr_series_peak]

        network_snr2 += snr ** 2

    ifos = sorted(set(ifos) - set(subthreshold_ifos))
    if not ifos:
        raise RuntimeError('all interferometers have SNR below threshold.'
                           ' Is this really a candidate event?')

    coinc_results['foreground/stat'] = network_snr2 ** 0.5
    coinc_results['foreground/ifar'] = ifar

    # BAYESTAR's convention for end time of subthreshold detectors
    subthreshold_sngl_time = np.mean(
            [coinc_results['foreground/%s/end_time' % ifo] for ifo in ifos])

    window_bins = int(snr_series_duration * sample_rate)
    followup_data = {}
    for ifo in ifos + subthreshold_ifos:
        key = 'foreground/{}/'.format(ifo)

        if ifo in subthreshold_ifos:
            coinc_results[key + 'end_time'] = subthreshold_sngl_time

        followup_data[ifo] = {
            'psd': interpolate(coinc_results[key + 'psd_series'], 0.25)
        }

        # now go back to the SNR time series and cut out
        # a shorter piece centered on the end time
        time = coinc_results[key + 'end_time']
        snr_series = coinc_results[key + 'snr_series']
        peak_bin = int((time - snr_series.start_time) / snr_series.delta_t)
        max_bin = peak_bin + window_bins + 1
        if max_bin > len(snr_series):
            window_bins = len(snr_series) - peak_bin
            max_bin = len(snr_series)
            min_bin = peak_bin - window_bins + 1
        else:
            min_bin = peak_bin - window_bins
        if min_bin < 0:
            window_bins = peak_bin
            max_bin = peak_bin + window_bins + 1
            min_bin = peak_bin - window_bins

        followup_data[ifo]['snr_series'] = \
                snr_series[min_bin:max_bin].astype(np.complex64)

    kwargs = {'psds': {ifo: followup_data[ifo]['psd']
                       for ifo in ifos + subthreshold_ifos},
              'low_frequency_cutoff': f_low,
              'high_frequency_cutoff': f_upper,
              'followup_data': followup_data,
              'channel_names': channel_names}

    doc = live.SingleCoincForGraceDB(ifos, coinc_results, **kwargs)

    ligolw_file_path = os.path.join(tmpdir, file_name_tag + '.xml')

    if gracedb_server:
        comments = ['Manual followup up from PyCBC']
        gid = doc.upload(ligolw_file_path,
                         gracedb_server=gracedb_server,
                         testing=test_event,
                         extra_strings=comments)
        gracedb = GraceDb(gracedb_server)
    else:
        doc.save(ligolw_file_path)

    # Defining the approximant to feed into BAYESTAR

    row = WaveformArray.from_kwargs(
        mass1=mass1,
        mass2=mass2,
        spin1z=spin1z,
        spin2z=spin2z)

    bs_approximant = waveform.bank.parse_approximant_arg(approximant, row)[0]

    # BAYESTAR uses TaylorF2 instead of SPAtmplt
    if bs_approximant == 'SPAtmplt':
        bs_approximant = 'TaylorF2'

    # run BAYESTAR to generate the skymap
    cmd = ['bayestar-localize-coincs',
           ligolw_file_path,
           ligolw_file_path,
           '--waveform', str(bs_approximant),
           '--f-low', str(f_low),
           '-o', tmpdir]
    if rescale_loglikelihood is not None:
        cmd += ['--rescale-loglikelihood', rescale_loglikelihood]
    subprocess.call(cmd)

    skymap_fits_name = os.path.join(tmpdir, '0.fits')

    # plot the skymap

    skymap_plot_path = os.path.join(ligolw_skymap_output,
                                    file_name_tag + '_skymap.png')
    cmd = ['ligo-skymap-plot',
           skymap_fits_name,
           '-o', skymap_plot_path,
           '--contour', '50', '90',
           '--annotate']
    subprocess.call(cmd)

    final_fits_path = os.path.join(ligolw_skymap_output,
                                   file_name_tag + '.fits')
    shutil.move(skymap_fits_name, final_fits_path)

    if gracedb_server:
        gracedb.writeLog(gid, 'Bayestar skymap FITS file upload',
                         filename=final_fits_path,
                         tag_name=['sky_loc'],
                         displayName=['Bayestar FITS skymap'])
        gracedb.writeLog(gid, 'Bayestar skymap plot upload',
                         filename=skymap_plot_path,
                         tag_name=['sky_loc'],
                         displayName=['Bayestar skymap plot'])

    if ligolw_event_output:
        shutil.move(ligolw_file_path, ligolw_event_output)
    shutil.rmtree(tmpdir)

if __name__ == '__main__':
    pycbc.init_logging(True)

    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--version', action=pycbc.version.Version)
    # note that I am not using a MultiDetOptionAction for --trig-time as I
    # explicitly want to handle cases like `--trig-time 1234` and
    # `--trig-time H1:1234 L1:1234` in different ways
    parser.add_argument('--trig-time', required=True, nargs='+',
                        help='GPS time of trigger. Can be a single value, or a list of per-detector times'
                             'using the format H1:123456 L1:123457. In the first case, the given time is only '
                             'an approximate guess of the trigger time, and '
                             'the code will try to find the maximum-SNR peaks '
                             'at each detector on its own. The given time should be within 1 s '
                             'of the peak. This procedure becomes unreliable '
                             'for low-SNR triggers. When times are given at specific detectors, instead, they '
                             'literally indicate which sample of the '
                             'SNR time series should be taken as peak SNR at '
                             'the given detector. Unspecified detectors will '
                             'be considered subthreshold and their time will '
                             'be chosen as the mean of the given times.')
    parser.add_argument('--mass1', type=float, required=True)
    parser.add_argument('--mass2', type=float, required=True)
    parser.add_argument('--spin1z', type=float, required=True)
    parser.add_argument('--spin2z', type=float, required=True)
    parser.add_argument('--approximant', type=str, nargs='+',
                        default=["SPAtmplt:mtotal<4", "SEOBNRv4_ROM:else"])
    parser.add_argument('--f-low', type=float,
                        help="lower frequency cut-off (float)", default=20.0)
    parser.add_argument('--f-upper', type=float,
                        help="upper frequency cut-off (float)")
    parser.add_argument('--sample-rate', type=int, default=2048,
                        help='sample rate of the data')
    parser.add_argument('--ifar', type=float, help="false alarm rate (float)",
                        default=1)
    parser.add_argument('--thresh-SNR', type=float, default=4.5,
                        help='When `--trig-time` is given a single (approximate) time, only '
                             'detectors with SNR above this threshold are used '
                             'to find the precise trigger times. The others are still '
                             'used for sky localization, but do not determine '
                             'the trigger time.')
    parser.add_argument('--ifos', type=str, required=True, nargs='+',
                        help='List of interferometer names, e.g. H1 L1')
    parser.add_argument('--frame-type', type=str, nargs='+')
    parser.add_argument('--channel-name', type=str, nargs='+')
    parser.add_argument('--detector-state', type=str, nargs='+',
                        action=MultiDetMultiColonOptionAction,
                        help='Use the segment database to exclude detectors '
                             'with missing or vetoed data')
    parser.add_argument('--veto-definer', type=str,
                        help='Optional path to a veto definer file to resolve '
                             'macro flags in --detector-state')
    parser.add_argument('--segment-source', choices=['any', 'GWOSC', 'dqsegdb'],
                        default='any',
                        help='When using `--detector-state`, this controls '
                             'whether to query the GWOSC or dqsegdb server')
    parser.add_argument('--segment-server', metavar='URL',
                        help='URL of segment server to query when using '
                             '`--detector-state`',
                        default='https://segments.ligo.org')
    parser.add_argument('--ligolw-skymap-output', type=str, default='.',
                        help='Option to output sky map files to directory')
    parser.add_argument('--ligolw-event-output', type=str,
                        help='Option to keep coinc file under given name')
    parser.add_argument('--enable-production-gracedb-upload',
                        action='store_true', default=False,
                        help='Do not mark triggers uploaded to GraceDB as test '
                             'events. This option should *only* be enabled in '
                             'production analyses!')
    parser.add_argument('--gracedb-server', metavar='URL',
                        help='URL of GraceDB server API for uploading events.')
    parser.add_argument('--custom-frame-file', type=str, nargs='+',
                        action=MultiDetOptionAppendAction,
                        help='Lists of local frame files, e.g., '
                             'H1:/path/to/frame/file L1:/path/to/frame/file')
    parser.add_argument('--injection-file', type=str,
                        help='Optional path to an injection file.')
    parser.add_argument('--fake-strain', type=str, 
                        action=MultiDetOptionAction, metavar="IFO:FAKE_STRAIN",
                        help="The set of fake-strains for each detector")
    parser.add_argument('--fake-strain-seed', type=int,
                        action=MultiDetOptionAction, metavar="IFO:FAKE_STRAIN_SEED",
                        help="The set of fake-strains-seeds for each detector")
    parser.add_argument('--rescale-loglikelihood',
                        help="SNR rescaling factor used in BAYESTAR")

    opt = parser.parse_args()

    frame_type_dict = {f.split(':')[0]: f.split(':')[1] for f in opt.frame_type} \
            if opt.frame_type is not None else None
    chan_name_dict = {f.split(':')[0]: f for f in opt.channel_name} \
            if opt.channel_name is not None else None

    main(opt.trig_time, opt.mass1, opt.mass2,
         opt.spin1z, opt.spin2z, opt.f_low, opt.f_upper, opt.sample_rate,
         opt.ifar, opt.ifos, opt.thresh_SNR, opt.ligolw_skymap_output,
         opt.ligolw_event_output,
         frame_types=frame_type_dict, channel_names=chan_name_dict,
         gracedb_server=opt.gracedb_server,
         test_event=not opt.enable_production_gracedb_upload,
         custom_frame_files=opt.custom_frame_file,
         approximant=opt.approximant, detector_state=opt.detector_state,
         segment_source=opt.segment_source, segment_server=opt.segment_server,
         veto_definer=opt.veto_definer, injection_file=opt.injection_file, 
         fake_strain=opt.fake_strain, fake_strain_seed=opt.fake_strain_seed,
         rescale_loglikelihood=opt.rescale_loglikelihood)
