import ast
import doctest
import logging
logger = logging.getLogger(__name__)

from synthegrator.code_problem_builders import make_simple_method_completion_problem
from synthegrator.code_problems import (
    CodeProblem,
    TestCase,
    TestCaseArbitraryMethod,
    TestCaseExpressionIsEq,
    TestCaseMethodCallIsEq,
)
from synthegrator.lang_specs.lang_spec_python import PythonLangSpec
from synthegrator.solution_eval import apply_solution
from synthegrator.util import IteratorWithLength


def produce_code_problem_for_human_eval(
    task_id: str,
    prompt: str,
    canonical_solution: str,
    test: str,
    entry_point: str,
    is_eval_plus: bool,
) -> CodeProblem:
    imports, method = PythonLangSpec.split_before_last_method_name(prompt)
    signature, prompt = method.split("\n", 1)
    # ^ human eval uses irregular test case style. Some are real doctests, some are
    #   just text in an 'example' section. Some are doctests but done wrong (like are == expression)
    test_cases = _parse_assert_style_test_cases(test, entry_point)

    _original_prompt = prompt
    prob = make_simple_method_completion_problem(
        prompt=prompt,
        signature=signature,
        test_cases=test_cases,
        model_context=imports,
        known_solutions=[canonical_solution],
        lang_spec=PythonLangSpec(),
        prompt_already_docstring=True,
        problem_id=task_id,
        dataset_name="humaneval" if not is_eval_plus else "humaneval_plus",
    )
    apply_solution(prob.known_solutions[0])
    return prob


def yield_human_eval(
    use_eval_plus: bool = False,
    max_problems: int = None,
) -> IteratorWithLength[CodeProblem]:
    """
    Yields code problems from the HumanEval dataset.

    Args:
    ----
        use_eval_plus: If True, use the eval+ (from https://evalplus.github.io/)
            dataset instead of the original HumanEval dataset. This expands the
            number of test cases.

    """
    from datasets import load_dataset

    if not use_eval_plus:
        dataset = load_dataset("openai_humaneval")
    else:
        dataset = load_dataset("evalplus/humanevalplus")
    data = dataset["test"]

    def gen():
        for ex in data:
            yield produce_code_problem_for_human_eval(
                task_id=ex["task_id"],
                prompt=ex["prompt"],
                canonical_solution=ex["canonical_solution"],
                test=ex["test"],
                entry_point=ex["entry_point"],
                is_eval_plus=use_eval_plus,
            )

    return IteratorWithLength(gen(), min(len(data), max_problems or 9e9))


def _get_docstring_test_in_outs(docstring: str) -> list[tuple[str, str]]:
    docstring = docstring.strip()[3:-3]
    parser = doctest.DocTestParser()

    # Parse the doctests from the docstring
    doc_test = parser.get_doctest(docstring, {}, "test", "test", 0)

    # Extract and return the individual test cases
    test_cases = []
    for example in doc_test.examples:
        source = example.source.strip()
        want = example.want.strip()
        test_cases.append((source, want))
    return test_cases


def _parse_doctest_style_test_cases(docstring: str, method_name: str) -> list[TestCase]:
    doc_test_in_outs = _get_docstring_test_in_outs(docstring)
    test_cases = []
    for i, (source, want) in enumerate(doc_test_in_outs):
        is_test_method_call = source.startswith(f"{method_name}(") and source.endswith(
            ")",
        )
        output = ast.literal_eval(want) if want.strip() else None
        if is_test_method_call:
            args = ast.literal_eval(source[len(method_name) + 1 : -1])
            test_cases.append(
                TestCaseMethodCallIsEq(
                    is_hidden_test=False,
                    fail_message=f"Fail docstring test method call case {i}: {source}",
                    method_name=method_name,
                    test_input=args,
                    output=output,
                ),
            )
        else:
            test_cases.append(
                TestCaseExpressionIsEq(
                    is_hidden_test=False,
                    fail_message=f"Fail docstring test expression case {i}: {source}",
                    input_expression=source,
                    output=output,
                ),
            )
    return test_cases


def _parse_assert_style_test_cases(code: str, test_method_name: str) -> list[TestCase]:
    # TODO split several asserts into different test cases
    # pull out just the body
    code += f"\ncheck({test_method_name})"
    return [
        TestCaseArbitraryMethod(
            is_hidden_test=True,
            fail_message=None,
            method_body=code,
            lang_spec=PythonLangSpec(),
        ),
    ]


def main():
    exs = next(yield_human_eval())
    logger.info(exs)


if __name__ == "__main__":
    main()
