#!/usr/bin/env python

import sys
from optparse import OptionParser

import numpy

from glue.ligolw import utils, table, lsctables
from glue.segments import segment, PosInfinity, NegInfinity
from pylal import Fr

import lal

import lalsimutils

optp = OptionParser()
optp.add_option("-c", "--channel", action="append", help="Channel to add waveform to. Specify in form inst=channel_name (e.g. H1=FAKE-STRAIN). Provide multiples times for different detectors.")
optp.add_option("-s", "--start-time", type=int, help="Start time of the frame. Default is truncate beginning GPS of the waveform to an integer.")
optp.add_option("-e", "--end-time", type=int, help="End time of the frame. Default is ceiling the end GPS of the waveform to an integer.")
optp.add_option("-p", "--param", action="append", help="Override this parameter with a value. E.g. --param lambda1=1.0 Provide multiple times for different parameters.")
optp.add_option("-x", "--xml-table", help="Name of XML table containing sim_inspiral table to read.")
optp.add_option("-r", "--sample-rate", type=int, default=16384, help="Sample rate of data, default is 16kHz")
opts, args = optp.parse_args()

det_chan = dict([opt.split("=") for opt in opts.channel])
param_val = dict([opt.split("=") for opt in opts.param])

sim_table = table.get_table(utils.load_filename(opts.xml_table), lsctables.SimInspiralTable.tableName)

wave_p = lalsimutils.ChooseWaveformParams()
wave_p.copy_lsctables_sim_inspiral(sim_table[0])

for att, val in param_val.iteritems():
    if att in ["m1", "m2"]:
        val = float(val)*lal.LAL_MSUN_SI
    setattr(wave_p, att, float(val))
wave_p.deltaT = 1.0/opts.sample_rate

print "=======Waveform parameters after modification=========="
wave_p.print_params()

frame_seg = segment(NegInfinity, PosInfinity)
if opts.start_time:
    frame_seg = segment(int(numpy.floor(opts.start_time)), frame_seg[1])
if opts.end_time:
    frame_seg = segment(frame_seg[0], int(numpy.floor(opts.end_time)))

data = []
for det, chan in det_chan.iteritems():
    wave_p.detector = det
    waveform = lalsimutils.hoft(wave_p)

    st, dt = float(waveform.epoch), waveform.deltaT
    wdur = len(waveform.data.data)*dt

    wave_seg = segment(st, st+wdur)
    if wave_seg not in frame_seg:
        sys.exit("Requested waveform(s) would be clipped by requested frame boundaries.")

    waveform = waveform.data.data # no LAL, you couldn't play nice so it's mine now

    if opts.start_time:
        clip_samples = int((st - opts.start_time)/dt)
        if clip_samples < 0: # Clip
            waveform = waveform[clip_samples:]
            st = opts.start_time
        elif clip_samples > 0: # Pad
            waveform = numpy.concatenate((numpy.zeros(clip_samples), waveform))
            st = opts.start_time

    wdur = len(waveform)*dt

    if opts.end_time:
        clip_samples = int((opts.end_time - (st + wdur))/dt)
        if clip_samples > 0: # Pad
            waveform = numpy.concatenate((waveform, numpy.zeros(clip_samples)))
        elif clip_samples < 0: # Clip
            waveform = waveform[:clip_samples]

    data.append( {
	    "name": "%s:%s" % (det, chan),
	    "data": waveform,
	    "start": st,
	    "dx": dt } )

dur = int(abs(frame_seg))
st = int(frame_seg[0])

det = "".join(sorted([d[0] for d in det_chan.keys()]))
Fr.frputvect( "%s-INJ_FRAME-%d-%d.gwf" % (det, st, dur), data, verbose = False )
