#!/usr/bin/env python
import sys
import types
from optparse import OptionParser

import matplotlib
matplotlib.use("Agg")
from matplotlib import pyplot
import matplotlib.cm
from matplotlib.mlab import griddata

from scipy import interpolate

import numpy

from glue.ligolw import utils, table, lsctables

import xmlutils

def calculate_evidence(snr_ar, log=False):
    # FIXME: Doesn't understand runs of varying length yet
    if log:
        return reduce(numpy.logaddexp, snr_ar) - numpy.log(len(snr_ar))
    else:
        return sum(map(numpy.exp, snr_ar))/len(snr_ar)

def mchirp_eta(m1, m2):
    eta = m1*m2/(m1 + m2)**2
    return eta**(3.0/5)*(m1+m2), eta

optp = OptionParser()
optp.add_option("--dimension1", help="Plot this column along the x-axis")
optp.add_option("--dimension2", help="Plot this column along the y-axis")
optp.add_option("--log-evidence", action="store_true", help="Plot log evidence")
optp.add_option("--partial-evidence", action="store_true", help="Plot the partial evidence (L*p)")
optp.add_option("--show-points", action="store_true", help="Scatter plot the points used.")
optp.add_option("--show-mdc", help="Mark MDC values from provided XML on the plots.")
optp.add_option("--no-fill-contours", action="store_true", help="Don't fill contours created.")
optp.add_option("--fig-extension", help="Type of figure (png, pdf, jpg)", default="png")
optp.add_option("--input-cache", help="Read files in from this cache")
opts, args = optp.parse_args()

from glue.lal import Cache
if opts.input_cache:
    with open(opts.input_cache) as cfile:
        args.extend( Cache.fromfile(cfile).pfnlist() )

if opts.dimension1 in ["mass1", "mass2", "mchirp", "eta"] and opts.dimension2 in ["mass1", "mass2", "mchirp", "eta"]:
    points, like = [], []
    for arg in args:
        try:
            tbl = table.get_table(utils.load_filename(arg), lsctables.SnglInspiralTable.tableName)
        except:
            # FIXME: hardcoding
            tbl = xmlutils.db_to_samples(arg, lsctables.SnglInspiralTable, ["mass1", "mass2", "snr"])
            like.append(calculate_evidence([row.snr for row in tbl if row.snr is not None], log=opts.log_evidence))

            # FIXME: order
            if opts.dimension1 in ["mchirp", "eta"] and opts.dimension2 in ["mchirp", "eta"]:
                points.append( [mchirp_eta(row.mass1, row.mass2) for row in tbl][0] )
            else:
                points.append( [(row.mass1, row.mass2) for row in tbl][0] )

    d1, d2 = numpy.array(points).T
elif opts.dimension2 is None:

    if opts.show_mdc:
        sim = table.get_table(utils.load_filename(opts.show_mdc), lsctables.SimInspiralTable.tableName)
        assert len(sim) == 1
        if opts.dimension1 == "geocent_end_time":
            sim = float(sim[0].get_end())
        else:
            sim = getattr(sim[0], opts.dimension1)

    try:
        points = table.get_table(utils.load_filename(args[0]), lsctables.SimInspiralTable.tableName)

        # FIXME: We really need to get over this whole seconds/nanoseconds thing
        for point in points:
            point.geocent_end_time = point.geocent_end_time + 1e-9*point.geocent_end_time_ns

    except:
        # FIXME: hardcoding
        points = xmlutils.db_to_samples(args[0], lsctables.SimInspiralTable, [opts.dimension1, "alpha1", "alpha2", "alpha3"])

    fig = pyplot.figure()
    bottom = pyplot.subplot2grid((3,1), (2,0))
    top = pyplot.subplot2grid((3,1), (0,0), rowspan=2)

    if opts.show_mdc:
        top.axvline(sim)
        bottom.axvline(sim)

    # FIXME: Hardcoded bad column name
    points = [(getattr(row, opts.dimension1), row.alpha1, row.alpha2) for row in points]
    points = numpy.array(sorted(points)).T

    if opts.partial_evidence:
        top.plot(points[0], points[1]*numpy.log(points[2]), "k-")
    else:
        top.plot(points[0], points[1], "k-")

    # FIXME: For better plotting resolution, bin by dL/dx rather than linearly
    phist, bins = numpy.histogram(points[0], bins=100)
    like_intp = interpolate.interp1d(points[0], points[1])
    like_bins = like_intp(bins)

    # FIXME: Center bins
    bin_norm = matplotlib.colors.Normalize(min(phist), max(phist))
    sm = matplotlib.cm.ScalarMappable(norm=bin_norm)
    begin = numpy.vstack((bins[:-1], bins[1:])).T
    end = numpy.vstack((like_bins[:-1], like_bins[1:])).T
    segments = numpy.array(zip(begin, end))
    for i, seg in enumerate(segments):
        if i >= len(phist): break
        top.plot(seg[0], seg[1], color=matplotlib.cm.hsv(sm.norm(phist[i])) )
    sm.set_array(like_bins)
    sm.cmap = matplotlib.cm.hsv
    pyplot.gcf().colorbar(sm)

    # NOTE: Reenable for inset histogram
    #ax_inset = fig.add_axes([0.4, 0.2, 0.3, 0.3])
    #ax_inset.hist(points[0], bins=100)
    bottom.hist(points[0], bins=100, normed=True)

    # Plot prior
    bottom.plot(points[0], points[2], 'k-')

    bottom.set_xlabel(opts.dimension1)
    top.grid()
    bottom.grid()

    pyplot.savefig(opts.dimension1+"_likelihood." +opts.fig_extension )
    exit()

