from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import contextlib
import jwt
import logging
import unittest

import gevent
import gevent.monkey

from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
from thrift.Thrift import TApplicationException
from thrift.transport.TTransport import TMemoryBuffer, TTransportException

from baseplate import config
from baseplate.core import (
    Baseplate,
    BaseplateObserver,
    EdgeRequestContext,
    EdgeRequestContextFactory,
    NoAuthenticationError,
    ServerSpan,
    ServerSpanObserver,
    SpanObserver,
)
from baseplate.context.thrift import ThriftContextFactory
from baseplate.file_watcher import FileWatcher
from baseplate.integration.thrift import baseplateify_processor, RequestContext
from baseplate.secrets import store
from baseplate.server import make_listener
from baseplate.server.thrift import make_server
from baseplate.thrift_pool import ThriftConnectionPool

from .test_thrift import TestService, ttypes
from .. import (
    mock,
    AUTH_TOKEN_PUBLIC_KEY,
    SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH,
)


cryptography_installed = True
try:
    import cryptography
except:
    cryptography_installed = False
else:
    del cryptography


try:
    from importlib import reload
except ImportError:
    pass


def make_edge_context_factory():
    mock_filewatcher = mock.Mock(spec=FileWatcher)
    mock_filewatcher.get_data.return_value = {
        "secrets": {
            "secret/authentication/public-key": {
                "type": "versioned",
                "current": AUTH_TOKEN_PUBLIC_KEY,
            },
        },
        "vault": {
            "token": "test",
            "url": "http://vault.example.com:8200/",
        }
    }
    secrets = store.SecretsStore("/secrets")
    secrets._filewatcher = mock_filewatcher
    return EdgeRequestContextFactory(secrets)


@contextlib.contextmanager
def serve_thrift(handler, server_span_observer=None):
    # create baseplate root
    baseplate = Baseplate()
    if server_span_observer:
        class TestBaseplateObserver(BaseplateObserver):
            def on_server_span_created(self, context, server_span):
                server_span.register(server_span_observer)
        baseplate.register(TestBaseplateObserver())

    # set up the server's processor
    logger = mock.Mock(spec=logging.Logger)
    edge_context_factory = make_edge_context_factory()
    processor = TestService.Processor(handler)
    processor = baseplateify_processor(processor, logger, baseplate, edge_context_factory)

    # bind a server socket on an available port
    server_bind_endpoint = config.Endpoint("127.0.0.1:0")
    listener = make_listener(server_bind_endpoint)
    server = make_server({
        "max_concurrency": "100",
    }, listener, processor)

    # figure out what port the server ended up on
    server_address = listener.getsockname()
    server_endpoint = config.EndpointConfiguration(
        family=server_bind_endpoint.family, address=server_address)
    server.endpoint = server_endpoint

    # run the server until our caller is done with it
    server_greenlet = gevent.spawn(server.serve_forever)
    try:
        yield server
    finally:
        server_greenlet.kill()


@contextlib.contextmanager
def raw_thrift_client(endpoint):
    pool = ThriftConnectionPool(endpoint)
    with pool.connection() as client_protocol:
        yield TestService.Client(client_protocol)


@contextlib.contextmanager
def baseplate_thrift_client(endpoint, client_span_observer=None):
    pool = ThriftConnectionPool(endpoint)

    context = RequestContext()
    server_span = ServerSpan(
        trace_id=1234,
        parent_id=2345,
        span_id=3456,
        flags=4567,
        sampled=1,
        name="example_service.example",
        context=context,
    )

    if client_span_observer:
        class TestServerSpanObserver(ServerSpanObserver):
            def on_child_span_created(self, span):
                span.register(client_span_observer)
        server_span.register(TestServerSpanObserver())

    edge_context_factory = make_edge_context_factory()
    edge_context = edge_context_factory.from_upstream(
        SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)
    edge_context.attach_context(context)

    context_factory = ThriftContextFactory(pool, TestService.Client)
    client = context_factory.make_object_for_context("example_service", server_span)
    setattr(context, "example_service", client)

    yield context


class GeventPatchedTestCase(unittest.TestCase):
    def setUp(self):
        gevent.monkey.patch_socket()

    def tearDown(self):
        import socket
        reload(socket)


