#! /usr/bin/env python
#
# GOAL
#
#
# EXAMPLES
#    python pp_RIFT --use-ini sample_pp_config.ini --use-pipe-ini proto.ini
#
# REFERENCES
#   util_RIFT_pseudo_pipe  (for parsing/restructuring arguments)
#
# DEVELOPER WARNINGS
#    - config : assignments are STRINGS, do not use as truth values (e.g., 'False' evaluates to True)

import numpy as np
import argparse
import os
import sys
import shutil

import configparser

from RIFT.misc.dag_utils import mkdir
import RIFT.lalsimutils as lalsimutils
import lal
import lalsimulation as lalsim

from igwn_ligolw import lsctables, table, utils

# Backward compatibility
from RIFT.misc.dag_utils import which
lalapps_path2cache = which('lal_path2cache')
if lalapps_path2cache == None:
    lalapps_path2cache =  which('lalapps_path2cache')


def unsafe_config_get(config,args,verbose=False):
    """
    unsafe_config_get(config, args):
    """
    if verbose:
        print(" Retrieving ", args, end=' ') 
        print(" Found ",eval(config.get(*args)))
    return eval( config.get(*args))


#activate_env ="source /cvmfs/oasis.opensciencegrid.org/ligo/sw/pycbc/x86_64_rhel_7/virtualenv/pycbc-v1.16.13/bin/activate" # else NRSur frames were not loading
def add_frames(channel,input_frame, noise_frame,combined_frame):
    cmd = " add_frames.py " + channel + " " + input_frame + " " + noise_frame + " " + combined_frame #{} {} {} {} ".format(channel,input_frame,noise_frame,combined_frame)
    print(cmd)
    if not opts.test:
        os.system(cmd)

def m1m2(mc, eta):
    # Step 1: Calculate total mass (M = m1 + m2) using mc and eta
    M = mc / (eta**(3/5))  # Based on the chirp mass equation: mc = (m1*m2)^(3/5) / (m1+m2)^(1/5)
    
    # Step 2: Solve for m1 and m2 using the quadratic formula
    # We know: eta = (m1 * m2) / (m1 + m2)^2 => m1*m2 = eta * M^2
    m1m2_product = eta * M**2
    
    # Quadratic equation: x^2 - M*x + m1m2_product = 0, where x is m1
    # Discriminant for the quadratic equation:
    discriminant = M**2 - 4 * m1m2_product
    
    if discriminant < 0:
        raise ValueError("Discriminant is negative. No real solution for m1 and m2.")
    
    # Step 3: Solve for m1 and m2 (roots of the quadratic equation)
    m1 = (M + np.sqrt(discriminant)) / 2
    m2 = (M - np.sqrt(discriminant)) / 2
    
    return m1, m2


# def dmax_seglen_ladder(mc):
#     """
#     dmax, seglen ladder motivated by G2201558 but with nonoverlapping mass bins.  Not sure how we are suppose do seglen < 4, so I am making that the shortest length
#     De facto enables a conditional distance prior, based on the chirp mass.
#     Not clear that a single PP run with this weird conditional distance prior is what is intended.
#     """
#     if mc>50 and mc<100:
#         return [3500,4]
#     elif mc>20 and mc<=50:
#         return [2500,4]
#     elif mc>12.3 and mc<=20:
#         return [1500,4]
#     elif mc>7.9 and mc <=12.3:
#         return [1000,8]
#     elif mc > 5.2 and mc<=7.9:
#         return [800,16]
#     elif mc > 3.4 and mc<=5.2:
#         return [400,32]
#     elif mc > 2.2 and mc<=3.4:
#         return [400,64]
#     elif mc>1.4 and mc<=2.2:
#         return [300,128]
#     elif mc < 1.4:
#         return [150,256]

#################
## Parse command line options for pp_RIFT_with_ini
## user specified
#################

parser = argparse.ArgumentParser()
parser.add_argument("--use-ini",default=None,type=str,help="Pass ini file for parsing. Intended to reproduce lalinference_pipe functionality. Overrides most other arguments. Full path recommended")
parser.add_argument("--use-gwsignal",action='store_true')
#parser.add_argument("--internal-distance-ladder",action='store_true',help="Modify injection distance ladder as in G2201558 ")
parser.add_argument("--use-osg",action='store_true',help="Attempt OSG operation. Command-line level option, not at ini level to make portable results")
parser.add_argument("--use-osg-cip",action='store_true',help="Attempt OSG operation. Command-line level option, not at ini level to make portable results")
parser.add_argument("--add-extrinsic",action='store_true',help="Add extrinsic posterior.  Corresponds to --add-extrinsic --add-extrinsic-time-resampling --batch-extrinsic for pipeline")
parser.add_argument("--test",default=False,action='store_true',help="Used to test the pipeline : prints out commands, generates workflows as much as possible without high-cost steps")
opts =  parser.parse_args()


