from collections.abc import Callable, Iterable, Iterator
from fractions import Fraction
from itertools import cycle
from typing import Final

import eth_abi.abi
from eth_typing import ChecksumAddress
from eth_utils.crypto import keccak
from hexbytes import HexBytes

from degenbot.checksum_cache import get_checksum_address
from degenbot.exceptions import DegenbotValueError
from degenbot.functions import create2_address, evm_divide
from degenbot.uniswap.v3_libraries import tick_bitmap
from degenbot.uniswap.v3_types import Pip


def decode_v3_path(path: bytes) -> list[ChecksumAddress | Pip]:
    """
    Decode the `path` bytes used by the Uniswap V3 Router/Router2 contracts. `path` is a
    close-packed encoding of 20 byte pool addresses, interleaved with 3 byte fees.
    """
    address_bytes: Final = 20
    fee_bytes: Final = 3

    def _extract_address(chunk: bytes) -> ChecksumAddress:
        return get_checksum_address(chunk)

    def _extract_fee(chunk: bytes) -> Pip:
        return int.from_bytes(chunk, byteorder="big")

    if any(
        [
            len(path) < address_bytes + fee_bytes + address_bytes,
            len(path) % (address_bytes + fee_bytes) != address_bytes,
        ]
    ):  # pragma: no cover
        raise DegenbotValueError(message="Invalid path.")

    chunk_length_and_decoder_function: Iterator[
        tuple[
            int,
            Callable[
                [bytes],
                ChecksumAddress | Pip,
            ],
        ]
    ] = cycle(
        [
            (address_bytes, _extract_address),
            (fee_bytes, _extract_fee),
        ]
    )

    path_offset = 0
    decoded_path: list[ChecksumAddress | Pip] = []
    while path_offset != len(path):
        byte_length, extraction_func = next(chunk_length_and_decoder_function)
        path_chunk = HexBytes(path[path_offset : path_offset + byte_length])
        decoded_path.append(extraction_func(path_chunk))
        path_offset += byte_length

    return decoded_path


def exchange_rate_from_sqrt_price_x96(sqrt_price_x96: int) -> Fraction:
    # ref: https://blog.uniswap.org/uniswap-v3-math-primer

    # equivalent to Fraction(sqrt_price_x96, 2**96)**2
    return Fraction(
        sqrt_price_x96 * sqrt_price_x96,
        6277101735386680763835789423207666416102355444464034512896,  # 2**192
    )


def generate_v3_pool_address(
    deployer_address: str | bytes,
    token_addresses: Iterable[str | bytes],
    fee: Pip,
    init_hash: str | bytes,
) -> ChecksumAddress:
    """
    Get the deterministic V3 pool address generated by CREATE2. Uses the token address and fee to
    generate the salt. The token addresses can be passed in any order.

    Adapted from https://github.com/Uniswap/v3-periphery/blob/0682387198a24c7cd63566a2c58398533860a5d1/contracts/libraries/PoolAddress.sol#L33
    """

    token_addresses = sorted([HexBytes(address) for address in token_addresses])

    salt = keccak(
        eth_abi.abi.encode(
            ("address", "address", "uint24"),
            (*token_addresses, fee),
        )
    )

    return create2_address(
        deployer=deployer_address,
        salt=salt,
        init_code_hash=init_hash,
    )


def get_tick_word_and_bit_position(
    tick: int,
    tick_spacing: int,
) -> tuple[int, int]:
    """
    Retrieves the word and bit position for the tick, accounting for tick spacing.
    """
    return tick_bitmap.position(evm_divide(tick, tick_spacing))
