#!/usr/bin/python

import math
from scipy import stats
import numpy as np

from io import TextIOWrapper
from typing import Union, Optional

from idm_test import plot as my_plot
from idm_test.dtk_test import sft

"""
This module centralizes some statistical test methods.
"""



def test_binomial_95ci(num_success, num_trials, prob, report_file, category):
    """
    This function test if a binomial distribution falls within the 95% confidence interval
    Args:
        num_success: The number of successes of all trials(binomial tests)
        num_trials: The total number of trials(binomial tests)
        prob: The probability of success for one trial(binomial test)
        report_file: For error reporting
        category: For error reporting

    Returns: True, False

    """
    # calculate the mean and  standard deviation for binomial distribution
    mean = num_trials * prob
    standard_deviation = math.sqrt(prob * (1 - prob) * num_trials)
    # 95% confidence interval
    lower_bound = mean - 2 * standard_deviation
    upper_bound = mean + 2 * standard_deviation
    success = True
    result_message = f"For category {category}, the success cases is {num_success}," \
                     f"expected 95% confidence interval ( {lower_bound}, {upper_bound}).\n"
    if mean < 5 or num_trials * (1 - prob) < 5:
        # The general rule of thumb for normal approximation method is that
        # the sample size n is "sufficiently large" if np >= 5 and n(1-p) >= 5
        # for cases that binomial confidence interval will not work
        success = False
        report_file.write(f"There is not enough sample size in group {category}:" \
                          f"mean = {mean}, sample size - mean = {num_trials * (1 - prob)}.\n")
    elif num_success < lower_bound or num_success > upper_bound:
        success = False
        report_file.write(f"BAD:  {result_message}")
    else:
        report_file.write(f"GOOD: {result_message}")
    return success


def test_binomial_99ci(num_success, num_trials, prob, report_file, category):
    """
    This function test if a binomial distribution falls within the 99.73% confidence interval
    Args:
        num_success: The number of successes of all trials(binomial tests)
        num_trials: The total number of trials(binomial tests)
        prob: The probability of success for one trial(binomial test)
        report_file: For error reporting
        category: For error reporting

    Returns: True, False

    """
    # calculate the mean and  standard deviation for binomial distribution
    mean = num_trials * prob
    standard_deviation = math.sqrt(prob * (1 - prob) * num_trials)
    # 99.73% confidence interval
    lower_bound = mean - 3 * standard_deviation
    upper_bound = mean + 3 * standard_deviation
    success = True
    result_message = f"For category {category}, the success cases is {num_success}," \
                     f"expected 99.75% confidence interval ( {lower_bound}, {upper_bound}).\n"
    if mean < 5 or num_trials * (1 - prob) < 5:
        # The general rule of thumb for normal approximation method is that
        # the sample size n is "sufficiently large" if np >= 5 and n(1-p) >= 5
        success = False
        report_file.write(
            f"There is not enough sample size in group {category}: mean = {mean}, sample size - mean = "
            f"{num_trials * (1 - prob)}.\n")
    elif num_success < lower_bound or num_success > upper_bound:
        success = False
        report_file.write(f"BAD:  {result_message}")
    else:
        report_file.write(f"GOOD: {result_message}")
    return success


def calc_poisson_binomial(prob):
    """
    By definition, a Poisson binomial distribution is a sum of n independent Bernoulli distribution. This function
    calculated the mean, standard deviation and variance based on probabilities from n independent Bernoulli
    distribution.
    Args:
        prob: List of probabilities from n independent Bernoulli distribution

    Returns: Mean, standard deviation and variance

    """
    mean = 0
    variance = 0
    standard_deviation = 0
    for p in prob:
        mean += p
        variance += (1 - p) * p
        standard_deviation = math.sqrt(variance)
    return {'mean': mean, 'standard_deviation': standard_deviation, 'variance': variance}


def calc_ks_critical_value(num_trials):
    """
    This function returns the critical values for kstest Statistic test based on KS table.
    (reference: http://www.cas.usf.edu/~cconnor/colima/Kolmogorov_Smirnov.htm)
    Args:
        num_trials: Number of trials = length of distribution

    Returns: Critical values assuming confidence_interval = 0.05

    """
    # KS table
    ks_table = [0.975, 0.842, 0.708, 0.624, 0.565, 0.521, 0.486, 0.457, 0.432, 0.410, 0.391, 0.375, 0.361, 0.349, 0.338,
                0.328, 0.318, 0.309, 0.301, 0.294, 0.270, 0.240, 0.230]
    critical_value_s = 0
    if num_trials <= 20:
        critical_value_s = ks_table[num_trials - 1]
    elif num_trials <= 25:
        critical_value_s = ks_table[20]
    elif num_trials <= 30:
        critical_value_s = ks_table[21]
    elif num_trials <= 35:
        critical_value_s = ks_table[22]
    else:
        critical_value_s = 1.36 / math.sqrt(num_trials)
    return critical_value_s