else:
    if opts.show_mdc:
        sim = table.get_table(utils.load_filename(opts.show_mdc), lsctables.SimInspiralTable.tableName)
        assert len(sim) == 1
        # FIXME: HACKHACK
        if opts.dimension1 == "geocent_end_time" or opts.dimension2 == "geocent_end_time":
            sim = float(sim[0].get_end())
        else:
            sim = getattr(sim[0], opts.dimension1)

    try:
        points = table.get_table(utils.load_filename(args[0]), lsctables.SimInspiralTable.tableName)

        # FIXME: We really need to get over this whole seconds/nanoseconds thing
        for point in points:
            point.geocent_end_time = point.geocent_end_time + 1e-9*point.geocent_end_time_ns

    except:
        # FIXME: hardcoding
        points = xmlutils.db_to_samples(args[0], lsctables.SimInspiralTable, [opts.dimension1, opts.dimension2, "alpha1", "alpha2"])

    # FIXME: Hardcoded bad column name
    like = [row.alpha1 for row in points]
    # FIXME: Hardcoded bad column name and not used in some cases
    prior = [row.alpha2 for row in points]
    # FIXME: Hardcoded bad column name and not used in some cases
    sprior = [row.alpha2 for row in points]

    if opts.partial_evidence:
        like += numpy.log(prior)

    points = [(getattr(row, opts.dimension1), getattr(row, opts.dimension2)) for row in points]
    d1, d2 = numpy.array(points).T

#
# Determine grid limits
#
d1min, d1max = min(d1), max(d1)
d2min, d2max = min(d2), max(d2)
d1step, d2step = (d1max-d1min)/100.0, (d2max-d2min)/100.0

#
# Create grid to interpolate to
#
d1g, d2g = numpy.meshgrid(numpy.arange(d1min, d1max, d1step), numpy.arange(d2min, d2max, d2step))
d1d2 = numpy.array(zip(d1g.flatten(), d2g.flatten()))

# FIXME: Reenable when available
#from scipy.interpolate import griddata
#int_like = griddata((d1,d2), like, d1d2)
int_like = griddata(d1, d2, like, d1g, d2g)


if opts.no_fill_contours:
    pyplot.contour(d1g, d2g, int_like, 20)
else:
    pyplot.contourf(d1g, d2g, int_like, 20)

if opts.show_mdc:

    sim = table.get_table(utils.load_filename(opts.show_mdc), lsctables.SimInspiralTable.tableName)
    assert len(sim) == 1
    if opts.dimension1 == "geocent_end_time":
        d1_sim = float(sim[0].get_end())
    else:
        d1_sim = getattr(sim[0], opts.dimension1)
    if opts.dimension2 == "geocent_end_time":
        d2_sim = float(sim[0].get_end())
    else:
        d2_sim = getattr(sim[0], opts.dimension2)
    pyplot.axvline(d1_sim, color='k', linewidth=2, alpha=0.5)
    pyplot.axhline(d2_sim, color='k', linewidth=2, alpha=0.5)

pyplot.colorbar()
pyplot.grid()
if opts.show_points:
    pyplot.scatter(d1, d2, s=5.0, c=like)
pyplot.xlabel(opts.dimension1)
pyplot.ylabel(opts.dimension2)
if opts.log_evidence:
    pyplot.savefig("%s_%s_logevidence.%s" % (opts.dimension1, opts.dimension2, opts.fig_extension))
else:
    pyplot.savefig("%s_%s_evidence.%s" % (opts.dimension1, opts.dimension2, opts.fig_extension))
