"""
The CLI module for PentestTools.com. It uses PTT API and formats output.
"""

from __future__ import annotations
import argparse
from concurrent.futures import ThreadPoolExecutor
import dataclasses
import functools
import json
import sys
import time
from typing import Optional
import os

import requests

from . import api
from . import banner
from .context import Context
from .logger_config import logger, configure_logger_for_cli
from .text_format import log_runtime_status, print_summary, print_report


STARTED_SCANS = []
EXIT_SUCCESS = 0
EXIT_FAILURE = 1

FAIL_CHOICES = ["low", "medium", "high", "none"]
SCAN_TYPES = ["light", "deep"]


class CLIException(Exception):
    """Exception specific for this file."""


class Timer:
    """Holds a start_time variable and a delta we want to achive. Used to check if delta time
    has elapsed since start_time and the remaining time."""

    def __init__(self, delta: float):
        self.delta = delta
        self.start_time = time.perf_counter()

    def set(self):
        """Resets the timer start point with the current time."""
        self.start_time = time.perf_counter()

    def is_done(self) -> bool:
        """Checks if wanted time interval has passed."""
        return time.perf_counter() - self.start_time >= self.delta

    def elapsed(self) -> float:
        """Returns how many seconds have passed since start_time."""
        return time.perf_counter() - self.start_time

    def remaining(self):
        """Returns how many seconds are needed to pass for wanted time interval."""
        return self.delta - (time.perf_counter() - self.start_time)


class WaitingFuture:
    """A replacement class for Futures that submit the task to the thread only after a time has
    passed."""

    def __init__(self, executor, timer, *args, **kwargs):
        self.executor = executor
        self.timer = timer
        self.task_args = args
        self.task_kwargs = kwargs

    def done(self):  # pylint: disable=method-hidden
        if not self.timer.is_done():
            return False
        self.timer.set()

        task = self.executor._parent_submit(*self.task_args, **self.task_kwargs)
        self.done = task.done
        self.result = task.result

        return False


