"""
peakrdl-python is a tool to generate Python Register Access Layer (RAL) from SystemRDL
Copyright (C) 2021 - 2025

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

This module is intended to distributed as part of automatically generated code by the
peakrdl-python tool. It provides a set of base classes used by the autogenerated code
"""
from __future__ import annotations
import logging
from typing import Optional, Union, TypeVar, Any
from collections.abc import Iterator, Sequence
from abc import ABC, abstractmethod
from itertools import product
from enum import IntEnum, Enum, auto
import math

from .callbacks import CallbackSet, CallbackSetLegacy

UDPStruct = dict[str, 'UDPType']
UDPType = Union[str, int, bool, IntEnum, UDPStruct]

class Base(ABC):
    """
    base class of for all types
    """
    __slots__: list[str] = ['__logger', '__inst_name', '__parent']

    def __init__(self, *,
                 logger_handle: str, inst_name: str, parent: Optional[Union['Node', 'NodeArray']]):

        if not isinstance(logger_handle, str):
            raise TypeError(f'logger_handle should be str but got {type(logger_handle)}')

        self.__logger = logging.getLogger(logger_handle)
        self._logger.debug(f'creating instance of {self.__class__}')

        if not isinstance(inst_name, str):
            raise TypeError(f'inst_name should be str but got {type(inst_name)}')
        self.__inst_name = inst_name

        if parent is not None:
            if not isinstance(parent, (Node, NodeArray)):
                raise TypeError(f'parent should be Node or Node Array but got {type(parent)}')
        self.__parent = parent

    @property
    def _logger(self) -> logging.Logger:
        return self.__logger

    @property
    def inst_name(self) -> str:
        """
        systemRDL name of the instance in the parent
        """
        return self.__inst_name

    @property
    def parent(self) -> Optional[Union['Node', 'NodeArray']]:
        """
        parent of the node or field, or None if it has no parent
        """
        return self.__parent

    @property
    def full_inst_name(self) -> str:
        """
        The full hierarchical systemRDL name of the instance
        """
        if self.parent is not None:
            if isinstance(self.parent, NodeArray):
                if self.parent.parent is None:
                    raise RuntimeError('A node array must have a parent it can not be the '
                                       'top level')
                return self.parent.parent.full_inst_name + "." + self.inst_name

            return self.parent.full_inst_name + "." + self.inst_name

        return self.inst_name

    @property
    def udp(self) -> UDPStruct:
        """
        A dictionary of the user defined properties for the node
        """
        return {}

    @property
    def rdl_name(self) -> Optional[str]:
        """
        The systemRDL name property for the item
        """
        return None

    @property
    def rdl_desc(self) -> Optional[str]:
        """
        The systemRDL desc property for the item
        """
        return None

class IterationClassification(Enum):
    """
    Enumation to classify different node types so they can be filtered for the purpose of
    iteration
    """
    REGISTER = auto()
    MEMORY = auto()
    SECTION = auto()

# FieldType = TypeVar('NodeElementType', bound=Node|NodeArray)
# However, python 3.9 does not support the combination so the binding was removed
# pylint: disable-next=invalid-name
class Node(Base ,ABC):
    """
    base class of for all types with an address i.e. not fields

    Note:
        It is not expected that this class will be instantiated under normal
        circumstances however, it is useful for type checking
    """

    # in order to avoid circular import loops, the node class needs to keep track of its type
    # so that this can be used in the iteration filters
    _iteration_classification: IterationClassification

    __slots__ = ['__address','_iteration_classification']

    def __init_subclass__(cls, **kwargs:Any) -> None:
        super().__init_subclass__(**kwargs)

        # python does not have the concept of an abstract class attribute, this is the recommended
        # design pattern, to check that the subclass as set a class attribute

        if not hasattr(cls, "_iteration_classification"):
            raise AttributeError(f"{cls.__name__} must define class attribute "
                                 "'_iteration_classification'")
        classification = getattr(cls, "_iteration_classification")
        if not isinstance(classification, IterationClassification):
            raise TypeError('_iteration_classification should be a IterationClassification '
                            f'but got {type(classification)}')


    def __init__(self, *,
                 address: int,
                 logger_handle: str,
                 inst_name: str,
                 parent: Optional[Union['Node', 'NodeArray']]):
        super().__init__(logger_handle=logger_handle, inst_name=inst_name, parent=parent)

        if not isinstance(address, int):
            raise TypeError(f'address should be int but got {type(address)}')

        self.__address = address

    @property
    def address(self) -> int:
        """
        address of the node
        """
        return self.__address

    @property
    @abstractmethod
    def _callbacks(self) -> Union[CallbackSet, CallbackSetLegacy]:
        ...

    @property
    @abstractmethod
    def systemrdl_python_child_name_map(self) -> dict[str, str]:
        """
        In some cases systemRDL names need to be converted make them python safe, this dictionary
        is used to map the original systemRDL names to the names of the python attributes of this
        class

        Returns: dictionary whose key is the systemRDL names and value it the property name
        """

    def get_child_by_system_rdl_name(self, name: Any) -> Any:
        """
        returns a child node by its systemRDL name

        Args:
            name: name of the node in the systemRDL

        Returns: Node

        """
        if not isinstance(name, str):
            raise TypeError(f'name must be a string got {type(name)}')
        return getattr(self, self.systemrdl_python_child_name_map[name])

    @property
    @abstractmethod
    def size(self) -> int:
        """
        Total Number of bytes of address the node occupies
        """


