import dataclasses
import logging
logger = logging.getLogger(__name__)
import os
import random
from pathlib import Path
from random import randint

import pandas as pd
from tree_sitter import Node

from synthegrator.code_problem_builders import make_simple_line_edit_problem
from synthegrator.code_problems import (
    CodeProblem,
    DiscoveredTestsuite,
)
from synthegrator.lang_specs.lang_spec import (
    PythonFunctionParser,
)
from synthegrator.lang_specs.lang_spec_python import PythonLangSpec
from synthegrator.sandboxing import DYPYBENCH_DOCKER_ENV, Cmd
from synthegrator.util import (
    IteratorWithLength,
    MultilineString,
)

cur_path = Path(__file__).parent.absolute()


def _default_projects_root() -> Path:
    ev = os.environ.get("DYPYBENCH_PROJECTS_ROOT", None)
    directory = None
    if ev is not None:
        directory = Path(ev).resolve()
    elif Path("~/dypybench_projects").expanduser().exists():
        directory = Path("~/dypybench_projects").expanduser().resolve()
    if directory is None or not directory.exists():
        if ev is not None:
            print(f"DYPYBENCH_PROJECTS_ROOT {ev} is not a valid directory")
        else:
            print(
                "DYPYBENCH_PROJECTS_ROOT is not set and ~/dypybench_projects does not exist. See "
                "synthegrator/synthdatasets/dypybench_setup.md for setup info on getting"
                "the projects to the source.",
            )
        msg = "Could not find the dypybench projects root"
        raise FileNotFoundError(msg)
    return directory


_suite_problems = [  # Precompute a test suite for each project.
    DiscoveredTestsuite(
        cmds=[
            # Cmd(
            #    "cd /DyPyBench && python3 dypybench.py --copy"
            # ),
            Cmd(
                "cd /DyPyBench && python3 dypybench.py --testtemp"
                f" {project_id} --target /sandbox",
            ),
            Cmd(f"cat /sandbox/project{project_id}/result.xml"),
        ],
    )
    for project_id in range(51)
]

_project_docker_envs = [
    dataclasses.replace(
        DYPYBENCH_DOCKER_ENV,
        before_file_copy_cmds=(
            *list(DYPYBENCH_DOCKER_ENV.before_file_copy_cmds),
            Cmd(
                f"cd /DyPyBench && python3 dypybench.py --copy {project_id} --target"
                " /sandbox",
            ),
        ),
    )
    for project_id in range(51)
]


def _sample_line(
    random_line_seed: int,
    path: Path,
    function: MultilineString,
    function_parser: PythonFunctionParser,
) -> int:
    sampled = 0
    path_seed = int.from_bytes(str(path).encode("utf-8"), "big") % 2**30
    while sampled <= 100:
        old_state = random.getstate()
        use_seed = random_line_seed + sampled + path_seed
        random.seed(use_seed)
        line_number = randint(
            function_parser.body_span_begin,
            function_parser.body_span_end,
        )
        random.setstate(old_state)
        line = function[line_number]
        if (
            "try:" not in line
            and line.strip() != ""
            and not line.lstrip().startswith("#")
        ):
            return line_number
        sampled += 1

    return None


