# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
# See the LICENSE file in the root of this repository for details.
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial

import warnings
from typing import Dict, Tuple

import plotly.express as px
import plotly.graph_objects as go
from datasets import Dataset

from validmind import RawData, tags, tasks
from validmind.errors import MissingDependencyError
from validmind.vm_models import VMDataset

from .utils import get_ragas_config, get_renamed_columns

try:
    from ragas import evaluate
    from ragas.metrics import Faithfulness as faithfulness
except ImportError as e:
    if "ragas" in str(e):
        raise MissingDependencyError(
            "Missing required package `ragas` for Faithfulness. "
            "Please run `pip install validmind[llm]` to use LLM tests",
            required_dependencies=["ragas"],
            extra="llm",
        ) from e

    raise e


@tags("ragas", "llm", "rag_performance")
@tasks("text_qa", "text_generation", "text_summarization")
def Faithfulness(
    dataset: VMDataset,
    user_input_column: str = "user_input",
    response_column: str = "response",
    retrieved_contexts_column: str = "retrieved_contexts",
    judge_llm=None,
    judge_embeddings=None,
) -> Tuple[Dict[str, list], go.Figure, go.Figure, RawData]:
    """
    Evaluates the faithfulness of the generated answers with respect to retrieved contexts.

    This metric uses a judge LLM to measure the factual consistency of the generated answer
    against the given context(s). It is calculated using the generated text `answer` from
    the LLM and the retrieved `contexts` which come from some RAG process. The score is
    a value between 0 and 1, where a higher score indicates that the generated answer is
    more faithful to the given context(s).

    The generated answer is regarded as faithful if all the claims that are made in the
    answer can be inferred from the given context. To calculate this a set of claims from
    the generated answer is first identified. Then each one of these claims are cross checked
    with given context to determine if it can be inferred from given context or not. The
    faithfulness score formula is as follows:

    $$
    \\text{Faithfulness score} = {|\\text{Number of claims in the generated answer that can be inferred from given context}| \\over |\\text{Total number of claims in the generated answer}|}
    $$

    ### Configuring Columns

    This metric requires the following columns in your dataset:

    - `user_input` (str): The user input that the model is responding to.
    - `retrieved_contexts` (List[str]): A list of text contexts which are retrieved to generate
    the answer.
    - `response` (str): The response generated by the model which will be evaluated for
    faithfulness against the given contexts.

    If the above data is not in the appropriate column, you can specify different column
    names for these fields using the parameters `retrieved_contexts_column` and
    `response_column`.

    For example, if your dataset has this data stored in different columns, you can
    pass the following parameters:
    ```python
    {
        "retrieved_contexts_column": "context_info",
        "response_column": "my_answer_col",
        "user_input_column": "user_input",
    }
    ```

    If the data is stored as a dictionary in another column, specify the column and key
    like this:
    ```python
    pred_col = dataset.prediction_column(model)
    params = {
        "retrieved_contexts_column": f"{pred_col}.retrieved_contexts",
        "response_column": f"{pred_col}.response",
        "user_input_column": "user_input",
    }
    ```

    For more complex situations, you can use a function to extract the data:
    ```python
    pred_col = dataset.prediction_column(model)
    params = {
        "retrieved_contexts_column": lambda row: [row[pred_col]["context_message"]],
        "response_column": lambda row: "\\n\\n".join(row[pred_col]["messages"]),
        "user_input_column": "user_input",
    }
    ```
    """
    warnings.filterwarnings(
        "ignore",
        category=FutureWarning,
        message="promote has been superseded by promote_options='default'.",
    )

    required_columns = {
        "response": response_column,
        "retrieved_contexts": retrieved_contexts_column,
        "user_input": user_input_column,
    }

    df = get_renamed_columns(dataset._df, required_columns)

    df = df[required_columns.keys()]
    result_df = evaluate(
        Dataset.from_pandas(df),
        metrics=[faithfulness()],
        **get_ragas_config(judge_llm, judge_embeddings)
    ).to_pandas()

    score_column = "faithfulness"

    fig_histogram = px.histogram(
        x=result_df[score_column].to_list(), nbins=10, title="Faithfulness"
    )
    fig_box = px.box(x=result_df[score_column].to_list(), title="Faithfulness")

    return (
        {
            # "Scores (will not be uploaded to ValidMind Platform)": result_df[
            #     ["retrieved_contexts", "response", "faithfulness"]
            # ],
            "Aggregate Scores": [
                {
                    "Mean Score": result_df[score_column].mean(),
                    "Median Score": result_df[score_column].median(),
                    "Max Score": result_df[score_column].max(),
                    "Min Score": result_df[score_column].min(),
                    "Standard Deviation": result_df[score_column].std(),
                    "Count": result_df.shape[0],
                }
            ],
        },
        fig_histogram,
        fig_box,
        RawData(evaluation_results=result_df, dataset=dataset.input_id),
    )
