import abc
import copy
import json
import logging
from collections import OrderedDict
from importlib.metadata import version
from itertools import product
from pathlib import Path

import entitysdk
from pydantic import PrivateAttr, ValidationError

from obi_one.core.block import Block
from obi_one.core.exception import OBIONEError
from obi_one.core.param import MultiValueScanParam, SingleValueScanParam
from obi_one.core.single import SingleConfigMixin, SingleCoordinateScanParams
from obi_one.core.task import Task
from obi_one.scientific.unions.unions_scan_configs import ScanConfigsUnion

L = logging.getLogger(__name__)


class ScanGenerationTask(Task, abc.ABC):
    """Task for creating multiple SingleConfigs where lists with multiple parameters are found."""

    form: ScanConfigsUnion  # REFACTORING NOTE: Should be renmaed to scan_config
    output_root: Path = Path()
    coordinate_directory_option: str = "NAME_EQUALS_VALUE"
    obi_one_version: str | None = None
    _multiple_value_parameters: list = None
    _coordinate_parameters: list = PrivateAttr(default=[])
    _single_configs: list[SingleConfigMixin] = PrivateAttr(default=[])

    @property
    def output_root_absolute(self) -> Path:
        """Returns the absolute path of the output_root."""
        L.info(self.output_root.resolve())
        return self.output_root.resolve()

    @property
    def single_configs(self) -> list[SingleConfigMixin]:
        """Returns the list of single_configs generated by the scan."""
        if len(self._single_configs) == 0:
            msg = "No single_configs have been generated. Please run the execute() method first."
            raise OBIONEError(msg)
        return self._single_configs

    def multiple_value_parameters(self, *, display: bool = False) -> list[MultiValueScanParam]:
        """Iterates through Blocks of self.form to find "multi value parameters".

            (i.e. parameters with list values of length greater than 1)
        - Returns a list of MultiValueScanParam objects
        """
        self._multiple_value_parameters = []

        # Iterate through all attributes of the Form
        for attr_name, attr_value in self.form.__dict__.items():
            # Check if the attribute is a dictionary of Block instances
            if isinstance(attr_value, dict) and all(
                isinstance(dict_val, Block) for dict_key, dict_val in attr_value.items()
            ):
                category_name = attr_name
                category_blocks_dict = attr_value

                # If so iterate through the dictionary's Block instances
                for block_key, block in category_blocks_dict.items():
                    # Call the multiple_value_parameters method of the Block instance
                    block_multi_value_parameters = block.multiple_value_parameters(
                        category_name=category_name, block_key=block_key
                    )
                    if len(block_multi_value_parameters):
                        self._multiple_value_parameters.extend(block_multi_value_parameters)

            # Else if the attribute is a Block instance, call the _multiple_value_parameters method
            # of the Block instance
            if isinstance(attr_value, Block):
                block_name = attr_name
                block = attr_value
                block_multi_value_parameters = block.multiple_value_parameters(
                    category_name=block_name
                )
                if len(block_multi_value_parameters):
                    self._multiple_value_parameters.extend(block_multi_value_parameters)

        # Optionally display the multiple_value_parameters
        if display:
            L.info("\nMULTIPLE VALUE PARAMETERS")
            if len(self._multiple_value_parameters) == 0:
                L.info("No multiple value parameters found.")
            else:
                for multi_value in self._multiple_value_parameters:
                    L.info(f"{multi_value.location_str}: {multi_value.values}")

        # Return the multiple_value_parameters
        return self._multiple_value_parameters

    @property
    def multiple_value_parameters_dictionary(self) -> dict:
        d = {}
        for multi_value in self.multiple_value_parameters():
            d[multi_value.location_str] = multi_value.values

        return d

    def coordinate_parameters(self) -> list[SingleCoordinateScanParams]:
        """Must be implemented by a subclass of Scan."""
        msg = "coordinate_parameters() must be implemented by a subclass of Scan."
        raise NotImplementedError(msg)

    def create_single_configs(self) -> list[SingleConfigMixin]:
        """Coordinate instance.

        - Returns a list of "coordinate instances" by:
            - Iterating through self.coordinate_parameters()
            - Creating a single "coordinate instance" for each single coordinate parameter

        - Each "coordinate instance" is created by:
            - Making a deep copy of the form
            - Editing the multi value parameters (lists) to have the values of the single
                coordinate parameters
                (i.e. timestamps.timestamps_1.interval = [1.0, 5.0] ->
                    timestamps.timestamps_1.interval = 1.0)
            - Casting the form to its single_config_class_name type
                (i.e. CircuitSimulationScanConfig -> CircuitSimulationSingleConfig)
        """
        single_configs = []

        # Iterate through coordinate_parameters
        for idx, single_coordinate_scan_params in enumerate(self.coordinate_parameters()):
            # Make a deep copy of self.form
            single_coord_config = copy.deepcopy(self.form)

            # Iterate through parameters in the single_coordinate_parameters tuple
            # Change the value of the multi parameter from a list to the single value of the
            # coordinate
            for scan_param in single_coordinate_scan_params.scan_params:
                level_0_val = single_coord_config.__dict__[scan_param.location_list[0]]

                # If the first level is a Block
                if isinstance(level_0_val, Block):
                    level_0_val.__dict__[scan_param.location_list[1]] = scan_param.value

                # If the first level is a category dictionary
                if isinstance(level_0_val, dict):
                    level_1_val = level_0_val[scan_param.location_list[1]]
                    if isinstance(level_1_val, Block):
                        level_1_val.__dict__[scan_param.location_list[2]] = scan_param.value
                    else:
                        msg = f"Non Block parameter {level_1_val} found in Form dictionary: \
                            {level_0_val}"
                        raise TypeError(msg)

            try:
                # Cast the form to its single_config_class_name type
                single_coord_config = single_coord_config.cast_to_single_coord()

                # Set the variables of the coordinate instance related to the scan
                single_coord_config.idx = idx
                single_coord_config.single_coordinate_scan_params = single_coordinate_scan_params

                # Append the coordinate instance to self._coordinate_instances
                single_configs.append(single_coord_config)

            except ValidationError as e:
                raise ValidationError(e) from e

        # Return single_configs
        return single_configs

    def serialize(self, output_path: Path) -> dict:
        """Serialize a Scan object.

        - type name added to each subobject of type
            inheriting from OBIBaseModel for future deserialization
        """
        # Important to use model_dump_json() instead of model_dump()
        # so OBIBaseModel's custom encoder is used to seri
        # PosixPaths as strings
        model_dump = self.model_dump_json()

        # Now load it back into an ordered dict to do some additional modifications
        model_dump = OrderedDict(json.loads(model_dump))

        # Add the obi_one version to the model_dump
        model_dump["obi_one_version"] = version("obi-one")

        # Order keys in dict
        model_dump.move_to_end("output_root", last=False)
        model_dump.move_to_end("type", last=False)
        model_dump.move_to_end("obi_one_version", last=False)

        # Order the keys in subdict "form"
        model_dump["form"] = OrderedDict(model_dump["form"])
        model_dump["form"].move_to_end("type", last=False)

        # Create the directory and write dict to json file
        if output_path:
            with output_path.open("w", encoding="utf-8") as json_file:
                json.dump(model_dump, json_file, indent=4)

        return model_dump

    def display_coordinate_parameters(self) -> None:
        L.info("\nCOORDINATE PARAMETERS")
        for single_coordinate_parameters in self._coordinate_parameters:
            single_coordinate_parameters.display_parameters()

    def execute(
        self,
        db_client: entitysdk.client.Client = None,
    ) -> None:
        Path.mkdir(self.output_root, parents=True, exist_ok=True)

        # Serialize the scan
        self.serialize(self.output_root / "obi_one_scan.json")

        # Create the campaign entity
        campaign = None
        if db_client and hasattr(self.form, "create_campaign_entity_with_config"):
            campaign = self.form.create_campaign_entity_with_config(
                output_root=self.output_root,
                multiple_value_parameters_dictionary=self.multiple_value_parameters_dictionary,
                db_client=db_client,
            )

        # Create the single_configs
        self._single_configs = self.create_single_configs()

        # Iterate through single_configs
        for single_coord_config in self._single_configs:
            single_coord_config.initialize_coordinate_output_root(
                self.output_root, self.coordinate_directory_option
            )

            # Serialize the coordinate instance
            single_coord_config.serialize(
                single_coord_config.coordinate_output_root / "obi_one_coordinate.json"
            )

            # Create the single coordinate entity
            if db_client and hasattr(single_coord_config, "create_single_entity_with_config"):
                single_coord_config.create_single_entity_with_config(
                    campaign=campaign, db_client=db_client
                )

        # Create the campaign generation entity
        if db_client and hasattr(self.form, "create_campaign_generation_entity"):
            single_entities = [sc.single_entity for sc in self._single_configs]
            self.form.create_campaign_generation_entity(single_entities, db_client=db_client)


