"""MCMC utility functions"""

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp


def is_list_like(x):
    return isinstance(x, list | tuple)


def get_flattening_bijector(example):
    """A bijector that converts a data structure to a 1-D tensor"""

    flat_example, example_treedef = jax.tree.flatten(example)

    split = tfp.bijectors.Split(
        [np.prod(x.shape) for x in flat_example], axis=-1
    )
    reshape = tfp.bijectors.JointMap(
        [tfp.bijectors.Reshape(x.shape) for x in flat_example]
    )
    restructure = tfp.bijectors.Restructure(
        output_structure=jax.tree.unflatten(
            example_treedef, range(len(flat_example))
        )
    )

    return tfp.bijectors.Invert(
        tfp.bijectors.Chain([restructure, reshape, split])
    )
