import pytest
from pigeon_transitions.client import ClientMachine as Machine
from pigeon_transitions.client import (
    BaseClient,
    PigeonClient,
    NotCollectedError,
    NoClientError,
    AccessDeniedError,
    subscribe,
    SubscriptionData,
)
from time import time, sleep
from threading import Thread


@pytest.fixture
def client(mocker):
    class MockClient(BaseClient):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.mock_subscribe = mocker.MagicMock()
            self.mock_send = mocker.MagicMock()

        def subscribe(self, *args, **kwargs):
            return self.mock_subscribe(*args, **kwargs)

        def send(self, *args, **kwargs):
            return self.mock_send(*args, **kwargs)

    return MockClient()


def test_add_client(client, mocker):
    test_machine = Machine()
    test_machine._gather_topics = mocker.MagicMock(return_value=["topic1", "topic2"])

    test_machine.add_client(client)

    assert test_machine._client == client
    assert client._callback == test_machine._message_callback
    client.mock_subscribe.assert_has_calls(
        [mocker.call("topic1"), mocker.call("topic2")]
    )


def test_machine_get_collected(client, mocker):
    machine = Machine()
    machine.add_client(client)
    client.get_collected = mocker.MagicMock()

    machine.get_collected("topic", timeout=102)

    client.get_collected.assert_called_with("topic", timeout=102)


def test_get_collected_timeout_missing(client):
    start = time()
    with pytest.raises(NotCollectedError):
        client.get_collected("missing", timeout=0.2)

    assert abs(time() - start - 0.2) < 0.01


@pytest.mark.parametrize("timeout", [0, 0.5])
def test_get_collected_timeout(client, timeout):
    def collect():
        sleep(0.25)
        client._collected["topic"] = "a message!"

    thread = Thread(target=collect)
    thread.start()

    msg = client.get_collected("topic", timeout=timeout)

    assert msg == "a message!"


def test_client_property(mocker):
    child2 = Machine(
        states=[
            "five",
        ],
        initial="five",
    )

    child1 = Machine(
        states=[
            "three",
            {
                "name": "four",
                "children": child2,
            },
        ],
        initial="three",
        transitions=[
            {
                "source": "three",
                "dest": "four",
                "trigger": "next",
            },
        ],
    )

    machine = Machine(
        states=[
            "one",
            {
                "name": "two",
                "children": child1,
            },
        ],
        initial="one",
        transitions=[
            {
                "source": "one",
                "dest": "two",
                "trigger": "next",
            },
        ],
    )
    machine._start()

    assert machine.state == "one"

    with pytest.raises(NoClientError):
        machine.client

    machine._client = "the_client"

    assert machine.client == machine._client
    assert not child1._current_machine()

    with pytest.raises(AccessDeniedError):
        child1.client

    with pytest.raises(AccessDeniedError):
        child2.client

    assert machine.next()

    assert machine.state == "two_three"
    assert machine.client == machine._client
    assert child1._current_machine()
    assert child1.client == machine._client
    assert not child2._current_machine()
    with pytest.raises(AccessDeniedError):
        child2.client

    assert machine.next()

    assert machine.state == "two_four_five"
    assert machine.client == machine._client
    assert child1._current_machine()
    assert child1.client == machine._client
    assert child2._current_machine()
    assert child2.client == machine._client


def test_subscribe(client, mocker):
    class TestMachine(Machine):
        def __init__(self, *args, **kwargs):
            self.mock_callback = mocker.MagicMock()
            super().__init__(*args, **kwargs)

        @subscribe("test_topic", some_option=True)
        def callback(self, msg):
            self.mock_callback(msg)

    test = TestMachine()

    assert len(test._subscriptions) == 1

    subscriptions = test._subscriptions["test_topic"]

    assert len(subscriptions) == 1

    subscription = subscriptions[0]

    assert subscription.topic == "test_topic"
    assert subscription.options == {"some_option": True}

    subscription("test_msg")

    test.mock_callback.assert_called_once_with("test_msg")


