from __future__ import annotations

import abc
import struct as _struct
import typing as ty

from .queue import SolidQueue, BytesQueue


class UnexpectedEOFError(ValueError):
    pass


def _complain(p):
    raise AssertionError(f"you must set `{str(type(p).__name__)}.generator`")
    yield


T = ty.TypeVar("T")


class Parser(ty.Generic[T]):
    queue: SolidQueue[T]
    eof: bool = False
    generator_exited = False

    def __init__(self, generator: ty.Callable[[Parser[T]], ty.Iterator[ty.Any]] = None):
        if generator is None:
            generator = _complain
        self.generator = generator(self)
        self._init_queue()

    @abc.abstractmethod
    def _init_queue(self):
        ...

    @property
    def generator(self):
        return self._generator

    @generator.setter
    def generator(self, value):
        self._generator = value
        self._generator_wrapped = self._wrap_generator(value)

    def feed_without_parsing(self, data: T) -> None:
        if data:
            self.queue.append(data)
        else:
            self.eof = True

    def feed(self, data: T) -> None:
        """
        Add *data* to the internal buffer :attr:`queue`, then call :meth:`advance` to parse as much as possible.

        If *data* is falsy (empty bytestring for example), then set :attr:`eof` to True.
        """
        self.feed_without_parsing(data)
        self.advance()

    def _wrap_generator(self, gen):
        yield from gen
        self.generator_exited = True
        while True:
            yield

    def advance_one(self) -> bool:
        """
        Return whether any progress was made.

        We consider that progress was made if the generator yielded a truthy value or if the queue became smaller.

        If no progress was made and :attr:`eof` is True and the generator has not yet exited, then raise
        :exc:`UnexpectedEOFError`.
        """
        q = self.queue
        n = len(q)
        result = next(self._generator_wrapped)
        progress = bool(result) or len(q) < n
        if not progress and self.eof and not self.generator_exited:
            raise UnexpectedEOFError
        return progress

    def advance(self) -> None:
        """
        Call :meth:`advance_one` to progress the parsing until no more progress can be made.
        """
        while self.advance_one():
            pass


class BinaryParser(Parser[bytes | bytearray | memoryview]):
    def _init_queue(self):
        self.queue = BytesQueue()

    def _read_bytes(self, nbytes: int) -> ty.Generator[None, None, bytes | memoryview | bytearray]:
        yield from self.wait(nbytes)
        return self.queue.popleft(nbytes)

    def read_bytes(self, nbytes: int) -> ty.Generator[None, None, bytes]:
        return bytes((yield from self._read_bytes(nbytes)))

    def wait(self, nbytes: int) -> ty.Generator[None, None, None]:
        q = self.queue
        while len(q) < nbytes:
            yield

    def read_int(self, nbytes: int, byteorder: str, signed: bool) -> ty.Generator[None, None, int]:
        yield from self.wait(nbytes)
        with self.queue.popleft_after(nbytes) as b:
            return int.from_bytes(b, byteorder=byteorder, signed=signed)

    def read_struct(self, struct: _struct.Struct) -> ty.Generator[None, None, list]:
        yield from self.wait(nbytes := struct.size)
        with self.queue.popleft_after(nbytes) as b:
            return struct.unpack(b)

    def read_variable_length_int_7bit(
        self, maximum_length: int, byteorder: str, continuation_bit_value: bool, require_canonical: bool
    ):
        with (q := self.queue).temporary_left() as tmp:
            integers = []
            left = maximum_length
            need_more = True
            while need_more:
                yield from self.wait(1)
                b = q.popleft_any_to(tmp)
                for i, c in enumerate(b):
                    integers.append(c & 127)
                    if (c < 128) == continuation_bit_value:
                        tmp.pop_to(len(b) - i - 1, q)
                        need_more = False
                        break
                    left -= 1
                    if left <= 0:
                        raise ValueError("integer too long")

            # Now we assemble it back into one integer.
            if byteorder == "big":
                top = integers[0]
                integers_ = reversed(integers)
            elif byteorder == "little":
                top = integers[-1]
                integers_ = integers
            else:
                raise ValueError('byteorder must be "big" or "little"')

            if require_canonical and top == 0 and len(integers) > 1:
                raise ValueError("non-canonical encoding")

            result = sum(x << (i * 7) for i, x in enumerate(integers_))

            tmp.clear()
            return result
