from transitions.extensions import LockedHierarchicalGraphMachine as Machine
from transitions.extensions.states import add_state_features, Timeout
from transitions.core import listify, EventData, Event
import logging
from copy import copy
from pathlib import Path
from functools import partial
import traceback
import transitions
from time import sleep
from threading import Timer
from .config import MachineConfig


PIGEON_TRANSITIONS_MODULE = Path(__file__).parent
TRANSITIONS_MODULE = Path(transitions.__file__).parent


def root(func):
    def wrapper(self, *args, **kwargs):
        if self.is_root:
            return func(self, *args, **kwargs)
        return func(self._root, *args, **kwargs)

    return wrapper


class State(Timeout, Machine.state_cls):
    def __init__(self, *args, **kwargs):
        self._span = None
        super().__init__(*args, **kwargs)

    def enter(self, event_data):
        self.machine_enter(event_data)
        super().enter(event_data)

    def exit(self, event_data):
        super().exit(event_data)
        self.machine_exit(event_data)

    def get_machine(self, event_data):
        return event_data.machine._get_from_heirarchy(
            self.name.split(event_data.machine.separator)
        )

    def machine_enter(self, event_data):
        try:
            event_data.machine.callbacks(
                self.get_machine(event_data)._on_enter,
                event_data,
            )
        except ValueError:
            pass

    def machine_exit(self, event_data):
        try:
            event_data.machine.callbacks(
                self.get_machine(event_data)._on_exit,
                event_data,
            )
        except ValueError:
            pass


class Transition(Machine.transition_cls):
    def execute(self, event_data):
        if success := super().execute(event_data):
            event_data.machine._logger.info(
                f"Transitioned to state: {event_data.model.state}"
            )
        return success


class Model:
    pass


