"""Utility to load waveform of different origins."""
import os
import numpy as np
import sxs
import scri
import h5py
import lal
import lalsimulation as lalsim
import warnings
import json
import spherical_functions as sf
from copy import deepcopy
from .utils import peak_time_via_quadratic_fit
from .utils import amplitude_using_all_modes
from .utils import check_kwargs_and_set_defaults
from .utils import raise_exception_if_none
from .utils import interpolate


def get_available_waveform_origins(return_dict=False):
    """Get available origins of waveforms that could be loaded.

    parameters:
    -----------
    return_dict: bool
        If True, returns a dictionary of origins and corresponding loading
        functions, otherwise just the list of origins.
        Default is False.
    """
    origin_dict = {
        "LAL": load_LAL_waveform,
        "SXSCatalog": load_sxs_catalogformat,
        "SXSCatalog_old": load_sxs_catalogformat_old,
        "LVCNR": load_lvcnr_waveform,
        "LVCNR_hack": load_lvcnr_hack,
        "EOB": load_EOB_waveform,
        "EMRI": load_EMRI_waveform}

    return origin_dict if return_dict else list(origin_dict.keys())


def get_load_waveform_docs(origin):
    """Get the docs for the loading function for given waveform origin."""
    # check origin
    origins = get_available_waveform_origins(return_dict=True)
    if origin not in origins:
        raise Exception(f"Unknown {origin}. Must be one of "
                        f"{list(origins.keys())}")
    return help(origins[origin])


def get_load_waveform_defaults(origin="LAL"):
    """Get the dictionary of default kwargs.

    parameters:
    -----------
    origin: str
        The origin of the waveform. See under
        `load_data.load_waveform` for more details.
        The possible values of origins can be obatined using
        `load_data.get_available_waveform_origins.`

    returns:
    --------
    default_kwargs: dict
        A dictionary of default kwargs for the given origin.
    """
    # for waveforms using LALSimulation
    if origin == "LAL":
        return {
            "approximant": None,
            "q": None,
            "chi1": None,
            "chi2": None,
            "ecc": None,
            "mean_ano": None,
            "Momega0": None,
            "deltaTOverM": 0.1,
            "physicalUnits": False,
            "M": None,
            "D": None,
            "include_zero_ecc": False
        }
    # for loading waveform file in SXS catalog format
    elif origin in ["SXSCatalog", "SXSCatalog_old"]:
        kwargs_list = ["data_dir",
                       "metadata_path",
                       "deltaTOverM",
                       "include_zero_ecc",
                       "include_params_dict",
                       "zero_ecc_approximant",
                       "num_orbits_to_remove_as_junk",
                       "mode_array",
                       "extrap_order"]
        if origin == "SXSCatalog":
            # SXS waveforms in the new catalog format comes with memory
            # correction. By default we remove this memory correction from the
            # waveform modes for measuring eccentricity. One can opt to keep
            # memory using the following kwarg.
            kwargs_list.append("keep_memory")
        return make_a_sub_dict(get_defaults_for_nr(), kwargs_list)
    # for waveforms in LVCNR format file using recommended function in LALSuite
    elif origin == "LVCNR":
        kwargs_list = ["filepath",
                       "deltaTOverM",
                       "Momega0",
                       "include_zero_ecc",
                       "include_params_dict",
                       "zero_ecc_approximant",
                       "num_orbits_to_remove_as_junk"]
        # We get the dictionary for all possible nr kwargs and
        # then make a dictionary from that for the keys in kwargs_list
        return make_a_sub_dict(get_defaults_for_nr(), kwargs_list)
    # for loading waveform in LVCNR format file using h5py
    elif origin == "LVCNR_hack":
        kwargs_list = ["filepath",
                       "deltaTOverM",
                       "include_zero_ecc",
                       "include_params_dict",
                       "zero_ecc_approximant",
                       "num_orbits_to_remove_as_junk"]
        return make_a_sub_dict(get_defaults_for_nr(), kwargs_list)
    # for loading SEOBNRv4EHM waveform files generated using Toni's code.
    elif origin == "EOB":
        return {"filepath": None,
                "include_zero_ecc": False,
                "filepath_zero_ecc": None}
    # for loading EMRI waveforms generated by Maarten
    elif origin == "EMRI":
        return {"filepath": None,
                "include_zero_ecc": False,
                "filepath_zero_ecc": None,
                "start_time": None,
                "end_time": None,
                "deltaT": None,
                "include_geodesic_ecc": False}
    else:
        raise Exception(f"Unknown origin {origin}. Must be one of "
                        f"{get_available_waveform_origins()}.")


def make_a_sub_dict(super_dict, sub_dict_keys):
    """Make a sub dictionary from super dictionary.

    parameters:
    -----------
    super_dict: dict
        A bigger dictionary from which a smaller dictionary is to be created.
    sub_dict_keys: list
        List of keys to be used for creating the sub (smaller) dictionary.

    returns:
    --------
    sub_dict:
        A dictionary of keys from `sub_dict_keys` and values from `super_dict`.
    """
    sub_dict = {}
    for kw in sub_dict_keys:
        if kw not in super_dict:
            raise Exception(f"kw {kw} not found. Should be one of "
                            f"{list(super_dict.keys())}")
        sub_dict[kw] = super_dict[kw]
    return sub_dict


def load_waveform(origin="LAL", **kwargs):
    """Load waveform.

    Parameters
    ----------
    origin: str
        The origin of the waveform to be generated/loaded. This can be one of

        - "LAL": Compute waveform by a call to the LAL-library.
            (see https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/)
        - "SXSCatalog": Import waveform by reading a file in the SXS catalog
            in the new format (from 2023 onward) where the waveform files are
            named as `Strain_N{extrap_order}.h5` for a given extrapolation
            order `extrap_order`.
            (see https://data.black-holes.org/waveforms/documentation.html)
        - "SXSCatalog_old": Import waveform by reading a file in the SXS
            catalog in the old format (prior to 2023) where the waveform file
            is named as `rhOverM_Asymptotic_GeometricUnits_CoM.h5`.
            (see https://data.black-holes.org/waveforms/documentation.html)
        - "LVCNR": Import waveform by reading a file in the LVCNR-data format.
            (see https://arxiv.org/abs/1703.01076)
        - "LVCNR_hack": Reading LVCNR-data format file using h5py.
            NOTE: This is NOT the recommended way to load lvcnr file.
        - "EOB": Import EOB waveform generated using SEOBNRv4EHM
            (see arxiv:2112.06952).
        - "EMRI": Import EMRI waveform generated by Maarten.

    kwargs:
        Kwargs dictionary to be passed to the waveform loading functions.
        Allowed kwargs depend on origin. Run
        `load_data.get_load_waveform_defaults(origin)` to see
        allowed keys and defaults.
    Returns
    -------
    dataDict:
        Dictionary of time, modes etc. For detailed structure of the returned
        dataDict see gw_eccentricity.measure_eccentricity.
    """
    available_origins = get_available_waveform_origins(return_dict=True)
    if origin in available_origins:
        return available_origins[origin](**kwargs)
    else:
        raise Exception(f"Unknown origin {origin}. "
                        f"Should be one of {list(available_origins.keys())}")


def load_LAL_waveform(**kwargs):
    """Load waveforms calling the LAL Library.

    The kwargs could be the following:
    Run `load_data.get_load_waveform_defaults('LAL')` to see allowed
    keys and defaults.

    approximant: str
        Name of the waveform model to be used for generating the waveform.
    q: float
        Mass ratio of the system.
    chi1: 1d array of size 3
        3-element 1d array of spin components of the 1st Black hole.
    chi2: 1d array of size 3
        3-element 1d array of spin components of the 2nd Black hole.
    ecc: float
        Initial eccentricity of the binary at Momega0 (see below).
    mean_ano: float
        Initial Mean anomaly of the binary at Momega0 (see below).
    Momega0: float
        Starting orbital frequency in dimensionless units.
    deltaTOverM: float
        Time steps in dimensionless units.
    physicalUnits: bool
        If True, returns modes in MKS units.
    M: float
        Total mass in units of solar mass. Required when physicalUnits
        is True.
    D: float
        Luminosity distance in units of megaparsec. Required when
        physicalUnits is True.
    include_zero_ecc: bool
        If True, a quasicircular waveform is created and
        returned. The quasicircular waveform is generated using the
        same set of parameters except eccentricity set to zero.
        In some cases, e = 0 is not supported and we set it to a small value
        like e = 1e-5.
    """
    default_lal_kwargs = get_load_waveform_defaults("LAL")
    # check and set default kwargs
    check_kwargs_and_set_defaults(kwargs, default_lal_kwargs, "LAL Kwargs",
                                  "load_data.get_load_waveform_defaults")
    raise_exception_if_none(
        kwargs,
        ["approximant", "q", "chi1", "chi2", "ecc", "mean_ano", "Momega0"],
        "LAL kwargs", "`load_data.load_LAL_waveform.py`")
    # FIXME, this assumes single mode models, talk to Vijay about
    # how to handle other models.
    dataDict = load_LAL_waveform_using_hack(
        kwargs['approximant'],
        kwargs['q'],
        kwargs['chi1'],
        kwargs['chi2'],
        kwargs['ecc'],
        kwargs['mean_ano'],
        kwargs['Momega0'],
        kwargs['deltaTOverM'],
        kwargs['physicalUnits'],
        kwargs['M'],
        kwargs['D'])

    if kwargs['include_zero_ecc']:
        # Keep all other params fixed but set ecc=0.
        zero_ecc_kwargs = kwargs.copy()
        # EccentricTD does not support eccentricity < 1e-5
        if kwargs["approximant"] == "EccentricTD":
            zero_ecc_kwargs['ecc'] = 1e-5
        else:
            zero_ecc_kwargs['ecc'] = 0
        zero_ecc_kwargs['include_zero_ecc'] = False   # to avoid infinite loops
        dataDict_zero_ecc = load_waveform(**zero_ecc_kwargs)
        # To make sure that we can compute the residual amplitude/frequency, we
        # need the zeroecc data to be longer than the ecc data so that we can
        # intepolate the zeroecc data on the same times as the ecc data.
        while dataDict_zero_ecc['t'][0] > dataDict["t"][0]:
            zero_ecc_kwargs["Momega0"] *= 0.9
            dataDict_zero_ecc = load_waveform(**zero_ecc_kwargs)
        t_zeroecc = dataDict_zero_ecc['t']
        hlm_zeroecc = dataDict_zero_ecc['hlm']
        dataDict.update({'t_zeroecc': t_zeroecc,
                         'hlm_zeroecc': hlm_zeroecc})
    return dataDict


