import hmac
import logging
import os
from typing import Any

from cryptography.hazmat.primitives.ciphers.aead import AESGCM

from AlgoTuneTasks.base import register_task, Task


# Define standard key sizes and nonce size
AES_KEY_SIZES = [16, 24, 32]  # AES-128, AES-192, AES-256
GCM_NONCE_SIZE = 12  # Recommended nonce size for GCM
GCM_TAG_SIZE = 16  # Standard tag size for GCM


@register_task("aes_gcm_encryption")
class AesGcmEncryption(Task):
    """
    AesGcmEncryption Task:

    Encrypt plaintext using AES-GCM from the `cryptography` library.
    The difficulty scales with the size of the plaintext.
    """

    DEFAULT_KEY_SIZE = 16  # Default to AES-128
    DEFAULT_PLAINTEXT_MULTIPLIER = 1024  # Bytes of plaintext per unit of n

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def generate_problem(self, n: int, random_seed: int = 1) -> dict[str, Any]:
        """
        Generate inputs for the AES-GCM encryption task.

        Args:
            n (int): Scaling parameter, determines the plaintext size.
            random_seed (int): Seed for reproducibility (used for key, nonce, plaintext).

        Returns:
            dict: A dictionary containing the problem parameters (key, nonce, plaintext, associated_data).
        """
        # Use n to scale the plaintext size
        plaintext_size = max(1, n * self.DEFAULT_PLAINTEXT_MULTIPLIER)  # Ensure at least 1 byte

        # Use fixed key size and nonce size for simplicity in scaling
        key_size = self.DEFAULT_KEY_SIZE
        nonce_size = GCM_NONCE_SIZE

        logging.debug(
            f"Generating AES-GCM problem with n={n} (plaintext_size={plaintext_size}), "
            f"key_size={key_size}, nonce_size={nonce_size}, random_seed={random_seed}"
        )

        if key_size not in AES_KEY_SIZES:
            raise ValueError(f"Unsupported key size: {key_size}. Must be one of {AES_KEY_SIZES}.")

        # Use os.urandom for cryptographic randomness.
        # While random_seed is an input, for crypto tasks, using truly random
        # keys/nonces per problem instance is generally better practice,
        # even if it makes the benchmark non-deterministic w.r.t random_seed.
        # If strict determinism based on random_seed is required, we'd need
        # to use a seeded PRNG like random.getrandbits, but that's less secure.
        key = os.urandom(key_size)
        nonce = os.urandom(nonce_size)
        plaintext = os.urandom(plaintext_size)
        # Generate some associated data for testing, can be empty
        associated_data = os.urandom(32) if n % 2 == 0 else b""  # Example: include AAD for even n

        return {
            "key": key,
            "nonce": nonce,
            "plaintext": plaintext,
            "associated_data": associated_data,
        }

    def solve(self, problem: dict[str, Any]) -> dict[str, bytes]:
        """
        Encrypt the plaintext using AES-GCM from the `cryptography` library.

        Args:
            problem (dict): The problem dictionary generated by `generate_problem`.

        Returns:
            dict: A dictionary containing 'ciphertext' and 'tag'.
        """
        key = problem["key"]
        nonce = problem["nonce"]
        plaintext = problem["plaintext"]
        associated_data = problem["associated_data"]

        try:
            # Validate key size based on provided key length
            if len(key) not in AES_KEY_SIZES:
                raise ValueError(f"Invalid key size: {len(key)}. Must be one of {AES_KEY_SIZES}.")

            aesgcm = AESGCM(key)
            ciphertext = aesgcm.encrypt(nonce, plaintext, associated_data)

            # GCM ciphertext includes the tag appended at the end. We need to split them.
            # The tag length is fixed (usually 16 bytes / 128 bits).
            if len(ciphertext) < GCM_TAG_SIZE:
                raise ValueError("Encrypted output is shorter than the expected tag size.")

            actual_ciphertext = ciphertext[:-GCM_TAG_SIZE]
            tag = ciphertext[-GCM_TAG_SIZE:]

            return {"ciphertext": actual_ciphertext, "tag": tag}

        except Exception as e:
            logging.error(f"Error during AES-GCM encryption in solve: {e}")
            raise  # Re-raise exception

    def is_solution(self, problem: dict[str, Any], solution: dict[str, bytes] | Any) -> bool:
        """
        Verify the provided solution by comparing its ciphertext and tag
        against the result obtained from calling the task's own solve() method.

        Args:
            problem (dict): The problem dictionary.
            solution (dict): The proposed solution dictionary with 'ciphertext' and 'tag'.

        Returns:
            bool: True if the solution matches the result from self.solve().
        """
        if not isinstance(solution, dict) or "ciphertext" not in solution or "tag" not in solution:
            logging.error(
                f"Invalid solution format. Expected dict with 'ciphertext' and 'tag'. Got: {type(solution)}"
            )
            return False

        try:
            # Get the correct result by calling the solve method
            reference_result = self.solve(problem)
            reference_ciphertext = reference_result["ciphertext"]
            reference_tag = reference_result["tag"]
        except Exception as e:
            # If solve itself fails, we cannot verify the solution
            logging.error(f"Failed to generate reference solution in is_solution: {e}")
            return False

        solution_ciphertext = solution["ciphertext"]
        solution_tag = solution["tag"]

        # Ensure types are bytes before comparison
        if not isinstance(solution_ciphertext, bytes) or not isinstance(solution_tag, bytes):
            logging.error("Solution 'ciphertext' or 'tag' is not bytes.")
            return False

        # Constant-time comparison for security
        ciphertext_match = hmac.compare_digest(reference_ciphertext, solution_ciphertext)
        tag_match = hmac.compare_digest(reference_tag, solution_tag)

        return ciphertext_match and tag_match
