"""Unit tests for MwgStep"""

from collections.abc import Callable
from typing import NamedTuple

import jax
import pytest

from gemlib.mcmc.mwg_step import MwgStep, _make_position_projector


class SamplingAlgorithmMock(NamedTuple):
    init: Callable
    step: Callable


class ChainStateMock(NamedTuple):
    position: NamedTuple
    log_density: float
    log_density_grad: float


def structure_assertion(expected_structure):
    def init_fn(_target_log_prob_fn, position):
        assert jax.tree.structure(position).from_iterable_tree(
            expected_structure
        )

        return (
            ChainStateMock(
                position=position, log_density=0.0, log_density_grad=()
            ),
            (),
        )

    def step_fn(_target_log_prob_fn, chain_and_kernel_state, _seed):
        cs, ks = chain_and_kernel_state

        assert jax.tree.structure(cs.position).from_iterable_tree(
            expected_structure
        )

        return chain_and_kernel_state, ()

    return SamplingAlgorithmMock(init_fn, step_fn)


class Position(NamedTuple):
    alpha: float
    beta: float
    gamma: float


@pytest.fixture
def position() -> Position:
    return Position(0.1, 0.1, 0.1)


def test_project_position(position):
    # Test target is a singleton tuple
    _project_position = _make_position_projector("alpha")
    target, target_compl = _project_position(position)
    assert isinstance(target, tuple)
    assert isinstance(target_compl, dict)
    assert len(target) == 1
    assert set(target_compl.keys()) == {"beta", "gamma"}

    # Test target is a duple
    _project_position = _make_position_projector(["alpha", "beta"])
    target, target_compl = _project_position(position)
    assert isinstance(target, tuple)
    assert isinstance(target_compl, dict)
    assert len(target_compl) == 1


@pytest.mark.parametrize(
    "target_names,shape_example",
    [("alpha", 1.0), (["alpha"], (1.0,)), (["alpha", "beta"], (1.0, 2.0))],
)
def test_mwg_state_shape(position, shape_example, target_names):
    def log_prob(_alpha, _beta, _gamma):
        return 1.0

    kernel = MwgStep(
        structure_assertion(shape_example), target_names=target_names
    )
    print("Position:", position)
    cks = kernel.init(log_prob, position)
    kernel.step(log_prob, cks, seed=None)