def test_subscribe_nested(client, mocker):
    class TestMachine(Machine):
        def __init__(self, *args, **kwargs):
            self.mock_callback = mocker.MagicMock()
            super().__init__(*args, **kwargs)

        @subscribe("test_topic", some_option=True)
        @subscribe("another_topic")
        def callback(self, msg):
            self.mock_callback(msg)

    test = TestMachine()

    assert len(test._subscriptions) == 2

    subscriptions = test._subscriptions["test_topic"]

    assert len(subscriptions) == 1

    subscription = subscriptions[0]

    assert subscription.topic == "test_topic"
    assert subscription.options == {"some_option": True}

    subscription("test_msg")

    test.mock_callback.assert_called_once_with("test_msg")

    subscriptions = test._subscriptions["another_topic"]

    assert len(subscriptions) == 1

    subscription = subscriptions[0]

    assert subscription.topic == "another_topic"
    assert subscription.options == {}

    subscription("test_msg_2")

    test.mock_callback.assert_called_with("test_msg_2")


def test_subscribe_call(client):
    class TestMachine(Machine):
        prop = 1

        @subscribe("test_topic")
        def callback(self, msg):
            return self.prop

    test = TestMachine()

    assert test.callback(None) == 1


def test_subscribe_nested_call(client):
    class TestMachine(Machine):
        prop = "test"

        @subscribe("test_topic")
        @subscribe("something")
        def callback(self, msg):
            return self.prop

    test = TestMachine()

    assert test.callback(None) == "test"


def test_on_msg_error(client):
    with pytest.raises(AssertionError):
        client.on_msg("test_topic", "test_msg")


def test_on_msg(client, mocker):
    callback = mocker.MagicMock()

    client.set_callback(callback)

    client.on_msg("test_topic", "test_data", "extra_stuff")

    assert client._collected == {"test_topic": "test_data"}
    callback.assert_called_once_with("test_topic", "test_data", "extra_stuff")


@pytest.fixture
def pigeon(mocker):
    return mocker.patch("pigeon_transitions.client.Pigeon")


@pytest.mark.parametrize(
    "service, service_name",
    [(None, "pigeon-transitions"), ("test_service", "test_service")],
)
def test_pigeon_connect(pigeon, service, service_name):
    client = PigeonClient(
        service=service,
        host="my.server.com",
        port=61613,
        username="user",
        password="pass",
    )

    pigeon.assert_called_with(service_name, host="my.server.com", port=61613)
    pigeon().connect.assert_called_with(username="user", password="pass")


def test_pigeon_send(pigeon):
    client = PigeonClient()
    client.send("test_topic", some="data")
    pigeon().send.assert_called_with("test_topic", some="data")


def test_pigeon_msg_callback(pigeon, mocker):
    client = PigeonClient()
    client.on_msg = mocker.MagicMock()
    client.msg_callback("test_msg", "test_topic", "headers")
    client.on_msg.assert_called_with("test_topic", "test_msg", extra="headers")


def test_pigeon_subscribe(pigeon):
    client = PigeonClient()
    client.subscribe("test_topic")
    pigeon().subscribe.assert_called_with("test_topic", client.msg_callback, True, True)


@pytest.mark.parametrize("include_topic", [False, True])
@pytest.mark.parametrize("include_headers", [False, True])
def test_pigeon_run_callback(pigeon, include_topic, include_headers, mocker):
    client = PigeonClient()
    callback = mocker.MagicMock()
    subscription = SubscriptionData(
        callback,
        "test_topic",
        include_topic=include_topic,
        include_headers=include_headers,
    )
    client.run_callback(subscription, "test_topic", "test_msg", "some_headers")
    args = ["test_msg"]
    if include_topic:
        args.append("test_topic")
    if include_headers:
        args.append("some_headers")
    callback.assert_called_with(*args)


