from __future__ import annotations

import dataclasses
import typing
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, call

import pytest

from randovania.bitpacking import bitpacking

if TYPE_CHECKING:
    from collections.abc import Iterator

    from randovania.bitpacking.bitpacking import BitPackDecoder
    from randovania.lib.json_lib import JsonObject_RO


@pytest.mark.parametrize(
    "value",
    [
        10,
        65,
        0,
        1,
        2,
        134,
    ],
)
@pytest.mark.parametrize(
    "limits",
    [
        (5, 20, 500),
        (3, 50, 150),
    ],
)
def test_encode_int_with_limits_round_trip(
    value: int,
    limits: tuple[int, ...],
):
    # Run
    data = bitpacking._pack_encode_results(list(bitpacking.encode_int_with_limits(value, limits)))
    decoded = bitpacking.decode_int_with_limits(bitpacking.BitPackDecoder(data), limits)

    # Assert
    assert decoded == value


@pytest.fixture(
    params=[
        (0, (1, 4), [(0, 2)]),
        (1, (1, 4), [(1, 2), (0, 4)]),
        (5, (10,), [(5, 11)]),
        (50, (10,), [(50, 11)]),
        (5, (10, 100), [(5, 11)]),
        (50, (10, 100), [(10, 11), (40, 91)]),
        (5, (10, 100, 500), [(5, 11)]),
        (50, (10, 100, 500), [(10, 11), (40, 91)]),
        (500, (10, 100, 500), [(10, 11), (90, 91), (400, 401)]),
    ],
)
def limits_fixture(request):
    return request.param[0], request.param[1], request.param[2]


def test_encode_int_with_limits(limits_fixture):
    # Setup
    value, limits, encoded = limits_fixture

    # Run
    result = list(bitpacking.encode_int_with_limits(value, limits))

    # Assert
    assert result == encoded


def test_decode_int_with_limits(limits_fixture):
    # Setup
    value, limits, encoded = limits_fixture
    decoder = MagicMock()
    decoder.decode_single.side_effect = [part for part, _ in encoded]

    # Run
    result = bitpacking.decode_int_with_limits(decoder, limits)

    # Assert
    decoder.decode_single.assert_has_calls([call(limit) for _, limit in encoded])
    assert result == value


@pytest.mark.parametrize(
    ("value", "limits", "expected"),
    [
        (0, (1, 4), "u1"),
        (1, (1, 4), "u1u2"),
        (2, (1, 4), "u1u2"),
        (3, (1, 4), "u1u2"),
        (4, (1, 4), "u1u2"),
    ],
)
def test_encode_int_with_limits_bitstring(value, limits, expected):
    # Run
    result = bitpacking._format_string_for(list(bitpacking.encode_int_with_limits(value, limits)))

    # Assert
    assert result == expected


@pytest.fixture(
    params=[
        (False, (0, 2)),
        (True, (1, 2)),
    ],
)
def bool_fixture(request):
    return request.param[0], request.param[1]


def test_encode_bool(bool_fixture):
    # Setup
    value, encoded = bool_fixture

    # Run
    result = list(bitpacking.encode_bool(value))

    # Assert
    assert result == [encoded]


def test_decode_bool(bool_fixture):
    # Setup
    value, encoded = bool_fixture
    decoder = MagicMock()
    decoder.decode_single.return_value = encoded[0]

    # Run
    result = bitpacking.decode_bool(decoder)

    # Assert
    decoder.decode_single.assert_called_once_with(encoded[1])
    assert result == value


@pytest.mark.parametrize(
    ("value", "metadata"),
    [
        (0.0, {"min": 0.0, "max": 1.0, "precision": 1}),
        (0.0, {"min": -1.0, "max": 1.0, "precision": 1}),
        (-0.5, {"min": -1.0, "max": 1.0, "precision": 1}),
        (1.0, {"min": 0.0, "max": 1.0, "precision": 1}),
        (1.0, {"min": 0.0, "max": 1.0, "precision": 2}),
    ],
)
def test_round_trip_float(value: float, metadata: dict):
    result = bitpacking.round_trip(bitpacking.BitPackFloat(value), metadata)
    assert result == value


@pytest.mark.parametrize(
    ("elements", "array"),
    [
        ([], [10, 20]),
        ([10], [10, 20]),
        ([10, 20], [10, 20]),
        ([10, 20], [10, 20, 30]),
        ([10, 20], [10, 20, 30, 50]),
        (list(range(15)), list(range(100))),
        ([x * 2 for x in range(150)], list(range(300))),
    ],
)
def test_sorted_array_elements_round_trip(elements, array):
    generator = bitpacking.pack_sorted_array_elements(elements, array)
    b = bitpacking._pack_encode_results(list(generator))
    decoder = bitpacking.BitPackDecoder(b)

    decoded_elements = bitpacking.decode_sorted_array_elements(decoder, array)

    assert elements == decoded_elements