config = configparser.ConfigParser(allow_no_value=True) #SafeConfigParser deprecated from py3.2
config.read(opts.use_ini)

base_dir = os.getcwd() # relative directory, for use if absolute paths not provided

test_convergence=False
if config.has_option('pp','test_convergence'):
    test_convergence =unsafe_config_get(config,['pp','test_convergence'])

#############
## Create, go to working directory
## with_ini will ideally copy original pp_RIFT dir structure
#############

working_dir = config.get('pp','working_directory')
print(" Working dir ", working_dir)
mkdir(working_dir)
os.chdir(working_dir)
working_dir_full = os.getcwd()

# Define how many events are needed
n_events = unsafe_config_get(config,['pp','n_events'])

# Create injection set
## random parameters based on prior ranges and n_events
## written out to mass mdc file to store all inj
mc_min = float(config.get('priors','mc_min'))
mc_max = float(config.get('priors','mc_max'))
m_min=float(config.get('priors','m_min'))
mc_range = [mc_min,mc_max]
eta_min = float(config.get('priors','eta_min'))
eta_max = float(config.get('priors','eta_max'))
eta_range = [eta_min,eta_max]
d_min = float(config.get('priors','d_min'))
d_max = float(config.get('priors','d_max'))
chi_max = float(config.get('priors','chi_max'))

m1a,m2a = m1m2(mc_range[0], eta_range[0])
m1b,m2b = m1m2(mc_range[0], eta_range[1])
m1c,m2c = m1m2(mc_range[1], eta_range[0])
m1d,m2d = m1m2(mc_range[1], eta_range[1])

# use matter?
lambda_max=0
use_matter=False
manual_extra_ile_args=None
use_4096s_data=False
if config.has_option('priors','use_matter'):
    use_matter=True
    lambda_max = float(config.get('priors','lambda_max'))

approx_str = config.get('waveform','approx')
approx_recover_str = approx_str
if config.has_option('waveform','approx_template'):
    approx_recover_str = config.get('waveform','approx_template')
# danger: lmax only used in analysis, not in generation usually
l_max = 2
if config.has_option('waveform','lmax'):
    l_max = int(config.get('waveform','lmax')  )
use_hlms_as_injections = unsafe_config_get(config,['waveform','use_hlms_as_injections'])
fmin_template = float(config.get('waveform','fmin_template'))
fref = float(config.get('data','fref'))
fmax = float(config.get('data','fmax'))
flow_dict = unsafe_config_get(config,['data','flow'])
srate_data = float(config.get('data', 'srate_data'))
seglen_data = float(config.get('data','seglen_data'))
seglen_analysis = float(config.get('data','seglen_analysis'))
## add option to change analysis waveform IF frames are pregenerated
change_approx =unsafe_config_get(config,['waveform','change_approx'])

if 'NRHybSur' in approx_str:       #for NRSur models
	group_str = config.get('waveform','group')
	param_str = config.get('waveform','param')

if 'TaylorF2Ecc' in approx_str:
	fecc = config.get('waveform','f_ecc')
	amporder = int(config.get('waveform','amporder'))

no_spin =unsafe_config_get(config,['priors','no_spin'])
aligned_spin =unsafe_config_get(config,['priors','aligned_spin'])
precessing_spin=unsafe_config_get(config,['priors','precessing_spin'])
if aligned_spin:
    print(" === aligned-spin PP ==== ")
elif no_spin:
    print(" === zero-spin PP ==== ")
else:
    print(" === precessing PP ==== ")

volumetric_spin =unsafe_config_get(config,['priors','volumetric_spin'])
if volumetric_spin:
    print("  (volumetric spin prior)  ")

fix_sky_location=unsafe_config_get(config,['priors','fix_sky_location'])
fiducial_ra=float(config.get('priors','fiducial_ra'))
fiducial_dec=float(config.get('priors','fiducial_dec'))
fiducial_event_time=float(config.get('priors','fiducial_event_time'))

