#!/usr/bin/env python3

"""
Some of this code is taken by https://github.com/DestructiveVoice/DestructiveFarm/blob/master/client/start_sploit.py
And that's the reason why this submitter is called "exploitfarm"
"""

import dateutil.parser
import os, re, orjson, math
import subprocess, time, threading, traceback
from concurrent.futures import ThreadPoolExecutor
from exploitfarm.utils.config import ClientConfig, ExploitConfig
from datetime import datetime as dt, timedelta
import datetime, io
from exploitfarm.model import AttackExecutionStatus, AttackMode
from rich import print
from uuid import UUID
from copy import deepcopy
from rich.markup import escape
from queue import Queue, Empty
import multiprocessing
from exploitfarm.utils import mem_usage
from exploitfarm import Prio, nicenessify
from exploitfarm import DEV_MODE

XPLOIT_DEBUG = os.getenv("XPLOIT_DEBUG", "0") in ["1", "true", "True"] 

import sys
def dumpstacks():
    if not XPLOIT_DEBUG: return
    id2name = dict([(th.ident, th.name) for th in threading.enumerate()])
    code = []
    for threadId, stack in sys._current_frames().items():
        code.append("\n# Thread: %s(%d)" % (id2name.get(threadId,""), threadId))
        for filename, lineno, name, line in traceback.extract_stack(stack):
            code.append('File: "%s", line %d, in %s' % (filename, lineno, name))
            if line:
                code.append("  %s" % (line.strip()))
    with open(f"stacktrace.log", "w") as f:
        f.write("\n".join(code))

def run_windows_ctrl_c_fix():
    if os_windows:
        # By default, Ctrl+C does not work on Windows if we spawn subprocesses.
        # Here we fix that using WinApi. See https://stackoverflow.com/a/43095532

        import signal
        import ctypes
        from ctypes import wintypes

        g.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)

        # BOOL WINAPI HandlerRoutine(
        #   _In_ DWORD dwCtrlType
        # );
        PHANDLER_ROUTINE = ctypes.WINFUNCTYPE(wintypes.BOOL, wintypes.DWORD)

        g.win_ignore_ctrl_c = PHANDLER_ROUTINE()  # = NULL

        def _errcheck_bool(result, _, args):
            if not result:
                raise ctypes.WinError(ctypes.get_last_error())
            return args

        # BOOL WINAPI SetConsoleCtrlHandler(
        #   _In_opt_ PHANDLER_ROUTINE HandlerRoutine,
        #   _In_     BOOL             Add
        # );
        g.kernel32.SetConsoleCtrlHandler.errcheck = _errcheck_bool
        g.kernel32.SetConsoleCtrlHandler.argtypes = (PHANDLER_ROUTINE, wintypes.BOOL)

        @PHANDLER_ROUTINE
        def win_ctrl_handler(dwCtrlType):
            if dwCtrlType == signal.CTRL_C_EVENT:
                g.kernel32.SetConsoleCtrlHandler(g.win_ignore_ctrl_c, True)
                if g.__callback_exit_win:
                    g.__callback_exit_win()
                shutdown()
            return False

        g.kernel32.SetConsoleCtrlHandler(win_ctrl_handler, True)

def dt_now():
    return dt.now(datetime.timezone.utc)
        
os_windows = (os.name == 'nt')
os_unix = (os.name == 'posix')

class g:
    pool_size:int = 50
    config: ClientConfig
    memory: dict = None
    submit_pool_timeout: int = 3
    server_status_refresh_period: int = 1
    exploit_config: ExploitConfig = None
    print_queue: Queue = None
    config_update_event:Queue = Queue()
    attack_storage:"AttackStorage" = None 
    instance_storage:"InstanceStorage" = None
    exit_event: threading.Event = None
    server_id = None
    restart_event = None
    max_mem_usage = 95
    pool: ThreadPoolExecutor = None
    tot_teams:int = 0
    remaining_teams:int = 0
    initial_timeout:int = 0
    is_ram_maxed = False
    max_queue = multiprocessing.cpu_count()*10
    flag_cache:set = set()
    last_attack_time: dt = dt_now()
    teams_info = {}
    old_config = {}
    _shutdown = False
    kernel32 = None
    win_ignore_ctrl_c = None
    __callback_exit_win = None
