"""Test the Hypergeometric random variable"""

import jax
import jax.numpy as jnp
import numpy as np
import pytest
import tensorflow_probability.substrates.jax as tfp
from absl.testing import absltest

from gemlib.distributions.hypergeometric import Hypergeometric

tfd = tfp.distributions


class TestHypergeometric(absltest.TestCase):
    def setUp(self):
        self._rng = np.random.RandomState(5)
        super().setUp()

    def test_neg_args(self):
        """Test for invalid arguments"""
        with pytest.raises(Exception, match="N must be non-negative"):
            _ = Hypergeometric(N=-3, K=1, n=2, validate_args=True)

    def test_sample_n_float32(self):
        """Sample returning float32 args"""

        X = Hypergeometric(
            jnp.float32(345.0), jnp.float32(35.0), jnp.float32(100.0)
        )
        x = X.sample([1000, 1000], seed=jax.random.key(1))

        assert x.dtype == jnp.float32
        np.testing.assert_allclose(jnp.mean(x), X.mean(), atol=1e-3, rtol=1e-3)

    def test_sample_n_float64(self):
        """Sample returning float32 args"""

        X = Hypergeometric(
            jnp.float64(345.0), jnp.float64(35.0), jnp.float64(100.0)
        )
        x = X.sample([1000, 1000], seed=jax.random.key(1))

        assert x.dtype == jnp.float64
        np.testing.assert_allclose(jnp.mean(x), X.mean(), atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
    absltest.main()
