# mypy: disable-error-code="operator,dict-item"
from typing import Optional, Union, Annotated

from torch import Tensor
from caskade import forward, Param

from .base import ThinLens, CosmologyType, NameType, ZType
from . import func

__all__ = ("MassSheet",)


class MassSheet(ThinLens):
    """
    Represents an external shear effect in a gravitational lensing system.

    Attributes
    ----------
    name: string
        Identifier for the lens instance.

    cosmology: Cosmology
        The cosmological model used for lensing calculations.

    z_l: Optional[Union[Tensor, float]]
        The redshift of the lens.

        *Unit: unitless*

    z_s : Optional[Union[Tensor, float]]
        The redshift of the source.

        *Unit: unitless*

    x0: Optional[Union[Tensor, float]]
        x-coordinate of the shear center in the lens plane.

        *Unit: arcsec*

    y0: Optional[Union[Tensor, float]]
        y-coordinate of the shear center in the lens plane.

        *Unit: arcsec*

    kappa: Optional[Union[Tensor, float]]
        Convergence. Surface density normalized by the critical surface density.

        *Unit: unitless*
    """

    _null_params = {
        "x0": 0.0,
        "y0": 0.0,
        "kappa": 0.1,
    }

    def __init__(
        self,
        cosmology: CosmologyType,
        z_l: ZType = None,
        z_s: ZType = None,
        x0: Annotated[
            Optional[Union[Tensor, float]],
            "x-coordinate of the shear center in the lens plane",
            True,
        ] = None,
        y0: Annotated[
            Optional[Union[Tensor, float]],
            "y-coordinate of the shear center in the lens plane",
            True,
        ] = None,
        kappa: Annotated[
            Optional[Union[Tensor, float]], "Surface density", True
        ] = None,
        name: NameType = None,
    ):
        super().__init__(cosmology, z_l, name=name, z_s=z_s)

        self.x0 = Param("x0", x0, units="arcsec")
        self.y0 = Param("y0", y0, units="arcsec")
        self.kappa = Param("kappa", kappa, units="unitless")

    @forward
    def reduced_deflection_angle(
        self,
        x: Tensor,
        y: Tensor,
        x0: Annotated[Tensor, "Param"],
        y0: Annotated[Tensor, "Param"],
        kappa: Annotated[Tensor, "Param"],
    ) -> tuple[Tensor, Tensor]:
        """
        Calculates the reduced deflection angle.

        Parameters
        ----------
        x: Tensor
            x-coordinates in the lens plane.

            *Unit: arcsec*

        y: Tensor
            y-coordinates in the lens plane.

            *Unit: arcsec*

        Returns
        -------
        x_component: Tensor
            Deflection Angle in x-direction.

            *Unit: arcsec*

        y_component: Tensor
            Deflection Angle in y-direction.

            *Unit: arcsec*

        """
        return func.reduced_deflection_angle_mass_sheet(x0, y0, kappa, x, y)

    @forward
    def potential(
        self,
        x: Tensor,
        y: Tensor,
        x0: Annotated[Tensor, "Param"],
        y0: Annotated[Tensor, "Param"],
        kappa: Annotated[Tensor, "Param"],
    ) -> Tensor:
        # Meneghetti eq 3.81
        return func.potential_mass_sheet(x0, y0, kappa, x, y)

    @forward
    def convergence(
        self,
        x: Tensor,
        y: Tensor,
        kappa: Annotated[Tensor, "Param"],
    ) -> Tensor:
        # Essentially by definition
        return func.convergence_mass_sheet(kappa, x)
