"""Base classes used by all script types."""

from typing import (
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
    Type,
    cast,
    TYPE_CHECKING,
)
import uuid
from copy import deepcopy

if TYPE_CHECKING:
    from smrpgpatchbuilder.datatypes.items.classes import Item

from smrpgpatchbuilder.datatypes.overworld_scripts.arguments.types.short_var import (
    ShortVar,
    TimerVar,
)
from smrpgpatchbuilder.datatypes.overworld_scripts.arguments.types.byte_var import (
    ByteVar,
)
from smrpgpatchbuilder.datatypes.numbers.classes import (
    Int8,
    Int16,
    UInt16,
    UInt8,
    UInt4,
)


class IdentifierException(Exception):
    """An exception to raise when an erroneous operation occurs on command identifiers."""


class ScriptBankTooLongException(Exception):
    """an exception to be called when the generated bytes for a script bank exceed the
    allotted length."""


class InvalidCommandArgumentException(Exception):
    """An exception to be called when a given argument is invalid for the command class."""


class InvalidOpcodeException(Exception):
    """An exception to be called when a given opcode is invalid for the command class."""


class RenderException(Exception):
    """An exception to be called when a miscellaneous problem happens in script rendering"""


class TransformableIdentifier:
    """a unique identifier used for any instance of a command class.\n
    smrpg uses a lot of gotos, but we don't know the addresses of those gotos until
    after all event scripts have been compiled.\n
    transformableidentifier is a placeholder class that takes the place of
    an address to jump to, and the actual address it corresponds to is calculated
    when the script is compiled. any other commands that goto another command
    by identifier will have their actual goto address value reflect whatever
    address the destination identifier ends up being located at."""

    _label: str = ""
    _address: int

    def __init__(self, identifier: str) -> None:
        assert identifier is not None and len(identifier) > 0
        self._label = identifier

    @property
    def label(self) -> str:
        """The name string used to identify a command."""
        return self._label

    @property
    def address(self) -> int:
        """The address that will be calculated when the script is compiled."""
        return self._address

    def render_address(self) -> bytearray:
        """Return the little endian short of this identifier's calculated address."""
        return UInt16(self._address & 0xFFFF).little_endian()

    def set_address(self, address: int) -> None:
        """set this identifier's calculated address.\n
        Should only be called by a script rendering method."""
        """assert (
            0x1e0c00 <= address <= 0x1effff
            or 0x1f0c00 <= address <= 0x1fffff
            or 0x200800 <= address <= 0x20dfff
        ) """
        self._address = address

    def __str__(self):
        return self._label


class ScriptCommand:
    """a base class representing any command used in an internal script
    (any script that is used inside smrpg itself).\n
    a script is just a series of these. meant to mimic how scripts are displayed
    in Lazy Shell."""

    _identifier: TransformableIdentifier
    _opcode: Union[bytearray, int]
    _size: int

    @property
    def opcode(self) -> Union[int, bytearray]:
        """The header byte that tells SMRPG what command is being run."""
        return self._opcode

    @property
    def identifier(self) -> TransformableIdentifier:
        """a unique identifier for the command instance that other commands can
        use to jump to this one."""
        return self._identifier

    @property
    def size(self) -> int:
        """The total number of bytes this command can be expected to occupy in the ROM."""
        return self._size

    def _generate_identifier(self) -> str:
        """In the absence of a specified identifier, use this to create one randomly."""
        return str(uuid.uuid4())

    def __init__(self, identifier: Optional[str] = None) -> None:
        if identifier is None or len(identifier) == 0:
            identifier = self._generate_identifier()
        self._identifier = TransformableIdentifier(identifier)

    def render(self, *args) -> bytearray:
        """Converts the command instance into bytes that can be patched to the ROM."""
        output = bytearray([])
        if hasattr(self, "_opcode"):
            if isinstance(self._opcode, bytearray):
                output.extend(self._opcode)
            elif isinstance(self._opcode, int) and 0 <= self._opcode <= 0xFF:
                output.append(self._opcode)
            else:
                raise IdentifierException(
                    f"illegal opcode in {self.identifier}: {self._opcode}"
                )
        for arg in args:
            if isinstance(arg, (ShortVar, ByteVar, TimerVar)):
                output.append(arg.to_byte())
            elif isinstance(arg, UInt16):
                output.extend(arg.little_endian())
            elif isinstance(arg, Int16):
                output.extend(arg.little_endian())
            elif isinstance(arg, UInt8):
                output.append(arg.to_byte())
            elif isinstance(arg, Int8):
                output.append(arg.to_byte())
            elif isinstance(arg, UInt4):
                output.append(arg.to_byte())
            elif isinstance(arg, int) and 0 <= arg <= 0xFF:
                output.append(arg)
            elif isinstance(arg, bytearray):
                output.extend(arg)
            elif isinstance(arg, TransformableIdentifier):
                output.extend(arg.render_address())
            else:
                raise InvalidCommandArgumentException(
                    f"unknown {self.__str__()} argument type in {self.identifier}: {type(arg)} ({arg})"
                )
        if len(output) != self.size:
            raise RenderException(
                (
                    f"{self.identifier} of type {self} output wrong length: "
                    f"{self.__class__.__name__}{args} length {len(output)}, expected {self.size}"
                )
            )
        return output


