import itertools
import logging
import os
from itertools import islice
from unittest.mock import MagicMock

import pytest
from tqdm.auto import tqdm

from synthegrator.code_solver import (
    LmCodeSolverAutoRegressive,
)
from synthegrator.execution_threading import evaluate_all_solutions, solve_all_problems
from synthegrator.memory_fs import ProjectDir
from synthegrator.problem_rendering_insertion_tags import (
    TaggedEditRenderer,
    TaggedEditResponseParser,
)
from synthegrator.sandboxing import run_on_docker
from synthegrator.solution_eval import (
    evaluate_code_problem_execution,
    parse_junit_test_cases,
)
from synthegrator.synthdatasets.defects4j import (
    DEFECTS_4_J_DOCKER_ENV,
    _default_projects_root,
    _get_deepest_path,
    make_discovered_test_case,
    yield_defects4j,
)

logging.getLogger().setLevel(level="DEBUG")


IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
SKIP_DOCKER_TESTS = True
skip_docker_reason = "Docker install takes a long time"


def test_read_from_zip():
    from synthegrator.synthdatasets.defects4j_data.dataset import read_file_from_zip

    content = read_file_from_zip(
        "Defects4J_projects_clean/Chart_1/source/org/jfree/chart/renderer/category/AbstractCategoryItemRenderer.java"
    )
    assert content.startswith(
        "/* ===========================================================\r\n"
        " * JFreeChart : a free chart library for the Java(tm) platform"
    )


@pytest.mark.skipif(SKIP_DOCKER_TESTS, reason=skip_docker_reason)
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Downloads large docker images")
def test_can_run_work_passing():
    project_name = "Chart"
    problem_id = "1"
    start_path = _default_projects_root() / f"{project_name}_{problem_id}"
    file_path = str(
        _get_deepest_path(
            start_path,
            file_only=True,
        ).relative_to(start_path),
    )
    assert (
        file_path
        == "source/org/jfree/chart/renderer/category/AbstractCategoryItemRenderer.java"
    )
    assert (start_path / file_path).exists()
    discovered_suite = make_discovered_test_case(
        project_name="Chart",
        problem_id="1",
        is_buggy=False,
        file_path=file_path,
    )
    cmd_outputs = run_on_docker(
        docker_context=DEFECTS_4_J_DOCKER_ENV,
        cmds=discovered_suite.cmds,
        interactive_shell_debug=False,
        files=ProjectDir.construct_with_one_file(
            "solution.java",
            _get_deepest_path(
                _default_projects_root() / f"{project_name}_{problem_id}_fixed",
                file_only=True,
            ).read_text(),
        ),
    )
    exec_result = cmd_outputs[-1].stdout.decode()
    test_result = parse_junit_test_cases(exec_result, MagicMock())
    print(test_result)
    assert test_result.all_passed()


@pytest.mark.skipif(SKIP_DOCKER_TESTS, reason=skip_docker_reason)
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Downloads large docker images")
def test_can_run_work_buggy():
    project_name = "Chart"
    problem_id = "1"
    start_path = _default_projects_root() / f"{project_name}_{problem_id}"
    file_path = str(
        _get_deepest_path(
            start_path,
            file_only=True,
        ).relative_to(start_path),
    )
    assert (
        file_path
        == "source/org/jfree/chart/renderer/category/AbstractCategoryItemRenderer.java"
    )
    assert (start_path / file_path).exists()
    discovered_suite = make_discovered_test_case(
        project_name="Chart",
        problem_id="1",
        is_buggy=True,
        file_path=file_path,
    )
    cmd_outputs = run_on_docker(
        docker_context=DEFECTS_4_J_DOCKER_ENV,
        cmds=discovered_suite.cmds,
        interactive_shell_debug=False,
        files=ProjectDir.construct_with_one_file(
            "solution.java",
            _get_deepest_path(
                _default_projects_root() / f"{project_name}_{problem_id}",
                file_only=True,
            ).read_text(),
        ),
    )
    exec_result = cmd_outputs[-1].stdout.decode()
    test_result = parse_junit_test_cases(exec_result, MagicMock())
    assert not test_result.all_passed()


def test_read_all_problems():
    all_problems = list(yield_defects4j())
    assert len(all_problems) == 119


