#!/usr/bin/env python
"""Generate a bank of templates using a brute force stochastic method.
"""
import numpy, h5py, logging, argparse, numpy.random, sys
import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.conversions
from scipy.stats import gaussian_kde

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--output-file', required=True,
    help='Output file name for template bank.')
parser.add_argument('--input-file',
    help='Bank to use as a starting point.')
parser.add_argument('--params', required=True,
    help='list of paramaters to use', nargs='+')
parser.add_argument('--min', required=True,
    help='list of the minimum parameter values', nargs='+', type=float)
parser.add_argument('--max',  required=True,
    help='list of the maximum parameter values', nargs='+', type=float)
parser.add_argument('--approximant',  required=True,
    help='The waveform approximant to place')
parser.add_argument('--minimal-match', default=0.97, type=float)
parser.add_argument('--buffer-length', default=2, type=float,
    help='size of waveform buffer in seconds')
parser.add_argument('--sample-rate', default=2048, type=int,
    help='sample rate in seconds')
parser.add_argument('--low-frequency-cutoff', default=20.0, type=float)
parser.add_argument('--enable-sigma-bound', action='store_true')
parser.add_argument('--tau0-threshold', type=float)
parser.add_argument('--permissive', action='store_true',
    help='Allow waveform generator to fail.')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--tolerance', type=float)
parser.add_argument('--max-mtotal', type=float)
parser.add_argument('--min-mchirp', type=float)
parser.add_argument('--max-mchirp', type=float)
parser.add_argument('--fixed-params', type=str, nargs='*')
parser.add_argument('--fixed-values', type=float, nargs='*')
parser.add_argument('--max-q', type=float)
parser.add_argument('--tau0-crawl', type=float)
parser.add_argument('--tau0-start', type=float)
parser.add_argument('--tau0-end', type=float)
pycbc.psd.insert_psd_option_group(parser)
args = parser.parse_args()
pycbc.init_logging(args.verbose)
numpy.random.seed(args.seed)


fdict = {}
if args.fixed_params:
    for p, v in zip(args.fixed_params, args.fixed_values):
        fdict[p] = v

class Shrinker(object):
    def __init__(self, data):
        self.data = data

    def pop(self):
        if len(self.data) == 0:
            return None
        l = self.data[-1]
        self.data = self.data[:-1]
        return l

class TriangleBank(object):
    """ A bank of templates that uses the triangle inequality to estimate
    matches based on prior ones.
    """
    def __init__(self, p=None):
        self.waveforms = p if p is not None else []
        self.tbins = {}

    def __len__(self):
        return len(self.waveforms)

    def activelen(self):
        i = 0
        for w in self.waveforms:
            if isinstance(w, pycbc.types.FrequencySeries):
                i += 1
        return i

    def insert(self, hp):
        self.waveforms.append(hp)

        for b in [hp.tbin - 1, hp.tbin, hp.tbin + 1]:
            if b in self.tbins:
                self.tbins[b].append(len(self)-1)
            else:
                self.tbins[b] = [len(self)-1]

    def __getitem__(self, index):
        return self.waveforms[index]

    def keys(self):
        return self.waveforms[0].params.keys()

    def key(self, k):
        return numpy.array([p.params[k] for p in self.waveforms])

    def sigma_match_bound(self, sig):
        if not hasattr(self, 'sigma'):
            self.sigma = None
        if self.sigma is None or len(self.sigma) != len(self):
            self.sigma = numpy.array([h.s for h in bank.waveforms])
        return numpy.minimum(sig / self.sigma, self.sigma / sig)

    def range(self):
        if not hasattr(self, 'r'):
            self.r = None
        if self.r is None or len(self.r) != len(self):
            self.r = numpy.arange(0, len(self))
        return self.r

    def culltau0(self, threshold):
        cull = numpy.where(self.tau0() < threshold)[0]

        class dumb(object):
            pass
        for c in cull:
            d = dumb()
            d.tau0 = self.waveforms[c].tau0
            d.params = self.waveforms[c].params
            d.s = self.waveforms[c].s
            self.waveforms[c] = d


    def tau0(self):
        if not hasattr(self, 't0'):
            self.t0 = None
        if self.t0 is None or len(self.t0) != len(self):
            self.t0 = numpy.array([h.tau0 for h in self])
        return self.t0

    def __contains__(self, hp):
        mmax = 0
        mnum = 0
        #Apply sigmas maximal match.
        if args.enable_sigma_bound:
            matches = self.sigma_match_bound(hp.s)
            r = self.range()[matches > hp.threshold]
        else:
            matches = numpy.ones(len(self))
            r = self.range()

        msig = len(r)

        #Apply tua0 threshold
        if args.tau0_threshold:
            hp.tau0 = pycbc.conversions.tau0_from_mass1_mass2(
                                            hp.params['mass1'],
                                            hp.params['mass2'], 15.0)
            hp.tbin = int(hp.tau0 / args.tau0_threshold)

            if hp.tbin in self.tbins:
                r = numpy.array(self.tbins[hp.tbin])
            else:
                r = r[:0]

        mtau = len(r)

        # Try to do some actual matches
        inc = Shrinker(r*1)
        while 1:
            j = inc.pop()
            if j is None:
                hp.matches = matches[r]
                hp.indices = r
                logging.info("TADD MaxMatch:%0.3f Size:%i "
                             "AfterSigma:%i AfterTau0:%i Matches:%i"
                              % (mmax, len(self), msig, mtau, mnum))
                return False

            hc = self[j]
            m = hp.gen.match(hp, hc)
            matches[j] = m
            mnum += 1

            # Update bounding match values, apply triangle inequality
            maxmatches = hc.matches - m + 1.10
            update = numpy.where(maxmatches < matches[hc.indices])[0]
            matches[hc.indices[update]] = maxmatches[update]

            # Update where to calculate matches
            skip_threshold = 1 - (1 - hp.threshold) * 2.0
            inc.data = inc.data[matches[inc.data] > skip_threshold]

            if m > hp.threshold:
                return True
            if m > mmax:
                mmax = m

    def check_params(self, gen, params, threshold):
        num_tried = 0
        num_added = 0
        for i in range(len(tuple(params.values())[0])):
            num_tried += 1.0

            try:
                hp = gen.generate(**{key:params[key][i] for key in params})
            except Exception as err:
                print(err)
                continue

            hp.gen = gen
            hp.threshold = threshold
            if hp not in self:
                num_added += 1
                self.insert(hp)

        return bank, num_added / float(num_tried)

