from __future__ import annotations

import logging
import time
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast

from typing_extensions import NotRequired

from dagster_ray.kuberay.client.base import BaseKubeRayClient, load_kubeconfig
from dagster_ray.kuberay.client.raycluster import RayClusterClient, RayClusterStatus
from dagster_ray.kuberay.client.raycluster.client import RayClusterEndpoints

if TYPE_CHECKING:
    from kubernetes.client import ApiClient

GROUP = "ray.io"
VERSION = "v1"
PLURAL = "rayjobs"
KIND = "RayJob"

logger = logging.getLogger(__name__)


class RayJobStatus(TypedDict):
    jobId: NotRequired[str]
    rayJobInfo: NotRequired[dict[str, Any]]
    jobDeploymentStatus: str
    rayClusterName: str
    rayClusterStatus: RayClusterStatus
    startTime: str

    dashboardURL: NotRequired[str]
    endTime: NotRequired[str]
    jobStatus: NotRequired[Literal["PENDING", "RUNNING", "SUCCEEDED", "FAILED", "STOPPED"]]
    message: NotRequired[str]
    failed: NotRequired[int]
    succeeded: NotRequired[int]


class RayJobClient(BaseKubeRayClient[RayJobStatus]):
    def __init__(
        self,
        kube_config: str | None = None,
        kube_context: str | None = None,
        api_client: ApiClient | None = None,
    ) -> None:
        self.kube_config = kube_config
        self.kube_context = kube_context

        # this call must happen BEFORE creating K8s apis
        if api_client is None:
            load_kubeconfig(config_file=kube_config, context=kube_context)

        super().__init__(group=GROUP, version=VERSION, kind=KIND, plural=PLURAL, api_client=api_client)

    def get_ray_cluster_name(self, name: str, namespace: str, timeout: float, poll_interval: float = 1.0) -> str:
        return self.get_status(name, namespace, timeout=timeout, poll_interval=poll_interval)["rayClusterName"]

    def get_job_submission_id(
        self, name: str, namespace: str, timeout: float, poll_interval: float = 1.0
    ) -> str | None:
        """Returns the ray job submission ID. It may be missing for mode: InteractiveMode."""
        return self.get_status(name, namespace, timeout=timeout, poll_interval=poll_interval).get("jobId")

    @property
    def ray_cluster_client(self) -> RayClusterClient:
        return RayClusterClient(kube_config=self.kube_config, kube_context=self.kube_context)

    def wait_until_ready(
        self,
        name: str,
        namespace: str,
        timeout: float = 600,
        failure_tolerance_timeout: float = 0.0,
        poll_interval: float = 1.0,
        log_cluster_conditions: bool = False,
    ) -> tuple[str, RayClusterEndpoints]:
        """Wait until the RayCluster attached to the RayJob is ready.

        This doesn't necessarily mean that the cluster has already taken a job, just that it is ready to accept connections.

        Parameters:
            name (str): The name of the `RayJob` resource
            namespace (str): The namespace of the `RayJob` resource
            timeout (float): The timeout in seconds to wait for the cluster to become ready.
            failure_tolerance_timeout (float): The period in seconds to wait for the cluster to transition out of `failed` state if it reaches it. This state can be transient under certain conditions. With the default value of 0, the first `failed` state appearance will raise an exception immediately.
            poll_interval (float): The interval in seconds to poll the cluster status.
            log_cluster_conditions (bool): Whether to log cluster conditions. See [KubeRay docs](https://docs.ray.io/en/latest/cluster/kubernetes/user-guides/observability.html#raycluster-status-conditions)

        Returns:
            tuple[str, RayClusterEndpoints]: The service ip address and a dictionary of ports.
        """
        ray_cluster_name = self.get_ray_cluster_name(name, namespace, timeout=timeout, poll_interval=poll_interval)
        ray_cluster_client = self.ray_cluster_client
        ray_cluster_client.wait_until_exists(
            name=ray_cluster_name, namespace=namespace, timeout=timeout, poll_interval=poll_interval
        )
        return ray_cluster_client.wait_until_ready(
            ray_cluster_name,
            namespace=namespace,
            timeout=timeout,
            failure_tolerance_timeout=failure_tolerance_timeout,
            poll_interval=poll_interval,
            log_cluster_conditions=log_cluster_conditions,
        )

    def wait_until_running(
        self,
        name: str,
        namespace: str,
        timeout: float = 600,
        poll_interval: float = 1.0,
        terminate_on_timeout: bool = True,
        port_forward: bool = False,
        log_cluster_conditions: bool = False,
    ) -> bool:
        start_time = time.time()

        while True:
            status = self.get_status(name, namespace, timeout, poll_interval).get("jobDeploymentStatus")

            if status in ["Running", "Complete"]:
                break
            elif status == "Failed":
                raise RuntimeError(f"RayJob {namespace}/{name} deployment failed. Status:\n{status}")

            if time.time() - start_time > timeout:
                if terminate_on_timeout:
                    logger.warning(f"Terminating RayJob {namespace}/{name} because of timeout {timeout}s")
                    try:
                        self.terminate(name, namespace, port_forward=port_forward)
                    except Exception as e:
                        logger.warning(
                            f"Failed to gracefully terminate RayJob {namespace}/{name}: {e}, will delete it instead."
                        )
                        self.delete(name, namespace)

                raise TimeoutError(f"Timed out waiting for RayJob {namespace}/{name} to start. Status:\n{status}")

            time.sleep(poll_interval)

        while True:
            status = self.get_status(name, namespace, timeout, poll_interval).get("jobStatus")

            if status:
                break

            if time.time() - start_time > timeout:
                raise TimeoutError(f"Timed out waiting for RayJob {namespace}/{name} to start. Status:\n{status}")

            time.sleep(poll_interval)

        return True

    def _wait_for_job_submission(
        self,
        name: str,
        namespace: str,
        timeout: float = 600,
        poll_interval: float = 1.0,
    ):
        start_time = time.time()

        while True:
            status = self.get_status(name, namespace)
            if status.get("jobDeploymentStatus") in ["Complete", "Failed"]:
                return

            if (job_status := status.get("jobStatus")) is not None:
                if job_status != "PENDING":
                    return

            if time.time() - start_time > timeout:
                raise TimeoutError(f"Timed out waiting for job {name} to start")

            logger.debug(f"RayJob {namespace}/{name} deployment status is {job_status}, waiting for it to start...")

            time.sleep(poll_interval)

    def get_job_logs(
        self,
        name: str,
        namespace: str,
        timeout: float = 60 * 60,
        poll_interval: float = 1.0,
        port_forward: bool = False,
    ) -> str:
        self._wait_for_job_submission(name, namespace, timeout=timeout)
        with self.ray_cluster_client.job_submission_client(
            name=self.get_ray_cluster_name(name, namespace, timeout=timeout, poll_interval=poll_interval),
            namespace=namespace,
            port_forward=port_forward,
        ) as job_submission_client:
            return job_submission_client.get_job_logs(
                job_id=cast(
                    str, self.get_job_submission_id(name, namespace, timeout=timeout, poll_interval=poll_interval)
                )
            )

    def tail_job_logs(
        self,
        name: str,
        namespace: str,
        timeout: float = 60 * 60,
        poll_interval: float = 1.0,
        port_forward: bool = False,
    ) -> Iterator[str]:
        import asyncio

        self._wait_for_job_submission(name, namespace, timeout=timeout)
        with self.ray_cluster_client.job_submission_client(
            name=self.get_ray_cluster_name(name, namespace, timeout=timeout, poll_interval=poll_interval),
            namespace=namespace,
            port_forward=port_forward,
        ) as job_submission_client:
            async_tailer = job_submission_client.tail_job_logs(
                job_id=cast(
                    str, self.get_job_submission_id(name, namespace, timeout=timeout, poll_interval=poll_interval)
                )
            )

            # Backward compatible sync generator
            def tail_logs() -> Iterator[str]:
                while True:
                    try:
                        yield asyncio.get_event_loop().run_until_complete(async_tailer.__anext__())  # type: ignore
                    except StopAsyncIteration:
                        break

            yield from tail_logs()

    def terminate(
        self, name: str, namespace: str, timeout: float = 10.0, poll_interval: float = 1.0, port_forward: bool = False
    ) -> bool:
        """
        Unlike the .delete method, this won't remove the Kubernetes object, but will instead stop the Ray Job.
        """
        with self.ray_cluster_client.job_submission_client(
            name=self.get_ray_cluster_name(name, namespace, timeout=timeout, poll_interval=poll_interval),
            namespace=namespace,
            port_forward=port_forward,
        ) as job_submission_client:
            job_id = cast(
                str, self.get_job_submission_id(name, namespace, timeout=timeout, poll_interval=poll_interval)
            )

            job_submitted = False

            while not job_submitted:
                jobs = job_submission_client.list_jobs()

                for job in jobs:
                    if job.submission_id == job_id:
                        job_submitted = True
                        break

                logger.debug(
                    f"Trying to terminate job {name}, but it wasn't submitted yet. Waiting for it to be submitted..."
                )
                time.sleep(10)

            return job_submission_client.stop_job(job_id=job_id)
