"""Test MCMC utilities"""

from typing import NamedTuple

import jax.numpy as jnp
import numpy as np

from .mcmc_util import get_flattening_bijector


def test_get_flattening_bijector_scalar():
    val = np.float32(0.1)

    bijector = get_flattening_bijector(val)

    flat_val = bijector(val)

    assert flat_val.shape[-1] == 1
    assert flat_val == 0.1  # noqa: PLR2004


def test_get_flattening_bijector_list():
    val = [
        np.float32(0.1),
        np.float32(0.2),
    ]

    bijector = get_flattening_bijector(val)

    flat_val = bijector(val)

    flat_shape = 2
    assert flat_val.shape[-1] == flat_shape
    assert np.all(flat_val == np.array([0.1, 0.2], np.float32))


def test_get_flattening_bijector_namedtuple():
    class Struct(NamedTuple):
        foo: float
        bar: float
        baz: float

    val = Struct(
        np.array(0.1, dtype=np.float32),
        np.array([0.2, 0.3], dtype=np.float32),
        np.array([[0.4, 0.5], [0.6, 0.7]], dtype=np.float32),
    )

    bijector = get_flattening_bijector(val)

    flat_val = bijector(val)

    flat_shape = 7
    assert flat_val.shape[-1] == flat_shape
    assert all(
        flat_val == jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], np.float32)
    )