def get_p_s_from_ksresult(result):
    """
    Get p-value and statistic from ks test result string
    Args:
        result: Result string from ks test

    Returns: {'p': p, 's': s}

    """
    p = s = 0
    # NOTE: different versions of kstest seem to produce different output.
    if "pvalue" in result:
        p = float(sft.get_val("pvalue=", str(result)))
        s = float(sft.get_val("statistic=", str(result)))
    else:
        s = result[0]
        p = result[1]
        # report_file.write("s is {0}, p is : {1} for {2}.\n".format(s, p, category))
    return {'p': p, 's': s}


def create_geometric_dist(rate, scale, size, test_decay=True):
    """
    Create a Geometric like distribution with decay or no decay
    Args:
        rate: Geometric rate
        scale:
        size:
        test_decay:

    Returns: An array of values

    """
    curr_change = 0
    curr_count = scale
    series = []
    for _ in range(size):
        curr_count -= curr_change
        curr_change = math.floor(curr_count * rate)
        if test_decay:
            series.append(curr_count)
        else:
            series.append(scale - curr_count)
    return series


# Ye: I think this is not testing a real Geometric distribution, please see test_geometric(). This is for a distribution
# that looks like what created with create_geometric_dist()
def test_geometric_decay(distribution_under_test, rate, scale, test_decay=True, report_file=None, debug=False):
    """
    Tests if the given distribution is a geometric distribution with the given rate and scale
    Args:
        distribution_under_test: Array of integers to test against
        rate: Rate per timestep at which the value decays
        scale: Number of things to decay
        test_decay: True assumes that your distribution is being decayed, False assumes it is growing
        report_file:
        debug:

    Returns: True, False

    """
    size = len(distribution_under_test)
    series = create_geometric_dist(rate, scale, size, test_decay)

    result = stats.ks_2samp(series, distribution_under_test)
    if debug and report_file:
        report_file.write(str(result) + "\n")

    p = get_p_s_from_ksresult(result)['p']
    s = get_p_s_from_ksresult(result)['s']

    critical_value_s = calc_ks_critical_value(len(distribution_under_test))
    report_file.write("distribution under test\n")
    report_file.write(str(distribution_under_test) + "\n")
    report_file.write("series I made\n")
    report_file.write(str(series) + "\n")
    if p >= 5e-2 or s <= critical_value_s:
        success = True
    else:
        if report_file is not None:
            report_file.write(
                f"BAD: Two sample kstest result for geometric decay is: statistic={s}, pvalue={p}, expected s less than"
                f" {critical_value_s} and p larger than 0.05.\n")
        success = False
    return success


def convert_remains_to_binomial_chain(remains):
    """
    Convert the remaining # of binomial trials at each time step into # of success cases at each time step.
    Example:
        with binomial prob = 0.1, initial trials = 100 and total time step = 10, we have:
        remains = <class 'list'>: [100, 90, 81, 73, 66, 60, 54, 49, 45, 41]
        binomial_chain = <class 'list'>: [10, 9, 8, 7, 6, 6, 5, 4, 4]
    Args:
        remains: The remaining # of binomial trials at each time step

    Returns: binomial_chain

    """
    binomial_chain = []
    for i in range(len(remains) - 1):
        binomial_chain.append(remains[i] - remains[i + 1])
    return binomial_chain


def convert_binomial_chain_to_geometric(binomial_chain):
    """
    Convert an array of binomial trials results into # of trials before first success, which is a Geometric distribution.
    Example:
        binomial_chain = <class 'list'>: [10, 9, 8, 7, 6, 6, 5, 4, 4]
        geom = <class 'list'>: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 9]
    Args:
        binomial_chain: Array of binomial trials results(# of success at each timestep)

    Returns: Geom

    """
    geom = []
    for i in range(len(binomial_chain)):
        for _ in range(binomial_chain[i]):
            # index i starts at 0 but Geometric distribution are from set{1, 2, 3, ...}(shifted geometric distribution)
            geom.append(i+1)
    return geom