def load_LAL_waveform_using_hack(approximant, q, chi1, chi2, ecc, mean_ano,
                                 Momega0, deltaTOverM, physicalUnits, M, D):
    """Load LAL waveforms."""
    # Many LAL models don't return the modes. So, to get h22 we evaluate the
    # strain at (incl, phi)=(0,0) and divide by Ylm(0,0).  NOTE: This only
    # works if the only mode is the (2,2) mode.
    phi_ref = 0
    inclination = 0

    # h = hp -1j * hc
    t, h = generate_LAL_waveform(approximant, q, chi1, chi2,
                                 deltaTOverM, Momega0, eccentricity=ecc,
                                 phi_ref=phi_ref, inclination=inclination,
                                 physicalUnits=physicalUnits, M=M, D=D)

    Ylm = lal.SpinWeightedSphericalHarmonic(inclination, phi_ref, -2, 2, 2)
    mode_dict = {(2, 2): h/Ylm}
    # Make t = 0 at the merger. This would help when getting
    # residual amplitude by subtracting quasi-circular counterpart
    t = t - peak_time_via_quadratic_fit(
        t,
        amplitude_using_all_modes(mode_dict))[0]

    dataDict = {"t": t, "hlm": mode_dict}
    return dataDict


def generate_LAL_waveform(approximant, q, chi1, chi2, deltaTOverM, Momega0,
                          inclination=0, phi_ref=0., longAscNodes=0,
                          eccentricity=0, meanPerAno=0,
                          alignedSpin=True, lambda1=None, lambda2=None,
                          physicalUnits=False, M=None, D=None):
    """Generate waveform for a given approximant using LALSuite.

    Returns dimless time and dimless complex strain.

    Parameters
    ----------
    approximant : str
        Name of approximant.
    q : float
        Mass ratio q>=1.
    chi1 : array/list of len=3
        Dimensionless spin vector of larger BH.
    chi2 : array/list of len=3
        Dimensionless spin vector of smaller BH.
    deltaTOverM : float
        Dimensionless time step size.
    Momega0 : float
        Dimensionless starting orbital frequency for waveform (rad/s).
    inclination : float
        Inclination angle in radians.
    phi_ref : float
        Lalsim stuff.
    longAscNodes : float
        Longiture of Ascending nodes.
    eccentricity : float
        Eccentricity.
    meanPerAno : float
        Mean anomaly of periastron.
    alignedSpin
        Assume aligned spin approximant.
    lambda1
        Tidal parameter for larger BH.
    lambda2
        Tidal parameter for smaller BH.
    physicalUnits
        If True, return in physical units.
    M
        Total mass in units of solar mass. Required when physicalUnits is True.
    D
        Luminosity distance in units of mega parsec. Required when
        physicalUnits is True.

    Returns
    -------
    t : array
        Dimensionless time.
    h : complex array
        Dimensionless complex strain h_{+} -i*h_{x}.
    """
    chi1 = np.array(chi1)
    chi2 = np.array(chi2)

    if alignedSpin:
        if np.sum(np.sqrt(chi1[:2]**2)) > 1e-5 or np.sum(
                np.sqrt(chi2[:2]**2)) > 1e-5:
            raise Exception("Got precessing spins for aligned spin "
                            "approximant.")
        if np.sum(np.sqrt(chi1[:2]**2)) != 0:
            chi1[:2] = 0
        if np.sum(np.sqrt(chi2[:2]**2)) != 0:
            chi2[:2] = 0

    # sanity checks
    if np.sqrt(np.sum(chi1**2)) > 1:
        raise Exception('chi1 out of range.')
    if np.sqrt(np.sum(chi2**2)) > 1:
        raise Exception('chi2 out of range.')
    if len(chi1) != 3:
        raise Exception('chi1 must have size 3.')
    if len(chi2) != 3:
        raise Exception('chi2 must have size 3.')

    # if physicalUnits is True then M and D must be provided
    if physicalUnits:
        if M is None:
            raise Exception("Must provide total mass `M` for physical units.")
        if D is None:
            raise Exception("Must provide luminosity distance `D` for physical units.")
        distance = D * 1e6 * lal.PC_SI
    else:
        # use M=10 and distance=1 Mpc, but will scale these out before outputting h
        M = 10      # dimless mass
        distance = 1.0e6 * lal.PC_SI

    approxTag = lalsim.GetApproximantFromString(approximant)
    MT = M * lal.MTSUN_SI
    f_low = Momega0/np.pi/MT
    f_ref = f_low

    # component masses of the binary
    m1_kg = M * lal.MSUN_SI * q / (1. + q)
    m2_kg = M * lal.MSUN_SI / (1. + q)

    # tidal parameters if given
    if lambda1 is not None or lambda2 is not None:
        dictParams = lal.CreateDict()
        lalsim.SimInspiralWaveformParamsInsertTidalLambda1(dictParams, lambda1)
        lalsim.SimInspiralWaveformParamsInsertTidalLambda2(dictParams, lambda2)
    else:
        dictParams = None

    hp, hc = lalsim.SimInspiralChooseTDWaveform(
        m1_kg, m2_kg, chi1[0], chi1[1], chi1[2], chi2[0], chi2[1], chi2[2],
        distance, inclination, phi_ref,
        longAscNodes, eccentricity, meanPerAno,
        deltaTOverM*MT, f_low, f_ref, dictParams, approxTag)

    h = np.array(hp.data.data - 1.j*hc.data.data)
    t = deltaTOverM * MT * np.arange(len(h)) if physicalUnits else (
        deltaTOverM * np.arange(len(h)))

    return t, h if physicalUnits else h * distance/MT/lal.C_SI


def time_dimless_to_mks(M):
    """Factor to convert time from dimensionless units to SI units.

    parameters
    ----------
    M:
        Mass of system in the units of solar mass.

    Returns
    -------
    converting factor
    """
    return M * lal.MTSUN_SI


def amplitude_dimless_to_mks(M, D):
    """Factor to rescale amp from dimensionless units to SI units.

    parameters
    ----------
    M:
        Mass of the system in units of solar mass.
    D:
        Luminosity distance in units of megaparsecs.

    Returns
    -------
    Scaling factor
    """
    return lal.G_SI * M * lal.MSUN_SI / (lal.C_SI**2 * D * 1e6 * lal.PC_SI)


def get_defaults_for_nr():
    """Get a dictionary of default values used by different NR loading functions.

    This function returns a dictionary of all possible kwargs used in
    the different NR loading functions. This dictionary, therefore, is
    a superset of possible list of kwargs for a given NR loading
    function.

    The keys are the following:
    filepath: str
        Path to the nr file. Default value is None.

    data_dir: str
        Directory to look for the files necessary to load the NR waveform.
        Default is None.

    deltaTOverM: float
        Time step in dimensionless unit. Default is 0.1

    Momega0: float
        Lower frequency to start waveform generation.
        If Momega0 = 0, uses the entire NR data. The actual Momega0 will be
        returned.
        NOTE: This key might be used only for LVCNR format waveform.
        Default is 0.

    include_zero_ecc: bool
        If True returns waveform mode for same set of parameters
        except eccentricity set to zero.
        Default is False.

    include_params_dict: bool
        If True, returns a dictionary containing paramaters of the binary.

    zero_ecc_approximant: str
        Waveform model to generate zero ecc waveform when
        `include_zero_ecc` is True.
        Default is IMRPhenomT

    metadata_path: str
        NOTE: Only for SXS catalog format waveform.
        Path to the sxs metadata file. This file generally can be found in the
        same directory as the waveform file and has the name `metadata.txt`
        (for SXSCatalog_old) or `metadata.json` (for SXSCatalog). It contains
        the metadata including binary parameters along with other information
        related to the NR simulation performed to obtain the waveform modes.
        Required when `include_zero_ecc` or `include_params_dict` or
        `keep_memory` (available only for `SXSCatalog`) is False.  If
        provided, a dictionary containing binary mass ratio, spins and the
        relaxation time is returned.
        Default is None.

    num_orbits_to_remove_as_junk: float
        Number of orbits to throw away as junk from the begining of the NR
        data.
        Default is 2.

    mode_array: 1d array
        1d array of modes to load.
        Default is [(2, 2)] which loads only the (2, 2) mode.

    extrap_order: int
        Extrapolation order to use for loading the waveform data.
        NOTE: This is used only for sxs catalog formatted waveforms.

    keep_memory: bool
        If False, remove memory contribution from the waveform modes.  This will
        require metadata file to find t_relax which is used to start the
        integration for computing memory contribution.  NOTE: This can be used
        only in the newer sxs catalog formatted waveforms with
        origin=`SXSCatalog`
        Default is False.
    """
    return {"filepath": None,
            "data_dir": None,
            "deltaTOverM": 0.1,
            "Momega0": 0.0,
            "include_zero_ecc": False,
            "include_params_dict": False,
            "zero_ecc_approximant": "IMRPhenomT",
            "metadata_path": None,
            "num_orbits_to_remove_as_junk": 2,
            "mode_array": [(2, 2)],
            "extrap_order": 2,
            "keep_memory": False}


