ContinuousTimeStateTransitionModel#

class gemlib.distributions.ContinuousTimeStateTransitionModel(transition_rate_fn, incidence_matrix, initial_state, num_steps, initial_time=0.0, validate_args=False, allow_nan_stats=True, name='ContinuousTimeStateTransitionModel')#

Continuous time state transition model.

Example

A homogeneously mixing SIR model implementation:

import jax.numpy as jnp
from gemlib.distributions import ContinuousTimeStateTransitionModel

# Initial state, counts per compartment (S, I, R), for one
#   population
initial_state = jnp.array([[99, 1, 0]], jnp.float32)

# Note that the incidence_matrix is treated as a static parameter,
# and so is supplied as a numpy ndarray, not a jax Array.
incidence_matrix = np.array(
    [
        [-1, 0],
        [1, -1],
        [0, 1],
    ],
    dtype=np.float32,
)


def si_rate(t, state):
    return 0.28 * state[:, 1] / jnp.sum(state, axis=-1)


def ir_rate(t, state):
    return 0.14


# Instantiate model
sir = ContinuousTimeStateTransitionModel(
    transition_rate_fn=(si_rate, ir_rate),
    incidence_matrix=incidence_matrix,
    initial_state=initial_state,
    num_steps=100,
)

# One realisation of the epidemic process
sim = sir.sample(seed=0)

# Compute the log probability of observing `sim`
sir.log_prob(sim)
__init__(transition_rate_fn, incidence_matrix, initial_state, num_steps, initial_time=0.0, validate_args=False, allow_nan_stats=True, name='ContinuousTimeStateTransitionModel')#

Initializes a ContinuousTimeStateTransitionModel object.

Parameters:
  • transition_rate_fn (Sequence[Callable[[float, TypeAliasForwardRef('ArrayLike')], TypeAliasForwardRef('ArrayLike')]] | Callable[[float, TypeAliasForwardRef('ArrayLike')], tuple[TypeAliasForwardRef('ArrayLike'), ...]]) – Either a list of callables of the form fn(t: float, state: Tensor) -> Tensor or a Python callable of the form fn(t: float, state: Tensor) -> tuple(Tensor,...) . In the first (preferred) form, each callable in the list corresponds to the respective transition in incidence_matrix. In the second form, the callable should return a tuple of transition rate tensors corresponding to transitions in incidence_matrix. Note: the second form will be deprecated in future releases of gemlib.

  • incidence_matrix (ndarray[tuple[int, ...], dtype[_ScalarType_co]]) – Matrix representing the incidence of transitions between states.

  • initial_state (Array) – A [N, S] tensor containing the initial state of the population of N units in S epidemiological classes.

  • num_steps (int) – the number of markov jumps for a single iteration. initial_time: Initial time of the model. Defaults to 0.0.

  • name (str) – Name of the model. Defaults to “ContinuousTimeStateTransitionModel”.

Methods#

batch_shape_tensor(name='batch_shape_tensor')#

Shape of a single sample from a single event index as a 1-D Tensor.

The batch dimensions are indexes into independent, non-identical parameterizations of this distribution.

Parameters:

name – name to give to the op

Returns:

Tensor.

Return type:

batch_shape

compute_state(event_list, include_final_state=False)#

Compute state timeseries given an event list

Parameters:
  • event_list (EventList) – the event list, assumed to be sorted by time.

  • include_final_state (bool) – should the final state be included in the returned timeseries? If True, then the time dimension of the returned tensor will be 1 greater than the length of the event list. If False (default) these will be equal.

Returns:

A [T, N, S] tensor where T is the number of events, N is the number of units, and S is the number of states.

Return type:

Array

event_shape_tensor(name='event_shape_tensor')#

Shape of a single sample from a single batch as a 1-D int32 Tensor.

Parameters:

name – name to give to the op

Returns:

Tensor.

Return type:

event_shape

log_prob(value, name='log_prob', **kwargs)#

Log probability density/mass function.

Additional documentation from ContinuousTimeStateTransitionModel:

Computes the log probability of the given outcomes.

Args:
value (EventList): an EventList object representing the

outcomes.

Returns:

float: The log probability of the given outcomes.

Parameters:
  • valuefloat or double Tensor.

  • name – Python str prepended to names of ops created by this function.

  • **kwargs – Named arguments forwarded to subclass implementation.

Returns:

a Tensor of shape sample_shape(x) + self.batch_shape with

values of type self.dtype.

Return type:

log_prob

prob(value, name='prob', **kwargs)#

Probability density/mass function.

Parameters:
  • valuefloat or double Tensor.

  • name – Python str prepended to names of ops created by this function.

  • **kwargs – Named arguments forwarded to subclass implementation.

Returns:

a Tensor of shape sample_shape(x) + self.batch_shape with

values of type self.dtype.

Return type:

prob

sample(sample_shape=(), seed=None, name='sample', **kwargs)#

Generate samples of the specified shape.

Note that a call to sample() without arguments will generate a single sample.

Additional documentation from ContinuousTimeStateTransitionModel:

Samples n outcomes from the continuous time state transition model.

Args:
n (int): The number of realisations of the Markov process to sample

(currently ignored).

seed (int, optional): The seed value for random number generation.

Defaults to None.

Returns:

Sample(s) from the continuous time state transition model.

Parameters:
  • sample_shape – 0D or 1D int32 Tensor. Shape of the generated samples.

  • seed – PRNG seed; see tfp.random.sanitize_seed for details.

  • name – name to give to the op.

  • **kwargs – Named arguments forwarded to subclass implementation.

Returns:

a Tensor with prepended dimensions sample_shape.

Return type:

samples

Attributes#

allow_nan_stats

Python bool describing behavior when a stat is undefined.

batch_shape

Shape of a single sample from a single event index as a TensorShape.

dtype

The DType of Tensor`s handled by this `Distribution.

event_shape

Shape of a single sample from a single batch as a TensorShape.

experimental_shard_axis_names

The list or structure of lists of active shard axis names.

incidence_matrix

Incidence matrix for the model.

initial_state

Initial state of the model.

initial_time

Initial wall clock for the model.

name

Name prepended to all ops created by this Distribution.

num_steps

Number of events to simulate.

parameters

Dictionary of parameters used to instantiate this Distribution.

reparameterization_type

Describes how samples from the distribution are reparameterized.

trainable_variables

transition_rate_fn

Transition rate function for the model.

validate_args

Python bool indicating possibly expensive checks are enabled.

variables