"""Pseudo random number utils"""

import jax
from jax import Array


def sanitize_key(key: int | Array | None):
    """Sanitize a user-supplied key

    Args:
      key: either an int or a JAX pseudorandom key

    Returns:
      an instance of a JAX pseudorandom key
    """
    if key is None:
        raise ValueError("Explicit key required")

    if isinstance(key, int):
        key = jax.random.key(key)

    return key