use_eccentric = unsafe_config_get(config,['priors','use_eccentric'])
use_meanPerAno = unsafe_config_get(config,['priors','use_meanPerAno'])
if use_eccentric:
    print(" === eccentric PP ==== ")
    ecc_min = float(config.get('priors','ecc_min'))
    ecc_max = float(config.get('priors','ecc_max'))
if use_meanPerAno:
    print(" === meanPerAno PP ==== ")
    meanPerAno_min = float(config.get('priors','meanPerAno_min'))
    meanPerAno_max = float(config.get('priors','meanPerAno_max'))


ifos = unsafe_config_get(config,['data','ifos'])
username = config.get('make_workflow','username')
group = config.get('make_workflow','accounting_group')
cip_sampler = config.get('make_workflow', 'cip_sampler_method')
ile_sampler = config.get('make_workflow', 'ile_sampler_method')
ile_n_eff = config.get('make_workflow','ile_n_eff')
if config.has_option('make_workflow','manual_extra_ile_args'):
    manual_extra_ile_args = config.get('make_workflow','manual_extra_ile_args')
if config.has_option('make_workflow','use_gwsignal_lmax_nyquist'):
    use_gwsignal_lmax_nyquist= config.get('make_workflow','use_gwsignal_lmax_nyquist')
ile_n_copies = config.get('make_workflow','ile_n_copies')
ile_request_memory = config.get('make_workflow','ile_request_memory')
n_events_per_worker = config.get('make_workflow','n_events_per_worker')
ile_request_disk = config.get('make_workflow','ile_request_disk')
cip_fit_method = config.get('make_workflow', 'cip_fit_method')
cip_explode_jobs = config.get('make_workflow','cip_explode_jobs')
cip_explode_jobs_last = config.get('make_workflow','cip_explode_jobs_last')
cip_request_memory = config.get('make_workflow','cip_request_memory')
cip_request_disk = config.get('make_workflow','cip_request_disk')
if config.has_option('make_data','use_4096s_data'):
    use_4096s_data = config.get('make_data','use_4096s_data')
use_gpu = config.getboolean('make_workflow', 'use_gpu')
use_distance_marginalization = config.getboolean('make_workflow','use_distance_marginalization')
transverse_puff = config.getboolean('make_workflow','use_transverse_puff')
#use_mu_coord = config.getboolean('make_workflow','use_mu_coord') # will ignore and always do this, if aligned spin


## "default" incompatible with --force-reset-all inn ILE
if ile_sampler == "default":
    ile_sampler = 'adaptive_cartesian_gpu'
# Based on priors and options, create set of signals with specified parameters
## If mdc already present, do not duplicate
## If changing inj approx for direct comparison, ensure tfe has necessary options
## change_approx needs to change call to util_ManualOverlapGrid as well - just change mdc file?

if os.path.exists(working_dir_full+"/mdc.xml.gz"):
    if change_approx:
        ## make sure fecc is set in mdc
        filename = working_dir_full+"/mdc.xml.gz"
        P_list = []; indx=0
        while len(P_list) < n_events:
            P = lalsimutils.ChooseWaveformParams()
            xmldoc = utils.load_filename(filename, verbose = True, contenthandler =lalsimutils.cthdler)
            sim_inspiral_table = lsctables.SimInspiralTable.get_table(xmldoc)
            P.copy_sim_inspiral(sim_inspiral_table[indx])
            P.approx = lalsim.GetApproximantFromString(approx_str)
            if 'TaylorF2Ecc' in approx_str:
                P.ampO = amporder
                P.fecc = fecc
            P_list.append(P)
            indx+=1
        lalsimutils.ChooseWaveformParams_array_to_xml(P_list,"mdc")
    else:
        P_list = lalsimutils.xml_to_ChooseWaveformParams_array(working_dir_full+"/mdc.xml.gz")