class GenUniformWaveform(object):
    def __init__(self, buffer_length, sample_rate, f_lower):
        self.f_lower = f_lower
        self.delta_f = 1.0 / buffer_length
        tlen = int(buffer_length * sample_rate)
        self.flen = tlen // 2 + 1
        psd = pycbc.psd.from_cli(args, self.flen, self.delta_f, self.f_lower)
        self.kmin = int(f_lower * buffer_length)
        self.w = ((1.0 / psd[self.kmin:-1]) ** 0.5).astype(numpy.float32)
        qtilde = pycbc.types.zeros(tlen, numpy.complex64)
        q = pycbc.types.zeros(tlen, numpy.complex64)
        self.qtilde_view = qtilde[self.kmin:self.flen - 1]
        self.ifft = pycbc.fft.IFFT(qtilde, q)
        self.md = q._data[-100:]
        self.md2 = q._data[0:100]

    def generate(self, **kwds):
        kwds.update(fdict)
        if kwds['approximant'] in pycbc.waveform.fd_approximants():
            hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f,
                                                f_lower=self.f_lower, **kwds)
            if 'fratio' in kwds:
                hp = hc * kwds['fratio'] + hp * (1 - kwds['fratio'])
        else:
            dt = 1.0 / args.sample_rate
            hp = pycbc.waveform.get_waveform_filter(
                        pycbc.types.zeros(self.flen, dtype=numpy.complex64),
                        delta_f=self.delta_f, delta_t=dt,
                        f_lower=self.f_lower, **kwds)

        hp.resize(self.flen)
        hp = hp.astype(numpy.complex64)
        hp[self.kmin:-1] *= self.w
        s = float(1.0 / pycbc.filter.sigmasq(hp, low_frequency_cutoff=self.f_lower) ** 0.5)
        hp *= s
        hp.params = kwds
        hp.view = hp[self.kmin:-1]
        hp.s = (1.0 / s) ** 2.0
        return hp

    def match(self, hp, hc):
        pycbc.filter.correlate(hp.view, hc.view, self.qtilde_view)
        self.ifft.execute()
        m = max(abs(self.md).max(), abs(self.md2).max())
        return m * 4.0 * self.delta_f

r = 0
if not args.tolerance:
    tolerance = (1 - args.minimal_match) / 10
else:
    tolerance = args.tolerance

size = int(1.0 / tolerance)

gen = GenUniformWaveform(args.buffer_length,
    args.sample_rate, args.low_frequency_cutoff)
bank = TriangleBank()

if args.input_file:
    f = h5py.File(args.input_file, 'r')
    params = {k: f[k][:] for k in f}
    bank, _ = bank.check_params(gen, params, args.minimal_match)


