from textual.app import App, ComposeResult
from textual.widgets import Footer, RichLog, Tab, Tabs, DataTable, Header, Label
from textual.widgets.data_table import CellDoesNotExist
from exploitfarm.utils.config import ClientConfig, ExploitConfig
from multiprocessing import Queue
from exploitfarm.model import AttackExecutionStatus
from datetime import datetime as dt
import datetime
from rich.text import Text
from rich.markup import escape
from textual_plotext import PlotextPlot
import threading
from textual.reactive import var
from datetime import timedelta
import dateutil.parser
from copy import deepcopy

class g:
    config: ClientConfig
    memory: dict
    exploit_config: ExploitConfig
    print_queue: Queue
    pool_size: int
    exit_event: threading.Event
    restart_event: threading.Event

def stop_screen():
    g.exit_event.set()
    g.print_queue.put("Stopping exploit execution..")

def start_exploit_tui(memory: dict, exploit_config: ExploitConfig, print_queue: Queue, pool_size:int, exit_event: threading.Event, restart_event: threading.Event):
    g.memory = memory
    g.exploit_config = exploit_config
    g.print_queue = print_queue
    g.pool_size = pool_size
    g.exit_event = exit_event
    g.restart_event = restart_event
    XploitRun().run()
    stop_screen()

class FlagGraph(PlotextPlot):
    """A plot of the flags submitted by the exploit."""

    marker: var[str] = var("hd")
    """The type of marker to use for the plot."""

    def __init__(
        self,
        title: str,
        *,
        name: str | None = None,
        id: str | None = None,  # pylint:disable=redefined-builtin
        classes: str | None = None,
        disabled: bool = False,
    ) -> None:
        """Initialise the flag graph.

        Args:
            name: The name of the flag graph.
            id: The ID of the flag graph widget in the DOM.
            classes: The CSS classes of the widget.
            disabled: Whether the widget is disabled or not.
        """
        super().__init__(name=name, id=id, classes=classes, disabled=disabled)
        self._title = title
        self._unit = "Loading..."
        self._data: list[int] = []
        self._time: list[str] = []

    def on_mount(self) -> None:
        """Plot the data using Plotext."""
        self.plt.date_form("d/m/Y H:M:S.f", "H:M:S")
        self.plt.title(self._title)
        self.plt.xlabel("Time")
        self.plt.ylabel("Flags")

    def replot(self) -> None:
        """Redraw the plot."""
        self.plt.clc()
        self.plt.cld()
        self.plt.plot(self._time, self._data, marker=self.marker)
        self.refresh()
    def update(self) -> None:
        config: ClientConfig = deepcopy(g.memory.get("config", None))
        if config is None: return
        teams = config.get("teams", [])
        default_cache = {"data_graph":[], "teams_index":{}, "tick_duration":config["config"]["TICK_DURATION"]}
        graph_cache = g.memory.get("graph_cache", default_cache)
        
        if graph_cache["tick_duration"] != config["config"]["TICK_DURATION"]:
            graph_cache = default_cache
        
        data_to_analyse = []
        for team in teams:
            team_key = f"team-{team['id']}"
            history = g.memory.get(team_key, [])[graph_cache["teams_index"].get(team_key,0):].copy()
            graph_cache["teams_index"][team_key] = graph_cache["teams_index"].get(team_key,0) + len(history)
            for ele in history:
                end_time = ele.get("end_time", None)
                if end_time is None: continue
                flags = ele.get("flags", 0)
                if isinstance(end_time, str):
                    end_time = dateutil.parser.parse(end_time)
                data_to_analyse.append((end_time, flags))
        data_to_analyse = sorted(data_to_analyse, key=lambda x: x[0].timestamp())
        if not data_to_analyse: return
        data_graph = graph_cache["data_graph"]
        if not data_graph:
            data_graph = [ (data_to_analyse[0][0], 0) ]
        dg_index = len(data_graph)-1
        timeoffset = timedelta(seconds=config["config"]["TICK_DURATION"])
        timeend = data_graph[-1][0]+timeoffset
        for ele in data_to_analyse:
            if ele[0] <= timeend:
                data_graph[dg_index] = ((data_graph[dg_index][0], data_graph[dg_index][1]+ele[1]))
            else:
                data_graph.append((timeend, ele[1]))
                timeend += timeoffset
                dg_index += 1
        graph_cache["data_graph"] = data_graph
        g.memory["graph_cache"] = graph_cache
        
        self._data = [ele[1] for ele in data_graph]
        self._time = [ele[0].strftime("%d/%m/%Y %H:%M:%S.%f") for ele in data_graph]
        self.replot()