# create new injection set
else:
  P_list =[]; indx=0
  while len(P_list) < n_events:
    P = lalsimutils.ChooseWaveformParams()
    # Randomize (sky location, etc)
    P.randomize(dMax=d_max,dMin=d_min,aligned_spin_Q=aligned_spin,volumetric_spin_prior_Q=volumetric_spin,sMax=chi_max)
    P.tref = fiducial_event_time
    P.fmin=fmin_template
    P.deltaF = 1./seglen_data
    P.deltaT = 1./srate_data
    # sky location
    if fix_sky_location:
        P.theta = fiducial_dec
        P.phi = fiducial_ra
    # some redundancy
    if no_spin:
        P.s1x=P.s1y=P.s1z=0
        P.s2x=P.s2y=P.s2z=0
    elif aligned_spin:
        P.s1x = P.s1y=0
        P.s2x = P.s2y=0
    if use_matter:
        P.lambda1 = np.random.uniform(0,lambda_max)
        P.lambda2 = np.random.uniform(0,lambda_max)
    if use_eccentric:
        P.eccentricity = np.random.uniform(ecc_min, ecc_max)
    if use_meanPerAno:
        P.meanPerAno = np.random.uniform(meanPerAno_min, meanPerAno_max)
    if 'TaylorF2Ecc' in approx_str:
        P.ampO = amporder
        
    if not(opts.use_gwsignal):  # with gwsignal, we just pass the approximant on the command line, and must track it differently : XML files have no string approx fields.
        P.approx=lalsim.GetApproximantFromString(approx_str)

    # Uniform in m1 and m2: 
    #m1 = np.random.uniform(mc_range[0],mc_range[1]*2)
    #m2 = np.random.uniform(m_min,mc_range[1]*1.5)
    #m1,m2 = [np.maximum(m1,m2), np.minimum(m1,m2)]
    m1 = np.random.uniform (np.min([m1a,m1b,m1c,m1d]),np.max([m1a,m1b,m1c,m1d]))
    m2 = np.random.uniform (np.min([m2a,m2b,m2c,m2d]),np.max([m2a,m2b,m2c,m2d]))
    # reject m1<m2.  Be wary about folding -- some objects like rectangles do not fold properly, though we should be symmetric
    if m1<m2:
        continue
    P.m1 = m1*lal.MSUN_SI
    P.m2 = m2*lal.MSUN_SI
    # ...but downselect in mchirp, eta
    #mc_val = P.extract_param('mc')/lal.MSUN_SI
    #eta_val = P.extract_param('eta')
    mc_val , eta_val = lalsimutils.Mceta(m1,m2)
    if mc_val < mc_range[0] or mc_val > mc_range[1]:
        continue
    if eta_val < eta_range[0] or eta_val > eta_range[1]:
        continue

    # rescale distance of injection based on mc
    # if opts.internal_distance_ladder:
    #     dmax_scale, seglen_scale = dmax_seglen_ladder(mc_val)
    #     P.dist *=  dmax_scale/d_max   # direct scale
    #     P.deltaF = 1./seglen_scale  


    P_list.append(P)
    indx+=1

  lalsimutils.ChooseWaveformParams_array_to_xml(P_list,"mdc")

# Make sanity check plot of injection masses - helpful to assess if accidents happen
if True:
    from matplotlib import pyplot as plt
    m1= np.zeros(len(P_list))
    m2 = np.zeros(len(P_list))
    for indx, P in enumerate(P_list):
        m1[indx] = P.m1/lal.MSUN_SI
        m2[indx] = P.m2/lal.MSUN_SI
    plt.scatter(m1,m2)
    plt.savefig("prior_m1_m2.png")
  

# Create data files and cache file
#   - probably want to create workflow to build these, rather than do it all in one go... very slow!

print(" === Writing signal files === ")

# made up durations for now
# must be compatible with duration of the noise frames
if use_4096s_data:
    t_start = int(fiducial_event_time)-2048
    t_stop = int(fiducial_event_time)+2048
else:
    t_start = int(fiducial_event_time)-150
    t_stop = int(fiducial_event_time)+150
