import unittest
import numpy as np

import idm_test.stats_test as stats_test


class StatsTest(unittest.TestCase):
    def setUp(self) -> None:
        print(f"\n{self._testMethodName} started...")
        self.sample_size = 1000
        self.total_number_of_tests = 1000
        self.failed_test = 0
        self.msg = list()
        self.happy_path = True
        self.success = True
        pass

    def tearDown(self) -> None:
        pass

    def verify_result(self):
        if self.happy_path:
            if self.failed_test / self.total_number_of_tests > 0.075:
                self.success = False
        else:
            if self.failed_test / self.total_number_of_tests < 0.75:  # assume expected sensitivity is 75%
                self.success = False
        message = f"{self._testMethodName} failed {self.failed_test} times in " \
                  f"{self.total_number_of_tests} total tests.\n Happy path test: {self.happy_path}.\n"
        if self.success:
            self.msg.append(f"GOOD: {message}")
        else:
            self.msg.append(f"BAD: {message}")

    def test_uniform(self):
        p1 = 1
        p2 = 3
        for _ in range(self.total_number_of_tests):
            np_uniform = np.random.uniform(p1, p2, self.sample_size)
            result = stats_test.test_uniform(np_uniform, p1=p1, p2=p2, report_file=None, round=False, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_uniform_swap_values(self):
        p1 = 1
        p2 = 0
        for _ in range(self.total_number_of_tests):
            np_uniform = np.random.uniform(p2, p1, self.sample_size)
            result = stats_test.test_uniform(np_uniform, p1=p1, p2=p2, report_file=None, round=False, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_uniform_larger_value(self):
        p1 = 100
        p2 = 300
        for _ in range(self.total_number_of_tests):
            np_uniform = np.random.uniform(p1, p2, self.sample_size)
            result = stats_test.test_uniform(np_uniform, p1=p1, p2=p2, report_file=None, round=False, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_uniform_round_data(self):
        p1 = 100
        p2 = 300
        for _ in range(self.total_number_of_tests):
            np_uniform = np.random.uniform(p1, p2, self.sample_size)
            np_uniform = [stats_test.round_to_n_digit(val, 4) for val in np_uniform]
            result = stats_test.test_uniform(np_uniform, p1=p1, p2=p2, report_file=None, round=True,
                                             significant_digits=4, plot=False, msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_uniform_equally_distributed(self):
        p1 = 0.1
        p2 = 1.1
        self.total_number_of_tests = 100
        for _ in range(self.total_number_of_tests):
            np_uniform = np.random.uniform(p1, p2, self.sample_size)
            result = stats_test.test_uniform(np_uniform, report_file=None, round=False, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_uniform_wrong_input(self):
        self.happy_path = False
        p1 = 50
        p2 = 100
        self.sample_size = 10000
        self.total_number_of_tests = 100
        for _ in range(self.total_number_of_tests):
            np_uniform = np.random.uniform(p1, p2, self.sample_size)
            result = stats_test.test_uniform(np_uniform, p1=p1-1, p2=p2-1, report_file=None, round=False, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_uniform_with_flat_normal(self):
        self.happy_path = False
        p1 = 50
        self.total_number_of_tests = 100
        for _ in range(self.total_number_of_tests):
            np_normal = np.random.normal(0, p1, self.sample_size)
            result = stats_test.test_uniform(np_normal, p1=-200, p2=200, report_file=None, round=False, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_lognorm(self):
        mu = 2
        sigma = 0.4
        for _ in range(self.total_number_of_tests):
            np_lognorm = np.random.lognormal(mu, sigma, self.sample_size)
            result = stats_test.test_lognorm(np_lognorm, mu, sigma, report_file=None, round=False, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_lognorm_round_data(self):
        mu = 2
        sigma = 0.4
        for _ in range(self.total_number_of_tests):
            np_lognorm = np.random.lognormal(mu, sigma, self.sample_size)
            np_lognorm = [stats_test.round_to_n_digit(val, 7) for val in np_lognorm]
            result = stats_test.test_lognorm(np_lognorm, mu, sigma, report_file=None, round=True, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_lognorm_with_weibull_data(self):
        self.happy_path = False
        mu = 0
        sigma = 0.5

        # lambda_value = 1, this by default
        k = 1.5
        self.total_number_of_tests = 100
        for _ in range(self.total_number_of_tests):
            np_weibull = np.random.weibull(k, self.sample_size)
            result = stats_test.test_lognorm(np_weibull, mu, sigma, report_file=None, round=False, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])

    def test_lognorm_wrong_inputs(self):
        self.happy_path = False
        mu = 2
        sigma = 0.4
        self.total_number_of_tests = 100
        for _ in range(self.total_number_of_tests):
            np_lognorm = np.random.lognormal(mu-0.1, sigma-0.01, self.sample_size)
            result = stats_test.test_lognorm(np_lognorm, mu, sigma, report_file=None, round=False, plot=False,
                                             msg=self.msg)
            if not result:
                self.failed_test += 1
        self.verify_result()
        self.assertTrue(self.success, msg=self.msg[-1])
        print(self.msg[-1])


if __name__ == '__main__':
    unittest.main()
