import json
import logging
logger = logging.getLogger(__name__)
from collections.abc import Iterator
from typing import Any

from synthegrator.code_problem_builders import (
    make_simple_file_generation,
    make_simple_method_completion_problem,
)
from synthegrator.code_problems import (
    CodeProblem,
    TestCase,
    TestCaseArbitraryMethod,
)
from synthegrator.lang_specs.lang_spec_python import PythonLangSpec

APPS_CODE_ENV = (
    "import sys\nimport time\nimport itertools\nfrom itertools import accumulate,"
    " product, permutations, combinations\nimport collections\nfrom collections import"
    " Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import"
    " lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor,"
    " gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport"
    " numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
)


def produce_app_file_generation(
    task_id: str,
    prompt: str,
    canonical_solutions: list[str],
    inputs_outputs: dict[str, list[str]],
    starter_code: str,
) -> None:
    assert not isinstance(canonical_solutions, str), "Expect list of solutions?"
    return make_simple_file_generation(
        prompt=prompt,
        test_cases=parse_input_output_to_test_case(inputs_outputs, starter_code),
        model_context="",
        execution_context_hidden_from_model=APPS_CODE_ENV,
        dataset_name="apps",
        problem_id=task_id,
        known_solutions=canonical_solutions,
    )


def produce_app_method_problem(
    task_id: str,
    prompt: str,
    canonical_solutions: list[str],
    inputs_outputs: dict[str, list[str]],
    starter_code: str,
) -> CodeProblem:
    signature = starter_code.rstrip().replace("\t", "    ") if starter_code else ""
    assert not isinstance(canonical_solutions, str), "Expect list of solutions?"
    return make_simple_method_completion_problem(
        prompt,
        signature=signature,
        test_cases=parse_input_output_to_test_case(inputs_outputs, starter_code),
        problem_id=task_id,
        known_solutions=canonical_solutions,
        lang_spec=PythonLangSpec(),
        execution_context_hidden_from_model=APPS_CODE_ENV,
        prompt_already_docstring=True,
    )


def yield_apps(stream_subset=False) -> Iterator[CodeProblem]:
    from datasets import load_dataset

    # for ci, take 10 examples only
    if stream_subset:
        dataset = load_dataset(
            "codeparrot/apps",
            split="test",
            streaming=True,  # , trust_remote_code=True
        ).take(10)
    else:
        dataset = load_dataset(
            "codeparrot/apps",
            split="test",  # , trust_remote_code=True
        ).shuffle()

    for ex in dataset:
        # if no example input/output and no available solutions, skip
        if not (ex["input_output"] and ex["solutions"]):
            continue

        # load solutions
        try:
            ex["solutions"] = json.loads(ex["solutions"])  # TODO only one solution?
        except ValueError:
            logger.info("Unable to load solution %s", ex["solutions"])
            continue

        try:
            ex["input_output"] = json.loads(ex["input_output"])
        except ValueError:
            # skip if unable to load input/output example
            continue

        # see apps/eval/testing_util.py:145
        if ex["input_output"].get("fn_name") is None:
            # function has no name -> module level
            yield produce_app_file_generation(
                task_id=ex["problem_id"],
                prompt=ex["question"],
                canonical_solutions=ex["solutions"],
                inputs_outputs=ex["input_output"],
                starter_code=ex["starter_code"],
            )
        else:
            # function has a name
            yield produce_app_method_problem(
                task_id=ex["problem_id"],
                prompt=ex["question"],
                canonical_solutions=ex["solutions"],
                inputs_outputs=ex["input_output"],
                starter_code=ex["starter_code"],
            )


def parse_input_output_to_test_case(
    inputs_outputs: dict[str, any],
    starter_code: str,
) -> list[TestCase]:
    # TODO split several asserts into different test cases
    # pull out just the body
    inputs = inputs_outputs["inputs"]
    outputs = inputs_outputs["outputs"]
    fn_name = inputs_outputs.get("fn_name")

    test_cases = []
    for i, input_ex in enumerate(inputs):
        # see apps/eval/testing_util.py:243
        # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
        try:
            if isinstance(inputs[0], dict):
                inputs = [{int(k): v for k, v in inputs[0].items()}]
        except KeyError:
            pass

        try:
            if isinstance(outputs[i], dict):
                outputs[i] = [{int(k): v for k, v in outputs[i].items()}]
        except KeyError:
            pass

        try:
            if isinstance(outputs[i][0], dict):
                outputs[i] = [{int(k): v for k, v in outputs[i][0].items()}]
        except KeyError:
            pass

        output_ex = outputs[i]

        if isinstance(input_ex, str) and isinstance(output_ex, str) and not fn_name:
            code = _test_case_for_builtins_str(
                i,
                input_ex.replace("\\", "\\\\"),
                output_ex.replace("\\", "\\\\"),
            )
        elif not isinstance(input_ex, str) and not fn_name:
            code = _test_case_for_builtins(i, input_ex, output_ex)
        elif fn_name:
            code = _test_case_for_named_function(
                i,
                input_ex,
                output_ex,
                fn_name,
                starter_code,
            )

        test_case = TestCaseArbitraryMethod(
            is_hidden_test=True,
            fail_message=None,
            method_body=code,
            lang_spec=PythonLangSpec(),
            fixtures=("monkeypatch", "capfd"),
        )
        test_cases.append(test_case)
    return test_cases


def _test_case_for_builtins_str(i: int, function_input: str, output: str) -> str:
    return f"""
inputs = iter(\"\"\"{function_input}\"\"\".split("\\n"))
monkeypatch.setattr('builtins.input', lambda: next(inputs))
main()
check_outputs(capfd)
out, err = capfd.readouterr()
assert out == \"\"\"{output}\"\"\"
""".strip()


def _test_case_for_builtins(i: int, function_input: any, output: any) -> str:
    return f"""
inputs = iter({function_input})
monkeypatch.setattr('builtins.input', lambda: next(inputs))
main()
out, err = capfd.readouterr()
assert out == {output}
""".strip()


def apps_assertion_code_gen(x: Any) -> str:
    match str(type(x).__name__):
        case "str":
            y = '["' + x + '"]'
        case "list":
            y = str(x)
        case "bool":
            y = str(x)
        case "int":
            y = str(x)
        case _:
            msg = "Unsupported type"
            raise Exception(msg, type(x))
    return y


def _test_case_for_named_function(
    i: int,
    ref_input: Any,
    ref_output: Any,
    fn_name: str,
    starter_code: str,
) -> str:
    args = apps_assertion_code_gen(ref_input)
    out = apps_assertion_code_gen(ref_output)

    """
    TODO: the reference implementation works like this:
        if isinstance(ref_output, tuple):
            output = list(output)

        tmp_result = output == ref_output
        if isinstance(ref_input, list) and ref_output:
            tmp_result = tmp_result or (output == ref_input[0])

        # ground truth sequences are not tuples
        try:
            if isinstance(output[0], tuple):
                tmp_result = tmp_result or ([list(x) for x in output] == output[0])
        except:
            True
    """

    if "class Solution:" in starter_code:  # TODO replace with regex or treesitter
        if f"{fn_name}(self," in starter_code:
            code = f"""assert Solution().{fn_name}(*{args}) == {out}"""
        else:
            code = f"""assert Solution.{args}) == {out}"""
    else:
        code = f"""assert {fn_name}({args}) == {out}"""
    return code