class InstanceStorage:
    """
    Storage comprised of a dictionary of all running sploit instances and some statistics.

    Always acquire instance_lock before using this class. Do not release the lock
    between actual spawning/killing a process and calling register_start()/register_stop().
    """

    def __init__(self):
        self._counter = 0
        self._instances = {}

        self.n_completed = 0
        self.n_killed = 0
        self.lock = threading.RLock()

    @property
    def instances(self):
        return self._instances

    def register_start(self, process):
        with self.lock:
            instance_id = self._counter
            g.remaining_teams -= 1
            self._instances[instance_id] = process
            self._counter += 1
            g.memory["running_workers"] = len(self._instances)
            return instance_id

    def register_stop(self, instance_id, was_killed):
        with self.lock:
            del self._instances[instance_id]
            self.n_completed += 1
            self.n_killed += was_killed
            g.memory["running_workers"] = len(self._instances)

def default_team_info():
    return {
        "last": None,
        "flags": 0,
        "executing": False,
    }

def set_attack_exec(team:int, status:bool):
    if not "teams" in g.memory:
        g.memory["teams"] = {}
    team_info = g.memory["teams"].get(team, default_team_info())
    team_info["executing"] = status
    g.memory["teams"][team] = team_info

class AttackStorage:
    """
    Thread-safe storage comprised of a set and a post queue.

    Any number of threads may call add(), but only one "consumer thread"
    may call pick_flags() and mark_as_sent().
    """

    def __init__(self, client_id: UUID|None = None):
        self._queue = []
        self._lock = threading.RLock()
        self.client_id = client_id

    def add(self,
        flags: list[str],
        team: int,
        start_time: dt,
        end_time: dt,
        status: AttackExecutionStatus,
        output: bytes,
        stats: bool = True
    ):
        with self._lock:
            flags = [flag for flag in flags if flag not in g.flag_cache]
            g.flag_cache.update(flags)
            
            execution_submit = {
                "start_time": start_time,
                "end_time": end_time,
                "status": status.value,
                "error": None if status == AttackExecutionStatus.done else output,
                "executed_by": self.client_id,
                "target": team,
                "flags": flags,
            }
            if stats:
                team_info = g.memory["teams"].get(team, default_team_info())
                team_info["last"] = {
                    "start_time": start_time,
                    "end_time": end_time,
                    "status": status.value,
                    "flags": len(flags),
                }
                team_info["flags"] += len(flags)
                g.memory["teams"][team] = team_info
            
            self._queue.append(execution_submit)

    def pick_attacks(self):
        with self._lock:
            return self._queue[:]

    def mark_as_sent(self, count):
        with self._lock:
            self._queue = self._queue[count:]

    @property
    def queue_size(self):
        with self._lock:
            return len(self._queue)

WARNING_RUNTIME = 5

def qprint(*args, end="\n"):
    try:
        for arg in args:
            if isinstance(arg, str):
                g.print_queue.put(escape(arg))
            else:
                g.print_queue.put(arg)
        if end:
            g.print_queue.put(end)
    except (KeyboardInterrupt, ValueError):
        traceback.print_exc()
        pass

class InvalidSploitError(Exception):
    pass

