#! /usr/bin/env python
#
#  Based on generate_initial_grid_based_of_gstlal_O2_overlaps.py from Caitlin Rose (?)
#
# 20190215
#
# The overlap files for the gstlal O2 template bank described in https://arxiv.org/abs/1812.05121
# were generated by Heather Fong. 
# The overlap banks are split by Mchirp, with about 500 templates per subbank.
#
# This script sets up the command line to run lalsuite rapidpe_compute_intrinsic_grid to generate
# the initial grid for this rapidpe run. 
#
# For now the lalsuite code is only setup for grids in mass1 mass2 space. it needs to be expanded
#
# util_GridSubsetOfTemplateBank.py  --use-ini path_to_config_file
# util_GridSubsetOfTemplateBank.py --use-ini config_O3.ini --use-bank /home/sinead.walsh/20180214_O2_overlaps_from_heather_fong/

# or
# generate_initial_grid_based_of_gstlal_O2_overlaps.py path_to_config_file "{output_event_ID=name,intrinsic_param=[mass1=1.4,mass2=1.4]}"
# or with spin:
# generate_initial_grid_based_of_gstlal_O2_overlaps.py path_to_config_file "{output_event_ID=name,intrinsic_param=[mass1=1.4,mass2=1.4,spin1z=0.01,spin2z=0.01]}"
# 
# The output file is placed in the event directory
#


import argparse
import sys,os,json,ast,glob,h5py

import numpy as np
from sklearn.neighbors import BallTree
import lal
import RIFT.lalsimutils as lalsimutils
import lalsimulation as lalsim
from RIFT.misc.modules import *   # argh!
import configparser
from configparser import ConfigParser

remap_rpe2rift = {'m1':'mass1','m2':'mass2','s1z':'spin1z', 's2z':'spin2z'}
def translate_params(param):
    if param in remap_rpe2rift:
        return remap_rpe2rift[param]
    return param

parser = argparse.ArgumentParser()
parser.add_argument("--use-ini",default=None,help="ini file (required)")
parser.add_argument("--use-bank",default=None,help="path to bank files (top level directory), required. For example /home/sinead.walsh/20180214_O2_overlaps_from_heather_fong/")
parser.add_argument("--refine-exe",default="util_AMRGrid.py",help="exe for grid refinement name (util_AMRGrid.py)")
parser.add_argument("--extra-ini-args",default=None,help="extra dictionary of kwargs (?)")
parser.add_argument("--output-path",default=".",help="path to output")
# see ManualOverlapGrid, same argument structure here so interoperable. 
parser.add_argument("--inj", dest='inj', default=None,help="inspiral XML file containing the base point.")
parser.add_argument("--event",type=int, dest="event_id", default=None,help="event ID of injection XML to use.")
parser.add_argument("--mass1", default=1.50,type=float,help="Mass in solar masses")  # 150 turns out to be ok for Healy et al sims
parser.add_argument("--mass2", default=1.35,type=float,help="Mass in solar masses")
parser.add_argument("--mc-range",default=None,help="Manually input target chirp mass range")
parser.add_argument("--assume-nospin",action='store_true')
parser.add_argument("--s1z", default=0.,type=float,help="Spin1z")
parser.add_argument("--s2z", default=0.,type=float,help="Spin1z")
parser.add_argument("--eff-lambda", type=float, help="Value of effective tidal parameter. Optional, ignored if not given")
parser.add_argument("--deff-lambda", type=float, help="Value of second effective tidal parameter. Optional, ignored if not given")
parser.add_argument("--parameter", action='append')
parser.add_argument("--parameter-range", action='append', type=str,help="Add a range (pass as a string evaluating to a python 2-element list): --parameter-range '[0.,1000.]'   MUST specify ALL parameter ranges (min and max) in order if used")
parser.add_argument("--random-parameter", action='append',help="These parameters are specified at random over the entire range, uncorrelated with the grid used for other parameters.  Use for variables which correlate weakly with others; helps with random exploration")
parser.add_argument("--random-parameter-range", action='append', type=str,help="Add a range (pass as a string evaluating to a python 2-element list): --parameter-range '[0.,1000.]'   MUST specify ALL parameter ranges (min and max) in order if used.  ")
parser.add_argument("--approx",type=str,default=None)
parser.add_argument("--grid-cartesian-npts", default=100, type=int)
parser.add_argument("--match-value", type=float, default=0.01, help="Use this as the minimum match value. Default is 0.01 (i.e., keep almost everything)")
parser.add_argument("--verbose", action="store_true",default=False, help="Extra warnings")
opts=  parser.parse_args()

if opts.mc_range:
    # If user specifies it, create it
    mc_range  = list(map(int,opts.mc_range.replace('[','').replace(']','').split(',')))
    opts.mc_range = mc_range
else:
    print(" User not specifying mc range, using event-based selection")

cfg = ConfigParser()
cfg.optionxform = str
cfgname = opts.use_ini
if not cfgname:
    print("  No input file ")
    sys.exit(1)

cfg.read(cfgname)