mkdir('signal_frames')
snr_dict = {}
for indx in np.arange(n_events):
    os.chdir(working_dir_full)
    target_subdir = 'signal_frames/event_{}'.format(indx)
    # Test if directory already exists
    if os.path.exists(target_subdir):
        print(" Signal frames exist for event {}, skipping ".format(indx))
        os.chdir(target_subdir)
        if not(os.path.exists('snr-report.txt')):
            get_snr = "util_FrameZeroNoiseSNR.py --cache signals.cache --psd-file H1=../../../H1-psd.xml.gz --psd-file L1=../../../L1-psd.xml.gz --psd-file V1=../../../V1-psd.xml.gz"
            os.system(get_snr)
        from ast import literal_eval
        with open('snr-report.txt', 'r') as f:
                this_snr_dict = literal_eval(f.readlines()[-1])
                snr_dict[indx] = this_snr_dict['Network']
        continue
    print(" Writing ", indx)
    mkdir(target_subdir)
    os.chdir(working_dir_full+"/"+target_subdir)
    # Loop over instruments, write files
    for ifo in ifos:
        if opts.use_gwsignal:
            cmd = "util_GWSignalWriteFrame.py  --inj " + working_dir_full+"/mdc.xml.gz --event {} --start {} --stop {} --instrument {} --seglen {} --approx {} --fref {} --use-gwsignal-lmax-nyquist {}".format(indx, t_start, t_stop, ifo, seglen_analysis, approx_str, fref, use_gwsignal_lmax_nyquist)
            if use_hlms_as_injections: #approx_str == "SEOBNRv5EHM" or approx_str == "TEOBResumSDALI":
                cmd += " --use-hlms-as-injections "
        elif 'NRHybSur' in approx_str:
            cmd = "util_ROMWriteFrame.py --inj " + working_dir_full+"/mdc.xml.gz --event {} --start {} --stop {} --instrument {} --seglen {} --group {} --param {} --lmax {}".format(indx, t_start, t_stop, ifo, seglen_analysis, group_str, param_str, l_max)
        else:
            cmd = "util_LALWriteFrame.py --inj " + working_dir_full+"/mdc.xml.gz --event {} --start {}  --stop {}  --instrument {} --seglen {} --approx {}".format(indx, t_start,t_stop,ifo, seglen_analysis, approx_str)
        print(cmd)
        if not opts.test:
            os.system(cmd)

    if not opts.test:
        cmd = "/bin/find .  -name '*.gwf' | {} > signals.cache".format(lalapps_path2cache)
        os.system(cmd)
        # Note: this must be COPIED into the run directory if we are using osg

    # Evaluate zero-noise SNR
    #print(" --> FIXME: evaluate zero-noise SNR, to use as hint in place of search <-- ")
    if not(os.path.exists('snr-report.txt')):
        get_snr = "util_FrameZeroNoiseSNR.py --cache signals.cache --psd-file H1=../../../H1-psd.xml.gz --psd-file L1=../../../L1-psd.xml.gz --psd-file V1=../../../V1-psd.xml.gz"
        os.system(get_snr)
        from ast import literal_eval
        with open('snr-report.txt', 'r') as f:
            this_snr_dict = literal_eval(f.readlines()[-1])
            snr_dict[indx] = this_snr_dict['Network']

print(" === Using fiducial noise frames === ")

noise_frame_dir = config.get("make_data","fiducial_noise_frames")

print(" === Joining synthetic signals to reference noise === ")


seglen_actual = t_stop - t_start
os.chdir(working_dir_full)
mkdir('combined_frames')
for indx in np.arange(n_events):
    os.chdir(working_dir_full)
#    mkdir('analysis_event_{}'.format(indx))
    target_subdir='combined_frames/event_{}'.format(indx)
    if os.path.exists(target_subdir):
        print(" Combined frames exist for event {}, skipping ".format(indx))
        continue
    print(" Writing ", indx)
    mkdir(target_subdir)
    os.chdir(working_dir_full+"/"+target_subdir)
    for ifo in ifos:
        fname_input = working_dir_full+"/signal_frames/event_{}/{}-fake_strain-{}-{}.gwf".format(indx,ifo[0],int(t_start),seglen_actual)
        fname_noise= noise_frame_dir + "/" + ifo + ("/%08d.gwf" % indx)
#        fname_output = working_dir_full+"/analysis_event_{}/{}-combined-{}-{}.gwf".format(indx,ifo[0],int(t_start),seglen_actual)
        fname_output = working_dir_full+"/{}/{}-combined-{}-{}.gwf".format(target_subdir,ifo[0],int(t_start),seglen_actual)
        add_frames(ifo+":FAKE-STRAIN",fname_input, fname_noise,fname_output)


print(" === WRITING INI FILES === ")
os.chdir(working_dir_full)
config=configparser.ConfigParser()
config.optionxform = str
#config.read('pseudo_ini.ini')

## ANALYSIS
if 'analysis' not in config.sections():
    config.add_section('analysis')
config['analysis']['ifos'] = "['H1','L1','V1']"
# find out appropriate osg options
config['analysis']['singularity'] = 'False'
config['analysis']['osg'] = 'False'

## DATA
if 'data' not in config.sections():
    config.add_section('data')
