#!/usr/bin/env python
# Copyright (C) 2013 Chris Pankow
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import sys

import matplotlib
matplotlib.use("Agg")
from matplotlib import pyplot
import matplotlib.cm
from matplotlib.collections import LineCollection

import numpy

from glue.ligolw import utils, table, lsctables
from glue.segments import segment

from optparse import OptionParser

import xmlutils

optp = OptionParser()
optp.add_option("--plot-param", action="append", help="Plot this parameter, give multiple times for multiple arguments.")
optp.add_option("--top-n", type=int, default=1000, help="Take only the top N events in likelihood. Default is 1000.")
opts, args = optp.parse_args()

try:
    if len(opts.plot_param) > 0:
        dims = opts.plot_param
    else:
        #FIXME: Bad default
        dims = ["polarization", "coa_phase", "distance", "inclination", "latitude", "longitude", "geocent_end_time", "alpha1"]

    tmp = None
    for arg in args:
        try:
            samples = table.get_table(utils.load_filename(arg), lsctables.SimInspiralTable.tableName)

            #
            # Hack around nanoseconds
            #
            for row in samples:
                row.geocent_end_time = row.geocent_end_time + 1e-9*row.geocent_end_time_ns
        except:
            #dims = ["polarization", "coa_phase", "distance", "inclination", "latitude", "longitude", "geocent_end_time", "geocent_end_time_ns", "alpha1"]
            #dims = ["polarization", "coa_phase", "distance", "latitude", "longitude", "geocent_end_time", "geocent_end_time_ns", "alpha1"]
            samples = xmlutils.db_to_samples(arg, lsctables.SimInspiralTable, dims)

        if tmp is None:
            tmp = samples
        else:
            tmp.extend(samples)

    samples = tmp

# FIXME: Catch the right exception
except:
    if len(args) != 1:
        raise ValueError("Can't handle more than one ASCII file argument currently.")

    # FIXME: Need proper header
    dims = ["efac", "equad", "jitter", "red_amp", "red_index", "alpha1"]
    from collections import namedtuple
    Sample = namedtuple("Sample", dims)
    samples = []
    for row in numpy.loadtxt(args[0]):
        samples.append( Sample(**dict(zip(dims, row))) )

# TODO: Density of lines will tell us how 'high' this should be
pyplot.figure()
pyplot.xlim([0, 9])
#pyplot.ylim([0, numpy.pi*2])
pyplot.ylim([0, 1])

#
# Plot high likelihoods last
#
samples = sorted(samples, key=lambda r: r.alpha1+numpy.log(r.alpha2/r.alpha3))

#
# Normalize the sample values between 0-1
#
sample_dict, prange = {}, {}
for dim in dims:
    #d = [getattr(r, dim) for r in samples if r.alpha1 > 0]
    d = [getattr(r, dim) for r in samples][-opts.top_n:]
    dmin, dmax = min(d), max(d)
    drange = dmax-dmin
    if drange == 0:
        sample_dict[dim] = numpy.ones(len(d))*0.5
    else:
        sample_dict[dim] = (numpy.array(d)-dmin)/drange
    print dim, dmin, dmax, drange
    prange[dim] = segment(dmin, dmax)

#
# Get a color scale for the integrand value
#
#loglike = [r.alpha1 for r in samples if r.alpha1 > 0]
logevid = numpy.array([r.alpha1+numpy.log(r.alpha2/r.alpha3) for r in samples][-opts.top_n:])
lmin, lmax = min(logevid), max(logevid)
sample_dict["logevidence"] = logevid

intg_norm = matplotlib.colors.Normalize(lmin, lmax)
sm = matplotlib.cm.ScalarMappable(norm=intg_norm)

axes_xloc = range(0,2*len(dims), 2)

#
# Reorder the arrays to correspond to log evidence --- this is to ensure the
# highest weighted samples are plotted on top
#
resort = logevid.argsort()
for key in sample_dict:
    sample_dict[key] = sample_dict[key][resort]
logevid = sorted(logevid)

line_segs, colors = [], []
for j, loge in enumerate(logevid):
    colors.append(matplotlib.cm.hot_r(sm.norm(loge)))
    samp_x_vals = []
    samp_y_vals = []
    for i, d in enumerate(dims):
        val = sample_dict[d][j]
        #pyplot.plot([axes_xloc[i-1], axes_xloc[i]], [prev_val, val], color=cval, linewidth=0.5)
        samp_x_vals.append(axes_xloc[i])
        samp_y_vals.append(val)
        prev_val = val
    line_segs.append(zip(samp_x_vals, samp_y_vals))

lc = LineCollection(line_segs, linewidth=0.5, color=colors)
pyplot.gca().add_collection(lc)

#
# Draw the pseudo axes for the sample dimenions
#
for xloc in axes_xloc:
    pyplot.axvline(xloc, 0, 1, color='k', linewidth=5)

#
# So we know what the colors actually mean
#
sm.set_array(sample_dict["logevidence"])
sm.cmap = matplotlib.cm.hot_r
pyplot.colorbar(sm)

#
# Label the pseudo axes properly
#
pyplot.xticks(axes_xloc, ["%s %1.2g" % (d, prange[d][0]) for d in dims], rotation=60)

#
# Get another axis to label the extent
#
ax1 = pyplot.gca()
ax2 = pyplot.gca().twiny()
for xloc in axes_xloc:
    ax2.axvline(xloc, 0, 1, color='k', linewidth=5)
ax2.set_xlim(ax1.get_xlim())

for dim, val in prange.iteritems():
    prange[dim] = "%1.2g" % val[1]
pyplot.xticks(axes_xloc, [prange[d] for d in dims])

pyplot.yticks([], [])

# the 'tight' is so the figure doesn't cut off the bottm of the tick labels
pyplot.savefig("parallel.png", bbox_inches='tight')
