"""
Some basic sandboxing for running untrusted (model generated code).
Not intended to be fully secure.
"""

import logging
import time
logger = logging.getLogger(__name__)
from synthegrator._vendor.epicbox.exceptions import DockerError, EpicBoxError


import multiprocessing
import os
import sys
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Union
import docker
import pytest
import xxhash
from docker.errors import BuildError

from synthegrator._vendor import epicbox
logging.getLogger("synthegrator._vendor.epicbox").setLevel(logging.WARNING)

from synthegrator._vendor.human_eval_ref.execution import (
    create_tempdir,
    reliability_guard,
    swallow_io,
    time_limit,
)
from synthegrator.memory_fs import ProjectDir
from synthegrator.util import pretty_print_python_code

cur_file = Path(__file__).parent.absolute()
IN_GITHUB_ACTION = os.environ.get("GITHUB_ACTIONS") == "true"


@dataclass(frozen=True)
class ExecLimits:
    timeout_cpu_s: int = 60
    timeout_realtime_s: int = 1_000
    memory_limit_mb: int = 4_000
    max_procs: int = 4_000
    networking_allowed: bool = False

    def __getitem__(self, key):
        return getattr(self, key)

    def keys(self):
        return self.__dict__.keys()


@dataclass(frozen=True)
class Cmd:
    cmd: str
    files: tuple[dict[str, bytes | str]] = ()

    @staticmethod
    def from_cmds_str_mix(cmds: Iterable[Union[str, "Cmd"]]) -> list["Cmd"]:
        return [Cmd(cmd) if isinstance(cmd, str) else cmd for cmd in cmds]

    @staticmethod
    def from_cmds_str_mix_tuple(
        *cmds: Iterable[Union[str, "Cmd"]],
    ) -> tuple["Cmd", ...]:
        return tuple(Cmd(cmd) if isinstance(cmd, str) else cmd for cmd in cmds)


@dataclass(frozen=True)
class CmdExecResult:
    cmd: Cmd
    exit_code: int
    completed: bool
    stdout: bytes
    stderr: bytes
    duration: float
    timeout: bool
    oom_killed: bool


@dataclass(frozen=True)
class TestsuiteExecResult:
    __test__ = False
    got_results: bool
    """whether got some xml"""
    collection_error: bool
    """Might be syntax error in the code or runtime error at collection"""
    exec_error: bool
    """The test suite didn't complete (like due to OOM or timeout)"""
    test_suite_exec_result: CmdExecResult
    xml_result: str


_BUILT_IMAGES = {
    # "synthegrator_defects4j",  # TODO: why here
    # "dypy_synthegrator",  # TODO: sort this out. Not clear this should be here
    "NO_DOCKER_IMAGE_BLOCKED",
}

_GITHUB_ACTION_IMAGE_DISALLOWED = {
    "dypy_synthegrator",
}


@dataclass(frozen=True)
class DockerExecutionContext:
    image_name: str
    """The image to try to use or pull"""
    setup_cmds: tuple[Cmd, ...] = ()
    """These are run at the start of execution. Essentially like the CMD in a dockerfile"""
    before_file_copy_cmds: tuple[Cmd, ...] = ()
    """These are run before the working directory is copied into the docker"""
    dockerfile: str = None
    """The string text of a docker file. If provided, should be built when first
    executing code with this environment."""
    build_pre_cmds: tuple[Cmd, ...] = ()
    """These commands run before the dockerfile is built. During build we make a
    temporary directory that serves as the context for the docker. The wd of the
    shell is set to this directory. Thus commands like `cp ~/mydata .` can copy
    into the directory so the dockerfile can access it"""
    default_limits: "ExecLimits" = None

    def __post_init__(self):
        if not (
            all(isinstance(cmd, Cmd) for cmd in self.setup_cmds)
            and all(isinstance(cmd, Cmd) for cmd in self.before_file_copy_cmds)
            and all(isinstance(cmd, Cmd) for cmd in self.build_pre_cmds)
        ):
            msg = "Not all commands are of Cmd type"
            raise TypeError(msg)

    def build(self):
        docker.from_env()
        # _BUILT_IMAGES |= { t for i in client.images.list() for t in i.tags  }
        if self.dockerfile is None:
            if self.image_name not in _BUILT_IMAGES:
                logger.info("Dockerfile is empty. Pulling image: %s", self.image_name)
                raise NotImplementedError("Cannot handle None docker file")
                # TODO: fix pulling images
                # client.images.pull(self.image_name)
        else:
            _build_docker_image(
                self.dockerfile,
                self.image_name,
                self.build_pre_cmds,
            )

    def get_hash(self) -> str:
        if getattr(self, "_hash", None) is None:
            hasher = xxhash.xxh64()
            hasher.update(str(self))
            object.__setattr__(self, "_hash", hasher.hexdigest())
        return self._hash


