import os
from unittest.mock import MagicMock

import numpy as np
import pytest
from lmwrapper.abstract_predictor import LmPredictor
from lmwrapper.openai_wrapper import (
    OpenAiModelNames,
    get_open_ai_lm,
)
from lmwrapper.structs import LmPrediction, LmPrompt

from synthegrator.code_problem_builders import (
    make_instruction_only_problem,
    make_simple_line_edit_problem,
    make_simple_method_completion_problem,
)
from synthegrator.code_problems import (
    CodeProblem,
    CodeSolution,
    TestCaseMethodCallIsEq,
)
from synthegrator.code_solver import (
    BaseCodeSolver,
    LmCodeSolverAutoRegressive,
    LmUncertaintyHook,
    LmUncertaintyProbAverage,
    LmUncertaintyProbTotal,
)
from synthegrator.few_shotting import FewShotConfig, FewShotLibrary
from synthegrator.lang_specs.lang_spec_python import PythonLangSpec
from synthegrator.lm_few_shotting_tools import FewShotRendererWrapper
from synthegrator.problem_rendering import (
    PromptRendererSingleEditGeneric,
)
from synthegrator.problem_rendering_insertion_tags import (
    TaggedEditRenderer,
    TaggedEditResponseParser,
)
from synthegrator.response_parser import (
    ResponseParserSingleEdit,
    _get_new_content_start,
)
from synthegrator.solution_eval import apply_solution, evaluate_code_problem_execution
from synthegrator.tests.non_dataset_tests.test_problem_rendering import factor_problems
from synthegrator.transformation_spec import (
    StsInsert,
)
from synthegrator.uncertainty_modeling import (
    DelayedUncertaintyEstimate,
    ProbabilityIsCorrect,
)
from synthegrator.util import estimate_best_overlap_span

assert factor_problems, "Unused import detectors struggle with fixtures"

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"

LM_SOLVERS = [
    LmCodeSolverAutoRegressive,
    # LmCodeSolverAutoRegressive,
]


def run_with_oai(solver: BaseCodeSolver):
    problem = make_simple_method_completion_problem(
        signature="def foo(x) -> int:",
        prompt="Write a function that doubles its input.",
        test_cases=[
            TestCaseMethodCallIsEq(
                test_input=(2,),
                output=4,
                method_name="foo",
                is_hidden_test=False,
                fail_message="",
            ),
            TestCaseMethodCallIsEq(
                test_input=(3,),
                output=6,
                method_name="foo",
                is_hidden_test=False,
                fail_message="",
            ),
        ],
        target_path="solution.py",
    )
    content = problem.working_directory.files["solution.py"].content.decode("utf-8")
    print(content)
    assert (
        content == "def foo(x) -> int:\n"
        '    """\n'
        "    Write a function that doubles its input.\n"
        '    """\n'
    )
    assert problem.working_directory is not None
    assert len(problem.transformation_spec.statements) == 1
    result = solver.solve(problem)
    assert len(result.solve_steps) == 1
    print(result.solve_steps)
    assert result.solve_steps[0].value in {
        "    return x * 2",
        "    return x * 2\n",
        "    return x*2",
        "    return x*2\n",
    }
    evaluations = evaluate_code_problem_execution(
        problem,
        result,
        user_prompted_guard=False,
    )
    assert evaluations.all_passed()


def test_basic_oai_mocked():
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo_instruct)
    lm.predict = MagicMock(return_value=LmPrediction("    return x * 2", None, None))
    run_with_oai(LmCodeSolverAutoRegressive(lm))


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="No API key in GitHub Actions")
def test_basic_oai():
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo_instruct)
    run_with_oai(LmCodeSolverAutoRegressive(lm))


def test_basic_oai_mocked_chat():
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    lm.predict = MagicMock(return_value=LmPrediction("return x * 2", None, None))
    run_with_oai(LmCodeSolverAutoRegressive(lm))


def test_basic_oai_mocked_chat_with_full_method():
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    lm.predict = MagicMock(
        return_value=LmPrediction(
            "def foo(x) -> int:\n"
            '    """\n'
            "    Write a function that doubles its input.\n"
            '    """\n'
            "    return x * 2\n",
            None,
            None,
        ),
    )
    run_with_oai(LmCodeSolverAutoRegressive(lm))


