#!/usr/bin/python

import math
import os
from scipy import stats
import numpy as np
import matplotlib
from sys import platform

# if os.environ.get('DISPLAY','') == '':
if platform == "linux" or platform == "linux2":
    print('Linux OS. Using non-interactive Agg backend')
    # print('no display found. Using non-interactive Agg backend')
    matplotlib.use('Agg')

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns  # necessary for three_plots() method below
import warnings
from scipy.stats import gaussian_kde


"""
This module centralizes some functionality for plotting methods.
"""


def check_for_plotting():
    """
    Check if we are running on non-HPC environment, if no, don't display figure.
    Returns: True, False

    """
    if os.environ.get("HOME"):
        homepath = os.getenv("HOME")
    elif os.environ.get("HOMEDRIVE"):
        homepath = os.path.join(os.getenv("HOMEDRIVE"), os.getenv("HOMEPATH"))
    else:
        # HPC case: if none of the env vars are present, we're on the HPC and won't plot anything.
        return False

    if os.path.exists(os.path.join(homepath, ".rt_show.sft")):
        return True
    else:
        return False


def plot_poisson_probability(rate, num, file, category='k vs. probability', xlabel='k', ylabel='probability',
                             label1="expected Poisson distribution", label2="test data",
                             title="Poisson Probability Mass Funciton",
                             show=True):
    """
    This function plot and save the actual and expected poisson probability and save them in a file when the error is
    larger than the tolerance

    Args:
        rate: Poisson rate
        num: An array of values
        file: file handle for error reporting
        category:
        xlabel:
        ylabel:
        label1:
        label2:
        title:
        show:

    Returns: None

    """
    if not check_for_plotting():
        show = False

    d = {}
    x = []
    y = []
    z = []
    for n in num:
        if n in d:
            d[n] += 1
        else:
            d[n] = 1
    for n in sorted(num):
        if n in d:
            p = stats.poisson.pmf(n, rate)
            x.append(n)
            y.append(p)
            z.append(d[n] / float(len(num)))
            # file.write("Poisson Probability of {0} is {1}, expected {2}.
            # \n".format(n, p, d[n] / float(len(num_exposures_enviro))))
            if math.fabs(p - d[n] / float(len(num))) > 0.001:
                file.write("{3} : Probability for NumOfExposures {0} is {1}, expected {2}."
                           "\n".format(n, d[n] / float(len(num)), p, category))
                file.write("{3} : Count for NumOfExposures {0} is {1}, expected {2}."
                           "\n".format(n, d[n], p * float(len(num)), category))
            d.pop(n)
    fig = plt.figure()
    plt.plot(x, y, 'r--', x, z, 'bs')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    red_patch = mpatches.Patch(color='red', label=label1)
    blue_patch = mpatches.Patch(color='blue', label=label2)
    plt.legend(handles=[red_patch, blue_patch])
    if show:
        plt.show()
    fig.savefig(str(category) + "_rate" + str(rate) + ".png")
    plt.close(fig)
    return None


def plot_dataframe_line(dataframe):
    """
    Plot each column inside a dataframe in a line figure
    Args:
        dataframe:

    Returns:

    """
    fig = dataframe.plot.line().get_figure()
    fig.savefig("line_chart.png")