def load_lvcnr_waveform(**kwargs):
    """Load modes from lvcnr files.

    Loading waveform modes from files in lvcnr format require the file
    `SEOBNRv4ROM_v2.0.hdf5` in `LAL_DATA_PATH`. This file can be downloaded
    from
    https://git.ligo.org/lscsoft/lalsuite-extra/-/blob/master/data/lalsimulation/SEOBNRv4ROM_v2.0.hdf5
    and the path can be set using `export
    LAL_DATA_PATH=/path/to/directory/containing/seobnrv4rom_file/`.

    Parameters
    ----------
    kwargs: Could be the following.
    Run `load_data.get_load_waveform_defaults('LVCNR')` to see allowed
    keys and defaults.

    filepath: str
        Path to lvcnr file in format described in arXiv:1703.01076.

    deltaTOverM: float
        Time step in dimensionless units.

    Momega0: float
        Lower frequency to start waveform generation.
        If Momega0 = 0, uses the entire NR data.

    include_zero_ecc: bool
        If True, returns zero eccentricity waveform mode (only (2, 2) mode) for
        the same set of parameters except eccentricity set to zero.
        The zero eccentricity waveform is generated using the waveform model
        provided via `zero_ecc_approximant` (see below).

    include_params_dict: bool
        If True, returns a dictionary of binary parameters.

    zero_ecc_approximant: str
        Waveform model to generate zero eccentricity waveform.

    num_orbits_to_remove_as_junk: float
        Number of orbits to throw away as junk from the beginning of the NR
        data.

    Returns
    -------
        Dictionary of time and modes dictionary. Optionally the returned
        dictionary includes a dictionary of binary parameters, dictionary of
        zero eccentricity modes etc.

    t:
        Time array in dimensionless units. This already discards the first
        `num_orbits_to_remove_as_junk` orbits.
    hlm:
        Dictionary of modes in dimensionless units. This already discards the
        first `num_orbits_to_remove_as_junk` orbits.  To get a particular mode,
        do h22 = hlm[(2, 2)].

    Optionally,
    params_dict:
        Dictionary of parameters of the binary. Returned when
        `include_params_dict` is True.

    t_zeroecc:
        1d uniform array of times corresponding to zero eccentricity
        modes in dimensionless units.
        Returned when `include_zero_ecc` is True.

    hlm_zeroecc:
        Dictionary of modes created using `zero_ecc_approximant` model
        with the same mass ratio and spin components as the NR
        simulation and eccentricity set to zero. Currently, it contains
        only the (2, 2) mode.
        Returned when `include_zero_ecc` is True.
    """
    kwargs = check_kwargs_and_set_defaults(
        kwargs,
        get_load_waveform_defaults("LVCNR"),
        "LVCNR kwargs",
        "`load_data.get_defaults_for_nr`"
    )

    filepath = kwargs["filepath"]
    M = 10  # will be factored out
    dt = kwargs["deltaTOverM"] * time_dimless_to_mks(M)
    dist_mpc = 1  # will be factored out
    f_low = kwargs["Momega0"] / np.pi / time_dimless_to_mks(M)

    NRh5File = h5py.File(filepath, "r")
    params_NR = lal.CreateDict()
    lalsim.SimInspiralWaveformParamsInsertNumRelData(params_NR, filepath)

    # Metadata parameters masses:
    m1 = NRh5File.attrs["mass1"]
    m2 = NRh5File.attrs["mass2"]
    m1SI = m1 * M / (m1 + m2) * lal.MSUN_SI
    m2SI = m2 * M / (m1 + m2) * lal.MSUN_SI

    distance = dist_mpc * 1.0e6 * lal.PC_SI
    # If f_low == 0, update it to the start frequency so that
    # we get the right start frequency
    if f_low == 0:
        f_low = NRh5File.attrs["f_lower_at_1MSUN"] / M
    f_ref = 0  # Non zero f_ref is not supported since the lvcnr format of the
    # files we are testing is format 1.
    spins = lalsim.SimInspiralNRWaveformGetSpinsFromHDF5File(f_ref, M,
                                                             filepath)
    s1x = spins[0]
    s1y = spins[1]
    s1z = spins[2]
    s2x = spins[3]
    s2y = spins[4]
    s2z = spins[5]

    # Generating the NR modes
    values_mode_array = lalsim.SimInspiralWaveformParamsLookupModeArray(
        params_NR)
    _, modes = lalsim.SimInspiralNRWaveformGetHlms(
        dt,
        m1SI,
        m2SI,
        distance,
        f_low,
        f_ref,
        s1x,
        s1y,
        s1z,
        s2x,
        s2y,
        s2z,
        filepath,
        values_mode_array)

    modes_dict = {}
    while modes is not None:
        modes_dict[(modes.l, modes.m)] = (
            modes.mode.data.data
            / amplitude_dimless_to_mks(M, dist_mpc))
        modes = modes.next

    t = np.arange(len(modes_dict[(2, 2)])) * dt
    t = t / time_dimless_to_mks(M)
    # shift the times to make merger a t = 0
    t = t - peak_time_via_quadratic_fit(
        t,
        amplitude_using_all_modes(modes_dict))[0]

    q = m1SI/m2SI
    try:
        eccentricity = float(NRh5File.attrs["eccentricity"])
    except ValueError:
        eccentricity = None
    try:
        mean_anomaly = float(NRh5File.attrs["mean_anomaly"])
    except ValueError:
        mean_anomaly = None

    NRh5File.close()

    # remove junk from the begining of the data
    t, modes_dict = reomve_junk_from_nr_data(
        t,
        modes_dict,
        kwargs["num_orbits_to_remove_as_junk"])

    return_dict = {"t": t,
                   "hlm": modes_dict}

    if kwargs["include_params_dict"] or kwargs["include_zero_ecc"]:
        params_dict = {"q": q,
                       "chi1": [s1x, s1y, s1z],
                       "chi2": [s2x, s2y, s2z],
                       "ecc": eccentricity,
                       "mean_ano": mean_anomaly}

    if kwargs["include_zero_ecc"]:
        params_dict_zero_ecc = params_dict.copy()
        params_dict_zero_ecc.update(
            {"approximant": kwargs["zero_ecc_approximant"]})
        dataDict_zero_ecc = get_zeroecc_dataDict_for_nr(
            return_dict, params_dict_zero_ecc)
        return_dict.update(dataDict_zero_ecc)
    if kwargs["include_params_dict"]:
        return_dict.update({"params_dict": params_dict})
    return return_dict


