"""Helper functions for Wuhan Covid-19 analysis using gemlib"""

import jax
import jax.numpy as jnp
import numpy as np

__all__ = ["expected_observed_cases", "make_initial_state"]

# Types
JaxArray = jax.Array
ArrayLike = jax.typing.ArrayLike


def make_initial_state(
    num_initial_infections: float, popsize: ArrayLike, initial_index: ArrayLike
) -> JaxArray:
    """Construct an initial state tensor for a SEIR state transition model

    This function returns an Array of shape ``(M, 4)`` where `M` is the
    number of metapopulations equal to ``popsize.shape[-1]``.

    Each metapopulation ``i`` is assigned ``S[i] = popsize[i], E[i] = I[i]
    = R[i] = 0``, except for the ``initial_index``th metapopulation which
    is assigned ``S[i] = popsize[i] - num_initial_infections, E[i] = 0,
    I[i] = num_initial_infections, R[i] = 0``.

    Args:
      num_initial_infections: the number of initial Infected individuals
        in the `initial_index`th metapopulation.
      popsize: an `ArrayLike` of shape `(M,)` containing the population
        sizes for each metapopulation.
      initial_index: the index of the metapopulation containing the
        initial infected individuals.

    Returns:
      an ``Array`` of shape ``(M, 4)`` representing ``M`` metapopulations
      each with ``4`` state compartments for the SEIR model.
    """

    initial_infections = num_initial_infections * jax.nn.one_hot(
        initial_index, popsize.shape[-1], dtype=np.float32
    )

    return jnp.stack(
        [
            popsize - initial_infections,
            jnp.zeros_like(popsize),
            initial_infections,
            jnp.zeros_like(popsize),
        ],
        axis=-1,
    )


def expected_observed_cases(
    states: ArrayLike, aggregate_up_to: int = 11
) -> JaxArray:
    """Return a partially-aggregated timeseries of observations

    Given a ``[T,M,4]`` array representing an SEIR epidemic with `T` timepoints,
    ``M`` metapopulations, and 4 states, return an Array of shape
    ``(T-aggregate_up_to,M)`` in which the first slice along the first axis is
    the aggregation of the R state compartment in the closed-open interval
    ``[0, aggregate_up_to)``.

    Args:
      states: an ``ArrayLike`` of shape ``[T,M,4]`` of SEIR states for each of
      ``M`` metapopulations for timepoints ``[0, T)``.
      aggregate_up_to: sum the elements of the R state compartment in the
        close-open interval ``[0,aggregate_up_to)``

    Returns:
      a ``JaxArray`` of shape ``(T-aggregate_up_to, M)`` containing values of
      the R state compartment.
    """
    r_state = states[..., 3]
    incr = r_state[1:] - r_state[:-1]
    agg_incrs = incr[:aggregate_up_to].sum(axis=0, keepdims=True)
    return jnp.concatenate(
        [agg_incrs, incr[11:]],
        axis=0,
    )
