# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
# See the LICENSE file in the root of this repository for details.
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial

import warnings
from typing import Dict, Tuple

import plotly.express as px
import plotly.graph_objects as go
from datasets import Dataset

from validmind import RawData, tags, tasks
from validmind.errors import MissingDependencyError
from validmind.vm_models import VMDataset

from .utils import get_ragas_config, get_renamed_columns

try:
    from ragas import evaluate
    from ragas.metrics import ResponseRelevancy as response_relevancy
except ImportError as e:
    if "ragas" in str(e):
        raise MissingDependencyError(
            "Missing required package `ragas` for AnswerRelevance. "
            "Please run `pip install validmind[llm]` to use LLM tests",
            required_dependencies=["ragas"],
            extra="llm",
        ) from e

    raise e


@tags("ragas", "llm", "rag_performance")
@tasks("text_qa", "text_generation", "text_summarization")
def ResponseRelevancy(
    dataset: VMDataset,
    user_input_column: str = "user_input",
    retrieved_contexts_column: str = None,
    response_column: str = "response",
    judge_llm=None,
    judge_embeddings=None,
) -> Tuple[Dict[str, list], go.Figure, go.Figure, RawData]:
    """
    Assesses how pertinent the generated answer is to the given prompt.

    The evaluation metric, Response Relevancy, focuses on assessing how pertinent the
    generated answer is to the given prompt. A lower score is assigned to answers that
    are incomplete or contain redundant information and higher scores indicate better
    relevancy. This metric is computed using the `user_input`, the `retrieved_contexts`
    and the `response`.

    The Response Relevancy is defined as the mean cosine similarity of the original
    `user_input` to a number of artificial questions, which are generated (reverse-engineered)
    based on the `response`:

    $$
    \\text{answer relevancy} = \\frac{1}{N} \\sum_{i=1}^{N} cos(E_{g_i}, E_o)
    $$
    $$
    \\text{answer relevancy} = \\frac{1}{N} \\sum_{i=1}^{N} \\frac{E_{g_i} \\cdot E_o}{\\|E_{g_i}\\|\\|E_o\\|}
    $$

    Where:
    - $E_{g_i}$ is the embedding of the generated question $i$.
    - $E_o$ is the embedding of the original question.
    - $N$ is the number of generated questions - 3 by default.

    **Note**: *This is a reference-free metric, meaning that it does not require a
    `ground_truth` answer to compare against. A similar metric that does evaluate the
    correctness of a generated answers with respect to a `ground_truth` answer is
    `validmind.model_validation.ragas.AnswerCorrectness`.*

    ### Configuring Columns

    This metric requires the following columns in your dataset:

    - `user_input` (str): The text query that was input into the model.
    - `retrieved_contexts` (List[str]): Any contextual information retrieved by the model
    before generating an answer.
    - `response` (str): The response generated by the model.

    If the above data is not in the appropriate column, you can specify different column
    names for these fields using the parameters `question_column`, `answer_column`, and
    `contexts_column`.

    For example, if your dataset has this data stored in different columns, you can
    pass the following parameters:
    ```python
    params = {
        "user_input_column": "input_text",
        "response_column": "output_text",
        "retrieved_contexts_column": "context_info"
    }
    ```

    If answer and contexts are stored as a dictionary in another column, specify the
    column and key like this:
    ```python
    pred_col = dataset.prediction_column(model)
    params = {
        "response_column": f"{pred_col}.generated_answer",
        "retrieved_contexts_column": f"{pred_col}.contexts",
    }
    ```

    For more complex data structures, you can use a function to extract the answers:
    ```python
    pred_col = dataset.prediction_column(model)
    params = {
        "response_column": lambda row: "\\n\\n".join(row[pred_col]["messages"]),
        "retrieved_contexts_column": lambda row: [row[pred_col]["context_message"]],
    }
    ```
    """
    warnings.filterwarnings(
        "ignore",
        category=FutureWarning,
        message="promote has been superseded by promote_options='default'.",
    )

    required_columns = {
        "user_input": user_input_column,
        "response": response_column,
    }

    if retrieved_contexts_column:
        required_columns["retrieved_contexts"] = retrieved_contexts_column

    df = get_renamed_columns(dataset._df, required_columns)
    df = df[required_columns.keys()]

    metrics = [response_relevancy()]

    result_df = evaluate(
        Dataset.from_pandas(df),
        metrics=metrics,
        **get_ragas_config(judge_llm, judge_embeddings),
    ).to_pandas()

    score_column = "answer_relevancy"
    fig_histogram = px.histogram(
        x=result_df[score_column].to_list(), nbins=10, title="Response Relevancy"
    )
    fig_box = px.box(x=result_df[score_column].to_list(), title="Response Relevancy")

    return (
        {
            # "Scores (will not be uploaded to ValidMind Platform)": result_df[
            #     ["user_input", "retrieved_contexts", "response", "answer_relevancy"]
            # ],
            "Aggregate Scores": [
                {
                    "Mean Score": result_df[score_column].mean(),
                    "Median Score": result_df[score_column].median(),
                    "Max Score": result_df[score_column].max(),
                    "Min Score": result_df[score_column].min(),
                    "Standard Deviation": result_df[score_column].std(),
                    "Count": result_df.shape[0],
                }
            ],
        },
        fig_histogram,
        fig_box,
        RawData(evaluation_results=result_df, dataset=dataset.input_id),
    )
