# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
# See the LICENSE file in the root of this repository for details.
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial

from typing import Tuple

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics.pairwise import euclidean_distances

from validmind import RawData, tags, tasks
from validmind.vm_models import VMDataset, VMModel


@tags("visualization", "dimensionality_reduction", "embeddings")
@tasks("text_qa", "text_generation", "text_summarization")
def EuclideanDistanceHeatmap(
    dataset: VMDataset,
    model: VMModel,
    title="Euclidean Distance Matrix",
    color="Euclidean Distance",
    xaxis_title="Index",
    yaxis_title="Index",
    color_scale="Blues",
) -> Tuple[go.Figure, RawData]:
    """
    Generates an interactive heatmap to visualize the Euclidean distances among embeddings derived from a given model.

    ### Purpose

    This function visualizes the Euclidean distances between embeddings generated by a model, offering insights into
    the absolute differences between data points. Euclidean distance, a fundamental metric in data analysis, measures
    the straight-line distance between two points in Euclidean space. It is particularly useful for understanding
    spatial relationships and clustering tendencies in high-dimensional data.

    ### Test Mechanism

    The function operates through a streamlined process: firstly, embeddings are extracted for each dataset entry using
    the specified model. Subsequently, it computes the pairwise Euclidean distances among these embeddings. The results
    are then visualized in an interactive heatmap format, where each cell's color intensity correlates with the
    distance magnitude between pairs of embeddings, providing a visual assessment of these distances.

    ### Signs of High Risk

    - Uniformly low distances across the heatmap might suggest a lack of variability in the data or model overfitting,
    where the model fails to distinguish between distinct data points effectively.
    - Excessive variability in distances could indicate inconsistent data representation, potentially leading to
    unreliable model predictions.

    ### Strengths

    - Provides a direct, intuitive visual representation of distances between embeddings, aiding in the detection of
    patterns or anomalies.
    - Allows customization of visual aspects such as the heatmap's title, axis labels, and color scale, adapting to
    various analytical needs.

    ### Limitations

    - The interpretation of distances can be sensitive to the scale of data; normalization might be necessary for
    meaningful analysis.
    - Large datasets may lead to dense, cluttered heatmaps, making it difficult to discern individual distances,
    potentially requiring techniques like data sampling or dimensionality reduction for clearer visualization.
    """

    embeddings = np.stack(dataset.y_pred(model))

    # Calculate pairwise Euclidean distance
    distance_matrix = euclidean_distances(embeddings)

    # Create the heatmap using Plotly
    fig = px.imshow(
        distance_matrix,
        labels=dict(x=xaxis_title, y=yaxis_title, color=color),
        text_auto=True,
        aspect="auto",
        color_continuous_scale=color_scale,
    )

    fig.update_layout(
        title=f"{title} - {model.input_id}",
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
    )

    return fig, RawData(
        distance_matrix=distance_matrix, model=model.input_id, dataset=dataset.input_id
    )