@pytest.mark.parametrize(
    "with_lang_name",
    [True, False],
)
def test_basic_oai_mocked_chat_with_block(with_lang_name: bool):
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    lm.predict = MagicMock(
        return_value=LmPrediction(
            "Sure, here's the solution!\n"
            f"```{'python' if with_lang_name else ''}\n"
            "def foo(x) -> int:\n"
            '    """\n'
            "    Write a function that doubles its input.\n"
            '    """\n'
            "    return x * 2\n"
            "```\n",
            None,
            None,
        ),
    )
    run_with_oai(LmCodeSolverAutoRegressive(lm))


def test_basic_oai_mocked_chat_with_block_no_docstring():
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    lm.predict = MagicMock(
        return_value=LmPrediction(
            "Sure, here's the solution!\n"
            "```python\n"
            "def foo(x) -> int:\n"
            "    return x * 2\n"
            "```\n",
            None,
            None,
        ),
    )
    run_with_oai(LmCodeSolverAutoRegressive(lm))


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="No API key in GitHub Actions")
def test_basic_oai_chat():
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    run_with_oai(LmCodeSolverAutoRegressive(lm))


def run_with_oai_few_shot(lm, factor_problems):
    solver = LmCodeSolverAutoRegressive(
        lm,
        prompt_renderer=PromptRendererSingleEditGeneric(
            few_shot_config_default=FewShotConfig(
                library=FewShotLibrary(factor_problems[:2]),
                num_examples=1,
                change_examples_between_solves=False,
            ),
        ),
    )
    problem = factor_problems[0]
    solution = solver.solve(problem)
    assert solution.solve_steps[0].value in {
        "    return x * 2",
        "    return x * 2\n",
    }


def run_with_oai_few_shot_wrapped(lm, factor_problems):
    renderer = PromptRendererSingleEditGeneric()
    parser = ResponseParserSingleEdit()
    solver = LmCodeSolverAutoRegressive(
        lm,
        prompt_renderer=FewShotRendererWrapper(
            renderer,
            parser,
            few_shot_config=FewShotConfig(
                library=FewShotLibrary(factor_problems[:2]),
                num_examples=1,
                change_examples_between_solves=False,
            ),
        ),
        response_parser=parser,
    )
    problem = factor_problems[0]
    solution = solver.solve(problem)
    assert solution.solve_steps[0].value in {
        "    return x * 2",
        "    return x * 2\n",
    }


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="No API key in GitHub Actions")
def test_chat_few_shot(factor_problems):
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    run_with_oai_few_shot(lm, factor_problems)


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="No API key in GitHub Actions")
def test_chat_few_shot_wrapped(factor_problems):
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    run_with_oai_few_shot_wrapped(lm, factor_problems)


def test_chat_few_shot_mocked(factor_problems):
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    lm.predict = MagicMock(
        return_value=LmPrediction(
            "```\n    return x * 2\n```\n",
            None,
            None,
        ),
    )
    run_with_oai_few_shot(lm, factor_problems)


def test_chat_few_shot_mocked_wrapped(factor_problems):
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    lm.predict = MagicMock(
        return_value=LmPrediction(
            "```\n    return x * 2\n```\n",
            None,
            None,
        ),
    )
    run_with_oai_few_shot_wrapped(lm, factor_problems)


@pytest.mark.parametrize(
    ("first_str", "new_content", "expected"),
    [
        ("hello", "hello world", 5),
        ("abc", "abcd", 3),
        ("abc\n", "\n", 1),
        ("fooabc", "abcd", 3),
        ("", "abcd", 0),
        ("abc", "abcabc", 3),
        ("abcdecd", "cdecdfgh", len("cdecd")),
        ("fooabc", "", 0),
        ("aaacCC", "CCCb", len("CC")),
        # (
        #        "def foo(x) -> int:\n    '''\n    Write a func\n    '''\n",
        #        "def foo(x) -> int:\n    return x * 2\n",
        #        len("def foo() -> int:\n    ")
        # ),
    ],
)
def test_get_new_content_start(first_str, new_content, expected):
    assert _get_new_content_start(first_str, new_content) == expected