class ThriftTraceHeaderTests(GeventPatchedTestCase):
    def test_no_headers(self):
        """We should accept requests without headers and generate a trace."""

        class Handler(TestService.Iface):
            def __init__(self):
                self.server_span = None

            def example(self, context):
                self.server_span = context.trace
                return True
        handler = Handler()

        with serve_thrift(handler) as server:
            with raw_thrift_client(server.endpoint) as client:
                client_result = client.example()

        self.assertIsNotNone(handler.server_span)
        self.assertGreaterEqual(handler.server_span.id, 0)
        self.assertTrue(client_result)

    def test_header_propagation(self):
        """If the client sends headers, we should set the trace up accordingly."""

        trace_id = 1234
        parent_id = 2345
        span_id = 3456
        flags = 4567
        sampled = 1

        class Handler(TestService.Iface):
            def __init__(self):
                self.server_span = None

            def example(self, context):
                self.server_span = context.trace
                return True
        handler = Handler()

        with serve_thrift(handler) as server:
            with raw_thrift_client(server.endpoint) as client:
                transport = client._oprot.trans
                transport.set_header(b"Trace", str(trace_id).encode())
                transport.set_header(b"Parent", str(parent_id).encode())
                transport.set_header(b"Span", str(span_id).encode())
                transport.set_header(b"Flags", str(flags).encode())
                transport.set_header(b"Sampled", str(sampled).encode())
                client_result = client.example()

        self.assertIsNotNone(handler.server_span)
        self.assertEqual(handler.server_span.trace_id, trace_id)
        self.assertEqual(handler.server_span.parent_id, parent_id)
        self.assertEqual(handler.server_span.id, span_id)
        self.assertEqual(handler.server_span.flags, flags)
        self.assertEqual(handler.server_span.sampled, sampled)
        self.assertTrue(client_result)

    def test_b3_header_propagation(self):
        """If the client sends B3-style headers, we should accept them."""

        trace_id = 1234
        parent_id = 2345
        span_id = 3456
        flags = 4567
        sampled = 1

        class Handler(TestService.Iface):
            def __init__(self):
                self.server_span = None
            def example(self, context):
                self.server_span = context.trace
                return True
        handler = Handler()

        with serve_thrift(handler) as server:
            with raw_thrift_client(server.endpoint) as client:
                transport = client._oprot.trans
                transport.set_header(b"B3-TraceId", str(trace_id).encode())
                transport.set_header(b"B3-ParentSpanId", str(parent_id).encode())
                transport.set_header(b"B3-SpanId", str(span_id).encode())
                transport.set_header(b"B3-Flags", str(flags).encode())
                transport.set_header(b"B3-Sampled", str(sampled).encode())
                client_result = client.example()

        self.assertIsNotNone(handler.server_span)
        self.assertEqual(handler.server_span.trace_id, trace_id)
        self.assertEqual(handler.server_span.parent_id, parent_id)
        self.assertEqual(handler.server_span.id, span_id)
        self.assertEqual(handler.server_span.flags, flags)
        self.assertEqual(handler.server_span.sampled, sampled)
        self.assertTrue(client_result)

    def test_b3_header_propagation_case_insensitive(self):
        """Be case-insensitive to Trace headers."""

        trace_id = 1234
        parent_id = 2345
        span_id = 3456
        flags = 4567
        sampled = 1

        class Handler(TestService.Iface):
            def __init__(self):
                self.server_span = None
            def example(self, context):
                self.server_span = context.trace
                return True
        handler = Handler()

        with serve_thrift(handler) as server:
            with raw_thrift_client(server.endpoint) as client:
                transport = client._oprot.trans
                transport.set_header(b"b3-traceid", str(trace_id).encode())
                transport.set_header(b"b3-parentspanid", str(parent_id).encode())
                transport.set_header(b"b3-spanid", str(span_id).encode())
                transport.set_header(b"b3-flags", str(flags).encode())
                transport.set_header(b"b3-sampled", str(sampled).encode())
                client_result = client.example()

        self.assertIsNotNone(handler.server_span)
        self.assertEqual(handler.server_span.trace_id, trace_id)
        self.assertEqual(handler.server_span.parent_id, parent_id)
        self.assertEqual(handler.server_span.id, span_id)
        self.assertEqual(handler.server_span.flags, flags)
        self.assertEqual(handler.server_span.sampled, sampled)
        self.assertTrue(client_result)

    def test_optional_headers_optional(self):
        """Test that we accept traces from clients that don't include all headers."""

        trace_id = 1234
        parent_id = 2345
        span_id = 3456

        class Handler(TestService.Iface):
            def __init__(self):
                self.server_span = None

            def example(self, context):
                self.server_span = context.trace
                return True
        handler = Handler()

        with serve_thrift(handler) as server:
            with raw_thrift_client(server.endpoint) as client:
                transport = client._oprot.trans
                transport.set_header(b"Trace", str(trace_id).encode())
                transport.set_header(b"Parent", str(parent_id).encode())
                transport.set_header(b"Span", str(span_id).encode())
                client_result = client.example()

        self.assertIsNotNone(handler.server_span)
        self.assertEqual(handler.server_span.trace_id, trace_id)
        self.assertEqual(handler.server_span.parent_id, parent_id)
        self.assertEqual(handler.server_span.id, span_id)
        self.assertEqual(handler.server_span.flags, None)
        self.assertEqual(handler.server_span.sampled, False)
        self.assertTrue(client_result)



