#!/usr/bin/env python
import sys
import os
from xml.sax import SAXException

import matplotlib
matplotlib.use("Agg")
from matplotlib import pyplot

import numpy

from glue.ligolw import lsctables, table, utils

from statutils import cumvar
import xmlutils

from optparse import OptionParser
optp = OptionParser()
optp.add_option("-o", "--output", default="integral.png", help="Output file")
opts, args = optp.parse_args()


def plot_integral(samp, fname=None):
    """
    Effectively regenerates the process which happens in MCSampler.integrate, but keeps track of the values for plotting.
    Uses the sample index to enumerate the samples, so valid if only a subset of samples are returned.
    """
    idx, logfval, joint_prior, joint_p_s = numpy.array([ (int(s.simulation_id), s.alpha1, s.alpha2, s.alpha3) for s in samp]).T
    n = len(logfval)
    # Sort the arrays as they originally appeared (in case re-ordered).  
    # Probably a better python way to do this -- life is short.
    idx_sorted_index = numpy.lexsort((numpy.arange(len(idx)), idx))
    idx_sorted = numpy.sort(idx)
    logfval[:] = numpy.array([logfval[k] for k in idx_sorted_index])
    joint_prior[:] = numpy.array([joint_prior[k] for k in idx_sorted_index])
    joint_p_s[:] = numpy.array([joint_p_s[k] for k in idx_sorted_index])
    n_itr = idx_sorted+1   # iteration number, allowing for discarding samples. This should be sorted

    # Reproduce construction of weights
    int_val = numpy.exp(logfval)/joint_p_s*joint_prior
    maxval = [int_val[0] or -float("Inf")]
    for v in int_val[1:]:
        maxval.append( v if v > maxval[-1] and v != 0 else maxval[-1] )
    eff_samp = int_val.cumsum()/maxval


    pyplot.figure()
    pyplot.subplot(311)
    pyplot.title("Integral estimate")
    pyplot.loglog()

    int_est = int_val.cumsum()/n_itr
    err_est = numpy.sqrt(cumvar(int_val)*numpy.arange(len(int_val))/(n_itr*n_itr))  # Error goes as 1/sqrt(n). Beware: cumvar *already has* an implicit divide-by-sample-size-to-this-point, which must be undone
    pyplot.fill_between(n_itr, numpy.maximum(int_est - err_est, numpy.ones(int_est.shape)*1e-100), int_est + err_est, alpha=0.5)
    pyplot.ylim(min(int_est)/1e3,2*max(int_est))
    pyplot.xlim(1, None)

    pyplot.plot(n_itr, int_est, 'r-')
    pyplot.ylabel("integral val")
    #pyplot.twinx()
    #pyplot.ylabel("integral std")
    #pyplot.plot(n_itr, numpy.sqrt(cumvar(int_val)/n_itr), 'b-')
    pyplot.grid()

    pyplot.subplot(312)
    pyplot.title("Maximum lnL over iterations")
    #pyplot.loglog()
    pyplot.semilogx()
    pyplot.plot(n_itr, numpy.log(maxval), 'k-')
    pyplot.xlim(1, None)
    #pyplot.plot(n_itr, numpy.log10(int_val), 'b-')
    pyplot.grid()

    pyplot.subplot(313)
    pyplot.title("Effective samples")
    pyplot.semilogx()
    pyplot.ylabel("N_eff")
    pyplot.plot(n_itr, eff_samp, 'k-')
    pyplot.xlim(1, None)
    pyplot.twinx()
    pyplot.ylabel("ratio N_eff/N")
    pyplot.plot(n_itr, eff_samp/n_itr, 'b-')
    pyplot.grid()
    pyplot.subplots_adjust(hspace=0.5)
    pyplot.savefig(fname or (opts.output))
    pyplot.close()

#
# Main
#

for arg in args:
    try:
        samples = table.get_table(utils.load_filename(arg), lsctables.SimInspiralTable.tableName)
    # FIXME: Get right exception
    except SAXException:
        samples = xmlutils.db_to_samples(arg, lsctables.SimInspiralTable, ("alpha1", "alpha2", "alpha3"))

    base_name = os.path.basename(arg)
    base_name = base_name.split(".")[0]
    plot_name = base_name + "_" + opts.output
    plot_integral(samples, plot_name)
