ContinuousTimeStateTransitionModel#
- class gemlib.distributions.ContinuousTimeStateTransitionModel(transition_rate_fn, incidence_matrix, initial_state, num_steps, initial_time=0.0, validate_args=False, allow_nan_stats=True, name='ContinuousTimeStateTransitionModel')#
Continuous time state transition model.
Example
A homogeneously mixing SIR model implementation:
import jax.numpy as jnp from gemlib.distributions import ContinuousTimeStateTransitionModel # Initial state, counts per compartment (S, I, R), for one # population initial_state = jnp.array([[99, 1, 0]], jnp.float32) # Note that the incidence_matrix is treated as a static parameter, # and so is supplied as a numpy ndarray, not a jax Array. incidence_matrix = np.array( [ [-1, 0], [1, -1], [0, 1], ], dtype=np.float32, ) def si_rate(t, state): return 0.28 * state[:, 1] / jnp.sum(state, axis=-1) def ir_rate(t, state): return 0.14 # Instantiate model sir = ContinuousTimeStateTransitionModel( transition_rate_fn=(si_rate, ir_rate), incidence_matrix=incidence_matrix, initial_state=initial_state, num_steps=100, ) # One realisation of the epidemic process sim = sir.sample(seed=0) # Compute the log probability of observing `sim` sir.log_prob(sim)
- __init__(transition_rate_fn, incidence_matrix, initial_state, num_steps, initial_time=0.0, validate_args=False, allow_nan_stats=True, name='ContinuousTimeStateTransitionModel')#
Initializes a ContinuousTimeStateTransitionModel object.
- Parameters:
transition_rate_fn (Sequence[Callable[[float, TypeAliasForwardRef('ArrayLike')], TypeAliasForwardRef('ArrayLike')]] | Callable[[float, TypeAliasForwardRef('ArrayLike')], tuple[TypeAliasForwardRef('ArrayLike'), ...]]) – Either a list of callables of the form
fn(t: float, state: Tensor) -> Tensoror a Python callable of the formfn(t: float, state: Tensor) -> tuple(Tensor,...). In the first (preferred) form, each callable in the list corresponds to the respective transition inincidence_matrix. In the second form, the callable should return atupleof transition rate tensors corresponding to transitions inincidence_matrix. Note: the second form will be deprecated in future releases ofgemlib.incidence_matrix (ndarray[tuple[int, ...], dtype[_ScalarType_co]]) – Matrix representing the incidence of transitions between states.
initial_state (Array) – A
[N, S]tensor containing the initial state of the population ofNunits inSepidemiological classes.num_steps (int) – the number of markov jumps for a single iteration. initial_time: Initial time of the model. Defaults to 0.0.
name (str) – Name of the model. Defaults to “ContinuousTimeStateTransitionModel”.
Methods#
- batch_shape_tensor(name='batch_shape_tensor')#
Shape of a single sample from a single event index as a 1-D Tensor.
The batch dimensions are indexes into independent, non-identical parameterizations of this distribution.
- Parameters:
name – name to give to the op
- Returns:
Tensor.
- Return type:
batch_shape
- compute_state(event_list, include_final_state=False)#
Compute state timeseries given an event list
- Parameters:
event_list (EventList) – the event list, assumed to be sorted by time.
include_final_state (bool) – should the final state be included in the returned timeseries? If True, then the time dimension of the returned tensor will be 1 greater than the length of the event list. If False (default) these will be equal.
- Returns:
A [T, N, S] tensor where T is the number of events, N is the number of units, and S is the number of states.
- Return type:
Array
- event_shape_tensor(name='event_shape_tensor')#
Shape of a single sample from a single batch as a 1-D int32 Tensor.
- Parameters:
name – name to give to the op
- Returns:
Tensor.
- Return type:
event_shape
- log_prob(value, name='log_prob', **kwargs)#
Log probability density/mass function.
Additional documentation from ContinuousTimeStateTransitionModel:
Computes the log probability of the given outcomes.
- Args:
- value (EventList): an EventList object representing the
outcomes.
- Returns:
float: The log probability of the given outcomes.
- Parameters:
value – float or double Tensor.
name – Python str prepended to names of ops created by this function.
**kwargs – Named arguments forwarded to subclass implementation.
- Returns:
- a Tensor of shape sample_shape(x) + self.batch_shape with
values of type self.dtype.
- Return type:
log_prob
- prob(value, name='prob', **kwargs)#
Probability density/mass function.
- Parameters:
value – float or double Tensor.
name – Python str prepended to names of ops created by this function.
**kwargs – Named arguments forwarded to subclass implementation.
- Returns:
- a Tensor of shape sample_shape(x) + self.batch_shape with
values of type self.dtype.
- Return type:
prob
- sample(sample_shape=(), seed=None, name='sample', **kwargs)#
Generate samples of the specified shape.
Note that a call to sample() without arguments will generate a single sample.
Additional documentation from ContinuousTimeStateTransitionModel:
Samples n outcomes from the continuous time state transition model.
- Args:
- n (int): The number of realisations of the Markov process to sample
(currently ignored).
- seed (int, optional): The seed value for random number generation.
Defaults to None.
- Returns:
Sample(s) from the continuous time state transition model.
- Parameters:
sample_shape – 0D or 1D int32 Tensor. Shape of the generated samples.
seed – PRNG seed; see tfp.random.sanitize_seed for details.
name – name to give to the op.
**kwargs – Named arguments forwarded to subclass implementation.
- Returns:
a Tensor with prepended dimensions sample_shape.
- Return type:
samples
Attributes#
allow_nan_statsPython bool describing behavior when a stat is undefined.
batch_shapeShape of a single sample from a single event index as a TensorShape.
dtypeThe DType of Tensor`s handled by this `Distribution.
event_shapeShape of a single sample from a single batch as a TensorShape.
experimental_shard_axis_namesThe list or structure of lists of active shard axis names.
incidence_matrixIncidence matrix for the model.
initial_stateInitial state of the model.
initial_timeInitial wall clock for the model.
nameName prepended to all ops created by this Distribution.
num_stepsNumber of events to simulate.
parametersDictionary of parameters used to instantiate this Distribution.
reparameterization_typeDescribes how samples from the distribution are reparameterized.
trainable_variablestransition_rate_fnTransition rate function for the model.
validate_argsPython bool indicating possibly expensive checks are enabled.
variables