def test_geometric(distribution_under_test, prob, report_file=None, category=None):
    """
    Kstest for Geometric distribution. (note this will have a lower type 1 error, please see "Type 1 Error" section
    for more details)

    In the egg hatching scenario, the distribution_under_test should be an array that contains the
    duration(timesteps) until hatch for each egg. If simulation starts at time 0, then distribution_under_test
    = [1] * # of eggs hatch at day 1 + [2] * # of eggs hatch at day 2 + ... + [i] * # of eggs hatch at day i.

    Type 1 Error:
        The kstest assumes the distribution under test is continuous. Since Geometric distribution is discrete, we will
        get a true type 1 error lower than the one we choose(5%). But depending on "how discrete" the distribution is,
        the type 1 error might be close to 5%. The following tests were performed to get an idea of
        the accuracy of this test:
            Test 1:
                distribution_under_test = stats.geom.rvs(p=0.1, size=100000)
                test_geometric(distribution_under_test, prob=0.1)
                Ran this test 1000 times and it failed 28 times.(Type 1 error is ~3% in this case.)
            Test 2:
                distribution_under_test = stats.geom.rvs(p=0.11, size=100000)
                test_geometric(distribution_under_test, prob=0.1)
                Ran this test 1000 times and it failed 1000 times.
            Test 3:
                distribution_under_test = stats.geom.rvs(p=0.101, size=100000)
                test_geometric(distribution_under_test, prob=0.1)
                Ran this test 1000 times and it failed 327 times.

    Args:
        distribution_under_test: The distribution to be tested
        prob: Success probability for each trial
        report_file: File to write error reporting
        category: Name of test for error reporting
    Returns: True, False
    """

    geom_dist = np.random.geometric(p=prob, size=len(distribution_under_test))

    result = stats.ks_2samp(distribution_under_test, geom_dist)

    p = get_p_s_from_ksresult(result)['p']
    s = get_p_s_from_ksresult(result)['s']

    critical_value_s = calc_ks_critical_value(len(distribution_under_test))

    msg = f"Geometric kstest result for {category} is: statistic={s}, p_value={p}, " \
          f"expected statistic less than {critical_value_s} and p_value larger than 0.05.\n"
    if p >= 5e-2 or s <= critical_value_s:
        if report_file is not None:
            report_file.write("GOOD: " + msg)
        return True
    else:
        if report_file is not None:
            report_file.write("BAD: " + msg)
        return False


def test_poisson(trials, rate, report_file=None, route=None, normal_approximation=False,
                 plot=False, plot_name="plot_data_poisson", confidence_interval=0.05):
    """
    This function test if a distribution is a Poisson distribution with given rate
    Args:
        trials: Distribution to test, it contains values 0, 1, 2, ...
        rate: The average number of events per interval
        report_file:
        route:
        normal_approximation: No longer used in emod.
        plot:
        plot_name:
        confidence_interval: Rejection threshold for K-S test; smaller values allow larger differences

    Returns:

    """
    ref = stats.poisson.rvs(mu=rate, size=1000)
    result = stats.ks_2samp(trials, ref)

    p = get_p_s_from_ksresult(result)['p']
    success = (p >= confidence_interval)

    result_string = "GOOD" if success else "BAD"

    if report_file is not None:
        report_file.write(
            "{0:s}: Poisson two sample kstest result for {1:s} is: pvalue={2:.3f}, expected "
            "p larger than {3:.3f}.\n".format(result_string, route, p, confidence_interval)
        )
    if plot:
        my_plot.plot_data(trials, dist2=ref, label1="test data", label2="numpy_poisson",
                          title="Test data vs. numpy poisson", category=plot_name, sort=True)

    return success


def test_lognorm(timers, mu, sigma, report_file=None, category=None, round=False, plot=False,
                 plot_name="plot_data_lognorm", msg=None):
    """
    Kstest for lognormal distribution
    Args:
        timers: The distribution to test
        mu: Mean of the log-normal distribution's natural logarithm, which is equal to math.log(scale)
        sigma: The standard deviation of the log-normal distribution's natural logarithm
        report_file: For error reporting
        category: For error reporting
        round: Whether or not to round scipy data with 7 significant digits
        plot: Whether or not to plot the test data vs. scipy data or cdf function
        plot_name:
        msg: a list to append the details result message

    Returns: True, False

    """
    scale = math.exp(mu)  # median
    size = len(timers)

    if round:
        dist_lognormal = stats.lognorm.rvs(sigma, 0, scale, size)
        dist_lognormal_2 = []
        for n in dist_lognormal:
            dist_lognormal_2.append(round_to_n_digit(n, 7))
        result = stats.ks_2samp(dist_lognormal_2,
                                timers)  # switch to 2 sample kstest so I can round the data from scipy
        if plot:
            my_plot.plot_data(timers, dist2=dist_lognormal_2, label1="test data", label2="scipy_logormal",
                      title="Test data vs. scipy log normal", category=plot_name, sort=True)
    else:
        result = stats.kstest(timers, 'lognorm', args=(sigma, 0, scale))
        if plot:
            my_plot.plot_cdf_w_fun(timers, name=plot_name, cdf_function=stats.lognorm.cdf, args=(sigma, 0, scale), show=False)

    p = get_p_s_from_ksresult(result)['p']
    s = get_p_s_from_ksresult(result)['s']

    critical_value_s = calc_ks_critical_value(size)

    if p >= 5e-2 or s <= critical_value_s:
        message = f"GOOD: log normal kstest result for {category} is: statistic={s}, pvalue={p}, expected s less than" \
                  f" {critical_value_s} and p larger than 0.05.\n"
        if report_file is not None:
            report_file.write(message)
        if isinstance(msg, list):
            msg.append(message)
        return True
    else:
        message = f"BAD: log normal kstest result for {category} is: statistic={s}, pvalue={p}, expected s less than " \
                  f"{critical_value_s} and p larger than 0.05.\n"
        if report_file is not None:
            report_file.write(message)
        if isinstance(msg, list):
            msg.append(message)
        return False


