#!/usr/bin/env python
import sys
import os
import tempfile
from xml.sax import SAXException
from optparse import OptionParser

import matplotlib
matplotlib.use("Agg")
from matplotlib import pyplot
import numpy

from glue.ligolw import lsctables, utils, table
import xmlutils

import bayesutils

optp = OptionParser()
optp.add_option("-p", "--param", action="append", help="Plot this parameter. Give multiple times to append.")
optp.add_option("-o", "--output", help="Save figure to this name.")
optp.add_option("-i", "--injection", help="Use SimInspiral table from this to identify injection values.")
optp.add_option("-v", "--verbose", action="store_true", help="Be verbose.")

optp.add_option("-a", "--animate-stride", type=int, help="Animate a frame of images every 'n' samples. Default is off.")
opts, args = optp.parse_args()

params = opts.param or ["inclination", "distance", "longitude", "latitude", "geocent_end_time", "coa_phase", "polarization"]
nparm = len(params)
tmp = list(set(params + ["alpha1", "alpha2", "alpha3"]))

def sort(i):
    try: return params.index(i) 
    except ValueError: return len(params)
params = sorted(tmp, key=sort)

if opts.injection:
    injection = table.get_table(utils.load_filename(opts.injection), lsctables.SimInspiralTable.tableName)
    assert len(injection) == 1
    injection = numpy.array( [getattr(s, p) for s in injection for p in params[:nparm]] )
else:
    injection = None

try:
    samples = table.get_table(utils.load_filename(args[0]), lsctables.SimInspiralTable.tableName)
except SAXException:
    samples = xmlutils.db_to_samples(args[0], lsctables.SimInspiralTable, params)

def make_plot(samples, title=None):
    # calculate weights for each sample
    lnLmax = numpy.max([s.alpha1 for s in samples])
    weights = [numpy.exp(s.alpha1 - lnLmax)*s.alpha2/s.alpha3 for s in samples]   # avoid overflow and nans (white crap in plots) - underflow is better

    # unpack
    samples = numpy.array( [getattr(s, p) for s in samples for p in params[:nparm]] )
    samples.shape = (len(weights), nparm)

    from collections import defaultdict
    p_syms = defaultdict(lambda: "")
    p_syms.update({"inclination": "$\\iota$", "distance": "d", "longitude": "$\\alpha$", "latitude": "$\\delta$", "geocent_end_time": "t", "coa_phase": "$\\phi$", "polarization": "$\\psi$"})
    p_syms = [p_syms[p] for p in params]

    # FIXME: title
    bayesutils.triplot(samples, weights=weights, labels=p_syms, title=title, inj=injection)

def select_samples(samples, begin=0, end=None):
    if end is None:
        end = len(samples)
    return [s for s in samples if begin <= int(s.simulation_id) <= end]

def convert_fname(fin, fout, verbose=False):
    import subprocess
    import shlex
    cmd = "convert %s %s" % (fin, fout)
    if verbose:
        print cmd
    subprocess.call(shlex.split(cmd))
    

def make_animation(imglist, outfile, delay=100, loop=True, verbose=False):
    import subprocess
    import shlex
    cmd = "gifsicle --delay %s %s %s" % (delay, "--loop" if loop else "", " ".join(imglist))
    if verbose:
        print cmd
    with open(outfile, "w") as imgfile:
        subprocess.call(shlex.split(cmd), stdout=imgfile)

if opts.animate_stride is None:
    make_plot(samples)
    pyplot.savefig(opts.output or "triplot.png")
else:
    tmpdir = tempfile.mkdtemp()

    max_samp = max(map(lambda s: int(s.simulation_id), samples))
    nsamp, i = opts.animate_stride, 1
    tbdel, tbuse = [], []
    while i*nsamp <= max_samp:
        plot_samps = select_samples(samples, 0, i*nsamp)

        fname = "%s/tmp_triplot.%d.png" % (tmpdir, i)
        if len(plot_samps) == 0:
            if opts.verbose:
                print "No samples in range (%d, %d), skipping %s" % (0, i*nsamp, fname)
            i+=1
            continue

        if opts.verbose:
            print "creating %s" % fname

        make_plot(plot_samps, title="Samples in range [0, %d]" % (i*nsamp))
        pyplot.savefig(fname)
        tbdel.append(fname)
        i+=1
        pyplot.clf()

        # Seriously... matplotlib can't hande gifs, the animate interface isn't
        # until 1.1 and there's no python bindings to imageagick installed by
        # default. Unhappiness will ensue.
        convert_fname(fname, fname.replace("png", "gif"), verbose=opts.verbose)
        tbuse.append(fname.replace("png", "gif"))
        tbdel.append(fname.replace("png", "gif"))

    make_animation(tbuse, opts.output or "triplot.gif", verbose=opts.verbose)
    for f in tbdel:
        if opts.verbose:
            print "removing %s" % f
        os.remove(f)
    if opts.verbose:
        print "removing dir %s" % tmpdir
    os.removedirs(tmpdir)
