#!/usr/bin/env python3
import argparse
import atexit
import logging
import os
import re
import signal
import subprocess
import sys
import time
from datetime import datetime, timezone

from locust_swarm.functions import *

logging.basicConfig(
    format="%(asctime)s,%(msecs)d %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d:%H:%M:%S",
    level=logging.INFO,
)

parser = argparse.ArgumentParser(
    description="A tool for running locust/jmeter in a distributed fashion.",
    epilog="Any parameters not listed here are forwarded to locust/jmeter unmodified, so go ahead and use things like -c, -r, --host, ... All env vars starting with LOCUST_ will be forwarded to the slaves.",
)
parser.add_argument("-f", type=str, dest="testplan")
parser.add_argument("-t", type=str, dest="run_time")
parser.add_argument("--processes-per-loadgen", type=int, default=4)
parser.add_argument("--loadgens", type=int, default=2)
parser.add_argument("--loadgen-list", type=str, default=os.environ.get("LOADGEN_LIST"))
parser.add_argument("-L", type=str, dest="loglevel")
parser.add_argument("--port", type=str, default="5557")
parser.add_argument("--jmeter", action="store_true")
parser.add_argument("--jmeter-gui", action="store_true")
parser.add_argument("--jmeter-kill-first", action="store_true")

args, unrecognized_args = parser.parse_known_args()

if args.loglevel:
    logging.getLogger().setLevel(args.loglevel.upper())

if args.jmeter_gui:
    args.jmeter = True
    args.loadgens = 0

if args.loadgen_list is None:
    parser.error(
        "No loadgens specified on command line (--loadgen-list) or env var (LOADGEN_LIST). Use a comma-separated list."
    )

if not args.jmeter:
    if not args.run_time:
        parser.error("Run time (-t) not specified. This is mandatory when running locust.")
    client_count = None
    for i, argument in enumerate(unrecognized_args):
        if argument == "-c":
            client_count = int(unrecognized_args[i + 1])
    if not client_count:
        parser.error("Client count (-c) not specified. This is mandatory when running locust.")

if args.testplan is None:
    if args.jmeter:
        args.testplan = "testplan.jmx"
    else:
        args.testplan = "locustfile.py"

testplan = args.testplan
testplan_filename = os.path.split(testplan)[1]
run_time = args.run_time
port = int(args.port)
processes_per_loadgen = args.processes_per_loadgen
loadgens = args.loadgens
slave_process_count = processes_per_loadgen * loadgens
server_list = args.loadgen_list.split(",")
slaves = []

try:
    subprocess.check_output(f"ssh -o LogLevel=error -o BatchMode=yes {server_list[0]} true", shell=True)
    if args.jmeter:
        jmeter_master = os.environ["JMETER_MASTER"]
        subprocess.check_output(f"ssh -o LogLevel=error -o BatchMode=yes {jmeter_master} true", shell=True)
except Exception:
    logging.error(
        "Error ssh:ing to load generators. Maybe you dont have permission to log on to them? Or your ssh key requires a password? (in that case, use ssh-agent)"
    )
    raise

signal.signal(signal.SIGTERM, sig_handler)

while is_port_in_use(port):
    port += 2

if args.jmeter_kill_first:
    check_output_multiple(
        f"ssh -q {server} 'pkill -u $USER -f jmeter/bin/ApacheJMeter.jar' || true"
        for server in [*server_list, jmeter_master]
    )

slaves = get_available_servers_and_lock_them(loadgens, server_list)
slave_procs = []
start_time = datetime.now(timezone.utc)
atexit.register(cleanup, slaves, args)

