"""Library defining the interface to the generative firewall."""
import atexit
import importlib
import time
from typing import Any, Iterable, List, Optional

from urllib3.util import Retry

from rime_sdk.authenticator import Authenticator
from rime_sdk.client import RETRY_HTTP_STATUS
from rime_sdk.internal.rest_error_handler import RESTErrorHandler
from rime_sdk.internal.utils import is_positive_number, remove_null_values_from_dict
from rime_sdk.swagger.swagger_client import (
    ApiClient,
    ApigenerativefirewallListFirewallInstancesResponse,
    Configuration,
    FirewallApi,
    FirewallinstanceFirewallInstanceFirewallInstanceIdUuidBody,
    FirewallInstanceIdUuidValidateBody,
    FirewallInstanceManagerApi,
    GenerativefirewallCreateFirewallInstanceRequest,
    GenerativefirewallCreateFirewallInstanceResponse,
    GenerativefirewallFirewallInstanceInfo,
    GenerativefirewallFirewallInstanceStatus,
    GenerativefirewallFirewallRuleConfig,
    GenerativefirewallGetFirewallEffectiveConfigResponse,
    GenerativefirewallGetFirewallInstanceResponse,
    RimeUUID,
)
from rime_sdk.swagger.swagger_client.models.generativefirewall_firewall_rule_type import (
    GenerativefirewallFirewallRuleType,
)
from rime_sdk.swagger.swagger_client.models.generativefirewall_individual_rules_config import (
    GenerativefirewallIndividualRulesConfig,
)
from rime_sdk.swagger.swagger_client.models.rime_language import RimeLanguage

_DEFAULT_CHANNEL_TIMEOUT = 60.0


VALID_LANGUAGES = [
    RimeLanguage.EN,
    RimeLanguage.JA,
]

_DEFAULT_BLOCK_UNTIL_READY_TIMEOUT_SEC = 300.0
_CREATE_FIREWALL_INSTANCE_DEFAULT_BLOCK_UNTIL_READY_TIMEOUT_SEC = 900.0


def _get_firewall_instance_info(
    instance_manager_client: FirewallInstanceManagerApi, firewall_instance_id: str
) -> GenerativefirewallFirewallInstanceInfo:
    with RESTErrorHandler(is_generative_firewall=True):
        res: GenerativefirewallGetFirewallInstanceResponse = (
            instance_manager_client.get_instance(firewall_instance_id)
        )
        fw_instance: GenerativefirewallFirewallInstanceInfo = res.firewall_instance
    return fw_instance


def _validate_individual_rules_config(individual_rules_config: Any) -> None:
    """Validate the raw individual rules config provided by the user.

    Raises an exception if the configuration is malformed.
    """
    if not isinstance(individual_rules_config, dict):
        raise TypeError("individual rules config must be a dict")
    swagger_module = importlib.import_module("rime_sdk.swagger.swagger_client")
    for k, v in individual_rules_config.items():
        if k not in GenerativefirewallIndividualRulesConfig.swagger_types:
            raise ValueError(
                f"Found unrecognized rule '{k}' in `rule_config.individual_rules_config`. "
                f"The list of accepted individual rule configs is: {list(GenerativefirewallIndividualRulesConfig.swagger_types.keys())}"
            )
        # Dynamically load the Python autogenerated Swagger class from the
        # top-level Swagger module, which makes all models accessible.
        # This is only janky because each `swagger_types` stores the string
        # name of the class but not a reference to the class itself.
        desired_class = getattr(
            swagger_module,
            GenerativefirewallIndividualRulesConfig.swagger_types[k],
        )
        swagger_types: dict = desired_class.swagger_types  # type: ignore[attr-defined]
        if not isinstance(v, dict):
            raise TypeError(
                "Each individual rule config must be a dictionary. "
                f"Value '{v}' for rule config '{k}' is not a dictionary"
            )
        unrecognized_keys = set(v.keys()) - set(swagger_types.keys())
        if len(unrecognized_keys) > 0:
            keys_str = ", ".join(unrecognized_keys)
            raise ValueError(
                f"Provided individual rule config for rule '{k}' has unrecognized keys '{keys_str}'. "
                f"Desired structure for rule '{k}': {swagger_types}"
            )