class ThriftEdgeRequestHeaderTests(GeventPatchedTestCase):
    @unittest.skipIf(not cryptography_installed, "cryptography not installed")
    def test_edge_request_context(self):
        """If the client sends an edge-request header we should parse it."""

        class Handler(TestService.Iface):
            def __init__(self):
                self.request_context = None

            def example(self, context):
                self.request_context = context.request_context
                return True
        handler = Handler()

        with serve_thrift(handler) as server:
            with raw_thrift_client(server.endpoint) as client:
                transport = client._oprot.trans
                transport.set_header(b"Edge-Request", SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)
                client_result = client.example()

        self.assertIsNotNone(handler.request_context)
        self.assertEqual(handler.request_context.user.id, "t2_example")
        self.assertEqual(handler.request_context.user.roles, set())
        self.assertEqual(handler.request_context.user.is_logged_in, True)
        self.assertEqual(handler.request_context.user.loid, "t2_deadbeef")
        self.assertEqual(handler.request_context.user.cookie_created_ms, 100000)
        self.assertEqual(handler.request_context.oauth_client.id, None)
        self.assertFalse(handler.request_context.oauth_client.is_type("third_party"))
        self.assertEqual(handler.request_context.session.id, "beefdead")
        self.assertTrue(client_result)

    @unittest.skipIf(not cryptography_installed, "cryptography not installed")
    def test_edge_request_context_case_insensitive(self):
        """We should be case-insensitive to edge-request headers."""

        class Handler(TestService.Iface):
            def __init__(self):
                self.request_context = None

            def example(self, context):
                self.request_context = context.request_context
                return True
        handler = Handler()

        with serve_thrift(handler) as server:
            with raw_thrift_client(server.endpoint) as client:
                transport = client._oprot.trans
                transport.set_header(b"edge-request", SERIALIZED_EDGECONTEXT_WITH_VALID_AUTH)
                client_result = client.example()

        self.assertIsNotNone(handler.request_context)
        self.assertEqual(handler.request_context.user.id, "t2_example")
        self.assertEqual(handler.request_context.user.roles, set())
        self.assertEqual(handler.request_context.user.is_logged_in, True)
        self.assertEqual(handler.request_context.user.loid, "t2_deadbeef")
        self.assertEqual(handler.request_context.user.cookie_created_ms, 100000)
        self.assertEqual(handler.request_context.oauth_client.id, None)
        self.assertFalse(handler.request_context.oauth_client.is_type("third_party"))
        self.assertEqual(handler.request_context.session.id, "beefdead")
        self.assertTrue(client_result)