def make_py_env_with_dependencies(
    name: str = "python_synthegrator",
    base: str = "python:3.11-slim",
    dependencies: list[str] | None = None,
) -> DockerExecutionContext:
    # Auto-detect current pytest version to maintain consistency
    try:
        import pytest
        pytest_version = pytest.__version__
        # Use same minor version to avoid breaking changes
        pytest_spec = f"pytest~={pytest_version}"
    except ImportError:
        # Fallback if pytest not available
        pytest_spec = "pytest~=8.4.1"
    
    docker_file_lines = [
        f"FROM {base}",
        f"RUN pip install --no-cache-dir {pytest_spec}",
    ]
    if dependencies and len(dependencies) > 0:
        dependencies_quoted = [f'"{dep}"' for dep in dependencies]
        dependencies_cmd = (
            f"RUN pip install --no-cache-dir {' '.join(dependencies_quoted)}"
        )
        docker_file_lines.append(dependencies_cmd)
    docker_file_lines.append("WORKDIR /app")
    return DockerExecutionContext(
        image_name=name,
        dockerfile="\n".join(docker_file_lines),
    )


# fmt: off
PY_DEFAULT_DOCKER_ENV = make_py_env_with_dependencies(
    name="python_synthegrator",
    dependencies=["numpy==2.1.0"]
)
# fmt: on


JAVA_DEFAULT_DOCKER_ENV = DockerExecutionContext(
    image_name="java_synthegrator",
    build_pre_cmds=Cmd.from_cmds_str_mix_tuple(
        f"cp {cur_file / 'lang_specs/extra_files/default_java_pom.xml'} pom.xml",
        f"cp {cur_file / 'lang_specs/extra_files/DummyTest.java'} DummyTest.java",
    ),
    dockerfile=(
        # fmt: off
        "FROM maven:3.9.4-eclipse-temurin-17\n"
        "WORKDIR /app\n"
        "COPY pom.xml .\n"
        # Copy in a dummy file that uses Junit and force it install maven dependencies.
        #   Makeing the dependencies part of the container is important so that we
        #   aren't redownloading the dependencies every time we evaluate a problem.
        "COPY DummyTest.java ./src/test/java/DummyTest.java\n"
        "RUN mvn dependency:go-offline\n"
        "RUN mvn dependency:resolve-plugins\n"
        "RUN mvn test\n"
        # fmt: on
    ),
    setup_cmds=Cmd.from_cmds_str_mix_tuple(
        "cp /app/pom.xml .",
    ),
)

_dypy_root = Path(__file__).parent / "synthdatasets/dypybench_data/dypy_run_scripts"
DYPYBENCH_DOCKER_ENV = DockerExecutionContext(
    image_name="dypy_synthegrator",
    dockerfile=(
        # fmt: off
        "FROM dypybench/dypybench:v1.0\n"
        "COPY dypybench.py.nolint /DyPyBench/dypybench.py\n"
        "COPY run-test-temp.sh /DyPyBench/scripts/run-test-temp.sh\n"
        "COPY copy-project.sh /DyPyBench/scripts/copy-project.sh\n"
        "RUN chmod +x /DyPyBench/scripts/run-test-temp.sh /DyPyBench/scripts/copy-project.sh\n"
        # fmt: on
    ),
    build_pre_cmds=Cmd.from_cmds_str_mix_tuple(
        f"cp {_dypy_root / 'run-test-temp.sh'} .",
        f"cp {_dypy_root / 'copy-project.sh'} .",
        f"cp {_dypy_root / 'dypybench.py.nolint'} .",
    ),
    default_limits=ExecLimits(
        timeout_cpu_s=60 * 60,
        timeout_realtime_s=60 * 60,
        networking_allowed=True,
    ),
)


