from textual.app import App, ComposeResult
from textual.widgets import Footer, RichLog, Tab, Tabs, DataTable, Header, Label
from textual.widgets.data_table import CellDoesNotExist
from exploitfarm.utils.config import ClientConfig, ExploitConfig
from multiprocessing import Queue
from exploitfarm.model import AttackExecutionStatus
from datetime import datetime as dt
import datetime
from rich.text import Text
from rich.markup import escape
import threading, traceback, psutil
import dateutil.parser
from copy import deepcopy
from exploitfarm.utils import mem_usage
from exploitfarm import Prio, nicenessify
from exploitfarm.xploit import shutdown

class g:
    config: ClientConfig
    memory: dict
    exploit_config: ExploitConfig
    print_queue: Queue
    pool_size: int
    exit_event: threading.Event
    restart_event: threading.Event

def stop_screen():
    g.exit_event.set()
    shutdown()
    g.print_queue.put("Stopping exploit execution..")

def start_exploit_tui(config: ClientConfig, memory: dict, exploit_config: ExploitConfig, print_queue: Queue, pool_size:int, exit_event: threading.Event, restart_event: threading.Event):
    g.config = config
    g.memory = memory
    g.exploit_config = exploit_config
    g.print_queue = print_queue
    g.pool_size = pool_size
    g.exit_event = exit_event
    g.restart_event = restart_event
    nicenessify(Prio.normal)
    XploitRun().run()
    stop_screen()