def load_sxs_catalogformat(**kwargs):
    """Load modes from sxs waveform files in sxs catalog format.

    This function is intended for loading waveform modes from files in the sxs
    catalog in the new format, i.e., from 2023 onward. In the new format, the
    waveform files are named as `Strain_N{extrap_order}.h5`. See under
    `data_dir` below for more details.
    (Also see https://data.black-holes.org/waveforms/documentation.html).

    For loading sxs catalog waveforms in old format where the waveform file is
    named as `rhOverM_Asymptotic_GeometricUnits_CoM.h5`, see
    `load_sxs_catalogformat_old`.
    For loading lvcnr format files, see `load_lvcnr_waveform`.

    Parameters
    ----------
    kwargs: Dictionary with the following keys.
    Run `load_data.get_load_waveform_defaults('SXSCatalog')` to see allowed
    keys and defaults.

    data_dir: str
        Path to the directory to look for the waveform files in sxs catalog
        format. This function looks for three files in the `data_dir` based on
        the `extrap_order`:

        1. The strain file `Strain_N{extrap_order}.h5` (required)
        2. The corresponding json file `Strain_N{extrap_order}.json` (required)
        3. The metadata file `metadata.json` (required when `include_zero_ecc`
          or `include_params_dict` is True or `keep_memory` is False)
        4. The horizon file `Horizons.h5` (optional)

        `Strain_N{extrap_order}.h5` contains the waveform extrapolated to
        future null-infinity and corrected for initial center-of-mass
        drift. This and `Strain_N{extrap_order}.json` must be provided to load
        waveform modes.

        When `include_zero_ecc` or `include_params_dict` is True or
        `keep_memory` is False, `metadata.json` is required to obtain the
        parameters used in the NR simulation. See more under
        `get_params_dict_from_sxs_metadata`.

        If `Horizons.h5` is provided, it is used to get a better estimate of
        the duration of an orbit from phase data to use it for removing junk
        radiation. See more about it under `num_orbits_to_remove_as_junk`.

    deltaTOverM: float
        Time step to use for interpolating the waveform modes.  The
        unit is the same as the time array in the sxs catalog format
        waveform file which is dimensionless.

    include_zero_ecc: bool
        If True, returns waveform mode (only (2, 2) mode)
        for the same set of parameters except with eccentricity set to
        zero.

        When set to True, the function will search for the `metadata.json` file
        in the `data_dir` directory. Typically, the `metadata.json` file is
        located in the same directory as the waveform file within the sxs
        catalog. The `metadata.json` file is essential for extracting binary
        parameters and related metadata, as it typically contains crucial
        information about the binary parameters and the NR simulation used to
        generate the waveform modes.

        The zero eccentricity waveform is generated using an approximant
        provided via `zero_ecc_approximant` (see below).

        Also needs `SEOBNRv4ROM_v2.0.hdf5` in `LAL_DATA_PATH`.
        Currently, it is used to get an estimate for the initial
        frequency to use for generating zero eccentricity waveform based on
        the inspiral time of the NR waveform.  Download it from
        https://git.ligo.org/lscsoft/lalsuite-extra/-/blob/master/data/lalsimulation/SEOBNRv4ROM_v2.0.hdf5
        and set the path using
        `export LAL_DATA_PATH=/path/to/directory/containing/seobnrv4rom_file/`.

    include_params_dict:
        If True, returns a dictionary of binary parameters.

    zero_ecc_approximant: str
        Waveform model to generate zero eccentricity waveform when
        `include_zero_ecc` is True.

    num_orbits_to_remove_as_junk: float
        Number of orbits to throw away as junk from the beginning of the NR
        data. If the file `Horizons.h5` is located within the `data_dir`, the
        orbital phase data contained in it is utilized to estimate the duration
        of a single orbit. In cases where this file is not present, the
        duration of one orbit is instead derived from the phase of the (2, 2)
        mode. It's important to note that the accuracy of this estimate is
        compromised due to contamination of the waveform data by junk
        radiation.

    mode_array: 1d array
        1d array of modes to load. Should have the format `[(l1, m1), (l2,
        m2),..]`

    extrap_order: int
        The extrapolation order determines the filename to search for in order
        to locate the strain file. This function will seek a file named
        `Strain_N{extrap_order}.h5` in the `data_dir`.

    keep_memory: bool
        If False, remove memory contribution from the waveform modes.
        This will require metadata file to find t_relax which is used
        to start the integration for computing memory contribution.

    Returns
    -------
    Returns a dictionary with the following quantities:
    t:
        1d array of times, in dimensionless units, with uniform time
        step `dt`, shifted such t=0 coincides with the peak waveform
        amplitude (obtained using all requested modes in
        mode_array). This already discards the first
        `num_orbits_to_remove_as_junk` orbits.

    hlm:
        Dictionary of NR waveform modes interpolated onto the time
        array, `t`. This already discards the first
        `num_orbits_to_remove_as_junk` orbits. The dictionary contains
        all requested modes in `mode_array`.  To get a particular mode,
        do h22 = hlm[(2, 2)].

    Optionally,
    params_dict:
        Dictionary of parameters containing mass ratio and spins.
        Returned when `metadata_path` is provided and `include_params_dict` is
        True.
    t_zeroecc:
        1d uniform array of times corresponding to zero eccentricity
        modes in dimensionless units.
        Returned when `include_zero_ecc` is True.
    hlm_zeroecc:
        Dictionary of modes created using `zero_ecc_approximant` model
        with the same mass ratio and spin components as the NR
        simulation and eccentricity set to zero. Currently, it contains
        only the (2, 2) mode.
        Returned when `include_zero_ecc` is True.
    """
    kwargs = check_kwargs_and_set_defaults(
        kwargs,
        get_load_waveform_defaults("SXSCatalog"),
        "SXSCatalog kwargs",
        "`load_data.get_defaults_for_nr`")

    # check data directory
    horizon_file_exists = check_sxs_data_dir("SXSCatalog", **kwargs)
    # get the modes
    t, modes_dict = get_modes_dict_from_sxs_catalog_format(**kwargs)
    # make dataDict and return
    # The following actions are performed and the resulting dict is returned:
    # - the original modes are cleaned by removing junk radiation
    # - shift time axis such that the global amplitude peak occurs at t = 0
    # - add zeroecc data if `include_zero_ecc` is True
    # - add params dict if `include_params_dict` is True
    # see `make_return_dict_for_sxs_catalog_format` for more details.
    return make_return_dict_for_sxs_catalog_format(
        t, modes_dict, horizon_file_exists, **kwargs)


def load_sxs_catalogformat_old(**kwargs):
    """Load waveform modes from sxs catalog in old format.

    This function can be used to load waveform modes from sxs catalog in old
    format (prior to 2023) where the waveform file is named as
    `rhOverM_Asymptotic_GeometricUnits_CoM.h5`. For loading sxs catalog
    waveforms in the new format where the waveform files are named as
    `Strain_N{extrap_order}.h5`, see `load_sxs_catalogformat`.

    The allowed kwargs and defaults are the same as in `load_sxs_catalogformat`
    except that the waveform file that is to be provided in the `data_dir`
    directory is different. See below for more details on the files that should
    exist inside `data_dir`. All other args in `kwargs` are the same as in
    `load_sxs_catalogformat`. For detailed description of the kwargs see the
    docstring under `load_sxs_catalogformat`.

    In the old catalog format, a single waveform file named
    `rhOverM_Asymptotic_GeometricUnits_CoM.h5` contains all the extrapolated
    waveform modes and for a given `extrap_order`, the corresponding waveform
    modes are retrieved from this file. Therefore, in the old format, the
    following files are looked for in the `data_dir` directory:

    1. `rhOverM_Asymptotic_GeometricUnits_CoM.h5` (mandatory).
    2. `metadata.txt` (required when `include_zero_ecc` or
        `include_params_dict` is True). For more details, see `data_dir` under
        `load_sxs_catalogformat`. `metadata.txt` is required for sxs old
        catalog format. This file contains the same information as found in
        `metadata.json` in the newer sxs catalog format (origin=`SXSCatalog`).
    3. `Horizons.h5` (optional). For more details, see `data_dir`
      under `load_sxs_catalogformat`.
    """
    kwargs = check_kwargs_and_set_defaults(
        kwargs,
        get_load_waveform_defaults("SXSCatalog_old"),
        "SXSCatalog_old kwargs",
        "`load_data.get_defaults_for_nr`")

    # check data directory
    horizon_file_exists = check_sxs_data_dir("SXSCatalog_old", **kwargs)
    # get the modes
    t, modes_dict = get_modes_dict_from_sxs_catalog_old_format(**kwargs)
    # make dataDict and return
    # The following actions are performed and the resulting dict is returned:
    # - the original modes are cleaned by removing junk radiation
    # - shift time axis such that peak occurs at t = 0
    # - add zeroecc data if `include_zero_ecc` is True
    # - add params dict if `include_params_dict` is True
    # see `make_return_dict_for_sxs_catalog_format` for more details.
    return make_return_dict_for_sxs_catalog_format(
        t, modes_dict, horizon_file_exists, **kwargs)


