#!/usr/bin/env python
import sys
import operator
from optparse import OptionParser

import matplotlib
matplotlib.use("Agg")
from matplotlib import pyplot
import matplotlib.cm

import numpy

from glue.ligolw import ligolw, table, lsctables
lsctables.use_in(ligolw.LIGOLWContentHandler)
from glue.ligolw import utils

optp = OptionParser()
optp.add_option("-p", "--parameter", default="alpha2", help="Parameter to plot on the x-axis. Default is the prior.")
optp.add_option("-r", "--relabel-parameter", help="Label the parameter this on the plot.")
optp.add_option("-l", "--log-scale-parameter", action="store_true", help="Plot the parameter in logscale. Default is no.")
optp.add_option("-a", "--plot-adaptive-samples", type=int, help="Indicate which samples were in the adaptive period. A number must be supplied to indicate the index of the last adaptive sample. Default is off.")
optp.add_option("-o", "--plot-outliers", type=float, help="Indicate which samples account for more than a given percentage of the total summed weight. A number must be supplied to indicate the threshold of percentage to plot. Default is off.")
optp.add_option("-P", "--plot-prior-ratio", type=float, help="Indicate which samples have prior to sampling prior in excess of a given threshold. A number must be supplied to indicate the threshold. Default is off.")
optp.add_option("-O", "--plot-name", default="sample_scatter.png", help="Output plot to this filename.")
opts, args = optp.parse_args()

tbl = None
for arg in args:
    if tbl is None:
        tbl = table.get_table(utils.load_filename(arg, contenthandler=ligolw.LIGOLWContentHandler), lsctables.SimInspiralTable.tableName)
    else:
        tbl.extend( table.get_table(utils.load_filename(arg), lsctables.SimInspiralTable.tableName) )

pevid=numpy.array([s.alpha1+numpy.log(s.alpha2/s.alpha3) for s in tbl])
idxsorted = pevid.argsort()
pevid = pevid[idxsorted]
norm = reduce(numpy.logaddexp, pevid)

#
# Plot the samples in the parameter/likelihood space colored by weight
#
xaxis=numpy.array(map(operator.attrgetter(opts.parameter), tbl))[idxsorted]
like=numpy.array(map(numpy.exp, map(operator.attrgetter("alpha1"), tbl)))[idxsorted]
pyplot.scatter(xaxis, like, s=3, edgecolor='none', cmap=matplotlib.cm.spectral, c=pevid)
if opts.log_scale_parameter:
    pyplot.loglog()
else:
    pyplot.semilogy()
#pyplot.xlim(min(xaxis)/10, max(xaxis)*10)
pyplot.ylim(min(like)/10, max(like)*10)

#
# Plot samples with a prior ratio that it is in significant excess
#
prior_outlier = []
if opts.plot_prior_ratio is not None:
    prior = numpy.array(map(operator.attrgetter("alpha2"), tbl))[idxsorted]
    ps = numpy.array(map(operator.attrgetter("alpha3"), tbl))[idxsorted]
    prior_outlier = numpy.argwhere(prior/ps > opts.plot_prior_ratio)
if len(prior_outlier) > 0:
    pyplot.scatter(xaxis[prior_outlier], like[prior_outlier], s=20, marker="s", cmap=matplotlib.cm.spectral, c=pevid[prior_outlier], vmin=pevid[0], vmax=pevid[-1], label="p/ps > %1.2f" % opts.plot_prior_ratio)

#
# Plot samples which comprise more than a given percentage of the total weight
#
outliers = []
if opts.plot_outliers:
    norm_per = numpy.exp(norm)*opts.plot_outliers
    outliers = numpy.argwhere(numpy.exp(pevid) > norm_per)
if len(outliers) > 0:
    pyplot.scatter(xaxis[outliers], like[outliers], s=20, marker="v", cmap=matplotlib.cm.spectral, c=pevid[outliers], vmin=pevid[0], vmax=pevid[-1], label="> %d%% evidence" % int(100*opts.plot_outliers))
#print numpy.hstack((numpy.exp(pevid[outliers])/numpy.exp(norm), numpy.log10(prior[outliers]/ps[outliers]), ids[outliers]))

#
# Plot samples which were generated during the adaptive period
#
early = []
if opts.plot_adaptive_samples:
    ids = numpy.array(map(int, map(operator.attrgetter("simulation_id"), tbl)))[idxsorted]
    early = numpy.argwhere(ids <= opts.plot_adaptive_samples)
if len(early) > 0:
    pyplot.scatter(xaxis[early], like[early], s=20, marker="^", cmap=matplotlib.cm.spectral, c=pevid[early], vmin=pevid[0], vmax=pevid[-1], label="in adaptive period")

cbar=pyplot.colorbar()
pyplot.grid()
cbar.set_label("log weight")
pyplot.xlabel(opts.relabel_parameter or opts.parameter)
pyplot.ylabel("likelihood")
pyplot.legend(loc="lower right")
pyplot.savefig(opts.plot_name)