@pytest.fixture
def example_edit_problem() -> CodeProblem:
    return make_simple_line_edit_problem(
        code=(
            "def foo(vals: list[int]) -> list[int]:\n"
            '    """\n'
            "    Write a function that doubles every value in a list\n"
            '    """\n'
            "    out = []\n"
            "    for val in vals:\n"
            "        out.append(val * 10)\n"
            "    return out\n"
        ),
        edit_lines_span=(6, 7),
        test_cases=[
            TestCaseMethodCallIsEq(
                test_input=([1, 2, 3],),
                output=[2, 4, 6],
                method_name="foo",
                is_hidden_test=False,
                fail_message="",
            ),
        ],
        current_text_visible=True,
        target_path="solution.py",
        known_solutions=[
            "        out.append(val * 2)\n",
        ],
    )


@pytest.fixture
def example_edit_problem_before_cur_after_invisible() -> CodeProblem:
    return make_simple_line_edit_problem(
        code=(
            "import math\n"
            "\n"
            "def foo(vals: list[int]) -> list[int]:\n"
            '    """\n'
            "    Write a function that doubles every value in a list\n"
            '    """\n'
            "    out = []\n"
            "    for val in vals:\n"
            "        out.append(val * 10)\n"
            "    return out\n"
            "\n"
            "print('hello world')"
        ),
        edit_lines_span=(8, 9),
        invisible_spans=[(0, 2), (10, 12)],
        test_cases=[
            TestCaseMethodCallIsEq(
                test_input=([1, 2, 3],),
                output=[2, 4, 6],
                method_name="foo",
                is_hidden_test=False,
                fail_message="",
            ),
        ],
        current_text_visible=False,
        target_path="solution.py",
        known_solutions=[
            "        out.append(val * 2)\n",
        ],
    )


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="No API key in GitHub Actions")
def test_edit_solve(example_edit_problem):
    problem = example_edit_problem
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    solver = LmCodeSolverAutoRegressive(lm)
    result = solver.solve(problem)
    assert len(result.solve_steps) == 1
    evaluations = evaluate_code_problem_execution(
        problem,
        result,
        user_prompted_guard=False,
    )
    assert evaluations.all_passed()


@pytest.mark.parametrize("lmsolver", LM_SOLVERS)
@pytest.mark.parametrize(
    "mocked_return",
    [
        "        out.append(val * 2)",
        "        out.append(val * 2)\n",
        "out.append(val * 2)\n",
        "    out.append(val * 2)\n",
        "        out.append(val * 2)\n    return out\n",
        "        out.append(val * 2)\n    return out",
        "        out.append(val * 2) # comment\n    return out # return the output\n",
        "        out.append(val * 2) # comment\n    return out # return the output",
    ],
)
def test_edit_solve_mocked(example_edit_problem, lmsolver, mocked_return):
    # TODO: work even when not forced at 1 line
    problem = example_edit_problem
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    lm.predict = MagicMock(return_value=LmPrediction(mocked_return, None, None))
    solver = lmsolver(lm)
    result = solver.solve(problem)
    assert len(result.solve_steps) == 1
    new_files = apply_solution(result)
    assert (
        new_files.get_only_file().content_str
        == 'def foo(vals: list[int]) -> list[int]:\n    """\n    Write a function'
        ' that doubles every value in a list\n    """\n    out = []\n    for val'
        " in vals:\n        out.append(val *"
        f' 2){" # comment" if " # comment" in mocked_return else ""}\n    return'
        " out\n"
    )
    evaluations = evaluate_code_problem_execution(
        problem,
        result,
        user_prompted_guard=False,
    )
    assert evaluations.all_passed()


@pytest.mark.parametrize("lmsolver", LM_SOLVERS)
@pytest.mark.parametrize(
    "mocked_return",
    [
        "        out.append(val * 2)",
    ],
)
def test_edit_solve_invisible_mocked(
    example_edit_problem_before_cur_after_invisible,
    lmsolver,
    mocked_return,
):
    # TODO: work even when not forced at 1 line
    problem = example_edit_problem_before_cur_after_invisible
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo)
    lm.predict = MagicMock(return_value=LmPrediction(mocked_return, None, None))
    solver = lmsolver(lm)
    result = solver.solve(problem)
    assert len(result.solve_steps) == 1
    new_files = apply_solution(result)
    assert (
        new_files.get_only_file().content_str
        == 'import math\n\ndef foo(vals: list[int]) -> list[int]:\n    """\n   '
        ' Write a function that doubles every value in a list\n    """\n    out'
        " = []\n    for val in vals:\n        out.append(val *"
        f' 2){" # comment" if " # comment" in mocked_return else ""}\n    return'
        " out\n\nprint('hello world')"
    )
    evaluations = evaluate_code_problem_execution(
        problem,
        result,
        user_prompted_guard=False,
    )
    assert evaluations.all_passed()