def check_sxs_data_dir(origin, **kwargs):
    """Check if the necessary files exist for loading sxs catalog format.

    Depending on the origin, it looks for a set of files needed to extract
    the waveform modes, get the parameters of the NR simulation and to
    clean the modes by removing junk radiation before returning the modes.
    These files are

    - Files to extract the modes
       - If origin = "SXSCatalog", i.e., for the format from 2023 onwards,
         - `Strain_N{extrap_order}.h5`
         - `Strain_N{extrap_order}.json` where `extrap_order` is the
           extrapolation order provided in the `kwargs`
       - If origin = "SXSCatalog_old", i.e., for the format before 2023,
         - `rhOverM_Asymptotic_GeometricUnits_CoM.h5`.

        These files are required to extract the waveform modes successfully.
    - `metadata.txt` or `metadata.json` file to get the parameters of the NR
        Simulation. `metadata.txt` is required for `SXSCatalog_old`. In
        `SXSCatalog`, the newer format of sxs catalog, it is replaced by
        `metadata.json`. This file is required when `include_zero_ecc` or
        `include_params_dict` is True or `keep_memory` (available only
        for `SXSCatalog`) is False.
    - `Horizons.h5` file to estimate the duration of an orbit using the orbital
        phase data. This file is optional. If it is not found, we use the phase
        of the (2, 2) mode to get the duration of an orbit assuming a phase
        change of 4pi occurs over an orbit. However, since the waveform data is
        affected by the junk radiation, this estimate may not be very accurate.

    Parameters
    ----------
    original : str
        Either "SXSCatalog" or "SXSCatalog_old".
    kwargs : dict
        kwargs for loading the sxs catalog format files.

    Returns
    -------
    True if `Horizons.h5` file exists else False.
    """
    # check data_dir
    if kwargs["data_dir"] is None:
        raise Exception(
            "Must provide path to the directory containing waveform files. "
            "`data_dir` can not be None.")
    # Check if the data directory exists
    if not os.path.exists(kwargs["data_dir"]):
        raise FileNotFoundError(
            f"Can not find the directory {kwargs['data_dir']}.")
    required_files_dict = {
        "SXSCatalog": [f"Strain_N{kwargs['extrap_order']}.h5",
                       f"Strain_N{kwargs['extrap_order']}.json"],
        "SXSCatalog_old": ["rhOverM_Asymptotic_GeometricUnits_CoM.h5"]}
    message_dict = {
        "SXSCatalog": " You should provide the h5 and json file named "
        f"`Strain_N{kwargs['extrap_order']}` since `extrap_order` is "
        f"{kwargs['extrap_order']}. If you are using the old format, "
        "you should provide the `rhOverM_Asymptotic_GeometricUnits_CoM.h5` "
        "file.",
        "SXSCatalog_old": " You should provide the "
        "`rhOverM_Asymptotic_GeometricUnits_CoM.h5` file. "
        "If you are using the new format, You should provide the h5 and json "
        f"file named `Strain_N{kwargs['extrap_order']}` since `extrap_order` "
        f"is {kwargs['extrap_order']}."}
    if any([kwargs["include_zero_ecc"], kwargs["include_params_dict"],
            not kwargs.get("keep_memory", True)]):
        # In newer versions of sxscatalog format, metadata.txt files are
        # replaced by metadata.json file.
        required_metadata_file = "metadata.json" if origin == "SXSCatalog" else "metadata.txt"
        for k in required_files_dict:
            required_files_dict.update(
                {k: np.append(required_files_dict[k], [required_metadata_file])})
    # Check if all the required files exist
    for filename in required_files_dict[origin]:
        if not os.path.exists(
                os.path.join(kwargs["data_dir"], filename)):
            if "metadata" in filename:
                message = (
                    f" {required_metadata_file} file is required when "
                    "`include_zero_ecc` or `include_params_dict` is True or "
                    "`keep_memory` is set to False to get the binary "
                    "parameters of the NR simulation.")
            else:
                message = message_dict[origin]
            raise FileNotFoundError(
                f"Can not find `{filename}` in `{kwargs['data_dir']}`."
                + message)
    # Check if the Horizons.h5 file exists. If it exists we return True, else
    # False.
    if os.path.exists(os.path.join(kwargs['data_dir'], "Horizons.h5")):
        return True
    else:
        warnings.warn(
            f"Can not find `Horizons.h5` in {kwargs['data_dir']}. "
            "Phase of the (2, 2) mode will be used to estimate the duration "
            "of `num_orbits_to_remove_as_junk` orbits which may not be "
            "accurate since the (2, 2) mode phase contains junk radiation "
            "in the initial part.")
        return False


def make_return_dict_for_sxs_catalog_format(t, modes_dict, horizon_file_exits,
                                            **kwargs):
    """Make dictionary to return for sxs catalog format.

    This function takes the modes data extracted from the sxs catalog format
    files and performs the following list of actions and returns the final
    processed data.

    - Remove junk from the begining of the data
    - Shift the time axis to align the global peak amplitude to t = 0
    - Add zeroecc data if `include_zero_ecc` is True
    - Add params dict if `include_params_dict` is True
    """
    # remove junk from the begining of the data
    if horizon_file_exits:
        t, modes_dict = remove_junk_from_sxs_catalogformat_using_horizons_data(
            t, modes_dict, kwargs["num_orbits_to_remove_as_junk"],
            os.path.join(kwargs["data_dir"], "Horizons.h5"))
    else:
        t, modes_dict = reomve_junk_from_nr_data(
            t,
            modes_dict,
            kwargs["num_orbits_to_remove_as_junk"])
    # get time at peak amplitude to shift the time axis
    tpeak = peak_time_via_quadratic_fit(
        t, amplitude_using_all_modes(modes_dict))[0]
    # shift time axis by tpeak such that peak occurs at t = 0
    dataDict = {"t": t - tpeak,
                "hlm": modes_dict}
    if any([kwargs["include_zero_ecc"], kwargs["include_params_dict"],
            not kwargs.get("keep_memory", True)]):
        if os.path.exists(os.path.join(kwargs["data_dir"], "metadata.txt")):
            params_dict = get_params_dict_from_sxs_metadata(
                os.path.join(kwargs["data_dir"], "metadata.txt"))
        else:
            params_dict = get_params_dict_from_sxs_metadata(
                os.path.join(kwargs["data_dir"], "metadata.json"))
    # if include_zero_ecc is True, load zeroecc dataDict
    if kwargs["include_zero_ecc"]:
        params_dict_zero_ecc = params_dict.copy()
        # remove t_relax from the params dict
        params_dict_zero_ecc.pop("t_relax", None)
        # provide the approximant to be used for zero eccentricity waveform
        params_dict_zero_ecc.update(
            {"approximant": kwargs["zero_ecc_approximant"],
             "mean_ano": 0.0  # Needed for LAL waveform models used to generate zero eccentricity waveform
             })
        dataDict_zeroecc = get_zeroecc_dataDict_for_nr(
            dataDict, params_dict_zero_ecc)
        dataDict.update({"t_zeroecc": dataDict_zeroecc["t_zeroecc"],
                         "hlm_zeroecc": dataDict_zeroecc["hlm_zeroecc"]})
    if kwargs["include_params_dict"]:
        dataDict.update({"params_dict": params_dict})
    return dataDict


def get_modes_dict_from_sxs_catalog_old_format(**kwargs):
    """Get modes from sxs catalog old format files.

    See documentation of `load_sxs_catalogformat` for allowed kwargs and
    default values.
    """
    # load modes
    modes_dict = {}
    data = h5py.File(os.path.join(
        kwargs["data_dir"],
        "rhOverM_Asymptotic_GeometricUnits_CoM.h5"), "r")
    waveform_data = data[f"Extrapolated_N{kwargs['extrap_order']}.dir"]
    for idx, mode in enumerate(kwargs["mode_array"]):
        ell, m = mode
        mode_data = waveform_data[f"Y_l{ell}_m{m}.dat"]
        # create the time array only once
        if idx == 0:
            time = mode_data[:, 0]
            t = np.arange(time[0], time[-1], kwargs["deltaTOverM"])
        hlm = mode_data[:, 1] + 1j * mode_data[:, 2]
        # See comments under `get_modes_dict_from_sxs_catalog_format` on why we
        # interpolate real and imaginary parts instead of the amplitude and
        # phase.
        real_interp = interpolate(t, time, np.real(hlm))
        imag_interp = interpolate(t, time, np.imag(hlm))
        hlm_interp = real_interp + 1j * imag_interp
        modes_dict.update({(ell, m): hlm_interp})
    return t, modes_dict


def get_modes_dict_from_sxs_catalog_format(**kwargs):
    """Get modes from sxs catalog format files.

    See documentation of `load_sxs_catalogformat` for allowed kwargs and
    default values.
    """
    # get the waveform object
    waveform = sxs.load(
        os.path.join(kwargs["data_dir"], f"Strain_N{kwargs['extrap_order']}"))
    
    if kwargs["keep_memory"]:
        waveform_modes = waveform.data
    else:
        # remove memory contribution.
        # Get parameters from the metadata file. We need the relaxation time
        # `t_relax` to use as the starting time for the integration to compute
        # the memory contribution
        params_dict = get_params_dict_from_sxs_metadata(
            os.path.join(kwargs["data_dir"], "metadata.json"))

        # Get the memory contribution
        waveform_mem_only = sxs.waveforms.memory.J_E(
            waveform, integration_start_time=params_dict["t_relax"])
    
        # NOTE: This is currently required because the ell = 0, 1 modes get
        # included by silly sxs when removing memory.  So, we drop all modes
        # before the first nonzero mode (2, -2).  This should eventually not be
        # required if fixed in sxs, but that should not break this code anyway.
        waveform_mem_only_data = waveform_mem_only.data[
            :, sf.LM_index(2, -2, waveform_mem_only.ell_min):]
        # Get waveform modes without the memory
        waveform_modes = waveform.data - waveform_mem_only_data
    # get the time
    time = waveform.t
    # Create a time array with step = dt, to interpolate the waveform
    # modes on this uniform time array.
    t = np.arange(time[0], time[-1], kwargs["deltaTOverM"])
    modes_dict = {}
    for mode in kwargs["mode_array"]:
        ell, m = mode
        hlm = waveform_modes[:, waveform.index(ell, m)]
        # NOTE: We interpolate the real and imaginary parts of the modes,
        # instead of interpolating the amplitude and phase. We noticed that for
        # systems with high eccentricity and extreme precession, interpolating
        # amplitude and phase over smaller deltaTOverM values introduces
        # artificial spikes in the frequency that are absent in the original
        # data. These spikes become more pronounced as the spline order
        # increases. In contrast, interpolating the real and imaginary parts
        # avoids these issues.
        real_interp = interpolate(t, time, np.real(hlm))
        imag_interp = interpolate(t, time, np.imag(hlm))
        hlm_interp = real_interp + 1j * imag_interp
        modes_dict.update({(ell, m): hlm_interp})
    return t, modes_dict


