import sys, os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from file_utils import prepare_file_name_saving

import logging

logging.basicConfig(
    format="%(levelname)s:%(asctime)s %(message)s",
    datefmt="%m/%d/%Y %I:%M:%S %p",
    level=logging.INFO,
    filename="chembee_plotting.log",
)


def plot_collection(metrics_json: dict, file_name: str, prefix: str):

    scalar_metrics = metrics_json["scalar"]
    array_metrics = metrics_json["array"]
    matrix_metrics = metrics_json["matrix"]

    plot_bar_chart_collection(scalar_metrics, file_name, prefix)
    plot_roc_chart_collection(array_metrics, file_name, prefix)
    plot_heat_map_collection(matrix_metrics, file_name, prefix)


def plot_heat_map_collection(metrics_json: dict, file_name: str, prefix: str):
    algs, metrics, metrics_storage = init_collection_plot(
        metrics_json, metric_type="scalar"
    )
    for metric in metrics:
        metrics_storage = []
        for i in range(len(algs)):
            metrics_storage.append(metrics_json[algs[i]][metric])


def plot_roc_chart_collection(metrics_json: dict, file_name: str, prefix: str):

    # TODO: Need to do this for multi_class, too
    # Not very elegant here the numbering is bad

    algs, metrics, metrics_storage = init_collection_plot(
        metrics_json, metric_type="scalar"
    )
    for i in range(len(algs)):
        plot_roc_chart(
            metrics_json[algs[i]]["fpr"],
            metrics_json[algs[i]]["tpr"],
            metrics_json[algs[i]]["roc_auc"],
            file_name="roc_auc_" + algs[i] + "_" + file_name,
            prefix=prefix,
        )


def plot_bar_chart_collection(metrics_json: dict, file_name: str, prefix: str):
    """
    The plot_bar_chart_collection function plots a collection of bar charts, one for each metric in the metrics_json.
    The metrics_json is expected to be a dictionary with keys corresponding to algorithm names and values being dictionaries
    of scalar values (one value per metric). The file name is expected to be the same for all plots.

    :param metrics_json:dict: Used to Store the values of all metrics for each algorithm.
    :param file_name:str: Used to Specify the name of the file that is generated by this function.
    :param prefix:str: Used to Distinguish between different types of plots.
    :return: The metrics_storage array.

    :doc-author: Trelent
    """

    algs, metrics, metrics_storage = init_collection_plot(
        metrics_json, metric_type="scalar"
    )
    metrics_store = np.zeros((len(metrics), len(algs)))
    for i in range(len(metrics)):
        metric = metrics[i]
        for j in range(len(algs)):
            try:
                metrics_storage[i, j] = metrics_json[algs[j]][metric]
            except:
                raise RuntimeError(
                    "Either value for alg:"
                    + str(algs[j])
                    + " or metric:"
                    + str(metric)
                    + " is not a valid value"
                )
        plot_bar_chart(
            algs,
            metrics_storage[i, :],
            y_label=metric,
            file_name=metric + "_" + file_name,
            prefix=prefix,
        )


def plot_roc_chart(
    fpr: list, tpr: list, roc_auc: list, prefix: str, file_name: str = "roc_auc_curve"
):
    """
    The plot_roc_chart function takes in a list of false positive rates, true positive rates, and the area under the curve.
    It then plots these values on a graph and saves it to file_name.png

    :param fpr:list: Used to Plot the false positive rate.
    :param tpr:list: Used to Plot the true positive rate.
    :param roc_auc:list: Used to Plot the area under the curve.
    :param prefix:str: Used to Add a prefix to the file name.
    :param file_name:str="roc_auc_curve": Used to Specify the name of the file to be saved.
    :return: The area under the curve (auc).

    :doc-author: Trelent
    """

    "Needs x,y data, check here for more information"
    matplotlib.rcParams.update({"font.size": 32})
    fig = plt.figure(figsize=(15, 15))
    lw = 4
    plt.plot(
        1 - fpr,  # compare test with standard function
        1 - tpr,
        lw=lw,
        label="ROC curve %s" % str(roc_auc),
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc="lower right")
    fig.tight_layout()
    file_name = prepare_file_name_saving(
        prefix=prefix, file_name=file_name, ending=".png"
    )
    plt.savefig(file_name)
    plt.clf()
    plt.cla()
    plt.close()


def plot_heat_map():
    pass