def run_on_docker(
    docker_context: DockerExecutionContext,
    cmds: Iterable[Cmd | str],  # Commands to run after starting
    files: ProjectDir = None,  # files to copy over
    limits: ExecLimits | None = None,
    interactive_shell_debug: bool = False,
    retries_on_docker_error: int = 2,
) -> list["CmdExecResult"]:
    """Run a set of commands in a docker container. Returns a list of results."""
    # TODO: This is just a hack to get this working. Will
    #  need to benchmark and optimize this later. In particular it restarts
    #  the container multiple time for each command, and it should just keep
    #  it running.
    # TODO: migrate away from epicbox. Epicbox has weird quirk where not only
    #  is it restarting every time, but it ignores the working directory in the
    #  the dockerfile and is instead using a /sandbox volume
    cmds = Cmd.from_cmds_str_mix(cmds)
    if limits is None:
        limits = docker_context.default_limits or ExecLimits()

    for attempt in range(retries_on_docker_error):
        try:
            docker_context.build()
            break
        except DockerError as e:
            logger.warning(f"Docker error: {e}.\n Retrying {attempt + 1} of {retries_on_docker_error}...")
            time.sleep(2 ** attempt)

    profile_name = f"synthegrator_{docker_context.image_name}"
    epicbox.configure(
        profiles=[
            epicbox.Profile(
                profile_name,
                docker_context.image_name,
                network_disabled=not limits.networking_allowed,
            ),
        ],
    )

    limits_d = {
        "cputime": limits.timeout_cpu_s,
        "realtime": limits.timeout_realtime_s,
        "memory": limits.memory_limit_mb,
        "processes": limits.max_procs,
        # TODO: add a way to limit disk space
    }

    for attempt in range(retries_on_docker_error):
        try:
            with epicbox.working_directory() as wd:
                results = []

                def run_cmd(cmd: Cmd):
                    if len(cmd.files) > 0 and not all(
                        isinstance(f["name"], str) and isinstance(f["content"], bytes)
                        for f in cmd.files
                    ):
                        msg = "Files in cmd are not str/bytes pairs."
                        raise Exception(msg)

                    run = epicbox.run(
                        profile_name,
                        cmd.cmd,
                        workdir=wd,
                        limits=limits_d,
                        files=cmd.files,
                    )
                    results.append(
                        CmdExecResult(
                            cmd=cmd,
                            completed=not (run["timeout"] or run["oom_killed"]),
                            **run,
                        ),
                    )

                # Pre copy
                if docker_context.before_file_copy_cmds:
                    for cmd in docker_context.before_file_copy_cmds:
                        run_cmd(cmd)

                # Copy over working directory files
                if files:
                    # hackily move the files over. Have to work around how epicbox
                    #  handles files where can only have flat set of files.
                    transfer_tar_name = "transfer_data_to_val.tar"
                    epic_files = [
                        {
                            "name": transfer_tar_name,
                            "content": files.convert_to_tar(gzip=False),
                        },
                    ]
                    epicbox.run(
                        profile_name,
                        f"tar -xf {transfer_tar_name}; rm {transfer_tar_name}",
                        workdir=wd,
                        files=epic_files,
                    )

                # Run rest of commands
                for cmd in [*docker_context.setup_cmds, *cmds]:
                    run_cmd(cmd)

                if interactive_shell_debug:
                    while cmd := input("cmd: "):
                        r = epicbox.run(profile_name, cmd, workdir=wd, limits=limits_d)
                        print(r)
                        print(r["stdout"].decode("utf-8"))
                        print(r["stderr"].decode("utf-8"), file=sys.stderr)
            return results
        except (DockerError, EpicBoxError) as e:
            logger.warning(f"Docker error: {e}.\n Retrying {attempt + 1} of {retries_on_docker_error}...")
            time.sleep(2 ** attempt)
            continue
    raise Exception("Failed to run commands on docker")