if args.jmeter:
    jmeter_grafana_url = os.environ["JMETER_GRAFANA_URL"]
    jmeter_params = " -t " + testplan_filename + " " + " ".join(unrecognized_args)

    if args.jmeter_gui:
        check_output("/usr/local/bin/jmeter" + jmeter_params)
        exit(0)

    try:
        check_output(f"ssh -q {jmeter_master} pgrep -c -u \\$USER java")
        raise Exception("Master was busy running another jmeter test. Bailing out!")
        # check_output(f"ssh -q {jmeter_master} pkill -9 -u \\$USER java &>/dev/null || true")
    except subprocess.CalledProcessError:
        logging.debug("There was no jmeter master running, great.")

    for slave in slaves:
        slave_procs.append(start_jmeter_process(slave, port, unrecognized_args))

    time.sleep(3)
    logging.debug("All jmeter slaves started")
    check_output(f"scp -q {testplan} {jmeter_master}: && scp -q load_profile.csv {jmeter_master}:")
    log_folder = f"logs/{start_time.strftime('%Y-%m-%d-%H.%M')}"
    check_output(f"ssh -q {jmeter_master} mkdir -p {log_folder}")
    # upload an extra copy of test plan & load profile to log folder, just for keeping track of previous runs:
    check_output(
        f"scp -q {testplan} {jmeter_master}:{log_folder} && scp -q load_profile.csv {jmeter_master}:{log_folder}"
    )
    master_command = [
        "ssh",
        "-q",
        jmeter_master,
        f"nohup bash -c 'jmeter/bin/jmeter -n -R {f':{port},'.join(slaves)}:{port} -Jjmeterengine.nongui.port=0 {jmeter_params}'",
    ]
    logging.info(f"launching master: {' '.join(master_command)}")
    master_start_time = datetime.now()
    master_proc = subprocess.Popen(master_command)
else:
    # locust
    os.environ["LOCUST_RUN_ID"] = start_time.isoformat()
    master_command = [
        "locust",
        "--master",
        "--master-bind-port",
        str(port),
        "--master-bind-host=127.0.0.1",  # this avoids MacOS popups about opening firewall for python
        "--expect-slaves",
        str(slave_process_count),
        "--no-web",
        "-f",
        testplan,
        "-t",
        run_time,
        "--exit-code-on-error",
        "0",  # return zero even if there were failed samples (locust default is to return 1)
        *unrecognized_args,
    ]
    logging.info(f"launching master: {' '.join(master_command)}")
    master_proc = subprocess.Popen(master_command)
    locust_env_vars = []

    for varname in os.environ:
        if varname.startswith("LOCUST_") or varname.startswith("PG"):
            if varname == "LOCUST_RPS":
                # distribute the rps over the locust slave processes
                # when client count < slave_process count, not all locust processes will get a client,
                # so use the minium of the two when distributing rps.
                locust_env_vars.append(
                    f'LOCUST_RPS="{float(os.environ[varname])/min(slave_process_count, client_count)}"'
                )
            else:
                locust_env_vars.append(f'{varname}="{os.environ[varname]}"')

    for slave in slaves:
        # fail early if master has already terminated
        check_proc_running(master_proc)
        slave_procs.extend(
            start_locust_processes(slave, port, args.processes_per_loadgen, locust_env_vars, testplan_filename)
        )

# check that slave procs didnt immediately terminate for some reason (like invalid parameters)
time.sleep(2)
for proc in slave_procs:
    check_proc_running(proc)

# wait for test to complete
while True:
    try:
        code = master_proc.wait(timeout=10)
        break
    except subprocess.TimeoutExpired:
        pass
    # ensure slave procs didnt die before master
    for proc in slave_procs:
        try:
            check_proc_running(proc)
        except subprocess.CalledProcessError as e:
            time.sleep(1)
            code = master_proc.poll()
            if code is None:
                logging.error(
                    f"Slave proc finished unexpectedly with ret code {e.returncode} (and master was still running)"
                )
                raise

if args.jmeter:
    # jmeter (unlike locust with its TimescaleListener) doesnt output a link to its report so we do it from here
    host_arg = list(filter(lambda k: "host=" in k, unrecognized_args))
    if host_arg:
        host_with_protocol = host_arg[0].split("=")[1]
        application = re.sub(r"http[s]*://", "", host_with_protocol)
    else:
        # if host name is not explicitly set, then we cant know which application jmeter will log to, so we cant filter
        application = "All"
    logging.info(
        f"Report: {jmeter_grafana_url}&var-application={application}&var-send_interval=5&from={int(master_start_time.timestamp()*1000)}&to={int((time.time())*1000)}\n"
    )

logging.info(f"Load gen master process finished (return code {code})")
sys.exit(code)