print(("CFGNME",cfgname))
output_parent_directory= cfg.get("General","output_parent_directory") # output parent dir for all triggers from same scenario study etc.

if opts.extra_ini_args:
    kwargs = convert_dict_string_to_dict(opts.extra_ini_args)
else:
    kwargs = convert_section_args_to_dict(cfg,"Event")


output_event_directory= opts.output_path #kwargs["output_event_ID"] #the output directory for the single event trigger being followed up here

# as in MOG
P=lalsimutils.ChooseWaveformParams()
if opts.inj:
    from igwn_ligolw import lsctables, table, utils # check all are needed
    filename = opts.inj
    event = opts.event_id
    xmldoc = utils.load_filename(filename, verbose = True,contenthandler =lalsimutils.cthdler)
    sim_inspiral_table = table.get_table(xmldoc, lsctables.SimInspiralTable.tableName)
    P.copy_sim_inspiral(sim_inspiral_table[int(event)])
    P.fmin =0 #opts.fmin
    if opts.approx:
        P.approx = lalsim.GetApproximantFromString(opts.approx)
        if not (P.approx in [lalsim.TaylorT1,lalsim.TaylorT2, lalsim.TaylorT3, lalsim.TaylorT4]):
            # Do not use tidal parameters in approximant which does not implement them
            print(" Do not use tidal parameters in approximant which does not implement them ")
            P.lambda1 = 0
            P.lambda2 = 0    
else:    
    P.m1 = opts.mass1 *lal.MSUN_SI
    P.m2 = opts.mass2 *lal.MSUN_SI
    P.s1z = opts.s1z
    P.s2z = opts.s2z
    P.dist = 150*1e6*lal.PC_SI
    if opts.eff_lambda and Psig:
        lambda1, lambda2 = 0, 0
        if opts.eff_lambda is not None:
            lambda1, lambda2 = lalsimutils.tidal_lambda_from_tilde(m1, m2, opts.eff_lambda, opts.deff_lambda or 0)
            Psig.lambda1 = lambda1
            Psig.lambda2 = lambda2

    P.fmin=0 #opts.fmin   # Just for comparison!  Obviously only good for iLIGO
    P.ampO=0 #opts.amplitude_order  # include 'full physics'
    P.phaseO =0 # opts.phase_order
    if opts.approx:
        P.approx = lalsim.GetApproximantFromString(opts.approx)
        if not (P.approx in [lalsim.TaylorT1,lalsim.TaylorT2, lalsim.TaylorT3, lalsim.TaylorT4]):
            # Do not use tidal parameters in approximant which does not implement them
            print(" Do not use tidal parameters in approximant which does not implement them ")
            P.lambda1 = 0
            P.lambda2 = 0
    else:
        P.approx = lalsim.GetApproximantFromString("TaylorT4")


intrinsic_param ={}
intrinsic_param["m1"] = P.m1/lal.MSUN_SI
intrinsic_param["m2"] = P.m2/lal.MSUN_SI
if not(opts.assume_nospin):
    intrinsic_param["s1z"] = P.s1z
    intrinsic_param["s2z"] = P.s2z
#intrinsic_param = convert_list_string_to_dict(kwargs["intrinsic_param"])
distance_coordinates = cfg.get("GridRefine","distance-coordinates") if cfg.has_option("GridRefine","distance-coordinates") else ""
additional_command_line_args = convert_cfg_section_to_cmd_line(cfg,"InitialGridOnly") if cfg.has_section("InitialGridOnly") else ""

#script_directory = os.path.dirname(os.path.realpath(__file__))
output_dir =output_event_directory

