#!/usr/bin/env python
#
"""
Integrate the extrinsic parameters of the prefactored likelihood function.
no-op version: use synthetic functions
"""

import sys
import functools
from optparse import OptionParser, OptionGroup

import scipy.stats

import numpy

import lal
from igwn_ligolw import utils, lsctables, table, ligolw
from igwn_ligolw.utils import process
import glue.lal
#import pylal

import RIFT.lalsimutils as lalsimutils
import RIFT.misc.xmlutils as xmlutils

#from lalinference.bayestar import fits as bfits

__author__ = "Evan Ochsner <evano@gravity.phys.uwm.edu>, Chris Pankow <pankow@gravity.phys.uwm.edu>, R. O'Shaughnessy"

#
# Pinnable parameters -- for command line processing
#
LIKELIHOOD_PINNABLE_PARAMS = ["right_ascension", "declination", "psi", "distance", "phi_orb", "t_ref", "inclination"]

def get_pinned_params(opts):
    """
    Retrieve a dictionary of user pinned parameters and their pin values.
    """
    return dict([(p,v) for p, v in opts.__dict__.items() if p in LIKELIHOOD_PINNABLE_PARAMS and v is not None]) 

def get_unpinned_params(opts, params):
    """
    Retrieve a set of unpinned parameters.
    """
    return params - set([p for p, v in opts.__dict__.items() if p in LIKELIHOOD_PINNABLE_PARAMS and v is not None])

#
# Option parsing
#

