"""Implementation of Metropolis-within-Gibbs framework"""

from __future__ import annotations

from collections import namedtuple
from collections.abc import Callable

import tensorflow_probability.substrates.jax as tfp

from gemlib.mcmc.mcmc_util import is_list_like
from gemlib.mcmc.sampling_algorithm import (
    ChainAndKernelState,
    ChainState,
    KernelInfo,
    LogProbFnType,
    Position,
    SamplingAlgorithm,
    SeedType,
)

split_seed = tfp.random.split_seed

__all__ = ["MwgStep"]


def as_list(x):
    if is_list_like(x):
        return x
    return [x]


def _make_target_type(target_names):
    if is_list_like(target_names):
        return namedtuple("_Target", target_names)
    return lambda x: x  # identity


def _make_position_projector(target_names: list[str]):
    target_names = as_list(target_names)

    def fn(position: Position) -> tuple[tuple, dict]:
        position_dict = position._asdict()

        for name in target_names:
            if name not in position_dict:
                raise ValueError(f"`{name}` is not present in `position`")

        target_tuple = tuple(position_dict[k] for k in target_names)
        target_compl_dict = {
            k: v for k, v in position_dict.items() if k not in target_names
        }

        return (
            target_tuple,
            target_compl_dict,
        )

    return fn


class MwgStep:  # pylint: disable=too-few-public-methods
    """A Metropolis-within-Gibbs step.

    Transforms a base kernel to operate on a substate of a Markov chain.

    Args:
      sampling_algorithm: a named tuple containing the generic kernel `init`
                        and `step` function.
      target_names: a list of variable names on which the
                        Metropolis-within-Gibbs step is to operate
      kernel_kwargs_fn: a callable taking the chain position as an argument,
                    and returning a dictionary of extra kwargs to
                    `sampling_algorithm.step`.

    Returns:
      An instance of SamplingAlgorithm.

    """

    def __new__(
        cls,
        sampling_algorithm: SamplingAlgorithm,
        target_names: str | list[str],
        kernel_kwargs_fn: Callable[[Position], dict] = lambda _: {},
    ):
        """Create a new Metropolis-within-Gibbs step"""

        target_names_list = as_list(target_names)
        _project_position = _make_position_projector(target_names_list)

        TargetType = _make_target_type(target_names)

        def _name_target(target: tuple) -> dict:
            return dict(target_names, target)

        def init(
            target_log_prob_fn: LogProbFnType,
            initial_position: Position,
        ):
            target, target_compl = _project_position(initial_position)

            def conditional_tlp(*args):
                tlp_kwargs = (
                    dict(zip(target_names_list, args, strict=True))
                    | target_compl
                )
                return target_log_prob_fn(**tlp_kwargs)

            kernel_state = sampling_algorithm.init(
                conditional_tlp,
                TargetType(*target),
                **kernel_kwargs_fn(initial_position),
            )

            chain_state = ChainState(
                position=initial_position,
                log_density=kernel_state[0].log_density,
                log_density_grad=kernel_state[0].log_density_grad,
            )

            return chain_state, kernel_state[1]

        def step(
            target_log_prob_fn: LogProbFnType,
            chain_and_kernel_state: ChainAndKernelState,
            seed: SeedType,
        ) -> tuple[ChainAndKernelState, KernelInfo]:
            chain_state, kernel_state = chain_and_kernel_state

            # Split global state and generate conditional density
            target, target_compl = _project_position(chain_state.position)

            # Calculate the conditional log density
            def conditional_tlp(*args):
                tlp_kwargs = (
                    dict(zip(target_names_list, args, strict=True))
                    | target_compl
                )
                return target_log_prob_fn(**tlp_kwargs)

            chain_substate = chain_state._replace(position=TargetType(*target))

            # Invoke the kernel on the target state
            (new_chain_substate, new_kernel_state), info = (
                sampling_algorithm.step(
                    conditional_tlp,
                    (chain_substate, kernel_state),
                    seed,
                    **kernel_kwargs_fn(chain_state.position),
                )
            )

            # Stitch the global position back together
            new_position_dict = dict(
                zip(
                    target_names_list,
                    as_list(new_chain_substate.position),
                    strict=True,
                )
            )
            new_global_state = new_chain_substate._replace(
                position=chain_state.position.__class__(
                    **(new_position_dict | target_compl)
                )
            )

            return (new_global_state, new_kernel_state), info

        return SamplingAlgorithm(init, step)
