from __future__ import annotations

import logging
import os
import shlex
import subprocess
import sys
import time
from typing import Optional, Sequence

import click
from dask.base import tokenize
from dask.distributed import Client, get_worker
from dask.utils import parse_timedelta
from packaging.version import Version
from rich import print

import coiled
from coiled.compatibility import DISTRIBUTED_VERSION
from coiled.core import logger as coiled_logger
from coiled.utils import error_info_for_tracking, unset_single_thread_defaults

from ..shutdown import NoClientShutdown
from .cluster.better_logs import better_logs
from .utils import CONTEXT_SETTINGS

MINIMUM_DISTRIBUTED_VERSION = Version("2023.6.0")


@click.command(context_settings=CONTEXT_SETTINGS)
@click.option(
    "--name",
    default=None,
    help="Run name. If not given, defaults to a unique name.",
)
@click.option(
    "--account",
    default=None,
    help="Coiled account (uses default account if not specified)",
)
@click.option(
    "--software",
    default=None,
    help=(
        "Software environment name to use. If neither software nor container is specified, "
        "all the currently-installed Python packages are replicated on the VM using package sync."
    ),
)
@click.option(
    "--container",
    default=None,
    help=(
        "Container image to use. If neither software nor container is specified, "
        "all the currently-installed Python packages are replicated on the VM using package sync."
    ),
)
@click.option(
    "--vm-type",
    default=[],
    multiple=True,
    help="VM type to use. Specify multiple times to provide multiple options.",
)
@click.option(
    "--gpu",
    default=False,
    is_flag=True,
    help="Have a GPU available.",
)
@click.option(
    "--region",
    default=None,
    help="The cloud provider region in which to run the notebook.",
)
@click.option(
    "--keepalive",
    default=None,
    help=(
        "Keep your VM running for the specified time, even after your command completes. "
        "Default to shutdown immediately."
    ),
)
@click.option(
    "--file",
    "-f",
    default=[],
    multiple=True,
    help="Local files required to run command.",
)
@click.argument("command", nargs=-1)
def run(
    name: str,
    account: Optional[str],
    software: Optional[str],
    container: Optional[str],
    vm_type: Sequence[str],
    gpu: bool,
    region: Optional[str],
    keepalive,
    file,
    command,
):
    """
    Run a command on the cloud.
    """
    if DISTRIBUTED_VERSION < MINIMUM_DISTRIBUTED_VERSION:
        # Relies on scheduler having scratch space, which was added in https://github.com/dask/distributed/pull/7802
        raise RuntimeError(
            f"`coiled run` requires distributed>={MINIMUM_DISTRIBUTED_VERSION} "
            f"(distributed={DISTRIBUTED_VERSION} is installed)"
        )

    env = unset_single_thread_defaults()
    if container and "rapidsai" in container:
        env = {"DISABLE_JUPYTER": "true", **env}  # needed for "stable" RAPIDS image

    if not command:
        raise ValueError("command must be specified")

    keepalive = parse_timedelta(keepalive)
    shutdown_on_close = keepalive is None

    info = {"command": command, "keepalive": keepalive}
    success = True
    exception = None

    # if user tries `coiled run foo.py` they probably want to run `python foo.py` rather than `foo.py`
    if len(command) == 1 and command[0].endswith(".py"):
        command = ("python", command[0])

    # configure logging (remove once we have nicer widget)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(logging.Formatter(fmt="  %(message)s"))
    coiled_logger.setLevel(logging.INFO)
    coiled_logger.addHandler(stream_handler)

    try:
        print()
        print(f"Setting up [bold]{shlex.join(command)}[/bold]...")
        print("-" * len(f"Setting up {shlex.join(command)}..."))
        print()

        with coiled.Cloud(account=account) as cloud:
            account = account or cloud.default_account
            info["account"] = account
            cluster_kwargs = dict(
                account=account,
                n_workers=0,
                software=software,
                container=container,
                scheduler_options={"idle_timeout": "24 hours"},
                scheduler_vm_types=list(vm_type) if vm_type else None,
                worker_vm_types=list(vm_type) if vm_type else None,
                allow_ssh=True,
                extra_worker_on_scheduler=True,
                environ=env,
                scheduler_gpu=gpu,
                region=region,
                shutdown_on_close=shutdown_on_close,
                tags={"coiled-cluster-type": "run/cli"},
            )
            token = tokenize(sys.executable, **cluster_kwargs)
            name = f"run-{token[:8]}"

            with coiled.Cluster(name=name, cloud=cloud, **cluster_kwargs) as cluster:
                info["cluster_id"] = cluster.cluster_id
                with Client(cluster) as client:
                    if not shutdown_on_close:
                        client.register_scheduler_plugin(NoClientShutdown(keepalive=keepalive))
                    # Extract and upload files from `command`
                    command = shlex.split(" ".join(command))
                    info["files-implicit"] = []
                    for idx, i in enumerate(command):
                        if os.path.exists(i) and os.path.isfile(i):
                            client.upload_file(i, load=False)
                            info["files-implicit"].append(i)
                            command[idx] = os.path.basename(i)
                    # Upload user-specified files too
                    info["files-explicit"] = file
                    for f in file:
                        client.upload_file(f, load=False)

                    def run_command(command):
                        subprocess.run(command, cwd=get_worker().local_directory)

                    info["command-parsed"] = command

                    # keep track of time before we start executing user code
                    start_ns = time.time_ns()

                    # TODO execute command in non-blocking way so we can also tail logs while it runs
                    client.submit(run_command, command=["echo", "@start-coiled-run"]).result()
                    client.submit(run_command, command=command).result()
                    client.submit(run_command, command=["echo", "@stop-coiled-run"]).result()

        if cluster.cluster_id:
            print()

            print("Output")
            print("-" * len("Output"))
            print()

            better_logs(
                cluster_id=cluster.cluster_id,
                instance_labels_dict={},
                show_label=False,
                show_all_instances=True,
                show_timestamp=False,  # only show the messages
                tail=True,  # use tail so that we keep polling logs until stop sentinel to shows up
                tail_max_times=10,  # only try 10 times (at 3s interval), logs should be ingested by then
                since=str(int(start_ns // 1e6)),  # it wants timestamp in ms as a string
                start_sentinel="@start-coiled-run",
                stop_sentinel="@stop-coiled-run",
            )
            print()

    except Exception as e:
        success = False
        exception = e
        raise e
    finally:
        coiled.add_interaction(
            "coiled-run-cli",
            success=success,
            **info,
            **error_info_for_tracking(exception),
        )