class XploitRun(App):
    
    def __init__(self):
        super().__init__()
        self.title = f"xFarm - Exploit execution of {g.exploit_config.name}"          
    
    BINDINGS = [
        ("ctrl+c", "cancel()", "Close attack"),
        ("1", "show_tab('Teams')", "Show teams"),
        ("2", "show_tab('Log')", "Show logs"),
        ("3", "show_tab('Graph')", "Show graph"),
    ]
    
    CSS = """
        .hidden {
            display: none;
        }
        .tab {
            width: 100%;
            height: 88%;
        }
        .tab-Log {
            align: center middle;
        }
        .tab-Teams {
            align: center middle;
        }
        .full-width {
            width: 100%;
        }
    """
    
    TABS = [
        Tab("Teams", id="Teams"),
        Tab("Log", id="Log"),
        Tab("Graph", id="Graph")
    ]
    
    COLUMNS = [
        "Team 👾",
        "Host 📡",
        "Errors 🚨",
        "Flags 🚩",
        "Last status 📊",
        "Last flag 🕰️",
        "Time to exploit ⏱️",
        "Enabled ☑️\n(Click to switch)",
        "Executing ⚠️"
    ]

    def get_team_row(self, team:dict):
        team_id = team.get("id", None)
        name = team.get("name", None)
        short_name = team.get("short_name", None)
        host = team.get("host", None)
        history = g.memory.get(f"team-{team_id}", [])
        errors = len([x for x in history if x.get("status", None) == AttackExecutionStatus.crashed.value])
        flags = sum([x.get("flags", 0) for x in history])
        last_attack_status = history[-1].get("status", None) if history else None
        last_flag_time = None
        for ele in reversed(history):
            if ele.get("flags", 0) > 0:
                last_flag_time = ele.get("end_time", None)
                if isinstance(last_flag_time, str):
                    last_flag_time = dateutil.parser.parse(last_flag_time)
                last_flag_time = str(dt.now(datetime.timezone.utc) - last_flag_time).split(".")[0] if last_flag_time else None
                break
        last_start_time = history[-1].get("start_time", None) if history else None
        if isinstance(last_start_time, str):
            last_start_time = dateutil.parser.parse(last_start_time)
        last_end_time = history[-1].get("end_time", None) if history else None
        if isinstance(last_end_time, str):
            last_end_time = dateutil.parser.parse(last_end_time)
        last_execution_time = str(last_end_time - last_start_time).split(".")[0] if last_start_time and last_end_time else None
        team_enabled = g.memory.get(f"team-{team_id}-enabled", True)
        executing = g.memory.get(f"team-{team_id}-executing", False)
        match last_attack_status:
            case AttackExecutionStatus.crashed.value:
                last_attack_status = "⚠️ ☠️ "
            case AttackExecutionStatus.done.value:
                last_attack_status = " ✅ 🚩"
            case AttackExecutionStatus.noflags.value:
                last_attack_status = " ❌ 🚩"
            case _:
                last_attack_status = "❓"
        return [
            f"{team_id}: {short_name if short_name else name if name else ''}",
            host if host else "",
            errors,
            flags if flags > 0 else "⚠️  0",
            last_attack_status,
            last_flag_time + " ago" if last_flag_time else "⚠️   Never",
            last_execution_time if last_execution_time else "",
            "🟢" if team_enabled else "❌",
            "🔥" if executing else "💤"
        ]
    
    def on_data_table_cell_selected(self, event: DataTable.CellSelected) -> None:
        if event.coordinate.column == len(self.COLUMNS)-2:
            team_id = int(event.cell_key.row_key.value.split("-")[1])
            enabled = g.memory.get(f"team-{team_id}-enabled", True)
            g.memory[f"team-{team_id}-enabled"] = not enabled
            event.data_table.update_cell(row_key=event.cell_key.row_key, column_key=event.cell_key.column_key, value=Text("\n"+str("🟢" if not enabled else "❌"), justify=("center")))
    
    def update_data(self):
        table = self.query_one(DataTable)
        table.rows.clear()
        column_keys = table.add_columns(*[Text(text="\n\n"+ele, justify="center") for ele in self.COLUMNS])
        while True:
            config: ClientConfig = deepcopy(g.memory.get("config", None))
            if config is None:
                if g.exit_event.wait(timeout=1): return
                continue
            next_attack = g.memory.get("next_attack_at", None)
            submitter_flag_queue = g.memory.get("submitter_flags_in_queue", None)
            submitter_flag_queue = submitter_flag_queue if not submitter_flag_queue is None else "❓"
            submitter_status = g.memory.get("submitter_status", None)
            submitter_status = "🟢" if submitter_status else "[bold red]❌ ( Can't submit flags )[/]" if not submitter_status is None else "❓"
            config_updater = g.memory.get("config_update", None)
            if config_updater == True:
                config_updater = "🟢" if config.get("status", None) == "running" else "[bold red]⚠️  Server is not in running mode![/]"
            else:
                config_updater = "🟢" if config_updater else "[bold red]❌ (Check server is alive)[/]" if not config_updater is None else "❓"
            this_time = dt.now(datetime.timezone.utc)
            delta_next_attack = str(next_attack - this_time).split(".")[0]+" s" if next_attack and next_attack > this_time else "❓"
            status_text = self.query_one("#status_text", Label)
            messages = config.get("messages", [])
            messages = ", ".join([f"[bold red]{ele['level']}[/]: {escape(ele['title'])}" for ele in messages])
            status_text.update(
                f"[yellow]Exploit [bold]{escape(g.exploit_config.name)}[/] is running.. [/]\t\t\t"
                f"[bold]{messages}[/]\n"
                f"Submitter status: [bold green]{submitter_status}[/] (queued: {submitter_flag_queue})\t"
                f"Exploit timeout: [bold yellow]{g.memory.get('runtime_timeout', 'Unknown')} s[/]\t\t\t"
                f"Tick Duration: [bold green]{config.get('config', {}).get('TICK_DURATION', 'Unknown')} s[/]\t\t\t"
                f"Workers: [bold green]{g.pool_size}[/]\n"
                f"Server connection: [bold]{config_updater}[/]\t\t\t"
                f"Next Attacks: [bold green]{delta_next_attack}[/]\t\t\t"
                f"Flag Format: [bold green]{escape(config.get('config', {}).get('FLAG_REGEX', 'Unknown'))}[/]"         
            )
            
            used_keys = set()
            teams = config.get("teams", [])
            for team in teams:
                cols = [Text("\n"+str(cell), justify=("center" if not i in (0,1) else "default")) for i, cell in enumerate(self.get_team_row(team))]
                team_key = f"team-{team['id']}"
                used_keys.add(team_key)
                try:
                    [table.update_cell(row_key=team_key, column_key=column_keys[i], value=cols[i]) for i in range(len(column_keys))]
                except CellDoesNotExist:
                    table.add_row(*cols, height=3, key=team_key)
            for key in table.rows.keys():
                if key.value not in used_keys:
                    table.remove_row(row_key=key.value)
            if g.exit_event.wait(timeout=1): return
    
    def update_log(self):
        logger = self.query_one("#log", RichLog)
        while True:
            if g.exit_event.is_set(): return
            log = g.print_queue.get()
            if isinstance(log, str):
                logger.write(escape(log))
            else:
                logger.write(log)
    
    def flag_graph_update(self):
        while True:
            with self.batch_update():
                self.query_one("#flags_graph", FlagGraph).update()
            if g.exit_event.wait(timeout=10): return
    
    def exit_listener(self):
        if g.exit_event.wait():
            stop_screen()
            self.exit(0)
    
    def on_ready(self) -> None:
        """Called when the DOM is ready."""
        self.run_worker(self.update_data, thread=True)
        self.run_worker(self.update_log, thread=True)
        self.run_worker(self.flag_graph_update, thread=True)
        self.run_worker(self.exit_listener, thread=True)

    def compose(self) -> ComposeResult:
        yield Header(self.title)
        yield Label(id="status_text", classes="full-width")
        yield Tabs(*self.TABS, id="tabs")
        yield DataTable(id="teams", classes="tab tab-Teams hidden", cell_padding=3, header_height=5 )
        yield RichLog(highlight=True, id="log", classes="tab tab-Log hidden", wrap=True, auto_scroll=True, max_lines=5000)
        yield FlagGraph("Flags got over time", id="flags_graph", classes="tab tab-Graph hidden")
        yield Footer()    
        
    def show_tab(self, tab_name:str):
        self.query(".tab").add_class("hidden")
        self.query(f".tab-{tab_name}").remove_class("hidden")

    def action_show_tab(self, tab_name:str):
        self.query_one("#tabs", Tabs)._activate_tab(self.query_one(f"#{tab_name}", Tab))

    def on_tabs_tab_activated(self, event: Tabs.TabActivated) -> None:
        """Handle TabActivated message sent by Tabs."""
        self.show_tab(event.tab.label_text)

    def action_cancel(self):
        with self.suspend(): # Needed to avoid the app to hang, and terminal to be unusable
            self.exit(0)
            stop_screen()
    
