#!/usr/bin/env python

from spyt.dependency_utils import require_yt_client
require_yt_client()

from yt.wrapper import YtClient  # noqa: E402
from yt.wrapper.http_helpers import get_user_name  # noqa: E402
from spyt.enabler import SpytEnablers  # noqa: E402
from spyt.launch_utils import add_default_launch_options, add_hs_options, add_livy_options, add_parser_group  # noqa: E402
from spyt.spec import SparkDefaultArguments  # noqa: E402
from spyt.standalone import start_spark_cluster  # noqa: E402
from spyt.utils import default_token, default_tvm_id, default_tvm_secret, get_default_arg_parser, parse_args, show_unspecified_args  # noqa: E402


def main(raw_args=None):
    parser = get_default_arg_parser(description="Spark Launch")
    add_default_launch_options(parser)
    parser.add_argument("--worker-cores", required=True, type=int)
    parser.add_argument("--worker-memory", required=True)
    parser.add_argument("--worker-num", required=True, type=int)
    parser.add_argument("--worker-cores-overhead", required=False, type=int,
                        default=SparkDefaultArguments.SPARK_WORKER_CORES_OVERHEAD)
    parser.add_argument("--worker-memory-overhead", required=False,
                        default=SparkDefaultArguments.SPARK_WORKER_MEMORY_OVERHEAD)
    parser.add_argument("--worker-timeout", required=False, default=SparkDefaultArguments.SPARK_WORKER_TIMEOUT)
    parser.add_argument("--tmpfs-limit", required=False, default=SparkDefaultArguments.SPARK_WORKER_TMPFS_LIMIT)
    # COMPAT(alex-shishkin): replace with worker-disk-name='ssd_slots_physical' and worker-disk-limit
    parser.add_argument("--ssd-limit", required=False, default=None)
    # COMPAT(alex-shishkin): replace with worker-disk-name='ssd_slots_physical' and worker-disk-account
    parser.add_argument("--ssd-account", required=False, default=None)
    parser.add_argument("--worker-gpu-limit", type=int, default=SparkDefaultArguments.SPARK_WORKER_GPU_LIMIT)
    parser.add_argument("--worker-disk-name", required=False, default="default")
    parser.add_argument("--worker-disk-limit", required=False, default=None)
    parser.add_argument("--worker-disk-account", required=False, default=None)
    parser.add_argument("--worker-port", required=False, default=SparkDefaultArguments.SPARK_WORKER_PORT, type=int),
    parser.add_argument("--master-memory-limit", required=False,
                        default=SparkDefaultArguments.SPARK_MASTER_MEMORY_LIMIT)
    parser.add_argument("--master-port", required=False, default=SparkDefaultArguments.SPARK_MASTER_PORT, type=int)
    add_hs_options(parser)
    parser.add_argument('--abort-existing', required=False, action='store_true', default=False)
    parser.add_argument("--worker-log-update-interval",
                        required=False, default=SparkDefaultArguments.SPARK_WORKER_LOG_UPDATE_INTERVAL)
    parser.add_argument("--worker-log-table-ttl", required=False,
                        default=SparkDefaultArguments.SPARK_WORKER_LOG_TABLE_TTL)

    parser.add_argument('--enable-multi-operation-mode', dest='enable_multi_operation_mode', action='store_true')
    parser.add_argument('--disable-multi-operation-mode', dest='enable_multi_operation_mode', action='store_false')
    parser.set_defaults(enable_multi_operation_mode=False)

    default_enablers = SpytEnablers()
    add_parser_group(parser, '--enable-byop', '--disable-byop', 'enable_byop', default_enablers.enable_byop)
    add_parser_group(parser, '--enable-solomon-agent', '--disable-solomon-agent', 'enable_solomon_agent',
                     default_enablers.enable_solomon_agent)

    parser.set_defaults(enable_tmpfs=True)

    add_parser_group(parser, '--enable-history-server', '--disable-history-server', 'enable_history_server', True)

    add_parser_group(parser, '--enable-worker-log-transfer', '--disable-worker-log-transfer', 'worker_log_transfer',
                     False)
    add_parser_group(parser, '--enable-worker-log-json-mode', '--disable-worker-log-json-mode', 'worker_log_json_mode',
                     False)

    add_parser_group(parser, '--enable-dedicated-driver-operation-mode', '--disable-dedicated-driver-operation-mode',
                     'dedicated_operation_mode', False)

    subgroup = parser.add_argument_group(
        '(experimental) run driver in dedicated operation')
    subgroup.add_argument("--driver-cores", required=False,
                          type=int, help="same as worker-cores by default")
    subgroup.add_argument("--driver-memory", required=False,
                          type=int, help="same as worker-memory by default")
    subgroup.add_argument("--driver-num", required=False,
                          type=int, help="Number of driver workers")
    subgroup.add_argument("--driver-cores-overhead", required=False,
                          type=int, help="same as worker-cores-overhead by default")
    subgroup.add_argument("--driver-timeout", required=False,
                          help="same as worker-timeout by default")

    subgroup = parser.add_argument_group("(experimental) autoscaler")
    subgroup.add_argument("--autoscaler-period", required=False,
                          type=str,
                          help="""
                            Start autoscaler process with provided period between autoscaling actions.
                            Period format is '<number> <time_unit>', for example '1s', '5 seconds', '100millis' etc.
                          """.strip())
    subgroup.add_argument("--autoscaler-metrics-port", required=False,
                          type=int, help="expose autoscaler metrics on provided port")
    subgroup.add_argument("--autoscaler-sliding-window", required=False,
                          type=int, help="size of autoscaler actions sliding window (in number of action) to downscale")
    subgroup.add_argument("--autoscaler-max-free-workers", required=False,
                          type=int, help="autoscaler maximum number of free workers")
    subgroup.add_argument("--autoscaler-slot-increment-step", required=False,
                          type=int, help="autoscaler worker slots increment step")

    subgroup = parser.add_argument_group("Livy server")
    subgroup.add_argument('--enable-livy', action='store_true', default=False)
    add_livy_options(subgroup)

    args, unknown_args = parse_args(parser, raw_args=raw_args)
    show_unspecified_args(unknown_args)

    yt_client = YtClient(proxy=args.proxy, token=default_token())

    if args.autoscaler_period and not args.enable_multi_operation_mode:
        print("Autoscaler could be enabled only with multi-operation mode")
        exit(-1)

    start_spark_cluster(worker_cores=args.worker_cores,
                        worker_memory=args.worker_memory,
                        worker_num=args.worker_num,
                        worker_cores_overhead=args.worker_cores_overhead,
                        worker_memory_overhead=args.worker_memory_overhead,
                        worker_timeout=args.worker_timeout,
                        operation_alias=args.operation_alias,
                        discovery_path=args.discovery_path,
                        pool=args.pool or get_user_name(client=yt_client),
                        enable_tmpfs=args.enable_tmpfs,
                        tmpfs_limit=args.tmpfs_limit,
                        ssd_limit=args.ssd_limit,
                        ssd_account=args.ssd_account,
                        worker_disk_name=args.worker_disk_name,
                        worker_disk_limit=args.worker_disk_limit,
                        worker_disk_account=args.worker_disk_account,
                        worker_port=args.worker_port,
                        master_memory_limit=args.master_memory_limit,
                        master_port=args.master_port,
                        enable_history_server=args.enable_history_server,
                        history_server_memory_limit=args.history_server_memory_limit,
                        history_server_memory_overhead=args.history_server_memory_overhead,
                        history_server_cpu_limit=args.history_server_cpu_limit,
                        network_project=args.network_project,
                        tvm_id=default_tvm_id(),
                        tvm_secret=default_tvm_secret(),
                        abort_existing=args.abort_existing,
                        advanced_event_log=args.advanced_event_log,
                        worker_log_transfer=args.worker_log_transfer,
                        worker_log_json_mode=args.worker_log_json_mode,
                        worker_log_update_interval=args.worker_log_update_interval,
                        worker_log_table_ttl=args.worker_log_table_ttl,
                        params=args.params,
                        shs_location=args.shs_location,
                        spark_cluster_version=args.spyt_version,
                        enablers=SpytEnablers(
                            enable_byop=args.enable_byop,
                            enable_mtn=args.enable_mtn,
                            enable_solomon_agent=args.enable_solomon_agent,
                            enable_tcp_proxy=args.enable_tcp_proxy,
                            enable_squashfs=args.enable_squashfs
                        ),
                        enable_preference_ipv6=args.enable_preference_ipv6,
                        client=yt_client,
                        preemption_mode=args.preemption_mode,
                        cluster_log_level=args.cluster_log_level,
                        enable_multi_operation_mode=args.enable_multi_operation_mode,
                        dedicated_operation_mode=args.dedicated_operation_mode,
                        driver_cores=args.driver_cores,
                        driver_memory=args.driver_memory,
                        driver_num=args.driver_num,
                        driver_cores_overhead=args.driver_cores_overhead,
                        driver_timeout=args.driver_timeout,
                        autoscaler_period=args.autoscaler_period,
                        autoscaler_metrics_port=args.autoscaler_metrics_port,
                        autoscaler_sliding_window=args.autoscaler_sliding_window,
                        autoscaler_max_free_workers=args.autoscaler_max_free_workers,
                        autoscaler_slot_increment_step=args.autoscaler_slot_increment_step,
                        enable_livy=args.enable_livy,
                        livy_driver_cores=args.livy_driver_cores,
                        livy_driver_memory=args.livy_driver_memory,
                        livy_max_sessions=args.livy_max_sessions,
                        rpc_job_proxy=args.rpc_job_proxy,
                        rpc_job_proxy_thread_pool_size=args.rpc_job_proxy_thread_pool_size,
                        tcp_proxy_range_start=args.tcp_proxy_range_start,
                        tcp_proxy_range_size=args.tcp_proxy_range_size,
                        enable_stderr_table=args.enable_stderr_table,
                        group_id=args.group_id,
                        worker_gpu_limit=args.worker_gpu_limit,
                        cluster_java_home=args.cluster_java_home
                        )


if __name__ == '__main__':
    main()
