import unittest
import os

from emodpy.emod_task import EMODTask
from emod_api.config.from_overrides import flattenConfig

from idmtools_platform_comps.utils.general import update_item
from idmtools.core.platform_factory import Platform
from idmtools.core import ItemType
from idmtools.builders import SimulationBuilder
from idmtools.entities.experiment import Experiment
from idmtools.assets import Asset

from idm_test.dtk_test.sft import sft_output_filename, format_success_msg
from idm_test.dtk_test.integration.download_eradication import download_model_files
from idm_test.dtk_test.integration.create_singularity import build_sft_singularity


class IntegrationTest(unittest.TestCase):
    """
    A base class for running SFT with multiple IDM products(Emodpy/idmtools/Comps/idm-test).

    Example:
        Example of using this class to create a test:

        from idm_test.dtk_test.integration.integration_test import IntegrationTest
        from idm_test.dtk_test.integration import manifest, bootstrap

        def update_sim_random_bic(simulation, alpha, beta):
            # define a sweep_fn_cb
        ...


        def set_param_fn(config):
            # define a set_param_fn
        ...


        def build_camp():
            # define a camp_fn_cb
        ...


        def build_demog():
            # define a demog_fn_cb
        ...


        class TestNonuniformShedding(IntegrationTest):
            def setUp(self):
                self.test_name = 'Test_Enable_Nonuniform_Shedding'
                bootstrap.setup()

            def test_nonuniform_shedding_sft(self):
                self.run_test(camp_fn_cb=build_camp, demog_fn_cb=build_demog, sweep_fn_cb=update_sim_random_bic,
                              sweep_values=([1, 2, 3], [1, 1.5, 4]), manifest=manifest, set_param_fn=set_param_fn,
                              force_build_exe=False)

    """
    exp = None
    platform = None
    test_name = 'SFT integration test'

    def run_test(self, camp_fn_cb=None, demog_fn_cb=None, sweep_fn_cb=None, sweep_values=None,
                 manifest=None, set_param_fn=None, force_build_exe=False, suite_id=None):
        """
        This is a base test function that each unittest can call to get model files, run SFTs in Comps and then asset on
        SFT results and update sim/exp tags.
        Args:
            camp_fn_cb: The campaign build function, which builds campaign object with emod_api.campaign and
                emod_api.interventions or emodpy-**disease**.interventions.
            demog_fn_cb: The demographic build function, which builds demographics object with emod_api.demographics or
                emodpy-**disease**.demographics.
            sweep_fn_cb: The sweep function, which must include a **simulation** parameter and one or multiple config
                parameter(s). If undefined, will use self._update_sim_random_seed() to sweep on Run_Number.
            sweep_values: The list of values to call the sweep_fn_cb function with.
            manifest: The file which defines all input paths and common variables for test environment. Suggest each
                test to import idm_test.dtk_test.integration.manifest as manifest unless we want to customize a test.
            set_param_fn: The function which update the config parameters from default values. If undefined, it loads
                parameters from "param_overrides.json" and overrides default config file. But using
                "param_overrides.json" is not recommended.
            force_build_exe: If force_build_exe is True, build test sugar enabled Eradication in Comps with singularity
                container defined by user, even if the same container is already in Comps. Use force_build_exe = True if
                there is code change in the target branch that you want to test.
                If force_build_exe is False and container is in Comps or eradication ac id is found in disk, skip
                building new singularity image.
            suite_id: Suite_id as optional parameter. If pass in, it will save experiment to the suite_id

        Returns: None

        """
        self._prepare(manifest, force_build_exe)
        self._run(camp_fn_cb, demog_fn_cb, sweep_fn_cb, sweep_values, manifest, set_param_fn, suite_id)
        self._check_result()

    def _prepare(self, manifest, force_build_exe):
        """
        Prepare Eradication and other model files and test environment in Comps
        Args:
            manifest:
            force_build_exe:

        Returns:

        """
        self.platform = Platform(manifest.platform)
        if not os.path.isfile(manifest.eradication_path):
            # If force_build_exe is True, build test sugar enabled Eradication in Comps with
            # singularity container defined by user.
            # If force_build_exe is not True and eradication ac id is not found in disk, try to download from bamboo.
            download_model_files(self.platform, manifest, force_build_exe)
        assert os.path.isfile(manifest.eradication_path)

    def _run(self, camp_fn_cb, demog_fn_cb, sweep_fn_cb, sweep_values, manifest, set_param_fn, suite_id):
        """
        Generate model inputs and run simulations with sweeping parameters and dtk_post_process.py in Comps.
        Args:
            camp_fn_cb:
            demog_fn_cb:
            sweep_fn_cb:
            sweep_values:
            manifest:
            set_param_fn:

        Returns:

        """
        if not set_param_fn:
            # update config with param_overrides.json
            set_param_fn = self._set_param_fn

        # create an emod task
        task = EMODTask.from_default2(
            config_path=manifest.config_path,
            eradication_path=manifest.eradication_path,
            campaign_builder=camp_fn_cb,
            demog_builder=demog_fn_cb,
            schema_path=manifest.schema_file,
            param_custom_cb=set_param_fn,
            ep4_custom_cb=self._add_ep4
        )

        # run with sft singularity image
        if not os.path.isfile(manifest.sft_id):
            # if the singularity image is not built yet, build it with default sft singularity definition
            build_sft_singularity(self.platform, manifest)
        task.set_sif(manifest.sft_id)

        # sweep on either config parameters, if sweep_fn_cb is not defined, sweep on Run_Number
        builder = SimulationBuilder()
        if sweep_fn_cb:
            if not sweep_values:
                raise ValueError("sweep_values is not defined.")
            if isinstance(sweep_values, tuple):
                *args, = sweep_values
                builder.add_multiple_parameter_sweep_definition(sweep_fn_cb, *args)
            elif isinstance(sweep_values, list):
                # check if sweep_values is nested list
                if any(isinstance(i, list) for i in sweep_values):
                    *args, = sweep_values
                    builder.add_multiple_parameter_sweep_definition(sweep_fn_cb, *args)
                else:
                    builder.add_sweep_definition(sweep_fn_cb, sweep_values)
            else:
                raise ValueError(f"sweep_values should be a tuple or list/nested list, get a {type(sweep_values)}.")
        else:
            builder.add_sweep_definition(self._update_sim_random_seed, range(manifest.n_sims))

        # create experiment from builder
        self.experiment = Experiment.from_builder(builder, task, name=self.test_name)
        if suite_id:
            suite = self.platform.get_item(item_id=suite_id, item_type=ItemType.SUITE)
            suite.add_experiment(self.experiment)
        # The last step is to call run() on the ExperimentManager to run the simulations.
        self.experiment.run(wait_until_done=True, platform=self.platform)

        # Check experiment result
        if not self.experiment.succeeded:
            raise Exception("Experiment Failed")
        # Else if experiment succeeded, save experiment id to file
        else:
            output_filename = manifest.exp_id
            with open(output_filename, 'a') as file:
                file.write(self.experiment.id)
                file.write("\n")

    def _check_result(self):
        """
        Check SFT results from Comps output and update sim/exp tags with result.
        Returns:

        """
        exp_result = True
        for sim in self.experiment.simulations:
            # get sft report txt from Comps
            report = self.experiment.platform.get_files(sim, [sft_output_filename])
            sft_result = report[sft_output_filename].decode("utf-8")

            # assert SFT result
            success_mag = format_success_msg(True)
            sim_result = success_mag in sft_result
            if not sim_result:
                print(f'Simulation: {sim.id} failed the SFT, please see details in {sim.id}.txt.')
                with open(f'{sim.id}.txt', 'w') as result_file:
                    result_file.write(sft_result)
                exp_result = False
            # update simulation tag based on SFT result
            self._update_sim_tags(sim, sim_result)

        # update experiment tag based on all simulation results
        self._update_exp_tags(self.experiment, exp_result)
        assert exp_result

    def _set_param_fn(self, config):
        try:
            flattened_config = flattenConfig("param_overrides.json")
            for key, val in flattened_config["parameters"].items():
                config.parameters[key] = val
            return config
        except Exception as e:
            raise ValueError(f"Can't flatten 'param_overrides.json', got exception: {e}.")

    def _add_ep4(self, task):
        ep4_dir = "ep4_dir"
        for entry_name in os.listdir(ep4_dir):
            full_path = os.path.join(ep4_dir, entry_name)
            if os.path.isfile(full_path) and entry_name.endswith(".py") and entry_name.startswith("dtk_"):
                py_file_asset = Asset(full_path, relative_path="python")
                task.common_assets.add_asset(py_file_asset)
        return task

    def _update_sim_random_seed(self, simulation, value):
        simulation.task.config.parameters.Run_Number = value
        return {"Run_Number": value}

    def _update_tags(self, item, item_type, result):
        tags = item.tags
        tags['SFT result'] = result
        update_item(self.platform, item.id, item_type, tags)

    def _update_sim_tags(self, sim, result):
        self._update_tags(sim, ItemType.SIMULATION, result)

    def _update_exp_tags(self, exp, result):
        self._update_tags(exp, ItemType.EXPERIMENT, result)

