#!/usr/bin/env python
import sys
import math
import gc
import operator
import os
import glob
from xml import sax
import json
import bisect
from collections import defaultdict

from optparse import OptionParser

import matplotlib
matplotlib.use("Agg")
from matplotlib import pyplot
from matplotlib import gridspec

import numpy

from glue.lal import Cache
from glue.ligolw import table, utils, lsctables

import xmlutils

# FIXME: Move to module level constants
RANGES = { "distance": (0, 200),
    "inclination": (0, numpy.pi),
    "latitude": (-numpy.pi/2, numpy.pi/2),
    "polarization": (0, 2*numpy.pi),
    "longitude": (0, 2*numpy.pi),
    "geocent_end_time": (None, None),
    "phi_orb": (0, 2*numpy.pi)
}

def get_cump_and_pval_unstructured(tbl, p_idx, injval):
    """
    Takes a list 'tbl' of samples and orders the samples according to 'param'. Finds the cumulative probability which corresponds to 'injval' by bisection. Returns the cumulative probability as well as the sorted pairs.
    NOTE: This is deprecated in favor of get_cump_and_pval.
    """
    tbl = tbl[tbl[:,p_idx].argsort()]
    prm, wt = tbl[:,(p_idx,-1)].T
    wt = wt.cumsum()
    wt /= wt[-1]

    # FIXME: Fix edge error
    pval = wt[bisect.bisect(prm, injval)]
    return (prm, wt), pval

def get_cump_and_pval(tbl, param, injval):
    """
    Takes a list 'tbl' of samples and orders the samples according to 'param'. Finds the cumulative probability which corresponds to 'injval' by bisection. Returns the cumulative probability as well as the sorted pairs.
    """
    tbl.sort(order=param)
    # FIXME: Probably not necessary to return prm, wts since we have 
    # inplace sort
    prm, wt = tbl[param], tbl["weights"]
    wt = wt.cumsum()
    wt /= wt[-1]

    pval_idx = bisect.bisect(prm, injval)
    # Edge case -- FIXME: Is this right
    if pval_idx >= len(wt):
        pval = wt[-1]
    else:
        pval = wt[pval_idx]
    return (prm, wt), pval

optp = OptionParser()
optp.add_option("-i", "--input-cache", help="Cache of files to process.")
optp.add_option("-p", "--posterior-param", action="append", help="Parameter to calculate posterior over.")
optp.add_option("-I", "--injection-xml", help="Use injection information in a sim_inspiral table from this document.")
optp.add_option("-s", "--show-all", action="store_true", help="Show results from all mass points.")
optp.add_option("-v", "--verbose", action="store_true", help="Be verbose.")
optp.add_option("-O", "--output-tag", help="Use this tag in the plot name.")
optp.add_option("-o", "--output-file", help="Output JSON file with key-values corresponding to p-values for each parameter.")
optp.add_option("-S", "--second-stage", action="store_true", help="Plot cumulative probabilities from first stage -- make a P-P plot.")
optp.add_option("-d", "--search-dir", help="Search this directory for sqlite files.")
optp.add_option("-n", "--dont-ignore-nan", action="store_true", help="Instead of ignoring a nan value, will instead die.")
opts, args = optp.parse_args()

#
# Input / Argument gathering
#

if not opts.second_stage:
    args.extend(glob.glob(opts.search_dir + "/*.sqlite"))
    if len(args) == 0:
        args.extend(glob.glob(opts.search_dir + "/sqlite/*.sqlite"))

if opts.input_cache is not None:
    with open(opts.input_cache) as cfile:
        args.extend(Cache.fromfile(cfile).pfnlist())

if len(args) == 0:
    sys.exit("No files to process")

if opts.injection_xml is not None:
    sim_table = table.get_table(utils.load_filename(opts.injection_xml), lsctables.SimInspiralTable.tableName)

    sim_table = sorted(sim_table, key=operator.attrgetter("geocent_end_time"))
    sim_table = zip(map(operator.attrgetter("geocent_end_time"), sim_table), sim_table)

