# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
# See the LICENSE file in the root of this repository for details.
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial

import itertools
from typing import Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

from validmind import RawData, tags, tasks
from validmind.vm_models import VMDataset, VMModel


@tags("visualization", "dimensionality_reduction", "embeddings")
@tasks("text_qa", "text_generation", "text_summarization")
def TSNEComponentsPairwisePlots(
    dataset: VMDataset,
    model: VMModel,
    n_components: int = 2,
    perplexity: int = 30,
    title: str = "t-SNE",
) -> Tuple[go.Figure, RawData]:
    """
    Creates scatter plots for pairwise combinations of t-SNE components to visualize embeddings and highlight potential
    clustering structures.

    ### Purpose

    This function creates scatter plots for each pairwise combination of t-SNE components derived from model
    embeddings. t-SNE (t-Distributed Stochastic Neighbor Embedding) is a machine learning algorithm for dimensionality
    reduction that is particularly well-suited for the visualization of high-dimensional datasets.

    ### Test Mechanism

    The function begins by extracting embeddings from the provided dataset using the specified model. These embeddings
    are then standardized to ensure that each dimension contributes equally to the distance computation. Following
    this, the t-SNE algorithm is applied to reduce the dimensionality of the data, with the number of components
    specified by the user. The results are plotted using Plotly, creating scatter plots for each unique pair of
    components if more than one component is specified.

    ### Signs of High Risk

    - If the scatter plots show overlapping clusters or indistinct groupings, it might suggest that the t-SNE
    parameters (such as perplexity) are not optimally set for the given data, or the data itself does not exhibit
    clear, separable clusters.
    - Similar plots across different pairs of components could indicate redundancy in the components generated by
    t-SNE, suggesting that fewer dimensions might be sufficient to represent the data's structure.

    ### Strengths

    - Provides a visual exploration tool for high-dimensional data, simplifying the detection of patterns and clusters
    which are not apparent in higher dimensions.
    - Interactive plots generated by Plotly enhance user engagement and allow for a deeper dive into specific areas of
    the plot, aiding in detailed data analysis.

    ### Limitations

    - The effectiveness of t-SNE is highly dependent on the choice of parameters like perplexity and the number of
    components, which might require tuning and experimentation for optimal results.
    - t-SNE visualizations can be misleading if interpreted without considering the stochastic nature of the algorithm;
    two runs with the same parameters might yield different visual outputs, necessitating multiple runs for a
    consistent interpretation.
    """
    # Get embeddings from the dataset using the model
    embeddings = np.stack(dataset.y_pred(model))

    # Standardize the embeddings
    scaler = StandardScaler()
    embeddings_scaled = scaler.fit_transform(embeddings)

    # Perform t-SNE
    tsne = TSNE(n_components=n_components, perplexity=perplexity)
    tsne_results = tsne.fit_transform(embeddings_scaled)

    # Prepare DataFrame for Plotly
    tsne_df = pd.DataFrame(
        tsne_results, columns=[f"Component {i + 1}" for i in range(n_components)]
    )

    # List to store each plot
    figures = []

    # Create plots for each pair of t-SNE components (if n_components > 1)
    if n_components > 1:
        for comp1, comp2 in itertools.combinations(range(1, n_components + 1), 2):
            fig = px.scatter(
                tsne_df,
                x=f"Component {comp1}",
                y=f"Component {comp2}",
                title=f"{title} - {getattr(model, 'input_id', 'Unknown Model')}",
                labels={
                    f"Component {comp1}": f"Component {comp1}",
                    f"Component {comp2}": f"Component {comp2}",
                },
            )
            figures.append(fig)
    else:
        fig = px.scatter(
            tsne_df,
            x="Component 1",
            y="Component 1",
            title=f"{title} - {getattr(model, 'input_id', 'Unknown Model')}",
            labels={
                "Component 1": "Component 1",
            },
        )
        figures.append(fig)

    return (
        *figures,
        RawData(
            embeddings_scaled=embeddings_scaled,
            tsne_results=tsne_results,
            model=model.input_id,
            dataset=dataset.input_id,
        ),
    )