def test_uniform(dist, p1=None, p2=None, report_file=None, round=False, significant_digits=7,
                 plot=False, plot_name="plot_data_uniform", msg=None):
    """
    Kstest or Chisquare(if p1 and p2 are not provided) for uniform distribution .If p1 and p2 are not provided, this
    only test if a given distribution is equally distributed.
    Args:
        dist: The distribution to be tested
        p1: Loc, min value
        p2: Loc + scale, max value
        report_file: For error reporting
        round: Whether or not to round scipy data with n significant digits
        significant_digits: N, number of significant digits
        plot: Whether or not to plot the test data vs. scipy data
        plot_name:
        msg: a list to append the details result message

    Returns: True, False

    """
    if not p1 and not p2:
        s, p = stats.chisquare(dist)
        critical_value_s = None
        m_template = "This distribution is equally distributed. Chisquare test passed with statistic={0}, pvalue={1}" \
                     ", expected  and p larger than 0.05.\n"
        suffix = m_template.format(s, p)
    else:
        if p1 > p2:  # swap p1 an p2 to make sure p1 is min and p2 is max
            p1, p2 = p2, p1
        loc = p1
        scale = p2 - p1
        size = len(dist)
        dist_uniform_scipy = stats.uniform.rvs(loc, scale, size)
        # dist_uniform_np = np.random.uniform(p1, p2, size)
        if round:
            dist_uniform_scipy_r = []
            for s in dist_uniform_scipy:
                dist_uniform_scipy_r.append(round_to_n_digit(s, significant_digits))
            dist_uniform_scipy = dist_uniform_scipy_r
        
        result = stats.ks_2samp(dist_uniform_scipy,dist)
        p = get_p_s_from_ksresult(result)['p']
        s = get_p_s_from_ksresult(result)['s']
        critical_value_s = calc_ks_critical_value(size)
        m_template="({0},{1})passed with statistic={2}, pvalue={3}, expected s less than {4} and p larger than 0.05.\n"
        suffix = m_template.format(p1, p2, s, p, critical_value_s)
        
        if plot:
            my_plot.plot_data(dist, dist2=dist_uniform_scipy, label1="test data", label2="scipy_uniform",
                      title="Test data vs. scipy uniform", category=plot_name, sort=True)

    # return p >= 5e-2 or s <= critical_value_s
    success = (p >= 5e-2)
    if not success and critical_value_s:
        success = (s <= critical_value_s)
    message = f"GOOD: {suffix}" if success else f"BAD: {suffix}"
    if report_file is not None:
        report_file.write(message)
    if isinstance(msg, list):
        msg.append(message)

    return success


def test_gaussian(dist, p1, p2, allow_negative=True, report_file=None, round=False,
                  plot=False, plot_name="plot_data_gaussian"):
    """
    Kstest for gaussian distribution
    Args:
        dist: The distribution to be tested
        p1: Mean, loc
        p2: Width(standard deviation), scale
        allow_negative: Allow negative value in normal distribution, if False, turn all negative value to 0.0
        report_file: For error reporting
        round: True to round the theoretical distribution to 7 significant digits.
        plot: True to plot the test data vs. scipy data or cdf function
        plot_name:

    Returns: True, False

    """
    size = len(dist)
    if round or not allow_negative:
        # dist_gaussian_np = np.random.normal(p1, p2, size)
        dist_gaussian_scipy = stats.norm.rvs(p1, p2, size)
        dist_gaussian_scipy2 = []
        for n in dist_gaussian_scipy:
            if (not allow_negative) and n < 0:
                n = 0
            if round:
                dist_gaussian_scipy2.append(round_to_n_digit(n, 7))
            else:
                dist_gaussian_scipy2.append(n)
        result = stats.ks_2samp(dist_gaussian_scipy2, dist)
        if plot:
            my_plot.plot_data(dist, dist2=dist_gaussian_scipy2, label1="test data", label2="scipy_gaussian",
                      title="Test data vs. scipy gaussion", category=plot_name, sort=True)
    else:
        result = stats.kstest(dist, "norm", args=(p1, p2))
        if plot:
            my_plot.plot_cdf_w_fun(dist, name=plot_name, cdf_function=stats.norm.cdf, args=(p1, p2), show=False)

    p = get_p_s_from_ksresult(result)['p']
    s = get_p_s_from_ksresult(result)['s']
    critical_value_s = calc_ks_critical_value(size)

    # return p >= 5e-2 or s <= critical_value_s
    if p >= 5e-2 or s <= critical_value_s:
        return True
    else:
        if report_file is not None:
            report_file.write(
                f"BAD: ({p1},{p2})failed with statistic={s}, pvalue={p}, expected s less than {critical_value_s} "
                f"and p larger than 0.05.\n")
        return False


