#  Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at:
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
#  or implied. See the License for the specific language governing
#  permissions and limitations under the License.
"""Kubernetes step operator implementation."""

from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast

from kubernetes import client as k8s_client

from zenml.config.base_settings import BaseSettings
from zenml.config.build_configuration import BuildConfiguration
from zenml.enums import StackComponentType
from zenml.integrations.kubernetes.flavors import (
    KubernetesStepOperatorConfig,
    KubernetesStepOperatorSettings,
)
from zenml.integrations.kubernetes.orchestrators import kube_utils
from zenml.integrations.kubernetes.orchestrators.manifest_utils import (
    build_pod_manifest,
)
from zenml.logger import get_logger
from zenml.stack import Stack, StackValidator
from zenml.step_operators import BaseStepOperator

if TYPE_CHECKING:
    from zenml.config.step_run_info import StepRunInfo
    from zenml.models import PipelineDeploymentBase

logger = get_logger(__name__)

KUBERNETES_STEP_OPERATOR_DOCKER_IMAGE_KEY = "kubernetes_step_operator"


class KubernetesStepOperator(BaseStepOperator):
    """Step operator to run on Kubernetes."""

    _k8s_client: Optional[k8s_client.ApiClient] = None

    @property
    def config(self) -> KubernetesStepOperatorConfig:
        """Returns the `KubernetesStepOperatorConfig` config.

        Returns:
            The configuration.
        """
        return cast(KubernetesStepOperatorConfig, self._config)

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the Kubernetes step operator.

        Returns:
            The settings class.
        """
        return KubernetesStepOperatorSettings

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        Returns:
            A validator that checks that the stack contains a remote container
            registry and a remote artifact store.
        """

        def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
            if stack.artifact_store.config.is_local:
                return False, (
                    "The Kubernetes step operator runs code remotely and "
                    "needs to write files into the artifact store, but the "
                    f"artifact store `{stack.artifact_store.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote artifact store when using the Vertex "
                    "step operator."
                )

            container_registry = stack.container_registry
            assert container_registry is not None

            if container_registry.config.is_local:
                return False, (
                    "The Kubernetes step operator runs code remotely and "
                    "needs to push/pull Docker images, but the "
                    f"container registry `{container_registry.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote container registry when using the "
                    "Kubernetes step operator."
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_remote_components,
        )

    def get_docker_builds(
        self, deployment: "PipelineDeploymentBase"
    ) -> List["BuildConfiguration"]:
        """Gets the Docker builds required for the component.

        Args:
            deployment: The pipeline deployment for which to get the builds.

        Returns:
            The required Docker builds.
        """
        builds = []
        for step_name, step in deployment.step_configurations.items():
            if step.config.uses_step_operator(self.name):
                build = BuildConfiguration(
                    key=KUBERNETES_STEP_OPERATOR_DOCKER_IMAGE_KEY,
                    settings=step.config.docker_settings,
                    step_name=step_name,
                )
                builds.append(build)

        return builds

    def get_kube_client(self) -> k8s_client.ApiClient:
        """Get the Kubernetes API client.

        Returns:
            The Kubernetes API client.

        Raises:
            RuntimeError: If the service connector returns an unexpected client.
        """
        if self.config.incluster:
            kube_utils.load_kube_config(incluster=True)
            self._k8s_client = k8s_client.ApiClient()
            return self._k8s_client

        # Refresh the client also if the connector has expired
        if self._k8s_client and not self.connector_has_expired():
            return self._k8s_client

        connector = self.get_connector()
        if connector:
            client = connector.connect()
            if not isinstance(client, k8s_client.ApiClient):
                raise RuntimeError(
                    f"Expected a k8s_client.ApiClient while trying to use the "
                    f"linked connector, but got {type(client)}."
                )
            self._k8s_client = client
        else:
            kube_utils.load_kube_config(
                context=self.config.kubernetes_context,
            )
            self._k8s_client = k8s_client.ApiClient()

        return self._k8s_client

    @property
    def _k8s_core_api(self) -> k8s_client.CoreV1Api:
        """Getter for the Kubernetes Core API client.

        Returns:
            The Kubernetes Core API client.
        """
        return k8s_client.CoreV1Api(self.get_kube_client())

    def launch(
        self,
        info: "StepRunInfo",
        entrypoint_command: List[str],
        environment: Dict[str, str],
    ) -> None:
        """Launches a step on Kubernetes.

        Args:
            info: Information about the step run.
            entrypoint_command: Command that executes the step.
            environment: Environment variables to set in the step operator
                environment.
        """
        settings = cast(
            KubernetesStepOperatorSettings, self.get_settings(info)
        )
        image_name = info.get_image(
            key=KUBERNETES_STEP_OPERATOR_DOCKER_IMAGE_KEY
        )

        pod_name = f"{info.run_name}_{info.pipeline_step_name}"
        pod_name = kube_utils.sanitize_pod_name(
            pod_name, namespace=self.config.kubernetes_namespace
        )

        command = entrypoint_command[:3]
        args = entrypoint_command[3:]

        # Create and run the orchestrator pod.
        pod_manifest = build_pod_manifest(
            pod_name=pod_name,
            image_name=image_name,
            command=command,
            args=args,
            privileged=settings.privileged,
            service_account_name=settings.service_account_name,
            pod_settings=settings.pod_settings,
            env=environment,
            mount_local_stores=False,
            labels={
                "run_id": kube_utils.sanitize_label(str(info.run_id)),
                "pipeline": kube_utils.sanitize_label(info.pipeline.name),
            },
        )

        kube_utils.create_and_wait_for_pod_to_start(
            core_api=self._k8s_core_api,
            pod_display_name=f"pod of step `{info.pipeline_step_name}`",
            pod_name=pod_name,
            pod_manifest=pod_manifest,
            namespace=self.config.kubernetes_namespace,
            startup_max_retries=settings.pod_failure_max_retries,
            startup_failure_delay=settings.pod_failure_retry_delay,
            startup_failure_backoff=settings.pod_failure_backoff,
            startup_timeout=settings.pod_startup_timeout,
        )

        logger.info(
            "Waiting for pod of step `%s` to finish...",
            info.pipeline_step_name,
        )
        kube_utils.wait_pod(
            kube_client_fn=self.get_kube_client,
            pod_name=pod_name,
            namespace=self.config.kubernetes_namespace,
            exit_condition_lambda=kube_utils.pod_is_done,
            stream_logs=True,
        )
        logger.info("Pod of step `%s` completed.", info.pipeline_step_name)