@pytest.mark.parametrize("lmsolver", LM_SOLVERS)
@pytest.mark.parametrize(
    "model_name",
    [OpenAiModelNames.gpt_3_5_turbo, OpenAiModelNames.gpt_3_5_turbo_instruct],
)
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="No API key in GitHub Actions")
def test_instructions_only(example_edit_problem, lmsolver, model_name):
    # TODO: work even when not forced at 1 line
    problem = make_instruction_only_problem(
        instructions=(
            "Write a function that doubles its input. "
            "The function should be called `double` and take a single int argument `x`."
        ),
        insert_node=StsInsert(stop_at_block_end=False),
        test_cases=[
            TestCaseMethodCallIsEq(
                is_hidden_test=False,
                fail_message="",
                method_name="double",
                test_input=(x,),
                output=x * 2,
            )
            for x in range(3)
        ],
        known_solutions=["def double(x: int) -> int:\n    return x * 2"],
    )
    evaluations = evaluate_code_problem_execution(
        problem,
        problem.known_solutions[0],
    )
    assert evaluations.all_passed()
    solver = lmsolver(get_open_ai_lm(model_name))
    solve = solver.solve(problem)
    print("Solution:")
    print(apply_solution(solve).get_only_file().content_str)
    eval = evaluate_code_problem_execution(
        problem,
        solve,
    )
    assert eval.all_passed()


@pytest.mark.skip(reason="Uncertainty hooks needs to be refactored")
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="No API key in GitHub Actions")
def test_uncertainty_hook(factor_problems):
    lm = get_open_ai_lm(OpenAiModelNames.gpt_3_5_turbo_instruct)
    solver = LmCodeSolverAutoRegressive(lm)
    solver.add_uncertainty_hook(LmUncertaintyProbAverage())
    double_prob = factor_problems[0]
    solution = solver.solve(double_prob)
    assert len(solution.uncertainty_estimates) == 1
    file = apply_solution(solution).get_only_file().content_str
    estimate = solution.uncertainty_estimates[0].get_estimate()
    assert isinstance(estimate, ProbabilityIsCorrect)
    pred = lm.predict(LmPrompt(file, echo=True, max_tokens=0))
    print(pred.get_full_tokens())
    prompt_tokens = [
        "def",
        " double",
        "(",
        "x",
        ")",
        " ->",
        " int",
        ":",
        "\n",
        "   ",
        ' """',
        "\n",
        "   ",
        " Write",
        " a",
        " function",
        " that",
        " doubles",
        " its",
        " input",
        ".",
        "\n",
        "   ",
        ' """',
        "\n",
    ]
    continue_tokens = pred.get_full_tokens()[len(prompt_tokens) :]
    print(continue_tokens)
    log_probs = pred.full_logprobs[len(prompt_tokens) :]
    print(log_probs)
    avg = np.average(np.exp(log_probs))
    print(avg)
    assert avg == pytest.approx(estimate.value, abs=0.005)


@pytest.mark.parametrize("lmsolver", LM_SOLVERS)
@pytest.mark.skip(reason="Uncertainty hooks needs to be refactored")
def test_block_end(lmsolver):
    lm = get_open_ai_lm()
    pred = MagicMock()
    pred.completion_text = "    return 2\nassert False"
    pred.completion_tokens = ["   ", " return", " 2", "\n", "assert", " False"]
    completion_probs = [-0.1, -0.2, -0.3, -0.4, -0.5, -0.6]
    pred.completion_logprobs = completion_probs
    lm.predict = MagicMock(return_value=pred)
    solver = lmsolver(lm)
    num_calls = 0

    class ToyHook(LmUncertaintyHook):
        def prepare(
            self,
            lm_prediction: LmPrediction,
            model: LmPredictor,
            solution: CodeSolution,
            modified_full_text: str,
        ) -> DelayedUncertaintyEstimate:
            nonlocal num_calls
            num_calls += 1
            assert len(solution.solve_steps) == 1
            assert estimate_best_overlap_span(
                lm_prediction.completion_text,
                larger_tokens=lm_prediction.completion_tokens,
                mod_string=solution.solve_steps[0].value,
                tokenizer_func=model.tokenize,
            )
            if lmsolver == LmCodeSolverAutoRegressive:
                assert modified_full_text == (  # fmt: off
                    #'Complete the following method\n'
                    #'\n'
                    'def foo():\n    """\n    Always return 2\n    """\n    return 2'
                )  # fmt: on

    solver.add_uncertainty_hook(ToyHook())
    problem = make_simple_method_completion_problem(
        prompt="Always return 2",
        signature="def foo():",
        lang_spec=PythonLangSpec(),
        test_cases=[],
    )
    assert not problem.instructions_are_essential
    solution = solver.solve(problem)
    assert num_calls == 1
    content = apply_solution(solution).get_only_file().content_str
    assert content == (  # fmt: off
        'def foo():\n    """\n    Always return 2\n    """\n    return 2'
    )  # fmt: on