def round_down(num, precision):
    """
    Round value down with n decimal values
    Args:
        num: Value to be rounded up
        precision: N decimal values

    Returns: Float

    """
    multiplier = math.pow(10.0, precision)
    return math.floor(num * multiplier) / multiplier


def round_up(num, precision):
    """
    Round value up with n decimal values
    Args:
        num: Value to be rounded down
        precision: V decimal values

    Returns: Float
    """
    multiplier = math.pow(10.0, precision)
    return math.ceil(num * multiplier) / multiplier


def test_exponential(dist, p1, report_file=None, integers=False, roundup=False, round_nearest=False,
                     plot=False, plot_name="plot_data_exponential", msg=None):
    """
    Kstest for exponential distribution
    Args:
        dist: The distribution to be tested
        p1: Decay rate = 1 / decay length , lambda, >0
        report_file: Report file to which write the error if such exists
        integers: Indicates whether the distribution is rounded up or round nearest to integers or not
        roundup: Use with integers = True
        round_nearest: Use with integers = True
        plot: True to plot the test data vs. scipy data or cdf function
        plot_name:
        msg: a list to append the details result message

    Returns: True, False

    """
    size = max(len(dist), 10000)
    scale = 1.0 / p1
    if integers:
        dist_exponential_np = np.random.exponential(scale, size)
        if round_nearest:
            dist_exponential_np = [round(x) for x in dist_exponential_np]
        elif roundup:
            dist_exponential_np = [round_up(x, 0) for x in dist_exponential_np]
        else:
            dist_exponential_np = [round_down(x, 0) for x in dist_exponential_np]

        result = stats.anderson_ksamp([dist, dist_exponential_np])
        p = result.significance_level
        s = result.statistic
        if plot:
            my_plot.plot_data(dist, dist2=dist_exponential_np, label1="test data", label2="numpy_exponential",
                              title="Test data vs. numpy exponential", category=plot_name, sort=True)
    else:
        result = stats.kstest(dist, "expon", args=(0, scale))
        p = get_p_s_from_ksresult(result)['p']
        s = get_p_s_from_ksresult(result)['s']
        if plot:
            my_plot.plot_cdf_w_fun(dist, name=plot_name, cdf_function=stats.expon.cdf, args=(0, scale), show=True)

    if p >= 5e-2:
        message = "GOOD: ({0})succeed with statistic={1}, pvalue={2}, expected p larger " \
                  "than 0.05.\n".format(p1, s, p)
        if report_file is not None:
            report_file.write(message)
        if isinstance(msg, list):
            msg.append(message)
        return True
    else:
        message = "BAD: ({0})failed with statistic={1}, pvalue={2}, expected p larger " \
                  "than 0.05.\n".format(p1, s, p)
        if report_file is not None:
            report_file.write(message)
        if isinstance(msg, list):
            msg.append(message)
        return False


def test_bimodal(dist, p1, p2, report_file=None):
    """
    Test for bimodal distribution. This bimodal distribution is not a true bimodal distribution, which is defined as
    the overlap of two Gaussians. The definition of bimodal in DTK mathfunctions.cpp is a function that gives an output
    of either 1 or the value of param2. The param1 controls the fraction.

    Args:
        dist: The distribution to be tested
        p1: Faction of param 2
        p2: Multiplier
        report_file: File handle for error reporting

    Returns: True, False

    """
    size = len(dist)
    count1 = 0
    count2 = 0
    for n in dist:
        if n == 1.0:
            count1 += 1
        elif n == p2:
            count2 += 1
        else:
            if report_file is not None:
                report_file.write(
                    "BAD: Binomal distribution contains value = {0}, expected 1.0 or {1}.\n".format(n, p2))
            return False
    actual_faction = count2 / float(size)
    if math.fabs(p1 - actual_faction) <= 5e-2:
        return True
    else:
        if report_file is not None:
            report_file.write(
                "BAD: test Binomal failed with actual fraction = {0}, expected {1}.\n".format(actual_faction, p1))
        return False