def plot_grouped_bar_chart(
    labels: list,
    data_names: list,
    data: np.ndarray,
    file_name: str,
    prefix: str,
    y_label: str = "Score",
):
    """
    The plot_grouped_bar_chart function creates a grouped bar chart.
    The function takes the following parameters:
        - labels: list of strings, each string is a label for one group of bars;
        - data_names: list of strings, each string is a name for one bar in the plot;
        - data: numpy array with shape (len(labels), len(data_names)), contains values to be plotted as bars;
        - file_name (optional): str, name under which the plot will be saved. If not specified it will not save anything;
         prefix (optional): str, prefix that should be added to all files created by this function. If not specified no prefix will be added.

    :param labels:list: Used to Set the labels for each bar.
    :param data_names:list: Used to Label the bars.
    :param data:np.ndarray: Used to Pass the data to be plotted.
    :param file_name:str: Used to Specify the name of the file to be saved.
    :param prefix:str: Used to Identify the type of data that is being plotted.
    :param y_label:str="Score": Used to Set the y label of the plot.
    :return: A plot of the data.

    :doc-author: Trelent
    """

    # the problem here is that labels, and data entities are not the same. But viz is necessary to transport information. Anyhow, I need those ROC curves.

    file_name = prepare_plot_file_name(prefix, file_name)
    matplotlib.rcParams.update({"font.size": 22})
    fig = plt.figure(figsize=(15, 15))
    x = np.arange(len(data_names))  # the label locations
    width = 0.8 / len(data_names)  # the width of the bars

    fig, ax = plt.subplots()
    for i in range(len(labels)):
        rects = ax.bar(x - width / len(labels), data[i, :], width, label=labels[i])
        ax.bar_label(rects1, padding=len(labels) + 1)

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel("Scores")
    ax.set_xticks(x, labels)
    ax.legend()
    fig.tight_layout()
    plt.savefig(file_name)
    plt.cla()
    plt.clf()
    plt.close()


def plot_bar_chart(
    algs: list, metrics: list, file_name: str, prefix: str, y_label: str, **kwargs
):
    """
    The plot_bar_chart function plots a bar chart of the given data.

    :param algs:list: Used to Specify the list of algorithms for which the metrics should be plotted.
    :param metrics:list: Used to Specify the metrics that should be plotted.
    :param file_name:str: Used to Specify the name of the file to save.
    :param prefix:str: Used to Specify the prefix for the file name.
    :return: A file name.

    :doc-author: Trelent
    """
    y_label = y_label
    matplotlib.rcParams.update({"font.size": 22})
    fig = plt.figure(figsize=(15, 15))
    plt.bar(algs, metrics)
    file_name = prepare_file_name_saving(
        prefix=prefix, file_name=file_name, ending=".png"
    )
    plt.ylabel(y_label)
    fig.tight_layout()
    fig.savefig(file_name)
    plt.clf()
    plt.cla()
    plt.close()
    logging.info("Plotted " + str(file_name))


def init_collection_plot(metrics_json, metric_type=""):
    """
    The init_collection_plot function initializes the plot for a collection of algorithms.
    It takes as input a metrics_json file and returns an array of values that will be plotted.

    :param metrics_json: Used to Store the metrics of each algorithm.
    :param metric_type="": Used to Specify the type of metric that we are plotting.
    :return: The metrics_storage variable.

    :doc-author: Trelent
    """

    algs, metrics = parse_metrics_output(metrics_json)
    metrics_storage = np.zeros((len(metrics), len(algs)))

    return algs, metrics, metrics_storage


def parse_metrics_output(metrics_collection: dict) -> dict:
    """
    The parse_metrics_output function takes a dictionary of metrics and returns a dictionary of dictionaries.
    The outermost key is the algorithm name, and the value is another dictionary with keys 'precision', 'recall',
    and 'fscore'. The values are lists containing precision, recall, and fscore for each class in order.

    :param metrics_collection:dict: Used to Store the metrics of each algorithm.
    :return: A dictionary of dictionaries.

    :doc-author: Trelent
    """

    algs = list(metrics_collection.keys())
    metrics = list(metrics_collection[algs[0]].keys())
    return algs, metrics


if __name__ == "__main__":
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

    from datasets.BreastCancer import BreastCancerDataset
    from actions.evaluation import screen_classifier_for_metrics
    from plotting.evaluation import plot_barchart_collection

    DataSet = BreastCancerDataset(split_ratio=0.8)
    metrics = screen_classifier_for_metrics(
        X_train=DataSet.X_train,
        X_test=DataSet.X_test,
        y_train=DataSet.y_train,
        y_test=DataSet.y_test,
    )
    plot_barchart_collection(
        metrics, file_name=DataSet.name + "_evaluation", prefix="plots/evaluation/"
    )