@pytest.mark.parametrize("lmsolver", LM_SOLVERS)
@pytest.mark.skip(reason="Uncertainty hooks needs to be refactored")
def test_block_end_avg(lmsolver):
    lm = get_open_ai_lm()
    pred = MagicMock()
    pred.completion_text = "    return 2\nassert False"
    pred.completion_tokens = ["   ", " return", " 2", "\n", "assert", " False"]

    def mock_tokenize(input_str):
        if input_str == pred.completion_text:
            return pred.completion_tokens
        if input_str == "    return 2\n":
            return ["   ", " return", " 2", "\n"]
        if input_str == "    return 2":
            return ["   ", " return", " 2"]
        msg = "Unexpected tokenizing input"
        raise AssertionError(msg)

    lm.tokenize = MagicMock(side_effect=mock_tokenize)
    completion_probs = [-0.1, -0.2, -0.3, -0.4, -0.5, -0.6]
    pred.completion_logprobs = completion_probs
    lm.predict = MagicMock(return_value=pred)
    solver = lmsolver(lm)
    solver.add_uncertainty_hook(LmUncertaintyProbTotal())
    problem = make_simple_method_completion_problem(
        prompt="Always return 2",
        signature="def foo():",
        lang_spec=PythonLangSpec(),
        test_cases=[],
    )
    solution = solver.solve(problem)
    val = solution.uncertainty_estimates[0].get_estimate().value
    expected = np.exp(np.sum(np.array(completion_probs)[:3]))
    assert val == pytest.approx(expected)


@pytest.mark.parametrize("lmsolver", LM_SOLVERS)
@pytest.mark.skip(reason="Uncertainty hooks needs to be refactored")
def test_block_end2(lmsolver):
    lm = get_open_ai_lm()
    pred = MagicMock()
    pred.completion_text = "    return 2\nassert False"
    pred.completion_tokens = ["   ", " return", " 2", "\n", "assert", " False"]

    def mock_tokenize(input_str):
        if input_str == pred.completion_text:
            return pred.completion_tokens
        if input_str == "    return 2\n":
            return ["   ", " return", " 2", "\n"]
        if input_str == "    return 2":
            return ["   ", " return", " 2"]
        msg = "Unexpected tokenizing input"
        raise AssertionError(msg)

    lm.tokenize = MagicMock(side_effect=mock_tokenize)
    completion_probs = [-0.1, -0.2, -0.3, -0.4, -0.5, -0.6]
    pred.completion_logprobs = completion_probs
    lm.predict = MagicMock(return_value=pred)
    solver = lmsolver(lm)
    solver.add_uncertainty_hook(LmUncertaintyProbAverage())
    problem = make_simple_method_completion_problem(
        prompt="Always return 2",
        signature="def foo():",
        lang_spec=PythonLangSpec(),
        test_cases=[],
    )
    solution = solver.solve(problem)
    val = solution.uncertainty_estimates[0].get_estimate().value
    expected = np.average(np.exp(np.array(completion_probs)[:3]))
    assert val == pytest.approx(expected)


