# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
# See the LICENSE file in the root of this repository for details.
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial

from typing import Tuple, Union

import plotly.express as px
import plotly.graph_objects as go
from sklearn.manifold import TSNE

from validmind import RawData, tags, tasks
from validmind.logging import get_logger
from validmind.vm_models import VMDataset, VMModel

logger = get_logger(__name__)


@tags("llm", "text_data", "embeddings", "visualization")
@tasks("feature_extraction")
def EmbeddingsVisualization2D(
    dataset: VMDataset,
    model: VMModel,
    cluster_column: Union[str, None] = None,
    perplexity: int = 30,
) -> Tuple[go.Figure, RawData]:
    """
    Visualizes 2D representation of text embeddings generated by a model using t-SNE technique.

    ### Purpose

    The objective of this metric is to provide a visual 2D representation of the embeddings created by a text embedding
    machine learning model. By doing so, it aids in analyzing the embedding space created by the model and helps in
    understanding how the learned embeddings are distributed and how they relate to each other.

    ### Test Mechanism

    This metric uses the t-Distributed Stochastic Neighbor Embedding (t-SNE) technique, which is a tool for visualizing
    high-dimensional data by reducing the dimensionality to 2. The perplexity parameter for t-SNE is set to the value
    provided by the user. If the input perplexity value is greater than the number of samples, the perplexity is
    adjusted to be one less than the number of samples. Following the reduction of dimensionality, a scatter plot is
    produced depicting each embedding as a data point in the visualized 2D plane.

    ### Signs of High Risk

    - If the embeddings are highly concentrated in a specific region of the plane, it might indicate that the model is
    not learning diverse representations of the text.
    - Wide gaps or partitions in the visualization could suggest that the model is over-segmenting in the embedding
    space and may lead to poor generalization.

    ### Strengths

    - Offers a powerful visual tool that can assist in understanding and interpreting high-dimensional embeddings,
    which could otherwise be difficult to visualize.
    - It is model-agnostic and can be used with any machine learning model that produces text embeddings.
    - t-SNE visualization helps in focusing on local structures and preserves the proximity of points that are close
    together in the original high-dimensional space.

    ### Limitations

    - The reduction of high-dimensional data to 2D can result in loss of some information, which may lead to
    misinterpretation.
    - Due to its stochastic nature, t-SNE can produce different results when run multiple times with the same
    parameters, leading to potential inconsistency in interpretation.
    - It is designed for visual exploration and not for downstream tasks; that is, the 2D embeddings generated should
    not be directly used for further training or analysis.
    """
    y_pred = dataset.y_pred(model)

    num_samples = len(y_pred)
    perplexity = perplexity if perplexity < num_samples else num_samples - 1

    reduced_embeddings = TSNE(
        n_components=2,
        perplexity=perplexity,
    ).fit_transform(y_pred)

    if not cluster_column and len(dataset.feature_columns_categorical) == 1:
        cluster_column = dataset.feature_columns_categorical[0]
    else:
        logger.warning("Cannot color code embeddings without a 'cluster_column' param.")

    scatter_kwargs = {
        "x": reduced_embeddings[:, 0],
        "y": reduced_embeddings[:, 1],
        "title": "2D Visualization of Text Embeddings",
    }
    if cluster_column:
        scatter_kwargs["color"] = dataset.df[cluster_column]

    fig = px.scatter(**scatter_kwargs)
    fig.update_layout(width=500, height=500)

    return fig, RawData(
        tsne_embeddings=reduced_embeddings,
        model=model.input_id,
        dataset=dataset.input_id,
    )