class ThrottledThreadPoolExecutor(ThreadPoolExecutor):
    """A subclass of ThreadPoolExecutor that limits the task submission to a time rate."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.timers = {}
        self._parent_submit = super().submit

    def set_func_limit(self, func, delta: float):
        """Sets a time rate limit for a function that will be submitted to this pool."""
        self.timers[func] = Timer(delta)

    def submit(self, func, *args, **kwargs):
        if func not in self.timers:
            return super().submit(func, *args, **kwargs)
        if self.timers[func].is_done():
            self.timers[func].set()
            return super().submit(func, *args, **kwargs)

        return WaitingFuture(self, self.timers[func], func, *args, **kwargs)


class StartScanException(Exception):
    """Exception thrown when a scan start fails."""

    def __init__(self, info, message="Scan failed to start"):
        message += ": " + json.dumps(info)
        super().__init__(message)
        self.info = info


@dataclasses.dataclass
class ResultSummary:
    """Holds the result_summary data from the API."""

    high: int
    info: int
    low: int
    medium: int

    @staticmethod
    def from_json(data: dict) -> ResultSummary:
        """Converts a JSON-like dict to a ResultSummary object."""
        return ResultSummary(
            data["high"],
            data["info"],
            data["low"],
            data["medium"],
        )


@dataclasses.dataclass
class ScanStatus:
    """Holds the scan status data from the API."""

    duration: int
    end_time: Optional[str]
    num_finished_tests: int
    num_tests: int
    progress: int
    result_summary: Optional[ResultSummary]
    status_name: str
    start_time: str
    target_id: int

    @staticmethod
    def from_json(data: dict) -> ScanStatus:
        """Converts a JSON-like dict to a ScanStatus object."""
        result_summary = None
        if "result_summary" in data:
            result_summary = ResultSummary.from_json(data["result_summary"])
        return ScanStatus(
            data["duration"],
            data.get("end_time", None),
            data["num_finished_tests"],
            data["num_tests"],
            data["progress"],
            result_summary,
            data["status_name"],
            data["start_time"],
            data["target_id"],
        )


class Scan:
    """Represents a started scan."""

    def __init__(self, scan_id):
        self.scan_id = scan_id

    def get_output(self) -> dict:
        """Returns the result of the scan."""
        return api.get_output(self.scan_id).json()

    def get_scan_status(self) -> ScanStatus:
        """Returns the status of the scan."""
        ret = api.get_scan_status(self.scan_id).json()
        return ScanStatus.from_json(ret["data"])


def start_scan_instance(target, tool_id, scan_type="light", max_scan_time: int = 10):
    """Starts a scan through the API.

    Returns a Scan object associated to the started scan.
    """
    tool_params = {"scan_type": scan_type}
    scan_params = {"max_scan_time": max_scan_time}

    res = api.start_scan(target, tool_id, tool_params, scan_params)
    res_json = res.json()
    if "data" not in res_json or "created_id" not in res_json["data"]:
        raise StartScanException(res_json)
    scan_id = res_json["data"]["created_id"]
    scan = Scan(scan_id)
    STARTED_SCANS.append(scan)
    return scan


def stop_started_scans():
    """Stops all started scans."""
    for scan in STARTED_SCANS:
        api.stop_scan(scan.scan_id)


def cli_resource_handler(func):
    """CLI decorator used to catch all exceptions. Tries to write the scan results and stop the running
    scans."""

    @functools.wraps(func)
    def wrapper(ctx, *args, **kwargs):
        raised_exc = None
        try:
            return func(ctx, *args, **kwargs)
        except requests.exceptions.RequestException as exc:
            raised_exc = exc
            logger.error(
                "An HTTP error has occured. Stopping scans and writing output. "
                "Press Ctrl+C to terminate forcefully.\n",
            )
            write_result(ctx)
            try:
                stop_started_scans()
            except requests.exceptions.RequestException:
                pass
        except KeyboardInterrupt as exc:
            raised_exc = exc
            logger.error(
                "A stop was requested. Stopping scans and writing output. "
                "Press Ctrl+C again to terminate forcefully.\n",
            )
            write_result(ctx)
            stop_started_scans()
        except Exception as exc:
            raised_exc = exc
            logger.error(
                "An error has occured. Stopping scans and writing output. Press Ctrl+C to terminate forcefully.\n"
            )
            write_result(ctx)
            stop_started_scans()
            raise

        if ctx.args.verbose:
            raise raised_exc

        return None

    return wrapper


# Currently not used. Replaced by the WaitingFuture.
def throttled_call(delta=1):
    """Limits calls for max 1/delta per second."""

    def decorator(func):
        timer = Timer(delta)

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if "nowait" in kwargs and kwargs["nowait"]:
                timer.set()
                return func(*args, **kwargs)

            # Wait until time passes
            while not timer.is_done():
                time.sleep(max(0, timer.remaining() - 0.005))

            timer.set()
            return func(*args, **kwargs)

        return wrapper

    return decorator


def write_result(ctx: Context) -> None:
    """Writes the scan result at stdout or in a file.

    Args:
        ctx: Global context.
        target: The URL obtained from the scan summary.
    """
    if not ctx.has_args:
        return

    if ctx.args.output:
        prev_nocolor = ctx.args.nocolor
        ctx.args.nocolor = True
        try:
            with open(ctx.args.output, "w", encoding="utf8") as file:
                if ctx.args.format == "json":
                    file.write(json.dumps(ctx.scan_output))
                elif ctx.args.format == "text":
                    print_report(ctx, ctx.scan_output, file)
                    print_summary(ctx.summary, ctx.target, file)
                else:
                    raise CLIException("Unhandled output format")
        except OSError:
            logger.error("Could not open/write file: %s\n", ctx.args.output)
        else:
            logger.info("Output written to: %s\n", ctx.args.output)
        finally:
            ctx.args.nocolor = prev_nocolor
    else:
        if ctx.args.format == "json":
            print(json.dumps(ctx.scan_output))
        elif ctx.args.format == "text":
            print_report(ctx, ctx.scan_output, sys.stdout)
            print_summary(ctx.summary, ctx.target, sys.stdout)
        else:
            raise CLIException("Unhandled output format")


def scan_loop(ctx: Context, scan: Scan, sec_per_tick=0.2, scan_status_wait=3, scan_output_wait=60) -> None:
    """Renders a live status of the scan, while talking to the API.
    Stops when the scan status is finished. Handles the requests on a separate thread.

    Args:
        ctx: Global context.
        scan: The started scan.
        sec_per_tick: The number of seconds to wait, until rerendering.
    """

    scan_status = scan.get_scan_status()
    del_prev = False

    pool = ThrottledThreadPoolExecutor(max_workers=2)
    pool.set_func_limit(scan.get_scan_status, scan_status_wait)
    pool.set_func_limit(scan.get_output, scan_output_wait)

    status_task = pool.submit(scan.get_scan_status)
    output_task = pool.submit(scan.get_output)

    while True:
        time.sleep(sec_per_tick)

        if status_task.done():
            scan_status = status_task.result()
            status_task = pool.submit(scan.get_scan_status)

        if output_task.done():
            ctx.scan_output = output_task.result()
            output_task = pool.submit(scan.get_output)

        log_runtime_status(
            ctx,
            scan_status,
            del_prev=del_prev,
        )
        del_prev = del_prev or True
        if scan_status.status_name == "finished":
            break

    log_runtime_status(
        ctx,
        scan_status,
        del_prev=del_prev,
    )


def get_option(args, env_name, arg_name, cast=lambda x: x):
    """Gets the option with the correct precedence.
    Environment variables before command-line argument, before default values.

    This is needed for action.yml, where it's much easier to use env vars."""
    env = os.environ.get(env_name)
    arg = getattr(args, arg_name)
    return cast(env or arg)


