from collections import Counter
import dataclasses
from itertools import chain, islice
from pprint import pprint

from synthegrator.code_problems import CodeProblem, TestCaseArbitraryMethod
from synthegrator.code_problem_builders import make_simple_file_generation, \
    make_simple_method_completion_problem
from synthegrator.lang_specs.lang_spec_python import PythonLangSpec
from synthegrator.sandboxing import PY_DEFAULT_DOCKER_ENV, ExecLimits
from synthegrator.synthdatasets.apps import apps_assertion_code_gen
from synthegrator.synthdatasets.livecodebench.authorcode.benchmarks.code_generation import \
    load_code_generation_dataset, TestType, Difficulty, LivecodeInternalCodeGenerationProblem
from synthegrator.util import IteratorWithLength

LIVECODEBENCH_CODE_ENV = (
    "import sys\nimport time\nimport itertools\nfrom itertools import accumulate,"
    " product, permutations, combinations\nimport collections\nfrom collections import"
    " Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import"
    " lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor,"
    " gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport"
    " numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
)


def _create_stdin_test_case(input_data: str, expected_output: str, index: int, is_hidden: bool) -> TestCaseArbitraryMethod:
    """
    Create a test case for STDIN/STDOUT testing.
    
    Args:
        input_data: The input data to provide via stdin
        expected_output: The expected output to stdout
        index: Test case index (for naming)
        is_hidden: Whether this is a hidden test case
        
    Returns:
        A TestCaseArbitraryMethod instance configured for stdin/stdout testing
    """
    def outter_shorten_repr(d):
        if len(repr(d)) > 100:
            return repr(d)[:150] + '...'
        return repr(d)
    short_input_data = outter_shorten_repr(input_data)
    short_expected_output = outter_shorten_repr(expected_output)

    return TestCaseArbitraryMethod(
        is_hidden_test=is_hidden,
        fail_message=f"Test case {index+1} failed",
        method_body=f"""
# LCB STDIN TEST CASE
import sys
import io
import os
from unittest.mock import patch

# Capture stdin/stdout
stdin_data = {repr(input_data)}
expected_output = {repr(expected_output)}

# Find the solution file
solution_path = os.path.join(os.getcwd(), 'solution.py')
assert os.path.exists(solution_path), "solution.py file does not exist"

# Create a patched stdin and capture stdout
#input_buffer = io.StringIO(stdin_data)
input_buffer = io.StringIO(stdin_data if stdin_data.endswith('\\n') else stdin_data+'\\n')
output_buffer = io.StringIO()

with patch('sys.stdin', input_buffer), patch('sys.stdout', output_buffer):
    # Execute the code directly instead of importing
    import solution
    solution.main()
    #with open(solution_path, 'r') as f:
    #    solution_code = f.read()
    #    exec(solution_code)
    
    # Get the captured output
    actual_output = output_buffer.getvalue()
    
    # Handle the case of single-line output without newlines
    if not actual_output.strip():
        actual_lines = []
    elif '\\n' not in actual_output:
        actual_lines = [actual_output.strip()]
    else:
        actual_lines = [line.rstrip() for line in actual_output.splitlines()]
    
    if not expected_output.strip():
        expected_lines = []
    elif '\\n' not in expected_output:
        expected_lines = [expected_output.strip()]
    else:
        expected_lines = [line.rstrip() for line in expected_output.splitlines()]
    
    # Simple string format to avoid variable reference issues
    if actual_lines != expected_lines:
        msg = f'Expected output does not match actual output.\\n'
        actual_lines_repr = repr(actual_lines)
        if len(actual_lines_repr) > 150:
            actual_lines_repr = actual_lines_repr[:150] + '...'
        def shorten(d):
            if len(repr(d)) > 150:
                return repr(d)[:150-1] + '...'
            return repr(d)
        msg += f'Input: {{shorten(stdin_data)}},\\nGot: {{shorten(actual_lines)}},\\nExpected: {{shorten(expected_lines)}}'
        raise AssertionError(msg)
""",
        lang_spec=PythonLangSpec(),
        fixtures=("monkeypatch", "capfd"),
    )


LIVECODEBENCH_DOCKER_ENV = dataclasses.replace(
    PY_DEFAULT_DOCKER_ENV,
    default_limits=ExecLimits(
        # The default seems to be about 6 seconds
        # https://github.com/LiveCodeBench/LiveCodeBench/blob/effd1e15b37cf0135e6aaceabc9ae0dc61a5e838/lcb_runner/runner/parser.py#L93
        # We'll give it a bit more time
        timeout_cpu_s=10,
        timeout_realtime_s=20,  # A bit more since might be multithreaded.
    )
)


