Gemlib#

Gemlib

Gemlib is a library providing Python classes for infectious disease modelling. It provides build blocks to assemble complex models, and fit them to data.

Features#

  • programmable classes for deterministic, continuous- and discrete-time state transition models

  • the library is compatible with TensorFlow Probability (JAX backend), allowing complex hierarchical Bayesian models to be built around infectious disease models

  • a suite of MCMC samplers caters for parameter inference. Gemlib provides random walk Metropolis-Hastings, as well as Hamiltonian Monte Carlo, and specialised samplers for integrating out censored epidemiological event data (e.g. infection times)

Installation#

Warning

As of version 0.14.0, gemlib is based on JAX. For use with TensorFlow, please use version <0.14.0.

The latest release of Gemlib can be installed from PyPI using pip:

$ pip install gemlib

or if you have an NVIDIA GPU:

$ pip install "gemlib[gpu]"

The current development version of the library can be installed at any time from our GitLab repository with:

$ pip install git+https://gitlab.com/gem-epidemics/gemlib

System requirements:

  • A computer running Linux, Windows 7 (or later), or MacOSX

  • Python >=3.11,<3.14

  • an NVIDIA GPU compatible with the latest version of JAX if ultimate performance is required

Quick example#

Gemlib presents a powerful API for constructing Markov state transition models, such as are used in infectious disease modelling. Here’s a quick example of how to implement a stochastic homogeneously-mixing SIR model in discrete time.:

import numpy as np
import matplotlib.pyplot as plt
from gemlib.distributions import DiscreteTimeStateTransitionModel

# Represent the S -> I -> R model as a graph incidence matrix
incidence_matrix = np.array([[-1,  0],
                            [ 1, -1],
                            [ 0,  1]], dtype=np.float32)

# Initial S, I, and R states for a single population
initial_state = np.array([[99, 1, 0]], dtype=np.float32)

# Define the transition rates
def si_rate(t, state):
    return 0.2 * state[:, 1] / state.sum(axis=-1)

def ir_rate(t, state):
    return 0.14

# Instantiate the model
model = DiscreteTimeStateTransitionModel(
    [si_rate, ir_rate], incidence_matrix, initial_state, num_steps=50
)

# Draw a realisation of the epidemic process
sample = model.sample(seed=5)

# Compute the probability of observing `sample` given the model
log_prob = model.log_prob(sample)

# Convert the transition event tensor output to numbers in each state over time
sample_state = model.compute_state(sample)

# Plot simulation
plt.plot(np.sum(sample_state, axis=1), label=["S", "I", "R"])
plt.xlabel("Time")
plt.ylabel("Number of individuals")
_ = plt.legend()
SIR model simulation

Tip

Since Gemlib is based on JAX, all functions and methods can be optimised using jax.jit. With complex models such as Gemlib is designed for, this can often result in spectacular speedups compared to unoptimised code.:

import jax
fast_sample = jax.jit(model.sample)
%time sample = fast_sample(seed=jax.random.key(5))

Acknowledgements#

The Gemlib team is indebted to the TensorFlow Probability and BlackJAX teams, whose wonderful ideas have inspired our library architecture. If you haven’t already, go and check out these fantastic libraries and spot the similarities with Gemlib.