"""Higher-order functions to run MCMC"""

from collections.abc import Callable, Iterable
from functools import partial
from typing import Any

import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp

from .sampling_algorithm import SamplingAlgorithm

__all__ = ["mcmc"]


def _split_seed(seed, n):
    return tfp.random.split_seed(seed, n=n)


def _scan(fn, init, xs):
    """Scan

    This function is equivalent to

    ```
    scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
    ```
    """
    return jax.lax.scan(fn, init, xs)


def mcmc(
    num_samples: int,
    sampling_algorithm: SamplingAlgorithm,
    target_density_fn: Callable[[Any, ...], float],
    initial_position: Iterable,
    seed: tuple[int, int],
    kernel_kwargs_fn=lambda _: {},
):
    """Runs an MCMC using `sampling_algorithm`

    Args:
      num_updates: integer giving the number of updates
      sampling_algorithm: an instance of `SamplingAlgorithm`
      target_density_fn: Python callable which takes an argument like
        `current_state` and returns its (possibly unnormalized) log-density
        under the target distribution.
      initial_position: initial state structured tuple
      seed: an optional list of two scalar ``int`` tensors.
      kernel_kwargs_fn: a callable taking the chain position as an argument,
        and returning a dictionary of extra kwargs


    Returns:
      A tuple containing samples of the Markov chain and information about the
      behaviour of the sampler(s) (e.g. whether kernels accepted or rejected,
      adaptive covariance matrices, etc).
    """

    initial_position = jax.tree.map(lambda x: jnp.asarray(x), initial_position)
    initial_state = sampling_algorithm.init(
        target_density_fn,
        initial_position,
        **kernel_kwargs_fn(initial_position),
    )
    kernel_step_fn = partial(sampling_algorithm.step, target_density_fn)

    def one_step(state, rng_key):
        new_state, info = kernel_step_fn(
            state, rng_key, **kernel_kwargs_fn(state)
        )
        return new_state, (new_state[0].position, info)

    keys = _split_seed(seed, num_samples)

    _, trace = _scan(one_step, initial_state, keys)

    return trace