def plot_probability(dist1, dist2=None, precision=1, label1="test data", label2="scipy data",
                     title='probability mass function', xlabel='k', ylabel='probability', category='test', show=True,
                     line=False):
    """
    This function plot and the probability mass function of two distributions.
    Args:
        dist1:
        dist2:
        precision: Number of decimal values
        label1:
        label2:
        title:
        xlabel:
        ylabel:
        category:
        show:
        line: True to plot the values as line, False to plot them as dots

    Returns:

    """

    if not check_for_plotting():
        show = False

    d = {}
    x = []
    y = []
    x2 = []
    y2 = []
    for n in sorted(dist1):
        # round n to x number of decimal value, x = precision
        i = round(n, precision)
        if i in d:
            d[i] += 1
        else:
            d[i] = 1
    for key in sorted(d):
        x.append(key)
        y.append(d[key] / float(len(dist1)))
        d.pop(key)
    fig = plt.figure()
    if line:
        color1 = 'r'
        color2 = 'b'
    else:
        color1 = 'rs'
        color2 = 'bs'
    if dist2 is not None:  # solve the future warning with comparison to 'None'.
        # if dist2: doesn't work for dataframe in Bamboo environment.
        for n in sorted(dist2):
            i = round(n, precision)
            if i in d:
                d[i] += 1
            else:
                d[i] = 1
        for key in sorted(d):
            x2.append(key)
            y2.append(d[key] / float(len(dist2)))
            d.pop(key)
        plt.plot(x, y, color1, x2, y2, color2)
    else:
        plt.plot(x, y, color1)
    # plt.plot(x, y, 'r--')
    plt.title("{0}".format(title))
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    red_patch = mpatches.Patch(color='red', label=label1)
    blue_patch = mpatches.Patch(color='blue', label=label2)
    plt.legend(handles=[red_patch, blue_patch])
    fig.savefig(str(category) + '.png')
    if show:
        plt.show()
    plt.close(fig)
    return None


def plot_hist_df(frame, x_column, y_column,
                 title=None,
                 show=True):
    """
    This plots a histogram from a dataframe that knows the count for each item
    Args:
        frame: pandas Dataframe to plot
        x_column: label of the category
        y_column: label of the count
        title: title of plot NOTE don't put spaces in here, also a filename that gets saved
        show: True to display the plot

    Returns: None

    """
    if not title:
        title = f'Frequency_of_{x_column}'
    fig = plt.figure()
    plt.hist(frame[x_column],
             weights=frame[y_column],
             bins=len(frame),
             histtype='bar',
             ec='black')
    plt.title(title)
    plt.xlabel(x_column)
    plt.ylabel(y_column)
    fig.savefig(title + '.png')
    if show:
        plt.show()
    plt.close()
    return None


def plot_hist(dist1, dist2=None, label1="test data 1", label2="test data 2", title=None, xlabel=None, ylabel=None,
              category='histogram', show=True):
    """
    This function plot and the histogram of one/two distributions
    Args:
        dist1:
        dist2:
        label1:
        label2:
        title:
        xlabel:
        ylabel:
        category:
        show:

    Returns:

    """

    if not check_for_plotting():
        show = False

    fig = plt.figure()
    if title:
        plt.title(title)
    if xlabel:
        plt.xlabel(xlabel)
    if ylabel:
        plt.ylabel(ylabel)
    if dist2 is not None:
        plt.hist([dist1, dist2], color=['r', 'b'], alpha=0.8)
    else:
        plt.hist(dist1, color='r', alpha=0.8)
    red_patch = mpatches.Patch(color='red', label=label1)
    blue_patch = mpatches.Patch(color='blue', label=label2)
    plt.legend(handles=[red_patch, blue_patch])
    fig.savefig(str(category) + '.png')
    if show:
        plt.show()
    plt.close(fig)
    return None


def calc_cdf(dist, num_bin=20):
    """
    Calculated cdf for an array like values with a given number of bins
    Args:
        dist: An array like values
        num_bin: number of bins

    Returns: cdf, bin_edges

    """
    min_num = min(dist)
    max_num = max(dist)
    step = float(max_num - min_num) / num_bin
    bin_range = np.arange(min_num, max_num + step, step)
    # Use the histogram function to bin the data
    counts, bin_edges = np.histogram(dist, bins=bin_range, normed=True)
    # Now find the cdf
    cdf = np.cumsum(counts)
    max_p = float(max(cdf))
    cdf = [x / max_p for x in cdf]
    return cdf, bin_edges[1:]


