from numbers import Integral, Real
from typing import Final, Self, cast

import faiss  # type: ignore
import numpy as np
from scipy.linalg import eigh  # type: ignore
from scipy.linalg.blas import dgemm  # type: ignore
from sklearn.base import BaseEstimator, ClusterMixin  # type: ignore
from sklearn.utils._param_validation import Interval, validate_params  # type: ignore
from sklearn.utils.validation import check_is_fitted  # type: ignore

from ._defs import AffinityTransform, ExpQuantileTransform
from ._kmeans import KMeans

# Constants
DEFAULT_AFFINITY_TRANSFORM: Final = ExpQuantileTransform(0.1, 0.9, 1e4)


class SpectralBridges(BaseEstimator, ClusterMixin):
    """Spectral Bridges clustering algorithm.

    Attributes:
        n_clusters (int): The number of clusters to form.
        n_nodes (int): Number of nodes or initial clusters.
        p (int | float): Power of the alpha_i.
        n_iter (int): Number of iterations to run the k-means algorithm.
        n_local_trials (int | None): Number of seeding trials for centroids
            initialization.
        random_state (int | None): Determines random number generation for centroid
            initialization.
        tol (float): Tolerance for the normalized eigengap.
        affinity_transform (AffinityTransform): Affinity transform to apply to the
            affinity matrix.
        cluster_centers_ (np.ndarray | None): Coordinates of cluster centers.
        cluster_labels_ (np.ndarray | None): Labels of each cluster.
        labels_ (np.ndarray | None): Labels of each data point.
        ngap_ (float | None): The normalized eigengap.
        embedding_ (np.ndarray | None): Embedding of the data.
    """

    n_clusters: int
    n_nodes: int
    p: int | float
    n_iter: int
    n_local_trials: int | None
    random_state: int | None
    tol: float
    no_clustering: bool
    affinity_transform: AffinityTransform
    cluster_centers_: np.ndarray | None
    cluster_labels_: np.ndarray | None
    labels_: np.ndarray | None
    ngap_: float | None
    embedding_: np.ndarray | None

    @validate_params(
        {
            "X": ["array-like"],
            "n_clusters": [Interval(Integral, 2, None, closed="left")],
            "n_nodes": [Interval(Integral, 2, None, closed="left")],
            "p": [Interval(Real, 0, None, closed="left")],
            "n_iter": [Interval(Integral, 1, None, closed="left")],
            "n_local_trials": [Interval(Integral, 1, None, closed="left"), None],
            "random_state": ["random_state"],
            "tol": [Interval(Real, 0, None, closed="left")],
            "no_clustering": [bool],
            "affinity_transform": [AffinityTransform],
        },
        prefer_skip_nested_validation=True,
    )
    def __init__(
        self,
        n_clusters: int,
        n_nodes: int,
        *,
        p: int | float = 2,
        n_iter: int = 20,
        n_local_trials: int | None = None,
        random_state: int | None = None,
        tol: float = 1e-8,
        no_clustering: bool = False,
        affinity_transform: AffinityTransform = DEFAULT_AFFINITY_TRANSFORM,
    ):
        """Initialize the Spectral Bridges model.

        Args:
            n_clusters (int): The number of clusters to form.
            n_nodes  (int | None): Number of nodes or initial clusters.
            p (int | float, optional): Power of the alpha_i. Defaults to 2.
            n_iter (int, optional): Number of iterations to run the k-means
                algorithm. Defaults to 20.
            n_local_trials (int | None, optional): Number of seeding trials for
                centroids initialization. Defaults to None.
            random_state (int | None, optional): Determines random number
                generation for centroid initialization. Defaults to None.
            tol (float, optional): Tolerance for the normalized eigengap.
                Defaults to 1e-8.
            no_clustering (bool, optional): Whether to return the embedding. If set
                to True, the model will not do clustering, and use of predict will raise
                an error. Defaults to False.
            affinity_transform (AffinityTransform, optional): Affinity transform
                to apply to the affinity matrix. Defaults to DEFAULT_AFFINITY_TRANSFORM.
        """
        self.n_clusters = n_clusters
        self.n_nodes = n_nodes
        self.p = p
        self.n_iter = n_iter
        self.n_local_trials = n_local_trials
        self.random_state = random_state
        self.tol = tol
        self.no_clustering = no_clustering
        self.affinity_transform = affinity_transform
        self.cluster_centers_ = None
        self.ngap_ = None
        self.embedding_ = None

        if self.n_nodes <= self.n_clusters:
            raise ValueError(
                f"n_nodes must be greater than n_clusters, got {self.n_nodes} <= "
                f"{self.n_clusters}"
            )

    @validate_params(
        {
            "X": ["array-like"],
            "y": [None],
        },
        prefer_skip_nested_validation=True,
    )
    def fit(self, X: np.typing.ArrayLike, y: None = None) -> Self:  # noqa: ARG002
        """Fit the Spectral Bridges model on the input data X.

        Args:
            X (np.typing.ArrayLike): Input data to cluster.
            y (None, optional): Placeholder for y.

        Raises:
            ValueError: If the number of samples is less than the number of clusters.
            ValueError: If n_nodes is not provided.

        Returns:
            Self: The fitted model.
        """
        X = np.asarray(X)  # type: ignore

        if X.shape[0] < self.n_nodes:
            raise ValueError(
                f"n_samples={X.shape[0]} must be >= n_nodes={self.n_nodes}."
            )

        kmeans = KMeans(
            self.n_nodes,
            self.n_iter,
            self.n_local_trials,
            self.random_state,
        )
        kmeans.fit(X)
        self.cluster_centers_ = cast(np.ndarray, kmeans.cluster_centers_)

        affinity = np.empty((self.n_nodes, self.n_nodes), dtype=np.float64)

        X_centered = [
            np.array(
                X[kmeans.labels_ == i] - self.cluster_centers_[i],
                dtype=np.float64,
                order="F",
            )
            for i in range(self.n_nodes)
        ]

        counts = np.array([X_centered[i].shape[0] for i in range(self.n_nodes)])
        counts = counts[None, :] + counts[:, None]

        for i in range(self.n_nodes):
            segments = np.asfortranarray(
                self.cluster_centers_ - self.cluster_centers_[i]
            )
            dists = np.einsum("ij,ij->i", segments, segments)
            dists[i] = 1

            projs = cast(np.ndarray, dgemm(1.0, X_centered[i], segments, trans_b=True))
            np.clip(projs / dists, 0, None, out=projs)
            projs = np.power(projs, self.p)

            affinity[i] = projs.sum(axis=0)

        affinity = np.power((affinity + affinity.T) / counts, 1 / self.p)

        affinity = self.affinity_transform(affinity)

        d = np.power(affinity.mean(axis=1), -0.5)
        L = -(d[:, None] * affinity * d[None, :])
        np.fill_diagonal(L, self.n_nodes + self.tol)

        eigvals, eigvecs = cast(
            tuple[np.ndarray, np.ndarray],
            eigh(
                L,
                subset_by_index=[0, self.n_clusters],
            ),
        )

        self.embedding_ = eigvecs[:, :-1]
        self.embedding_ /= np.linalg.norm(self.embedding_, axis=1)[:, None]
        self.ngap_ = (eigvals[-1] - eigvals[-2]) / eigvals[-2]

        if self.no_clustering:
            return self

        self.cluster_labels_ = cast(
            np.ndarray,
            KMeans(self.n_clusters, self.n_iter, self.n_local_trials, self.random_state)
            .fit(self.embedding_)  # type: ignore
            .labels_,
        )
        self.labels_ = self.cluster_labels_[kmeans.labels_]

        return self

    @validate_params(
        {
            "X": ["array-like"],
        },
        prefer_skip_nested_validation=True,
    )
    def predict(self, X: np.typing.ArrayLike) -> np.ndarray:
        """Predict the nearest cluster index for each input data point x.

        Args:
            X (np.typing.ArrayLike): The input data.

        Raises:
            ValueError: If `X` contains inf or NaN values.
            ValueError: If `X` contains inf or NaN values.
            ValueError: If `self.cluster_centers_` is not set.

        Returns:
            np.ndarray The predicted cluster indices.
        """
        check_is_fitted(self, ("cluster_centers_", "cluster_labels_"))

        X_f32 = np.asarray(X).astype(np.float32)

        if np.isinf(X_f32).any():
            raise ValueError("X must not contain inf values")
        if np.isnan(X_f32).any():
            raise ValueError("X must not contain NaN values")

        index = faiss.IndexFlatL2(X_f32.shape[1])
        index.add(self.cluster_centers_)  # type: ignore

        return cast(np.ndarray, self.cluster_labels_)[
            cast(np.ndarray, index.search(X_f32, 1)[1]).ravel()  # type: ignore
        ]