optp = OptionParser()
optp.add_option("-c", "--cache-file", default=None, help="LIGO cache file containing all data needed.")
optp.add_option("-C", "--channel-name", action="append", help="instrument=channel-name, e.g. H1=FAKE-STRAIN. Can be given multiple times for different instruments.")
optp.add_option("-p", "--psd-file", action="append", help="instrument=psd-file, e.g. H1=H1_PSD.xml.gz. Can be given multiple times for different instruments.")
optp.add_option("-k", "--skymap-file", help="Use skymap stored in given FITS file.")
optp.add_option("-x", "--coinc-xml", help="gstlal_inspiral XML file containing coincidence information.")
optp.add_option("-I", "--sim-xml", help="XML file containing mass grid to be evaluated")
optp.add_option("-E", "--event", default=0,type=int, help="Event number used for this run")
optp.add_option("--soft-fail-event-range",action='store_true',help='Soft failure (exit 0) if event ID is out of range. This happens in pipelines, if we have pre-built a DAG attempting to analyze more points than we really have')
optp.add_option("-f", "--reference-freq", type=float, default=100.0, help="Waveform reference frequency. Required, default is 100 Hz.")
optp.add_option("--fmin-template", dest='fmin_template', type=float, default=40, help="Waveform starting frequency.  Default is 40 Hz. Also equal to starting frequency for integration") 
optp.add_option("--fmin-ifo", action='append' , help="Minimum frequency for each IFO. Implemented by setting the PSD=0 below this cutoff. Use with care.") 
#optp.add_option("--nr-params",default=None, help="List of specific NR parameters and groups (and masses?) to use for the grid.")
#optp.add_option("--nr-index",type=int,default=-1,help="Index of specific NR simulation to use [integer]. Mass used: mtot= m1+m2")
optp.add_option('--nr-group', default=None,help="If using a *ssingle specific simulation* specified on the command line, provide it here")
optp.add_option('--nr-param', default=None,help="If using a *ssingle specific simulation* specified on the command line, provide it here")
optp.add_option("--nr-lookup",action='store_true', help=" Look up parameters from an NR catalog, instead of using the approximant specified")
optp.add_option("--nr-lookup-group",action='append', help="Restriction on 'group' for NR lookup")
optp.add_option("--nr-hybrid-use",action='store_true',help="Enable use of NR (or ROM!) hybrid, using --approx as the default approximant and with a frequency fmin")
optp.add_option("--nr-hybrid-method",default="taper_add",help="Hybridization method for NR (or ROM!).  Passed through to LALHybrid. pseudo_aligned_from22 will provide ad-hoc higher modes, if the early-time hybridization model only includes the 22 mode")
optp.add_option("--rom-group",default=None)
optp.add_option("--rom-param",default=None)
optp.add_option("--rom-use-basis",default=False,action='store_true',help="Use the ROM basis for inner products.")
optp.add_option("--rom-limit-basis-size-to",default=None,type=int)
optp.add_option("--rom-integrate-intrinsic",default=False,action='store_true',help='Integrate over intrinsic variables. REQUIRES rom_use_basis at present. ONLY integrates in mass ratio as present')
optp.add_option("--nr-perturbative-extraction",default=False,action='store_true')
optp.add_option("--nr-use-provided-strain",default=False,action='store_true')
optp.add_option("--no-memory",default=False,action='store_true', help="At present, turns off m=0 modes. Use with EXTREME caution only if requested by model developer")
optp.add_option("--use-external-EOB",default=False,action='store_true')
optp.add_option("--maximize-only",default=False, action='store_true',help="After integrating, attempts to find the single best fitting point")
optp.add_option("--dump-lnL-time-series",default=False, action='store_true',help="(requires --sim-xml) Dump lnL(t) at the injected parameters")
optp.add_option("-a", "--approximant", default="TaylorT4", help="Waveform family to use for templates. Any approximant implemented in LALSimulation is valid.")
optp.add_option("-A", "--amp-order", type=int, default=0, help="Include amplitude corrections in template waveforms up to this e.g. (e.g. 5 <==> 2.5PN), default is Newtonian order.")
optp.add_option("--l-max", type=int, default=2, help="Include all (l,m) modes with l less than or equal to this value.")
optp.add_option("-s", "--data-start-time", type=float, default=None, help="GPS start time of data segment. If given, must also give --data-end-time. If not given, sane start and end time will automatically be chosen.")
optp.add_option("-e", "--data-end-time", type=float, default=None, help="GPS end time of data segment. If given, must also give --data-start-time. If not given, sane start and end time will automatically be chosen.")
optp.add_option("-F", "--fmax", type=float, help="Upper frequency of signal integration. Default is use PSD's maximum frequency.")
optp.add_option("-t", "--event-time", type=float, help="GPS time of the event --- probably the end time. Required if --coinc-xml not given.")
optp.add_option("-i", "--inv-spec-trunc-time", type=float, default=8., help="Timescale of inverse spectrum truncation in seconds (Default is 8 - give 0 for no truncation)")
optp.add_option("-w", "--window-shape", type=float, default=0, help="Shape of Tukey window to apply to data (default is no windowing)")
optp.add_option("-m", "--time-marginalization", action="store_true", help="Perform marginalization over time via direct numerical integration. Default is false.")
optp.add_option("-o", "--output-file", help="Save result to this file.")
optp.add_option("-O", "--output-format", default='xml', help="[xml|hdf5]")
optp.add_option("-S", "--save-samples", action="store_true", help="Save sample points to output-file. Requires --output-file to be defined.")
optp.add_option("-L", "--save-deltalnL", type=float, default=float("Inf"), help="Threshold on deltalnL for points preserved in output file.  Requires --output-file to be defined")
optp.add_option("-P", "--save-P", type=float,default=0, help="Threshold on cumulative probability for points preserved in output file.  Requires --output-file to be defined")
optp.add_option("--verbose",action='store_true')

