#!/usr/bin/env python

from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("mode")

parser.add_argument("vm")

parser.add_argument("-p", "--port", dest="port", help="Local Port.", default=8888)
parser.add_argument("-ep", "--eport", dest="eport", help="External port for the jupyter notebook.", default=8888)

parser.add_argument("-d", "--debug",
                    action="store_true", dest="debug", default=False,
                    help="All commands will be printed.")

parser.add_argument("-dr", "--dryrun",
                    action="store_true", dest="debug", default=False,
                    help="No command will be executed.")

pargs = parser.parse_args()

import time, timeit
import os
from concurrent.futures import ThreadPoolExecutor

import encap_lib.encap_settings as settings
from encap_lib.encap_lib import SSHMachine, GCMachine, LocalMachine, get_machine
from encap_lib import fabric_wrapper


def extract_hosts(a):
    """
    input "er{00-05}c"
    ouput ['er00c', 'er01c', 'er02c', 'er03c', 'er04c', 'er05c']
    """
    head, mid = a.split("{")
    mid, back = mid.split("}")

    n1, n2 = mid.split("-")
    # Count leading zeros
    for l, ni in enumerate(n1):
        if ni != "0":
            break

    # Construct hosts
    # zfill: will fill with the begining of the sring with zeros
    return [head + str(i).zfill(l+1) + back for i in range(int(n1), int(n2)+1)]

def get_status(vms):
    
    # machines = [fabric_wrapper.SSHMachine(vm) for vm in vms]
    pool = ThreadPoolExecutor(9)
    machines = pool.map(fabric_wrapper.SSHMachine, vms)
    
    code ="""python -c "import psutil, pwd, os
print(max(psutil.cpu_percent(interval=10, percpu=True)))
# Refresh
processes = list(psutil.process_iter(['pid', 'name', 'cpu_percent', 'uids']))

current_pid = os.getpid()

# Get all running processes except the current one
processes = (proc for proc in psutil.process_iter(['pid', 'name', 'cpu_percent', 'uids']) if proc.info['pid'] != current_pid)

# Find the process with the highest CPU usage
max_cpu_process = max(processes, key=lambda p: p.info['cpu_percent'])

# Get the UID of the owner of the process
uid = max_cpu_process.info['uids'].real

# Get the username corresponding to the UID
username = pwd.getpwuid(uid).pw_name
print(username)" """

    procs = [machine.run_code(code, wait=False) for machine in machines if machine.conn is not None]
    
    return [fabric_wrapper.follow_output(proc, verbose=False) for proc in procs]


if pargs.debug:
    settings.debug = True

local_project_dir = os.getcwd()

#Begining of code
localmachine = LocalMachine(local_project_dir)

if pargs.mode == "status":
    if "-" in pargs.vm and "{" in pargs.vm:
        vms = extract_hosts(pargs.vm)
        percents_and_names = get_status(vms)
        for vm, (percent, name) in zip(vms, percents_and_names):
            if float(percent) < 10:
                print(f"\033[1m{vm}: {percent} free\033[0m")
            else:
                print(f"{vm}: {percent} {name}")
        
    exit()

machine = get_machine(pargs.vm, local_project_dir)

if pargs.mode == "ssh":
    machine.interactive_ssh()

elif pargs.mode == "sshfs":
    local_mount = "~/encap_mount"
    localmachine.run_code(f"mkdir -p {local_mount}")
    #Check if anything is mounted
    out = localmachine.run_code(f"mountpoint {local_mount}")
    assert len(out) == 1

    if out[0] == "{local_mount} is a mountpoint":
        localmachine.run_code(f"umount {local_mount}")

    command = f"sshfs -o ssh_command='ssh -C'  {machine.username}@{machine.ip}:{machine.remote_project_dir} {local_mount}"
    localmachine.run_code(command)

elif pargs.mode == "stop":
    assert type(machine) == GCMachine
    localmachine.run_code(f"gcloud compute instances stop {machine.vm} --zone {machine.zone}")

elif pargs.mode == "sync":
    t0 = time.time()
    assert type(machine) == GCMachine
    machine.rsync_push("", "", verbose=True)
    print("It took:",(time.time() - t0), "s" )

elif pargs.mode == "sync_delete":
    assert type(machine) == GCMachine
    t0 = time.time()
    machine.rsync_push("", "", verbose=True, last_timestamp_prevails=False, delete=True)
    print("It took:",(time.time() - t0), "s" )

elif pargs.mode == "timeit":
    #assert type(machine) == GCMachine
    f = lambda: machine.run_code("ls", verbose=False)
    print(timeit.timeit(f, number=10))


elif pargs.mode == "jupyter":
    pid = localmachine.run_code(f"ssh -N -f -L localhost:{pargs.eport}:localhost:{pargs.port} {machine.username}@{machine.ip}", verbose=False, wait=False, debug=True)
    try:
        machine.run_code(f"jupyter notebook --no-browser --port={pargs.eport}", verbose=True)
    except KeyboardInterrupt:
        localmachine.run_code(f'pkill -f "ssh -N"')
        machine.run_code(f'pkill -f "jupyter-notebook"')

elif pargs.mode == "connect":
    pid = localmachine.run_code(f"ssh -N -f -L localhost:{pargs.eport}:localhost:{pargs.port} {machine.username}@{machine.ip}", verbose=False, wait=False, debug=True)

else:
    raise ValueError("The mode '%s' is not available." % pargs.mode)