def test_pigeon_integration(pigeon, mocker):
    class TestMachine(Machine):
        def __init__(self, *args, **kwargs):
            self.mock_callback = mocker.MagicMock()
            self.mock_callback2 = mocker.MagicMock()
            super().__init__(*args, **kwargs)

        @subscribe("test_topic")
        def callback(self, msg):
            self.mock_callback(msg)

        @subscribe("test_topic", include_topic=True)
        @subscribe("test_topic_2")
        def callback2(self, msg, topic=None):
            self.mock_callback2(msg, topic)

    child = TestMachine(
        states=[
            "three",
        ],
        initial="three",
    )

    machine = TestMachine(
        states=[
            "one",
            {
                "name": "two",
                "children": child,
            },
        ],
        initial="one",
        transitions=[
            {
                "source": "one",
                "dest": "two",
                "trigger": "next",
            },
        ],
    )

    client = PigeonClient()
    machine.add_client(client)
    machine._start()

    assert machine.state == "one"

    client.msg_callback("test_data", "test_topic", "headers")

    machine.mock_callback.assert_called_with("test_data")
    machine.mock_callback2.assert_called_with("test_data", "test_topic")
    child.mock_callback.assert_not_called()
    child.mock_callback2.assert_not_called()

    machine.mock_callback.reset_mock()
    machine.mock_callback2.reset_mock()
    child.mock_callback.reset_mock()
    child.mock_callback2.reset_mock()

    client.msg_callback("test_data", "test_topic_2", "headers")

    machine.mock_callback.assert_not_called()
    machine.mock_callback2.assert_called_with("test_data", None)
    child.mock_callback.assert_not_called()
    child.mock_callback2.assert_not_called()

    machine.mock_callback.reset_mock()
    machine.mock_callback2.reset_mock()
    child.mock_callback.reset_mock()
    child.mock_callback2.reset_mock()

    assert machine.next()
    assert machine.state == "two_three"

    client.msg_callback("test_data", "test_topic", "headers")

    machine.mock_callback.assert_called_with("test_data")
    machine.mock_callback2.assert_called_with("test_data", "test_topic")
    child.mock_callback.assert_called_with("test_data")
    child.mock_callback2.assert_called_with("test_data", "test_topic")

    machine.mock_callback.reset_mock()
    machine.mock_callback2.reset_mock()
    child.mock_callback.reset_mock()
    child.mock_callback2.reset_mock()

    client.msg_callback("test_data", "test_topic_2", "headers")

    machine.mock_callback.assert_not_called()
    machine.mock_callback2.assert_called_with("test_data", None)
    child.mock_callback.assert_not_called()
    child.mock_callback2.assert_called_with("test_data", None)


def test_message_callback(client, mocker):
    class TestMachine(Machine):
        def __init__(self, *args, **kwargs):
            self.mock_callback = mocker.MagicMock()
            super().__init__(*args, **kwargs)

        @subscribe("test_topic")
        def callback(self, msg):
            self.mock_callback(msg)

    child2 = TestMachine(
        states=[
            "five",
        ],
        initial="five",
    )

    child1 = TestMachine(
        states=[
            "three",
            {
                "name": "four",
                "children": child2,
            },
        ],
        initial="three",
        transitions=[
            {
                "source": "three",
                "dest": "four",
                "trigger": "next",
            },
        ],
    )

    machine = TestMachine(
        states=[
            "one",
            {
                "name": "two",
                "children": child1,
            },
        ],
        initial="one",
        transitions=[
            {
                "source": "one",
                "dest": "two",
                "trigger": "next",
            },
        ],
    )

    machine.add_client(client)
    machine._start()

    assert machine.state == "one"

    client.on_msg("test_topic", "test_data")

    machine.mock_callback.assert_called_with("test_data")
    child1.mock_callback.assert_not_called()
    child2.mock_callback.assert_not_called()

    machine.mock_callback.reset_mock()
    child1.mock_callback.reset_mock()
    child2.mock_callback.reset_mock()

    assert machine.next()
    assert machine.state == "two_three"

    client.on_msg("test_topic", "test_data")

    machine.mock_callback.assert_called_with("test_data")
    child1.mock_callback.assert_called_with("test_data")
    child2.mock_callback.assert_not_called()

    machine.mock_callback.reset_mock()
    child1.mock_callback.reset_mock()
    child2.mock_callback.reset_mock()

    assert machine.next()
    assert machine.state == "two_four_five"

    client.on_msg("test_topic", "test_data")

    machine.mock_callback.assert_called_with("test_data")
    child1.mock_callback.assert_called_with("test_data")
    child2.mock_callback.assert_called_with("test_data")


def test_message_callback_error(client, mocker):
    class TestMachine(Machine):
        def __init__(self, *args, **kwargs):
            self.mock_callback = mocker.MagicMock(side_effect=Exception)
            super().__init__(*args, **kwargs)

        @subscribe("test_topic")
        def callback(self, msg):
            self.mock_callback(msg)

    machine = TestMachine(
        states=[
            "one",
        ],
        initial="one",
    )

    machine._logger = mocker.MagicMock()
    machine.add_client(client)
    machine._start()

    assert machine.state == "one"

    client.on_msg("test_topic", "test_data")

    machine._logger.warning.assert_called()
