"""Sampler for discrete-space occult events"""

from collections.abc import Callable
from typing import NamedTuple

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from tensorflow_probability.substrates.jax.mcmc.internal import (
    util as mcmc_util,
)

from gemlib.distributions.discrete_markov import (
    compute_state,
)
from gemlib.mcmc.discrete_time_state_transition_model.right_censored_events_proposal import (  # noqa:E501
    add_occult_proposal,
    del_occult_proposal,
)
from gemlib.mcmc.sampling_algorithm import Position
from gemlib.prng_util import sanitize_key
from gemlib.util import transition_coords

tfd = tfp.distributions
Tensor = np.typing.NDArray

__all__ = ["UncalibratedOccultUpdate"]


PROB_DIRECTION = 0.5


class OccultKernelResults(NamedTuple):
    log_acceptance_correction: float
    target_log_prob: float
    unit: int
    timepoint: int
    is_add: bool
    event_count: int
    seed: tuple[int, int]


def _is_row_nonzero(m):
    return jnp.sum(m, axis=-1) > 0.0


def _maybe_expand_dims(x):
    """If x is a scalar, give it at least 1 dimension"""
    return jnp.atleast_1d(x)


def _add_events(events, unit, timepoint, target_transition_id, event_count):
    """Adds `x_star` events to metapopulation `m`,
    time `t`, transition `x` in `events`.
    """
    events = jnp.asarray(events)
    return events.at[timepoint, unit, target_transition_id].add(event_count)


