#!/usr/bin/env python

"""A utility script that takes a pyspark script and runs it in the SAS environment.

Typical usage example:
    snowpark-submit ./tools/examples_row.py

"""
import argparse
import logging
import subprocess
import sys

import pyspark

import snowflake.snowpark_connect

logger = logging.getLogger("snowpark-submit")


def setup_logging(log_level):
    logger = logging.getLogger("snowpark-submit")
    logger.setLevel(log_level)
    if not logger.hasHandlers():
        console_handler = logging.StreamHandler()
        console_handler.setLevel(log_level)
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - [Thread %(thread)d] - %(message)s"
        )
        console_handler.setFormatter(formatter)
        logger.addHandler(console_handler)
    sas_logger = logging.getLogger("snowflake_connect_server")
    sas_logger.setLevel(log_level)
    for handler in sas_logger.handlers:
        handler.setLevel(log_level)


def init_sas(remote_url: str):
    snowflake.snowpark_connect.start_session(remote_url=remote_url)


def run_spark_code(args: list[str]):
    logger.info("running spark script as %s", args)

    p = subprocess.run(args)
    return p.returncode


def init_args(args: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run a Spark script in SAS environment.",
        add_help=False,
        usage="""NOTE: All spark-submit options are displayed here, currently unsupported options are marked [DEPRECATED]
    Usage: snowpark-submit [options] <app jar | python file> [app arguments]
    """,
    )
    # Other spark-submit usage (we may add support for these in the future):
    # Usage: snowpark-submit --kill [submission ID] --master [spark://...]
    # Usage: snowpark-submit --status [submission ID] --master [spark://...]
    # Usage: snowpark-submit run-example [options] example-class [example args]
    options_group = parser.add_argument_group("Options")
    spark_connect_group = parser.add_argument_group("Spark Connect only")
    cluster_deploy_group = parser.add_argument_group("Cluster Deploy mode only")
    spark_standalone_or_mesos_cluster_group = parser.add_argument_group(
        "[Unsupported] Spark standalone or Mesos with cluster deploy mode only"
    )
    k8s_group = parser.add_argument_group(
        "[Unsupported] Spark standalone, Mesos or K8s with cluster deploy mode only"
    )
    spark_standalone_mesos_group = parser.add_argument_group(
        "[Unsupported] Spark standalone and Mesos only"
    )
    spark_standalone_yarn_group = parser.add_argument_group(
        "[Unsupported] Spark standalone, YARN and Kubernetes only"
    )
    spark_yarn_k8s_group = parser.add_argument_group(
        "[Unsupported] Spark on YARN and Kubernetes only"
    )
    spark_yarn_group = parser.add_argument_group("Spark on YARN only")

    options_group.add_argument(
        "--master",
        metavar="MASTER_URL",
        type=str,
        help="[DEPRECATED] spark://host:port, mesos://host:port, yarn, k8s://https://host:port, or local (Default: local[*]).",
    )
    options_group.add_argument(
        "--deploy-mode",
        metavar="DEPLOY_MODE",
        type=str,
        choices=["client", "cluster"],
        help="[DEPRECATED] Whether to launch the driver program locally ('client') or on one of the worker machines inside the cluster ('cluster') (Default: client).",
    )
    options_group.add_argument(
        "--class",
        metavar="CLASS_NAME",
        type=str,
        help="Your application's main class (for Java / Scala apps).",
    )
    options_group.add_argument(
        "--name",
        metavar="NAME",
        type=str,
        help="A name of your application.",
    )
    options_group.add_argument(
        "--jars",
        metavar="JAR",
        type=str,
        help="Comma-separated list of jars to include on the driver and executor classpaths.",
    )
    options_group.add_argument(
        "--packages",
        type=str,
        nargs="*",
        help="[DEPRECATED] Comma-separated list of maven coordinates of jars to include on the driver and executor classpaths. Will search the local maven repo, then maven central and any additional remote repositories given through --repositories. The format for the coordinates should be groupId:artifactId:version.",
    )
    options_group.add_argument(
        "--exclude-packages",
        type=str,
        nargs="*",
        help="Comma-separated list of groupId:artifactId, to exclude while resolving the dependencies provided in --packages to avoid dependency conflicts.",
    )
    options_group.add_argument(
        "--repositories",
        type=str,
        nargs="*",
        help="[DEPRECATED] Comma-separated list of additional remote repositories to search for the maven coordinates given with --packages.",
    )
    options_group.add_argument(
        "--py-files",
        metavar="PY_FILES",
        type=str,
        help="Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps.",
    )
    options_group.add_argument(
        "--files",
        metavar="FILES",
        type=str,
        nargs="*",
        help="[DEPRECATED] Comma-separated list of files to be placed in the working directory of each executor.",
    )
    options_group.add_argument(
        "--archives",
        metavar="ARCHIVES",
        type=str,
        nargs="*",
        help="[DEPRECATED] Comma-separated list of archives to be extracted into the working directory of each executor.",
    )
    options_group.add_argument(
        "--conf",
        "-c",
        metavar="PROP=VALUE",
        type=str,
        nargs="*",
        help="Arbitrary Spark configuration property.",
    )
    options_group.add_argument(
        "--properties-file",
        metavar="FILE",
        type=str,
        help="Path to a file from which to load extra properties. If not specified, this will look for conf/spark-defaults.conf.",
    )
    options_group.add_argument(
        "--driver-memory",
        metavar="MEM",
        type=str,
        help="[DEPRECATED] Memory for driver (e.g. 1000M, 2G) (Default: 1024M).",
    )
    options_group.add_argument(
        "--driver-java-options",
        type=str,
        help="[DEPRECATED] Extra Java options to pass to the driver.",
    )
    options_group.add_argument(
        "--driver-library-path",
        type=str,
        help="[DEPRECATED] Extra library path entries to pass to the driver.",
    )
    options_group.add_argument(
        "--driver-class-path",
        type=str,
        help="[DEPRECATED] Extra class path entries to pass to the driver. Note that jars added with --jars are automatically included in the classpath.",
    )
    options_group.add_argument(
        "--executor-memory",
        metavar="MEM",
        type=str,
        help="[DEPRECATED] Memory per executor (e.g. 1000M, 2G) (Default: 1G).",
    )
    options_group.add_argument(
        "--proxy-user",
        type=str,
        help="[DEPRECATED] User to impersonate when submitting the application. This argument does not work with --principal / --keytab.",
    )
    options_group.add_argument(
        "--help",
        "-h",
        action="help",
        help="Show this help message and exit.",
    )
    options_group.add_argument(
        "--verbose",
        "-v",
        action="store_true",
        help="Print additional debug output.",
    )
    options_group.add_argument(
        "--version",
        action="version",
        version=pyspark.__version__,
    )
    spark_connect_group.add_argument(
        "--remote",
        metavar="CONNECT_URL",
        type=str,
        default="sc://localhost:15003",  # Different from snowpark-session, to avoid conflicts.
        help="URL to connect to the server for Spark Connect, e.g., sc://host:port. --master and --deploy-mode cannot be set together with this option. This option is experimental, and might change between minor releases.",
    )
    spark_connect_group.add_argument(
        "--skip-init-sas",
        action="store_true",
        help="If given, skip initialize SAS. This is used in server side testing.",
    )
    cluster_deploy_group.add_argument(
        "--driver-cores",
        metavar="NUM",
        type=str,
        help="[DEPRECATED] Number of cores used by the driver, only in cluster mode (Default: 1).",
    )
    spark_standalone_or_mesos_cluster_group.add_argument(
        "--supervise",
        action="store_true",
        help="[DEPRECATED] If given, restart the driver on failure.",
    )
    k8s_group.add_argument(
        "--kill",
        metavar="SUBMISSION_ID",
        type=str,
        help="[DEPRECATED] If given, kills the driver specified.",
    )
    k8s_group.add_argument(
        "--status",
        metavar="SUBMISSION_ID",
        type=str,
        help="[DEPRECATED] If given, requests the status of the driver specified.",
    )
    spark_standalone_mesos_group.add_argument(
        "--total-executor-cores",
        metavar="NUM",
        type=str,
        help="[DEPRECATED] Total cores for all executors.",
    )
    spark_standalone_yarn_group.add_argument(
        "--executor-cores",
        metavar="NUM",
        type=str,
        help="[DEPRECATED] Number of cores per executor. (Default: 1 in YARN mode, or all available cores on the worker in standalone mode).",
    )
    spark_yarn_k8s_group.add_argument(
        "--num-executors",
        metavar="NUM",
        type=str,
        help="[DEPRECATED] Number of executors to launch (Default: 2).\nIf dynamic allocation is enabled, the initial number of executors will be at least NUM.",
    )
    spark_yarn_k8s_group.add_argument(
        "--principal",
        metavar="PRINCIPAL",
        type=str,
        help="[DEPRECATED] Principal to be used to login to KDC.",
    )
    spark_yarn_k8s_group.add_argument(
        "--keytab",
        metavar="KEYTAB",
        type=str,
        help="[DEPRECATED] The full path to the file that contains the keytab for the principal specified.",
    )
    spark_yarn_group.add_argument(
        "--queue",
        metavar="QUEUE_NAME",
        type=str,
        help="[DEPRECATED] The YARN queue to submit to (Default: 'default').",
    )
    parser.add_argument(
        "filename",
        metavar="FILE",
        type=str,
        help=argparse.SUPPRESS,
    )

    args, unknown_args = parser.parse_known_args(args)
    args.app_arguments = unknown_args

    return args


