"""The UniformInteger distribution class"""

import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from tensorflow_probability.substrates.jax.internal import (
    dtype_util,
    parameter_properties,
    samplers,
)

from gemlib.prng_util import sanitize_key

tfd = tfp.distributions


class UniformInteger(tfd.Distribution):
    def __init__(
        self,
        low,
        high,
        validate_args=False,
        allow_nan_stats=True,
        float_dtype=jnp.float32,
        name="UniformInteger",
    ):
        """Integer uniform distribution.

        Args:
        ----
          low: Integer tensor, lower boundary of the output interval. Must have
            `low <= high`.
          high: Integer tensor, _inclusive_ upper boundary of the output
            interval.  Must have `low <= high`.
          validate_args: Python `bool`, default `False`. When `True`
            distribution parameters are checked for validity despite possibly
            degrading runtime performance. When `False` invalid inputs may
            silently render incorrect outputs.
          allow_nan_stats: Python `bool`, default `True`. When `True`,
           statistics (e.g., mean, mode, variance) use the value "`NaN`" to
           indicate the result is undefined. When `False`, an exception is
           raised if one or more of the statistic's batch members are undefined.
          dtype: returned integer dtype when sampling.
          float_dtype: returned float dtype of log probability.
          name: Python `str` name prefixed to Ops created by this class.

        Example 1: sampling
        ```python
        import jax.numpy as jnp
        import jax
        from gemlib.distributions.uniform_integer import UniformInteger

        key = jax.random.key(10402302)
        X = UniformInteger(0, 10, dtype=jnp.int32)
        x = X.sample([3, 3], seed=key)
        print("samples:", x, "=", [[8, 4, 8], [2, 7, 9], [6, 0, 9]])
        ```

        Example 2: log probability
        ```python
        import jax.numpy as jnp
        from gemlib.distributions.uniform_integer import UniformInteger

        X = UniformInteger(0, 10)
        lp = X.log_prob(jnp.array([[8, 4, 8], [2, 7, 9], [6, 0, 9]]))
        total_lp = jnp.round(tf.math.reduce_sum(lp) * 1e5) / 1e5
        print("total lp:", total_lp, "= -20.72327")
        ```

        Raises:
        ------
          InvalidArgument if `low > high` and `validate_args=False`.

        """
        parameters = dict(locals())
        dtype = dtype_util.common_dtype([low, high], jnp.int32)
        self._float_dtype = float_dtype
        self._low = jnp.asarray(low, dtype=dtype)
        self._high = jnp.asarray(high, dtype=dtype)

        super().__init__(
            dtype=dtype,
            reparameterization_type=tfd.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name,
        )
        if validate_args is True:
            assert jnp.all(self._low < self._high), (
                "Condition low < high failed"
            )

    @classmethod
    def _parameter_properties(cls, dtype, num_classes=None):  # noqa: ARG003
        return {
            "low": parameter_properties.ParameterProperties(
                default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED,
            ),
            "high": parameter_properties.ParameterProperties(
                default_constraining_bijector_fn=parameter_properties.BIJECTOR_NOT_IMPLEMENTED,
            ),
        }

    @property
    def low(self):
        """Lower boundary of the output interval."""
        return self._low

    @property
    def high(self):
        """Upper boundary of the output interval."""
        return self._high

    @property
    def float_dtype(self):
        return self._float_dtype

    def _event_shape_tensor(self):
        return jnp.array([], dtype=jnp.int32)

    def _event_shape(self):
        return ()

    def _sample_n(self, n, seed=None):
        seed = sanitize_key(seed)
        low = self.low
        high = self.high
        shape = (n,) + self._batch_shape_tensor(low=low, high=high)

        samples = samplers.uniform(shape=shape, dtype=jnp.float32, seed=seed)
        return low + jnp.floor((high - low) * samples).astype(low.dtype)

    def _prob(self, x):
        low = jnp.asarray(self.low, self.float_dtype)
        high = jnp.asarray(self.high, self.float_dtype)
        x = jnp.asarray(x, dtype=self.float_dtype)

        return jnp.where(
            jnp.isnan(x),
            x,
            jnp.where(
                (x < low) | (x >= high),
                jnp.zeros_like(x),
                jnp.ones_like(x) / (high - low),
            ),
        )

    def _log_prob(self, x):
        return jnp.log(self._prob(x))