def _validate_selected_rules(selected_rules: Any) -> None:
    """Validate the selected rules in a user's rule config."""
    if not isinstance(selected_rules, list):
        raise TypeError("selected_rules must be a list")

    valid_enum_values = [
        getattr(GenerativefirewallFirewallRuleType, attr)
        for attr in dir(GenerativefirewallFirewallRuleType)
        if attr.isupper() and not attr == "UNSPECIFIED"
    ]

    diff = set(selected_rules) - set(valid_enum_values)
    if len(diff) > 0:
        diff_str = ", ".join(diff)
        valid_enum_values_str = ", ".join(valid_enum_values)
        raise ValueError(
            f"Unrecognized firewall rule type enum values '{diff_str}'. "
            f"Valid firewall rule types are '{valid_enum_values_str}'."
        )


def _get_validated_firewall_config(
    config: dict,
) -> GenerativefirewallFirewallRuleConfig:
    _config = config.copy()
    language = _config.pop("language", None)
    if language is not None and language not in VALID_LANGUAGES:
        raise ValueError(
            f"Provided language {language} is invalid, please choose one of the "
            f"following values {VALID_LANGUAGES}"
        )

    individual_rules_config = _config.pop("individual_rules_config", None)
    if individual_rules_config is not None:
        _validate_individual_rules_config(individual_rules_config)

    selected_rules = _config.pop("selected_rules", None)
    if selected_rules is not None:
        _validate_selected_rules(selected_rules)

    if _config:
        raise ValueError(
            f"Found unexpected keys in `rule_config`: {list(_config.keys())}"
        )
    return GenerativefirewallFirewallRuleConfig(
        individual_rules_config=individual_rules_config,
        selected_rules=selected_rules,
        language=language,
    )


def _validate_block_until_ready_params(
    block_until_ready: bool = True,
    block_until_ready_timeout_sec: Optional[float] = None,
    block_until_ready_poll_rate_sec: Optional[float] = None,
    block_until_consecutive_ready_count: Optional[int] = None,
) -> None:
    if block_until_ready:
        if block_until_ready_timeout_sec is not None and not is_positive_number(
            block_until_ready_timeout_sec
        ):
            raise ValueError(
                "Invalid value for 'block_until_ready_timeout_sec'"
                f": {block_until_ready_timeout_sec}. The value must be a positive number."
            )
        if block_until_ready_poll_rate_sec is not None and not is_positive_number(
            block_until_ready_poll_rate_sec
        ):
            raise ValueError(
                "Invalid value for 'block_until_ready_poll_rate_sec'"
                f": {block_until_ready_poll_rate_sec}. The value must be a positive number."
            )
        if block_until_consecutive_ready_count is not None and not is_positive_number(
            block_until_consecutive_ready_count
        ):
            raise ValueError(
                "Invalid value for 'block_until_consecutive_ready_count'"
                f": {block_until_consecutive_ready_count}. The value must be a positive number."
            )


