#! /usr/bin/env python
#
#  - construct 1d cumulative (from flatfile), given argument type specification.
#     - can also construct 2d skymap
#  - evaluate the p value of the injection's location (or constant L contour)
#  - plot the cumulative
#
# USAGE: sqlite input
#    ./postprocess_1d_cumulative --save-sampler-file test.sqlite
#         # input from 'test.sqlite'
#         # output to
# USAGE: ASCII INPUT
#    postprocess_1d_cumulative --save-sampler-file flatfile  --inj inj.xml --event 0
#          # input from flatfile-points.xml.gz
#          # output to flatfile-posterior-*
#                 - skymap
#                 - 1d cumulatives in each parameter
#                 - cumulative data at injection prameters
#    postprocess_1d_cumulative --save-sampler-file flatfile  --adapt-beta 0
#          # weight as L^\beta p/ps
#          # using 0 treats all mass points the same.
#          # Note that if only a fraction of the extrinsic points are saved, using \beta=0 produces
#             a conservative skymap
#
# PHYSICS NOTES
#   - *Requires a mass prior implicitly! *
#        : this script joins results from different simulations, hence different masses.
#          Right now, this prior is *not* properly implemented when creating relative weights at different mass points:
#          In practice, we are using a uniform prior in mchirp, eta.  That's not the physical prior. 
#      ===> This needs to be corrected. <===
#
# PORTABILITY NOTES
#    np.histogram :  density argument not supported on LIGO clusters' version

# Setup. 
import numpy as np
import lal
import lalsimutils
import bisect
# Plot always, for now
import matplotlib
#matplotlib.use("Agg")  # Uncomment this line on clusters.  Need a better solution...
fExtension = "jpeg"
if matplotlib.get_backend() == "MacOSX":
    fExtension = "jpeg"
if matplotlib.get_backend() == "agg":
    fExtension  = "png"
print " matplotlib backend ", matplotlib.get_backend(), " so using file type ", fExtension
from matplotlib import pylab as plt
from mpl_toolkits.mplot3d import Axes3D


def mean_and_dev(arr, wt):
    av = np.average(arr, weights=wt)
    var = np.average(arr*arr, weights=wt)
    return [av, var - av*av]

def pcum_at(val,arr, wt):
    nm = np.sum(wt)
    return np.sum(wt[np.where(val < arr)]/nm)


# Parse standard arguments
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--save-sampler-file", dest="points_file_base",default="sampler-output-file")
parser.add_argument("--inj-xml", dest='inj', default=None,help="inspiral XML file containing injection information.")
parser.add_argument("--event",type=int, dest="event_id", default=0,help="event ID of injection XML to use.")
parser.add_argument( "--adapt-beta", type=float,default=1)
parser.add_argument("--save-P",default=0.,type=float)
parser.add_argument("--star-P",default=0.01,type=float,help="For density plots, star the top fraction of the cumulative weights. Helpful for identifying 1d regions that are undersampled and overweighted.")
parser.add_argument("--star-pratio",default=0.3,type=float,help="For density plots, star conditions when p/ps > pratio.  Useful for identifying overweighted points (e.g. unusually low ps)")
parser.add_argument("--disable-triplot", default=False,action="store_true")
parser.add_argument("--disable-1d-density", default=False,action="store_true")
parser.add_argument("--disable-skymap", default=True,action="store_true") # bugs with skyymapsa
opts = parser.parse_args()

values ={}
metadata = {}
ppdata = {}

# TEST IF SQLITE FILE SPECIFIED
fExtensionInputFile = opts.points_file_base.split(".")[-1]
if not fExtensionInputFile == "sqlite":

    # NORMAL ASCII INPUT
    # Import flatfile (default).  Modify in future to take xml
    # Should use first line to establish associations!
    # For now, assume output in format provided by 'convert_output_format_ile2inference':
    # Columns assumed sorted as <stuff>, tref, phi, incl, psi, ra, dec, dist, lnL,p, ps
    #  * assume* we have already sorted these by importance?
    ret = np.loadtxt(opts.points_file_base+"-points.dat")
    lnLmax = np.max(ret[:,-1])  # only relative weights are needed, so avoid loss of precision and infinty
    weights = (np.exp(opts.adapt_beta*(ret[:,-1] - lnLmax))*ret[:,-3]/ret[:,-2])*ret[:,-3-8]   # add number of points to weights: NZ is additive; the overall # normalization is created elsewhere.  Note we can't compute the evidence in this strategy
    p = ret[:,-3]
    ps = ret[:,-2]