class ScriptCommandWithJmps(ScriptCommand):
    """A base class for any command that uses a GOTO to another command."""

    _destinations: List[TransformableIdentifier] = []

    @property
    def destinations(self) -> List[TransformableIdentifier]:
        """A list of one or more command identifiers that this command executes a GOTO to."""
        return self._destinations

    def set_destinations(self, destinations: List[str]) -> None:
        """overwrite the list of one or more command identifiers that this command
        executes a GOTO to."""
        self._destinations = [TransformableIdentifier(dest) for dest in destinations]

    def set_destination(self, destination: str, index: int) -> None:
        """Overwrite a specific identifier by list index that this command executes a GOTO to."""
        self._destinations[index] = TransformableIdentifier(destination)

    def __init__(
        self, destinations: List[str], identifier: Optional[str] = None
    ) -> None:
        super().__init__(identifier)
        self.set_destinations(destinations)


class ScriptCommandNoArgs(ScriptCommand):
    """a base class for any script command that doesn't take any arguments.\n
    The rendered command will simply be its opcode."""

    def __init__(self, identifier: Optional[str] = None) -> None:
        super().__init__(identifier)
        if isinstance(self.opcode, bytearray):
            self._size = len(self.opcode)
        else:
            self._size = 1


class ScriptCommandAnySizeMem(ScriptCommand):
    """a base class for any script command that can accept either a 16 bit variable
    or 8 bit variable as an argument."""

    _address: Union[ShortVar, ByteVar]

    @property
    def address(self) -> Union[ShortVar, ByteVar]:
        """The SMRPG internal variable used by the command."""
        return self._address

    def set_address(self, address: Union[ShortVar, ByteVar]) -> None:
        """Designate the SMRPG internal variable used by the command."""
        self._address = address

    def __init__(
        self, address: Union[ByteVar, ShortVar], identifier: Optional[str] = None
    ) -> None:
        super().__init__(identifier)
        self.set_address(address)


class ScriptCommandShortMem(ScriptCommand):
    """a base class for any script command that can only accept a 16 bit variable
    as an argument."""

    _address: ShortVar

    @property
    def address(self) -> ShortVar:
        """The SMRPG internal variable used by the command."""
        return self._address

    def set_address(self, address: ShortVar) -> None:
        """Designate the SMRPG internal variable used by the command."""
        self._address = address

    def __init__(self, address: ShortVar, identifier: Optional[str] = None) -> None:
        super().__init__(identifier)
        self.set_address(address)

    def render(self, *args) -> bytearray:
        return super().render(self.address.to_byte())


class ScriptCommandShortAddrAndValueOnly(ScriptCommand):
    """a base class for any script command that accepts a 16 bit variable
    and a constant number as an argument."""

    _address: ShortVar
    _value: UInt16

    @property
    def address(self) -> ShortVar:
        """The SMRPG internal variable used by the command."""
        return self._address

    def set_address(self, address: ShortVar) -> None:
        """Designate the SMRPG internal variable used by the command."""
        self._address = ShortVar(address)

    @property
    def value(self) -> UInt16:
        """The constant number argument used by the command."""
        return self._value

    def set_value(self, value: Union[int, Type["Item"]]) -> None:
        """set the constant number argument used by the command.\n
        Can also accept an item class, which it extracts the ID from."""
        if not isinstance(value, int) and hasattr(value, "item_id"):
            # This is an Item subclass - extract the ID
            value = value().item_id
        self._value = UInt16(value)

    def __init__(
        self,
        address: ShortVar,
        value: Union[int, Type["Item"]],
        identifier: Optional[str] = None,
    ) -> None:
        super().__init__(identifier)
        self.set_address(address)
        self.set_value(value)