def draw(rtype):
    params = {}

    if rtype == 'uniform':
        for name, pmin, pmax in zip(args.params, args.min, args.max):
            params[name] = numpy.random.uniform(pmin, pmax, size=size)
    elif rtype == 'kde':
        trail = 300
        if trail > len(bank):
            trail = len(bank)
        p = bank.keys()
        p = [k for k in p if k not in fdict]
        p.remove('approximant')
        bdata = numpy.array([bank.key(k)[-trail:] for k in p])

        kde = gaussian_kde(bdata)
        points = kde.resample(size=size)
        for k, v in zip(p, points):
            params[k] = v

    params['approximant'] = numpy.array([args.approximant]*size)

    # Filter out stuff
    l = None
    for name, pmin, pmax in zip(args.params, args.min, args.max):
        nl = (params[name] < pmax) & (params[name] > pmin)
        l = (nl & l) if l is not None else nl

    if args.max_q:
        q =  numpy.maximum(params['mass1'] / params['mass2'], params['mass2'] / params['mass1'])
        l &= q < args.max_q

    if args.max_mtotal:
        l &= params['mass1'] + params['mass2'] < args.max_mtotal

    if args.max_mchirp:
        from pycbc.conversions import mchirp_from_mass1_mass2
        mc = mchirp_from_mass1_mass2(params['mass1'], params['mass2'])
        l &= mc < args.max_mchirp

    if args.min_mchirp:
        from pycbc.conversions import mchirp_from_mass1_mass2
        mc = mchirp_from_mass1_mass2(params['mass1'], params['mass2'])
        l &= mc > args.min_mchirp

    for k in params:
        params[k] = params[k][l]

    return params

def cdraw(rtype, ts, te):
    from pycbc.conversions import tau0_from_mass1_mass2

    p = draw(rtype)
    if  len(p[list(p.keys())[0]]) > 0:
        t = tau0_from_mass1_mass2(p['mass1'], p['mass2'], 15.0)
        l = (t < te) & (t > ts)
        for k in p:
            p[k] = p[k][l]

    i = 0
    while len(p[list(p.keys())[0]]) < size:
        tp = draw(rtype)
        for k in p:
            p[k] = numpy.concatenate([p[k], tp[k]])

        if  len(p[list(p.keys())[0]]) > 0:
            t = tau0_from_mass1_mass2(p['mass1'], p['mass2'], 15.0)
            l = (t < te) & (t > ts)
            for k in p:
                p[k] = p[k][l]

        i += 1
        if i > 1000:
            break


    if len(p[list(p.keys())[0]]) == 0:
        return None

    return p

tau0s = args.tau0_start
tau0e = tau0s + args.tau0_crawl
go = True

region = 0
while tau0s < args.tau0_end:
    conv = 1
    r = 0
    while conv > tolerance:
        # Standard Round
        r += 1
        params = cdraw('uniform', tau0s, tau0e)
        if params is None:
            if len(bank) > 0:
                go = False
            break

        blen = len(bank)
        bank, uconv = bank.check_params(gen, params, args.minimal_match)
        logging.info("%s: Round (U): %s Size: %s conv: %s added: %s",
                     region, r, len(bank), uconv, len(bank) - blen)
        if r > 10:
            conv = uconv
        kloop = 0
        while ((kloop == 0) or (kconv / okconv) > .5) and len(bank) > 10:
            r += 1
            kloop += 1
            params = cdraw('kde', tau0s, tau0e)
            blen = len(bank)
            bank, kconv = bank.check_params(gen, params, args.minimal_match)
            logging.info("%s: Round (K) (%s): %s Size: %s conv: %s added: %s",
                         region, kloop, r, len(bank), kconv, len(bank) - blen)
            if uconv:
                logging.info('Ratio of convergences: %2.3f' % (kconv / (uconv)))

            if kloop == 1:
                okconv = kconv

            if kconv <= tolerance:
                conv = kconv
                break

    bank.culltau0(tau0s - args.tau0_threshold * 2.0)
    logging.info("Region Done %3.1f-%3.1f, %s stored", tau0s, tau0e, bank.activelen())
    region += 1
    tau0s += args.tau0_crawl / 2
    tau0e += args.tau0_crawl / 2

o = h5py.File(args.output_file, 'w')
for k in bank.keys():
    val = bank.key(k)
    if val.dtype.char == 'U':
        val = val.astype('bytes')
    o[k] = val
o['f_lower'] =  numpy.ones(len(val)) * args.low_frequency_cutoff