def _produce_livecode_ex(ex: LivecodeInternalCodeGenerationProblem) -> CodeProblem:
    """
    Produce a CodeProblem for LiveCodeBench examples with STDIN test cases.
    
    Args:
        ex: A CodeGenerationProblem instance from the livecodebench dataset
        
    Returns:
        A CodeProblem instance
    
    Raises:
        ValueError: If not all test cases are of type STDIN
    """
    all_test_cases = list(chain(ex.private_test_cases, ex.public_test_cases))

    test_cases = []
    if all(tc.testtype == TestType.FUNCTIONAL for tc in all_test_cases):
        solution_start_good = ex.starter_code.strip().startswith("class Solution:")
        solution_lines_good = ex.starter_code.strip().count("\n") < 4
        if not (solution_start_good and solution_lines_good):
            raise ValueError("Starter code must be a simple function/class definition:\n" + ex.starter_code)
        signature = ex.starter_code.strip().split("\n")[1].strip()
        signature = signature.replace("self, ", "")
        func_name = ex.metadata["func_name"]
        for i, tc in enumerate(all_test_cases):
            is_hidden = tc in ex.private_test_cases
            test_case = _create_func_call_tc(tc.input, tc.output, i, is_hidden, func_name)
            test_cases.append(test_case)
        override_instructions = None
    else:
        if not all(tc.testtype == TestType.STDIN for tc in all_test_cases):
            raise ValueError(f"All test cases must be of type STDIN, got {set(tc.testtype for tc in all_test_cases)}")
        signature = "def main() -> None:"
        func_name = "main"
        for i, tc in enumerate(all_test_cases):
            is_hidden = tc in ex.private_test_cases
            test_case = _create_stdin_test_case(tc.input, tc.output, i, is_hidden)
            test_cases.append(test_case)
        override_instructions = (
            "Complete the following method. "
            "The solution requires reading from standard input and writing to standard output. "
            "You can do this with the input() and print() functions, respectively. "
            "Give only the solution body that can go inside `def main() -> None`. It is fine and recommended "
            "to have any needed imports inside the function. Do not include the `def main` function "
            "signature. Do not repeat the instructions docstring."
        )
    question_content = ex.question_content.replace("\r", "")
    return make_simple_method_completion_problem(
        prompt=f"{ex.question_title}\n\n{question_content}",
        signature=signature,
        test_cases=test_cases,
        override_instructions=override_instructions,
        #model_context=ex.starter_code,
        dataset_name="livecodebench",
        execution_context_hidden_from_model=LIVECODEBENCH_CODE_ENV,
        problem_id=f"{ex.platform.value}_{ex.question_id}",
        known_solutions=[],  # No known solutions for now
        docker_env=LIVECODEBENCH_DOCKER_ENV,
    )
    # Create the problem using the builder function
    return make_simple_file_generation(
        prompt=f"{ex.question_title}\n\n{ex.question_content}",
        test_cases=test_cases,
        model_context=ex.starter_code,
        dataset_name="livecodebench",
        problem_id=f"{ex.platform.value}_{ex.question_id}",
        known_solutions=[],  # No known solutions for now
    )


def _create_func_call_tc(
    input_data: str,
    expected_output: str,
    index: int,
    is_hidden: bool,
    func_name: str,
) -> TestCaseArbitraryMethod:
    """
    Create a test case for functional testing.

    Args:
        input_data: The input data to provide to the function
        expected_output: The expected output from the function
        index: Test case index (for naming)
        is_hidden: Whether this is a hidden test case
        func_name: The name of the function to call

    Returns:
        A TestCaseArbitraryMethod instance configured for functional testing
    """
    input_str = _lcb_assertion_code_gen(input_data)
    return TestCaseArbitraryMethod(
        is_hidden_test=is_hidden,
        fail_message=f"Test case {index+1} failed",
        method_body=f"assert {func_name}({input_str}) == {expected_output}",
        lang_spec=PythonLangSpec(),
    )


def _lcb_assertion_code_gen(x: str) -> str:
    if type(x).__name__ != "str":
        raise ValueError(f"Unsupported type: {type(x).__name__}")
    return ", ".join(x.split("\n"))
    match str(type(x).__name__):
        case "str":
            y = x
        case "list":
            y = str(x)
        case "bool":
            y = str(x)
        case "int":
            y = str(x)
        case _:
            msg = "Unsupported type"
            raise Exception(msg, type(x))
    return y


def yield_livecode_problems(
    filter_difficulty: Difficulty | str = None,
    filter_test_type: TestType = None,
    version_tag: str = "release_v6",
    max_problems: int = None,
):
    filter_difficulty = (
        Difficulty(filter_difficulty) if isinstance(filter_difficulty, str) else filter_difficulty
    )
    def itr():
        num_yield = 0
        for ex in load_code_generation_dataset(version_tag):
            if max_problems and num_yield >= max_problems:
                break
            if filter_difficulty and ex.difficulty != filter_difficulty:
                continue
            if filter_test_type:
                all_test_cases = list(chain(ex.private_test_cases, ex.public_test_cases))
                if not all(tc.testtype == filter_test_type for tc in all_test_cases):
                    continue
            prob = _produce_livecode_ex(ex)
            yield prob
            num_yield += 1
    return IteratorWithLength(itr(), max_problems)


def explore():
    """Print counts of different test case types in the livecodebench dataset."""
    ct = Counter()
    for ex in load_code_generation_dataset():
        for tc in chain(ex.private_test_cases, ex.public_test_cases):
            ct[tc.testtype] += 1
    print(ct)


if __name__ == "__main__":
    explore()
