"""Test the UniformKCategorical distribution"""

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from gemlib.distributions.kcategorical import UniformKCategorical


@pytest.fixture
def seed():
    return jax.random.key(10402302)


@pytest.fixture
def mask():
    return np.array([True, False, False, True, True, False, True])


def test_sample(mask, seed):
    """Sample draws one sample with shape (1,3) ."""
    X = UniformKCategorical(k=3, mask=mask)
    x = X.sample(seed=seed, sample_shape=(1,))

    assert x.shape == (1, 3)
    assert x.dtype == jnp.int32

    # Check all values are within valid indives from self.masked
    valid_indices = jnp.array([0, 3, 4, 6])
    is_valid = jnp.isin(x, valid_indices)
    assert jnp.all(is_valid)


def test_log_prob_float32(mask):
    """Log probability of 1 realisations using float32."""
    target = jnp.array([[0, 6, 3]], dtype=jnp.int32)
    X = UniformKCategorical(
        k=target.shape[-1], mask=mask, float_dtype=jnp.float32
    )
    lp = X.log_prob(target)

    np.testing.assert_almost_equal(lp, -1.3862944, decimal=5)
    assert lp.dtype == jnp.float32


def test_log_prob_float64(mask):
    """Log probability of 1 realisations using float64."""
    target = jnp.array([[0, 6, 3]], dtype=jnp.int32)
    X = UniformKCategorical(
        k=target.shape[-1], mask=mask, float_dtype=jnp.float64
    )
    lp = X.log_prob(target)

    np.testing.assert_almost_equal(lp, -1.3862944, decimal=5)
    assert lp.dtype == jnp.float64