def parse_args():
    """Parses the command line arguments.

    Returns an ArgumentParser object.
    """
    parser = argparse.ArgumentParser(
        description="""Command-line utility for PentestTools.com.

        Example usage: ptt -q --key <api_key> run website_scanner https://pentest-ground.com
        """,
        prog="ptt",
        add_help=False,
    )
    parser.add_argument(
        "-h",
        "--help",
        action="help",
        default=argparse.SUPPRESS,
        help="Show this help message and exit.",
    )

    parser.add_argument(
        "--fail",
        choices=FAIL_CHOICES,
        default="none",
        help="Define failure criteria. `--fail low` fails if even a low vulnerability finding is found.",
    )

    parser.add_argument("--key", help="The API key. If not provided, obtain one automatically.")
    parser.add_argument(
        "--nocolor",
        action="store_true",
        help="If set, don't color the output at stdout and stderr.",
    )
    parser.add_argument("--verbose", action="store_true", help="If set, print debug information.")
    parser.add_argument(
        "-q",
        "--quiet",
        action="store_true",
        help="If set, suppress info data and errors. Only print the final report.",
    )

    # Create the top-level parser
    subparsers = parser.add_subparsers(required=True)

    # Add the model-context protocol subparser
    # It either runs the MCP server, or, if it was not installed, prints the install command and exits
    parser_mcp = subparsers.add_parser("mcp")
    parser_mcp.set_defaults(mcp=True) # We want to know that `ptt mcp` has been invoked

    # Create the parser for the first subcommand
    parser_run = subparsers.add_parser("run")
    subparsers_run = parser_run.add_subparsers(dest="website_scanner", required=True)

    # Create the parser for the second subcommand
    parser_website_scanner = subparsers_run.add_parser("website_scanner", add_help=False)

    # Define the final subparser for target and arguments
    parser_website_scanner.add_argument("target", nargs="?", help="The URL to scan.")
    parser_website_scanner.add_argument(
        "--scan_type",
        choices=SCAN_TYPES,
        default="light",
        help="The type of the scan",
    )
    parser_website_scanner.add_argument(
        "--max_scan_time",
        default=10,
        type=int,
        help="Maximum scan duration",
    )

    parser_website_scanner.add_argument(
        "-h",
        "--help",
        action="help",
        default=argparse.SUPPRESS,
        help="Show this help message and exit.",
    )
    parser_website_scanner.add_argument(
        "-o",
        "--output",
        help="File to write the report. If specified, suppress the default report output and redirect to the path.",
    )
    parser_website_scanner.add_argument(
        "--format",
        choices=["text", "json"],
        default="text",
        help="Specify the output format of the vulnerability report.",
    )

    return parser.parse_args()


def transform_args(args: argparse.Namespace) -> None:
    """Transforms some of the arguments given by the user, in a format accepted by the API.

    Args:
        args: The parsed command line arguments.
    """
    args.key = get_option(args, "PTT_KEY", "key")

    # Check for MCP.
    if getattr(args, "mcp", False) is True:
        # `ptt mcp` invoked, return early because we don't have to parse the other args
        return
    else:
        # Assigning so we don't have to check with `getattr` again
        args.mcp = False

    args.quiet = get_option(args, "PTT_QUIET", "quiet", bool)
    args.format = get_option(args, "PTT_FORMAT", "format", lambda x: x if x in ["text", "json"] else "text")
    args.target = get_option(args, "PTT_TARGET", "target")
    args.fail = get_option(args, "PTT_FAIL", "fail")
    args.scan_type = get_option(args, "PTT_SCAN_TYPE", "scan_type")
    args.max_scan_time = int(get_option(args, "PTT_MAX_SCAN_TIME", "max_scan_time"))

    if args.fail not in FAIL_CHOICES:
        sys.exit(
            f"error: the failure criteria is not correct. Accepted values are {FAIL_CHOICES}, either through PTT_FAIL or --fail <choice>"
        )

    if args.scan_type not in SCAN_TYPES:
        sys.exit(
            f"error: the scan type is not correct. Accepted values are {SCAN_TYPES}, either through PTT_SCAN_TYPE or --scan_type <choice>"
        )

    if not args.target:
        sys.exit(
            "error: the `target` is required, either through the env var PTT_TARGET or through the command-line: ptt run website_scanner <target>"
        )

    if (not args.target.startswith("http://") or not args.target.startswith("https://")) and "://" not in args.target[
        :10
    ]:
        args.target = f"http://{args.target}"