#
# Add the integration options
#
integration_params = OptionGroup(optp, "Integration Parameters", "Control the integration with these options.")
# Default is actually None, but that tells the integrator to go forever or until n_eff is hit.
integration_params.add_option("--n-max", type=int, help="Total number of samples points to draw. If this number is hit before n_eff, then the integration will terminate. Default is 'infinite'.",default=1e7)
integration_params.add_option("--n-eff", type=int, default=100, help="Total number of effective samples points to calculate before the integration will terminate. Default is 100")
integration_params.add_option("--n-chunk", type=int, help="Chunk'.",default=100)
integration_params.add_option("--convergence-tests-on",default=False,action='store_true')
integration_params.add_option("--seed", type=int, help="Random seed to use. Default is to not seed the RNG.")
integration_params.add_option("--no-adapt", action="store_true", help="Turn off adaptive sampling. Adaptive sampling is on by default.")
integration_params.add_option("--no-adapt-distance", action="store_true", help="Turn off adaptive sampling, just for distance. Adaptive sampling is on by default.")
integration_params.add_option("--adapt-weight-exponent", type=float, default=1.0, help="Exponent to use with weights (likelihood integrand) when doing adaptive sampling. Used in tandem with --adapt-floor-level to prevent overconvergence. Default is 1.0.")
integration_params.add_option("--adapt-floor-level", type=float, default=0.1, help="Floor to use with weights (likelihood integrand) when doing adaptive sampling. This is necessary to ensure the *sampling* prior is non zero during adaptive sampling and to prevent overconvergence. Default is 0.1 (no floor)")
integration_params.add_option("--adapt-adapt",action='store_true',help="Adapt the tempering exponent")
integration_params.add_option("--adapt-log",action='store_true',help="Use a logarithmic tempering exponent")
integration_params.add_option("--interpolate-time", default=False,help="If using time marginalization, compute using a continuously-interpolated array. (Default=false)")
integration_params.add_option("--d-max", default=10000,type=float,help="Maximum distance in volume integral. Used to SET THE PRIOR; changing this value changes the numerical answer.")
integration_params.add_option("--declination-cosine-sampler",action='store_true',help="If specified, the parameter used for declination is cos(dec), not dec")
integration_params.add_option("--inclination-cosine-sampler",action='store_true',help="If specified, the parameter used for inclination is cos(dec), not dec")
integration_params.add_option("--manual-logarithm-offset",type=float,default=0,help="Target value of logarithm lnL. Integrand is reduced by exp(-manual_logarithm_offset).  Important for high-SNR sources!   Should be set dynamically")
optp.add_option_group(integration_params)

#
# Add the intrinsic parameters
#
intrinsic_params = OptionGroup(optp, "Intrinsic Parameters", "Intrinsic parameters (e.g component mass) to use.")
intrinsic_params.add_option("--pin-to-sim", help="Pin values to sim_inspiral table entry.")
intrinsic_params.add_option("--pin-distance-to-sim",action='store_true', help="Pin *distance* value to sim entry. Used to enable source frame reconstruction with NR.")
intrinsic_params.add_option("--mass1", type=float, help="Value of first component mass, in solar masses. Required if not providing coinc tables.")
intrinsic_params.add_option("--mass2", type=float, help="Value of second component mass, in solar masses. Required if not providing coinc tables.")
intrinsic_params.add_option("--eff-lambda", type=float, help="Value of effective tidal parameter. Optional, ignored if not given.")
intrinsic_params.add_option("--deff-lambda", type=float, help="Value of second effective tidal parameter. Optional, ignored if not given")
optp.add_option_group(intrinsic_params)


#
# Add options to integrate over intrinsic parameters.  Same conventions as util_ManualOverlapGrid.py.  
# Parameters have special names, and we adopt priors that use those names.
# NOTE: Only 'q' implemented
#
intrinsic_int_params = OptionGroup(optp, "Intrinsic integrated parameters", "Intrinsic parameters to integrate over. ONLY currently used with ROM version")
intrinsic_int_params.add_option("--parameter",action='append')
intrinsic_int_params.add_option("--parameter-range",action='append',type=str)
intrinsic_int_params.add_option("--adapt-intrinsic",action='store_true')
optp.add_option_group(intrinsic_int_params)


#
# Add the pinnable parameters
#
pinnable = OptionGroup(optp, "Pinnable Parameters", "Specifying these command line options will pin the value of that parameter to the specified value with a probability of unity.")
for pin_param in LIKELIHOOD_PINNABLE_PARAMS:
    option = "--" + pin_param.replace("_", "-")
    pinnable.add_option(option, type=float, help="Pin the value of %s." % pin_param)