class ScriptCommandBasicShortOperation(ScriptCommand):
    """a base class for performing a mathematical operation on an internal variable
    using some constant number. the internal variable is decided by the command itself
    and cannot be designated by the developer."""

    _value: UInt16
    _size: int = 4

    @property
    def value(self) -> UInt16:
        """The number to use in the math operation."""
        return self._value

    def set_value(self, value: int) -> None:
        """Set the number to use in the math operation."""
        self._value = UInt16(value)

    def __init__(self, value: int, identifier: Optional[str] = None) -> None:
        super().__init__(identifier)
        self.set_value(value)

    def render(self, *args) -> bytearray:
        return super().render(self.value)


ScriptCommandT = TypeVar("ScriptCommandT", bound=ScriptCommand)


class Script(Generic[ScriptCommandT]):
    """A base class for any script, which is mostly a list of script commands."""

    _contents: List[ScriptCommandT] = []

    @property
    def contents(self) -> List[ScriptCommandT]:
        """The list of script commands."""
        return self._contents

    def append(self, command: ScriptCommandT) -> None:
        """Set the list of script commands."""
        self._contents.append(command)

    def extend(self, commands: List[ScriptCommandT]) -> None:
        """Extend the script by a list of commands"""
        self._contents.extend(commands)

    def _insert(self, index: int, command: ScriptCommandT) -> None:
        assert 0 <= index < len(self._contents)
        self._contents.insert(index, command)

    def _get_index_of_nth_command_of_type(
        self, ordinality: int, cls: Type[ScriptCommand]
    ) -> int:
        """returns the list index of the nth item in the script that matches the
        command class provided."""
        try:
            index: int = [
                i for i, cmd in enumerate(self.contents) if isinstance(cmd, cls)
            ][ordinality]
        except Exception as exc:
            raise ValueError(f"could not find {ordinality} instances of {cls}") from exc
        return index

    def get_index_of_identifier(self, identifier: str) -> int:
        """Get script index of command that matches the given identifier name."""
        index = next(
            (
                i
                for i, cmd in enumerate(self._contents)
                if cmd.identifier.label == identifier
            ),
            -1,
        )
        if index == -1:
            raise IdentifierException("{identifier} not found")
        return index

    def set_contents(self, script: Optional[List[ScriptCommandT]] = None) -> None:
        """Overwrite this script's command list."""
        if script is None:
            script = []
        self._contents = deepcopy(script)

    def __init__(self, script: Optional[List[ScriptCommandT]] = None) -> None:
        if script is None:
            script = []
        self.set_contents(script)

    @property
    def length(self) -> int:
        """The expected length of this script in bytes."""
        return sum(cast(List[int], [c.size for c in self.contents]))

    def insert_before_nth_command(self, index: int, command: ScriptCommandT) -> None:
        """Insert a command before the command at the Nth index of this script."""
        self._insert(index, command)

    def insert_after_nth_command(self, index: int, command: ScriptCommandT) -> None:
        """Insert a command after the command at the Nth index of this script."""
        self._insert(index + 1, command)

    def insert_before_nth_command_of_type(
        self,
        ordinality: int,
        cls: Type[ScriptCommandT],
        command: ScriptCommandT,
    ) -> None:
        """insert a command before the nth match of the specified command type in this script.\n
        ie you can insert a new command right before the 3rd "show dialog", etc."""
        index: int = self._get_index_of_nth_command_of_type(ordinality, cls)
        self._insert(index, command)

    def insert_after_nth_command_of_type(
        self,
        ordinality: int,
        cls: Type[ScriptCommandT],
        command: ScriptCommandT,
    ) -> None:
        """insert a command after the nth match of the specified command type in this script.\n
        ie you can insert a new command right before the 3rd "show dialog", etc."""
        index: int = self._get_index_of_nth_command_of_type(ordinality, cls)
        self._insert(index + 1, command)

    def insert_before_identifier(
        self, identifier: str, command: ScriptCommandT
    ) -> None:
        """insert a command to this script immediately before the command matching the unique
        identifier specified."""
        index: int = self.get_index_of_identifier(identifier)
        self._insert(index + 1, command)

    def insert_after_identifier(self, identifier: str, command: ScriptCommandT) -> None:
        """insert a command to this script immediately after the command matching the unique
        identifier specified."""
        index: int = self.get_index_of_identifier(identifier)
        self._insert(index, command)

    def replace_at_index(self, index: int, content: ScriptCommandT) -> None:
        """Replace the command at the specified list index within the script."""
        self._contents[index] = content

    def replace_by_name(self, identifier: str, content: ScriptCommandT) -> None:
        """Replace the command that matches the specifed unique identifier."""
        index: int = self.get_index_of_identifier(identifier)
        self._contents[index] = content

    def delete_at_index(self, index: int) -> None:
        """Delete the command at the specified list index within the script."""
        del self._contents[index]

    def get_command_by_name(self, identifier: str) -> Tuple[int, ScriptCommandT]:
        """Return the command that matches the specifed unique identifier."""
        index = next(
            (
                i
                for i, command in enumerate(self._contents)
                if command.identifier.label == identifier
            ),
            -1,
        )
        if index == -1:
            raise IdentifierException("{identifier} not found")
        return index, self._contents[index]

    def render(self):
        """Convert this script into bytes that can be used in a ROM patch."""
        output = bytearray()
        script: ScriptCommand
        for script in self._contents:
            output += script.render()
        return output