def plot_cdf(dist1, dist2=None, label1="test data", label2="scipy data", title='Cumulative distribution function',
             xlabel='k', ylabel='probability', category='Cumulative_distribution_function', show=True, line=False):
    """
    This function plot and the Cumulative distribution function of one/two distributions
    Args:
        dist1: An array like values
        dist2: Optional, an array like values
        label1: label for dist1
        label2: label for dist2
        title: title of plot
        xlabel: k
        ylabel: probability
        category: plot name
        show: True to display plot
        line: True to plot as line, False to plot as dots

    Returns:

    """
    if not check_for_plotting():
        show = False

    fig = plt.figure()
    if line:
        color1 = 'r'
        color2 = 'b'
    else:
        color1 = 'ro'
        color2 = 'bo'

    num_bin = 20
    cdf, bin = calc_cdf(dist1, num_bin)

    if dist2 is None:
        plt.plot(bin, cdf, color1)
    else:
        cdf_2, bin_2 = calc_cdf(dist2, num_bin)
        plt.plot(bin, cdf, color1, bin_2, cdf_2, color2)

    plt.title("{0}".format(title))
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    red_patch = mpatches.Patch(color='red', label=label1)
    blue_patch = mpatches.Patch(color='blue', label=label2)
    plt.legend(handles=[red_patch, blue_patch])
    # plt.ylim((0.0, 1.01))
    fig.savefig(str(category) + '.png')
    if show:
        plt.show()
    plt.close(fig)
    return None


def plot_data_3series(dist1, dist2, dist3, label1="test data 1", label2="test data 2", label3="test data 3",
                      title=None, xlabel=None, ylabel=None, category='plot_data', show=True, line=False,
                      sort=True):
    """
    Plot 3 array like values(sorted or not sorted)
    Args:
        dist1:
        dist2:
        dist3:
        label1:
        label2:
        label3:
        title:
        xlabel:
        ylabel:
        category:
        show:
        line:
        sort: Whether to sort the values

    Returns:

    """
    if not check_for_plotting():
        return

    fig = plt.figure()
    if title:
        plt.title(title)
    if xlabel:
        plt.xlabel(xlabel)
    if ylabel:
        plt.ylabel(ylabel)
    if sort:
        plt.plot(sorted(dist1), 'r', sorted(dist2), 'b', sorted(dist3), 'm')
    else:
        plt.plot(dist1, 'r', dist2, 'b', dist3, 'm')
    red_patch = mpatches.Patch(color='red', label=label1)
    blue_patch = mpatches.Patch(color='blue', label=label2)
    magenta_patch = mpatches.Patch(color='magenta', label=label3)
    plt.legend(handles=[red_patch, blue_patch, magenta_patch])
    fig.savefig(str(category) + '.png')
    if show:
        plt.show()
    plt.close(fig)
    return None


def plot_data_unsorted(dist1, dist2=None, label1="test data 1", label2="test data 2", title=None, xlabel=None,
                       ylabel=None,
                       category='plot_data_unsorted', show=True, line=False, alpha=1, overlap=False):
    """
    This function plot the data of one\two distributions without any change on the order(unsorted).
    Args:
        dist1:
        dist2:
        label1:
        label2:
        title:
        xlabel:
        ylabel:
        category:
        show:
        line:
        alpha:
        overlap:

    Returns:

    """
    exception = "plot_data_unsorted is deprecated, please use plot_data() with sort argument instead."
    warnings.warn(exception, FutureWarning)

    if not check_for_plotting():
        show = False

    plot_data(dist1=dist1, dist2=dist2, label1=label1, label2=label2, title=title,
              xlabel=xlabel, ylabel=ylabel, category=category, show=show, line=line, alpha=alpha, overlap=overlap,
              sort=False)


