""" """

import h5py
import numpy as np
from diffmah import mah_singlehalo
from diffmah.defaults import LGT0
from jax import grad
from jax import jit as jjit
from jax import numpy as jnp

from ..defaults import (
    DEFAULT_DIFFSTAR_U_PARAMS,
    DEFAULT_MS_PARAMS,
    DEFAULT_Q_PARAMS,
    FB,
    SFR_MIN,
    get_bounded_diffstar_params,
)
from ..kernels.main_sequence_kernels import _get_unbounded_sfr_params
from ..kernels.quenching_kernels import _get_unbounded_q_params
from ..sfh_model import calc_sfh_singlegal
from ..utils import _sigmoid, compute_fstar, cumulative_mstar_formed
from .utils import minimizer_wrapper

T_FIT_MIN = 1.0  # Only fit snapshots above this threshold. Gyr units.
DLOGM_CUT = 3.5  # Only fit SMH within this dex of the present day stellar mass.
MIN_MASS_CUT = 7.0  # Only fit SMH above this threshold. Log10(Msun) units.
FSTAR_TIME_DELAY = 1.0  # Time period of averaged SFH (aka fstar). Gyr units.
SSFRH_FLOOR = 1e-12  # Clip SFH to this minimum sSFR value. 1/yr units.


NOFIT_FILL = -99.0


def diffstar_fitter(
    t_table,
    sfh_table,
    mah_params,
    dlogm_cut=DLOGM_CUT,
    t_fit_min=T_FIT_MIN,
    mass_fit_min=MIN_MASS_CUT,
    fstar_tdelay=FSTAR_TIME_DELAY,
    ssfrh_floor=SSFRH_FLOOR,
    lgt0=LGT0,
    fb=FB,
):
    """Run the diffstar fitter on the input SFH"""
    u_p_init_and_err, loss_data = get_loss_data_default(
        t_table,
        sfh_table,
        mah_params,
        dlogm_cut=dlogm_cut,
        t_fit_min=t_fit_min,
        mass_fit_min=mass_fit_min,
        fstar_tdelay=fstar_tdelay,
        ssfrh_floor=ssfrh_floor,
        lgt0=lgt0,
        fb=fb,
    )
    _res = minimizer_wrapper(
        loss_default_clipssfrh,
        loss_grad_default_clipssfrh_np,
        u_p_init_and_err,
        loss_data,
    )
    varied_u_p_best, loss_best, success = _res

    # Transform varied_u_p_best into p_best
    u_p_best = DEFAULT_DIFFSTAR_U_PARAMS._make((varied_u_p_best))
    p_best = get_bounded_diffstar_params(u_p_best)

    return p_best, loss_best, success


def get_loss_data_default(
    t_table,
    sfh_table,
    mah_params,
    dlogm_cut=DLOGM_CUT,
    t_fit_min=T_FIT_MIN,
    mass_fit_min=MIN_MASS_CUT,
    fstar_tdelay=FSTAR_TIME_DELAY,
    ssfrh_floor=SSFRH_FLOOR,
    lgt0=LGT0,
    fb=FB,
):
    """Get loss data to use with diffstar_fitter"""
    sfh_target = np.clip(sfh_table, SFR_MIN, np.inf)
    mstar_target = cumulative_mstar_formed(t_table, sfh_table)
    logmstar_target = np.log10(mstar_target)

    fstar_table = compute_fstar(t_table, mstar_target, fstar_tdelay)
    ssfrh_table = fstar_table / mstar_target
    ssfrh_target = np.clip(ssfrh_table, ssfrh_floor, np.inf)

    fstar_target = ssfrh_target * mstar_target
    fstar_target_min = fstar_target.max() / 1000.0
    fstar_target = np.where(
        fstar_target < fstar_target_min, fstar_target_min, fstar_target
    )
    log_fstar_target = np.log10(fstar_target)

    lgt_table = jnp.log10(t_table)
    log_mah = mah_singlehalo(mah_params, t_table, lgt0)[1]
    logmp0 = log_mah[-1]

    weight, weight_fstar = get_weights(
        t_table,
        logmstar_target,
        log_fstar_target,
        fstar_tdelay,
        dlogm_cut,
        t_fit_min,
        mass_fit_min,
    )

    lgt_fstar_max = lgt_table[np.argmax(log_fstar_target)]

    ms_params = np.array(DEFAULT_MS_PARAMS)
    ms_params[0] = np.clip(0.3 * (logmp0 - 11.0) + 11.4, 11.0, 13.0)
    ms_params[1] = np.clip(0.2 * (logmp0 - 11.0) - 9.7, -10.5, -9.2)
    ms_params[2] = np.clip(0.7 * (logmp0 - 11.0) - 0.3, 0.2, 3.0)
    ms_params = DEFAULT_MS_PARAMS._make(ms_params)
    varied_u_ms_params = np.array(_get_unbounded_sfr_params(*ms_params))

    u_ms_params_err = np.array([0.5, 0.5, 1.0, 1.0])

    varied_q_params = np.array(DEFAULT_Q_PARAMS)
    varied_q_params[0] = np.clip(-0.5 * (logmp0 - 11.0) + 1.5, 0.7, 1.5)
    varied_q_params[2] = -2.0
    varied_q_params = DEFAULT_Q_PARAMS._make(DEFAULT_Q_PARAMS)
    varied_u_q_params = np.array(_get_unbounded_q_params(*varied_q_params))
    u_q_params_err = np.array([0.3, 0.5, 0.3, 0.3])

    loss_data = (
        t_table,
        mah_params,
        mstar_target,
        logmstar_target,
        sfh_target,
        log_fstar_target,
        fstar_tdelay,
        ssfrh_floor,
        weight,
        weight_fstar,
        lgt_fstar_max,
        lgt0,
        fb,
    )

    u_p_init_and_err = (
        np.concatenate((varied_u_ms_params, varied_u_q_params)),
        np.concatenate((u_ms_params_err, u_q_params_err)),
    )
    return u_p_init_and_err, loss_data


