from abc import ABC, abstractmethod
from .base import BaseMachine, root
from time import time, sleep
from functools import partial
from pigeon import Pigeon


class NotCollectedError(Exception):
    pass


class AccessDeniedError(Exception):
    pass


class NoClientError(Exception):
    pass


class SubscriptionData:
    def __init__(self, callback, topic, **options):
        self._callback = callback
        self.topic = topic
        self.options = options

    def __get__(self, obj, objtype):
        return Subscription(obj, self._callback, self.topic, **self.options)

    def __call__(self, *args, **kwargs):
        return self._callback(*args, **kwargs)


class Subscription:
    def __init__(self, machine, callback, topic, **options):
        self._machine = machine
        self._callback = callback
        self.topic = topic
        self.options = options

    def __call__(self, *args, **kwargs):
        if isinstance(self._callback, Subscription):
            return self._callback(*args, **kwargs)
        return self._callback(self._machine, *args, **kwargs)


def subscribe(topic, **options):
    return lambda callback: SubscriptionData(callback, topic, **options)


class BaseClient(ABC):
    def __init__(self):
        self._callback = None
        self._collected = {}

    def set_callback(self, callback):
        self._callback = callback

    @abstractmethod
    def subscribe(self, topic):
        pass

    @abstractmethod
    def send(self):
        pass

    def on_msg(self, topic, msg, extra=None):
        self._collected[topic] = msg
        assert self._callback is not None
        self._callback(topic, msg, extra)

    def get_collected(self, topic, timeout=0):
        """This function returns the most recent message recieved on a given topic.

        args:
            topic (str): The topic to get the latest message from.
            timeout (float): The number of seconds to wait until the topic is available.
                If None, immediately return None if a message has not been recieved on
                the topic. If 0, wait indefinitely."""
        if timeout is None:
            return self._collected.get(topic, None)
        start = time()
        while topic not in self._collected and (
            timeout == 0 or time() - start <= timeout
        ):
            sleep(0.1)
        if topic not in self._collected:
            raise NotCollectedError(
                f"A message on topic {topic} has not been received after {timeout} seconds."
            )
        return self._collected[topic]

    def run_callback(self, subscription, topic, msg, extra):
        subscription(msg)


class PigeonClient(BaseClient):
    def __init__(
        self,
        service=None,
        host="127.0.0.1",
        port=61616,
        username=None,
        password=None,
    ):
        self._client = Pigeon(
            service if service is not None else "pigeon-transitions",
            host=host,
            port=port,
        )
        self._client.connect(username=username, password=password)
        super().__init__()

    def send(self, topic, **data):
        self._client.send(topic, **data)

    def msg_callback(self, msg, topic, headers):
        self.on_msg(topic, msg, extra=headers)

    def subscribe(self, topic):
        self._client.subscribe(topic, self.msg_callback, True, True)

    def run_callback(self, subscription, topic, msg, extra):
        args = [msg]
        if subscription.options.get("include_topic", False):
            args.append(topic)
        if subscription.options.get("include_headers", False):
            args.append(extra)
        subscription(*args)


class ClientMachine(BaseMachine):
    def __init__(self, *args, **kwargs):
        self._client = None
        super().__init__(*args, **kwargs)
        self._gather_subscriptions()

    def _gather_subscriptions(self):
        self._subscriptions = {}
        for name in dir(self):
            try:
                attr = getattr(self, name)
            except Exception as e:
                self._logger.debug(
                    f"Was not able to check if class attribute {name} of {self} is a subscription callback due to error {e}"
                )
                continue
            if isinstance(attr, Subscription):
                self._process_subscription(attr)

    def _process_subscription(self, subscription):
        if subscription.topic not in self._subscriptions:
            self._subscriptions[subscription.topic] = [subscription]
        else:
            self._subscriptions[subscription.topic].append(subscription)
        if isinstance(subscription._callback, SubscriptionData):
            subscription._callback = subscription._callback.__get__(self, type(self))
        if isinstance(subscription._callback, Subscription):
            self._process_subscription(subscription._callback)

    @property
    def client(self):
        """Returns the Pigeon client, or None, if the machine is not part of the
        current state."""
        if hasattr(self, "state") and not self._current_machine():
            raise AccessDeniedError("Machine not currently active.")
        if self._root._client is None:
            raise NoClientError("No client added.")
        return self._root._client

    def get_collected(self, topic, timeout=0):
        if self._root._client is None:
            raise NoClientError("No client added.")
        return self._root._client.get_collected(topic, timeout=timeout)

    @root
    def add_client(self, client):
        assert isinstance(client, BaseClient)
        self._client = client
        self._client.set_callback(self._message_callback)
        for topic in self._gather_topics():
            self._client.subscribe(topic)

    def _gather_topics(self):
        topics = list(self._subscriptions.keys())
        for child in self._children.values():
            for topic in child._gather_topics():
                if topic not in topics:
                    topics.append(topic)
        return topics

    @root
    def _message_callback(self, topic, msg, extra=None):
        for machine in self._get_current_machines():
            for subscription in machine._subscriptions.get(topic, []):
                try:
                    self._client.run_callback(subscription, topic, msg, extra)
                except Exception as e:
                    self._logger.warning(
                        f"Callback for a message on topic '{topic}' with data '{msg}' resulted in an exception:\n{self._get_traceback(e)}"
                    )