def plot_data_sorted(dist1, dist2=None, label1="test data 1", label2="test data 2", title=None, xlabel=None,
                     ylabel=None,
                     category='plot_data_sorted', show=True, line=False, alpha=1, overlap=False):
    """
    This function sort and plot the data of one\two distributions.
    Args:
        dist1:
        dist2:
        label1:
        label2:
        title:
        xlabel:
        ylabel:
        category:
        show:
        line:
        alpha:
        overlap:

    Returns:

    """
    exception = "plot_data_sorted is deprecated, please use plot_data() with sort argument instead."
    warnings.warn(exception, FutureWarning)

    if not check_for_plotting():
        show = False

    plot_data(dist1=dist1, dist2=dist2, label1=label1, label2=label2, title=title, xlabel=xlabel, ylabel=ylabel,
              category=category, show=show, line=line, alpha=alpha, overlap=overlap, sort=True)


def plot_scatter_fit_line(dist1, dist2=None, label1="test data 1", label2=None, title=None,
                          xlabel=None, ylabel=None, xmin=None, xmax=None, ymin=None, ymax=None,
                          category='plot_scatter_fit_line', show=True, line=False, marker='s',
                          xticks=None, xtickslabel=None):
    """
    This function plot the dist1 data with color based on the spatial density of nearby points. If dist2 is provided,
    it plot dist2 data as fit data/line for dist1.
    Args:
        dist1:
        dist2:
        label1:
        label2:
        title:
        xlabel:
        ylabel:
        xmin:
        xmax:
        ymin:
        ymax:
        category:
        show:
        line:
        marker:
        xticks:
        xtickslabel:

    Returns:

    """
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    if not check_for_plotting():
        show = False

    fig, ax = plt.subplots()

    if title:
        ax.set_title(title)
    if xlabel:
        ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel(ylabel)
    if xmin and xmax:
        ax.set_xlim(xmin, xmax)
    else:
        if xmax:
            ax.set_xlim(xmax=xmax)
        if xmin:
            ax.set_xlim(xmin=xmin)
    if ymin and ymax:
        ax.set_ylim(ymin, ymax)
    else:
        if ymax:
            ax.set_ylim(ymax=ymax)
        if xmin:
            ax.set_ylim(ymin=ymin)

    color = 'k' + marker
    if line:
        color += '-'

    # Generate data for scatter plot
    y = dist1
    x = np.arange(len(y))

    # Calculate the point density
    xy = np.vstack([x, y])
    z = gaussian_kde(xy)(xy)

    # Sort the points by density, so that the densest points are plotted last
    idx = z.argsort()
    x, y, z = x[idx], y[idx], z[idx]

    # scatter = ax.scatter(x, y, c=z, s=20000 * z, edgecolor='', label=label1, cmap="rainbow", alpha=0.5)
    scatter = ax.scatter(x, y, c=z, s=100, edgecolor='', label=label1, cmap="rainbow", alpha=0.5)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cax.set_ylabel("spatial density of nearby points")

    fig.colorbar(scatter, cax=cax, orientation='vertical')

    if dist2 is not None:  # "if dist2:" will not work with numpy.ndarray
        # use plot instead of scatter for more efficiency here.
        ax.plot(dist2, color, label=label2, lw=0.5, markersize=3)
    ax.legend(loc=0)

    if xticks is not None:
        ax.set_xticks(xticks)
        if xtickslabel is not None:
            ax.set_xticklabels(xtickslabel, fontsize=8, rotation=20, rotation_mode="anchor", ha='right')

    fig.tight_layout()
    fig.savefig(str(category) + '.png')
    if show:
        plt.show()
    plt.close(fig)
    return None