config['data']['channels'] = "{'H1': 'FAKE-STRAIN', 'L1':'FAKE-STRAIN','V1':'FAKE-STRAIN'}"

## CONDOR
if 'condor' not in config.sections():
    config.add_section('condor')
config['condor']['accounting_group'] = group
config['condor']['accounting_group_user'] = username

## LALINFERENCE
if 'lalinference' not in config.sections():
    config.add_section('lalinference')
if flow_dict:
    config['lalinference']['flow'] = str(flow_dict)
else:
    config['lalinference']['flow'] = "{'H1': 20, 'L1': 20, 'V1': 20}"
if fmax:
    config['lalinference']['fhigh'] = f"{{'H1': {fmax}, 'L1': {fmax}, 'V1': {fmax}}}"
else:
    config['lalinference']['fhigh'] = "{ 'H1': 2000, 'L1': 2000, 'V1': 2000}"

## ENGINE
if 'engine' not in config.sections():
    config.add_section('engine')
config['engine']['fref'] = str(fref) #str(fmin_template) ## check this!!!!
config['engine']['approx'] = approx_str
if 'TaylorF2Ecc' in approx_str:
    config['engine']['amporder'] = str(amporder)
else:
    config['engine']['amporder'] = '0'
if seglen_analysis:
    config['engine']['seglen'] = str(seglen_analysis)
else:
    config['engine']['seglen'] = '16' 
config['engine']['srate'] = str(srate_data) # get this from user ini
config['engine']['distance-max'] = str(d_max)
config['engine']['distance-min'] = str(d_min)

## RIFT-PSEUDO-PIPE
if 'rift-pseudo-pipe' not in config.sections():
    config.add_section('rift-pseudo-pipe')

# waveform/grid stuff
config['rift-pseudo-pipe']['approx'] = '"{}"'.format(approx_str)
config['rift-pseudo-pipe']['fmin-template'] = str(fmin_template)

# ILE stuff
#config['rift-pseudo-pipe']['manual-extra-ile-args'] = str(manual_extra_ile_args)
if use_gwsignal_lmax_nyquist and not (manual_extra_ile_args):
    config['rift-pseudo-pipe']['manual-extra-ile-args'] = ' --use-gwsignal-lmax-nyquist {} '.format(use_gwsignal_lmax_nyquist)
elif not (use_gwsignal_lmax_nyquist) and manual_extra_ile_args:
    config['rift-pseudo-pipe']['manual-extra-ile-args'] = str(manual_extra_ile_args)
elif use_gwsignal_lmax_nyquist and manual_extra_ile_args:
    config['rift-pseudo-pipe']['manual-extra-ile-args'] = ' --use-gwsignal-lmax-nyquist {} {} '.format(use_gwsignal_lmax_nyquist,manual_extra_ile_args)    
config['rift-pseudo-pipe']['ile-n-eff'] = str(ile_n_eff)
config['rift-pseudo-pipe']['ile-copies'] = str(ile_n_copies)
config['rift-pseudo-pipe']['ile-sampler-method'] = '"{}"'.format(str(ile_sampler))
config['rift-pseudo-pipe']['internal-ile-freezeadapt'] = 'False'
config['rift-pseudo-pipe']['ile-runtime-max-minutes'] = '700'
config['rift-pseudo-pipe']['ile-jobs-per-worker'] = str(n_events_per_worker)
config['rift-pseudo-pipe']['internal-ile-auto-logarithm-offset'] = 'False'
config['rift-pseudo-pipe']['internal-ile-use-lnL'] = 'False'
config['rift-pseudo-pipe']['internal-ile-reset-adapt'] = 'True'
config['rift-pseudo-pipe']['internal-ile-request-disk']=str(ile_request_disk) #'"10M"'
config['rift-pseudo-pipe']['internal-ile-request-memory']=str(ile_request_memory)
if use_gpu:
    config['rift-pseudo-pipe']['ile-force-gpu'] = 'True'
else:
    config['rift-pseudo-pipe']['ile-no-gpu'] = 'True'

