from contextlib import nullcontext
import pytest

from sansio_tools.parser import BinaryParser, UnexpectedEOFError


def test_parse_no_generator():
    with pytest.raises(AssertionError):
        p = BinaryParser()
        p.feed(b"")


def test_parse_simple():
    results = []

    def gen(p: BinaryParser):
        b = yield from p.read_bytes(6)
        results.append(b)

        n = yield from p.read_int(2, "little", False)
        results.append(n)

    p = BinaryParser(gen)
    assert len(results) == 0

    p.feed(b"01234")
    assert len(results) == 0

    p.feed(b"5\x07")
    assert len(results) == 1
    assert results[0] == b"012345"
    assert not p.generator_exited

    p.feed(b"\xffex")
    assert len(results) == 2
    assert p.generator_exited
    assert len(p.queue) == 2

    p.feed(b"tra")
    assert len(p.queue) == 5
    assert not p.eof

    p.feed(b"")
    assert p.eof
    assert len(p.queue) == 5


def test_parse_varint_special():
    def _f(data, bo, cbit, canon):
        def parser(p: BinaryParser):
            n1 = yield from p.read_variable_length_int_7bit(
                9, byteorder=bo, continuation_bit_value=cbit, require_canonical=canon
            )
            results.append(n1)

        p = BinaryParser(parser)
        results = []
        p.feed(data)
        return results[0]

    assert _f(b"\x82\x03", bo="big", cbit=True, canon=True) == 256 + 3
    assert _f(b"\x02\x83", bo="big", cbit=False, canon=True) == 256 + 3

    assert _f(b"\x80\x03", bo="big", cbit=True, canon=False) == 3
    with pytest.raises(ValueError, match="canonical"):
        _f(b"\x80\x03", bo="big", cbit=True, canon=True)

    assert _f(b"\x83\x00", bo="little", cbit=True, canon=False) == 3
    with pytest.raises(ValueError, match="canonical"):
        _f(b"\x83\x00", bo="little", cbit=True, canon=True)


@pytest.mark.parametrize("n", {o + 2**k for o in {-1, 0, 1} for k in (0, 7, 14, 21)})
@pytest.mark.parametrize("byteorder", ["little", "big"])
@pytest.mark.parametrize("split", [True, False])
def test_parse_varint(n, byteorder, split):
    N = n
    r = []
    while True:
        r.append(n & 127)
        n >>= 7
        if not n:
            break
    if byteorder == "big":
        r.reverse()
    for i in range(len(r) - 1):
        r[i] = r[i] | 128

    parser_results = []
    feed_input = bytes(r) + b"\xff"

    def parser(p: BinaryParser):
        n1 = yield from p.read_variable_length_int_7bit(
            2, byteorder, continuation_bit_value=True, require_canonical=True
        )
        parser_results.append(n1)

    def feed():
        if split:
            for i in range(len(feed_input)):
                p.feed(feed_input[i : i + 1])
        else:
            p.feed(feed_input)

    p = BinaryParser(parser)
    if N >= 2**14:
        with pytest.raises(ValueError):
            feed()
        if split:
            expected_queue = feed_input[:2]
        else:
            expected_queue = feed_input
        assert not parser_results
    else:
        feed()
        expected_queue = feed_input[-1:]
        assert parser_results[0] == N

    assert b"".join(p.queue.data) == expected_queue


@pytest.mark.parametrize("i,j", [(i, j) for i in range(1, 8) for j in range(i + 1, 8 + 1)])
@pytest.mark.parametrize("exception", [False, True])
def test_parse_one(i, j, exception):
    data = memoryview(bytearray(b"01234567"))
    if exception:
        data[0] = 0
    results = []

    def parser(p: BinaryParser):
        results.append(x := (yield from p.read_bytes(5)))
        if x[0] == 0:
            print(1 // 0)
        results.append((yield from p.read_bytes(3)))

    p = BinaryParser(parser)
    with pytest.raises(ZeroDivisionError) if exception and (j >= 5) else nullcontext():
        p.feed(data[:i])
        p.feed(data[i:j])

    if not exception:
        if j == 8:
            p.feed(b"")
            assert len(results) == 2
        else:
            with pytest.raises(UnexpectedEOFError):
                p.feed(b"")
