DiscreteTimeStateTransitionModel#
- class gemlib.distributions.DiscreteTimeStateTransitionModel(transition_rate_fn, incidence_matrix, initial_state, num_steps, initial_step=0, time_delta=1.0, validate_args=False, allow_nan_stats=True, name='DiscreteTimeStateTransitionModel')#
Discrete-time state transition model
A discrete-time state transition model assumes a population of individuals is divided into a number of mutually exclusive states, where transitions between states occur according to a Markov process. Such models are commonly found in epidemiological and ecological applications, where rapid implementation and modification is necessary.
This class provides a programmable implementation of the discrete-time state transition model, compatible with TensorFlow Probability.
Example
A homogeneously mixing SIR model implementation:
import jax.numpy as jnp from gemlib.distributions import DiscreteTimeStateTransitionModel # Initial state, counts per compartment (S, I, R), for one # population initial_state = jnp.array([[99, 1, 0]], jnp.float32) # Note that the incidence_matrix is treated as a static parameter, # and so is supplied as a numpy ndarray, not a jax Array. incidence_matrix = np.array( [ [-1, 0], [1, -1], [0, 1], ], dtype=np.float32, ) def si_rate(t, state): return 0.28 * state[:, 1] / jnp.sum(state, axis=-1) def ir_rate(t, state): return 0.14 # Instantiate model sir = DiscreteTimeStateTransitionModel( transition_rate_fn=(si_rate, ir_rate), incidence_matrix=incidence_matrix, initial_state=initial_state, num_steps=100, ) # One realisation of the epidemic process sim = sir.sample(seed=0)
- __init__(transition_rate_fn, incidence_matrix, initial_state, num_steps, initial_step=0, time_delta=1.0, validate_args=False, allow_nan_stats=True, name='DiscreteTimeStateTransitionModel')#
Initialise a discrete-time state transition model.
- Parameters:
transition_rate_fn (Sequence[Callable[[float, TypeAliasForwardRef('ArrayLike')], TypeAliasForwardRef('ArrayLike')]] | Callable[[float, TypeAliasForwardRef('ArrayLike')], tuple[TypeAliasForwardRef('ArrayLike'), ...]]) – Either a sequence of
Rcallables(t: float, state: ArrayLike) -> ArrayLike, each returning a Array of rates of shape(N,), or a callable(t: float, state: ArrayLike) -> tuple[ArrayLike, ...], returning a tuple ofRvectors of shape(N,). HereNis the number of units in the simulation. In the first (preferred) form, each callable returns the respective transition rate. In the second form, the single callable returnsRtransition rates. Note: the second form will be deprecated in future releases ofgemlib.incidence_matrix (ndarray[tuple[int, ...], dtype[_ScalarType_co]]) – incidence matrix of shape
(S, R)forSstates andRtransitions between states.initial_state (ArrayLike) – Initial state of the model, of shape
(N,S)containing the counts in each ofNunits andSstates. We requireinitial_state.shape[-1]==incidence_matrix.shape[-2]num_steps (int) – the number of time steps simulated by the model.
initial_step (int) – time
tat the start of the simulation.time_delta (float) – the duration of the discretized time step.
Methods#
- batch_shape_tensor(name='batch_shape_tensor')#
Shape of a single sample from a single event index as a 1-D Tensor.
The batch dimensions are indexes into independent, non-identical parameterizations of this distribution.
- Parameters:
name – name to give to the op
- Returns:
Tensor.
- Return type:
batch_shape
- compute_state(events, include_final_state=False)#
Computes a state timeseries from a sequence of transition events.
- Parameters:
events (Array) – an array of events of shape
(T,N,R)whereTis the number of timesteps,N=self.num_unitsthe number of units andR=self.incidence_matrix.shape[0]the number of transitions.include_final_state (bool) – If
False(default) the result does not include the final state. Otherwise the result includes the final state.
- Returns:
- An array of shape
(T, N, S)if include_final_state==Falseor(T+1, N, S)ifinclude_final_state==True, whereS=self.num_states, giving the number of individuals in each state at each time point for each unit.
- An array of shape
- Return type:
Array
- event_shape_tensor(name='event_shape_tensor')#
Shape of a single sample from a single batch as a 1-D int32 Tensor.
- Parameters:
name – name to give to the op
- Returns:
Tensor.
- Return type:
event_shape
- log_prob(value, name='log_prob', **kwargs)#
Log probability density/mass function.
- Parameters:
value – float or double Tensor.
name – Python str prepended to names of ops created by this function.
**kwargs – Named arguments forwarded to subclass implementation.
- Returns:
- a Tensor of shape sample_shape(x) + self.batch_shape with
values of type self.dtype.
- Return type:
log_prob
- prob(value, name='prob', **kwargs)#
Probability density/mass function.
- Parameters:
value – float or double Tensor.
name – Python str prepended to names of ops created by this function.
**kwargs – Named arguments forwarded to subclass implementation.
- Returns:
- a Tensor of shape sample_shape(x) + self.batch_shape with
values of type self.dtype.
- Return type:
prob
- sample(sample_shape=(), seed=None, name='sample', **kwargs)#
Generate samples of the specified shape.
Note that a call to sample() without arguments will generate a single sample.
Additional documentation from DiscreteTimeStateTransitionModel:
Runs n simulations of the epidemic model.
- Args:
n: number of simulations to run. seed: an integer or a JAX PRNG key.
- Returns:
an array of shape
(n, T, N, R)fornsamples,T=self.num_stepstime steps,N=self.num_unitsunits andR=self.transition_matrix.shape[0]transitions, containing the number of events for each timestep, unit and transition.
- Parameters:
sample_shape – 0D or 1D int32 Tensor. Shape of the generated samples.
seed – PRNG seed; see tfp.random.sanitize_seed for details.
name – name to give to the op.
**kwargs – Named arguments forwarded to subclass implementation.
- Returns:
a Tensor with prepended dimensions sample_shape.
- Return type:
samples
- transition_prob_matrix(events)#
Compute the Markov transition probability matrix.
- Parameters:
events (None | ArrayLike) – None, or an array of shape
(T, N, R), whereTis the number of timesteps,N=self.num_units` the number of units and ``R=self.incidence_matrix.shape[0]the number of transitions.- Returns:
Transition probabilty matrix. If
eventsis None, this matrix is of shape(N, S, S), transition probability matrix associated with the initial state (self.initial_state), of shape(N, S), whereS=self.num_statesis the number of distinct states. Otherwise, this matrix is of shape(T, N, S, S), representing the transition probability matrix at each timestep. of shape(T, N, S, S).- Return type:
Array
Attributes#
allow_nan_statsPython bool describing behavior when a stat is undefined.
batch_shapeShape of a single sample from a single event index as a TensorShape.
dtypeThe DType of Tensor`s handled by this `Distribution.
event_shapeShape of a single sample from a single batch as a TensorShape.
experimental_shard_axis_namesThe list or structure of lists of active shard axis names.
incidence_matrixinitial_stateinitial_stepnameName prepended to all ops created by this Distribution.
num_statesnum_stepsnum_unitsparametersDictionary of parameters used to instantiate this Distribution.
reparameterization_typeDescribes how samples from the distribution are reparameterized.
source_statestime_deltatrainable_variablestransition_rate_fnvalidate_argsPython bool indicating possibly expensive checks are enabled.
variables