def get_params_dict_from_sxs_metadata(metadata_path):
    """Get binary parameters from sxs metadata file.

    This file is usually located in the same directory as the waveform file and
    has the name `metadata.txt` or `metadata.json`. It contains metadata
    related to the NR simulation performed to obtain the waveform modes.
    """
    if "metadata.txt" in metadata_path:
        fl = open(metadata_path, "r")
        lines = fl.readlines()
        fl.close()
        for line in lines:
            if "reference-dimensionless-spin1" in line:
                chi1 = [float(x.strip()) for x in line.split("=")[-1].split(",")]
            if "reference-dimensionless-spin2" in line:
                chi2 = [float(x.strip()) for x in line.split("=")[-1].split(",")]
            if "reference-mass1" in line:
                m1 = float(line.split("=")[-1].strip())
            if "reference-mass2" in line:
                m2 = float(line.split("=")[-1].strip())
            if "relaxation-time" in line:
                t_relax = float(line.split("=")[-1].strip())
    if "metadata.json" in metadata_path:
        fl = open(metadata_path, "r")
        data = json.load(fl)
        fl.close()
        chi1 = data["reference_dimensionless_spin1"]
        chi2 = data["reference_dimensionless_spin2"]
        m1 = data["reference_mass1"]
        m2 = data["reference_mass2"]
        t_relax = data["relaxation_time"]
    # numerical noise can make m1 slightly lesser than m2. Catch this whenver
    # it happens. Typically dq = (1 - q) is very small (dq <~ 1e-7) but for few
    # cases it can be dq ~ 1e-4. Therefore, if dq < 5e-4, we treat it as 1,
    # otherwise raise exception.
    q = m1/m2
    dq = 1 - q
    dq_tol = 5e-4
    if dq > 0:
        # if dq < dq_tol, treat it q as 1.
        if dq < dq_tol:
            warnings.warn(
                f"SXS metadata gives m1 = {m1} < m2 = {m2} but "
                f"1 - (m1/m2) = {dq} < {dq_tol}. Setting q = m1/m2 = 1.")
            q = 1.0
        else:
            raise Exception(f"SXS metadata gives m1 = {m1} < m2 = {m2} -> "
                            f"1 - (m1/m2) = {dq} > {dq_tol}.")
    params_dict = {"q": q,
                   "chi1": chi1,
                   "chi2": chi2,
                   "t_relax": t_relax}
    return params_dict


def get_zeroecc_dataDict_for_nr(nr_dataDict, params_dict):
    """Get the zero ecc data dict corresponding to a nr data.

    Params:
    -------
    nr_dataDict:
        Data Dictionary containing NR data.
    params_dict:
        Dictionary of parameters to generate zero eccentricity waveform.
    Returns:
    -------
    dataDict_zeroecc:
        Data Dictionary containing zero ecc data.
    """
    # Keep all other params fixed but set ecc = 0 and generate
    # waveform with approximant provided in nr_dataDict["params_dict"]
    zero_ecc_kwargs = params_dict.copy()
    zero_ecc_kwargs["ecc"] = 0.0
    zero_ecc_kwargs["include_zero_ecc"] = False  # to avoid double calc
    # calculate the Momega0 so that the length is >= the length of the NR
    # waveform.
    # First we compute the inspiral time of the NR waveform.
    # get time at merger of the NR waveform
    t_merger = peak_time_via_quadratic_fit(
            nr_dataDict["t"],
            amplitude_using_all_modes(nr_dataDict["hlm"]))[0]
    M = 10  # will be factored out
    inspiralTime = (t_merger
                    - nr_dataDict["t"][0]) * time_dimless_to_mks(M)
    # get the initial frequency to generate waveform of inspiral time
    # roughly equal to that of the NR one.
    # The following function that estimates the initial frequency to
    # generate a waveform with given time to merger needs
    # the file at
    # https://git.ligo.org/lscsoft/lalsuite-extra/-/blob/master/data/lalsimulation/SEOBNRv4ROM_v2.0.hdf5
    # to be present at LAL_DATA_PATH
    # TODO: Replace this function with one from Phenom models
    q = zero_ecc_kwargs["q"]
    m1SI = q * M / (1 + q) * lal.MSUN_SI
    m2SI = M / (1 + q) * lal.MSUN_SI
    s1z = zero_ecc_kwargs["chi1"][2]
    s2z = zero_ecc_kwargs["chi2"][2]
    f0 = lalsim.SimIMRSEOBNRv4ROMFrequencyOfTime(
        inspiralTime, m1SI, m2SI, s1z, s2z)
    # convert to omega and make dimensionless
    Momega0_zeroecc = f0 * time_dimless_to_mks(M) * np.pi
    zero_ecc_kwargs["Momega0"] = Momega0_zeroecc

    dataDict_zeroecc = load_waveform(origin="LAL", **zero_ecc_kwargs)
    t_zeroecc = dataDict_zeroecc["t"]

    # if f0 is too small and generate too long zero ecc waveform
    # report that
    if -t_zeroecc[0] >= - 2 * nr_dataDict["t"][0]:
        warnings.warn("zeroecc waveform is too long. It's "
                      f"{t_zeroecc[0]/nr_dataDict['t'][0]:.2f}"
                      " times the ecc waveform.")
    # We need the zeroecc modes to be long enough, at least the same length
    # as the eccentric one to get the residual amplitude correctly.
    # In case the zeroecc waveform is not long enough we reduce the
    # initial Momega0 by a factor of 2 and generate the waveform again
    # NEED A BETTER SOLUTION to this later
    num_tries = 0
    while t_zeroecc[0] > nr_dataDict["t"][0]:
        zero_ecc_kwargs["Momega0"] = zero_ecc_kwargs["Momega0"] / 2
        dataDict_zeroecc = load_waveform(origin="LAL", **zero_ecc_kwargs)
        t_zeroecc = dataDict_zeroecc["t"]
        num_tries += 1
    if num_tries >= 2:
        warnings.warn("Too many tries to reset Momega0 for generating"
                      " zeroecc modes. Total number of tries = "
                      f"{num_tries}")
    hlm_zeroecc = dataDict_zeroecc["hlm"]
    # Finally we want to return zeroecc data only about the length of the
    # eccentric waveform and truncate the rest of the waveform to avoid
    # wasting computing resources
    start_zeroecc_idx = np.argmin(
        np.abs(t_zeroecc - nr_dataDict["t"][0])) - 10
    for key in hlm_zeroecc.keys():
        hlm_zeroecc[key] = hlm_zeroecc[key][start_zeroecc_idx:]

    return {"t_zeroecc": t_zeroecc[start_zeroecc_idx:],
            "hlm_zeroecc": hlm_zeroecc}


def reomve_junk_from_nr_data(t, modes_dict, num_orbits_to_remove_as_junk):
    """Remove junk from beginning of NR data.

    Parameters
    ----------
    t:
        Time array for the NR data.
    modes_dict:
        Dictionary containing modes array.
    num_orbits_to_remove_as_junk:
        Number of orbits to remove as junk from the begining of NR data.

    Returns
    -------
    t_clean:
        Time array corresponding to clean NR data.
    modes_dict_clean:
        modes_dict with `num_orbits_to_remove_as_junk` orbits removed from the
        begining of modes array.
    """
    phase22 = - np.unwrap(np.angle(modes_dict[(2, 2)]))
    # one orbit corresponds to 4pi change in 22 mode
    # phase
    idx_junk = np.argmin(
        np.abs(
            phase22 - (
                phase22[0]
                + num_orbits_to_remove_as_junk * 4 * np.pi)))
    t_clean = t[idx_junk:]
    modes_dict_clean = {}
    for key in modes_dict:
        modes_dict_clean[key] = modes_dict[key][idx_junk:]

    return t_clean, modes_dict_clean


def remove_junk_from_sxs_catalogformat_using_horizons_data(
        t, modes_dict, num_orbits_to_remove_as_junk, horizon_filepath):
    """Remove first `num_orbits_to_remove_as_junk` orbits as junk.

    Paramters
    ---------
    t: array-like
        1d array of times associated with the waveform modes.

    modes_dict: dict
        Dictionary containing waveform modes.

    num_orbits_to_remove_as_junk: float
        Number of orbits to remove from the start of the waveform modes.

    horizon_filepath: str
        Path to the `Horizons.h5` file. This file typically can be found in the
        same directory where the waveform files are located in the SXS catalog.

    Returns
    -------
    t_clean: array-like
        Time array corresponding to clean NR data.

    modes_dict_clean: dict
        modes_dict with `num_orbits_to_remove_as_junk` orbits removed from the
        begining of modes array.
    """
    num_orbits_duration = get_num_orbits_duration_from_horizon_data(
        horizon_filepath, num_orbits_to_remove_as_junk)

    # The time array in the `Horizons.h5` file is the time coordinate of the NR
    # simulation, which starts at t = 0.  In contrast, the time array in the
    # waveform file is related to retarded time, shifted back to the origin,
    # and begins at t < 0, typically around -100M or so. Although they have
    # different starting points, we can roughly associate the t = 0 point in
    # both files.

    # Therefore, after obtaining the duration `delta_t` for
    # `num_orbits_to_remove_as_junk` number of orbits, when we truncate the
    # waveform, we retain only the part corresponding to t > delta_t, not t >
    # t[0] + delta_t.  This is because t[0] is located at around -100M, and
    # starting the truncation there would be closer to the unwanted junk than
    # desired.

    # NOTE: The following line assumes that the time array has not been shifted
    # in anyway and is the original time array contained in the waveform file.
    idx_junk = np.argmin(np.abs(t - num_orbits_duration))
    t_clean = t[idx_junk:]
    modes_dict_clean = {}
    for key in modes_dict:
        modes_dict_clean[key] = modes_dict[key][idx_junk:]

    return t_clean, modes_dict_clean


