DiscreteTimeStateTransitionModel#

class gemlib.distributions.DiscreteTimeStateTransitionModel(transition_rate_fn, incidence_matrix, initial_state, num_steps, initial_step=0, time_delta=1.0, validate_args=False, allow_nan_stats=True, name='DiscreteTimeStateTransitionModel')#

Discrete-time state transition model

A discrete-time state transition model assumes a population of individuals is divided into a number of mutually exclusive states, where transitions between states occur according to a Markov process. Such models are commonly found in epidemiological and ecological applications, where rapid implementation and modification is necessary.

This class provides a programmable implementation of the discrete-time state transition model, compatible with TensorFlow Probability.

Example

A homogeneously mixing SIR model implementation:

import jax.numpy as jnp
from gemlib.distributions import DiscreteTimeStateTransitionModel

# Initial state, counts per compartment (S, I, R), for one
#   population
initial_state = jnp.array([[99, 1, 0]], jnp.float32)

# Note that the incidence_matrix is treated as a static parameter,
# and so is supplied as a numpy ndarray, not a jax Array.
incidence_matrix = np.array(
    [
        [-1, 0],
        [1, -1],
        [0, 1],
    ],
    dtype=np.float32,
)

def si_rate(t, state):
    return 0.28 * state[:, 1] / jnp.sum(state, axis=-1)


def ir_rate(t, state):
    return 0.14


# Instantiate model
sir = DiscreteTimeStateTransitionModel(
    transition_rate_fn=(si_rate, ir_rate),
    incidence_matrix=incidence_matrix,
    initial_state=initial_state,
    num_steps=100,
)

# One realisation of the epidemic process
sim = sir.sample(seed=0)
__init__(transition_rate_fn, incidence_matrix, initial_state, num_steps, initial_step=0, time_delta=1.0, validate_args=False, allow_nan_stats=True, name='DiscreteTimeStateTransitionModel')#

Initialise a discrete-time state transition model.

Parameters:
  • transition_rate_fn (Sequence[Callable[[float, TypeAliasForwardRef('ArrayLike')], TypeAliasForwardRef('ArrayLike')]] | Callable[[float, TypeAliasForwardRef('ArrayLike')], tuple[TypeAliasForwardRef('ArrayLike'), ...]]) – Either a sequence of R callables (t: float, state: ArrayLike) -> ArrayLike, each returning a Array of rates of shape (N,), or a callable (t: float, state: ArrayLike) -> tuple[ArrayLike, ...], returning a tuple of R vectors of shape (N,). Here N is the number of units in the simulation. In the first (preferred) form, each callable returns the respective transition rate. In the second form, the single callable returns R transition rates. Note: the second form will be deprecated in future releases of gemlib.

  • incidence_matrix (ndarray[tuple[int, ...], dtype[_ScalarType_co]]) – incidence matrix of shape (S, R) for S states and R transitions between states.

  • initial_state (ArrayLike) – Initial state of the model, of shape (N,S) containing the counts in each of N units and S states. We require initial_state.shape[-1]==incidence_matrix.shape[-2]

  • num_steps (int) – the number of time steps simulated by the model.

  • initial_step (int) – time t at the start of the simulation.

  • time_delta (float) – the duration of the discretized time step.

Methods#

batch_shape_tensor(name='batch_shape_tensor')#

Shape of a single sample from a single event index as a 1-D Tensor.

The batch dimensions are indexes into independent, non-identical parameterizations of this distribution.

Parameters:

name – name to give to the op

Returns:

Tensor.

Return type:

batch_shape

compute_state(events, include_final_state=False)#

Computes a state timeseries from a sequence of transition events.

Parameters:
  • events (Array) – an array of events of shape (T,N,R) where T is the number of timesteps, N=self.num_units the number of units and R=self.incidence_matrix.shape[0] the number of transitions.

  • include_final_state (bool) – If False (default) the result does not include the final state. Otherwise the result includes the final state.

Returns:

An array of shape (T, N, S) if

include_final_state==False or (T+1, N, S) if include_final_state==True, where S=self.num_states, giving the number of individuals in each state at each time point for each unit.

Return type:

Array

event_shape_tensor(name='event_shape_tensor')#

Shape of a single sample from a single batch as a 1-D int32 Tensor.

Parameters:

name – name to give to the op

Returns:

Tensor.

Return type:

event_shape

log_prob(value, name='log_prob', **kwargs)#

Log probability density/mass function.

Parameters:
  • valuefloat or double Tensor.

  • name – Python str prepended to names of ops created by this function.

  • **kwargs – Named arguments forwarded to subclass implementation.

Returns:

a Tensor of shape sample_shape(x) + self.batch_shape with

values of type self.dtype.

Return type:

log_prob

prob(value, name='prob', **kwargs)#

Probability density/mass function.

Parameters:
  • valuefloat or double Tensor.

  • name – Python str prepended to names of ops created by this function.

  • **kwargs – Named arguments forwarded to subclass implementation.

Returns:

a Tensor of shape sample_shape(x) + self.batch_shape with

values of type self.dtype.

Return type:

prob

sample(sample_shape=(), seed=None, name='sample', **kwargs)#

Generate samples of the specified shape.

Note that a call to sample() without arguments will generate a single sample.

Additional documentation from DiscreteTimeStateTransitionModel:

Runs n simulations of the epidemic model.

Args:

n: number of simulations to run. seed: an integer or a JAX PRNG key.

Returns:

an array of shape (n, T, N, R) for n samples, T=self.num_steps time steps, N=self.num_units units and R=self.transition_matrix.shape[0] transitions, containing the number of events for each timestep, unit and transition.

Parameters:
  • sample_shape – 0D or 1D int32 Tensor. Shape of the generated samples.

  • seed – PRNG seed; see tfp.random.sanitize_seed for details.

  • name – name to give to the op.

  • **kwargs – Named arguments forwarded to subclass implementation.

Returns:

a Tensor with prepended dimensions sample_shape.

Return type:

samples

transition_prob_matrix(events)#

Compute the Markov transition probability matrix.

Parameters:

events (None | ArrayLike) – None, or an array of shape (T, N, R), where T is the number of timesteps, N=self.num_units` the number of units and ``R=self.incidence_matrix.shape[0] the number of transitions.

Returns:

Transition probabilty matrix. If events is None, this matrix is of shape (N, S, S), transition probability matrix associated with the initial state (self.initial_state), of shape (N, S), where S=self.num_states is the number of distinct states. Otherwise, this matrix is of shape (T, N, S, S), representing the transition probability matrix at each timestep. of shape (T, N, S, S).

Return type:

Array

Attributes#

allow_nan_stats

Python bool describing behavior when a stat is undefined.

batch_shape

Shape of a single sample from a single event index as a TensorShape.

dtype

The DType of Tensor`s handled by this `Distribution.

event_shape

Shape of a single sample from a single batch as a TensorShape.

experimental_shard_axis_names

The list or structure of lists of active shard axis names.

incidence_matrix

initial_state

initial_step

name

Name prepended to all ops created by this Distribution.

num_states

num_steps

num_units

parameters

Dictionary of parameters used to instantiate this Distribution.

reparameterization_type

Describes how samples from the distribution are reparameterized.

source_states

time_delta

trainable_variables

transition_rate_fn

validate_args

Python bool indicating possibly expensive checks are enabled.

variables