def generate_spark_submit_cmd(
    args: argparse.Namespace,
    entrypoint_arg: str = "spark-submit",
) -> list[str]:
    args_for_spark = [entrypoint_arg]
    for k, v in vars(args).items():
        if v is not None and k not in [
            "filename",
            "verbose",
            "supervise",
            "skip_init_sas",
            "deploy_mode",
            "app_arguments",
        ]:
            args_for_spark.append(f"--{k.replace('_', '-')}")
            args_for_spark.append(v)
    if args.verbose:
        args_for_spark.append("--verbose")
        setup_logging(logging.DEBUG)
    else:
        setup_logging(logging.INFO)
    args_for_spark.append(args.filename)
    args_for_spark.extend(args.app_arguments)
    return args_for_spark


def run():
    args = init_args()

    try:
        if not args.skip_init_sas:
            init_sas(args.remote)
    except RuntimeError as re:
        logger.error(
            "%s. Please check logs for details and reach out to Snowflake if needed.",
            re,
        )
        return 1

    args_for_spark = generate_spark_submit_cmd(args)
    return run_spark_code(args_for_spark)


def runner_wrapper(test_mode=False):
    logger.debug("Runner starts.")
    exit_status = run()
    # send the exit status in lower byte as 0/1 flag
    if exit_status != 0:
        logger.error("Unexpected Exit: non-zero exit code.")
        exit_status = 1
    if test_mode:
        return exit_status
    else:
        sys.exit(exit_status)


if __name__ == "__main__":
    runner_wrapper()