class BaseMachine(Machine):
    state_cls = State
    transition_cls = Transition

    def __init__(
        self,
        states=None,
        transitions=None,
        on_enter=None,
        on_exit=None,
        before_state_change=None,
        after_state_change=None,
        prepare_event=None,
        finalize_event=None,
        on_exception=None,
        on_final=None,
        auto_transitions=False,
        initial="initial",
        truncate_tb=True,
        **kwargs,
    ):
        """The standard transitions Machine constructor with the folowing changes:

        * The model is disabled.
        * An on_enter callback is created on the machine and transformed using _get_callables and _get_callable.
        * The before_state_change, after_state_change, prepare_event, finalize_event, on_exception, and on_final
            callbacks are transformed using _get_callables and _get_callable.
        """
        self._logger = logging.getLogger(__name__)
        self._parent = None
        self.state_name = None
        self._children = {}
        self._on_enter = self._get_callables(on_enter)
        self._on_exit = self._get_callables(on_exit)
        self._truncate_tb = truncate_tb
        self._model = Model()
        super().__init__(
            states=states,
            transitions=transitions,
            model=None,
            before_state_change=self._get_callables(before_state_change),
            after_state_change=self._get_callables(after_state_change),
            prepare_event=self._get_callables(prepare_event),
            finalize_event=self._get_callables(finalize_event),
            on_exception=self._get_callables(on_exception),
            on_final=self._get_callables(on_final),
            auto_transitions=auto_transitions,
            initial=initial,
            **kwargs,
        )
        self._rename_loggers()

    def _add_machine_states(self, state, remap):
        """This method is overridden to build the parent, child relationships
        between each machine in the hierarchy."""
        state._parent = self
        assert (
            self.get_global_name() != "parent"
        ), "The state name 'parent' is reserved."
        state.state_name = self.get_global_name()
        self._children[self.get_global_name()] = state
        super()._add_machine_states(state, remap)

    def _remap_state(self, state, remaps):
        """This function overrides the normal _remap_state method to add the following:
        * Remove the remaped state so it does not appear in the diagram.
        * Add any on_enter callbacks of the remapped state to the after callbacks of
        the transition.
        * Allow passing a list of dicts for remaps where the dict specifies the original state,
        destination, and any conditions or callbacks for the transition."""
        if isinstance(remaps, dict):
            return self._remap_state(
                state, [{"orig": orig, "dest": new} for orig, new in remaps.items()]
            )
        dummy_remaps = {}
        dest_ind = 0
        for remap in remaps:
            if remap["orig"] not in dummy_remaps:
                dummy_remaps[remap["orig"]] = str(dest_ind)
                dest_ind += 1
        dummy_transitions = super()._remap_state(state, dummy_remaps)
        remapped_transitions = []
        for remap in remaps:
            dest_ind = dummy_remaps[remap["orig"]]
            transition = None
            for dummy in dummy_transitions:
                if dummy["dest"] == dest_ind:
                    transition = {key: copy(val) for key, val in dummy.items()}
            assert transition is not None
            transition["dest"] = remap["dest"]
            for key, val in remap.items():
                if key not in ("orig", "dest"):
                    transition[key] += listify(val)
            transition["before"] += self.states[remap["orig"]].on_enter
            remapped_transitions.append(transition)
        for remap in remaps:
            old_state = remap["orig"]
            if old_state in self.states:
                del self.states[old_state]
        return remapped_transitions

    def get_state_path(self, join=True):
        """Returns the hierarchical state that leads to this machine.

        If join is False, returns a list of hierarchical states which lead to
            this machine."""
        parent = self
        states = []
        while parent._parent is not None:
            states.insert(0, parent.state_name)
            parent = parent._parent
        if join:
            return self.separator.join(states)
        return states

    def get_machine_state(self):
        """Returns the current state of this machine, or None, if the current
        state is not a state in this machine, or a substate."""
        state_path = self.get_state_path(join=False)
        state = self.state.split(self.separator)
        if any(
            [
                state_comp != state_path_comp
                for state_comp, state_path_comp in zip(state, state_path)
            ]
        ):
            return None
        return state[len(state_path)]

    def _current_machine(self):
        """Returns True if the current state is a state of this machine, or a substate."""
        return self.get_machine_state() is not None

    def current_machine(self):
        """Returns True if the current state is a state of this machine strictly."""
        if not self._current_machine():
            return False
        state = self.state.split(self.separator)
        state_path = self.get_state_path(join=False)
        return len(state_path) + 1 == len(state)

    def _create_state(
        self, *args, on_enter=None, on_exit=None, on_timeout=None, **kwargs
    ):
        """Transform callbacks using _get_callables"""
        return super()._create_state(
            *args,
            on_enter=self._get_callables(on_enter),
            on_exit=self._get_callables(on_exit),
            on_timeout=self._get_callables(on_timeout),
            **kwargs,
        )

    def add_transition(
        self,
        *args,
        conditions=None,
        unless=None,
        before=None,
        after=None,
        prepare=None,
        **kwargs,
    ):
        """Transform callbacks using _get_callables"""
        return super().add_transition(
            *args,
            conditions=self._get_callables(conditions),
            unless=self._get_callables(unless),
            before=self._get_callables(before),
            after=self._get_callables(after),
            prepare=self._get_callables(prepare),
            **kwargs,
        )

    def _get_callable(self, func):
        """Get a class member function of the same name as the input if available.
        If the class member is not a function, create a lambda function which
        returns the current value of the class member variable. If the input is
        a variable, return a lambda function which returns the current value of
        the variable."""
        if isinstance(func, str):
            if hasattr(self, func):
                tmp = getattr(self, func)
                if callable(tmp):
                    return tmp
                else:
                    geter = lambda: getattr(self, func)
                    # Setting the __name__ attribute shows the function name on the graph
                    geter.__name__ = func
                    return geter
            else:
                return func
        if not callable(func):
            return lambda: func
        return func

    def _get_callables(self, funcs):
        """Returns a transformed list of callbacks with string entries substituted
        for class member functions when available."""
        if funcs is None:
            return []
        return [self._get_callable(func) for func in listify(funcs)]

    def callback(self, func, event_data):
        try:
            super().callback(func, event_data)
        except Exception as e:
            if isinstance(func, str):
                name = func
            else:
                name = func
                if isinstance(func, partial) and func.func == self._locked_method:
                    name = func.args[0]
                if hasattr(name, "__name__"):
                    name = name.__name__
                else:
                    name = repr(name)
            self._logger.warning(
                f'An error was encountered while running callback "{name}":\n{self._get_traceback(e)}'
            )

    def _get_traceback(self, exception):
        raw_tb = traceback.extract_tb(exception.__traceback__)
        if self._root._truncate_tb:
            tb = []
            for frame in reversed(raw_tb):
                parents = Path(frame.filename).parents
                if (
                    PIGEON_TRANSITIONS_MODULE in parents
                    or TRANSITIONS_MODULE in parents
                ):
                    break
                tb.insert(0, frame)
        else:
            tb = raw_tb
        return "\n".join(traceback.format_list(tb))

    @property
    def _root(self):
        """Traverse the tree of hierarchical machines to the root and return it."""
        root = self
        while root._parent is not None:
            root = root._parent
        return root

    @property
    def is_root(self):
        return self._parent is None

    @root
    def _get_from_heirarchy(self, state_list):
        child = self
        for state in state_list:
            if state not in child._children:
                raise ValueError(f"Machine {child} has no child {state}.")
            child = child._children[state]
        return child

    @root
    def _get_machine(self, state):
        """This method returns the machine instance which a given state is part of."""
        return self._get_from_heirarchy(state.split(self.separator)[:-1])

    @root
    def _get_current_machine(self):
        """This method returns the machine instance which the current state is part of."""
        return self._get_machine(self.state)

    @root
    def _get_current_machines(self):
        """This generator first yields the full hierarchical state of the current
        machine then continues yielding states descending to the root machine."""
        state_list = self.state.split(self.separator)
        yield self._get_current_machine()
        for i in range(1, len(state_list)):
            yield self._get_machine(self.separator.join(state_list[:-i]))

    @root
    def _collect_states(self, states=None):
        for state in self.states.values() if states is None else states.values():
            yield state
            yield from self._collect_states(state.states)

    @root
    def _get_attr(self, name):
        try:
            return super().__getattribute__(name)
        except AttributeError:
            pass
        try:
            return getattr(self._model, name)
        except AttributeError:
            pass
        raise AttributeError

    def __getattr__(self, name):
        """If a class attribute is not available in this class, try to get it
        from the root class."""
        return self._get_attr(name)

    def _add_model(self):
        if not len(self.models):
            self.add_model(self._model)

    @root
    def _start(self):
        """This method can be called at the beginning of execution of the state
        machine. It runs the on_enter callback of each of the machines that are
        part of the initial state."""
        self._add_model()
        self.callbacks(
            self._on_enter, EventData(None, Event("_start", self), self, self, [], {})
        )
        heirarchy = []
        for state in self._get_initial_states():
            event = EventData(state, Event("_start", self), self, self, [], {})
            heirarchy.append(state.name)
            try:
                machine = self._get_from_heirarchy(heirarchy)
                self.callbacks(machine._on_enter, event)
            except ValueError:
                pass
            self.callbacks(state.on_enter, event)
        self._start_timer()

    def _get_initial_states(self):
        """This method returns the set of initial states of the hierarchical state machine."""
        states = [self.states[self.initial]]
        while len(states[-1].states):
            states.append(states[-1].states[states[-1].initial])
        return states

    @root
    def _start_timer(self):
        state = self.get_state(self.state)
        if state.timeout > 0:
            event_data = EventData(self.state, "_start", self, self, (), {})
            timer = Timer(state.timeout, state._process_timeout, args=(event_data,))
            timer.daemon = True
            timer.start()
            state.runner[id(self._model)] = timer

    def _run_once(self):
        sleep(1)

    def _loop(self):
        while True:
            self._run_once()

    def _run(self):
        """This method runs the _start routine, then enters an infinte loop."""
        self._start()
        self._loop()

    def _rename_loggers(self):
        if self._parent is not None:
            self._logger = self._parent._logger.getChild(self.state_name)
        for child in self._children.values():
            child._rename_loggers()
