from __future__ import annotations

import socket
import ssl
from pathlib import Path
from struct import unpack
from typing import Callable
from unittest.mock import call, patch

import pytest

from dissect.target.loader import open as loader_open
from dissect.target.loaders.remote import RemoteLoader, RemoteStream, RemoteStreamConnection
from dissect.target.target import Target


@pytest.mark.parametrize(
    ("opener"),
    [
        pytest.param(Target.open, id="target-open"),
        pytest.param(lambda x: next(Target.open_all([x])), id="target-open-all"),
    ],
)
def test_target_open(opener: Callable[[str | Path], Target]) -> None:
    """Test that we correctly use ``RemoteLoader`` when opening a ``Target``."""
    path = "remote://127.0.0.1:1337"
    with (
        patch.object(ssl, "SSLContext", autospec=True),
        patch("dissect.target.loaders.remote.RemoteLoader.map"),
        patch("dissect.target.target.Target.apply"),
    ):
        target = opener(path)
        assert isinstance(target._loader, RemoteLoader)
        assert target.path == Path("127.0.0.1:1337")


def test_loader() -> None:
    """Test that ``RemoteLoader`` is correctly selected."""
    with patch.object(ssl, "SSLContext", autospec=True):
        loader = loader_open("remote://127.0.0.1:1337")
        assert isinstance(loader, RemoteLoader)


def test_stream() -> None:
    """Test that we can create a ``RemoteStream`` with a ``RemoteStreamConnection``."""

    with (
        patch.object(ssl, "SSLContext", autospec=True) as mock_ssl_context,
        patch.object(socket, "socket", autospec=True) as mock_socket,
    ):
        rsc = RemoteStreamConnection("remote://127.0.0.1", 9001, options={"ca": "A", "key": "B", "crt": "C"})
        assert rsc.is_connected() is False
        rsc.connect()

        rsc.remote_disk_data = b"ABC"
        rsc.remote_disk_response = None

        def send(data: bytes) -> None:
            _, offset, read = unpack(">BQQ", data)
            rsc.remote_disk_response = rsc.remote_disk_data[offset : offset + read]
            return len(data)

        def receive(num: bytes) -> bytes:
            return rsc.remote_disk_response

        rsc._ssl_sock.send = send
        rsc._ssl_sock.recv = receive
        rs = RemoteStream(rsc, 0, 3)
        rs.align = 1
        rs.seek(1)
        data = rs.read(2)
        rs.close()

        mock_socket.assert_called_with(socket.AddressFamily.AF_INET, socket.SocketKind.SOCK_STREAM)

        expected = [
            call(ssl.PROTOCOL_TLSv1_2),
            call().load_default_certs(),
            call().load_cert_chain(certfile="C", keyfile="B"),
            call().load_verify_locations("A"),
            call().wrap_socket(rsc._socket, server_hostname="remote://127.0.0.1"),
            call().wrap_socket().connect(("remote://127.0.0.1", 9001)),
        ]

        assert data == b"BC"
        assert mock_ssl_context.mock_calls == expected
        assert rs.tell() == 3
        assert rsc.is_connected() is True


def test_stream_embedded() -> None:
    with (
        patch.object(ssl, "SSLContext", autospec=False) as mock_ssl_context,
        patch.object(socket, "socket", autospec=True),
    ):
        RemoteStreamConnection.configure("K", "C")
        RemoteStreamConnection("remote://127.0.0.1", 9001)
        mock_ssl_context.assert_has_calls([call().load_cert_chain_str(certfile="C", keyfile="K")])
