"""Hypergeometric sampling algorithm"""

# ruff: noqa: N803
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from tensorflow_probability.substrates.jax.internal import (
    dtype_util,
    samplers,
)

tfd = tfp.distributions
ps = tfp.internal.prefer_static
brs = tfp.internal.batched_rejection_sampler


def sample_hypergeometric(num_samples, N, K, n, seed=None):
    dtype = dtype_util.common_dtype([N, K, n], jnp.float32)
    N = jnp.asarray(N, dtype)
    K = jnp.asarray(K, dtype)
    n = jnp.asarray(n, dtype)
    good_params_mask = (N >= 1.0) & (N >= K) & (n <= N)
    N = jnp.where(good_params_mask, N, 100.0)
    K = jnp.where(good_params_mask, K, 50.0)
    n = jnp.where(good_params_mask, n, 50.0)
    sample_shape = ps.concat(
        [
            [num_samples],
            ps.broadcast_shape(
                ps.broadcast_shape(ps.shape(N), ps.shape(K)), ps.shape(n)
            ),
        ],
        axis=0,
    )

    # First Transform N, K, n such that
    # N / 2 >= K, N / 2 >= n
    is_k_small = 0.5 * N >= K
    is_n_small = n <= 0.5 * N
    previous_K = K
    previous_n = n
    K = jnp.where(is_k_small, K, N - K)
    n = jnp.where(is_n_small, n, N - n)

    # TODO: Can we write this in a more numerically stable way?
    def _log_hypergeometric_coeff(x):
        return (
            jax.lax.lgamma(x + 1.0)
            + jax.lax.lgamma(K - x + 1.0)
            + jax.lax.lgamma(n - x + 1.0)
            + jax.lax.lgamma(N - K - n + x + 1.0)
        )

    p = K / N
    q = 1 - p
    a = n * p + 0.5
    c = jnp.sqrt(2.0 * a * q * (1.0 - n / N))
    k = jnp.floor((n + 1) * (K + 1) / (N + 2))
    g = _log_hypergeometric_coeff(k)
    diff = jnp.floor(a - c)
    x = (a - diff - 1) / (a - diff)
    diff = jnp.where(
        (n - diff) * (p - diff / N) * jnp.square(x)
        > (diff + 1.0) * (q - (n - diff - 1) / N),
        diff + 1.0,
        diff,
    )
    # TODO: Can we write this difference of lgammas more numerically stably?
    h = (a - diff) * jnp.exp(
        0.5 * (g - _log_hypergeometric_coeff(diff)) + jnp.log(2.0)
    )
    b = jnp.minimum(jnp.minimum(n, K) + 1, jnp.floor(a + 5 * c))

    def generate_and_test_samples(seed):
        v_seed, u_seed = samplers.split_seed(seed)
        U = samplers.uniform(sample_shape, dtype=dtype, seed=u_seed)
        V = samplers.uniform(sample_shape, dtype=dtype, seed=v_seed)
        # Guard against 0.

        X = a + h * (V - 0.5) / (1.0 - U)
        samples = jnp.floor(X)
        good_sample_mask = (samples >= 0.0) & (samples < b)
        T = g - _log_hypergeometric_coeff(samples)
        # Uses slow pass since we are trying to do this in a vectorized way.
        good_sample_mask = good_sample_mask & (2 * jnp.log1p(-U) <= T)
        return samples, good_sample_mask

    samples = brs.batched_las_vegas_algorithm(
        generate_and_test_samples, seed=seed
    )[0]
    samples = jnp.where(good_params_mask, samples, np.nan)
    # Now transform the samples depending on if we constrained N and / or k
    samples = jnp.where(
        ~is_k_small & ~is_n_small,
        samples + previous_K + previous_n - N,
        jnp.where(
            ~is_k_small,
            previous_n - samples,
            jnp.where(~is_n_small, previous_K - samples, samples),
        ),
    )
    return samples
