#! /usr/bin/env python
#
#  - construct 1d cumulative (from flatfile), given argument type specification.
#     - can also construct 2d skymap
#  - evaluate the p value of the injection's location (or constant L contour)
#  - plot the cumulative
#
# USAGE
#    postprocess_1d_cumulative --save-sampler-file flatfile  --inj inj.xml --event 0
#          # input from flatfile-points.xml.gz
#          # output to flatfile-posterior-*
#                 - skymap
#                 - 1d cumulatives in each parameter
#                 - cumulative data at injection prameters
#    postprocess_1d_cumulative --save-sampler-file flatfile  --adapt-beta 0
#          # weight as L^\beta p/ps
#          # using 0 treats all mass points the same.
#          # Note that if only a fraction of the extrinsic points are saved, using \beta=0 produces
#             a conservative skymap
#
# PHYSICS NOTES
#   - *Requires a mass prior implicitly! *
#        : this script joins results from different simulations, hence different masses.
#          Right now, this prior is *not* properly implemented when creating relative weights at different mass points:
#          In practice, we are using a uniform prior in mchirp, eta.  That's not the physical prior. 
#      ===> This needs to be corrected. <===
#
# PORTABILITY NOTES
#    np.histogram :  density argument not supported on LIGO clusters' version

# Setup. 
import numpy as np
import lal
import bisect
# Plot always, for now
import matplotlib
#matplotlib.use("Agg")  # Uncomment this line on clusters.  Need a better solution...
fExtension = "jpeg"
if matplotlib.get_backend() == "MacOSX":
    fExtension = "jpeg"
if matplotlib.get_backend() == "agg":
    fExtension  = "png"
print " matplotlib backend ", matplotlib.get_backend(), " so using file type ", fExtension
import lalsimutils
from matplotlib import pylab as plt
from mpl_toolkits.mplot3d import Axes3D

import sys



# Import all ascii files on the command line
# Should use first line to establish associations!
# For now, assume output in format provided by 'convert_output_format_ile2inference':
# Columns assumed sorted as <stuff>, tref, phi, incl, psi, ra, dec, dist, lnL,p, ps
#  * assume* we have already sorted these by importance?

error_report = {}

import sys
for fname in sys.argv[1:]:
    print fname
    ret = np.loadtxt(fname)

    lnLmax = np.max(ret[:,-1])  # only relative weights are needed, so avoid loss of precision and infinty
    weights = (np.exp((ret[:,-1] - lnLmax))*ret[:,-3]/ret[:,-2])*ret[:,-3-8]   # add number of points to weights: NZ is additive; the overall # normalization is created elsewhere.  Note we can't compute the evidence in this strategy
# if opts.adapt_beta == 0:
#     weights = np.ones(len(weights))

    # Provide integral value and error, to independently check what is reported at the end of the code
    weights_raw = (np.exp((ret[:,-1] - lnLmax))*ret[:,-3]/ret[:,-2])
    n_samples = np.max(ret[:,-3-8])
    print " Number of trials, stored points, and neff ", n_samples, len(ret), np.sum(weights_raw)/np.max(weights_raw)
    LmeanScaled = np.sum(weights_raw) /n_samples    # correct for fact that ony some points are stored -- do *not* use np.mean
    LdevScaled = np.sqrt(np.sum(weights_raw*weights_raw)/n_samples - LmeanScaled*LmeanScaled)
    error_report[fname] = [np.log(LmeanScaled)+lnLmax,np.log(LdevScaled/LmeanScaled),np.sum(weights_raw)/np.max(weights_raw)]
    print " ln\int Lpdx estimate, and error estimate: ", error_report[fname]

    # Metadata and pp data
    # Note mc, eta not always supported, but m1, m2 *are very likely* always be the first two
    values = {}
    values["ra"] = ret[:,-3-7]
    values["dec"] = ret[:,-3-6]
    values["phi"] = ret[:,-3-4]
    values["incl"] = ret[:,-3-3]
    values["psi"] = ret[:,-3-2]
    values["dist"] = ret[:,-3-1]
    values["lnL"] = ret[:,-1]

    nFig = 0
    keynames = values.keys() # ["ra", "dec", "tref", "phi","incl", "psi", "dist" ]   # Need  
    vals = np.zeros(len(values["phi"]))
    wts = np.zeros(len(values["phi"]))
    for nFig in np.arange(len(keynames)):
        param = keynames[nFig]
        fig = plt.figure(nFig)
    # Sample distributions.  Copy needed becuase of sorting and sum
        vals = values[param]
        wts = weights

        
        # Stage 2: Cumulative.  *High detail* Note this can be done with 'cumulative=True' in np
        # This also is used to set the plot range, at the 99.9% and 0.1% confidence intervals -- very important for parameters like mc, eta
        print " Plotting ", len(vals),  " items for ", param
        idx_sorted_util  = np.lexsort((np.arange(len(vals)), vals))
        vals  = np.array([vals[k] for k in idx_sorted_util])
        wts  = np.array([wts[k] for k in idx_sorted_util])
        cum_wts = np.cumsum(wts)
        cum_wts = cum_wts/cum_wts[-1]
        xlow = vals[bisect.bisect(cum_wts, 0.001)]
        xhigh = vals[bisect.bisect(cum_wts, 0.999)]
        plt.plot(vals,cum_wts,label="posterior:"+param)
        if xlow < xhigh:  # can be a problem with t
            plt.xlim(xlow,xhigh)
      
        plt.xlabel(param)
        plt.title("1d cumulative for "+param)
        plt.legend()


for nFig in np.arange(len(keynames)):
    param = keynames[nFig]

    plt.figure(nFig)
    plt.savefig("posterior-cumulative-multiplot-"+param+"-1d."+fExtension)

for key in error_report:
    print key, error_report[key][0], error_report[key][1]

sys.exit(0)



 