# if opts.adapt_beta == 0:
#     weights = np.ones(len(weights))

    # Metadata and pp data
    # Note mc, eta not always supported, but m1, m2 *are very likely* always be the first two
    values["ra"] = ret[:,-3-7]
    values["dec"] = ret[:,-3-6]
    values["tref"] = ret[:,-3-5]
    values["phi"] = ret[:,-3-4]
    values["incl"] = ret[:,-3-3]
    values["psi"] = ret[:,-3-2]
    values["dist"] = ret[:,-3-1]
    values["mc"] = lalsimutils.mchirp(ret[:,0],ret[:,1])
    values["eta"] = lalsimutils.symRatio(ret[:,0],ret[:,1])
    values["lnL"] = ret[:,-1]
    values["m1"] = ret[:,0]
    values["m2"] = ret[:,1]
else:
    import xmlutils 
    from glue.ligolw import lsctables, utils, table

    # SQLITE INPUT
    fnameIn =opts.points_file_base
    opts.points_file_base = ".".join(opts.points_file_base.split(".")[:-1])   # strip the .sqlite extension
    print "Writing to ", opts.points_file_base

    # Get sample array
    samples = xmlutils.db_to_samples(fnameIn, lsctables.SimInspiralTable,["mass1", "mass2", "simulation_id", "inclination", "distance", "longitude", "latitude", "geocent_end_time", "coa_phase", "polarization","alpha1", "alpha2","alpha3"])

    # Extract the number of samples used. *Not* available -- must count sim_id.  
    # For now, ignore it: assume all points come from the same # of trials

    # Populate data structures described above
    values["lnL"] = np.array([s.alpha1 for s in samples])
    p = np.array([s.alpha2 for s in samples])
    ps = np.array([s.alpha3 for s in samples])
    lnLmax = np.max(values["lnL"])
    weights = np.exp(values["lnL"]-lnLmax)*p/ps

    # For compatibility with ascii input, use the same labels as above.  Obviously a cleaner python-y implementation would loop over parameters, using 'getattr'
    values["ra"]   = np.array([s.longitude for s in samples])
    values["dec"]  = np.array([s.latitude for s in samples])
    values["tref"]  = np.array([s.geocent_end_time for s in samples])
    values["phi"]  = np.array([s.coa_phase for s in samples])
    values["incl"] = np.array([s.inclination for s in samples])
    values["dist"] = np.array([s.distance for s in samples])
    values["psi"]  = np.array([s.polarization for s in samples])
    values["m1"] = np.array([s.mass1 for s in samples])
    values["m2"] = np.array([s.mass2 for s in samples])
    values["mc"]  = lalsimutils.mchirp(values["m1"],values["m2"])
    values["eta"]  = lalsimutils.symRatio(values["m1"],values["m2"])

# For density plots, create index of points to highlight on the x axis.
if opts.star_P:
    idx_sorted_index = np.lexsort((np.arange(len(weights)), weights))  # Sort the array of weights, recovering index values
    indx_list = np.array( [[k, weights[k]] for k in idx_sorted_index])   # pair up with the weights again
    cum_sum = np.cumsum(indx_list[:,1])                                     # find the cumulative sum
    cum_sum = cum_sum/cum_sum[-1]                                         # normalize the cumulative sum
    indx_star_list = [indx_list[k, 0] for k, value in enumerate(cum_sum > opts.star_P) if value]  # find the indices that preserve > 1-star_P  of total probability

if opts.star_pratio and opts.star_P:
    my_ratio = p/ps
    indx_ratio_list = [k for k, value in enumerate(my_ratio > opts.star_pratio ) if value]  # find the indices that pass the threshold
    indx_ratio_list = list( set(indx_ratio_list) & set(indx_star_list))                             # Only show if p/ps is large