def main():

    print("kwargs",kwargs)
    print("the following will be added to the grid generation command line",additional_command_line_args)
    print("The directory for the output file is",output_dir)
    if not os.path.isdir(output_dir):
        os.system("mkdir -p "+output_dir)

    #make the output directory your working directory                                                                                                      
    os.chdir(output_dir)

    exe = opts.refine_exe 
    path_to_olap_files = opts.use_bank # "/home/sinead.walsh/20180214_O2_overlaps_from_heather_fong/"
    if not os.path.isdir(path_to_olap_files):
        sys.exit("ERROR: path to overlap files doesn't exist. Make sure you are on the correct filesystem / host cluster for {}".format(path_to_olap_files))
    intrinsic_grid_name_base = "intrinsic_grid"
    initial_grid_xml = intrinsic_grid_name_base+"_iteration_0.xml.gz"
    initial_grid_hdf = intrinsic_grid_name_base+"_all_iterations.hdf"
    #now fill in the rest
    cmd = exe + " --verbose --no-exact-match --setup "+initial_grid_hdf+" --output-xml-file-name "+initial_grid_xml
    if distance_coordinates != "":
        cmd += " -d "+distance_coordinates

    #Add the event trigger parameters, the inital grid will include all points in the overlap bank with overlap < the -T value  
    for param,val in intrinsic_param.items():
        print(param,val)
        translated_param = translate_params(param)  # variable name convention for lower-level variables is different
        cmd += " -i "+translated_param+"="+str(val)

    cmd += additional_command_line_args


    #The overlap files are split by Mchirp, it takes time to check all files and see which one contains our signal. Here, we check the 
    m1 =float(intrinsic_param["m1"])
    m2= float(intrinsic_param["m2"])
    s1 = s2 = 0
    if "s1z" in intrinsic_param:
        s1 =float(intrinsic_param["s1z"])
        s2= float(intrinsic_param["s2z"])
    
    chi_eff_event = transform_s1zs2z_chi(m1,m2,s1,s2)    
    Mchirp_event = ( (m1*m2)**(3/5.0) ) / ( (m1 + m2)**(1/5.0) )
    eta_event = ((m1*m2)/((m1+m2)**2.))
    print("Event mchirp",Mchirp_event,eta_event)

    # from helper code: choose some mc range that's plausible, not a delta function at trigger mass
    fmin_fiducial = 20
    v_PN_param = (np.pi* Mchirp_event*fmin_fiducial*lalsimutils.MsunInSec)**(1./3.)  # 'v' parameter
    snr_fac = 1 # not using that information
    v_PN_param = v_PN_param
    v_PN_param_max = 0.2
    fac_search_correct = 1.5   # if this is too large we can get duration effects / seglen limit problems when mimicking LI
    ln_mc_error_pseudo_fisher = 1.5*np.array(fac_search_correct)*0.3*(v_PN_param/v_PN_param_max)**(7.)/snr_fac 
    if ln_mc_error_pseudo_fisher  >1:
        ln_mc_error_pseudo_fisher =0.8   # stabilize
    mc_max = np.exp( ln_mc_error_pseudo_fisher) * Mchirp_event
    mc_min = np.exp( -ln_mc_error_pseudo_fisher) * Mchirp_event
    if opts.mc_range:
        mc_min = opts.mc_range[0]
        mc_max = opts.mc_range[1]

    #Reducing list of files to those in mchirp range
    olap_filenames = glob.glob(path_to_olap_files+"/*.hdf")
    count_files = 0
    strings_to_include = "{"
    min_dist = -1
    min_dist_filename = ""
    # if we only provide one file, don't bother looking
    if len(olap_filenames)==1:
        cmd += " --use-overlap " + olap_filenames[0]
        count_files +=1
    else:
      for hdf_filename in olap_filenames:
        h5file = h5py.File(hdf_filename,"r")
        wfrm_fam = list(h5file.keys())[0]
        mdata = h5file[wfrm_fam]
#        m1, m2 = mdata["m1"][:], mdata["m2"][:]
        ntemplates = len(mdata["overlaps"])
        m1, m2 = mdata["mass1"][:ntemplates], mdata["mass2"][:ntemplates]
        Mchirps = ( (m1*m2)**(3/5.0) ) / ( (m1 + m2)**(1/5.0) )
#        print Mchirp_event,min(Mchirps),max(Mchirps)
        if (mc_max > min(Mchirps) and  mc_max < max(Mchirps)) or (mc_min > min(Mchirps) and  mc_min < max(Mchirps)) or ( mc_max > max(Mchirps) and mc_min < min(Mchirps)) :
            print(hdf_filename)
            etas = ((m1*m2)/((m1+m2)**2.))
            if 's1z' in intrinsic_param:
                s1, s2 = mdata["spin1z"][:ntemplates], mdata["spin2z"][:ntemplates]
                chi_effs = transform_s1zs2z_chi(m1,m2,s1,s2)    
            #FIXME:even if youre not searching over spin, you want to find the file with the closest template assuming spin=0
            #implement above here at same time as code
            list_for_tree = np.asarray([Mchirps,etas]).T
            pt = np.asarray([Mchirp_event,eta_event])
            if "s1z" in intrinsic_param:
                list_for_tree = np.asarray([Mchirps,etas,chi_effs]).T
                pt = np.asarray([Mchirp_event,eta_event,chi_eff_event])            

            tree = BallTree(list_for_tree)
            dist, m_idx = tree.query(np.atleast_2d(pt), k=1)
            if dist < min_dist or min_dist_filename == "":
                min_dist = dist
                min_dist_filename = hdf_filename

            count_files += 1
#            if not "s1z" in intrinsic_param:
                #FIXME: rapidpe compute grid doesn't consider spin when checking for the closest template, which can lead to incorrect overlaps with other templates because the spin of the closest template assuming zero spin is actually non-zero, and if that was taken into account it wouldn't be the closest template
            cmd += " --use-overlap "+hdf_filename

#    if "s1z" in intrinsic_param:                
#        cmd += " --use-overlap "+min_dist_filename

    print(("CLA",cmd))
    print(("Command line includes",count_files,"files to read"))
    exit_status = os.system(cmd)
    if exit_status != 0:
        print(("ERROR with",cmd))
        sys.exit("ERROR: non zero exit status"+str(exit_status))

    print("[initial_grid_xml="+initial_grid_xml+",initial_grid_hdf="+initial_grid_hdf+"]")

    return

if __name__ == '__main__':
    main()