def plot_data(dist1, dist2=None, label1="test data 1", label2=None, title=None,
              xlabel=None, ylabel=None, xmin=None, xmax=None, ymin=None, ymax=None,
              category='plot_data', show=True, line=False, alpha=1, overlap=False, marker1='s', marker2='o',
              sort=False, xticks=None, xtickslabel=None):
    """
    This function plot the data of one\two distributions.
    Args:
        dist1:
        dist2:
        label1:
        label2:
        title:
        xlabel:
        ylabel:
        xmin:
        xmax:
        ymin:
        ymax:
        category:
        show:
        line:
        alpha:
        overlap:
        marker1:
        marker2:
        sort:
        xticks:
        xtickslabel:

    Returns:

    """
    if not check_for_plotting():
        show = False

    if sort:  # use not in-place method
        dist1 = sorted(dist1)
        if dist2 is not None:
            dist2 = sorted(dist2)

    fig = plt.figure()
    ax = fig.add_axes([0.12, 0.15, 0.76, 0.76])
    if title:
        ax.set_title(title)
    if xlabel:
        ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel(ylabel)
    if xmin and xmax:
        ax.set_xlim(xmin, xmax)
    else:
        if xmax:
            ax.set_xlim(xmax=xmax)
        if xmin:
            ax.set_xlim(xmin=xmin)
    if ymin and ymax:
        ax.set_ylim(ymin, ymax)
    else:
        if ymax:
            ax.set_ylim(ymax=ymax)
        if xmin:
            ax.set_ylim(ymin=ymin)
    if overlap:
        color1 = 'r' + marker1
        color2 = 'b' + marker2
    else:
        color1 = 'y' + marker1
        color2 = 'g' + marker2
    if line:
        color1 += '-'
        color2 += '-'
    ax.plot(dist1, color1, alpha=alpha, label=label1, lw=0.5, markersize=4)
    if dist2 is not None:  # "if dist2:" will not work with numpy.ndarray
        plt.plot(dist2, color2, alpha=alpha, label=label2, lw=0.5, markersize=3)
    ax.legend(loc=0)

    if xticks is not None:
        ax.set_xticks(xticks)
        if xtickslabel is not None:
            ax.set_xticklabels(xtickslabel, fontsize=8, rotation=20, rotation_mode="anchor", ha='right')

    fig.savefig(str(category) + '.png')
    if show:
        plt.show()
    plt.close(fig)
    return None


def plot_pie(sizes, labels, category='plot_pie', show=True):
    """
    plot a pie chart based on sizes
    Args:
        sizes:
        labels:
        category:
        show:

    Returns:
    """
    plt.pie(sizes, labels=labels, autopct='%1.1f%%', shadow=True, startangle=140)
    plt.axis('equal')
    plt.savefig(str(category) + '.png')
    if show:
        plt.show()
    plt.close()
    return None


def plot_scatter_with_fit_lines(dataframe, xlabel, ylabel, fit_data_segments,
                                est_data_segments, fit_data_label="fits to data",
                                est_data_label="estimated fits", category="plot_data_fits",
                                show=True):
    """
    Plots a scatterplot of data, as well as fitted and estimated segments. This is written with Immunity Initialization
     in mind
    Args:
        dataframe:
        xlabel: What is the x-axis? For example, 'age'
        ylabel: what is the y-axis? For example, 'mod_acquire'
        fit_data_segments: array of data-fit line segments in this format ([startx, endx],[starty, endy])
        est_data_segments: array of esimated (ideal) line segments as above
        fit_data_label: ("fits to data")
        est_data_label: ("esitmated fits")
        category: name of file without extension ("plot_data_fits")
        show:

    Returns:

    """
    if not check_for_plotting():
        return
    fig = plt.figure()
    plt.scatter(dataframe[xlabel], dataframe[ylabel], s=20, alpha=0.02, lw=0)
    for segment in est_data_segments:
        plt.plot(segment[0], segment[1], 'r')  # Red is reference, or expected data
    for segment in fit_data_segments:
        plt.plot(segment[0], segment[1], 'b')  # Blue is data under test
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    fig.savefig(str(category) + '.png')
    if show:
        plt.show()
    plt.close(fig)