optp.add_option_group(pinnable)

opts, args = optp.parse_args()

manual_avoid_overflow_logarithm=opts.manual_logarithm_offset




#
# Integrator options
#
n_max = opts.n_max # Max number of extrinsic points to evaluate at
n_eff = opts.n_eff # Effective number of points evaluated


#
# Gather information from the detection pipeline
#
if opts.coinc_xml is not None:
    xmldoc = utils.load_filename(opts.coinc_xml)
    coinc_table = table.get_table(xmldoc, lsctables.CoincInspiralTable.tableName)
    assert len(coinc_table) == 1
    coinc_row = coinc_table[0]
    event_time = coinc_row.get_end()
    print("Coinc XML loaded, event time: %s" % str(coinc_row.get_end()))
elif opts.event_time is not None:
    event_time = glue.lal.LIGOTimeGPS(opts.event_time)
    print("Event time from command line: %s" % str(event_time))
else:
    raise ValueError("Either --coinc-xml or --event-time must be provided to parse event time.")


#
# Set masses 
#
if opts.mass1 is not None and opts.mass2 is not None:
    m1, m2 = opts.mass1, opts.mass2
elif opts.coinc_xml is not None:
    sngl_inspiral_table = table.get_table(xmldoc, lsctables.SnglInspiralTable.tableName)
    assert len(sngl_inspiral_table) == len(coinc_row.ifos.split(","))
    m1, m2 = None, None
    for sngl_row in sngl_inspiral_table:
        # NOTE: gstlal is exact match, but other pipelines may not be
        assert m1 is None or (sngl_row.mass1 == m1 and sngl_row.mass2 == m2)
        m1, m2 = sngl_row.mass1, sngl_row.mass2
elif opts.sim_xml:
    True
else:
    raise ValueError("Need either --mass1 --mass2, --coinc-xml, or --sim-xml to retrieve masses.")


#
# Template descriptors
#

fiducial_epoch = lal.LIGOTimeGPS()
fiducial_epoch = event_time.seconds + 1e-9*event_time.nanoseconds   # no more direct access to gpsSeconds

# Struct to hold template parameters
if opts.sim_xml:
    print("====Loading injection XML:", opts.sim_xml, opts.event, " =======")
    P_list = lalsimutils.xml_to_ChooseWaveformParams_array(str(opts.sim_xml))
    P = P_list[opts.event]
    m1 = P.m1/lal.MSUN_SI
    m2 =P.m2/lal.MSUN_SI
    lambda1, lambda2 = P.lambda1,P.lambda2
    if  opts.pin_distance_to_sim:
        dist_in =P.dist
        opts.distance = dist_in/lal.PC_SI/1e6  # create de facto pinnable parrameter. Use Mpc unit
    P.dist = 100 * 1.e6 * lal.PC_SI   # use *nonstandard* distance
    P.phi=0.0
    P.psi=0.0
    P.incl = 0.0       # only works for aligned spins. Be careful.
    P.fref = opts.reference_freq
    if opts.approximant != "TaylorT4": # not default setting
        P.approx = lalsimutils.lalsim.GetApproximantFromString(opts.approximant)  # allow user to override the approx setting. Important for NR followup, where no approx set in sim_xml!
else:
 lambda1, lambda2 = 0, 0
 if opts.eff_lambda is not None:
        lambda1, lambda2 = lalsimutils.tidal_lambda_from_tilde(m1, m2, opts.eff_lambda, opts.deff_lambda or 0)
 P = lalsimutils.ChooseWaveformParams(
	approx = lalsimutils.lalsim.GetApproximantFromString(opts.approximant),
    fmin = template_min_freq,
    radec = False,   # do NOT propagate the epoch later
    incl = 0.0,       # only works for aligned spins. Be careful.
    phiref = 0.0,
    theta = 0.0,
    phi = 0.0,
    psi = 0.0,
    m1 = m1 * lal.MSUN_SI,
    m2 = m2 * lal.MSUN_SI,
    lambda1 = lambda1,
    lambda2 = lambda2,
    ampO = opts.amp_order,
    fref = opts.reference_freq,
    tref = fiducial_epoch,
    dist = factored_likelihood.distMpcRef * 1.e6 * lal.PC_SI
    )