# Truncate the sample set using save-P.  No science loss; massive sample reduction; speeds up the very slow triplot stage
# Note we will sort *again* later for the cumulative  Going forward, reduce redundant worting
if opts.save_P:
    print " Reducing sample size using cumulative cutoff ", opts.save_P
    print " Original sample size : ", len(values["phi"])
    idx_sorted_index = np.lexsort((np.arange(len(weights)), weights))  # Sort the array of weights, recovering index values
    indx_list = np.array( [[k, weights[k]] for k in idx_sorted_index])   # pair up with the weights again
    cum_sum = np.cumsum(indx_list[:,1])                                     # find the cumulative sum
    cum_sum = cum_sum/cum_sum[-1]                                         # normalize the cumulative sum
    indx_list = [indx_list[k, 0] for k, value in enumerate(cum_sum > opts.save_P) if value]  # find the indices that preserve > sampP  of total probability
    valuesNew = {}
    for key in values.keys():
         valuesNew[key] = np.array([values[key][int(indx)] for indx in indx_list] )
    weightsNew = np.array([weights[int(indx)] for indx in indx_list] )
    values = {}
    values = valuesNew
    weights = weightsNew
    print " Reduced sample size  : ",  len(values["phi"])


for key in values.keys():
    metadata[key] =  mean_and_dev(values[key], weights)

if opts.inj:
    Psig = lalsimutils.xml_to_ChooseWaveformParams_array(str(opts.inj))[opts.event_id]  # Load in the physical parameters of the injection.  
    ppdata['ra'] = [Psig.phi,pcum_at(Psig.phi,values["ra"],weights)]
    ppdata['dec'] = [Psig.theta,pcum_at(Psig.theta,values["dec"], weights)]
    ppdata['phi'] = [Psig.phiref,pcum_at(Psig.phiref,values["phi"], weights)]
    ppdata['tref'] = [Psig.tref, 0]  # to make sure this exists for all keys
    ppdata['incl'] = [Psig.incl,pcum_at(Psig.incl, values["incl"], weights)]
    ppdata['psi'] = [Psig.psi,pcum_at(Psig.psi,values["psi"], weights)]
    ppdata['dist'] = [Psig.dist/(1e6*lalsimutils.lsu_PC),pcum_at(Psig.dist/(1e6*lalsimutils.lsu_PC),values["dist"], weights)]

    # See comment about mass prior, above! 
    m1Sun = Psig.m1/lalsimutils.lsu_MSUN
    m2Sun = Psig.m2/lalsimutils.lsu_MSUN
    ppdata["m1"] = [m1Sun, pcum_at(m1Sun, values["m1"], weights)]
    ppdata["m2"] = [m2Sun, pcum_at(m2Sun, values["m2"], weights)]
    mc = lalsimutils.mchirp(m1Sun,m2Sun)
    eta = lalsimutils.symRatio(m1Sun,m2Sun)
    ppdata['mc'] = [mc, pcum_at(mc, values["mc"], weights)]
    ppdata['eta'] = [eta, pcum_at(mc, values["eta"], weights)]

    with open(opts.points_file_base+"-postprocess-pp.dat",'w') as f:
        for key in ['ra','dec', 'phi', 'incl', 'psi', 'dist']: 
            f.write(key + " " + str(ppdata[key][0]) + ' '+ str(ppdata[key][1]) + '\n')


###
### 1d cumulative plots
###


