#!/bin/env python

"""
Create a file containing the time, phase and amplitude correlations between two
or more detectors for signals by doing a simple monte-carlo.

Output is the relative amplitude, time, and phase as compared to a reference
IFO. A separate calculation is done for each IFO as a possible reference.  The
data is stored as two vectors : one vector gives the discrete integer bin
corresponding to a particular location in 3*(Nifo-1)-dimensional
amplitude/time/phase space, the other gives the weight assigned to that bin.
To get the signal rate this weight should be scaled by the local sensitivity
value and by the SNR of the event in the reference detector.
"""

import argparse, h5py, numpy as np, pycbc.detector, logging
from numpy.random import uniform
from scipy.stats import norm
from copy import deepcopy


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--ifos', nargs='+',
                    help="The ifos to generate a histogram for")
parser.add_argument('--sample-size', type=int, required=True,
                    help="Approximate number of independent samples to draw "
                         "for the distribution")
parser.add_argument('--snr-ratio', type=float, required=True,
                    help="The SNR ratio permitted between reference ifo and "
                         "all others. Ex. giving 4 permits a ratio of "
                         "0.25 -> 4")
parser.add_argument('--relative-sensitivities', nargs='+', type=float,
                    help="Numbers proportional to horizon distance or "
                         "expected SNR at fixed distance, one for each ifo")
parser.add_argument('--seed', type=int, default=124)
parser.add_argument('--output-file', required=True)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--bin-density', type=int, default=1,
                    help="Number of bins per 1 sigma uncertainty in a "
                         "parameter. Higher values increase the resolution of "
                         "the histogram at the expense of storage.")
parser.add_argument('--smoothing-sigma', type=int, default=2,
                    help="Width of the smoothing kernel in sigmas")
parser.add_argument('--timing-uncertainty', type=float, default=.001,
                    help="Timing uncertainty to set bin size and smoothing "
                         "interval [default=.001s]")
parser.add_argument('--phase-uncertainty', type=float, default=0.25,
                    help="Phase uncertainty used to set bin size and "
                         "smoothing")
parser.add_argument('--snr-reference', type=float, default=5,
                    help="Reference SNR to scale SNR uncertainty")
parser.add_argument('--snr-uncertainty', type=float, default=1.0,
                    help="SNR uncertainty to set bin size and smoothing")
parser.add_argument('--weight-threshold', type=float, default=1e-10,
                    help="Minimum histogram weight to store as a proportion "
                         "of the histogram maximum: bins with less weight "
                         "will not be stored.")
parser.add_argument('--batch-size', type=int, default=1000000)
parser.add_argument('--param-bin-dtype', default='int32',
                    help="Type to use to store param_bin information. "
                         "Affects the maximum values that can be taken by "
                         "bin parameters. Default int32.",
                    choices=['int8', 'int16', 'int32'])
args = parser.parse_args()

if len(args.relative_sensitivities) != len(args.ifos):
    parser.error('--relative-sensitivities requires one numerical argument '
                 'for each detector')

# Approximate timing error at lower SNRs
twidth = args.timing_uncertainty / args.bin_density
# Factor of sqrt(2) as we'll be combining two SNRs with independent
# uncertainties
serr = args.snr_uncertainty * 2 ** 0.5
# Reference SNR for bin smoothing
sref = args.snr_reference
swidth = serr / sref / args.bin_density
# Approximate phase error at lower SNRs
pwidth = np.arctan(serr / sref) / args.bin_density

srbmax = int(args.snr_ratio / swidth)
srbmin = int((1.0 / args.snr_ratio) / swidth)

# Apply a simple smoothing to help account for measurement errors for
# weak signals
def smooth_param(data, index, wrapped=None):
    mref = max(data.values())
    bins = np.arange(-args.smoothing_sigma * args.bin_density,\
                        args.smoothing_sigma * args.bin_density + 1)
    kernel = norm.pdf(bins, scale=args.bin_density)

    nweights = {}

    # This is massively redundant as many points may be
    # recalculated, but it's straightforward.
    for i, key in enumerate(data):
        if data[key] / mref < args.weight_threshold:
            continue

        for a in bins:
            nkey = list(key)
            nkey[index] += a
            # If smoothing a wrapped parameter (e.g. phase)
            # put back into the allowed range
            if wrapped:
                nkey[index] = np.floor(nkey[index] % wrapped)

            tnkey = tuple(nkey)

            if tnkey in nweights:
                continue

            weight = 0
            for b, w in zip(bins, kernel):
                wkey = list(nkey)
                wkey[index] += b
                # If smoothing a wrapped parameter (e.g. phase)
                # put back into the allowed range
                if wrapped:
                    wkey[index] = np.floor(wkey[index] % wrapped)

                wkey = tuple(wkey)

                if wkey in data:
                    weight += data[wkey] * w

            nweights[tnkey] = weight
    return nweights

d = {ifo: pycbc.detector.Detector(ifo) for ifo in args.ifos}

pycbc.init_logging(args.verbose)

np.random.seed(args.seed)
size = args.batch_size
chunks = int(args.sample_size / size) + 1

bin_dtype = args.param_bin_dtype
max_td = twidth * (np.iinfo(bin_dtype).max - np.iinfo(bin_dtype).min + 1)

if max_td < 0.0425:
    raise RuntimeError('Max allowed time difference is less than the earth '
                       'travel time. Some observatories may be further apart '
                       'than this.')