@pytest.mark.parametrize("lmsolver", LM_SOLVERS)
def test_lead_new_line(lmsolver):
    lm = get_open_ai_lm()
    pred = MagicMock()
    pred.completion_text = "\n  return 2"
    lm.predict = MagicMock(return_value=pred)
    solver = lmsolver(lm)
    problem = make_simple_method_completion_problem(
        prompt="Always return 2",
        signature="def foo():",
        lang_spec=PythonLangSpec(),
        test_cases=[],
    )
    solution = solver.solve(problem)
    assert solution.solve_steps[0].value == "    return 2"
    apply_solution(solution)


@pytest.mark.parametrize("lmsolver", LM_SOLVERS)
def test_lead_new_line2(lmsolver):
    lm = get_open_ai_lm()
    pred = MagicMock()
    pred.completion_text = "\nreturn 2"
    lm.predict = MagicMock(return_value=pred)
    solver = lmsolver(lm)
    problem = make_simple_method_completion_problem(
        prompt="Always return 2",
        signature="def foo():",
        lang_spec=PythonLangSpec(),
        test_cases=[],
    )
    solution = solver.solve(problem)
    assert solution.solve_steps[0].value == "    return 2"
    apply_solution(solution)


# fmt: off
@pytest.mark.parametrize(
    "lm_name",
    [
        OpenAiModelNames.gpt_3_5_turbo_instruct,
        OpenAiModelNames.gpt_4,
    ],
)
@pytest.mark.parametrize(
    "completion_text",
    [
        '<edit_solve id="0_qkIO0">    return x * 2</edit_solve>',
        '<edit_solve id="0_qkIO0">return x * 2</edit_solve>',
        (
            "The answer is"
            '\n<edit_solve id="0_qkIO0">    return x * 2</edit_solve>\n'
            "Let me know what else you need"
        ),
        "<edit_solve>return x * 2</edit_solve>",  # Default value
        "return x * 2</edit_solve>",
        "    return x * 2",
        (
            "```python\n"
            "<edit_solve>     return x * 2</edit_solve>\n"
            "```"
        ),
    ],
)
def test_edit_solver(factor_problems, lm_name, completion_text):
    # fmt: on
    lm = get_open_ai_lm(lm_name)
    problem = factor_problems[0]
    pred = MagicMock()
    pred.completion_text = completion_text
    lm.predict = MagicMock(return_value=pred)
    renderer = TaggedEditRenderer()
    render = renderer(problem, MagicMock())
    parser = TaggedEditResponseParser()
    assert parser.parse(render, pred, problem)[0].value == "    return x * 2"
    solver = LmCodeSolverAutoRegressive(
        model=lm,
        prompt_renderer=renderer,
        response_parser=parser,
    )
    solution = solver.solve(problem)
    assert len(solution.solve_steps) == 1
    assert solution.solve_steps[0].value == "    return x * 2"
    apply_solution(solution)



@pytest.mark.parametrize(
    "lm_name",
    [
        OpenAiModelNames.gpt_3_5_turbo_instruct,
        OpenAiModelNames.gpt_4,
    ],
)
@pytest.mark.parametrize(
    "completion_text",
    [
        (
            "```java\n"
            "<fix>    if (options.dependencyOptions.needsManagement() && options.shouldRunClosurePass()) {</fix>\n"
            "```"
        )
    ],
)
def test_edit_solver(factor_problems, lm_name, completion_text):
    lm = get_open_ai_lm(lm_name)
    problem = factor_problems[0]
    pred = MagicMock()
    pred.completion_text = completion_text
    lm.predict = MagicMock(return_value=pred)
    renderer = TaggedEditRenderer(
        tag_name_edit="buggy",
        tag_name_solve="fix",
        include_first_tag_at_end=None,
        custom_closing_lines=(
            "Question: There is a bug in the above code snippet tagged by <buggy> and </buggy>. "
            "Please generate the correct version inside of tags <fix> and </fix>.\n"
            "Answer:\n"
        ),
    )
    render = renderer(problem, MagicMock())
    parser = TaggedEditResponseParser()
    v = "    if (options.dependencyOptions.needsManagement() && options.shouldRunClosurePass()) {"
    result = parser.parse(render, pred, problem)[0].value
    print("Expected: ", repr(v))
    print("Result: ", repr(result))
    assert result == v
    solver = LmCodeSolverAutoRegressive(
        model=lm,
        prompt_renderer=renderer,
        response_parser=parser,
    )
    solution = solver.solve(problem)
    assert len(solution.solve_steps) == 1
    assert solution.solve_steps[0].value == v
    apply_solution(solution)
