"""Galilean coordinate transformations."""

__all__ = ["VelocityBoost"]


from dataclasses import replace
from typing import Any, Literal, cast, final

import equinox as eqx
import jax.numpy as jnp
import wadler_lindig as wl
from plum import convert, dispatch

import quaxed.numpy as jnp
import unxt as u
from dataclassish.converters import Unless

from .base import AbstractOperator
from .identity import Identity
from coordinax._src.vectors.base_pos import AbstractPos
from coordinax._src.vectors.base_vel import AbstractVel
from coordinax._src.vectors.d3 import CartesianVel3D


@final
class VelocityBoost(AbstractOperator):
    r"""Operator for an instantaneous velocity boost.

    The operation is given by:

    .. math::

        (\mathbf{v}) \mapsto (\mathbf{v} + \Delta\mathbf{v})

    where :math:`\Delta\mathbf{v}` is the boost velocity.

    Parameters
    ----------
    velocity
        The boost velocity. This parameters uses
        :meth:`coordinax.vecs.CartesianVel3D.from_` to enable a variety of more
        convenient input types.

    Examples
    --------
    >>> import coordinax as cx

    >>> op = cx.ops.VelocityBoost.from_([1.0, 2.0, 3.0], "m/s")
    >>> op
    VelocityBoost(CartesianVel3D( ... ))

    >>> q = cx.CartesianPos3D.from_([0.0, 0.0, 0.0], "m")
    >>> op(q) is q
    True

    >>> p = cx.CartesianVel3D.from_([0., 0, 0], "m/s")
    >>> op(p) == p + op.velocity
    Array(True, dtype=bool)

    """

    velocity: AbstractVel = eqx.field(
        converter=Unless(AbstractVel, CartesianVel3D.from_)
    )
    """The boost velocity.

    Unless given a `coordinax.AbstractVel`, this parameter uses
    :meth:`coordinax.CartesianVel3D.from_` to enable a variety of more
    convenient input types.
    """

    # -----------------------------------------------------

    @property
    def is_inertial(self) -> Literal[True]:
        """Galilean boost is an inertial-frame preserving transform.

        Examples
        --------
        >>> import coordinax as cx

        >>> op = cx.ops.VelocityBoost.from_([1, 2, 3], "m/s")
        >>> op.is_inertial
        True

        """
        return True

    @property
    def inverse(self) -> "VelocityBoost":
        """The inverse of the operator.

        Examples
        --------
        >>> import coordinax as cx

        >>> op = cx.ops.VelocityBoost.from_([1, 2, 3], "m/s")
        >>> op.inverse
        VelocityBoost(CartesianVel3D( ... ))

        >>> print(op.inverse.velocity)
        <CartesianVel3D: (x, y, z) [m / s]
            [-1 -2 -3]>

        """
        return VelocityBoost(-self.velocity)

    # -----------------------------------------------------

    @AbstractOperator.__call__.dispatch  # type: ignore[misc]
    def __call__(self: "VelocityBoost", p: AbstractVel, /) -> AbstractVel:
        """Apply the boost to the coordinates.

        This does nothing to the position, as the boost is to the velocity only.

        Examples
        --------
        >>> import coordinax as cx

        >>> op = cx.ops.VelocityBoost.from_([1, 2, 3], "m/s")

        >>> p = cx.CartesianVel3D.from_([0, 0, 0], "m/s")
        >>> print(op(p))
        <CartesianVel3D: (x, y, z) [m / s]
            [1 2 3]>

        """
        return cast(AbstractVel, p + self.velocity)

    # -------------------------------------------
    # Arithmetic operations

    def __neg__(self: "VelocityBoost") -> "VelocityBoost":
        """Negate the rotation.

        Examples
        --------
        >>> import quaxed.numpy as jnp
        >>> import coordinax as cx

        >>> op = cx.ops.VelocityBoost.from_([1, 0, 0], "m/s")
        >>> print((-op).velocity)
        <CartesianVel3D: (x, y, z) [m / s]
            [-1 0 0]>

        """
        return replace(self, velocity=-self.velocity)

    # -----------------------------------------------------

    def __pdoc__(self, **kwargs: Any) -> wl.AbstractDoc:
        """Return the Wadler-Lindig representation."""
        return (
            wl.TextDoc(f"{self.__class__.__name__}(")
            + wl.pdoc(self.velocity, **kwargs)
            + wl.TextDoc(")")
        )


# ======================================================================
# More call dispatch