# pylint: disable-next=invalid-name
NodeArrayElementType = TypeVar('NodeArrayElementType', bound=Node)


class NodeArray(Base, Sequence[NodeArrayElementType]):
    """
    base class of for all array types
    """

    # pylint: disable=too-few-public-methods
    __slots__: list[str] = ['__elements', '__element_indices', '__address',
                            '__stride', '__dimensions',
                            '_iteration_classification']

    # in order to avoid circular import loops, the node class needs to keep track of its type
    # so that this can be used in the iteration filters
    _iteration_classification: IterationClassification

    def __init_subclass__(cls, **kwargs:Any) -> None:
        super().__init_subclass__(**kwargs)

        # python does not have the concept of an abstract class attribute, this is the recommended
        # design pattern, to check that the subclass as set a class attribute

        if not hasattr(cls, "_iteration_classification"):
            raise AttributeError(f"{cls.__name__} must define class attribute "
                                 "'_iteration_classification'")
        classification = getattr(cls, "_iteration_classification")
        if not isinstance(classification, IterationClassification):
            raise TypeError('_iteration_classification should be a IterationClassification '
                            f'but got {type(classification)}')

    # pylint: disable-next=too-many-arguments
    def __init__(self, *, logger_handle: str,
                 inst_name: str,
                 parent: Node,
                 address: int,
                 stride: int,
                 dimensions: tuple[int, ...],
                 elements: Optional[tuple[tuple[tuple[int, ...], ...],
                                          tuple[NodeArrayElementType, ...]]] = None):

        super().__init__(logger_handle=logger_handle, inst_name=inst_name, parent=parent)

        if not isinstance(address, int):
            raise TypeError(f'address should be a int but got {type(dimensions)}')
        self.__address = address
        if not isinstance(stride, int):
            raise TypeError(f'stride should be a int but got {type(dimensions)}')
        self.__stride = stride

        if not isinstance(dimensions, tuple):
            raise TypeError(f'dimensions should be a tuple but got {type(dimensions)}')
        for dimension in dimensions:
            if not isinstance(dimension, int):
                raise TypeError(f'dimension should be a int but got {type(dimension)}')
        self.__dimensions = dimensions

        # There are two use cases for this class:
        # 1. Initial creation - elements is None in which case the data is populated
        # 2. Creating a recursive version of itself, this happens when it is sliced by the parent
        #    in which case a subset of the elements is presented and a new instance

        if elements is not None:
            self.__check_init_element(elements)
            self.__element_indices = elements[0]
            self.__elements = elements[1]
        else:
            # build the full set of indices
            self.__element_indices = \
                tuple(indices for indices in product(*[range(dim) for dim in self.dimensions]))
            self.__elements = tuple(self.__build_element(indices=index)
                                    for index in self.__element_indices)

    def __build_element(self, indices: tuple[int, ...]) -> NodeArrayElementType:

        return self._element_datatype(
                    logger_handle=self.__build_element_logger_handle(indices=indices),
                    address=self.__address_calculator(indices),
                    inst_name=self.__build_element_inst_name(indices=indices),
                    parent=self)

    def __build_element_logger_handle(self, indices: tuple[int, ...]) -> str:
        """
        Build the logger handle for an element in the array

        Args:
            indices: element index

        Returns:
        """
        return self._logger.name + '[' + ','.join([str(item) for item in indices]) + ']'

    def __build_element_inst_name(self, indices: tuple[int, ...]) -> str:
        """
        Build the logger handle for an element in the array

        Args:
            indices: element index

        Returns:
        """
        return self.inst_name + '[' + ']['.join([str(item) for item in indices]) + ']'

    def __check_init_element(self,
                             elements: tuple[tuple[tuple[int, ...], ...],
                                             tuple[NodeArrayElementType, ...]]) -> None:
        """
        Used in the __init__ to check that the elements passed in are valid
        Args:
            elements: proposed element of the array

        Returns:

        """
        if not isinstance(elements, tuple):
            raise TypeError(f'elements should be a tuple but got {type(elements)}')

        for index, item in zip(elements[0], elements[1]):
            if not isinstance(index, tuple):
                raise TypeError(f'element index should be a tuple but got {type(index)}')

            if len(index) != len(self.dimensions):
                raise ValueError(f'size of index does not match index length = {len(index)}')

            for index_pos, index_item in enumerate(index):
                if not isinstance(index_item, int):
                    raise TypeError(f'element index_item should be a int '
                                    f'but got {type(index_item)}')

                if not 0 <= index_item < self.dimensions[index_pos]:
                    raise ValueError('index outside of range of dimensions')

            if not isinstance(item, self._element_datatype):
                raise TypeError(f'elements should be a {self._element_datatype} '
                                f'but got {type(item)}')

            expected_address = self.__address_calculator(index)
            if expected_address != item.address:
                raise ValueError('Expected Address of the item and actual address differ')

    def __address_calculator(self, indices: tuple[int, ...]) -> int:
        def cal_addr(dimensions: tuple[int,...], indices: tuple[int, ...], base_address: int,
                     stride: int) -> int:
            """
            Calculates the address of an register within an array

            :param dimensions: list of the array dimensions
            :param indices: list of the array indices (length must match the dimensions)
            :param base_address: base address of the array
            :param stride: address stride of of the array
            :return: address of the register
            """
            if len(dimensions) == 1:
                return (indices[0] * stride) + base_address

            outer_offset = math.prod(dimensions[1::]) * stride * indices[0]
            return outer_offset + cal_addr(dimensions=dimensions[1::],
                                           indices=indices[1::],
                                           stride=self.stride,
                                           base_address=self.address)

        return cal_addr(self.dimensions, base_address=self.address,
                        stride=self.stride, indices=indices)

    def __sub_instance(self, elements: tuple[tuple[tuple[int, ...],...],
                                             tuple[NodeArrayElementType, ...]]) -> \
            NodeArray[NodeArrayElementType]:
        if not isinstance(self.parent, Node):
            raise RuntimeError('Parent of a Node Array must be Node')
        return self.__class__(logger_handle=self._logger.name,
                              inst_name=self.inst_name,
                              parent=self.parent,
                              address=self.address,
                              stride=self.stride,
                              dimensions=self.dimensions,
                              elements=elements)

    def __getitem__(self, item):  # type: ignore[no-untyped-def]
        if len(self.dimensions) > 1:
            return self.__getitem_nd(item)

        if isinstance(item, tuple):
            if len(item) == 1:
                return self.__getitem__(item[0])
            raise IndexError('attempting a multidimensional array access on a single dimension'
                             ' array')

        if isinstance(item, slice):

            valid_items = tuple((i,) for i in range(*item.indices(self.dimensions[0])))

            child_elements = ( valid_items,
                               tuple(self.__getitem__(index) for index in valid_items) )
            return self.__sub_instance(elements=child_elements)

        if isinstance(item, int):
            if (item, ) not in self.__element_indices:
                raise IndexError(f'{item:d} in in the array')
            index = self.__element_indices.index((item, ))
            return self.__elements[index]

        raise TypeError(f'Array index must either being an int or a slice, got {type(item)}')

    def __getitem_nd(self, item): # type: ignore[no-untyped-def]

        if isinstance(item, tuple):

            if len(item) != len(self.dimensions):
                raise ValueError('When using a multidimensional access, the size must match the'
                                 ' dimensions of the array, array dimensions '
                                 f'are {len(self.dimensions)}')

            if all(isinstance(i, int) for i in item):
                # single item access
                if item not in self.__element_indices:
                    msg = 'index[' + ','.join([str(i) for i in item]) + '] not in array'
                    raise IndexError(msg)
                index = self.__element_indices.index(item)
                return self.__elements[index]

            unpack_index_set = []
            for axis, sub_index in enumerate(item):
                if isinstance(sub_index, int):
                    if not 0 <= sub_index < self.dimensions[axis]:
                        raise IndexError(f'{sub_index:d} out of range for dimension {axis}')
                    unpack_index_set.append((sub_index,))
                    continue

                if isinstance(sub_index, slice):
                    unpack_index_set.append(range(*sub_index.indices(self.dimensions[axis])))
                    continue

                raise TypeError(f'unhandle index of {type(sub_index)} in position {axis:d}')

            valid_items = tuple(product(*unpack_index_set))

            child_elements = ( valid_items,
                               tuple(self[index] for index in valid_items) )
            return self.__sub_instance(elements=child_elements)

        raise IndexError('attempting a single dimensional array access on a multidimension'
                         ' array')

    def __len__(self) -> int:
        return len(self.__elements)

    def __iter__(self) -> Iterator[NodeArrayElementType]:
        yield from self.__elements

    def items(self) -> Iterator[tuple[tuple[int, ...], NodeArrayElementType]]:
        """
        iterate through all the items in an array but also return the index of the array
        """
        yield from zip(self.__element_indices, self.__elements)

    @property
    def dimensions(self) -> Union[tuple[int, ...], tuple[int]]:
        """
        Dimensions of the array
        """
        return self.__dimensions

    @property
    @abstractmethod
    def _element_datatype(self) -> type[NodeArrayElementType]:
        ...

    @property
    def address(self) -> int:
        """
        address of the node
        """
        return self.__address

    @property
    def stride(self) -> int:
        """
        address stride of the array
        """
        return self.__stride

    @property
    def _callbacks(self) -> Union[CallbackSet, CallbackSetLegacy]:
        if self.parent is None:
            raise RuntimeError('Parent must be set')
        # pylint: disable-next=protected-access
        return self.parent._callbacks

    @property
    def size(self) -> int:
        """
        Total Number of bytes of address the array occupies
        """
        return math.prod(self.dimensions) * self.stride