def validate_args(args: argparse.Namespace):
    """Checks the corectness of the CLI arguments.

    Args:
        args: The parsed command line arguments.

    Returns:
        True, if the arguments are valid, False otherwise.
    """
    if args.output:
        try:
            with open(args.output, "a", encoding="utf-8") as file:
                if not file.writable():
                    logger.error("File is not writable: %s\n", args.output)
                    return False
        except OSError:
            logger.error("Could not open file: %s\n", args.output)
            return False

    return True


@cli_resource_handler
def cli(ctx: Context, scan_loop_kwargs={}):
    """Handles command line arguments, starts the scan, logs the status while
    running, then outputs the vulnerability report."""
    ctx.args = parse_args()
    transform_args(ctx.args)

    # `ptt mcp` was called
    if ctx.args.mcp is True:
        if ctx.args.key:
            api.api_key = ctx.args.key

        from .mcp import entrypoint as mcp_entrypoint
        sys.exit(mcp_entrypoint())

    configure_logger_for_cli(ctx)
    if not validate_args(ctx.args):
        return 1

    manual_key_flag = False
    first_key_try_flag = True
    regenerate = False

    if ctx.args.key:
        api.API_KEY = ctx.args.key
        api.KEY_READ = False
        manual_key_flag = True

    for _ in range(2):
        api.init(regenerate)
        # After this point api.API_KEY is populated

        # Validate api key
        api_key_valid = False
        res = api.get_scan_status(0)
        if res.headers.get("content-type") == "application/json":
            data = res.json()
            if res.status_code == 404 and data["status"] == 404 and "not exist" in data["message"]:
                api_key_valid = True
                break
            elif res.status_code == 200 and "id" in data["data"]:
                api_key_valid = True
                break
        if not api_key_valid:
            # If the key was inserted manually, fail with a log.
            if manual_key_flag:
                logger.error(
                    "The API %s didn't respond properly. Your API key might be invalid.\n",
                    {api.API_URL},
                )
                return EXIT_FAILURE
            # If the key was read from the config, regenerate it.
            elif api.KEY_READ and first_key_try_flag:
                logger.info("API key may be old, regenerating\n")
                regenerate = True
                first_key_try_flag = False
                continue
            # If the key was read from the config and it still fails, fail with a log.
            elif api.KEY_READ and not first_key_try_flag:
                logger.error(
                    "The API %s didn't respond properly on the second retry.\n",
                    {api.API_URL},
                )
                return EXIT_FAILURE

            return EXIT_FAILURE

    logger.error("%s\n", banner.BANNER)
    logger.error("Scanning target: %s\n\n", ctx.args.target)

    try:
        ws = start_scan_instance(
            ctx.args.target,
            api.Tool.WEBSITE_SCANNER,
            ctx.args.scan_type,
            ctx.args.max_scan_time,
        )
    except StartScanException as exc:
        if "message" not in exc.info:
            raise exc
        if (
            "Target URL redirects to" in exc.info["message"]
            or "Target URL is not accessible (DNS error)" in exc.info["message"]
            or "Your current plan only allows" in exc.info["message"]
        ):
            logger.error("%s\n", exc.info["message"])
            sys.exit(1)

        raise exc

    if not isinstance(ws, Scan):
        logger.error("Scan info: %s\n", ws)
        return EXIT_FAILURE

    scan_loop(ctx, ws, **scan_loop_kwargs)
    ctx.scan_output = ws.get_output()
    ctx.summary = ws.get_scan_status()
    target_info = api.get_target_by_id(ctx.summary.target_id)
    ctx.target = target_info.json()["data"]["name"]
    write_result(ctx)

    # Fails if our failure criteria is found within the summary
    result_summary = ctx.summary.result_summary
    if result_summary is not None:
        if ctx.args.fail == "low":
            if result_summary.low or result_summary.medium or result_summary.high:
                print(
                    "`ptt` scan failed, because it ran with `--fail low` and the report has vulnerabilities with higher or equal risk.",
                    file=sys.stderr,
                )
                return EXIT_FAILURE
        elif ctx.args.fail == "medium":
            if result_summary.medium or result_summary.high:
                print(
                    "`ptt` scan failed, because it ran with `--fail medium` and the report has vulnerabilities with higher or equal risk.",
                    file=sys.stderr,
                )
                return EXIT_FAILURE
        elif ctx.args.fail == "high":
            if result_summary.high:
                print(
                    "`ptt` scan failed, because it ran with `--fail high` and the report has vulnerabilities with higher or equal risk.",
                    file=sys.stderr,
                )
                return EXIT_FAILURE

    return EXIT_SUCCESS


def entrypoint():
    """Entrypoint function for the CLI."""
    ctx = Context()
    return cli(ctx)
