"""Right-censored event time proposal mechanism"""

import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp

from gemlib.distributions.uniform_integer import UniformInteger

tfd = tfp.distributions
Root = tfd.JointDistribution.Root

Tensor = np.typing.NDArray


def _slice_min(state_tensor, start):
    """Compute `min(state_tensor[start:]` in an XLA-safe way

    Args
    ----
    state_tensor: a 1-D tensor
    start: an index into state_tensor

    Return
    ------
    `min(state_tensor[start:]`
    """
    state_tensor = jnp.asarray(state_tensor)

    masked_state_tensor = jnp.where(
        jnp.arange(state_tensor.shape[-1]) < start,
        np.inf,
        state_tensor,
    )

    return jnp.min(masked_state_tensor)


def add_occult_proposal(
    count_max: int,
    events: Tensor,
    src_state: Tensor,
    name=None,
):
    events = jnp.asarray(events)
    src_state = jnp.asarray(src_state)

    num_times = events.shape[-2]
    num_units = events.shape[-1]

    def proposal():
        # Select unit
        unit = yield Root(
            UniformInteger(
                low=0,
                high=num_units,
                float_dtype=events.dtype,
                name="unit",
            )
        )

        # Select timepoint
        timepoint = yield Root(
            UniformInteger(
                low=0,
                high=num_times,
                float_dtype=events.dtype,
                name="timepoint",
            )
        )

        # event_count is bounded by the minimum value of the source state
        state_bound = _slice_min(src_state[..., unit], timepoint)
        bound = jnp.minimum(state_bound, count_max).astype(np.int32)

        yield UniformInteger(
            low=jnp.minimum(1, bound),
            high=bound + 1,
            float_dtype=events.dtype,
            name="event_count",
        )

    return tfd.JointDistributionCoroutineAutoBatched(proposal, name=name)


def del_occult_proposal(
    count_max: int,
    events: Tensor,
    dest_state: Tensor,
    name=None,
):
    events = jnp.asarray(events)
    dest_state = jnp.asarray(dest_state)

    def proposal():
        # Select unit to delete events from
        nonzero_units = jnp.any(events > 0, axis=-2)
        probs = nonzero_units / jnp.linalg.norm(nonzero_units, ord=1, axis=-1)
        unit = yield Root(
            tfd.Categorical(
                probs=probs,
                name="unit",
            )
        )
        # If there are no events to delete, unit will be events.shape[-1] + 1
        # Therefore clip to ensure we don't get an error in the next stage.
        unit = jnp.clip(unit, min=0, max=events.shape[-1] - 1)

        # Select timepoint to delete events from
        unit_events = events[..., unit]  # T
        probs = (unit_events > 0) / jnp.linalg.norm(
            unit_events > 0, ord=1, axis=-1
        )
        timepoint = yield tfd.Categorical(probs=probs, name="timepoint")

        # Clip if there are no events to delete
        timepoint = jnp.clip(timepoint, min=0, max=events.shape[-2] - 1)
        # Draw num to delete - this is bounded by the minimum value of the
        # destination state over the range [timepoint, num_times)
        unit_dest_state = dest_state[..., unit]
        state_bound = _slice_min(unit_dest_state, timepoint + 1)
        bound = jnp.minimum(state_bound, events[..., timepoint, unit])
        bound = jnp.minimum(bound, count_max).astype(np.int32)

        yield UniformInteger(
            low=jnp.minimum(1, bound),
            high=bound + 1,
            float_dtype=events.dtype,
            name="event_count",
        )

    return tfd.JointDistributionCoroutineAutoBatched(proposal, name=name)