class XploitRun(App):
    
    def __init__(self):
        super().__init__()
        self.title = f"xFarm - Exploit execution of {g.exploit_config.name}"          
    
    BINDINGS = [
        ("ctrl+c", "cancel()", "Close attack"),
        ("1", "show_tab('Teams')", "Show teams"),
        ("2", "show_tab('Log')", "Show logs"),
    ]
    
    CSS = """
        .hidden {
            display: none;
        }
        .tab {
            width: 100%;
        }
        .tab-Log {
            align: center middle;
        }
        .tab-Teams {
            align: center middle;
            height: 1fr;
        }
        .full-width {
            width: 100%;
        }
        #status_text{
            height: 5;
        }
    """
    
    TABS = [
        Tab("Teams", id="Teams"),
        Tab("Log", id="Log")
    ]
    
    COLUMNS = [
        "Team 👾",
        "Host 📡",
        "Flags 🚩",
        "Last status 📊",
        "Last attack 🕰️",
        "Time to exploit ⏱️",
        "Executing ⚠️"
    ]

    def get_team_row(self, team:dict):
        team_info = deepcopy(g.memory.get("teams", {}).get(team.get("id", None), {
            "last": None,
            "flags": 0,
            "executing": False,
        }))
        team_id = team.get("id", None)
        name = team.get("name", None)
        short_name = team.get("short_name", None)
        host = team.get("host", None)
        flags = team_info["flags"]
        last_attack = team_info["last"]
        last_start_time = last_attack.get("start_time", None) if last_attack else None
        if isinstance(last_start_time, str):
            last_start_time = dateutil.parser.parse(last_start_time)
        last_end_time = last_attack.get("end_time", None) if last_attack else None
        if isinstance(last_end_time, str):
            last_end_time = dateutil.parser.parse(last_end_time)
        time_delta = str(dt.now(datetime.timezone.utc) - last_end_time).split(".")[0] if last_end_time else None
        last_execution_time = str(last_end_time - last_start_time).split(".")[0] if last_start_time and last_end_time else None
        executing = team_info.get("executing", False)
        last_attack_status = last_attack.get("status", None) if last_attack else None
        match last_attack_status:
            case AttackExecutionStatus.crashed.value:
                last_attack_status = "⚠️ ☠️ "
            case AttackExecutionStatus.done.value:
                last_attack_status = " ✅ 🚩"
            case AttackExecutionStatus.noflags.value:
                last_attack_status = " ❌ 🚩"
            case _:
                last_attack_status = "❓"
        return [
            f"{team_id}: {short_name if short_name else name if name else ''}",
            host if host else "",
            flags if flags > 0 else "⚠️  0",
            last_attack_status,
            time_delta + " ago" if time_delta else "⚠️   Never",
            last_execution_time if last_execution_time else "",
            "🔥" if executing else "💤"
        ]
    
    def update_data(self):
        try:
            table = self.query_one(DataTable)
            table.rows.clear()
            column_keys = table.add_columns(*[Text(text="\n\n"+ele, justify="center") for ele in self.COLUMNS])
            while True:
                if g.config.status is None:
                    if g.exit_event.wait(timeout=1): return
                    continue
                next_attack = g.memory.get("next_attack_at", None)
                running_workers = g.memory.get("running_workers", 0)
                submitter_flag_queue = g.memory.get("submitter_flags_in_queue", None)
                submitter_flag_queue = submitter_flag_queue if not submitter_flag_queue is None else "❓"
                submitter_status = g.memory.get("submitter_status", None)
                submitter_status = "🟢" if submitter_status else "[bold red]❌ ( Can't submit flags )[/]" if not submitter_status is None else "❓"
                config_updater = g.memory.get("config_update", None)
                if config_updater == True:
                    config_updater = "🟢" if g.config.status.get("status", None) == "running" else "[bold red]⚠️  Server is not in running mode![/]"
                else:
                    config_updater = "🟢" if config_updater else "[bold red]❌ (Check server is alive)[/]" if not config_updater is None else "❓"
                this_time = dt.now(datetime.timezone.utc)
                delta_next_attack = str(next_attack - this_time).split(".")[0]+" s" if next_attack and next_attack > this_time else "❓"
                status_text = self.query_one("#status_text", Label)
                messages = deepcopy(g.config.status.get("messages", []))
                if g.memory.get("ram_pressure", False):
                    messages.append({"level": "warning", "title": "The system is running out of memory, cannot run more workers!"})
                messages = ", ".join([f"[bold red]{ele['level']}[/]: {escape(ele['title'])}" for ele in messages])
                status_text.update(
                    f"[yellow]Exploit [bold]{escape(g.exploit_config.name)}[/] is running.. [/]\t\t\t"
                    f"[bold]{messages}[/]\n"
                    f"Submitter status: [bold green]{submitter_status}[/] (queued: {submitter_flag_queue})\t"
                    f"Exploit timeout: [bold yellow]{g.memory.get('runtime_timeout', 'Unknown')} s[/]\t\t\t"
                    f"Tick Duration: [bold green]{g.config.status.get('config', {}).get('TICK_DURATION', 'Unknown')} s[/]\n"
                    f"Server connection: [bold]{config_updater}[/]\t\t\t"
                    f"Next Attacks: [bold green]{delta_next_attack}[/]\t\t\t"
                    f"Flag Format: [bold green]{escape(g.config.status.get('config', {}).get('FLAG_REGEX', 'Unknown'))}[/]\n"
                    f"System Memory: [bold green]{mem_usage()}%[/]\t\t\t"
                    f"System CPU: [bold green]{psutil.cpu_percent()}%[/]\t\t\t"
                    f"Running worker: [bold green]{running_workers} (max: {g.pool_size})[/]"
                )
                
                used_keys = set()
                teams = deepcopy(g.config.status.get("teams", []))
                for team in teams:
                    cols = [Text("\n"+str(cell), justify=("center" if not i in (0,1) else "default")) for i, cell in enumerate(self.get_team_row(team))]
                    team_key = f"team-{team['id']}"
                    used_keys.add(team_key)
                    try:
                        [table.update_cell(row_key=team_key, column_key=column_keys[i], value=cols[i]) for i in range(len(column_keys))]
                    except CellDoesNotExist:
                        table.add_row(*cols, height=3, key=team_key)
                for key in table.rows.keys():
                    if key.value not in used_keys:
                        table.remove_row(row_key=key.value)
                if g.exit_event.wait(timeout=1): return
        except Exception as e:
            g.print_queue.put(f"[bold red]Error on update_data: {e}[/]")
            g.print_queue.put(traceback.format_exc())
    
    def update_log(self):
        logger = self.query_one("#log", RichLog)
        while True:
            if g.exit_event.is_set(): return
            log = g.print_queue.get()
            if isinstance(log, str):
                logger.write(escape(log), scroll_end=False)
            else:
                logger.write(log, scroll_end=False)
    
    def exit_listener(self):
        if g.exit_event.wait():
            stop_screen()
            self.exit(0)
    
    def on_ready(self) -> None:
        """Called when the DOM is ready."""
        self.run_worker(self.update_data, thread=True)
        self.run_worker(self.update_log, thread=True)
        self.run_worker(self.exit_listener, thread=True)

    def compose(self) -> ComposeResult:
        yield Header(self.title)
        yield Label(id="status_text", classes="full-width")
        yield Tabs(*self.TABS, id="tabs")
        yield DataTable(id="teams", classes="tab tab-Teams hidden", cell_padding=3, header_height=5, cursor_type="none" )
        yield RichLog(highlight=True, id="log", classes="tab tab-Log hidden", wrap=True, auto_scroll=True, max_lines=5000)
        yield Footer()    
        
    def show_tab(self, tab_name:str):
        self.query(".tab").add_class("hidden")
        self.query(f".tab-{tab_name}").remove_class("hidden")

    def action_show_tab(self, tab_name:str):
        self.query_one("#tabs", Tabs)._activate_tab(self.query_one(f"#{tab_name}", Tab))

    def on_tabs_tab_activated(self, event: Tabs.TabActivated) -> None:
        """Handle TabActivated message sent by Tabs."""
        self.show_tab(event.tab.label_text)

    def action_cancel(self):
        with self.suspend(): # Needed to avoid the app to hang, and terminal to be unusable
            self.exit(0)
            stop_screen()
    