def convert_pytest_outputs_to_testsuite_exec_result(
    pytest_out: CmdExecResult,
    cat_out: CmdExecResult,
) -> TestsuiteExecResult:
    if not cat_out.cmd.cmd.startswith("cat "):
        msg = "Output command does not start with cat"
        raise Exception(msg)

    pytest_exit_code = pytest_out.exit_code
    ran_tests = pytest_out.completed and pytest_exit_code in {
        pytest.ExitCode.OK.value,
        pytest.ExitCode.NO_TESTS_COLLECTED.value,
        pytest.ExitCode.TESTS_FAILED.value,
        # pytest.ExitCode.INTERRUPTED.value,
        # ^ This for some reason will happen if have a syntax error in the test
        # internal error, or pytest command usage error
    }
    xml_result = cat_out.stdout.decode("utf-8")
    ran_with_results = len(xml_result) > 0 and ran_tests
    return TestsuiteExecResult(
        got_results=ran_with_results,
        collection_error=pytest_out.completed
        and (
            pytest_exit_code
            in {
                pytest.ExitCode.USAGE_ERROR.value,
                pytest.ExitCode.INTERNAL_ERROR.value,
                pytest.ExitCode.INTERRUPTED.value,
            }
        ),
        exec_error=not pytest_out.completed,
        test_suite_exec_result=pytest_out,
        xml_result=xml_result,
    )


def pytest_on_docker(
    files: ProjectDir,
    run_file: str,
    docker_env: DockerExecutionContext = PY_DEFAULT_DOCKER_ENV,
    extra_setup_cmds: Iterable[Cmd] = (),
    limits: ExecLimits = None,  # Use whatever docker_env has, or default
) -> TestsuiteExecResult:
    xml_name = "test_report_results_from_sb.xml"
    if limits is None:
        limits = docker_env.default_limits or ExecLimits()
    out = run_on_docker(
        docker_env,
        (
            *docker_env.setup_cmds,
            *extra_setup_cmds,
            *[
                Cmd(f"rm {xml_name}"),
                Cmd(f"python -m pytest {run_file} -s --junitxml={xml_name}"),
                Cmd(f"cat {xml_name}"),
            ],
        ),
        files=files,
        interactive_shell_debug=False,
        limits=limits,
    )
    return convert_pytest_outputs_to_testsuite_exec_result(out[-2], out[-1])


def _build_docker_image(dockerfile_str, tag, pre_cmds: Iterable[Cmd]):
    if IN_GITHUB_ACTION and tag in _GITHUB_ACTION_IMAGE_DISALLOWED:
        raise RuntimeError(f"Building docker image {tag} is disabled in GitHub Actions")
    client = docker.from_env()

    bult_images = {t for i in client.images.list() for t in i.tags}

    if tag in _BUILT_IMAGES | bult_images:
        return None
    logger.info(f"Building docker image {tag}...")
    # Create a temporary directory
    with TemporaryDirectory() as tmp_dir_obj:
        tmpdir = Path(tmp_dir_obj)
        # Write Dockerfile string to a file in the temporary directory
        (tmpdir / "Dockerfile").write_text(dockerfile_str)

        for cmd in pre_cmds:
            cd_cmd = f"cd {tmpdir} && {cmd.cmd}"
            os.system(cd_cmd)

        # Build the docker image from the Dockerfile in the temporary directory
        try:
            image, build_logs = client.images.build(
                path=str(tmpdir),
                tag=tag,
                rm=True,
            )
        except BuildError as e:
            for chunk in e.build_log:
                if "stream" in chunk:
                    for line in chunk["stream"].splitlines():
                        logger.error(line)
            raise
        _BUILT_IMAGES.add(tag)
        logger.info(f"docker image {tag} build complete")
        return (image, build_logs)