class FirewallInstance:
    """An interface to a single instance of the firewall running on a cluster.

    Each FirewallInstance has its own rule configuration and can be accessed
    by its unique ID.
    This allows users to customize the behavior of the firewall for different
    use cases.
    Note: FirewallInstance should not be instantiated directly, but instead
    instantiated through methods of the FirewallClient.

    Args:
        firewall_instance_id: str
            The unique ID of the FirewallInstance.
        api_client: ApiClient
            API client for interacting with the cluster.
    """

    def __init__(
        self,
        firewall_instance_id: str,
        api_client: ApiClient,
    ) -> None:
        """Initialize a new FirewallInstance."""
        self._firewall_instance_id = firewall_instance_id

        self._api_client = api_client
        self._firewall_client = FirewallApi(self._api_client)
        self._instance_manager_client = FirewallInstanceManagerApi(self._api_client)

    def validate(
        self,
        user_input_text: Optional[str] = None,
        output_text: Optional[str] = None,
        contexts: Optional[List[str]] = None,
    ) -> dict:
        """Validate model input and/or output text."""
        if user_input_text is None and output_text is None and contexts is None:
            raise ValueError(
                "Must provide either input text, output text, or context documents to validate."
            )

        body = FirewallInstanceIdUuidValidateBody(
            user_input_text=user_input_text,
            contexts=contexts,
            output_text=output_text,
            firewall_instance_id=RimeUUID(uuid=self.firewall_instance_id),
        )
        with RESTErrorHandler(is_generative_firewall=True):
            response = self._firewall_client.validate(
                body=body, firewall_instance_id_uuid=self.firewall_instance_id
            )
        res_dict = response.to_dict()
        remove_null_values_from_dict(res_dict)
        return res_dict

    def get_effective_config(self) -> dict:
        """Get the effective configuration for the FirewallInstance.

        This effective configuration has default values filled in and shows what
        is actually being used at runtime.
        """
        with RESTErrorHandler(is_generative_firewall=True):
            res: GenerativefirewallGetFirewallEffectiveConfigResponse = (
                self._firewall_client.effective_config(
                    firewall_instance_id_uuid=self.firewall_instance_id
                )
            )
        if res.config is None:
            return {}
        return res.config.to_dict()

    def block_until_ready(
        self,
        verbose: bool = True,
        timeout_sec: float = _DEFAULT_BLOCK_UNTIL_READY_TIMEOUT_SEC,
        poll_rate_sec: float = 5.0,
        consecutive_ready_count: int = 1,
    ) -> None:
        """Block until ready blocks until the FirewallInstance is ready.

        Raises:
            TimeoutError
                This error is raised if the FirewallInstance is not ready by the deadline
                set through `timeout_sec`.
        """
        if not is_positive_number(timeout_sec):
            raise ValueError(
                f"Invalid value for 'timeout_sec': {timeout_sec}. "
                f"The value must be a positive number."
            )
        if not is_positive_number(poll_rate_sec):
            raise ValueError(
                f"Invalid value for 'poll_rate_sec': {poll_rate_sec}. "
                f"The value must be a positive number."
            )
        if not is_positive_number(consecutive_ready_count):
            raise ValueError(
                f"Invalid value for 'consecutive_ready_count': {consecutive_ready_count}. "
                f"The value must be a positive number."
            )
        start_time = time.time()
        if verbose:
            print(
                "Waiting until FirewallInstance {} is ready with timeout {}s".format(
                    self.firewall_instance_id,
                    timeout_sec,
                )
            )
        ready_count = 0
        while True:
            status = self.status

            # Get the total time the caller has been waiting and print it out.
            cur_time = time.time()
            elapsed_time = cur_time - start_time
            if verbose:
                minute, second = divmod(elapsed_time, 60)
                print(
                    "\rStatus: {}, Time Elapsed: {:02}:{:05.2f}".format(
                        status,
                        int(minute),
                        second,
                    ),
                    end="",
                )

            if status == GenerativefirewallFirewallInstanceStatus.READY:
                ready_count += 1
            else:
                ready_count = 0
            # Either the firewall reaches the ready status or we are ready to
            # time out.
            if ready_count >= consecutive_ready_count or elapsed_time >= timeout_sec:
                break

            time.sleep(poll_rate_sec)

        if verbose:
            # Print an extra line because the status has a carriage return.
            print()

        if status != GenerativefirewallFirewallInstanceStatus.READY:
            raise TimeoutError(
                f"FirewallInstance did not reach status READY by the timeout {timeout_sec}s"
            )

    def update_firewall_instance(
        self,
        config: Optional[dict] = None,
        description: Optional[str] = None,
        block_until_ready: bool = True,
        block_until_ready_verbose: Optional[bool] = None,
        block_until_ready_timeout_sec: Optional[float] = None,
        block_until_ready_poll_rate_sec: Optional[float] = None,
        block_until_consecutive_ready_count: Optional[int] = 2,
    ) -> None:
        """Update the config or description of the FirewallInstance.

        Args:
            config: str
                New config for the FirewallInstance.

            description: str
                New description for the FirewallInstance.

            block_until_ready: bool = True
                Whether to block until the FirewallInstance is ready.

            block_until_ready_verbose: Optional[bool] = None
                Whether to print out information while waiting for the FirewallInstance to come up.

            block_until_ready_timeout_sec: Optional[float] = None
                How many seconds to wait until the FirewallInstance comes up before timing out.

            block_until_ready_poll_rate_sec: Optional[float] = None
                How often to poll the FirewallInstance status.

            block_until_consecutive_ready_count: Optional[int] = 2
                Number of consecutive READY status poll to wait for, before returning.
        """
        if config is None and description is None:
            raise ValueError("A new `config` or `description` must be provided.")
        _validate_block_until_ready_params(
            block_until_ready,
            block_until_ready_timeout_sec,
            block_until_ready_poll_rate_sec,
            block_until_consecutive_ready_count,
        )
        final_conf = (
            _get_validated_firewall_config(config) if config is not None else None
        )
        with RESTErrorHandler(is_generative_firewall=True):
            body = FirewallinstanceFirewallInstanceFirewallInstanceIdUuidBody(
                config=final_conf,
                description=description,
            )
            self._instance_manager_client.update_instance(
                body,
                self.firewall_instance_id,
            )

        if block_until_ready:
            # Forward keyword arguments to the FirewallInstance.block_until_ready call.
            block_until_ready_kwargs: dict[str, Any] = {}
            if block_until_ready_verbose:
                block_until_ready_kwargs["verbose"] = block_until_ready_verbose
            if block_until_ready_timeout_sec:
                block_until_ready_kwargs["timeout_sec"] = block_until_ready_timeout_sec
            if block_until_ready_poll_rate_sec:
                block_until_ready_kwargs[
                    "poll_rate_sec"
                ] = block_until_ready_poll_rate_sec
            if block_until_consecutive_ready_count:
                block_until_ready_kwargs[
                    "consecutive_ready_count"
                ] = block_until_consecutive_ready_count
            self.block_until_ready(**block_until_ready_kwargs)

    @property
    def rule_config(self) -> dict:
        """Access the rule config of the FirewallInstance.

        This config is immutable after it is created.
        """
        fw_instance = _get_firewall_instance_info(
            self._instance_manager_client,
            self.firewall_instance_id,
        )
        conf = fw_instance.config
        return conf.to_dict() if conf is not None else {}

    @property
    def firewall_instance_id(self) -> str:
        """Access the UUID of the FirewallInstance."""
        return self._firewall_instance_id

    @property
    def status(self) -> str:
        """Access the current status of the FirewallInstance."""
        fw_instance = _get_firewall_instance_info(
            self._instance_manager_client,
            self.firewall_instance_id,
        )
        return fw_instance.deployment_status

    @property
    def description(self) -> str:
        """Access the description of the FirewallInstance."""
        fw_instance = _get_firewall_instance_info(
            self._instance_manager_client,
            self.firewall_instance_id,
        )
        return fw_instance.description if fw_instance.description else ""

    def __str__(self) -> str:
        """String representation of the FirewallInstance."""
        return f'FirewallInstance(id="{self.firewall_instance_id}", description="{self.description}")'