# Construct 1d plots.  See 'ourio.py'.  Prior code not used since sampler unavailable here to set range limits
nFig = 0
keynames = values.keys() # ["ra", "dec", "tref", "phi","incl", "psi", "dist" ]   # Need  
vals = np.zeros(len(values["phi"]))
wts = np.zeros(len(values["phi"]))
for nFig in np.arange(len(keynames)):
    param = keynames[nFig]
    plt.figure(nFig)
    # Sample distributions.  Copy needed becuase of sorting and sum
    vals = values[param]
    wts = weights

    plt.clf()
    # Stage 2: Cumulative.  *High detail* Note this can be done with 'cumulative=True' in np
    # This also is used to set the plot range, at the 99.9% and 0.1% confidence intervals -- very important for parameters like mc, eta
    idx_sorted_util  = np.lexsort((np.arange(len(vals)), vals))
    vals  = np.array([vals[k] for k in idx_sorted_util])
    wts  = np.array([wts[k] for k in idx_sorted_util])
    cum_wts = np.cumsum(wts)
    cum_wts = cum_wts/cum_wts[-1]
    xlow = vals[bisect.bisect(cum_wts, 0.001)]
    xhigh = vals[bisect.bisect(cum_wts, 0.999)]
    plt.plot(vals,cum_wts,label="posterior:"+param)
    if xlow < xhigh:  # can be a problem with t
        plt.xlim(xlow,xhigh)
    if opts.inj and not (param == "lnL"):
        plt.plot([ppdata[param][0],ppdata[param][0]], [1,1], color='k',linestyle='--')
   # hist, bins = plt.hist(vals, bins=100, normed=1, weights=weights,cumulative=True)
    plt.xlabel(param)
    plt.title("1d cumulative for "+param)
    plt.legend()
    plt.savefig(opts.points_file_base+"-posterior-cumulative-"+param+"-1d."+fExtension)


    plt.clf()
    # Stage 1: PDF
    if not opts.disable_1d_density:
        hist, bins = np.histogram(vals, bins=100,  weights=wts)  # Be careful: vals have been sorted
        dx = bins[1]-bins[0]
        hist = hist/(np.sum(hist)*dx)
        center = (bins[:-1]+bins[1:])/2
        plt.plot(center,hist,label="weighted samples:"+param)
        hist, bins  = np.histogram(vals,bins=100)
        dx = bins[1]-bins[0]
        hist = hist/(np.sum(hist)*dx)
        center = (bins[:-1]+bins[1:])/2
        plt.plot(center,hist,label="unweighted samples:"+param+"",linestyle='--')
        if opts.star_P:
            plt.scatter(values[param][indx_star_list],np.zeros(len(indx_star_list)),label="significant weight" )
        if opts.star_pratio:
            plt.scatter(values[param][indx_ratio_list],np.zeros(len(indx_ratio_list)),label="significant p/ps", c='r' )
        if opts.inj and not (param == "lnL"):
            plt.plot([ppdata[param][0],ppdata[param][0]], [np.max(hist),np.max(hist)], color='k',linestyle='--')
        plt.xlabel(param)
        plt.title("1d density for "+param)
        plt.legend()
        plt.savefig(opts.points_file_base+"-posterior-density-"+param+"-1d."+fExtension)
        plt.clf()


###
### Sky plots
###
if not opts.disable_skymap:
    # Add posterior skymap: use a *fixed* resolution.  (Better adaptive solutions in routines by Leo Singer)
    import healpy
    nside = 32   # must be a power of 2.  This is 12,300 sky points, good enough for moderate-SNR work
    npix = healpy.nside2npix(nside)
    smap = np.zeros(npix)
    print " Sky resolution note: ", len(smap), len(weights)

    # Convert to theta, phi from ra, dec.  Be careful -- don't clobber phi_orb
    values["thetaSkyplot"] = 0.5*np.pi  - values["dec"]
    values["phiSkyplot"] = values["ra"]

    # Convert sky locations to discrete pixels
    values["index"] = healpy.ang2pix(nside, values["thetaSkyplot"], values["phiSkyplot"],nest=False)

    # Loop over all pixel values.  Add more to the skymap.  [*slow* loop. Probably better python-y solution]
    for i in np.arange(len(weights)):
        smap[values["index"][i]]+= weights[i]
    smap = smap/np.sum(smap)

    # Write skymap fits file
    import lalinference.bayestar.fits
    lalinference.bayestar.fits.write_sky_map(opts.points_file_base+"-skymap.fits.gz", smap,
                                             creator="postprocess_1d_cumulative")


    # Create plot
    from lalinference.bayestar import fits as bfits
    from lalinference.bayestar import plot as bplot
    plt.clf()
    plt.subplot(111, projection='astro mollweide')
    # To keep using this, you will need to edit site-packages/lalinference/plot.py (!) to expose these functions
    bplot.healpix_heatmap(smap)    # Code has changed to make this not valid overall...need to make available
