#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2012 Chris Pankow
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Creates a dag workflow to perform extrinsic marginalization calculation.
"""

from __future__ import print_function

import sys
import os
import ast
import json
import re
from argparse import ArgumentParser

import numpy as np

import lal

from ligo.lw import utils, ligolw, lsctables, table
lsctables.use_in(ligolw.LIGOLWContentHandler)
from ligo.lw.utils import process

from glue import pipeline

from rapid_pe import common_cl, dagutils
from rapid_pe.common_cl import JOB_PRIORITIES, MAXJOBS

__author__ = "Evan Ochsner <evano@gravity.phys.uwm.edu>, Chris Pankow <pankow@gravity.phys.uwm.edu>, R. O'Shaughnessy"

#
# Option parsing
#

argp = ArgumentParser()
# Options needed by this program only.

argp.add_argument("-T", "--template-bank-xml", help="Input template bank as a sim_inspiral or sngl_inspiral table. Required.")
argp.add_argument("-D", "--working-directory", default="./", help="Directory in which to stage DAG components.")
argp.add_argument("-l", "--log-directory", default="./", help="Directory in which to place condor logs.")
argp.add_argument("-W", "--web-output", default="./", help="Directory to place web accessible plots and webpages.")
argp.add_argument("-O", "--output-name", default="marginalize_extrinsic_parameters", help="Filename (without extension) to write DAG to.")
argp.add_argument("--n-copies", default=1, help="Number of copies of each integrator instance to run per mass point. Default is one.")
argp.add_argument("--write-script", action="store_true", help="In addition to the DAG, write a script to this filename to execute the workflow.")
argp.add_argument("--write-eff-lambda", action="store_true", help="Use psi0 column of template bank XML as effective lambda point to calculate in DAG.")
argp.add_argument("--write-deff-lambda", action="store_true", help="Use psi3 column of template bank XML as delta effective lambda point to calculate in DAG.")
argp.add_argument("--condor-command", action="append", help="Append these condor commands to the submit files. Useful for account group information.")
argp.add_argument("--accounting-group-user",default=None, help="accounting-group-user for condor")
argp.add_argument("--exe-integrate-likelihood", default=None, help="This is executable to use to integrate the extrinsic likelihood per intrinsic grid point. It will default to the lalsuite rapidpe_integrate_extrinsic_likelihood.")
argp.add_argument("--integration-args-dict", default="", help="Pass these options as the kwargs input of the integrate dag creation function. They will be set, without editing, as input to the integration exe. If you use this, it will not be possible to also pass other command line arguments to the integration executable via the create_event_dag command line.")
argp.add_argument("--getenv", nargs="+", help="JSON encoded list of environment variable names to pass to Condor.")
argp.add_argument("--environment", help="JSON encoded dictionary of environment variables to pass to Condor.")
argp.add_argument("--iteration-level", default=None, help="integer denoting the iteration level")
argp.add_argument("--cProfile", action="store_true", help="if True, cProfile each ILE job.")

for cat, val in MAXJOBS.items():
    optname = "--maxjobs-%s" % cat.lower().replace("_", "-")
    argp.add_argument(optname, type=int, default=MAXJOBS[cat], help="Set MAXJOBS in DAGs for category %s. Default is %s" % (cat, str(val)))

# Options transferred to ILE
common_cl.add_datasource_params(argp)
common_cl.add_integration_params(argp)
common_cl.add_output_params(argp)
common_cl.add_intrinsic_params(argp)
common_cl.add_pinnable_params(argp)

opts = argp.parse_args()

if not opts.template_bank_xml:
    exit("Option --template-bank-xml is required.")

condor_commands = None
if opts.condor_command is not None:
    condor_commands = dict([c.split("=") for c in opts.condor_command])
if opts.accounting_group_user:
    condor_commands['accounting_group_user'] = opts.accounting_group_user

if opts.environment is None:
    environment_dict = None
else:
    environment_dict = json.loads(opts.environment)

#
# Get trigger information from coinc xml file
#
# FIXME: CML should package this for us

# Get end time from coinc inspiral table or command line
xmldoc = None
if opts.coinc_xml is not None:
    xmldoc = utils.load_filename(opts.coinc_xml, contenthandler=ligolw.LIGOLWContentHandler)
    coinc_table = table.Table.get_table(
        xmldoc, lsctables.CoincInspiralTable.tableName)
    assert len(coinc_table) == 1
    coinc_row = coinc_table[0]
    event_time = coinc_row.get_end()
    print("Coinc XML loaded, event time: %s" % str(coinc_row.get_end()))
elif opts.event_time is not None:
    # FIXME: Bad hack to make the ILE sub writer ignore this parameter
    opts.coinc_xml = False
    event_time = lal.LIGOTimeGPS(opts.event_time)
    print("Event time from command line: %s" % str(event_time))
else:
    raise ValueError("Either --coinc-xml or --event-time must be provided to parse event time.")

xmldoc, tmplt_bnk = utils.load_filename(opts.template_bank_xml, contenthandler=ligolw.LIGOLWContentHandler), None
try:
    tmplt_bnk = lsctables.SimInspiralTable.get_table(xmldoc)
except ValueError:
    print("Exactly one sim_inspiral table was not found in %s, trying sngl_inspiral" % opts.template_bank_xml, file=sys.stderr)

if tmplt_bnk is None:
    tmplt_bnk = lsctables.SnglInspiralTable.get_table(xmldoc)

#
# Post processing options
#
# FIXME: Remove these entirely
use_bayespe_postproc = False

# initialize the analysis subdag
dag = pipeline.CondorDAG(log=opts.log_directory)

if opts.maxjobs_ile is not None:
    dag.add_maxjobs_category("ILE", opts.maxjobs_ile)

# This is a subdag used for all our plotting and postproc so they don't block
# completion of an individual event's ILEs
ppdag = pipeline.CondorDAG(log=opts.log_directory)
ppdag.add_maxjobs_category("SQL", MAXJOBS["SQL"])
ppdag.add_maxjobs_category("PLOT", MAXJOBS["PLOT"])

if not os.path.exists(opts.log_directory):
    os.makedirs(opts.log_directory) # Make a directory to hold log files of jobs

# All the intrinsic parameters we're gridding in
#NOTE: this isn't used anywhere, the intrinsic parameters are instead read from 
intr_prms = set(("mass1", "mass2"))
for p in ("spin1z", "spin2z"): # FIXME: Add all
    if hasattr(tmplt_bnk[0], p):
        intr_prms.add(p)

# These have explicit options because they map to non-standard columns and I
# want the user to explicity use these columns if they've written them
if opts.write_eff_lambda:
    intr_prms.add("eff_lambda")
if opts.write_deff_lambda:
    intr_prms.add("deff_lambda")

ile_job_type = None
exe = opts.exe_integrate_likelihood if opts.exe_integrate_likelihood is not None else "rapidpe_integrate_extrinsic_likelihood"
python_exe = sys.executable
if opts.integration_args_dict != "":

    ile_job_type, ile_sub_name = dagutils.write_integrate_likelihood_extrinsic_sub(
        tag='integrate',
        condor_commands=condor_commands,
        intr_prms=intr_prms,
        log_dir=opts.log_directory,
        exe=python_exe,
        ncopies=opts.n_copies,
        output_file=opts.output_file,
        iteration_level=opts.iteration_level,
        getenv=opts.getenv,
        environment_dict=environment_dict,
        **ast.literal_eval(opts.integration_args_dict)
        )


else:
    ile_job_type, ile_sub_name = dagutils.write_integrate_likelihood_extrinsic_sub(
        tag='integrate',
        condor_commands=condor_commands,
        intr_prms=intr_prms,
        log_dir=opts.log_directory,
        cache_file=opts.cache_file,
        channel_name=opts.channel_name,
        psd_file=opts.psd_file,
        coinc_xml=opts.coinc_xml,
        reference_freq=opts.reference_freq,
        fmax=(opts.fmax or 2048),
        fmin_template=opts.fmin_template,
		approximant=opts.approximant,
		amp_order=opts.amp_order,
		l_max=opts.l_max,
        event_time=event_time,
        exe=python_exe,
        time_marginalization=opts.time_marginalization,
#FIXME after adding psi phi back        psi_phi_marginalization=opts.psi_phi_marginalization,
        save_samples=opts.save_samples,
        output_file=opts.output_file,
        iteration_level=opts.iteration_level,
        getenv=opts.getenv,
        environment_dict=environment_dict,
        n_eff=opts.n_eff,
        n_max=opts.n_max,
        ncopies=opts.n_copies,
        save_P=opts.save_P,
        n_chunk=opts.n_chunk,
        adapt_floor_level=opts.adapt_floor_level,
        adapt_weight_exponent=opts.adapt_weight_exponent,
        skymap_file=(opts.skymap_file or False),
        distance_maximum=opts.distance_maximum,
        )

ile_sub_name = os.path.join(opts.working_directory,ile_sub_name)
ile_job_type.set_sub_file(ile_sub_name)
ile_job_type.write_sub_file()



with open(ile_sub_name,'r') as integrate_sub_file:
    integrate_subfile_content = integrate_sub_file.read()

if opts.cProfile:
    uniq_str = "$(cluster)-$(process)"
    cprofile_stats_file = os.path.join(opts.log_directory, f'cprofile_integrate-{uniq_str}')
    replacement_pattern = rf'\1 -m cProfile -o {cprofile_stats_file}-\3.out {exe} \2\3\4'
else:
    replacement_pattern =  rf'\1 {exe} \2\3\4'

replaced_integrate_subfile_content = re.sub(r'^(\s*arguments\s*=\s*\")(.*)(ILE_iteration_[0-9]+)(.*)$', replacement_pattern,integrate_subfile_content,flags = re.MULTILINE)

with open(ile_sub_name,'w') as integrate_sub_file:
    integrate_sub_file.write(replaced_integrate_subfile_content)

if use_bayespe_postproc:
    if not os.path.exists(opts.web_output):
        os.makedirs(opts.web_output)
    bpp_plot_job_type, bpp_plot_job_name = dagutils.write_bayes_pe_postproc_sub(tag="bayes_pp_plot", log_dir=opts.log_directory, web_dir=opts.web_output, getenv=opts.getenv, environment_dict=environment_dict)
    bpp_plot_job_type.write_sub_file()
    bpp_plot_node = pipeline.CondorDAGNode(bpp_plot_job_type)
    bpp_plot_node.set_category("PLOT")
    bpp_plot_node.set_pre_script(dagutils.which("bayes_pe_preprocess"))
    ppdag.add_node(bpp_plot_node)

# TODO: Mass index table
for i, tmplt in enumerate(tmplt_bnk):
    mass_grouping = "MASS_SET_%d" % i

    ile_node = pipeline.CondorDAGNode(ile_job_type)
    ile_node.set_priority(JOB_PRIORITIES["ILE"])
    ile_node.add_macro("macromass1", tmplt.mass1)
    ile_node.add_macro("macromass2", tmplt.mass2)
    ile_node.add_macro("iterlevel", opts.iteration_level)
    if opts.write_eff_lambda:
        ile_node.add_macro("macroefflambda", tmplt.psi0)
    if opts.write_deff_lambda:
        ile_node.add_macro("macrodefflambda", tmplt.psi3)
    if hasattr(tmplt, "spin1z"):
        ile_node.add_macro("macrospin1z", tmplt.spin1z)
    if hasattr(tmplt, "spin2z"):
        ile_node.add_macro("macrospin2z", tmplt.spin2z)
    if use_bayespe_postproc:
        # If we're using the Bayesian PE post processing script, dump the data
        ile_node.set_post_script(dagutils.which("process_ile_output"))
        ile_node.add_post_script_arg("--output ILE_%s.txt" % mass_grouping)
        ile_node.add_post_script_arg("--glob *-%s-*.xml.gz" % mass_grouping)

    # This is to identify output from groupings of the same mass point
    ile_node.add_macro("macromassid", mass_grouping)

    ile_node.set_category("ILE")
    dag.add_node(ile_node)


dag_name=opts.output_name
dag.set_dag_file(dag_name)
dag.write_concrete_dag()
if opts.write_script:
    dag.write_script()

with open(f'{opts.output_name}.sh','r') as ile_sh_file:
    ile_sh_file_content = ile_sh_file.read()

replaced_ile_sh_file_content =  re.sub(rf'^({python_exe})(.*)(ILE_iteration_[0-9]+)(.*)$', replacement_pattern, ile_sh_file_content,flags = re.MULTILINE)

with open(f'{opts.output_name}.sh','w') as ile_sh_file:
    ile_sh_file.write(replaced_ile_sh_file_content)

print("Created a DAG named %s\n" % dag_name)
print("This will run %i instances of %s in parallel\n" % (len(tmplt_bnk), ile_sub_name))

# FIXME: Adjust name on command line
if use_bayespe_postproc:
    ppdag_name="posterior_pp"
    ppdag.set_dag_file(ppdag_name)
    ppdag.add_maxjobs_category("ANALYSIS", MAXJOBS["ANALYSIS"])
    ppdag.add_maxjobs_category("POST", MAXJOBS["POST"])
    ppdag.write_concrete_dag()
    if opts.write_script:
        ppdag.write_script()

    print("Created a postprocessing DAG named %s\n" % ppdag_name)

#xmldoc = ligolw.Document()
#xmldoc.appendChild(ligolw.LIGO_LW())
#process.register_to_xmldoc(xmldoc, sys.argv[0], opts.__dict__)
#utils.write_filename(xmldoc, opts.output_name + ".xml.gz")
