import itertools
import os
from unittest.mock import MagicMock

import pytest
from lmwrapper.huggingface_wrapper import get_huggingface_lm
from lmwrapper.openai_wrapper import OpenAiModelNames, get_open_ai_lm

from synthegrator.code_solver import (
    DummyCodeSolverAutoRegressive,
    LmCodeSolverAutoRegressive,
)
from synthegrator.df_converters import solution_evals_to_df
from synthegrator.execution_threading import solve_and_evaluate_problems
from synthegrator.synthdatasets.human_eval import yield_human_eval

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"


@pytest.mark.parametrize("max_threads_solve", [1, 2])
@pytest.mark.parametrize("max_threads_eval", [1, 2])
def test_df_converter_dummy(
    max_threads_solve,
    max_threads_eval,
):
    solver = DummyCodeSolverAutoRegressive(MagicMock())
    problems = itertools.islice(yield_human_eval(), 10)
    #problems = list(problems)
    #assert len(problems) == 10
    evals = solve_and_evaluate_problems(
        solver,
        problems,
        max_threads_solve=max_threads_solve,
        max_threads_eval=max_threads_eval,
    )
    df = solution_evals_to_df(evals, solver_key="dummy")
    assert len(df) == 10


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="No api key")
def test_df_converter_gpt():
    lm = get_open_ai_lm(
        OpenAiModelNames.gpt_3_5_turbo,
    )
    solver = LmCodeSolverAutoRegressive(lm)
    problems = itertools.islice(yield_human_eval(), 5)
    evals = solve_and_evaluate_problems(
        solver,
        problems,
        max_threads_solve=1,
        max_threads_eval=1,
    )
    df = solution_evals_to_df(evals, solver_key="gpt35")
    assert len(df) == 5


def test_df_converter_codegen():
    lm = get_huggingface_lm(
        "Salesforce/codegen-350M-multi",
        trust_remote_code=True,
    )
    solver = LmCodeSolverAutoRegressive(lm, include_lm_response=True)
    problems = itertools.islice(yield_human_eval(), 5)
    evals = solve_and_evaluate_problems(
        solver,
        problems,
        max_threads_solve=1,
        max_threads_eval=1,
    )
    df = solution_evals_to_df(evals, solver_key="codegen")
    assert len(df) == 5
