# Dependency imports
import jax.numpy as jnp
import numpy as np
from absl.testing import absltest
from jax import random

from gemlib.distributions.uniform_integer import UniformInteger


class TestUniformInteger(absltest.TestCase):
    def setUp(self):
        self.rng = random.key(10402302)
        self.fixture = jnp.array([[8, 4, 8], [2, 7, 9], [6, 0, 9]])
        self.low = 0
        self.high = 10

    def test_sample_n_int32(self):
        """Sample returning dtype int32."""
        key = random.fold_in(self.rng, 1)
        X = UniformInteger(self.low, self.high)
        x = X.sample([3, 3], seed=key)
        assert x.dtype == jnp.int32
        assert x.shape == (3, 3)
        assert jnp.all((x >= self.low) & (x < self.high))

    def test_sample_n_int64(self):
        """Sample returning int64."""
        key = random.fold_in(self.rng, 2)
        X = UniformInteger(np.int64(0), np.int64(10))
        x = X.sample([3, 3], seed=key)
        assert x.dtype == jnp.int64
        assert x.shape == (3, 3)
        assert jnp.all((x >= self.low) & (x < self.high))

    def test_log_prob_float32(self):
        """log_prob returning float32."""
        X = UniformInteger(0, 10)
        lp = X.log_prob(self.fixture)
        assert lp.shape == (3, 3)
        total_lp = jnp.sum(lp)
        np.testing.assert_almost_equal(total_lp, -20.723265, decimal=5)
        assert lp.dtype == jnp.float32

    def test_log_prob_float64(self):
        """log_prob returning float64."""
        X = UniformInteger(np.int64(0), np.int64(10), float_dtype=jnp.float64)
        lp = X.log_prob(self.fixture)
        total_lp = jnp.sum(lp)
        np.testing.assert_almost_equal(total_lp, -20.723265, decimal=5)
        assert lp.dtype == jnp.float64


if __name__ == "__main__":
    absltest.main()
