from __future__ import annotations

import argparse
import contextlib
import gc
import http.server
import importlib
import queue
import shlex
import shutil
import socket
import socketserver
import sys
import textwrap
import threading
import traceback
from dataclasses import dataclass
from pathlib import Path
from time import perf_counter
from typing import TYPE_CHECKING, NamedTuple, Protocol

from sonolus.backend.excepthook import print_simple_traceback
from sonolus.backend.utils import get_function, get_functions, get_tree_from_file
from sonolus.build.compile import CompileCache
from sonolus.script.internal.context import ProjectContextState
from sonolus.script.internal.error import CompilationError

if TYPE_CHECKING:
    from sonolus.script.project import BuildConfig, Project

HELP_TEXT = """
[r]ebuild
[d]ecode <message_code>
[h]elp
[q]uit
""".strip()

HELP_TEXT = textwrap.dedent(HELP_TEXT)


class CommandHelp(NamedTuple):
    alias: str
    command: str
    description: list[str]


DETAILED_HELP_TEXT = [
    CommandHelp(
        alias="r",
        command="rebuild",
        description=[
            ("Rebuild the project with the latest changes from source files."),
        ],
    ),
    CommandHelp(
        alias="d",
        command="decode <message_code>",
        description=[
            (
                "Decode a debug message code to its original text with a stack trace. "
                "Message codes are generated by assert statements and certain debug-related functions such as "
                "debug.error() and debug.notify(). "
                "The game is automatically paused when these are triggered in debug mode and the message code "
                "can be found in the debug log."
            ),
            "Example: d 42",
        ],
    ),
    CommandHelp(
        alias="h",
        command="help",
        description=[
            "Show this help message.",
        ],
    ),
    CommandHelp(
        alias="q",
        command="quit",
        description=[
            "Exit the development server.",
        ],
    ),
]


@dataclass
class ServerState:
    project: Project
    project_module_name: str
    core_module_names: set[str]
    build_dir: Path
    config: BuildConfig
    cache: CompileCache
    project_state: ProjectContextState


class Command(Protocol):
    def execute(self, server_state: ServerState) -> None: ...


@dataclass
class RebuildCommand:
    def execute(self, server_state: ServerState):
        from sonolus.build.cli import build_collection

        for module_name in tuple(sys.modules):
            if module_name not in server_state.core_module_names:
                del sys.modules[module_name]

        try:
            project_module = importlib.import_module(server_state.project_module_name)
        except Exception:
            print(traceback.format_exc())
            return

        get_function.cache_clear()
        get_tree_from_file.cache_clear()
        get_functions.cache_clear()
        print("Rebuilding...")
        try:
            start_time = perf_counter()
            server_state.cache.reset_accessed()
            server_state.project_state = ProjectContextState.from_build_config(server_state.config)
            server_state.project = project_module.project
            build_collection(
                server_state.project,
                server_state.build_dir,
                server_state.config,
                cache=server_state.cache,
                project_state=server_state.project_state,
            )
            server_state.cache.prune_unaccessed()
            end_time = perf_counter()
            print(f"Rebuild completed in {end_time - start_time:.2f} seconds")
        except CompilationError:
            exc_info = sys.exc_info()
            print_simple_traceback(*exc_info)


@dataclass
class DecodeCommand:
    message_code: int

    def execute(self, server_state: ServerState):
        debug_str_mappings = server_state.project_state.debug_str_mappings
        message = next((msg for msg, code in debug_str_mappings.items() if code == self.message_code), None)

        if message is not None:
            print(message)
        else:
            print(f"Unknown message code: {self.message_code}")


@dataclass
class HelpCommand:
    def execute(self, server_state: ServerState):
        terminal_width = shutil.get_terminal_size().columns
        max_width = min(terminal_width, 120)

        print("Available Commands:\n")

        for entry in DETAILED_HELP_TEXT:
            print(f"[{entry.alias}] {entry.command}")

            for paragraph in entry.description:
                initial_indent = "  "
                subsequent_indent = "  "
                wrapped = textwrap.fill(
                    paragraph,
                    width=max_width - len(initial_indent),
                    initial_indent=initial_indent,
                    subsequent_indent=subsequent_indent,
                )
                print(wrapped)
                print()


@dataclass
class ExitCommand:
    def execute(self, server_state: ServerState):
        print("Exiting...")
        sys.exit(0)