def get_num_orbits_duration_from_horizon_data(horizon_filepath, num_orbits):
    """Get the duration of `num_orbits` from phase data in `Horizons.h5`.

    Parameters
    ----------
    horizon_filepath: str
        Path to the `Horizons.h5` file. This file typically can be found in the
        same directory where the waveform files are located.

    num_orbits: float
        Number orbits from the start of the simulation. This function estimates
        the duration of `num_orbits` from orbital phase which is obtained from
        the coordinates of the black holes stored in the `Horizons.h5` file.

    Returns
    -------
    num_obits_duration : float
        Duration of first `num_orbits`.
    """
    horizons_data = h5py.File(horizon_filepath, "r")
    xA_data = horizons_data["AhA.dir"]["CoordCenterInertial.dat"]
    xB_data = horizons_data["AhB.dir"]["CoordCenterInertial.dat"]
    time = xA_data[:, 0]
    separion_vec = xA_data[:, 1:] - xB_data[:, 1:]
    # We will use the x-y coordinates, i.e., the motion projected
    # onto the x-y plane to compute the phase and then the number of
    # orbits using that phase.
    # NOTE: This assumption breaks down when the orbital plane itself contains
    # the z-axis, which can happen if there is extreme precession, or in the
    # case of EMRIs.
    # Get the orbital phase
    phase_orb_projected = np.unwrap(np.arctan2(separion_vec[:, 1], separion_vec[:, 0]))
    # Find the duration of first num_orbits assuming that the orbital phase
    # changes by 2pi over one orbit
    idx_at_num_obits_from_start = np.argmin(
        np.abs(phase_orb_projected - (phase_orb_projected[0] + num_orbits * 2 * np.pi)))
    num_obits_duration = (time[idx_at_num_obits_from_start]
                          - time[0])
    return num_obits_duration


def package_modes_for_scri(modes_dict, ell_min, ell_max):
    """Package modes in an ordered list to use as input data to `scri.WaveformModes`.

    Parameters
    ----------
    modes_dict: dict
        Dictionary of waveform modes.

    ell_min: int
        Minimum `ell` value to use.

    ell_max: int
        Maximum `ell` value to use.

    Returns
    -------
    List of modes in the order of increasing m for a given `ell` that is, for
    `ell`=2, the list should be [(2, -2), (2, 1), (2, 0), (2, 1), (2, 2)]
    """
    keys = modes_dict.keys()
    shape = modes_dict[(2, 2)].shape
    n_elem = (ell_max + 3) * (ell_max - 1)
    # Start with a result array with zeros, and populate only those modes that are
    # available in modes_dict. This is necessary because scri expects a list of modes in
    # a particular order.
    result = np.zeros((shape[0], n_elem), dtype=np.complex128)
    i = 0
    for ell in range(ell_min, ell_max + 1):
        for m in range(-ell, ell + 1):            
            if (ell, m) in keys:
                result[:, i] = modes_dict[(ell, m)]
            else:
                # for a given ell, all (ell, m) modes should exist in the 
                # modes_dict
                raise Exception(
                    f"{ell, m} mode for ell={ell} does not exist in the "
                    "modes dict. To get the coprecessing modes accurately, "
                    "all the `(ell, m)` modes for a given `ell` should exist "
                    "in the input modes dict.")
            i += 1
    return result


def unpack_scri_modes(w):
    """Unpack modes from `scri.WaveformModes` object to dict format.

    Get back the modes from the `scri.WaveformModes` object to dict format as
    required by `gw_eccentricity`.

    Parameters
    ----------
    w: scri.WaveformModes
        `scri.WaveformModes` object.

    Returns
    -------
    Waveform modes in dict with key `(ell, m)`.
    """
    result = {}
    for key in w.LM:
        result[(key[0], key[1])] = 1 * w.data[:, w.index(key[0], key[1])]
    return result


def get_coprecessing_data_dict(data_dict, ell_min=2, ell_max=2, tag=""):
    """Get `data_dict` in the coprecessing frame.

    Given a `data_dict` containing the modes dict in the inertial frame and the
    associated time, obtain the corresponding modes in the coprecessing frame.

    For a given `ell`, the data_dict should contain modes for all `m` values from
    `-ell` to `+ell`.
    
    Parameters
    ----------
    data_dict: dict
        Dictionary of waveform modes in the inertial frame and the associated
        time. It should have the same structure as `dataDict` in
        `gw_eccentricity.measure_eccentricity`.

    ell_min: int, default=2
        Minimum `ell` value to use.

    ell_max: int, default=2
        Maximum `ell` value to use.

    tag: str, default=""
        A tag specifying which inertial frame data to use when transforming
        inertial frame modes to coprecessing frame modes. For example, setting
        `tag="_zeroecc"` selects the inertial frame modes corresponding to
        the "zeroecc" (non-eccentric) case. If left as the default value
        (`""`), the inertial frame modes for the eccentric case are used.

    Returns
    -------
    Dictionary of waveform modes in the coprecessing frame and the associated
    time. It has the same structure as the input `data_dict` in the intertial
    frame.
    """
    # Get list of modes from `data_dict` to use as input to `scri.WaveformModes`.
    ordered_mode_list = package_modes_for_scri(
        data_dict["hlm" + tag],
        ell_min=ell_min,
        ell_max=ell_max)

    w = scri.WaveformModes(
        dataType=scri.h,
        t=data_dict["t" + tag],
        data=ordered_mode_list,
        ell_min=ell_min,
        ell_max=ell_max,
        frameType=scri.Inertial,
        r_is_scaled_out=True,
        m_is_scaled_out=True)

    # co-precessing frame modes
    w_coprecessing = deepcopy(w).to_coprecessing_frame()
    # Create a copy of data_dict and replace the "hlm" modes in the inertial frame
    # with the corresponding modes in the coprecessing frame
    data_dict_copr = deepcopy(data_dict)
    data_dict_copr.update(
        {"hlm" + tag: unpack_scri_modes(deepcopy(w_coprecessing))})
    return data_dict_copr


def load_h22_from_EOBfile(EOB_file):
    """Load data from EOB files."""
    fp = h5py.File(EOB_file, "r")
    t_ecc = fp['data/t'][:]
    amp22_ecc = fp['data/hCoOrb/Amp_l2m2'][:]
    phi22_ecc = fp['data/hCoOrb/phi_l2m2'][:]

    t_nonecc = fp['data/t'][:]
    amp22_nonecc = fp['nonecc_data/hCoOrb/Amp_l2m2'][:]
    phi22_nonecc = fp['nonecc_data/hCoOrb/phi_l2m2'][:]

    fp.close()
    dataDict = {"t": t_ecc, "hlm": amp22_ecc * np.exp(1j * phi22_ecc),
                "t_zeroecc": t_nonecc,
                "hlm_zeroecc": amp22_nonecc * np.exp(1j * phi22_nonecc)}
    return dataDict


def load_EOB_EccTest_file(**kwargs):
    """Load EOB files for testing EccDefinition.

    These files were generated using SEOBNRv4EHM model.
    Allowed kwargs are:

    filepath:
        Path to the EOB file.
    include_zero_ecc:
        If True, loads the quasicircular waveform modes also.
        This requires providing the path to quasicircular waveform
        file, see "filepath_zeroecc" below.
    filepath_zero_ecc:
        Path to the waveform file containing quasicircular waveform
        modes. Required only if include_zero_ecc is True.
    """
    f = h5py.File(kwargs["filepath"], "r")
    t = f["t"][:]
    hlm = {(2, 2): f["(2, 2)"][:]}
    # make t = 0 at the merger
    t = t - peak_time_via_quadratic_fit(
        t,
        amplitude_using_all_modes(hlm))[0]
    dataDict = {"t": t, "hlm": hlm}
    if ('include_zero_ecc' in kwargs) and kwargs['include_zero_ecc']:
        if "filepath_zero_ecc" not in kwargs:
            raise Exception("Mus provide file path to zero ecc waveform.")
        zero_ecc_kwargs = kwargs.copy()
        zero_ecc_kwargs["filepath"] = kwargs["filepath_zero_ecc"]
        zero_ecc_kwargs["include_zero_ecc"] = False
        dataDict_zero_ecc = load_EOB_EccTest_file(**zero_ecc_kwargs)
        t_zeroecc = dataDict_zero_ecc["t"]
        hlm_zeroecc = dataDict_zero_ecc["hlm"]
        dataDict.update({"t_zeroecc": t_zeroecc,
                         "hlm_zeroecc": hlm_zeroecc})
    return dataDict


def load_EOB_waveform(**kwargs):
    """Load EOB waveform.

    Run `load_data.get_load_waveform_defaults('EOB')` to see allowed
    keys and defaults.

    These files were generated using SEOBNRv4EHM model.
    Allowed kwargs are:

    filepath:
        Path to the EOB file.
    include_zero_ecc:
        If True, loads the quasicircular waveform modes also.
        This requires providing the path to quasicircular waveform
        file, see "filepath_zeroecc" below.
    filepath_zero_ecc:
        Path to the waveform file containing quasicircular waveform
        modes. Required only if include_zero_ecc is True.
    """
    # check kwargs and set defaults
    kwargs = check_kwargs_and_set_defaults(
        kwargs,
        get_load_waveform_defaults("EOB"),
        "EOB kwargs",
        "`load_data.get_load_waveform_defaults`")
    if kwargs["filepath"] is None:
        raise Exception("Must provide file path to EOB waveform")
    if kwargs["include_zero_ecc"] and kwargs["filepath_zero_ecc"] is None:
        raise Exception("Must provide `filepath_zero_ecc`, file path to zeroecc EOB waveform,"
                        " when `include_zero_ecc` is `True`.")
    if "EccTest" in kwargs["filepath"]:
        return load_EOB_EccTest_file(**kwargs)
    else:
        raise Exception("Unknown filepath pattern.")


