"""MultiScanKernel calls one_step a number of times on an inner kernel"""

from functools import partial
from typing import NamedTuple

import jax

from .sampling_algorithm import LogProbFnType, Position, SamplingAlgorithm

__all__ = ["multi_scan"]


class MultiScanKernelState(NamedTuple):
    last_results: NamedTuple


class MultiScanKernelInfo(NamedTuple):
    last_results: NamedTuple


def multi_scan(
    num_updates: int, sampling_algorithm: SamplingAlgorithm
) -> SamplingAlgorithm:
    """Performs multiple applications of a kernel

    :obj:`sampling_algorithm` is invoked :obj:`num_updates` times
    returning the state and info after the last step.

    Args:
      num_updates: integer giving the number of updates
      sampling_algorithm: an instance of :obj:`SamplingAlgorithm`

    Returns:
      An instance of :obj:`SamplingAlgorithm`
    """
    # warn(
    #     "Use of `multi_scan` is deprecated, and will be removed in future.\
    # Instead, please make use of SamplingAlgorithm.__mul__.",
    #     DeprecationWarning,
    #     stacklevel=2,
    # )

    def init_fn(target_log_prob_fn, position):
        cs, ks = sampling_algorithm.init(target_log_prob_fn, position)
        return cs, MultiScanKernelState(ks)

    def step_fn(
        target_log_prob_fn: LogProbFnType,
        current_state: tuple[Position, MultiScanKernelState],
        seed=None,
    ):
        seeds = jax.random.split(seed, num=num_updates)
        step_fn = partial(sampling_algorithm.step, target_log_prob_fn)

        def body(a):
            i, state, info = a
            state1, info1 = step_fn(state, seeds[i])
            return i + 1, state1, info1

        def cond(a):
            i, state, info = a
            return i < num_updates

        chain_state, kernel_state = current_state

        init_state, init_info = step_fn(
            (chain_state, kernel_state.last_results), seed
        )  # unrolled first it

        _, last_state, last_info = jax.lax.while_loop(
            cond, body, init_val=(1, init_state, init_info)
        )

        return (
            (last_state[0], MultiScanKernelState(last_state[1])),
            MultiScanKernelInfo(last_info),
        )

    return SamplingAlgorithm(init_fn, step_fn)