# CIP stuff
#  * Fit method can be 'rf', 'gp'
config['rift-pseudo-pipe']['cip-fit-method'] = '"{}"'.format(str(cip_fit_method))#'"rf"'
#  * sampler method can be 'default', 'GMM', 'adaptive_cartesian_gpu', 'AV'
config['rift-pseudo-pipe']['cip-sampler-method'] = '"{}"'.format(str(cip_sampler))
print("Using {} for fitting and {} for sampling".format(config['rift-pseudo-pipe']['cip-fit-method'],config['rift-pseudo-pipe']['cip-sampler-method']))
# run CIP workers simultaneously - larger = faster
config['rift-pseudo-pipe']['cip-explode-jobs'] = str(cip_explode_jobs)
config['rift-pseudo-pipe']['cip-explode-jobs-last'] = str(cip_explode_jobs_last)
config['rift-pseudo-pipe']['internal-cip-request-memory']=str(cip_request_memory)
config['rift-pseudo-pipe']['internal-cip-request-disk']=str(cip_request_disk) #'"10M"'
config['rift-pseudo-pipe']['internal-propose-converge-last-stage'] = 'True'

# spin settings
config['rift-pseudo-pipe']['l-max'] = str(l_max)
if no_spin:
    config['rift-pseudo-pipe']['assume-nospin'] = 'True'
    config['rift-pseudo-pipe']['assume-nonprecessing'] = 'False'
    config['rift-pseudo-pipe']['assume-precessing'] = 'False'
# aligned spin options
elif aligned_spin:
    config['engine']['aligned-spin'] = ''
    config['engine']['alignedspin-zprior'] = ''
    config['rift-pseudo-pipe']['assume-nonprecessing'] = 'True'
    config['rift-pseudo-pipe']['assume-precessing'] = 'False'
    config['rift-pseudo-pipe']['internal-use-aligned-phase-coordinates'] = 'True'
    config['rift-pseudo-pipe']['internal-correlate-default'] = 'True'
# spin precession ** ask about other necessary options
elif precessing_spin:
    config['rift-pseudo-pipe']['assume-nonprecessing'] = 'False'
    config['rift-pseudo-pipe']['assume-precessing'] = 'True'
    config['rift-pseudo-pipe']['internal-use-aligned-phase-coordinates'] = 'True'
    config['rift-pseudo-pipe']['internal-correlate-default'] = 'True'
    # for now set to false:
    config['rift-pseudo-pipe']['internal-use-rescaled-transverse-spin-coordinates']= 'False'
    # ability to turn transverse puff on/off - after testing will default to true?
    if transverse_puff:
        config['rift-pseudo-pipe']['internal-puff-transverse'] = str(transverse_puff)
# something about volmetric prior?
else:
    config['rift-pseudo-pipe']['assume-precessing'] = 'False'

# eccentricity
if use_eccentric:
    config['engine']['ecc_min'] = str(ecc_min)
    config['engine']['ecc_max'] = str(ecc_max)
    config['rift-pseudo-pipe']['assume-eccentric'] = 'True'
    config['rift-pseudo-pipe']['force-ecc-max'] = str(ecc_max)
    config['rift-pseudo-pipe']['force-ecc-min'] = str(ecc_min)
else:
    config['rift-pseudo-pipe']['assume-eccentric'] = 'False'
# meanPerAno
if use_meanPerAno:
    config['rift-pseudo-pipe']['use-meanPerAno'] = 'True'
    config['rift-pseudo-pipe']['force-meanPerAno-max'] = str(meanPerAno_max)
    config['rift-pseudo-pipe']['force-meanPerAno-min'] = str(meanPerAno_min)
# osg
if opts.use_osg:
    config['rift-pseudo-pipe']['use_osg'] = 'True'
    config['rift-pseudo-pipe']['use_osg_file_transfer'] = 'True'
    config['rift-pseudo-pipe']['internal_truncate_files_for_osg_file_transfer'] = 'True'
else:
    config['rift-pseudo-pipe']['use_osg'] = 'False'
    config['rift-pseudo-pipe']['use_osg_file_transfer'] = 'False'
if opts.use_osg_cip:
    config['rift-pseudo-pipe']['use_osg_cip'] = 'True'
else:
    config['rift-pseudo-pipe']['use_osg_cip'] = 'False'
    
if use_distance_marginalization:
    config['rift-pseudo-pipe']['internal-marginalize-distance']='True'
#    config['rift-pseudo-pipe']['ile-distance-prior']='"pseudo_cosmo"'
    config['rift-pseudo-pipe']['internal-distance-max'] = str(d_max)
    # build distance marginalization table NOW
    cmd_here = " util_InitMargTable --d-min {} --d-max {} ".format(d_min,d_max)
    #cmd_here += " --d-prior {} ".format(opts.ile_distance_prior)
    os.system(cmd_here)
    here = os.getcwd()
    config['rift-pseudo-pipe']['internal-marginalize-distance-file']=f"'{here}/distance_marginalization_lookup.npz'"


