"""Brownian motion as a distribution"""

import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from tensorflow_probability.substrates.jax.internal import (
    distribution_util as dist_util,
)
from tensorflow_probability.substrates.jax.internal import (
    dtype_util,
)
from tensorflow_probability.substrates.jax.internal.tensor_util import (
    convert_nonref_to_tensor,
)

tfd = tfp.distributions


class BrownianMotion(tfd.Distribution):
    def __init__(
        self,
        index_points,
        x0=0.0,
        scale=1.0,
        validate_args=False,
        allow_nan_stats=True,
        name="BrownianMotion",
    ):
        parameters = dict(locals())
        dtype = dtype_util.common_dtype(
            [index_points, x0, scale],
            dtype_hint=jnp.asarray(index_points).dtype,
        )
        self._x0 = convert_nonref_to_tensor(x0, dtype_hint=dtype)

        self._index_points = convert_nonref_to_tensor(
            index_points, dtype_hint=dtype
        )
        self._scale = convert_nonref_to_tensor(scale, dtype_hint=dtype)
        self._increments = tfd.MultivariateNormalDiag(
            loc=jnp.zeros(self._index_points[..., 1:].shape, dtype=dtype),
            scale_diag=jnp.sqrt(
                self._index_points[..., 1:] - self._index_points[..., :-1]
            ).astype(dtype)
            * self._scale,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name="bm_increments",
        )  # iid increments

        with jax.named_scope(name):
            super().__init__(
                dtype=dtype,
                reparameterization_type=tfd.FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name,
            )

    def _batch_shape(self):
        return self._x0.shape

    def _event_shape(self):
        return self._index_points.shape[-1] - 1

    def _sample_n(self, n, seed=None):
        return self._x0 + jnp.cumsum(
            self._increments.sample(n, seed=seed), axis=-1
        )

    def _log_prob(self, x):
        path = dist_util.pad(
            x, axis=-1, front=True, value=jnp.astype(self._x0, x.dtype)
        )
        diff = path[..., 1:] - path[..., :-1]

        return self._increments.log_prob(diff)


class BrownianBridge(tfd.Distribution):
    def __init__(
        self,
        index_points,
        x0=0.0,
        x1=0.0,
        scale=1.0,
        validate_args=False,
        allow_nan_stats=True,
        name="BrownianBridge",
    ):
        parameters = dict(locals())
        dtype = jnp.asarray(x0).dtype
        self._index_points = convert_nonref_to_tensor(
            index_points, dtype_hint=dtype
        )
        self._x0 = convert_nonref_to_tensor(x0, dtype_hint=dtype)
        self._x1 = convert_nonref_to_tensor(x1, dtype_hint=dtype)
        self._scale = convert_nonref_to_tensor(scale, dtype_hint=dtype)

        self._increments = tfd.MultivariateNormalDiag(
            loc=0.0,
            scale_diag=jnp.sqrt(
                self._index_points[..., 1:] - self._index_points[..., :-1]
            )
            * self._scale,
            name="bb_increments",
        )

        with jax.named_scope(name):
            super().__init__(
                dtype=dtype,
                reparameterization_type=tfd.FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                name=name,
            )

    def _batch_shape(self):
        return self._x0.shape

    def _event_shape(self):
        return self._index_points.shape[-1] - 2

    def _sample_n(self, n, seed=None):
        """Sampling based on re-leveling pure
        Brownian motion
        """
        z = self._increments.sample(n, seed=seed)
        z = jnp.cumsum(z, axis=-1)

        y_ref_0 = jnp.stack([jnp.zeros_like(z[..., 0]), z[..., -1]], axis=-1)
        y_ref_1 = jnp.stack([self._x0, self._x1], axis=-1)
        line = tfp.math.interp_regular_1d_grid(
            x=self._index_points[..., 1:-1],
            x_ref_min=self._index_points[..., 0],
            x_ref_max=self._index_points[..., -1],
            y_ref=y_ref_1 - y_ref_0,
        )
        return z[..., :-1] + line

    def _log_prob(self, x):
        path = dist_util.pad(x, -1, front=True, value=self._x0)
        path = dist_util.pad(path, -1, back=True, value=self._x1)
        diff = path[..., 1:] - path[..., :-1]
        return self._increments.log_prob(diff)