if os_unix:
    #Unlock the file descriptor limit
    import resource
    fdlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (fdlimit[1]//2,fdlimit[1]))




class APIException(Exception):
    pass

def once_in_a_period(period):
    iter_no = 1
    while(True):
        start_time = time.time()
        yield iter_no
        time_spent = time.time() - start_time
        if period > time_spent:
            if g.exit_event.wait(timeout=period - time_spent): return
        iter_no += 1

def repush_flags():
    try:
        with open(".flag_queue.json", "rt") as f:
            old_queue_data = orjson.loads(f.read())
        if old_queue_data["server_id"] != g.server_id:
            raise Exception("Server ID mismatch")
        for ele in old_queue_data["queue"]:
            g.attack_storage.add(ele["flags"], ele["target"], ele["start_time"], ele["end_time"], AttackExecutionStatus(ele["status"]), ele["error"], stats=False)
    except Exception:
        if os.path.exists(".flag_queue.json"):
            os.remove(".flag_queue.json")

def config_comp_and_update():
    new_conf = deepcopy(g.config.status)
    old_conf = g.old_config
    for ignore_key in ["server_time", "start_time", "end_time", "submitter", "services", "messages"]:
        if new_conf.get(ignore_key, None):
            del new_conf[ignore_key]
        if old_conf.get(ignore_key, None):
            del old_conf[ignore_key]
    if new_conf != old_conf:
        qprint('--- Config updated! ---')
        if DEV_MODE:
            qprint(new_conf)
        g.old_config = new_conf
        return True
    return False

def update_server_config():
    try:
        for _ in once_in_a_period(g.server_status_refresh_period):
            try:
                g.config.fetch_status()
                if g.config.status:
                    if not g.config.status["loggined"]:
                        qprint('Not loggined, login first to get the config and run the exploit')
                        print("[red]Not loggined, login first to get the config and run the exploit[/red]")
                        g.memory["config_update"] = False
                        shutdown()
                        break
                    else:
                        g.memory["config_update"] = True
                    if not g.server_id:
                        g.server_id = g.config.status["server_id"]
                        repush_flags()
                    else:
                        if g.server_id != g.config.status["server_id"]:
                            qprint('Server ID changed, restart the exploit')
                            shutdown(restart=True)
                            return
                    if config_comp_and_update():
                        g.config_update_event.put("update")
            except Exception as e:
                qprint(f"Can't get config from the server: {repr(e)}")
                qprint(traceback.format_exc())
                g.memory["config_update"] = False
    except Exception as e:
        qprint(traceback.format_exc())
        qprint(f'Config update loop died: {repr(e)}')
        shutdown()

def post_attacks(attacks):
    try:
        g.config.reqs.submit_flags(attacks, exploit=g.exploit_config.uuid)
        g.attack_storage.mark_as_sent(len(attacks))
        g.memory["submitter_status"] = True
        return True
    except Exception as e:
        g.memory["submitter_status"] = False
        qprint(f"Can't post flags to the server: {repr(e)}")
        return False

def run_post_loop():
    nicenessify(Prio.normal)
    try:
        for _ in once_in_a_period(g.submit_pool_timeout):
            all_attack_to_post = g.attack_storage.pick_attacks()
            g.memory["submitter_flags_in_queue"] = len(all_attack_to_post)
            packet_dim = 50 #50 is to avoid the backend goes in timeout
            for base_block in range(0, len(all_attack_to_post), packet_dim):
                end_block = base_block + packet_dim
                attack_to_post = all_attack_to_post[base_block:end_block]
                if attack_to_post and not post_attacks(attack_to_post):
                    break
    except Exception as e:
        g.memory["submitter_status"] = False
        qprint(traceback.format_exc())
        qprint(f'Posting loop died: {repr(e)}')

def process_sploit_filter(proc:subprocess.Popen, team: dict, start_time: dt, killed: bool, read_data:callable, final_output: bool = True):
    try:
        end_time = dt_now()
        output = read_data()
        
        if isinstance(output, bytes):
            output = output.decode('utf-8', errors='replace')
        
        if final_output:
            qprint(f"Output of the sploit of {team['host']} killed[{killed}] status[{proc.returncode}]: '{team.get('short_name') or team.get('name') or team.get('id')}' started: {start_time} ended: {end_time}\n{output}")
        
        if not final_output and killed:
            qprint("☠️ xFarm killed the sploit due to timeout!")
        
        flag_format = re.compile(g.config.status["config"]["FLAG_REGEX"])
        flags = list(map(str, set(flag_format.findall(output))))
        
        if killed or proc.returncode != 0:
            if killed:
                output = "THIS PROCESS HAS BEEN KILLED BY EXPLOITFARM DUE TO TIMEOUT\n-------- OUTPUT -------\n\n" + output
            g.attack_storage.add(flags, team["id"], start_time, end_time, AttackExecutionStatus.crashed, output)
        elif flags:
            g.attack_storage.add(flags, team["id"], start_time, end_time, AttackExecutionStatus.done, b'')
        else:
            g.attack_storage.add(flags, team["id"], start_time, end_time, AttackExecutionStatus.noflags, output)
    except Exception as e:
        traceback.print_exc()
        qprint(f'Failed to process sploit output: {repr(e)}')

def read_and_print(stdout:io.BytesIO, buffer:io.BytesIO):
    # Read stdout line by line in a separate thread
    for line in iter(stdout.readline, b''):
        qprint(line.decode('utf-8', errors='replace'), end="")  # Print to main stdout in real-time
        buffer.write(line)
        buffer.flush()
    buffer.write(stdout.read())
    stdout.close()

def launch_sploit(team:dict, stream_output:bool=False) -> tuple[subprocess.Popen, int, callable]:
    # For sploits written in Python, this env variable forces the interpreter to flush
    # stdout and stderr after each newline. Note that this is not default behavior
    # if the sploit's output is redirected to a pipe.
    env = os.environ.copy()
    env['PYTHONUNBUFFERED'] = '1'
    env["XFARM_HOST"] = team["host"]
    env["XFARM_TEAM"] = orjson.dumps(team).decode()
    
    command = f"{g.exploit_config.interpreter} {g.exploit_config.run}".strip()
    need_close_fds = (not os_windows)

    if os_windows and g.kernel32:
        # On Windows, we block Ctrl+C handling, spawn the process, and
        # then recover the handler. This is the only way to make Ctrl+C
        # intercepted by us instead of our child processes.
        g.kernel32.SetConsoleCtrlHandler(g.win_ignore_ctrl_c, True)
    proc = None
    if os_windows:
        proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=need_close_fds, env=env, shell=True)
    else:
        proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=need_close_fds, env=env, shell=True, preexec_fn=nicenessify)
    try:
        if os_windows:
            nicenessify(pid=proc.pid)
        if os_windows and g.kernel32:
            g.kernel32.SetConsoleCtrlHandler(g.win_ignore_ctrl_c, False)
            
    except Exception as e:
        qprint("Error on setting priority: {e}")
        
    
    read_function = None
    if stream_output:
        pipe = proc.stdout
        buffer = io.BytesIO()
        thr = threading.Thread(target=read_and_print, args=(pipe, buffer))
        thr.start()
        def func():
            thr.join()
            proc.stdout.close()
            value = buffer.getvalue()
            buffer.close()
            return value
        read_function = func
        
    if not read_function:
        def func():
            value = proc.stdout.read()
            proc.stdout.close()
            return value
        read_function = func

    return proc, g.instance_storage.register_start(proc), read_function


