#!/usr/bin/env python3

from subprocess import Popen, PIPE
from termcolor import colored, cprint
import getpass
import time
import os
from shtop import Tui
from datetime import datetime


tui = Tui()

# get nodes
def getHosts():
	nodesHosts = []
	clusterNote = False
	stdout, stderr = Popen(['cat', '/etc/hosts'], stdout=PIPE).communicate()

	for line in stdout.decode().split("\n"):
		if clusterNote:
			line = line.split("\t")
			if len(line) > 1:
				nodesHosts.append(line[0])
		else:
			if line.count("# slurm") > 0:
				clusterNote = True

	print(f"Founded nodes in /etc/hosts: {nodesHosts}")

	return nodesHosts


def progressbar(val, max, unit, pchar="|", size=60):
	progress = ''

	for i in range(size):
		if (val / max) >= (i / size):
			progress += pchar
		else:
			progress += " "

	progress = colored(progress, "green")

	return f"[{progress}] " + colored(f"{val}{unit}/{max}{unit}", "grey")

def slurmGet():
	stdout, stderr = Popen(['squeue'], stdout=PIPE).communicate()
	stdout = stdout.decode()
	tui.setVar("jobsCount", stdout.count("\n") - 1)
	tui.setVar("jobsRunning", stdout.count(" R "))

	stdout = stdout.split("\n")

	for i in range(len(stdout)):
		stdout[i] = stdout[i][12:]

	stdout[0] += " " * (os.get_terminal_size().columns - len(stdout[0]))  
	stdout[0] = colored(stdout[0], "blue", attrs=['reverse'])

	stdout = "\n".join(stdout)

	tui.setVar("jobs", stdout)
	tui.setVar("update", datetime.now().strftime("%d/%m/%Y %H:%M:%S"))
	

def updateData():
	count     = 0
	nodesLoad = []
	loadSum   = 0
	ramSum    = 0
	nodesRam  = []

	# get CPU INFO
	for node in nodesHosts:
		stdout, stderr = Popen(['sshpass', '-p', password, 'ssh', f'{user}@{node}', 'echo $[100-$(vmstat 1 2|tail -1|awk \'{print $15}\')]'], stdout=PIPE).communicate()
		loadSum += int(stdout)
		nodesLoad.append(int(stdout))
		count += 1

	# get RAM INFO
	for node in nodesHosts:
		stdout, stderr = Popen(['sshpass', '-p', password, 'ssh', f'{user}@{node}', 'free -m | awk \'NR==2{print $3}\''], stdout=PIPE).communicate()
		ramSum += int(stdout)
		nodesRam.append(int(stdout))

	nodesLoadText = ""
	for node in nodesLoad:
		color = 'green'
		if node > 75:
			color = 'red'
		elif node > 25:
			color = 'yellow'
		elif node > 50:
			color = 'orange'

		if node > 99:
			node = 99

		nodesLoadText += ' ' + colored('(%02d)' % node, color, attrs=['reverse']) + ' '


	nodesRamText = ""
	for i in range(len(nodesRam)):
		node = int(nodesRam[i] / nodestotalRAM[i] * 100)

		color = 'green'
		if node > 75:
			color = 'red'
		elif node > 25:
			color = 'yellow'
		elif node > 50:
			color = 'orange'

		if node > 99:
			node = 99

		nodesRamText += ' ' + colored('(%02d)' % node, color, attrs=['reverse']) + ' '

	tui.setVar("nodesCpus", nodesLoadText)
	tui.setVar("load", progressbar(loadSum, count * 100, '%', size=30))
	tui.setVar("usedRAM", progressbar(ramSum, totalRAM, 'M', size=30))
	tui.setVar("nodesRams", nodesRamText)
	tui.setVar("nodes", count)


tui.update()

print("Login to cluster")
print()

# get password
user     = os.getlogin()
password = getpass.getpass(prompt=f'Password for ssh user {user}: ')

tui.update()
print("Loading...")

#get hosts
nodesHosts = getHosts()

# get total ram of nodes
totalRAM = 0
nodestotalRAM = []
for node in nodesHosts:
	stdout, stderr = Popen(['sshpass', '-p', password, 'ssh', f'{user}@{node}', 'free -m | awk \'NR==2{print $2}\''], stdout=PIPE).communicate()
	totalRAM += int(stdout)
	nodestotalRAM.append(int(stdout))

tui.setVar("totalRAM", totalRAM)

tui.addTextLine(colored("CPU", "blue") + "%load%\t\t" + colored("Jobs: %jobsCount%, ", "blue") + colored("%jobsRunning% running; ", "green") + colored("Nodes: %nodes%", "blue"))
tui.addTextLine(colored("MEM", "blue") + "%usedRAM%\t\t" + colored("Last update: %update%", "blue"))
tui.addTextLine("")
tui.addTextLine("CPU load of nodes:")
tui.addTextLine("%nodesCpus%")
tui.addTextLine("")
tui.addTextLine("RAM usage of nodes:")
tui.addTextLine("%nodesRams%")
tui.addTextLine("")
tui.addTextLine("%jobs%")

while True:
	updateData()
	slurmGet()
	tui.update()