def plot_bar_graph(data, xticklabels, x_label, y_label, legends, plot_name, show=True, num_decimal=1):
    """
    Plot bar chart for two lists of data.
    Args:
        data: Two lists of data to plot
        xticklabels: List of string for xticks, length = length of data to plot
        x_label: String for x label
        y_label: String for y label
        legends: List of string for legends, length = 2
        plot_name: String for plot title and name
        show:
        num_decimal: Number of decimal places for label

    Returns:

    """
    if not check_for_plotting():
        show = False

    fig = plt.figure()
    ax = fig.add_axes([0.12, 0.15, 0.76, 0.76])
    x_ind = np.arange(len(xticklabels))
    width = 0.35
    rectangles1 = ax.bar(x_ind, data[0], width, color="orange")
    rectangles2 = ax.bar(x_ind + width, data[1], width, color="seagreen")
    ax.set_xticks(x_ind + width / 2)
    ax.set_xticklabels(xticklabels)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(f'{plot_name}')
    ax.legend((rectangles1[0], rectangles2[0]), (legends[0], legends[1]))
    autolabel(ax, rectangles1, num_decimal)
    autolabel(ax, rectangles2, num_decimal)

    fig.savefig(f'{plot_name}.png')
    if show:
        plt.show()
    plt.close(fig)


def autolabel(ax, rects, num_decimal=1):
    """
    Attach a text label above each bar displaying its height
    Args:
        ax: axes
        rects: axes.bar
        num_decimal: Number of decimal places for label

    Returns: None

    """
    for rect in rects:
        height = rect.get_height()
        ax.text(rect.get_x() + rect.get_width() / 2., 1.005 * height,
                f'{round(float(height), num_decimal)}',
                ha='center', va='center', rotation=45)


def plot_cdf_w_fun(data, name="cdf", cdf_function=None, args=(), show=False):
    """
    Plot Cumulative distribution function for a given array of values with theoretical cdf if cdf_function is provided.
    Args:
        data: An array of values
        name: Name of plot
        cdf_function: Theoretical cdf
        args: Optional arguments for cdf_function
        show: True to display plot

    Returns:

    """
    if not check_for_plotting():
        show = False

    cdf, bin_edges = calculate_cdf(data)
    data_set = sorted(set(data))

    # Plot the cdf
    fig = plt.figure()
    ax = fig.add_axes([0.12, 0.12, 0.76, 0.76])
    plt.plot(bin_edges[:-1], cdf, linestyle='--', marker="o", color='b', alpha=0.3, label="calculated with data bin",
             markersize=3)
    ax.set_ylim((-0.01, 1.05))
    ax.set_ylabel("Probability")
    ax.set_xlabel('X')
    plt.grid(True)

    if cdf_function is not None:
        cdf_theoretical = cdf_function(data_set, *args)
        plt.scatter(data_set, cdf_theoretical, color='r', alpha=0.3,
                    label=f"calculated used {cdf_function.__name__} function", s=20)
    ax.set_title("cumulative distribution function")
    ax.legend(loc=0)

    plt.savefig(f"{name}.png")
    if show:
        plt.show()
    plt.close()


def calculate_cdf(data):
    """
    Calculate cdf with given array of values
    Args:
        data:

    Returns:

    """
    data_size = len(data)

    # Set bins edges
    data_set = sorted(set(data))
    bins = np.append(data_set, data_set[-1] + 1)

    # Use the histogram function to bin the data
    counts, bin_edges = np.histogram(data, bins=bins, density=False)

    counts = counts.astype(float) / data_size

    # Find the cdf
    cdf = np.cumsum(counts)

    return cdf, bin_edges


def get_cmap(n, name='hsv'):
    """
    Returns a function that maps each index in 0, 1, ..., n-1 to a distinct RGB color; the keyword argument name must
    be a standard mpl colormap name.
    Args:
        n:
        name:

    Returns:

    """
    return plt.cm.get_cmap(name, n)