def run_sploit(team:dict, static_max_runtime:float|None = None, stream_output:bool=False):
    nicenessify()
    for _ in once_in_a_period(0.1):
        memory_usage = mem_usage()
        if (memory_usage+(1 if g.is_ram_maxed else 0)) <= g.max_mem_usage: #Allow a 1% of margin
            g.is_ram_maxed = False
            g.memory["ram_pressure"] = False
            break
        else:
            g.memory["ram_pressure"] = True
            g.is_ram_maxed = True
    start_time = dt_now()
    if g.exit_event.is_set(): return
    try:
        set_attack_exec(team["id"], True)
        proc, instance_id, read_data = launch_sploit(team, stream_output)
    except Exception as e:
        error_text = traceback.format_exc()
        qprint(error_text)
        if isinstance(e, FileNotFoundError):
            qprint(f'Sploit file or the interpreter for it not found: {repr(e)}')
        else:
            g.attack_storage.add([], team["id"], start_time, dt_now(), AttackExecutionStatus.crashed, error_text)
            qprint(f'Failed to run sploit: {repr(e)}')
        return
    time_execution = g.last_attack_time
    try:
        need_kill = False
        try:
            start_time_exec = time.time()
            for _ in once_in_a_period(0.1):
                if not proc.poll() is None:
                    break
                if static_max_runtime:
                    if static_max_runtime - (time.time() - start_time_exec) <= 0:
                        need_kill = True
                        break
                else:
                    if calc_exec_timeout(time.time() - start_time_exec, time_execution) <= 0:
                        need_kill = True
                        break
        except subprocess.TimeoutExpired:
            need_kill = True
        
        with g.instance_storage.lock:
            proc.kill()
        
        set_attack_exec(team["id"], False)
        process_sploit_filter(proc, team, start_time, need_kill, read_data, not stream_output)
        g.instance_storage.register_stop(instance_id, need_kill)
    except Exception as e:
        qprint(traceback.format_exc())
        qprint(f'Failed to finish sploit: {repr(e)}')