ScriptT = TypeVar("ScriptT", bound=Script)


class ScriptBank(Generic[ScriptT]):
    """a collection of scripts that comprise a contiguous and well-defined block of
    ROM bytes."""

    _scripts: List[ScriptT] = []
    _pointer_table_start: int
    _start: int
    _end: int

    _addresses: Dict[str, int]
    _pointer_bytes: bytearray
    _script_bytes: bytearray

    @property
    def pointer_table_start(self) -> int:
        """the rom address at which this script bank's pointer table should begin.\n
        a pointer table is a list of relative addresses, where the third address indicates
        where the third script in this bank starts, etc."""
        return self._pointer_table_start

    @property
    def start(self) -> int:
        """the rom address at which the contents of the scripts (excluding the pointer table)
        should begin."""
        return self._start

    @property
    def end(self) -> int:
        """The ROM address at which the contents of the scripts should end."""
        return self._end

    @property
    def scripts(self) -> List[ScriptT]:
        """The list of scripts that will comprise this bank."""
        return self._scripts

    def set_contents(self, scripts: Optional[List[ScriptT]] = None) -> None:
        """Overwrite the list of scripts that will comprise this bank."""
        if scripts is None:
            scripts = []
        self._scripts = deepcopy(scripts)

    def replace_script(self, index: int, script: ScriptT) -> None:
        """Overwrite a specific entire script within this bank."""
        self._scripts[index] = deepcopy(script)

    def __init__(self, scripts: Optional[List[ScriptT]]) -> None:
        if scripts is None:
            scripts = []
        self.set_contents(scripts)
        self._addresses = {}
        self._pointer_bytes = bytearray()
        self._script_bytes = bytearray()

    @property
    def addresses(self) -> Dict[str, int]:
        """a collection of the name and calculated address of every
        identifier within the script, which is used to fill the destination
        of every goto."""
        return self._addresses

    @property
    def pointer_bytes(self) -> bytearray:
        """The pointer table as bytes that can be patched to the ROM."""
        return self._pointer_bytes

    @property
    def script_bytes(self) -> bytearray:
        """The whole list of scripts as bytes that can be patched to the ROM."""
        return self._script_bytes

    def _set_identifier_addresses(
        self, identifiers: List[TransformableIdentifier]
    ) -> None:
        destination: TransformableIdentifier
        for destination in identifiers:
            key: str = destination.label
            if key not in self.addresses:
                if "ILLEGAL_JUMP_" in key:
                    destination.set_address((int(key[-4:], 16) & 0xFFFF))
                    return
                else:
                    raise IdentifierException(f"couldn't find destination {key}")
            destination.set_address(self.addresses[key] & 0xFFFF)

    def _populate_jumps(self, script: Script) -> None:
        affected_commands = [
            cmd for cmd in script.contents if isinstance(cmd, ScriptCommandWithJmps)
        ]
        for command in affected_commands:
            self._set_identifier_addresses(command.destinations)