##############################################
# NON-DOCKER VERSION ###


_result_var_name = "my_exec_result_value"


def combine_code_and_result_expression(
    code: str,
    result_expression: str,
) -> str:
    return (
        code
        + "\n\n"
        + "#### Result Expression ####\n"
        + _result_var_name
        + " = "
        + result_expression
    )


def safer_exec(
    code,
    result_expression,
    timeout=10.0,
    user_confirmation=True,
) -> "ExecResult":
    """
    DEPRECATED: switching to docker running instead.
    Runs code in the current python process with just some more risky globals (like file
    writing) set to None. This is based off how HumanEval ran their code. It is not
    intended to be fully secure.
    """
    if user_confirmation:
        logger.info("About to exec")
        logger.info("---")
        pretty_print_python_code(
            combine_code_and_result_expression(code, result_expression),
        )
        logger.info("---")
    # Ask the user to confirm
    if (
        user_confirmation
        and input("Are you sure you want to exec this code? (y/n)") != "y"
    ):
        msg = "User did not confirm"
        raise Exception(msg)

    result_queue = multiprocessing.Queue()

    # Create a Process with your_function as target and pass the argument and queue
    process = multiprocessing.Process(
        target=_unsafe_execute,
        args=(result_queue, code, result_expression, timeout),
    )

    # Start the process
    process.start()

    # Get the result from the queue
    result = result_queue.get(timeout=timeout + 1)

    # Join the process
    process.join(timeout=1)
    if process.is_alive():
        process.kill()

    # Only show executed code if it failed due to a type or assertion error
    # TODO make configurable?
    if isinstance(result.exception, AssertionError | TypeError):
        pretty_print_python_code(
            combine_code_and_result_expression(code, result_expression),
        )
        logger.info("Result: %s", result)

    return result


def _unsafe_execute(q, code, result_expression: str, timeout):
    # Copied from HumanEval. DEPRECATED for running on docker
    with create_tempdir():
        # These system calls are needed when cleaning up tempdir.
        import os
        import shutil

        rmtree = shutil.rmtree
        rmdir = os.rmdir
        chdir = os.chdir

        # Disable functionalities that can make destructive changes to the test.
        reliability_guard()
        check_program = combine_code_and_result_expression(
            code,
            result_expression,
        )

        try:
            exec_globals = {"my_exec_result_value": None}
            # track the start of the execution
            import time

            start_time = time.time()

            with swallow_io(), time_limit(timeout):
                # WARNING
                # This program exists to execute untrusted model-generated code. Although
                # it is highly unlikely that model-generated code will do something overtly
                # malicious in response to this test suite, model-generated code may act
                # destructively due to a lack of model capability or alignment.
                # Users are strongly encouraged to sandbox this evaluation suite so that it
                # does not perform destructive actions on their host or network. For more
                # information on how OpenAI sandboxes its code, see the accompanying paper.
                # Once you have read this disclaimer and taken appropriate precautions,
                # uncomment the following line and proceed at your own risk
                exec(check_program, exec_globals)
            q.put(
                ExecResult(
                    output=exec_globals[_result_var_name],
                    completed=True,
                    runtime=time.time() - start_time,
                    exception=None,
                ),
            )
        except TimeoutError as e:
            q.put(
                ExecResult(
                    output=None,
                    completed=False,
                    runtime=time.time() - start_time,
                    exception=e,
                ),
            )
        except BaseException as e:
            q.put(
                ExecResult(
                    output=None,
                    completed=False,
                    runtime=time.time() - start_time,
                    exception=e,
                ),
            )

        # Needed for cleaning up.
        shutil.rmtree = rmtree
        os.rmdir = rmdir
        os.chdir = chdir


@dataclass
class ExecResult:
    # Deprecated with docker running
    output: Any
    completed: bool
    runtime: float
    exception: BaseException | None = None