# extrinsic - add eventually
if opts.add_extrinsic:
    config['rift-pseudo-pipe']['add-extrinsic'] = 'True'
    config['rift-pseudo-pipe']['batch-extrinsic'] = 'True'
    config['rift-pseudo-pipe']['add-extrinsic-time-resampling'] = 'True'

# parameter ranges - may be modified by localize_ini.py
# for grid making but be limiting/redundant?
config['rift-pseudo-pipe']['force-mc-range'] = '[{},{}]'.format(mc_min,mc_max)
config['rift-pseudo-pipe']['force-eta-range'] = '[{},{}]'.format(eta_min,eta_max)
#config['rift-pseudo-pipe']['force-hint-snr'] = snr_dict['Network']

## Write pseudo-pipe ini file with standard defaults
with open("pseudo_ini.ini",'w') as f:
    config.write(f)

##########
## Create event directory structure
##########

P = [P_list[i] for i in range(len(P_list))]
for indx in np.arange(n_events):
    print(" ::: Event {} ::: ".format(indx))
    # Create dir for each event
    dir_target= working_dir_full+'/analysis_event_{}'.format(indx)
    mkdir(dir_target)
    os.chdir(dir_target)
    os.system("ln -sf ../combined_frames/event_{} .".format(indx))
    cmd = "/bin/find -L . -name '*.gwf' | {} > local.cache".format(lalapps_path2cache)
    os.system(cmd)
    here = os.getcwd()
    ## create a coinc with inidividual event from P
    ### does not look like this can handle eccentricity? CHECK
    cmd = "util_SimInspiralToCoinc.py --sim-xml ../mdc.xml.gz --event {}".format(indx)
    for ifo in ifos:
        cmd += "  --ifo {} ".format(ifo)
    os.system(cmd)
    # Copy main pseudo_ini here and run pp_localize_ini.py to see if defaults should change
    os.system('cp ../pseudo_ini.ini pseudo_ini.ini')
    localize = 'python ../../pp_localize_ini.py --event {} --ini pseudo_ini.ini --sim-xml ../mdc.xml.gz --guess-snr {}  --mc-range "[{}, {}]"'.format(indx, snr_dict[indx], mc_min, mc_max)
    os.system(localize)
    ## Make call to pseudo_pipe - should create rundir within event_dir
    if os.path.exists('localize.ini'):
        use_ini = 'localize.ini'
    else:
        use_ini = 'pseudo_ini.ini'
    cmd = 'util_RIFT_pseudo_pipe.py --use-coinc `pwd`/coinc.xml --use-ini `pwd`/{} --use-rundir `pwd`/rundir --fake-data-cache `pwd`/local.cache '.format(use_ini)
    manual_extra_ile = " --d-min {} --d-max {} ".format(d_min, d_max) 
    if fix_sky_location:
        cmd += " --fix-bns-sky --declination {} --right-ascension {} ".format(fiducial_dec, fiducial_ra)
    #if not use_distance_marginalization: # config['rift-pseudo-pipe']['internal-marginalize-distance']:
    cmd += " --manual-extra-ile-args ' {} ' ".format(manual_extra_ile)
    if cip_sampler == 'AV':
        cmd += " --manual-extra-cip-args '--downselect-parameter mc --downselect-parameter-range '[{},{}]''".format(mc_min, mc_max)
    if opts.use_gwsignal:
        cmd += " --use-gwsignal "
    print(cmd)
    os.system(cmd)
    os.chdir("rundir")
    os.system("ln -sf {}/*.xml.gz .".format(base_dir))
    os.system("cp ../local.cache .")  # make sure we have a copy of the cache file, to generate truncated files for the OSG if we are using OSG mode
    if opts.use_osg:
        os.system('util_ForOSG_MakeTruncatedLocalFramesDir.sh .')


##########
## Consolidate DAGs
##########

os.chdir(working_dir_full)
cmd = 'util_ConsolidateDAGsUnderMaster.sh analysis_event*/rundir'
print(cmd)
if not opts.test:
    os.system(cmd)
    mod = 'grep -v cip master.dag > master_clean.dag'
    os.system(mod)
    os.system('rm master.dag') # remove original file, confusing otherwise