def normal_attack_period():
    attack_mode = AttackMode(g.config.status["config"]["ATTACK_MODE"])
    match attack_mode:
        case AttackMode.TICK_DELAY:
            return g.config.status["config"]["TICK_DURATION"]
        case AttackMode.WAIT_FOR_TIME_TICK:
            return g.config.status["config"]["TICK_DURATION"]
        case AttackMode.LOOP_DELAY:
            return g.config.status["config"]["LOOP_ATTACK_DELAY"]

def next_timeout(last_execution: dt|None = None) -> float|int:
    last_execution = last_execution if last_execution else g.last_attack_time
    attack_mode = AttackMode(g.config.status["config"]["ATTACK_MODE"])
    this_time = dt_now()
    match attack_mode:
        case AttackMode.TICK_DELAY:
            timeout = g.config.status["config"]["TICK_DURATION"] - (this_time - last_execution).total_seconds()
            return timeout if timeout > 0 else 0
        case AttackMode.WAIT_FOR_TIME_TICK:
            start_time = dateutil.parser.parse(g.config.status["config"]["START_TIME"])
            tick = 1
            next_time:dt = None
            if start_time > this_time:
                return (start_time - this_time).total_seconds()
            while True:
                next_time = start_time + timedelta(seconds=g.config.status["config"]["TICK_DURATION"] * tick) + timedelta(seconds=g.config.status["config"]["ATTACK_TIME_TICK_DELAY"])
                if next_time > last_execution:
                    break
                tick += 1
            return (next_time - last_execution).total_seconds()            
        case AttackMode.LOOP_DELAY:
            timeout = g.config.status["config"]["LOOP_ATTACK_DELAY"] - (this_time - last_execution).total_seconds()
            return timeout if timeout > 0 else 0

def get_teams():
    try:
        return deepcopy(g.config.status["teams"])
    except Exception as e:
        qprint(f"Can't get teams from the server: {repr(e)}")
        return []

def calc_exec_timeout(time_used:float=0.0, last_execution: dt|None = None) -> float|int:
    
    if g.tot_teams == 0:
        return 0
    
    wait_for_execution_teams = g.remaining_teams
    running_teams = len(g.instance_storage.instances)
    #teams_not_completed = wait_for_execution_teams + running_teams
    tot_teams = g.tot_teams
    pool_size = g.pool_size
    total_time = g.initial_timeout
    ram_pressed = g.is_ram_maxed

    remaining_time = next_timeout(last_execution)
    
    if remaining_time <= 0:
        return 0 #kill all the processes
    
    if wait_for_execution_teams == 0:
        return remaining_time #Give all the time to the running teams
    
    else:
        if ram_pressed:
            if running_teams == 0:
                result = total_time / tot_teams
            else:
                result = total_time / math.ceil(tot_teams / running_teams)
        else:
            result = (total_time / math.ceil(tot_teams / pool_size))
       
    if result > remaining_time:
        result = remaining_time   
    
    g.memory["runtime_timeout"] = result = math.ceil(result)    
    
    result -= time_used
    
    if result > remaining_time:
        result = remaining_time

    result = math.ceil(result)
    return result

def set_http_timeout():
    import exploitfarm.utils.reqs
    exploitfarm.utils.reqs.HTTP_TIMEOUT = 30

def xploit(path: str):
    set_http_timeout()
    os.chdir(path)
    threading.Thread(target=run_post_loop).start()
    threading.Thread(target=update_server_config).start()
    g.pool = ThreadPoolExecutor(max_workers=g.max_queue if g.pool_size is None else g.pool_size, thread_name_prefix="sploit")
    try:
        while True:   
            if g.instance_storage.n_completed > 0:
                qprint('Total {:.1f}% of instances ran out of time'.format(
                    float(g.instance_storage.n_killed) / g.instance_storage.n_completed * 100))
            
            teams = get_teams()
            with g.instance_storage.lock:
                g.tot_teams = g.remaining_teams = len(teams)
            this_time = dt_now()
            g.last_attack_time = this_time
            g.initial_timeout = next_timeout()
            g.memory["next_attack_at"] = dt_now() + timedelta(seconds=g.initial_timeout)
            g.is_ram_maxed = False
            for team in teams:
                g.pool.submit(run_sploit, team)
            try:
                timeout = next_timeout()
                qprint(f"Next attack in {timeout:.1f} seconds")
                result = g.config_update_event.get(timeout=timeout) #If the config is updated, all attacks will be "restarted"
                if result == "shutdown": return
            except Empty:
                pass
            except Exception:
                qprint(traceback.format_exc())
    except (KeyboardInterrupt, RuntimeError):
        pass
    finally:
        shutdown()