#
# Agglomerate data into a P-P plot
# This is the "second stage" -- the results of the first stage must be avaialble
# in the form of JSON key value pairs. This simply reorders the pairing by
# probability and makes the P-P plot
#
if opts.second_stage:
    pdict = defaultdict(list)
    for arg in args:

        if opts.verbose:
            print "Processing %s ... " % arg

        with open(arg) as jsonf:
            tmpdict = json.load(jsonf)
            if opts.dont_ignore_nan and any(map(math.isnan, tmpdict.values())):
                print >>sys.stderr, "NaN detected in parameters %s" % arg
                sys.exit(1)
            for p in opts.posterior_param:
                pdict[p].append(tmpdict[p])

    for p, vals in pdict.iteritems():
        # NOTE: Filtering out NaNs
        vals = sorted(filter(lambda v: not math.isnan(v), vals))
        pyplot.plot(numpy.linspace(0, 1.0, len(vals)), vals, '-', label=p)

    pyplot.plot([0,1], [0,1], 'k-.')
    pyplot.grid()
    pyplot.legend(loc="lower right")
    pyplot.xlabel("cumulative probability")
    pyplot.ylabel("cumulative confidence region")

    if opts.output_tag is not None:
        fname = "%s_pp_plot.png" % (opts.output_tag)
    else:
        fname = "%s_1d_pp_plot.png" % param
    pyplot.savefig(fname)

    #
    # No first stage
    #
    exit()

#
# First stage -- effectively construct the cumulant of the posterior in a given
# parameter
#
pvals = defaultdict(list)

all_weights, all_params = numpy.empty(0), numpy.empty((len(opts.posterior_param),0))

run_lengths = [0]
partial_evidence = []
for arg in args:

    if opts.verbose:
        sys.stdout.write("Processing %s ... " % arg)

    #
    # Load samples
    #
    try:
        samples = table.get_table(utils.load_filename(arg), lsctables.SimInspiralTable)
        runs = table.get_table(utils.load_filename(arg), lsctables.SnglInspiralTable)
    except sax.SAXParseException:
        samples = xmlutils.db_to_samples(arg, lsctables.SimInspiralTable, ["alpha1", "alpha2", "alpha3"] + opts.posterior_param)
        runs = xmlutils.db_to_samples(arg, lsctables.SnglInspiralTable, ["process_id", "snr"])

    if opts.verbose:
        print "read %d samples (prev. total %d)" % (len(samples), len(all_weights))

    #
    # Get event time, total samples
    #
    for pid, evid in runs:
        event_time = float(xmlutils.db_identify_param(arg, pid, "event_time"))
        # FIXME: Won't work if we terminate early
        n_max = int(xmlutils.db_identify_param(arg, pid, "n_max"))

        partial_evidence.append((n_max, evid))

    inj = None
    if opts.injection_xml is not None:
        #
        # Locate time ordered injection
        #
        idx = bisect.bisect(sim_table, event_time)
        idx = min(idx, len(sim_table)-1)

        #
        # Determine which injection is closer, left or right
        #
        tdiff_left = abs(sim_table[idx-1][0] - event_time)
        tdiff_right = abs(sim_table[idx][0] - event_time)
        inj = sim_table[idx-1][1] if tdiff_left < tdiff_right else sim_table[idx][1]

        if min(tdiff_left, tdiff_right) > 1.0:
            print >>sys.stderr, "WARNING: event located is %f seconds away from the true time." % min(tdiff_left, tdiff_right)
        if opts.verbose:
            print "Located injection at %10.6f" % inj.get_time_geocent()

    # Hook together columns of data
    all_params = numpy.hstack((all_params, [map(operator.attrgetter(param), samples) for param in opts.posterior_param]))

    # Append weights
    # NOTE: Be careful since we use 128 bit precision, evidence values can get
    # large enough to overflow otherwise
    #weights = numpy.exp(numpy.array([s.alpha1 for s in samples], dtype=numpy.float128))
    #weights = numpy.array([s.alpha2/s.alpha3 for s in samples])*weights
    weights = numpy.array([numpy.exp(s.alpha1)*s.alpha2/s.alpha3 for s in samples])
    all_weights = numpy.hstack((all_weights, weights))
    run_lengths.append(len(weights))