@pytest.mark.parametrize(
    ("elements", "array", "expected_size"),
    [
        ([], [], 0),
        ([], range(100), 8),
        ([90], range(100), 18),
        (range(100), range(100), 8),
        (range(100), range(300), 219),
        (list(range(100)) + list(range(200, 300)), range(300), 318),
        (range(200), range(300), 120),
        (range(200, 300), range(300), 120),
        ([x * 2 for x in range(150)], range(300), 458),
        ([x * 3 for x in range(100)], range(300), 310),
    ],
)
def test_sorted_array_elements_size(elements, array, expected_size):
    count = 0
    for _, size in bitpacking.pack_sorted_array_elements(list(elements), list(array)):
        count += bitpacking._bits_for_number(size)
    assert count == expected_size


def test_pack_array_element_missing():
    with pytest.raises(ValueError, match="5 is not in list"):
        list(bitpacking.pack_array_element(5, [10, 25]))


def test_pack_array_element_single():
    assert len(list(bitpacking.pack_array_element("x", ["x"]))) == 0


@pytest.mark.parametrize(
    ("element", "array"),
    [
        (10, [10, 20]),
        ("x", [10, "x", 20]),
        ("x", ["x"]),
    ],
)
def test_array_elements_round_trip(element, array):
    generator = bitpacking.pack_array_element(element, array)
    b = bitpacking._pack_encode_results(list(generator))
    decoder = bitpacking.BitPackDecoder(b)

    decoded_element = decoder.decode_element(array)

    assert element == decoded_element


class BitPackValueUsingReference(bitpacking.BitPackValue):
    value: int

    def __init__(self, x):
        self.value = x

    def bit_pack_encode(self, metadata) -> Iterator[tuple[int, int]]:
        reference: BitPackValueUsingReference = metadata["reference"]
        yield self.value - reference.value, 128

    @classmethod
    def bit_pack_unpack(cls, decoder: BitPackDecoder, metadata):
        reference: BitPackValueUsingReference = metadata["reference"]
        value = decoder.decode_single(128) + reference.value
        return BitPackValueUsingReference(value)

    def __hash__(self) -> int:
        return hash(self.value)

    def __eq__(self, other):
        return self.value == other.value


@dataclasses.dataclass(frozen=True)
class DataclassForTest(bitpacking.BitPackDataclass):
    optional_int: int | None = dataclasses.field(metadata={"min": 0, "max": 15})
    uses_reference: BitPackValueUsingReference


@pytest.mark.parametrize(
    ("data_value", "data_reference"),
    [
        (5, 5),
        (20, 5),
        (50, 5),
    ],
)
@pytest.mark.parametrize(
    ("int_v", "int_reference"),
    [
        (15, 5),
        (12, None),
        (None, 20),
    ],
)
def test_round_trip_dataclass_for_test(int_v, data_value, int_reference, data_reference):
    data = DataclassForTest(int_v, BitPackValueUsingReference(data_value))
    ref = DataclassForTest(data_value, BitPackValueUsingReference(data_reference))

    result = bitpacking.round_trip(data, {"reference": ref})
    assert result == data


@pytest.mark.parametrize(
    ("int_v", "int_reference", "data_value", "data_reference", "expected"),
    [
        (None, 1, 5, 5, b"\x00"),
        (2, None, 5, 5, b"\xc8"),
        (None, None, 5, 5, b"\x00"),
        (1, None, 20, 5, b"\xc6<"),
        (5, 30, 50, 5, b"\xd6\xb4"),
    ],
)
def test_encode_dataclass_for_test(int_v, int_reference, data_value, data_reference, expected):
    data = DataclassForTest(int_v, BitPackValueUsingReference(data_value))
    ref = DataclassForTest(data_value, BitPackValueUsingReference(data_reference))

    result = bitpacking.pack_value(data, {"reference": ref})
    assert result == expected


@pytest.fixture
def json_fixture(request) -> tuple[JsonObject_RO, bytes]:
    value = {
        "string": "foo",
        "bool": True,
        "null": None,
        "integer": 1000,
        "number": 24.03,
        "array": [False, None],
        "object": {
            "key": None,
        },
    }
    encoded = (
        b"\x87\t\xcd\xd1\xc9\xa5\xb9\x9f6f\xf6\xf8"
        b"\x0cM\xed\xed\x87\x01\xb9\xd5\xb1\xb1C"
        b"integers8m\x10\x9b\x9d[X\x99\\\xa4D\x8e"
        b"\x17\xa1J\xe0s\x84\x08,.NL/!\x12\x84\xde"
        b"\xc4\xd4\xca\xc6\xe9Kkey@"
    )
    return typing.cast("JsonObject_RO", value), encoded


def test_encode_json(json_fixture):
    value, expected = json_fixture
    data = bitpacking.BitPackJson(value)

    result = bitpacking.pack_value(data)
    assert result == expected


def test_round_trip_json(json_fixture):
    value, _ = json_fixture
    data = bitpacking.BitPackJson(value)

    result = bitpacking.round_trip(data)
    assert result == value