def test_weibull(dist, p1, p2, report_file=None, round=False, plot=True, plot_name="plot_data_weibull"):
    """
    Kstest for weibull distribution
    Args:
        dist: The distribution to be tested
        p1: Scale, lambda > 0
        p2: Shape, kappa > 0
        report_file: File handle for error reporting
        round: True to round the theoretical distribution to 7 significant digits.
        plot: True to plot the test data vs. scipy data or cdf function
        plot_name:

    Returns: True, False

    """
    size = max(len(dist), 10000)
    # s = np.random.weibull(p2, size)
    # dist_weibull_np = map(lambda x : x * p1, s)

    if round:
        dist_weibull_scipy = stats.weibull_min.rvs(c=p2, loc=0, scale=p1, size=size)
        dist_weibull_scipy2 = []
        for n in dist_weibull_scipy:
            dist_weibull_scipy2.append(round_to_n_digit(n, 7))
        result = stats.anderson_ksamp([dist, dist_weibull_scipy2])
        dist_weibull_scipy = dist_weibull_scipy2
        p = result.significance_level
        s = result.statistic
        if plot:
            my_plot.plot_data(dist, dist2=dist_weibull_scipy, label1="test data", label2="scipy_weibull",
                              title="Test data vs. scipy weibull", category=plot_name, sort=True)
    else:
        # update to use one sample ks test with weibull cdf function.
        result = stats.kstest(dist, "weibull_min", (p2, 0, p1))
        p = get_p_s_from_ksresult(result)['p']
        s = get_p_s_from_ksresult(result)['s']
        if plot:
            my_plot.plot_cdf_w_fun(dist, name=plot_name, cdf_function=stats.weibull_min.cdf, args=(p2, 0, p1), show=True)

    msg = "({0},{1}) test returns statistic={2} and p value={3}, expected p larger than 0.05.\n".format(
        p1, p2, s, p)
    if p >= 5e-2:
        if report_file is not None:
            report_file.write("GOOD: " + msg)
        return True
    else:
        if report_file is not None:
            report_file.write("BAD: " + msg)
        return False


def test_multinomial(dist, proportions, report_file=None, prob_flag=True):
    """
    Chi-squared test for multinomial data
    Args:
        dist: Array_like, number in each categories
        proportions: Array_like, proportions in each categories
        report_file: File handle for error reporting
        prob_flag: Flag that indicates whether p are proportions or the expected values

    Returns: True or False for test result

    """
    if prob_flag:
        n = sum(dist)
        prob = sum(proportions)
        total = int(n/prob)
        result = stats.chisquare(dist, np.array(proportions) * total, ddof=0 ) #returns chi-square statistic and p value
    else:
        result = stats.chisquare(dist, proportions, ddof=0)
    p = get_p_s_from_ksresult(result)['p']
    s = get_p_s_from_ksresult(result)['s']
    if p >= 5e-2:
        if report_file is not None:
            report_file.write(
                "GOOD: Chi-squared test for multinomial data passed with statistic={0}, pvalue={1}, expected p larger"
                " than 0.05.\ndata for test is {2} and proportion is {3}.\n".format(s, p, dist, proportions))
        return True
    else:
        if report_file is not None:
            report_file.write(
                "BAD: Chi-squared test for multinomial data failed with statistic={0}, pvalue={1}, expected p larger"
                " than 0.05.\ndata for test is {2} and proportion is {3}.\n".format(s, p, dist, proportions))
        return False


def test_gamma(dist, p1, p2, report_file: Union[list, TextIOWrapper, None] = None):
    """
    Kstest for gamma distribution
    Args:
        dist: The distribution to be tested
        p1: K, shape,> 0
        p2: Theta, scale, > 0
        report_file: File handle for error reporting

    Returns: True, False

    """
    loc = 0
    shape = p1
    scale = p2
    size = len(dist)

    result = stats.kstest(dist, cdf='gamma', args=(shape, loc, scale))
    p = get_p_s_from_ksresult(result)['p']
    s = get_p_s_from_ksresult(result)['s']
    critical_value_s = calc_ks_critical_value(size)

    if p >= 5e-2 or s <= critical_value_s:
        msg = f"GOOD: (shape={p1},scale={p2})passed with statistic={s}, pvalue={p}, expected s less than " \
              f"{critical_value_s} and p larger than 0.05.\n"
        succeed = True
    else:
        msg = f"BAD: (shape={p1},scale={p2})failed with statistic={s}, pvalue={p}, expected s less than " \
              f"{critical_value_s} and p larger than 0.05.\n"
        succeed = False

    if isinstance(report_file, TextIOWrapper):
        report_file.write(msg)
    elif isinstance(report_file, list):
        report_file.append(msg)
    return succeed