def three_plots(dist1, cdf_function=None, args=(), dist2=None,
                label1="data 1", label2=None, title=None, xlabel=None, ylabel=None,
                category='three_plots', show=True, line=False, alpha=1, color1='b', color2='r', sort=False):
    """
    Compares two distributions with a distribution plot, a density plot, and a cdf plot.
    Mainly used to compare an expected distribution with one from model data.
    Args:
        dist1:
        cdf_function:
        args:
        dist2:
        label1:
        label2:
        title:
        xlabel:
        ylabel:
        category:
        show:
        line:
        alpha:
        color1:
        color2:
        sort:

    Returns:

    """
    if not check_for_plotting():
        show = False

    if sort:
        dist1 = sorted(dist1)
        if dist2 is not None:
            dist2 = sorted(dist2)
    fig, axarr = plt.subplots(1, 3)

    # 1st plot: simple plot with all data point
    if line:
        color1 += '-'
        color2 += '-'
    axarr[0].plot(dist1, color1, marker='s', alpha=alpha, label=label1, lw=0.5, markersize=4)
    if dist2 is not None:  # "if dist2:" will not work with numpy.ndarray
        axarr[0].plot(dist2, color2, marker='o', alpha=alpha, label=label2, lw=0.5, markersize=3)
    axarr[0].set_title("plot data")
    if xlabel:
        axarr[0].set_xlabel(xlabel)
    if ylabel:
        axarr[0].set_ylabel(ylabel)

    # 2nd plot: density plot
    sns.distplot(dist1, ax=axarr[1], color=color1, vertical=True, label=label1)
    if dist2 is not None:  # "if dist2:" will not work with numpy.ndarray
        sns.distplot(dist2, ax=axarr[1], color=color2, vertical=True, label=label2)
    axarr[1].set_xlabel("Probability")
    axarr[1].set_title("distplot")
    axarr[1].set_ylim(axarr[0].get_ylim())

    # 3rd plot: cdf plot
    cdf, bin_edges = calculate_cdf(dist1)
    data_set = sorted(set(dist1))
    axarr[2].plot(bin_edges[:-1], cdf, linestyle='--', color=color1, alpha=alpha, label=label1,
                  markersize=3)
    if cdf_function is not None:
        cdf_theoretical = cdf_function(data_set, *args)
        axarr[2].scatter(data_set, cdf_theoretical, color='r', alpha=0.3,
                         label=f"{cdf_function.__name__}", s=20)
    elif dist2 is not None:
        cdf2, bin_edges2 = calculate_cdf(dist2)
        axarr[2].plot(bin_edges2[:-1], cdf2, linestyle='--', color=color2, alpha=alpha,
                      label=label2, markersize=3)
    axarr[2].set_ylim((-0.01, 1.05))
    axarr[2].set_ylabel("Probability")
    if ylabel:
        axarr[2].set_xlabel(ylabel)
    axarr[2].set_title("CDF")

    # formatting
    if title:
        fig.suptitle(title)
    for ax in axarr:
        ax.legend(loc=0)
        ax.grid(True)
    fig.tight_layout()
    fig.subplots_adjust(top=0.88)
    fig.savefig(str(category) + '.png')
    if show:
        plt.show()
    plt.close(fig)
    return None


def plot_histogram(nums, name, mn, mx):
    """
    Plot a histogram with an array of values
    Args:
        nums:
        name:
        mn: 
        mx:

    Returns:

    """
    plt.hist(nums, density=True, bins=max(nums), label="Data")
    plt.xlim(mn, mx)
    kde_xs = np.linspace(mn, mx, 61)
    kde = stats.gaussian_kde(nums)
    plt.plot(kde_xs, kde.pdf(kde_xs), label="PDF")
    plt.legend(loc="upper left")
    plt.ylabel('Probability')
    plt.xlabel(name)
    plt.title("Histogram")
    plt.savefig(f'{name}.png')
    plt.close()
