from dataclasses import dataclass

from ndice import Dice, nd

from sd.gear import Grip, Weapon, WeaponMode

from .talents import backstab, Talent


@dataclass(frozen=True, slots=True)
class DamageRoll:
    one_handed: list[Dice]
    one_handed_backstab: list[Dice]
    two_handed: list[Dice]

    def __post_init__(self) -> None:
        assert self.one_handed or self.one_handed_backstab or self.two_handed

    def __str__(self) -> str:
        parts = []
        if self.one_handed:
            parts.append(f'1h {_dice_str(self.one_handed)}')
        if self.one_handed_backstab:
            parts.append(f'backstab {_dice_str(self.one_handed_backstab)}')
        if self.two_handed:
            parts.append(f'2h {_dice_str(self.two_handed)}')
        return ', '.join(parts)


def make_damage_roll(
    weapon: Weapon,
    mode: WeaponMode,
    mods: list[Dice],
    level: int,
    talents: list[Talent],
) -> DamageRoll:
    if isinstance(weapon.damage, tuple):
        dice_1h, dice_2h = weapon.damage
        one_handed = [dice_1h] + mods
        one_handed_backstab = []
        two_handed = [dice_2h] + mods
    elif Grip.ONE_HANDED == weapon.grip:
        one_handed = [weapon.damage] + mods
        one_handed_backstab = _backstab(weapon, mode, mods, level, talents)
        two_handed = []
    elif Grip.TWO_HANDED == weapon.grip:
        one_handed = []
        one_handed_backstab = []
        two_handed = [weapon.damage] + mods
    else:
        raise RuntimeError(f'Unexpected weapon grip {weapon.grip}')

    return DamageRoll(one_handed, one_handed_backstab, two_handed)


def _backstab(
    weapon: Weapon,
    mode: WeaponMode,
    mods: list[Dice],
    level: int,
    talents: list[Talent],
) -> list[Dice]:
    if WeaponMode.MELEE == mode and backstab in talents:
        assert isinstance(weapon.damage, Dice)
        number = weapon.damage.number + 1 + level // 2
        sides = weapon.damage.sides
        return [nd(number, sides)] + mods
    else:
        return []


def _dice_str(expression: list[Dice]) -> str:
    return ''.join([str(dice) for dice in expression])