def parse_dev_command(command_line: str) -> Command | None:
    parser = argparse.ArgumentParser(prog="", add_help=False, exit_on_error=False)
    subparsers = parser.add_subparsers(dest="cmd")

    subparsers.add_parser("rebuild", aliases=["r"])
    decode_parser = subparsers.add_parser("decode", aliases=["d"])
    decode_parser.add_argument("message_code", type=int, help="Message code to decode")
    subparsers.add_parser("help", aliases=["h"])
    subparsers.add_parser("quit", aliases=["q"])

    try:
        split_args = shlex.split(command_line)
    except ValueError as e:
        print(f"Error parsing command: {e}\n")
        return None

    try:
        args = parser.parse_args(split_args)
        if args.cmd in {"rebuild", "r"}:
            return RebuildCommand()
        elif args.cmd in {"decode", "d"}:
            return DecodeCommand(message_code=args.message_code)
        elif args.cmd in {"help", "h"}:
            return HelpCommand()
        elif args.cmd in {"quit", "q"}:
            return ExitCommand()
        else:
            # Really, we should not reach here, since argparse would have errored out earlier
            print("Unknown command.\n")
            return None
    except (argparse.ArgumentError, argparse.ArgumentTypeError) as e:
        print(f"Error parsing command: {e}\n")
        return None
    except SystemExit:
        # argparse throws this on some errors, and will print out help automatically
        print()
        return None


def command_input_thread(command_queue: queue.Queue, prompt_event: threading.Event):
    print(f"\nAvailable commands:\n{HELP_TEXT}")

    while True:
        try:
            prompt_event.wait()
            prompt_event.clear()

            print("\n> ", end="", flush=True)
            command_line = input()
            if command_line.strip():
                cmd = parse_dev_command(command_line.strip())
                if cmd:
                    command_queue.put(cmd)
                    if isinstance(cmd, ExitCommand):
                        break
                else:
                    print(f"Available commands:\n{HELP_TEXT}")
                    # Show prompt again
                    prompt_event.set()
            else:
                prompt_event.set()
        except EOFError:
            break
        except Exception as e:
            print(f"Error reading command: {e}\n")
            prompt_event.set()


def get_local_ips():
    hostname = socket.gethostname()
    local_ips = []

    with contextlib.suppress(socket.gaierror):
        local_ips.append(socket.gethostbyname(socket.getfqdn()))

    try:
        for info in socket.getaddrinfo(hostname, None):
            ip = info[4][0]
            if not ip.startswith("127.") and ":" not in ip:
                local_ips.append(ip)
    except socket.gaierror:
        pass

    return sorted(set(local_ips))


def run_server(
    base_dir: Path,
    port: int,
    project_module_name: str | None,
    core_module_names: set[str] | None,
    build_dir: Path,
    config: BuildConfig,
    project,
):
    from sonolus.build.cli import build_collection

    cache = CompileCache()
    project_state = ProjectContextState.from_build_config(config)

    start_time = perf_counter()
    build_collection(project, build_dir, config, cache=cache, project_state=project_state)
    end_time = perf_counter()
    print(f"Build finished in {end_time - start_time:.2f}s")

    interactive = project_module_name is not None and core_module_names is not None

    class DirectoryHandler(http.server.SimpleHTTPRequestHandler):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, directory=str(base_dir), **kwargs)

        def log_message(self, fmt, *args):
            sys.stdout.write("\r\033[K")  # Clear line
            sys.stdout.write(f"{self.address_string()} [{self.log_date_time_string()}] {fmt % args}\n")
            if interactive:
                sys.stdout.write("> ")
            sys.stdout.flush()

    with socketserver.TCPServer(("", port), DirectoryHandler) as httpd:
        local_ips = get_local_ips()
        print(f"Server started on port {port}")
        print("Available on:")
        for ip in local_ips:
            print(f"  http://{ip}:{port}")

        if interactive:
            server_state = ServerState(
                project=project,
                project_module_name=project_module_name,
                core_module_names=core_module_names,
                build_dir=build_dir,
                config=config,
                cache=cache,
                project_state=project_state,
            )

            threading.Thread(target=httpd.serve_forever, daemon=True).start()

            command_queue = queue.Queue()
            prompt_event = threading.Event()
            input_thread = threading.Thread(
                target=command_input_thread, args=(command_queue, prompt_event), daemon=True
            )
            input_thread.start()

            prompt_event.set()

            try:
                while True:
                    cmd = command_queue.get()
                    try:
                        cmd.execute(server_state)
                        gc.collect()
                    except Exception:
                        print(f"{traceback.format_exc()}\n")
                    prompt_event.set()
            except KeyboardInterrupt:
                print("Exiting...")
                sys.exit(0)
            finally:
                httpd.shutdown()
        else:
            httpd.serve_forever()