print(" --- Template for intrinsic parameters ---- ") 
P.print_params()




#wt_array = np.loadtxt("wt.dat")
#if not hasattr(wt_array, "__len__"):
#    wt_array = [wt_array[0]]
wt_array = 1
try:
    n_gauss = len(wt_array)
except:
    n_gauss = 1
    wt_array = [wt_array]
x_0_list = []
Sigma_list = []
rv_list = []
sigma_1d = 0.1
x_0 = numpy.array([35,30,0.0,0.0,0.0,0,0,0])
Sigma= numpy.diag( [16,16, 0.01, 0.01, 0.01, 0.01,0.01,0.01])
print(x_0.shape, Sigma.shape)
Sigma_list.append(Sigma)
rv_list.append( scipy.stats.multivariate_normal(mean=x_0, cov=Sigma))  # one-dimensional in one parameter
# for indx in np.arange(n_gauss):
#     x_0 = np.loadtxt("x0_" + str(indx)+".dat")
#     Sigma = np.loadtxt("sigma_"+str(indx)+".dat")
#     rv = scipy.stats.multivariate_normal(mean= x_0,cov=Sigma)
#     x_0_list.append(x_0)
#     Sigma_list.append(Sigma)
#     rv_list.append(rv)
n_dim = len(x_0)


# Grid.  
x_now = numpy.zeros(8)
x_now[0] = P.m1/lal.MSUN_SI
x_now[1] = P.m2/lal.MSUN_SI
x_now[2] = P.s1x
x_now[3] = P.s1y
x_now[4] = P.s1z
x_now[5] = P.s2x
x_now[6] = P.s2y
x_now[7] = P.s2z



# Populate function on the grid
Lval = 0
for param in numpy.arange(n_gauss):
    # print x_now
    # print rv_list[param]
    # print rv_list[param].pdf(x_now)
    Lval+= rv_list[param].pdf(x_now) * wt_array[param]

res  = Lval*numpy.exp(250)  # manual offset
var = 0.1*res*res
neff = 100
ntotal=1e5

# Output
if opts.output_file:
    fname_output_txt = opts.output_file  + ".dat"
    if opts.sim_xml: 
        event_id = opts.event
    else:
        event_id = -1
    if opts.event == None:
        event_id = -1
    if not (P.lambda1>0 or P.lambda2>0):
      if not opts.pin_distance_to_sim:
        numpy.savetxt(fname_output_txt, numpy.array([[event_id, m1, m2, P.s1x, P.s1y, P.s1z, P.s2x, P.s2y, P.s2z,  numpy.log(res)+manual_avoid_overflow_logarithm, numpy.sqrt(var)/res,1e5, neff ]]))  #dict_return["convergence_test_results"]["normal_integral]"
      else:
        # Use case for this scenario is NR, where lambda is not present
        numpy.savetxt(fname_output_txt, numpy.array([[event_id, m1, m2, P.s1x, P.s1y, P.s1z, P.s2x, P.s2y, P.s2z, pinned_params["distance"],  numpy.log(res)+manual_avoid_overflow_logarithm, numpy.sqrt(var)/res,1e5, neff ]]))  #dict_return["convergence_test_results"]["normal_integral]"
    else:
        # Alternative output format if lambda is active
        numpy.savetxt(fname_output_txt, numpy.array([[event_id, m1, m2, P.s1x, P.s1y, P.s1z, P.s2x, P.s2y, P.s2z,  P.lambda1, P.lambda2, numpy.log(res)+manual_avoid_overflow_logarithm, numpy.sqrt(var)/res,ntotal, neff ]]))  #dict_return["convergence_test_results"]["normal_integral]"