def load_lvcnr_hack(**kwargs):
    """Load 22 mode from lvcnr files using h5py and Interpolation.

    NOTE: This is not the recommended way to load lvcnr files.
    Use load_lvcnr for that. Currently the load_lvcnr function
    has some issues where it fails to load due to too low f_low,
    or takes too long to load or loads only last few cycles.

    This is a simple hack to load the NR files using h5py and then
    interpolate the data. Also we only load 22 modes here for simiplicity.
    This function is mostly for testing measurement of eccentricity
    of NR waveforms.

    Parameters
    ----------
    kwargs: Could be the followings.
    Run `load_data.get_load_waveform_defaults('LVCNR_hack')` to see allowed
    keys and defaults.

    filepath: str
        Path to lvcnr file.

    deltaTOverM: float
        Time step. The loaded data will be interpolated using this time step.

    include_zero_ecc: bool
        If True returns PhenomT waveform mode for same set of parameters
        except eccentricity set to zero.

    include_params_dict: bool
        If True, returns a dictionary of binary parameters.

    num_orbits_to_remove_as_junk: float
        Number of orbits to throw away as junk from the begining of the NR
        data.

    Returns
    -------
        Dictionary of modes dict, parameter dict and also zero ecc mode dict if
        include_zero_ecc is True.

    t:
        Time array.
    hlm:
        Dictionary of modes.
    Optionally,
    params_dict:
        Dictionary of parameters.
    t_zeroecc:
        Time array for zero ecc modes.
    hlm_zeroecc:
        Mode dictionary for zero eccentricity.
    """
    kwargs = check_kwargs_and_set_defaults(
        kwargs,
        get_load_waveform_defaults("LVCNR_hack"),
        "LVCNR hack kwargs",
        "`load_data.get_defaults_for_nr`")
    f = h5py.File(kwargs["filepath"])
    t_for_amp22 = f["amp_l2_m2"]["X"][:]
    amp22 = f["amp_l2_m2"]["Y"][:]

    t_for_phase22 = f["phase_l2_m2"]["X"][:]
    phase22 = f["phase_l2_m2"]["Y"][:]

    tstart = max(t_for_amp22[0], t_for_phase22[0])
    tend = min(t_for_amp22[-1], t_for_phase22[-1])

    t_interp = np.arange(tstart, tend, kwargs["deltaTOverM"])

    # NOTE: The data were downsampled using romspline
    # (https://arxiv.org/abs/1611.07529), which uses higher order splines as
    # appropriate, but we are now upsampling with only cubic splines.
    # This can lead to inaccuracies.
    amp22_interp = interpolate(t_interp, t_for_amp22, amp22)
    phase22_interp = interpolate(t_interp, t_for_phase22, phase22)
    h22_interp = amp22_interp * np.exp(1j * phase22_interp)

    # remove junk data from the beginning
    t, modes_dict = reomve_junk_from_nr_data(
        t_interp,
        {(2, 2): h22_interp},
        kwargs["num_orbits_to_remove_as_junk"])

    return_dict = {"t": t,
                   "hlm": modes_dict}

    # params
    s1x = f.attrs["spin1x"]
    s1y = f.attrs["spin1y"]
    s1z = f.attrs["spin1z"]
    s2x = f.attrs["spin2x"]
    s2y = f.attrs["spin2y"]
    s2z = f.attrs["spin2z"]
    m1 = f.attrs["mass1"]
    m2 = f.attrs["mass2"]
    ecc = f.attrs["eccentricity"]
    mean_ano = f.attrs["mean_anomaly"]
    f.close()

    if kwargs["include_zero_ecc"] or kwargs["include_params_dict"]:
        params_dict = {"q": m1/m2,
                       "chi1": [s1x, s1y, s1z],
                       "chi2": [s2x, s2y, s2z],
                       "ecc": ecc,
                       "mean_ano": mean_ano}
    if kwargs["include_params_dict"]:
        return_dict.update({"params_dict": params_dict})

    if kwargs["include_zero_ecc"]:
        params_dict_zero_ecc = params_dict.copy()
        params_dict_zero_ecc.update(
            {"approximant": kwargs["zero_ecc_approximant"]})
        dataDict_zeroecc = get_zeroecc_dataDict_for_nr(
            return_dict, params_dict_zero_ecc)
        return_dict.update(dataDict_zeroecc)

    return return_dict


def load_EMRI_waveform(**kwargs):
    """Load EMRI waveforms data.

    kwargs dictionary can have the following keys.
    Run `load_data.get_load_waveform_defaults('EMRI')` to see allowed
    keys and defaults.

    filepath: str
        Path to the eccentric EMRI waveform.
    include_zero_ecc: bool
        If true, load circular EMRI waveform that has the same set of
        parameters as used for eccentric EMRI except the eccentricity being set
        to zero.
    filepath_zero_ecc: str
        Path to the circular EMRI waveform. If None, a filepath would be
        generated based on the filepath of the eccentric waveform.
    start_time: float
        Since the EMRI waveforms could be very long, one can opt to load the
        waveform only from start_time, where start_time is to provided
        following the convention of merger being at t=0. Since EMRI waveform
        does not include an actual merger, t=0 corresponds to the global
        maximum of the amplitude.
        If None, start_time would be the time at the start of the waveform.
    end_time: float
        Similar to start_time, one could provide an end_time, time up to which
        the waveform is to be loaded. If None, it would set the
        end_time to the time of the global maximum.
    deltaT: float
        If provided, it would be used to interpolate the waveform with this
        time step. If None, it would not do any interpolation.
    include_geodesic_ecc: bool
        If True, loads geodesic eccentricity data.
    """
    kwargs = check_kwargs_and_set_defaults(
        kwargs,
        get_load_waveform_defaults("EMRI"),
        "EMRI kwargs",
        "load_data.get_load_waveform_defaults")
    if kwargs["filepath"] is None:
        raise KeyError("path to the eccentric EMRI waveform cannot be None.")
    emri_data = h5py.File(kwargs["filepath"], "r")["Dataset1"]
    t = emri_data[:, 0]
    h22 = emri_data[:, 1] + 1j * emri_data[:, 2]
    tpeak = peak_time_via_quadratic_fit(
        t,
        amplitude_using_all_modes({(2, 2): h22}))[0]
    t -= tpeak
    if kwargs["start_time"] is not None:
        start = np.argmin(np.abs(t - kwargs["start_time"]))
    else:
        start = 0
    if kwargs["end_time"] is not None:
        end = np.argmin(np.abs(t - kwargs["end_time"]))
    else:
        end = -1
    t_new = t[start: end]
    h22_new = h22[start: end]
    dataDict = {"t": t_new,
                "hlm": {(2, 2): h22_new}}
    if kwargs["deltaT"] is not None:
        t_interp = np.arange(t_new[0], t_new[-1], kwargs["deltaT"])
        # make t_interp within the bounds of t_new to avoild extrapolation
        t_interp = t_interp[np.logical_and(t_interp >= t_new[0],
                                           t_interp <= t_new[-1])]
        amp22_interp = interpolate(t_interp, t_new, np.abs(h22_new))
        phase22_interp = interpolate(
            t_interp, t_new, np.unwrap(np.angle(h22_new)))
        h22_interp = amp22_interp * np.exp(1j * phase22_interp)
        dataDict["t"] = t_interp
        dataDict["hlm"] = {(2, 2): h22_interp}

    if kwargs["include_zero_ecc"]:
        if kwargs["filepath_zero_ecc"] is None:
            idx = kwargs["filepath"].find("e0")
            kwargs["filepath_zero_ecc"] = (
                kwargs["filepath"][:idx] + "e0.000.h5")
        kwargs_zero_ecc = {
            "filepath": kwargs["filepath_zero_ecc"],
            "include_zero_ecc": False}
        dataDict_zero_ecc = load_EMRI_waveform(**kwargs_zero_ecc)
        dataDict.update({
            "t_zeroecc": dataDict_zero_ecc["t"],
            "hlm_zeroecc": {(2, 2): dataDict_zero_ecc["hlm"][(2, 2)]
                            / np.sqrt(2*np.pi)}})
    if kwargs["include_geodesic_ecc"]:
        e_geodesic_file = kwargs["filepath"][:-3] + "_ecc.h5"
        e_geodesic_data = h5py.File(e_geodesic_file, "r")["Dataset1"]
        e_geodesic = e_geodesic_data[:, 1][start: end]
        dataDict.update({"e_geodesic": e_geodesic})
        if kwargs["deltaT"] is not None:
            e_geodesic_interp = interpolate(
                t_interp, t_new, np.abs(e_geodesic))
            dataDict.update({"e_geodesic": e_geodesic_interp})
    return dataDict
