from __future__ import annotations

from dataclasses import dataclass
from functools import cached_property
from typing import Sequence

from ndice import d, Dice, min_roll, RNG, roll

from .row import RangedValue, Row


@dataclass(frozen=True)
class Table[T]:
    rows: list[Row[T]]
    dice: Dice | None = None

    @cached_property
    def total_weight(self) -> int:
        return sum(row.weight for row in self.rows)

    def __post_init__(self) -> None:
        assert self.rows

    @classmethod
    def with_dice_and_ranges(
        cls, dice: Dice, *ranged_values: RangedValue[T]
    ) -> Table[T]:
        assert ranged_values
        return cls(
            rows=[
                Row.from_ranged_value(ranged_value)
                for ranged_value in ranged_values
            ],
            dice=dice,
        )

    @classmethod
    def with_ranges(cls, *ranged_values: RangedValue[T]) -> Table[T]:
        assert ranged_values
        return cls(
            rows=[
                Row.from_ranged_value(ranged_value)
                for ranged_value in ranged_values
            ]
        )

    @classmethod
    def with_values(cls, *values: T) -> Table[T]:
        assert values
        return cls(rows=[Row(value, 1) for value in values])

    def pick(
        self, rng: RNG, count: int, *, skip: Sequence[T] = tuple()
    ) -> list[T]:
        table = self.without_values(*skip)
        assert table
        assert 0 <= count <= len(table.rows)

        choices: list[T] = []
        for _ in range(count):
            assert table
            choice = table.roll(rng)
            choices.append(choice)
            table = table.without_values(choice)
        return choices

    def roll(self, rng: RNG) -> T:
        dice = self.dice or d(self.total_weight)
        die_roll = roll(rng, dice) - min_roll(dice) + 1

        for row in self.rows:
            die_roll -= row.weight
            if die_roll <= 0:
                return row.value
        assert False, 'Unreachable'

    def without_values(self, *values: T) -> Table[T] | None:
        if values:
            rows = [row for row in self.rows if row.value not in values]
            return self.__class__(rows=rows) if rows else None
        else:
            return self