@AbstractOperator.__call__.dispatch
def call(
    self: VelocityBoost, q: AbstractPos, p: AbstractVel, /
) -> tuple[AbstractPos, AbstractVel]:
    r"""Apply the boost to the coordinates.

    This does nothing to the position, as the boost is to the velocity only.

    Examples
    --------
    >>> import coordinax as cx

    >>> op = cx.ops.VelocityBoost.from_([1, 2, 3], "m/s")

    >>> q = cx.CartesianPos3D.from_([0, 0, 0], "m")
    >>> p = cx.CartesianVel3D.from_([0, 0, 0], "m/s")
    >>> newq, newp = op(q, p)
    >>> print(newq, newp, sep="\n")
    <CartesianPos3D: (x, y, z) [m]
        [0 0 0]>
    <CartesianVel3D: (x, y, z) [m / s]
        [1 2 3]>

    """
    pc = p.vconvert(type(self.velocity), q).uconvert(self.velocity.units)
    newp = pc + self.velocity
    return q, newp.vconvert(type(p), q)


@AbstractOperator.__call__.dispatch
def call(
    self: VelocityBoost, q: u.Quantity["length"], p: u.Quantity["speed"], /
) -> tuple[u.Quantity["length"], u.Quantity["speed"]]:
    r"""Apply the boost to the coordinates.

    This does nothing to the position, as the boost is to the velocity only.

    Examples
    --------
    >>> import coordinax as cx

    >>> op = cx.ops.VelocityBoost.from_([1, 2, 3], "m/s")

    >>> q = u.Quantity([0., 0, 0], "m")
    >>> p = u.Quantity([0., 0, 0], "m/s")
    >>> newq, newp = op(q, p)
    >>> (newq, newp)
    (Quantity(Array([0., 0., 0.], dtype=float32), unit='m'),
     Quantity(Array([1., 2., 3.], dtype=float32), unit='m / s'))

    """
    pvec = CartesianVel3D.from_(p)
    newpvec = pvec + self.velocity
    return q, convert(newpvec, u.Quantity)


@AbstractOperator.__call__.dispatch(precedence=-1)
def call(self: VelocityBoost, q: AbstractPos, /) -> AbstractPos:
    """Apply the boost to the coordinates.

    This does nothing to the position, as the boost is to the velocity only.

    Examples
    --------
    >>> import coordinax as cx

    >>> op = cx.ops.VelocityBoost.from_([1, 2, 3], "m/s")

    >>> q = cx.CartesianPos3D.from_([0, 0, 0], "m")
    >>> op(q) is q
    True

    """
    return q


@AbstractOperator.__call__.dispatch
def call(
    self: VelocityBoost, t: u.Quantity["time"], q: AbstractPos, /
) -> tuple[u.Quantity["time"], AbstractPos]:
    """Apply the boost to the coordinates.

    This does nothing to the position, as the boost is to the velocity only.

    Examples
    --------
    >>> import unxt as u
    >>> import coordinax as cx

    >>> op = cx.ops.VelocityBoost.from_([1, 2, 3], "m/s")

    >>> q = cx.CartesianPos3D.from_([0, 0, 0], "m")
    >>> t = u.Quantity(1, "s")

    >>> newt, newq = op(t, q)
    >>> newt is t, newq is q
    (True, True)

    """
    return t, q


# ======================================================================
# Simplification


@dispatch
def simplify_op(op: VelocityBoost, /, **kwargs: Any) -> VelocityBoost | Identity:
    """Simplify a boost operator.

    Examples
    --------
    >>> import unxt as u
    >>> import coordinax as cx

    An operator with real effect cannot be simplified:

    >>> op = cx.ops.VelocityBoost.from_([1, 0, 0], "m/s")
    >>> cx.ops.simplify_op(op)
    VelocityBoost(CartesianVel3D( ... ))

    An operator with no effect can be simplified:

    >>> op = cx.ops.VelocityBoost.from_([0, 0, 0], "m/s")
    >>> cx.ops.simplify_op(op)
    Identity()

    """
    # Check if the velocity is zero.
    if jnp.allclose(convert(op.velocity, u.Quantity).value, jnp.zeros((3,)), **kwargs):
        return Identity()
    return op


@dispatch
def simplify_op(op1: VelocityBoost, op2: VelocityBoost) -> VelocityBoost:
    """Combine two boosts into a single boost.

    Examples
    --------
    >>> import unxt as u
    >>> import coordinax.ops as cxo

    >>> op1 = cxo.VelocityBoost.from_([1, 0, 0], "m/s")
    >>> op2 = cxo.VelocityBoost.from_([0, 1, 0], "m/s")

    >>> op3 = cxo.simplify_op(op1, op2)
    >>> op3
    VelocityBoost(CartesianVel3D( ... ))

    >>> op3.velocity == op1.velocity + op2.velocity
    Array(True, dtype=bool)

    """
    return VelocityBoost(op1.velocity + op2.velocity)