def run_printer_queue():
    try:
        while True:
            if g.exit_event.is_set():return
            print(g.print_queue.get(), end="")
    except Exception:
        pass

def shutdown(restart:bool=False):
    if g._shutdown: return
    g._shutdown = True
    g.exit_event.set()
    g.config_update_event.put("shutdown")
    # Kill all child processes (so consume_sploit_ouput and run_sploit also will stop)
    with g.instance_storage.lock:
        for proc in g.instance_storage.instances.values():
            proc.kill()
    if restart:
        g.restart_event.set()
    g.print_queue.put("Stopping exploit execution..")
    if g.pool:
        g.pool.shutdown(wait=True, cancel_futures=True)
    try:
        data = {"server_id": g.server_id, "queue":g.attack_storage.pick_attacks()}
        if data:
            with open(".flag_queue.json", "wb") as f:
                f.write(orjson.dumps(data))
        else:
            if os.path.exists(".flag_queue.json"):
                os.remove(".flag_queue.json")
    except Exception:
        pass
    
    dumpstacks()        

def xploit_one(config: ClientConfig, team: str, path: str, timeout: float = 30):
    set_http_timeout()
    os.chdir(path)
    try:
        g.config = config
        try:
            g.config.fetch_status()
        except Exception as e:
            qprint(f"[bold yellow]Can't get config from the server: {repr(e)}")
            return
        try:
            g.config.fetch_status()
        except Exception as e:
            qprint(f"[bold yellow]Can't get config from the server: {repr(e)}")
        
        g.memory = {}
        g.exploit_config = ExploitConfig.read(path)
        g.print_queue = Queue()
        g.attack_storage = AttackStorage(config.client_id)
        g.instance_storage = InstanceStorage()
        g.exit_event = threading.Event()
        g.restart_event = threading.Event()
        threading.Thread(target=run_printer_queue).start()
        run_sploit({"id":0, "host":team}, timeout, stream_output=True)
        try:
            while True:
                attacks = g.attack_storage.pick_attacks()
                if len(attacks) == 0:
                    break
                for attack in attacks:
                    if len(attack["flags"]) > 0:
                        g.config.reqs.submit_flags({"flags":attack["flags"]})
                    g.attack_storage.mark_as_sent(1)
                    if len(attack["flags"]) > 0:
                        qprint(f"Submitted {len(attack['flags'])} flags as manual submissions.")
        except Exception as e:
            qprint(f"[bold yellow]Can't submit flags due missing config")
    except KeyboardInterrupt:
        pass
    finally:
        shutdown()

def start_xploit(config: ClientConfig, shared_dict:dict, print_queue: Queue, pool_size:int=50, max_mem_usage:int=95, path:str="./", submit_pool_timeout:int=3, server_status_refresh_period: int = 1, exit_event: threading.Event = None, restart_event: threading.Event = None) -> threading.Thread:
    g.config = config
    g.pool_size = pool_size
    g.memory = shared_dict
    g.old_config = g.config.status
    g.submit_pool_timeout = submit_pool_timeout
    g.server_status_refresh_period = server_status_refresh_period
    g.exploit_config = ExploitConfig.read(path)
    g.print_queue = print_queue
    g.attack_storage = AttackStorage(config.client_id)
    g.instance_storage = InstanceStorage()
    g.max_mem_usage = max_mem_usage
    g.exit_event = exit_event if exit_event else threading.Event()
    g.restart_event = restart_event if restart_event else threading.Event()
    thr = threading.Thread(target=xploit, args=(path,))
    thr.start()
    return thr
