"""Test mcmc_sampler"""

from typing import NamedTuple

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from .mcmc_sampler import mcmc
from .test_util import CountingKernelInfo, counting_kernel


class DummyPosition(NamedTuple):
    x: float
    y: float


NUM_SAMPLES = 100


@pytest.fixture
def seed():
    return jax.random.key(42)


def test_mcmc(seed):
    sampling_algorithm = counting_kernel()

    initial_position = DummyPosition(0.0, -100.0)

    def tlp(x, y):  # noqa: ARG001
        return jnp.array(0.0)

    samples, info = mcmc(
        num_samples=NUM_SAMPLES,
        sampling_algorithm=sampling_algorithm,
        target_density_fn=tlp,
        initial_position=initial_position,
        seed=seed,
    )

    np.testing.assert_equal(
        samples,
        DummyPosition(
            x=np.arange(1.0, 1.0 + NUM_SAMPLES, step=1.0),
            y=np.arange(-99.0, -99.0 + NUM_SAMPLES, step=1.0),
        ),
    )
    np.testing.assert_equal(
        info, CountingKernelInfo(np.full(NUM_SAMPLES, True))
    )