def is_stats_test_pass(fail_count, pass_count, report_file=None):
    """
    This function determines whether a set of statistic tests(kstest basically) pass. The threshold for p value to
    determine whether the ks test pass is hardcoded as 0.05
    Args:
        fail_count: Total number of failing statistic tests
        pass_count: Total number of passing statistic tests
        report_file: Optional, file handle for error reporting

    Returns: True, False

    """
    count = fail_count + pass_count
    # hardcoded threshold for p value
    threshold_pvalue = 5e-2
    # probability for small probability event
    small_prob_event = 1e-3
    # this is the theoretical average failing count
    mean_pvalue = count * threshold_pvalue

    if fail_count < mean_pvalue:
        if report_file is not None:
            report_file.write("mean_pvalue = {0}, fail_count = {1}.\n".format(mean_pvalue, fail_count))
        return True
    else:
        # could use Normal Approximation with continuity correction (mean > 25) and 6 sigma rule.
        # calculate the cummulative density function
        prob = stats.poisson.cdf(fail_count - 1, mean_pvalue)
        if report_file is not None:
            report_file.write(
                "prob = {0}, mean_pvalue = {1}, fail_count = {2}.\n".format(prob, mean_pvalue, fail_count))
        # test passes when the survival function (1 - cdf) >= probability for samll probability event(0.001, 0.01 or 0.05)
        # higher value is stricter than lower value
        # if <=, which means the small probability event happens in real life so we determine the test fails
        return (1.0 - prob) >= small_prob_event


def round_to_1_digit(x):
    """
    Round number x to 1 significant digit
    Args:
        x: Number to be rounded

    Returns: Float

    """
    if x == float("inf") or x == float("-inf") or x == 0.0:
        return x
    # elif math.fabs(x - 0.0) < 1e-1:
    #     return 0.0
    else:
        return round(x, -int(math.floor(math.log10(abs(x)))))


def round_to_n_digit(x, n):
    """
    Round number x to n significant digit
    Args:
        x: Number to be rounded
        n: # of significant digits

    Returns: Float

    """
    if x == float("inf") or x == float("-inf") or x == 0.0:
        return x
    else:
        return round(x, -int(math.floor(math.log10(abs(x)))) + (n - 1))


def convert_barchart_to_interpolation(population_groups, result_values):
    """
    Convert a barchart to interpolation
    Args:
        population_groups:
        result_values:

    Returns:

    """
    for i in range(len(population_groups)):
        for j in range(0, 2 * len(population_groups[i]), 2):
            age_or_year = population_groups[i][j]
            population_groups[i].insert(j + 1, age_or_year + 0.9999999)
    for i in range(0, 2 * len(result_values), 2):
        age_value = result_values[i]
        age_value_copy = [p for p in age_value]
        result_values.insert(i + 1, age_value_copy)
    for age_value in result_values:
        for i in range(0, 2 * len(age_value), 2):
            age_value.insert(i + 1, age_value[i])


def cal_tolerance_poisson(expected_value, prob=0.05):
    """
    This method calculates the tolerance for expected mean of N draws from a Poisson or Binomial distributions.
    The probability of test value will exceed the tolerance is prob(default value is 5%).

    The equation is based on the cumulative distribution function and error function of a normal
    distribution.

    Applications: The sum of N draws from N independent Poisson distributions is Poisson distributed. When a Poisson has
    rate lambda greater than 10, then normal distribution with mean = lambda and variance = lambda is a good
    approximation to the Poisson distribution. In this case we can use this method for N draws from Poisson distributions.
    Args:
        expected_value: Expected mean
        prob: The probability of test value will exceed the tolerance.

    Returns: Tolerance

    """
    if expected_value < 10:
        raise ValueError("This method only valid with an expected_value >= 10.")
    else:
        from scipy.special import erfinv
        tolerance = -math.sqrt(2) * math.sqrt(expected_value) * erfinv(prob - 1) / expected_value

        return tolerance