#    bplot.heatmap(functools.partial(_healpix_loopkup,map),smap)    # Code has changed to make this not valid overall...need to make available
#    bplot.colorbar()  # Match Leo
    bplot.outline_text(plt.gca())    # Match Leo
    plt.savefig(opts.points_file_base+"-posterior-skymap."+fExtension)

###
### Bayestar sky localization data (using skymap)
###
#   - helpful for sky localization data
#   - note coinc not available -- I am using the injection itself.

# Report on found injection data
#   - Write single line in format Leo Singer's "aggregate_found_injections" expects
#   - Warning: Leo maps to the *coinc* ID, rather than the injection ID
if opts.inj:
    contours = [0.68, 0.99, 0.999]
    true_ra = ppdata['ra'][0]
    true_dec = ppdata['dec'][0]
    if not opts.disable_skymap:
        from lalinference.bayestar import postprocess
        injFound =postprocess.find_injection(smap, true_ra, true_dec,contours=contours)
        searched_area = injFound.searched_area
        searched_prob = injFound.searched_prob
        offset = injFound.offset
        contour_areas = injFound.contour_areas
    else:
        searched_area=0
        searched_prob=0
        offset=0
        contour_areas=[]
    snrGuess =  np.sqrt(np.max(2*lnLmax))   # Hack. Based on maximum value of lnL ever found.  Would be better to use coinc
    colnames = ['coinc_event_id', 'simulation_id', 'snr', 'far', 'searched_area', \
        'searched_prob', 'offset', 'runtime'] + ["area({0:g})".format(p) \
        for p in contours]
#    searched_area, searched_prob, offset, contour_areas = 
    with open(opts.points_file_base+"-skyloc.dat", 'w') as f:
        f.write( "\t".join(colnames+["\n"]))
        f.write( "\t".join(map(str,[ 0, opts.event_id, snrGuess, 1,  searched_area, searched_prob, offset, 0]+ contour_areas+['\n'])))

###
### Lp/ps vs L: a diagnostic for 'outliers' (e.g., poor convergence, spots where ps is anomalously low)
###
plt.clf()
plt.figure(99)
lnw = np.log(weights)+lnLmax
plt.scatter(values["lnL"], lnw)   # restore lnLmax, so scale is absolute, not relative
plt.scatter(values["lnL"][indx_ratio_list], lnw[indx_ratio_list],c='r')   # red points for high ratio
plt.xlim(lnLmax-20,lnLmax+5)
lnwmax = np.max(lnw)
plt.ylim(lnwmax-30,lnwmax+5)
plt.xlabel("ln L")
plt.ylabel("ln w")
plt.savefig(opts.points_file_base+ "-wL."+fExtension)


###
### Triplots : because sql locks are ubiquitous
###   Note the triplot code does many things to the environment; it's easiest to do this last
###

if not opts.disable_triplot:
    import bayesutils
    p_syms = ["$\\iota$", "d", "$\\delta$", "$\\alpha$", "t", "$\\phi$", "$\\psi$"]
# This causes a lot of memory duplication.  Would be better to use 'ret', above,
    samples = (np.array([values["incl"],values["dist"], values["dec"], values["ra"], values["tref"], values["phi"], values["psi"]])).T
    print len(samples), samples.shape
    if opts.inj:
        injection = np.array([ppdata['incl'][0], ppdata['dist'][0], ppdata['dec'][0], ppdata["ra"][0], ppdata["tref"][0], ppdata["phi"][0], ppdata["psi"][0]])
    else:
        injection=None
    bayesutils.triplot(samples, weights=weights, labels=p_syms, title=None, inj=injection)
    if matplotlib.get_backend() == "MacOSX":
        plt.savefig(opts.points_file_base+ "-triplot.pdf")   # Workaround for Richard's laptop
    else:
        plt.savefig(opts.points_file_base+ "-triplot."+fExtension)