def produce_dy_py_bench_problem(
    task_id: str,
    path: str,
    identifier: str,
    func_begin: int,
    func_end: int,
    function: str,
    project_id: str,
    line_to_replace: int | None = None,
    strip_docstring: bool = False,
    projects_root: str | None = None,
    random_line_seed: int = 42,
) -> CodeProblem:
    if projects_root is None:
        projects_root = _default_projects_root()

    projects_root = projects_root or _default_projects_root()
    func_begin = func_begin - 1  # 0 index
    func_end = func_end - 1  # 0 index

    function = MultilineString(function)
    function_parser = PythonFunctionParser(str(function))

    source_path = Path(projects_root) / f"project{project_id}" / path
    if not source_path.exists():
        msg = f"Could not find the source file for {source_path}"
        raise Exception(msg)

    source_file = MultilineString(source_path.read_text())

    if line_to_replace is None:
        function_line_to_replace = _sample_line(
            random_line_seed,
            source_path,
            function,
            function_parser,
        )
        if not function_line_to_replace:
            msg = f"Could not sample valid line for {task_id}"
            raise Exception(msg)
        source_line_to_replace = func_begin + function_line_to_replace
    else:
        source_line_to_replace = line_to_replace

    if strip_docstring:
        func_without_doc = MultilineString(
            function_parser.function_without_doc,
        )
        doc_offset = len(function) - len(func_without_doc)
        if line_to_replace:
            line_to_replace -= doc_offset
            assert line_to_replace > 0
            function_line_to_replace = line_to_replace
        function = func_without_doc

    known_solution = source_file[source_line_to_replace]

    invisible_spans = [
        (0, func_begin),
        (func_end + 1, len(source_file)),
    ]

    if strip_docstring:
        parser = function_parser
        docstring_node = parser.docstring_node
        if isinstance(docstring_node, Node):
            ds_start = func_begin + docstring_node.start_point[0]
            ds_end = func_begin + docstring_node.end_point[0] + 1
            if ds_end == invisible_spans[1][0]:
                logger.error("Is the function empty? %s", str(task_id))
                return None
            invisible_spans.insert(1, (ds_start, ds_end))
        else:
            logger.error("No docstring node? %s", str(task_id))

            return None

    return make_simple_line_edit_problem(
        code=str(source_file),
        invisible_spans=invisible_spans,
        test_cases=[_suite_problems[int(project_id)]],
        edit_lines_span=(source_line_to_replace, source_line_to_replace + 1),
        current_text_visible=False,
        target_path=f"project{project_id}/" + path,
        known_solutions=[known_solution],
        lang_spec=PythonLangSpec(),
        dataset_name="dypybench_line_completion",
        problem_id=task_id,
        docker_env=_project_docker_envs[int(project_id)],
        override_instructions="",
        cap_to_max_same_num_lines=True,
    )


def _gather_dypybench_funcs_df_from_json(
    json_path: str = cur_path / "dypybench_data/functions_dataset.jsonl",
    selected_ids_path: str = cur_path / "dypybench_data/selected_ids.parquet",
):
    df = pd.read_json(json_path, lines=True)
    if Path(selected_ids_path).exists():
        selected_ids = pd.read_parquet(selected_ids_path)
        df = df[df.index.isin(selected_ids.index)]
    df = df[
        (df.loc_len > 3)
        & (df.docstring_lines >= 1)
        & (df.coverage >= 100)
        & (df.filecoverage > 0)
    ]

    bad_functions = {
        109,
        7642,
        8226,
        8232,
        8239,
        8247,
        8255,
        8265,
        8306,
        8345,
        8384,
        8396,
        8457,
        8494,
        8507,
        8518,
        9068,
        9137,
        10776,
        10785,
        10789,
        10790,
        11438,
        20972,
        20973,
        20974,
        25916,
        26424,
        27592,
        28123,
    }

    # TODO: Remove bad functions from dataset itself rather than hack here
    return df[~df.index.isin(bad_functions)]


if __name__ == "__main__":
    df = _gather_dypybench_funcs_df_from_json()
    df.to_parquet(cur_path / "dypybench_data/func_dataset_subset.parquet")
    print(df)


def _get_dypy_df():
    return pd.read_parquet(cur_path / "dypybench_data/func_dataset_subset.parquet")


def yield_dypybench(
    df: pd.DataFrame = None,
    strip_docstring: bool = False,
    max_problems: int | None = None,
) -> IteratorWithLength[CodeProblem]:
    from datasets import Dataset

    if df is None:
        df = _get_dypy_df()

    dataset = Dataset.from_pandas(df, preserve_index=True).rename_column(
        "__index_level_0__",
        "index",
    )

    if max_problems and len(dataset) > max_problems:
        # Shuffle the dataset and take the first max_problems
        dataset = dataset.shuffle(seed=42).select(range(max_problems))

    def inner():
        for ex in dataset:
            problem = produce_dy_py_bench_problem(
                task_id=str(ex["index"]),
                path=str(ex["path"]),
                identifier=str(ex["identifier"]),
                func_begin=int(ex["func_begin"]),
                func_end=int(ex["func_end"]),
                function=str(ex["function"]),
                project_id=str(ex["project"]),
                line_to_replace=(
                    int(ex["line_number_to_replace"])
                    if ex.get("line_number_to_replace")
                    else None
                ),
                strip_docstring=strip_docstring,
            )

            if problem:
                yield problem

    return IteratorWithLength(inner(), len(dataset))