def cal_tolerance_binomial(expected_value, binomial_p, prob=0.05):
    """
    This method calculates the tolerance for expected mean of a Binomial distributions. The probability of test value
    will exceed the tolerance is prob(default value is 5%).

    The equation is based on the cumulative distribution function and error function of a normal
    distribution.

    Applications: When a Binomial distribution has a trial n greater than 20 and probability binomial_p which is not
    near 0 or 1, then normal distribution with mean = n*p and variance = n*p*(1-p) is a reasonable approximation for
    the Binomial distribution. In this case, we can also use this method for Binomial distribution.
    Args:
        expected_value: Expected mean from a Binomial distribution.
        binomial_p: P from a Binomial distribution.
        prob: The probability of test value will exceed the tolerance.

    Returns: Tolerance

    """
    if expected_value < 10 or expected_value * (1 - binomial_p) < 10:
        raise ValueError("This method only valid with an expected_value >= 10 and expected_value * (1 - binomial_p) >= 10.")
    else:
        from scipy.special import erfinv
        tolerance = -math.sqrt(2) * math.sqrt(expected_value * (1 - binomial_p)) * erfinv(prob - 1) / expected_value

        return tolerance


def test_eGaussNonNeg(dist, p1, p2, round=False, report_file=None, plot=False, plot_name='plot_data_gaussian',
                      msg=None):
    """
    Kstest for truncated normal distribution(with the lower bound is hard-coded to zero and upper bound is max float
    number.)
    Args:
        dist: The distribution to be tested
        p1: Mean, loc
        p2: Width(standard deviation), scale, sig
        round: True to round the theoretical distribution to 7 significant digits.
        report_file: Report file to write the kstest result detail.
        plot: True to plot the test data vs. scipy data or cdf function
        plot_name:
        msg: a list to append the test message
    Returns: True, False
    """
    size = len(dist)
    a = -p1 / p2
    b = np.inf
    dist_gaussian_scipy = stats.truncnorm.rvs(a, b, p1, p2, size)
    if round:
        dist_gaussian_scipy2 = []
        for n in dist_gaussian_scipy:
            dist_gaussian_scipy2.append(round_to_n_digit(n, 7))
        dist_gaussian_scipy = dist_gaussian_scipy2
    result = stats.ks_2samp(dist_gaussian_scipy, dist)

    if plot:
        my_plot.plot_data(dist, dist2=dist_gaussian_scipy, label1="test data", label2="scipy_gaussian",
                  title="Test data vs. scipy gaussian", category=plot_name, sort=True)

    p = get_p_s_from_ksresult(result)['p']
    s = get_p_s_from_ksresult(result)['s']
    critical_value_s = calc_ks_critical_value(size)

    if p >= 5e-2 or s <= critical_value_s:
        return True
    else:
        message = f"BAD: (mean = {p1}, sigma = {p2})failed with statistic={s}, p_value={p}, " \
                  f"expected s less than {critical_value_s} and p larger than 0.05.\n"
        if report_file is not None:
            report_file.write(message)
        if isinstance(msg, list):
            msg.append(message)
        return False


def test_dual_exponential(dist, m1, m2, p1, report_file=None, plot=False, plot_name="plot_data_exponential"):
    """
     Kstest for dual exponential distribution

    Args:
        dist: The distribution to be tested
        m1: Mean of the first exponential distribution, mean = decay length = 1 / rate, 1 / lambda, >0
        m2: Mean of the second exponential distribution, > 0
        p1: Proportion of the first exponential distribution, 0 < p1 < 1
        report_file: Report file to write the error if such exists
        plot: Plot the test and scipy/numpy data if True
        plot_name: Plot name

    Returns: True / False

    """

    size = len(dist)
    size1 = round(size * p1)
    size2 = size - size1
    dist_exponential_np1 = np.random.exponential(m1, size1)
    dist_exponential_np2 = np.random.exponential(m2, size2)

    dist_np = np.concatenate((dist_exponential_np1, dist_exponential_np2), axis=None)

    result = stats.ks_2samp(dist_np, list(dist))
    # ?? result = stats.kstest(dist, "exponential", args=(p1))

    if plot:
        my_plot.plot_data(dist, dist2=dist_np, label1="test data", label2="numpy_exponential",
                  title="Test data vs. numpy exponential", category=plot_name, sort=True)

    p = get_p_s_from_ksresult(result)['p']
    s = get_p_s_from_ksresult(result)['s']
    critical_value_s = calc_ks_critical_value(size)

    msg = f"{m1}_{m2}_{p1}: ks test return statistic={s}, pvalue={p}, expected s less than {critical_value_s} and p " \
          f"larger than 0.05.\n"
    if p >= 5e-2 or s <= critical_value_s:
        if report_file is not None:
            report_file.write(f"GOOD: {msg}")
        return True
    else:
        if report_file is not None:
            report_file.write(f"BAD: {msg}")
        return False


def mean_f(nums: Optional[list] = None):
    """
    Calculate mean of an array
    Args:
        nums: array like values

    Returns: mean of this array

    """
    return sum(nums)/len(nums)


def variance_f(nums: Optional[list] = None):
    """
    Calculate variance of an array
    Args:
        nums: array like values

    Returns: variance of this array

    """
    return np.std(nums) ** 2