class GridScanGenerationTask(ScanGenerationTask):
    """Description."""

    def coordinate_parameters(self, *, display: bool = False) -> list[SingleCoordinateScanParams]:
        """Description."""
        single_values_by_multi_value = []
        multi_value_parameters = self.multiple_value_parameters()

        if len(multi_value_parameters):
            for multi_value in multi_value_parameters:
                single_values = [
                    SingleValueScanParam(location_list=multi_value.location_list, value=value)
                    for value in multi_value.values
                ]

                single_values_by_multi_value.append(single_values)

            self._coordinate_parameters = []
            for scan_params in product(*single_values_by_multi_value):
                self._coordinate_parameters.append(
                    SingleCoordinateScanParams(scan_params=scan_params)
                )

        else:
            self._coordinate_parameters = [
                SingleCoordinateScanParams(
                    nested_coordinate_subpath_str=self.form.single_coord_scan_default_subpath
                )
            ]

        # Optionally display the coordinate parameters
        if display:
            self.display_coordinate_parameters()

        # Return the coordinate parameters
        return self._coordinate_parameters


class CoupledScanGenerationTask(ScanGenerationTask):
    """Description."""

    def coordinate_parameters(self, *, display: bool = False) -> list:
        """Description."""
        previous_len = -1

        multi_value_parameters = self.multiple_value_parameters()
        if len(multi_value_parameters):
            for multi_value in multi_value_parameters:
                current_len = len(multi_value.values)
                if previous_len not in {-1, current_len}:
                    msg = f"Multi value parameters have different lengths: {previous_len} and \
                            {current_len}"
                    raise ValueError(msg)

                previous_len = current_len

            n_coords = current_len

            self._coordinate_parameters = []
            for coord_i in range(n_coords):
                scan_params = [
                    SingleValueScanParam(
                        location_list=multi_value.location_list,
                        value=multi_value.values[coord_i],
                    )
                    for multi_value in multi_value_parameters
                ]
                self._coordinate_parameters.append(
                    SingleCoordinateScanParams(scan_params=scan_params)
                )

        else:
            self._coordinate_parameters = [
                SingleCoordinateScanParams(
                    nested_coordinate_subpath_str=self.form.single_coord_scan_default_subpath
                )
            ]

        if display:
            self.display_coordinate_parameters()

        return self._coordinate_parameters