# Store results for each ifo as a reference. The reference ifo is the
# ifo which gets the smallest amplitude. This allows us to get the correct
# symmetries handled under ifo switch and apply more consistent treatment
# of error uncertainties.
f = h5py.File(args.output_file, 'w')
for ifo0 in args.ifos:
    logging.info('Storing results using %s as a reference', ifo0)
    other_ifos = deepcopy(args.ifos)
    other_ifos.remove(ifo0)

    l = 0
    nsamples = 0
    weights = {}
    for k in range(chunks):
        nsamples += size
        logging.info('generating %s samples', size)

        # Choose random sky location and polarizations from
        # an isotropic population
        ra = uniform(0, 2 * np.pi, size=size)
        dec = np.arccos(uniform(-1., 1., size=size)) - np.pi/2
        inc = np.arccos(uniform(-1., 1., size=size))
        pol = uniform(0, 2 * np.pi, size=size)
        ic = np.cos(inc)
        ip = 0.5 * (1.0 + ic * ic)

        # calculate the toa, poa, and amplitude of each sample
        data = {}
        for rs, ifo in zip(args.relative_sensitivities, args.ifos):
            data[ifo] = {}
            fp, fc = d[ifo].antenna_pattern(ra, dec, pol, 0)
            sp, sc = fp * ip, fc * ic
            data[ifo]['s'] = (sp ** 2. + sc ** 2.) ** 0.5 * rs
            data[ifo]['t'] = d[ifo].time_delay_from_earth_center(ra, dec, 0)
            data[ifo]['p'] = np.arctan2(sc, sp)

        # Bin the data
        bind = []
        keep = None
        for ifo1 in other_ifos:
            dt = (data[ifo0]['t'] - data[ifo1]['t'])
            dp = (data[ifo0]['p'] - data[ifo1]['p']) % (2. * np.pi)
            sr = (data[ifo1]['s'] / data[ifo0]['s'])
            dtbin = (dt / twidth).astype(int)
            dpbin = (dp / pwidth).astype(int)
            srbin = (sr / swidth).astype(int)

            # We'll only store a limited range of ratios
            if keep is None:
                keep = (srbin <= srbmax) & (srbin >= srbmin)
            else:
                keep = keep & (srbin <= srbmax) & (srbin >= srbmin)
            bind += [dtbin, dpbin, srbin]

        # Calculate and sum the weights for each bin
        # use first ifo as reference for weights
        bind = [a[keep] for a in bind]

        w = data[ifo0]['s'][keep] ** 3.
        for i, key in enumerate(zip(*bind)):
            if key not in weights:
                weights[key] = 0
            weights[key] += w[i]

        ol = l
        l = len(weights.values())
        logging.info('%s, %s, %s, %s', l, l - ol, (l - ol) / float(size),
                     l / float(nsamples))

    logging.info('applying smoothing')
    # apply smoothing iteratively
    pwrap = (2 * np.pi) / pwidth
    for i in range(len(args.ifos)-1):
        logging.info('%s-phase', len(weights))
        weights = smooth_param(weights, i * 3 + 1, wrapped=pwrap)
        logging.info('%s-time', len(weights))
        weights = smooth_param(weights, i * 3 + 0)
        logging.info('%s-amp', len(weights))
        weights = smooth_param(weights, i * 3 + 2)
    logging.info('smoothing done: %s', len(weights))

    logging.info('converting to numpy arrays and normalizing')
    keys = np.array(list(weights.keys()))
    values = np.array(list(weights.values()))
    values /= values.max()

    logging.info('Removing bins outside of SNR ratio limits')
    n_precut = len(keys)
    keep = None
    for i in range(len(args.ifos)-1):
        srbin = np.array(list(zip(*keys))[i * 3 + 2])
        if keep is None:
            keep = (srbin <= srbmax) & (srbin >= srbmin)
        else:
            keep = keep & (srbin <= srbmax) & (srbin >= srbmin)
    keys = keys[keep]
    values = values[keep]
    logging.info('Removed %s bins', n_precut - len(keys))

    logging.info('discarding bins below threshold to limit storage')
    l = values > args.weight_threshold
    keys = keys[l].astype(bin_dtype)
    values = values[l].astype(np.float32)
    logging.info('Final length: %s', len(keys))

    logging.info('Presorting by keys for downstream use')
    ncol = keys.shape[1]
    pdtype = [('c%s' % i, bin_dtype) for i in range(ncol)]
    keys_bin = np.zeros(len(values), dtype=pdtype)
    for i in range(ncol):
        keys_bin['c%s' % i] = keys[:, i]
    lsort = keys_bin.argsort()
    keys_bin = keys_bin[lsort]
    values = values[lsort]

    logging.info('Writing results to file')
    f.create_dataset('%s/param_bin' % ifo0, data=keys_bin, compression='gzip',
                     compression_opts=7)
    f.create_dataset('%s/weights' % ifo0, data=values, compression='gzip',
                     compression_opts=7)

f.attrs['sensitivity_ratios'] = args.relative_sensitivities
f.attrs['srbmin'] = srbmin
f.attrs['srbmax'] = srbmax
f.attrs['twidth'] = twidth
f.attrs['pwidth'] = pwidth
f.attrs['swidth'] = swidth
f.attrs['ifos'] = args.ifos
f.attrs['stat'] = 'phasetd_newsnr_%s' % ''.join(args.ifos)

logging.info('Done')
