"""Configuration class for variables."""

from __future__ import annotations

from typing import Self

import numpy as np
from pydantic import (
    BaseModel,
    ConfigDict,
    ValidationInfo,
    field_validator,
    model_validator,
)

from ropt.config.utils import (
    broadcast_1d_array,
    check_enum_values,
    immutable_array,
)
from ropt.config.validated_types import (  # noqa: TC001  # noqa: TC001
    Array1D,
    Array1DBool,
    Array1DInt,
    ArrayEnum,
    ItemOrTuple,
)
from ropt.enums import BoundaryType, PerturbationType, VariableType

from .constants import (
    DEFAULT_PERTURBATION_BOUNDARY_TYPE,
    DEFAULT_PERTURBATION_MAGNITUDE,
    DEFAULT_PERTURBATION_TYPE,
    DEFAULT_SEED,
)


class VariablesConfig(BaseModel):
    r"""Configuration class for optimization variables.

    This class, `VariablesConfig`, defines the configuration for optimization
    variables. It is used in an [`EnOptConfig`][ropt.config.EnOptConfig] object
    to specify the initial values, bounds, types, and an optional mask for the
    variables.

    The `variables` field is required and determines the number of variables,
    including both free and fixed variables.

    The `lower_bounds` and `upper_bounds` fields define the bounds for each
    variable. These are also `numpy` arrays and are broadcasted to match the
    number of variables. By default, they are set to negative and positive
    infinity, respectively. `numpy.nan` values in these arrays indicate
    unbounded variables and are converted to `numpy.inf` with the appropriate
    sign.

    The optional `types` field allows assigning a
    [`VariableType`][ropt.enums.VariableType] to each variable. If not provided,
    all variables are assumed to be continuous real-valued
    ([`VariableType.REAL`][ropt.enums.VariableType.REAL]).

    The optional `mask` field is a boolean `numpy` array that indicates which
    variables are free to change during optimization. `True` values in the mask
    indicate that the corresponding variable is free, while `False` indicates a
    fixed variable.

    **Variable perturbations**

    The `VariablesConfig` class also stores information that is needed to
    generate perturbed variables, for instance to calculate stochastic
    gradients.

    Perturbations are generated by sampler objects that are configured
    separately as a tuple of [`SamplerConfig`][ropt.config.SamplerConfig]
    objects in the configuration object used by a plan step. For instance,
    [`EnOptConfig`][ropt.config.EnOptConfig] object defines the available
    samplers in its `samplers` field. The `samplers` field of the
    `VariablesConfig` object specifies, for each variable, the index of the
    sampler to use. A random number generator is created to support samplers
    that require random numbers.

    The generated perturbation values are scaled by the values of the
    `perturbation_magnitudes` field and can be modified based on the
    `perturbation_types`. See [`PerturbationType`][ropt.enums.PerturbationType]
    for details on available perturbation types.

    Perturbed variables may violate the defined variable bounds. The
    `boundary_types` field specifies how to handle such violations. See
    [`BoundaryType`][ropt.enums.BoundaryType] for details on available boundary
    handling methods.

    The `perturbation_types` and `boundary_types` fields use values from the
    [`PerturbationType`][ropt.enums.PerturbationType] and
    [`BoundaryType`][ropt.enums.BoundaryType] enumerations, respectively.

    Note: Seed for Samplers
        The `seed` value ensures consistent results across repeated runs with
        the same configuration. To obtain unique results for each optimization
        run, modify the seed. A common approach is to use a tuple with a unique
        ID as the first element, ensuring reproducibility across nested and
        parallel plan evaluations.

    Attributes:
        types:                    Optional variable types.
        variable_count:           Number of variables.
        lower_bounds:             Lower bounds for the variables (default: $-\infty$).
        upper_bounds:             Upper bounds for the variables (default: $+\infty$).
        mask:                     Optional boolean mask indicating free variables.
        perturbation_magnitudes:  Magnitudes of the perturbations for each variable
            (default:
            [`DEFAULT_PERTURBATION_MAGNITUDE`][ropt.config.constants.DEFAULT_PERTURBATION_MAGNITUDE]).
        perturbation_types:       Type of perturbation for each variable (see
            [`PerturbationType`][ropt.enums.PerturbationType], default:
            [`DEFAULT_PERTURBATION_TYPE`][ropt.config.constants.DEFAULT_PERTURBATION_TYPE]).
        boundary_types:           How to handle perturbations that violate boundary
            conditions (see [`BoundaryType`][ropt.enums.BoundaryType], default:
            [`DEFAULT_PERTURBATION_BOUNDARY_TYPE`][ropt.config.constants.DEFAULT_PERTURBATION_BOUNDARY_TYPE]).
        samplers:                 Indices of the samplers to use for each variable.
        seed:                     Seed for the random number generator used by the samplers.
    """

    variable_count: int
    types: ArrayEnum = np.array(VariableType.REAL)
    lower_bounds: Array1D = np.array(-np.inf)
    upper_bounds: Array1D = np.array(np.inf)
    mask: Array1DBool = np.array(1)
    perturbation_magnitudes: Array1D = np.array(DEFAULT_PERTURBATION_MAGNITUDE)
    perturbation_types: ArrayEnum = np.array(DEFAULT_PERTURBATION_TYPE)
    boundary_types: ArrayEnum = np.array(DEFAULT_PERTURBATION_BOUNDARY_TYPE)
    samplers: Array1DInt = np.array(0)
    seed: ItemOrTuple[int] = (DEFAULT_SEED,)

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        extra="forbid",
        validate_default=True,
        frozen=True,
    )

    @field_validator("types", mode="after")
    @classmethod
    def _check_variable_types(cls, value: ArrayEnum) -> ArrayEnum:
        check_enum_values(value, VariableType)
        return value

    @field_validator("perturbation_types", mode="after")
    @classmethod
    def _check_perturbation_types(cls, value: ArrayEnum) -> ArrayEnum:
        check_enum_values(value, PerturbationType)
        return value

    @field_validator("boundary_types", mode="after")
    @classmethod
    def _check_boundary_types(cls, value: ArrayEnum) -> ArrayEnum:
        check_enum_values(value, BoundaryType)
        return value

    @model_validator(mode="after")
    def _broadcast_and_transform(self, info: ValidationInfo) -> Self:
        dim = self.variable_count
        lower_bounds = broadcast_1d_array(self.lower_bounds, "lower_bounds", dim)
        upper_bounds = broadcast_1d_array(self.upper_bounds, "upper_bounds", dim)
        types = broadcast_1d_array(self.types, "types", dim)
        mask = broadcast_1d_array(self.mask, "mask", dim)
        perturbation_magnitudes = broadcast_1d_array(
            self.perturbation_magnitudes, "perturbation_magnitudes", dim
        )
        perturbation_types = broadcast_1d_array(
            self.perturbation_types, "perturbation_types", dim
        )
        boundary_types = broadcast_1d_array(self.boundary_types, "boundary_types", dim)
        samplers = broadcast_1d_array(self.samplers, "samplers", dim)

        if info.context is not None and info.context.variables is not None:
            lower_bounds = immutable_array(
                info.context.variables.to_optimizer(lower_bounds)
            )
            upper_bounds = immutable_array(
                info.context.variables.to_optimizer(upper_bounds)
            )
            absolute = perturbation_types == PerturbationType.ABSOLUTE
            transformed = info.context.variables.magnitudes_to_optimizer(
                perturbation_magnitudes
            )
            perturbation_magnitudes = immutable_array(
                np.where(absolute, transformed, perturbation_magnitudes)
            )

        if np.any(lower_bounds > upper_bounds):
            msg = "The lower bounds are larger than the upper bounds."
            raise ValueError(msg)

        relative = perturbation_types == PerturbationType.RELATIVE
        if not np.all(
            np.logical_and(
                np.isfinite(lower_bounds[relative]), np.isfinite(upper_bounds[relative])
            ),
        ):
            msg = "The variable bounds must be finite to use relative perturbations"
            raise ValueError(msg)

        return self.model_copy(
            update={
                "types": types,
                "lower_bounds": lower_bounds,
                "upper_bounds": upper_bounds,
                "mask": mask,
                "perturbation_magnitudes": perturbation_magnitudes,
                "perturbation_types": perturbation_types,
                "boundary_types": boundary_types,
                "samplers": samplers,
            }
        )
