"""This module provides async functionality for commit reveal in the Bittensor network."""

from typing import Optional, Union, TYPE_CHECKING

import numpy as np
from bittensor_drand import get_encrypted_commit
from numpy.typing import NDArray

from bittensor.core.settings import version_as_int
from bittensor.utils.btlogging import logging
from bittensor.utils.weight_utils import convert_and_normalize_weights_and_uids

if TYPE_CHECKING:
    from bittensor_wallet import Wallet
    from bittensor.core.async_subtensor import AsyncSubtensor
    from bittensor.utils.registration import torch


# TODO: Merge this logic with `commit_reveal_extrinsic` in SDKv10 bc this is not CRv3 anymore.
async def _do_commit_reveal_v3(
    subtensor: "AsyncSubtensor",
    wallet: "Wallet",
    netuid: int,
    commit: bytes,
    reveal_round: int,
    commit_reveal_version: int = 4,
    wait_for_inclusion: bool = False,
    wait_for_finalization: bool = False,
    period: Optional[int] = None,
) -> tuple[bool, str]:
    """
    Executes commit-reveal phase 3 for a given netuid and commit, and optionally waits for extrinsic inclusion or finalization.

    Arguments:
        subtensor: An instance of the AsyncSubtensor class.
        wallet: Wallet An instance of the Wallet class containing the user's keypair.
        netuid: int The network unique identifier.
        commit:  bytes The commit data in bytes format.
        reveal_round: int The round number for the reveal phase.
        commit_reveal_version: The version of the chain commit-reveal protocol to use. Default is ``4``.
        wait_for_inclusion: bool, optional Flag indicating whether to wait for the extrinsic to be included in a block.
        wait_for_finalization: bool, optional Flag indicating whether to wait for the extrinsic to be finalized.
        period (Optional[int]): The number of blocks during which the transaction will remain valid after it's submitted. If
            the transaction is not included in a block within that number of blocks, it will expire and be rejected.
            You can think of it as an expiration date for the transaction.

    Returns:
        A tuple where the first element is a boolean indicating success or failure, and the second element is a
            string containing an error message if any.
    """
    logging.info(
        f"Committing weights hash [blue]{commit.hex()}[/blue] for subnet #[blue]{netuid}[/blue] with "
        f"reveal round [blue]{reveal_round}[/blue]..."
    )

    call = await subtensor.substrate.compose_call(
        call_module="SubtensorModule",
        call_function="commit_timelocked_weights",
        call_params={
            "netuid": netuid,
            "commit": commit,
            "reveal_round": reveal_round,
            "commit_reveal_version": commit_reveal_version,
        },
    )
    return await subtensor.sign_and_send_extrinsic(
        call=call,
        wallet=wallet,
        wait_for_inclusion=wait_for_inclusion,
        wait_for_finalization=wait_for_finalization,
        sign_with="hotkey",
        period=period,
    )


# TODO: rename this extrinsic to `commit_reveal_extrinsic` in SDK.v10
async def commit_reveal_v3_extrinsic(
    subtensor: "AsyncSubtensor",
    wallet: "Wallet",
    netuid: int,
    uids: Union[NDArray[np.int64], "torch.LongTensor", list],
    weights: Union[NDArray[np.float32], "torch.FloatTensor", list],
    version_key: int = version_as_int,
    wait_for_inclusion: bool = False,
    wait_for_finalization: bool = False,
    block_time: Union[int, float] = 12.0,
    period: Optional[int] = None,
) -> tuple[bool, str]:
    """
    Commits and reveals weights for a given subtensor and wallet with provided uids and weights.

    Arguments:
        subtensor: The AsyncSubtensor instance.
        wallet: The wallet to use for committing and revealing.
        netuid: The id of the network.
        uids: The uids to commit.
        weights: The weights associated with the uids.
        version_key: The version key to use for committing and revealing. Default is version_as_int.
        wait_for_inclusion: Whether to wait for the inclusion of the transaction. Default is False.
        wait_for_finalization: Whether to wait for the finalization of the transaction. Default is False.
        block_time (float): The number of seconds for block duration. Default is 12.0 seconds.
        period (Optional[int]): The number of blocks during which the transaction will remain valid after it's submitted. If
            the transaction is not included in a block within that number of blocks, it will expire and be rejected.
            You can think of it as an expiration date for the transaction.

    Returns:
        tuple[bool, str]: A tuple where the first element is a boolean indicating success or failure, and the second
            element is a message associated with the result
    """
    try:
        uids, weights = convert_and_normalize_weights_and_uids(uids, weights)

        current_block = await subtensor.substrate.get_block(None)
        subnet_hyperparameters = await subtensor.get_subnet_hyperparameters(
            netuid, block_hash=current_block["header"]["hash"]
        )
        tempo = subnet_hyperparameters.tempo
        subnet_reveal_period_epochs = subnet_hyperparameters.commit_reveal_period

        # Encrypt `commit_hash` with t-lock and `get reveal_round`
        commit_for_reveal, reveal_round = get_encrypted_commit(
            uids=uids,
            weights=weights,
            version_key=version_key,
            tempo=tempo,
            current_block=current_block["header"]["number"],
            netuid=netuid,
            subnet_reveal_period_epochs=subnet_reveal_period_epochs,
            block_time=block_time,
            hotkey=wallet.hotkey.public_key,
        )

        success, message = await _do_commit_reveal_v3(
            subtensor=subtensor,
            wallet=wallet,
            netuid=netuid,
            commit=commit_for_reveal,
            reveal_round=reveal_round,
            wait_for_inclusion=wait_for_inclusion,
            wait_for_finalization=wait_for_finalization,
            period=period,
        )

        if not success:
            logging.error(message)
            return False, message

        logging.success(
            f"[green]Finalized![/green] Weights committed with reveal round [blue]{reveal_round}[/blue]."
        )
        return True, f"reveal_round:{reveal_round}"

    except Exception as e:
        logging.error(f":cross_mark: [red]Failed. Error:[/red] {e}")
        return False, str(e)
