"""Test modules for multiscan kernel"""

from typing import NamedTuple

import jax
import jax.numpy as jnp
import numpy as np

from .mcmc_sampler import mcmc
from .multi_scan import MultiScanKernelInfo, multi_scan
from .test_util import CountingKernelInfo, counting_kernel


class DummyPosition(NamedTuple):
    x: float


def get_seed():
    return jax.random.key(42)


def test_one_multi_scan():
    multi_scan_iterations = 100

    sampler = multi_scan(multi_scan_iterations, counting_kernel())

    def tlp(x):
        return x

    initial_position = DummyPosition(0.0)

    state = sampler.init(tlp, initial_position)
    (chain_state, kernel_state), info = sampler.step(
        target_log_prob_fn=tlp, current_state=state, seed=get_seed()
    )

    np.testing.assert_equal(
        chain_state.position, DummyPosition(jnp.asarray(100.0))
    )
    np.testing.assert_equal(
        chain_state.log_density, jnp.asarray(multi_scan_iterations)
    )
    np.testing.assert_equal(
        kernel_state.last_results.invocation, jnp.asarray(multi_scan_iterations)
    )
    np.testing.assert_equal(info.last_results.is_accepted, jnp.asarray(True))


def test_many_multi_scan():
    num_samples = 5
    multi_scan_iterations = 100

    sampler = multi_scan(multi_scan_iterations, counting_kernel())

    def tlp(x):
        return x

    initial_position = DummyPosition(0.0)

    samples, info = mcmc(
        num_samples=num_samples,
        sampling_algorithm=sampler,
        target_density_fn=tlp,
        initial_position=initial_position,
        seed=get_seed(),
    )

    np.testing.assert_equal(
        samples,
        DummyPosition(
            np.arange(
                100.0,
                100.0 + (num_samples * multi_scan_iterations),
                step=multi_scan_iterations,
            )
        ),
    )
    np.testing.assert_equal(
        info,
        MultiScanKernelInfo(
            last_results=CountingKernelInfo(np.full(num_samples, True))
        ),
    )