class FirewallClient:
    """An interface to connect to FirewallInstances on a firewall cluster.

    Create a firewall instance by specifying the rule configuration.
    It will take anywhere from a few seconds to a few minutes to spin up, but
    once it is ready, it can respond to validation requests with the custom
    configuration.
    A single firewall cluster can have many firewall instances.
    They are independent from each other.

    Args:
        domain: str
            The base domain/address of the firewall.
        auth_token: str
            The auth token is generated in the Firewall UI and is used to authenticate
            to the firewall. If the auth_token is provided, you do not need to provide
            an api_key. Auth tokens are only available when the Firewall UI has been
            enabled.
        channel_timeout: float
            The amount of time in seconds to wait for responses from the firewall.
    """

    def __init__(
        self,
        domain: str,
        auth_token: str = "",
        channel_timeout: float = _DEFAULT_CHANNEL_TIMEOUT,
    ):
        """Create a new Client connected to the services available at `domain`."""
        configuration = Configuration()
        configuration.api_key["X-Firewall-Auth-Token"] = auth_token
        if domain.endswith("/"):
            domain = domain[:-1]
        if not domain.startswith("https://") and not domain.startswith("http://"):
            domain = "https://" + domain
        configuration.host = domain
        self._api_client = ApiClient(configuration)
        # Prevent race condition in pool.close() triggered by swagger generated code
        atexit.register(self._api_client.pool.close)
        # Sets the timeout and hardcoded retries parameter for the api client.
        self._api_client.rest_client.pool_manager.connection_pool_kw[
            "timeout"
        ] = channel_timeout
        # Configure the retries against the firewall cluster.
        # Due to NGINX refresh issues breaking connections, we want to allow
        # retries on `read` errors on all idempotent HTTP methods.
        # The current implementation of PATCH with gRPC-gateway proto field
        # masks is idempotent, so it can be safely retried.
        # Note: the Retry will only retry on read errors if the HTTP method
        # belongs to one of the configured `allowed_methods`.
        self._api_client.rest_client.pool_manager.connection_pool_kw["retries"] = Retry(
            read=3,
            status=3,
            status_forcelist=RETRY_HTTP_STATUS,
            allowed_methods=Retry.DEFAULT_ALLOWED_METHODS.union(["PATCH"]),
            backoff_factor=0.5,
        )
        self._instance_manager_client = FirewallInstanceManagerApi(self._api_client)

    def login(self, email: str, system_account: bool = False) -> None:
        """Login to obtain an auth token.

        Args:
            email: str
                The user's email address that is used to authenticate.

            system_account: bool
                This flag specifies whether it is for a system account token or not.

        Example:
             .. code-block:: python

                firewall.login("dev@robustintelligence.com", True)
        """
        authenticator = Authenticator()
        authenticator.auth(self._api_client.configuration.host, email, system_account)
        with open("./token.txt", "r+") as file1:
            self._api_client.configuration.api_key[
                "X-Firewall-Auth-Token"
            ] = file1.read()

    def list_firewall_instances(self) -> Iterable[FirewallInstance]:
        """List the FirewallInstances for the given cluster."""
        with RESTErrorHandler(is_generative_firewall=True):
            res: ApigenerativefirewallListFirewallInstancesResponse = (
                self._instance_manager_client.list_instances()
            )
            firewall_instances: List[
                GenerativefirewallFirewallInstanceInfo
            ] = res.firewall_instances
            for fwinfo in firewall_instances:
                yield FirewallInstance(
                    fwinfo.firewall_instance_id.uuid, self._api_client
                )

    def create_firewall_instance(
        self,
        rule_config: dict,
        description: str = "",
        block_until_ready: bool = True,
        block_until_ready_verbose: Optional[bool] = None,
        block_until_ready_timeout_sec: Optional[
            float
        ] = _CREATE_FIREWALL_INSTANCE_DEFAULT_BLOCK_UNTIL_READY_TIMEOUT_SEC,
        block_until_ready_poll_rate_sec: Optional[float] = None,
    ) -> FirewallInstance:
        """Create a FirewallInstance with the specified rule configuration.

        This method blocks until the FirewallInstance is ready.

        Args:
            rule_config: dict
                Dictionary containing the rule config to customize the behavior
                of the FirewallInstance.

            description: str = ""
                Human-readable description of the FirewallInstance.

            block_until_ready: bool = True
                Whether to block until the FirewallInstance is ready.

            block_until_ready_verbose: Optional[bool] = None
                Whether to print out information while waiting for the FirewallInstance to come up.

            block_until_ready_timeout_sec: Optional[float] = None
                How many seconds to wait until the FirewallInstance comes up before timing out.

            block_until_ready_poll_rate_sec: Optional[float] = None
                How often to poll the FirewallInstance status.

        Returns:
            FirewallInstance that is ready to accept validation requests.

        Raises:
            TimeoutError
                This error is generated if the FirewallInstance is not ready by the deadline
                set through `timeout_sec`.
        """
        _validate_block_until_ready_params(
            block_until_ready,
            block_until_ready_timeout_sec,
            block_until_ready_poll_rate_sec,
        )
        final_conf = _get_validated_firewall_config(rule_config)
        req = GenerativefirewallCreateFirewallInstanceRequest(
            config=final_conf, description=description
        )
        with RESTErrorHandler(is_generative_firewall=True):
            res: GenerativefirewallCreateFirewallInstanceResponse = (
                self._instance_manager_client.create_instance(req)
            )
        fw = FirewallInstance(res.firewall_instance_id.uuid, self._api_client)

        if block_until_ready:
            # Forward keyword arguments to the FirewallInstance.block_until_ready call.
            block_until_ready_kwargs: dict[str, Any] = {}
            if block_until_ready_verbose:
                block_until_ready_kwargs["verbose"] = block_until_ready_verbose
            if block_until_ready_timeout_sec:
                block_until_ready_kwargs["timeout_sec"] = block_until_ready_timeout_sec
            if block_until_ready_poll_rate_sec:
                block_until_ready_kwargs[
                    "poll_rate_sec"
                ] = block_until_ready_poll_rate_sec
            fw.block_until_ready(**block_until_ready_kwargs)

        return fw

    def get_firewall_instance(self, firewall_instance_id: str) -> FirewallInstance:
        """Get a FirewallInstance from the cluster.

        Args:
            firewall_instance_id: str
                The UUID string of the FirewallInstance to retrieve.

        Returns:
            FirewallInstance:
                A firewall instance on which to perform validation.

        """
        with RESTErrorHandler(is_generative_firewall=True):
            fwinfo = _get_firewall_instance_info(
                self._instance_manager_client, firewall_instance_id
            )
        return FirewallInstance(
            firewall_instance_id=fwinfo.firewall_instance_id.uuid,
            api_client=self._api_client,
        )

    def delete_firewall_instance(self, firewall_instance_id: str) -> None:
        """Delete a FirewallInstance from the cluster.

        Careful when deleting a FirewallInstance: in-flight validation traffic
        will be interrupted.

        Args:
            firewall_instance_id: str
                The UUID string of the FirewallInstance to hard delete.
        """
        with RESTErrorHandler(is_generative_firewall=True):
            self._instance_manager_client.delete_instance(firewall_instance_id)