class ThriftServerSpanTests(GeventPatchedTestCase):
    def test_server_span_starts_and_stops(self):
        """The server span should start/stop appropriately."""
        class Handler(TestService.Iface):
            def example(self, context):
                return True
        handler = Handler()

        server_span_observer = mock.Mock(spec=ServerSpanObserver)
        with serve_thrift(handler, server_span_observer) as server:
            with raw_thrift_client(server.endpoint) as client:
                client.example()

        server_span_observer.on_start.assert_called_once_with()
        server_span_observer.on_finish.assert_called_once_with(None)

    def test_expected_exception_not_passed_to_server_span_finish(self):
        """If the server returns an expected exception, don't count it as failure."""

        class Handler(TestService.Iface):
            def example(self, context):
                raise TestService.ExpectedException()
        handler = Handler()

        server_span_observer = mock.Mock(spec=ServerSpanObserver)
        with serve_thrift(handler, server_span_observer) as server:
            with raw_thrift_client(server.endpoint) as client:
                with self.assertRaises(TestService.ExpectedException):
                    client.example()

        server_span_observer.on_start.assert_called_once_with()
        server_span_observer.on_finish.assert_called_once_with(None)

    def test_unexpected_exception_passed_to_server_span_finish(self):
        """If the server returns an unexpected exception, mark a failure."""

        class UnexpectedException(Exception):
            pass

        class Handler(TestService.Iface):
            def example(self, context):
                raise UnexpectedException
        handler = Handler()

        server_span_observer = mock.Mock(spec=ServerSpanObserver)
        with serve_thrift(handler, server_span_observer) as server:
            with raw_thrift_client(server.endpoint) as client:
                with self.assertRaises(TApplicationException):
                    client.example()

        server_span_observer.on_start.assert_called_once_with()
        self.assertEqual(server_span_observer.on_finish.call_count, 1)
        _, captured_exc, _ = server_span_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, UnexpectedException)


class ThriftClientSpanTests(GeventPatchedTestCase):
    def test_client_span_starts_and_stops(self):
        """The client span should start/stop appropriately."""
        class Handler(TestService.Iface):
            def example(self, context):
                return True
        handler = Handler()

        client_span_observer = mock.Mock(spec=SpanObserver)
        with serve_thrift(handler) as server:
            with baseplate_thrift_client(server.endpoint, client_span_observer) as context:
                context.example_service.example()

        client_span_observer.on_start.assert_called_once_with()
        client_span_observer.on_finish.assert_called_once_with(None)

    def test_expected_exception_not_passed_to_client_span_finish(self):
        """If the server returns an expected exception, don't count it as failure."""

        class Handler(TestService.Iface):
            def example(self, context):
                raise TestService.ExpectedException()
        handler = Handler()

        client_span_observer = mock.Mock(spec=SpanObserver)
        with serve_thrift(handler) as server:
            with baseplate_thrift_client(server.endpoint, client_span_observer) as context:
                with self.assertRaises(TestService.ExpectedException):
                    context.example_service.example()

        client_span_observer.on_start.assert_called_once_with()
        client_span_observer.on_finish.assert_called_once_with(None)

    def test_unexpected_exception_passed_to_client_span_finish(self):
        """If the server returns an unexpected exception, mark a failure."""

        class UnexpectedException(Exception):
            pass

        class Handler(TestService.Iface):
            def example(self, context):
                raise UnexpectedException
        handler = Handler()

        client_span_observer = mock.Mock(spec=SpanObserver)
        with serve_thrift(handler) as server:
            with baseplate_thrift_client(server.endpoint, client_span_observer) as context:
                with self.assertRaises(TApplicationException):
                    context.example_service.example()

        client_span_observer.on_start.assert_called_once_with()
        self.assertEqual(client_span_observer.on_finish.call_count, 1)
        _, captured_exc, _ = client_span_observer.on_finish.call_args[0][0]
        self.assertIsInstance(captured_exc, TApplicationException)


class ThriftEndToEndTests(GeventPatchedTestCase):
    def test_end_to_end(self):
        class Handler(TestService.Iface):
            def __init__(self):
                self.request_context = None

            def example(self, context):
                self.request_context = context.request_context
                return True
        handler = Handler()

        span_observer = mock.Mock(spec=SpanObserver)
        with serve_thrift(handler) as server:
            with baseplate_thrift_client(server.endpoint, span_observer) as context:
                context.example_service.example()

        try:
            self.assertEqual(handler.request_context.user.id, "t2_example")
            self.assertEqual(handler.request_context.user.roles, set())
            self.assertEqual(handler.request_context.user.is_logged_in, True)
            self.assertEqual(handler.request_context.user.loid, "t2_deadbeef")
            self.assertEqual(handler.request_context.user.cookie_created_ms, 100000)
            self.assertEqual(handler.request_context.oauth_client.id, None)
            self.assertFalse(handler.request_context.oauth_client.is_type("third_party"))
            self.assertEqual(handler.request_context.session.id, "beefdead")
        except jwt.exceptions.InvalidAlgorithmError:
            raise unittest.SkipTest("cryptography is not installed")