def get_weights(
    t_table,
    log_smah_sim,
    log_fstar_sim,
    fstar_tdelay,
    dlogm_cut,
    t_fit_min,
    mass_fit_min,
):
    mass_fit_min = min(log_smah_sim[-1] - 0.5, mass_fit_min)

    mask = log_smah_sim > (log_smah_sim[-1] - dlogm_cut)
    mask &= log_smah_sim > mass_fit_min
    mask &= t_table >= t_fit_min

    weight = np.ones_like(t_table)
    weight[~mask] = 1e10
    weight[log_smah_sim[-1] - log_smah_sim < 0.1] = 0.5
    weight = jnp.array(weight)

    weight_fstar = np.ones_like(t_table)
    weight_fstar[~mask] = 1e10
    weight_fstar[log_fstar_sim.max() - log_fstar_sim < 0.1] = 0.5
    weight_fstar[weight_fstar == -10.0] = 1e10
    weight_fstar[t_table < fstar_tdelay + 0.01] = 1e10

    return weight, weight_fstar


@jjit
def loss_default_clipssfrh(u_params, loss_data):
    """
    MSE loss function for fitting individual stellar mass histories.
    The parameters k, indx_hi are fixed.

    """
    (
        t_table,
        mah_params,
        sm_target,
        log_sm_target,
        sfh_target,
        log_fstar_target,
        fstar_tdelay,
        ssfrh_floor,
        weight,
        weight_fstar,
        lgt_fstar_max,
        lgt0,
        fb,
    ) = loss_data

    sfh_u_params = DEFAULT_DIFFSTAR_U_PARAMS._make((u_params))
    sfh_params = get_bounded_diffstar_params(sfh_u_params)
    sfh_table, mstar_table = calc_sfh_singlegal(
        sfh_params, mah_params, t_table, lgt0=lgt0, fb=fb, return_smh=True
    )

    fstar = compute_fstar(t_table, mstar_table, fstar_tdelay)
    fstar = jnp.clip(fstar, mstar_table * (ssfrh_floor / 10.0), jnp.inf)
    sfh_table = jnp.clip(sfh_table, mstar_table * (ssfrh_floor / 10.0), jnp.inf)
    log_fstar = jnp.log10(fstar)

    logsm_table = jnp.log10(mstar_table)

    sfr_res = 1e8 * (sfh_table - sfh_target) / sm_target
    sfr_res = jnp.clip(sfr_res, -1.0, 1.0)

    loss = 30 * jnp.mean(((logsm_table - log_sm_target) / weight) ** 2)
    loss += jnp.mean(((log_fstar - log_fstar_target) / weight_fstar) ** 2)
    loss += jnp.mean((sfr_res / weight) ** 2)
    loss += 0.5 * (log_fstar - log_fstar_target)[-1] ** 2

    # Compute ridge terms
    loss += _sigmoid(sfh_params.lg_qt - lgt_fstar_max, 0.0, 50.0, 100.0, 0.0)
    # loss += _sigmoid(sfh_params.indx_lo, 0.0, 10.0, 1.0, 0.0)
    # loss += _sigmoid(sfh_params.lgy_at_mcrit, 0.0, 20.0, 0.0, 1.0)
    return loss


loss_grad_default_clipssfrh = jjit(grad(loss_default_clipssfrh, argnums=(0)))


def loss_grad_default_clipssfrh_np(params, data):
    return np.array(loss_grad_default_clipssfrh(params, data)).astype(float)


def get_header():
    """ """
    colnames = ["halo_id"]
    colnames.extend(list(DEFAULT_MS_PARAMS._fields))
    colnames.extend(list(DEFAULT_Q_PARAMS._fields))
    colnames.extend(["loss", "success"])
    header_str = "# " + " ".join(colnames) + "\n"
    return header_str, colnames


def get_outline(halo_id, p_best, loss_best, success):
    """Return the string storing fitting results that will be written to disk"""
    _d = np.array((p_best)).astype("f4")
    data_out = (*_d, float(loss_best))
    out = str(halo_id) + " " + " ".join(["{:.5e}".format(x) for x in data_out])
    out = out + " " + str(success)
    return out + "\n"


def get_outline_nofit(halo_id):
    """Return the string storing output that will be written to disk
    for galaxies without a diffstar fit"""
    n_sfh_params = len(DEFAULT_MS_PARAMS) + len(DEFAULT_Q_PARAMS)
    _d = np.zeros(n_sfh_params) + NOFIT_FILL
    loss_best = -1.0
    success = -1

    data_out = (*_d, float(loss_best))
    out = str(halo_id) + " " + " ".join(["{:.5e}".format(x) for x in data_out])
    out = out + " " + str(success)
    return out + "\n"


def write_collated_data(outname, data, colnames):

    ncols = np.shape(data)[1]
    assert len(colnames) == ncols, "data mismatched with header"

    with h5py.File(outname, "w") as hdf:
        for i, name in enumerate(colnames):
            if (name == "halo_id") | (name == "success"):
                hdf[name] = data[:, i].astype(int)
            else:
                hdf[name] = data[:, i].astype(float)