class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
    """UncalibratedOccultUpdate"""

    def __init__(
        self,
        target_log_prob_fn: Callable[[Position], float],
        incidence_matrix: Tensor,
        initial_conditions: Tensor,
        target_transition_id: int,
        count_max: int,
        t_range,
        name=None,
    ):
        """An uncalibrated random walk for event times.
        :param target_log_prob_fn: the log density of the target distribution
        :param target_event_id: the position in the last dimension of the events
                                tensor that we wish to move
        :param t_range: a tuple containing earliest and latest times between
                         which to update occults.
        :param seed: a random seed
        :param name: the name of the update step
        """
        self._parameters = dict(locals())
        self._name = name or "uncalibrated_occult_update"
        self._dtype = jnp.asarray(initial_conditions).dtype

    @property
    def target_log_prob_fn(self):
        return self._parameters["target_log_prob_fn"]

    @property
    def incidence_matrix(self):
        return self._parameters["incidence_matrix"]

    @property
    def target_transition_id(self):
        return self._parameters["target_transition_id"]

    @property
    def initial_conditions(self):
        return self._parameters["initial_conditions"]

    @property
    def count_max(self):
        return self._parameters["count_max"]

    @property
    def t_range(self):
        return self._parameters["t_range"]

    @property
    def name(self):
        return self._parameters["name"]

    @property
    def parameters(self):
        """Return `dict` of ``__init__`` arguments and their values."""
        return self._parameters

    @property
    def is_calibrated(self):
        return False

    def one_step(self, current_events, previous_kernel_results, seed):  # noqa: ARG002
        """One update of event times.
        :param current_events: a [T, M, R] tensor containing number of events
                               per time t, metapopulation m,
                               and transition r.
        :param previous_kernel_results: an object of type
                                        UncalibratedRandomWalkResults.
        :returns: a tuple containing new_state and UncalibratedRandomWalkResults
        """
        with jax.named_scope("occult_rw/onestep"):
            current_events = jnp.asarray(current_events)
            initial_conditions = jnp.asarray(self.initial_conditions)

            proposal_seed, add_del_seed = jax.random.split(seed)
            t_range_slice = slice(*self.t_range)

            state = compute_state(
                initial_conditions,
                current_events,
                self.incidence_matrix,
            )
            src_dest_ids = transition_coords(self.incidence_matrix)[
                self.target_transition_id, :
            ]

            # Pull out the section of events and state that are within
            # the requested time interval - we focus on this, and insert
            # updated values back into the full events at the end.
            range_events = current_events[t_range_slice]
            state_slice = state[t_range_slice]

            def add_occult_fn():
                with jax.named_scope("true_fn"):
                    proposal = add_occult_proposal(
                        count_max=self.count_max,
                        events=range_events[..., self.target_transition_id],
                        src_state=state_slice[..., src_dest_ids[0]],
                    )
                    update = proposal.sample(seed=proposal_seed)

                    next_events = _add_events(
                        events=range_events,
                        unit=update.unit,
                        timepoint=update.timepoint,
                        target_transition_id=self.target_transition_id,
                        event_count=update.event_count,
                    )

                    next_dest_state = compute_state(
                        state_slice[0], next_events, self.incidence_matrix
                    )
                    reverse = del_occult_proposal(
                        count_max=self.count_max,
                        events=next_events[..., self.target_transition_id],
                        dest_state=next_dest_state[..., src_dest_ids[1]],
                    )
                    q_fwd = jnp.sum(proposal.log_prob(update))
                    q_rev = jnp.sum(reverse.log_prob(update))
                    log_acceptance_correction = q_rev - q_fwd

                return (
                    update,
                    next_events,
                    log_acceptance_correction,
                    True,
                )

            def del_occult_fn():
                with jax.named_scope("false_fn"):
                    proposal = del_occult_proposal(
                        count_max=self.count_max,
                        events=range_events[..., self.target_transition_id],
                        dest_state=state_slice[..., src_dest_ids[1]],
                    )
                    update = proposal.sample(seed=proposal_seed)

                    next_events = _add_events(
                        events=range_events,
                        unit=update.unit,
                        timepoint=update.timepoint,
                        target_transition_id=self.target_transition_id,
                        event_count=-update.event_count,
                    )

                    next_src_state = compute_state(
                        state_slice[0], next_events, self.incidence_matrix
                    )
                    reverse = add_occult_proposal(
                        count_max=self.count_max,
                        events=next_events[..., self.target_transition_id],
                        src_state=next_src_state[..., src_dest_ids[0]],
                    )
                    q_fwd = jnp.sum(proposal.log_prob(update))
                    q_rev = jnp.sum(reverse.log_prob(update))
                    log_acceptance_correction = q_rev - q_fwd

                return (
                    update,
                    next_events,
                    log_acceptance_correction,
                    False,
                )

            u = tfd.Uniform().sample(seed=add_del_seed)
            delta, next_range_events, log_acceptance_correction, is_add = (
                jax.lax.cond(
                    (u < PROB_DIRECTION)
                    & (jnp.count_nonzero(range_events) > 0),
                    del_occult_fn,
                    add_occult_fn,
                )
            )

            # Update current_events with the new next_range_events tensor
            next_events = current_events.at[t_range_slice].set(
                next_range_events
            )
            next_target_log_prob = self.target_log_prob_fn(next_events)

            return (
                next_events,
                OccultKernelResults(
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    unit=delta.unit,
                    timepoint=delta.timepoint,
                    is_add=is_add,
                    event_count=delta.event_count,
                    seed=seed,
                ),
            )

    def bootstrap_results(self, init_state):
        with jax.named_scope("uncalibrated_event_times_rw/bootstrap_results"):
            if not mcmc_util.is_list_like(init_state):
                init_state = [init_state]

            init_state = [jnp.asarray(x, dtype=self._dtype) for x in init_state]
            init_target_log_prob = self.target_log_prob_fn(*init_state)
            return OccultKernelResults(
                log_acceptance_correction=jnp.asarray(
                    0.0, dtype=init_target_log_prob.dtype
                ),
                target_log_prob=init_target_log_prob,
                unit=jnp.zeros((), dtype=np.int32),
                timepoint=jnp.zeros((), dtype=np.int32),
                is_add=jnp.asarray(True),
                event_count=jnp.zeros((), dtype=np.int32),
                seed=sanitize_key(0),
            )