@pytest.mark.skipif(SKIP_DOCKER_TESTS, reason=skip_docker_reason)
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Downloads large docker images")
def test_new_sample_eval():
    all_problems = list(yield_defects4j())
    assert len(all_problems) == 119
    for example_defects4j_prob in tqdm(
        itertools.islice(all_problems, 0, None, 20),
        total=119 // 20,
    ):
        print("Problem id:", example_defects4j_prob.problem_id)
        assert len(example_defects4j_prob.known_solutions) == 1
        result = evaluate_code_problem_execution(
            example_defects4j_prob,
            example_defects4j_prob.known_solutions[0],
            try_cache=False,
            do_not_execute_if_syntax_fail=False,
        )
        print("All passed:", result.all_passed())


@pytest.mark.skipif(SKIP_DOCKER_TESTS, reason=skip_docker_reason)
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Downloads large docker images")
def test_prompt():
    from lmwrapper.openai_wrapper import OpenAiModelNames, get_open_ai_lm

    example_defects4j_prob = next(yield_defects4j())
    print("Problem id:", example_defects4j_prob.problem_id)
    renderer = TaggedEditRenderer()
    lm = get_open_ai_lm(OpenAiModelNames.gpt_4)
    renderer(
        example_defects4j_prob,
        lm,
    )
    solver = LmCodeSolverAutoRegressive(
        lm,
        prompt_renderer=renderer,
        response_parser=TaggedEditResponseParser(),
        include_lm_response=False,
    )
    problems = list(islice(yield_defects4j(), 1))
    solution = solver.solve(problems[0])
    # print("Prompt")
    # print(solution.lm_prediction.prompt)
    # print("Completion text")
    # print(solution.lm_prediction.completion_text)
    assert len(solution.solve_steps) == 1
    # problems = islice(problems, 10)
    thread_goal = 3
    solves = list(
        solve_all_problems(
            solver=solver,
            problems=problems,
            max_threads=thread_goal if solver.allows_multithreading else 1,
            cuda_clear_freq=False,
        ),
    )
    for solve in islice(solves, 10):
        print("Solve value:")
        print(solve.solve_steps[0].value)
    #  print("aplication")
    #  pretty_print_python_code(solve.apply().get_only_file().content_str)
    # exit()
    evals = evaluate_all_solutions(solves, max_threads=thread_goal)
    evals = list(evals)
    print(evals)
    # pass_rate = sum(1 for e in evals if e.main_metric.is_success) / len(evals)
    # print(f"Pass rate: {pass_rate}")


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="No api")
def test_prompt2():
    from lmwrapper.openai_wrapper import OpenAiModelNames, get_open_ai_lm

    example_defects4j_prob = next(yield_defects4j())
    print("Problem id:", example_defects4j_prob.problem_id)
    renderer = TaggedEditRenderer()
    # TODO mocked version
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo_instruct)
    render = renderer(
        example_defects4j_prob,
        lm,
    )
    print(render.prompt.text)
    assert render.prompt.text.endswith("Answer:\n" '<edit_solve id="0_1xtCqX">')
    solver = LmCodeSolverAutoRegressive(
        lm,
        prompt_renderer=renderer,
        response_parser=TaggedEditResponseParser(),
        include_lm_response=True,
    )
    problems = list(islice(yield_defects4j(), 1))
    solution = solver.solve(problems[0])
    assert solution.lang_spec.get_lang_md_name() == "java"
    print("Prompt")
    print(solution.lm_prediction.prompt.text)
    print("stop", solution.lm_prediction.prompt.stop)
    print("Completion text")
    print(solution.lm_prediction.completion_text)
    assert len(solution.solve_steps) == 1
    # problems = islice(problems, 10)
    solves = list(
        solve_all_problems(
            solver=solver,
            problems=problems,
            max_threads=1,
            cuda_clear_freq=False,
        ),
    )
    for solve in islice(solves, 10):
        print("Solve value:")
        print(solve.solve_steps[0].value)


@pytest.mark.skipif(SKIP_DOCKER_TESTS, reason=skip_docker_reason)
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Downloads large docker images")
def test_defect_apply():
    prob = next(yield_defects4j())
    solution = prob.known_solutions[0]
    files = solution.apply()
    print(files["solution.java"].content)
    result = evaluate_code_problem_execution(
        prob,
        prob.known_solutions[0],
        try_cache=False,
        do_not_execute_if_syntax_fail=True,
    )
    assert result.all_passed()
    print("All passed:", result.all_passed())


if __name__ == "__main__":
    pass