#
# Calculate run evidence
#
log_evidence = reduce(numpy.logaddexp, [ev + numpy.log(nm) for (nm, ev) in partial_evidence])
log_evidence -= numpy.log(sum([nm for (nm, ev) in partial_evidence]))

#
# Compare calculated evidence with ILE runs
#
# TODO: move this version into plot_like_contours
evid = reduce(numpy.logaddexp, numpy.log(all_weights)) - numpy.log(sum([nm for (nm, ev) in partial_evidence]))
print "Compare run log evidence %e with calculated log evidence %e" % (log_evidence, evid)

#
# Make a pivot table... kind of
#
all_params = numpy.rec.fromarrays(numpy.vstack((all_params, all_weights)), list(zip(opts.posterior_param + ["weights"], [numpy.float64]*(len(opts.posterior_param)+1))))

# FIXME: Explicit evocations of the garbage collector are never a good idea
gc.collect()

run_lengths = numpy.cumsum(run_lengths)

if opts.show_all:
    gs = gridspec.GridSpec(1, 2, width_ratios=[2,1])
else:
    gs = gridspec.GridSpec(1, 1, width_ratios=[2,1])

pyplot.figure()
for i, param in enumerate(opts.posterior_param):
    if opts.verbose:
        print "Analyzing parameter %s" % param

    #
    # Plot 1-D posteriors
    #
    pyplot.clf()

    if opts.show_all:
        # Plot individual mass points
        #pyplot.subplot(gs[0])
        pyplot.subplot(121)
        sub_pvals = []
        for s1, s2 in zip(run_lengths[:-1], run_lengths[1:]):
            (prm, wt), pval = get_cump_and_pval(all_params[s1:s2], param, getattr(inj, param))
            sub_pvals.append(pval)
            pyplot.plot(prm, wt, 'k-', alpha=0.3)

        # and histogram their pvals
        #pyplot.subplot(gs[1])
        pyplot.subplot(122)
        pyplot.hist(sub_pvals, histtype='step', orientation='horizontal')
        pyplot.ylim([0,1])
        pyplot.gca().get_yaxis().set_ticks([])
        pyplot.grid()

    # Mass averaged values
    if opts.show_all:
        #pyplot.subplot(gs[0])
        pyplot.subplot(121)

    (prm, wt), pval = get_cump_and_pval(all_params, param, getattr(inj, param))

    pyplot.plot(prm, wt, 'r-')

    # These arrays tend to be big blocks of memory, releasing them helps a bit
    del prm
    # FIXME: Explicit evocations of the garbage collector are never a good idea
    gc.collect()

    del wt
    # FIXME: Explicit evocations of the garbage collector are never a good idea
    gc.collect()

    pyplot.xlabel(param)
    pyplot.ylim([0, 1])
    pyplot.grid()

    if inj is not None:
        pyplot.axvline(getattr(inj, param))

    if opts.output_tag is not None:
        fname = "%s_%s_1d_posterior.png" % (opts.output_tag, param)
    else:
        fname = "%s_1d_posterior.png" % param
    pyplot.savefig(fname)

    if opts.verbose:
        print param, pval
    pvals[param] = pval
    pyplot.clf()

    # FIXME: Explicit evocations of the garbage collector are never a good idea
    gc.collect()

if opts.output_file is not None:
    with open(opts.output_file, "w") as jsonf:
        json.dump(pvals, jsonf)
