import ast
import logging
logger = logging.getLogger(__name__)
import math
from collections import defaultdict
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from timeit import default_timer as timer
from typing import Any

import diskcache
from lxml import etree

from synthegrator.code_problems import (
    CodeProblem,
    CodeSolution,
    DiscoveredTestsuite,
    TestCase,
)
from synthegrator.lang_specs.lang_spec_python import PythonLangSpec
from synthegrator.memory_fs import ProjectDir
from synthegrator.sandboxing import (
    PY_DEFAULT_DOCKER_ENV,
    ExecLimits,
    TestsuiteExecResult,
    pytest_on_docker,
    run_on_docker, CmdExecResult, Cmd,
)
from synthegrator.transformation_spec import (
    MarkElement,
    MarkText,
    SolveStep,
    StsCodeClassificationQuestion,
    StsMarkNode,
    StsPlaceTransforming,
    StsValueQuery,
    markup_path,
)
from synthegrator.util import (
    get_first_line_with_nonwhitespace,
    get_only_item,
    pretty_print_python_code,
)

eval_cache = diskcache.Cache("eval_cache")
# TODO make this user configurable


@dataclass
class TestRunResult:
    success: bool
    output: Any
    completed: bool
    fail_message: str = None
    fail_type: str = None
    fail_text: str = None
    id_: str = None
    name: str = None


@dataclass
class TestRunSummary:
    """Store pytest result summary."""

    __test__ = False
    name: str
    errors: int
    failures: int
    skipped: int
    tests: int
    time: timedelta
    timestamp: datetime
    hostname: str
    valid_xml: bool = True


class TestRunResultList(list[TestRunResult]):
    __test__ = False
    METRIC_NAME: str = "test_passing"

    def __init__(
        self,
        code_problem: CodeProblem,
        summary: TestRunSummary = None,
        initial_list=None,
        meta_error: TestRunResult = None,
        runtime: float | None = None,
        syntax_check_result: bool | None = None,
    ) -> None:
        super().__init__(initial_list or [])

        self.code_problem = code_problem
        self.summary = summary
        self.runtime = runtime
        self.syntax_check_result = syntax_check_result
        self.collection_error = None
        self.exec_error = None
        self.timeout = None


        if meta_error is not None:
            raise NotImplementedError

    def passed(self) -> int:
        if self.summary is None:
            return 0
        return (
            (self.summary.tests - self.summary.skipped) - self.summary.errors
        ) - self.summary.failures

    def all_passed(self) -> bool:
        return len(self) > 0 and all(t.success for t in self)
        if not self.summary:
            return False
        if self.syntax_check_result is False:
            return False
        return self.summary.errors == 0 and self.summary.failures == 0

    def pass_rate(self) -> float:
        if not self.summary:
            return 0.0
        if not self.syntax_check_result and (
            self.summary.errors is None or self.summary.tests is None
        ):
            return 0.0
        return self.passed() / max(
            (self.summary.tests - self.summary.skipped),
            1,
        )

    def as_metric(self) -> "SolutionMetric":
        return SolutionMetric(
            is_success=self.all_passed(),
            float_val=self.pass_rate(),
            label=self.METRIC_NAME,
        )

    def __str__(self) -> str:
        if not self.summary:
            if self.syntax_check_result is False:
                return "Syntax check failed!"
            return "No summary available!"
        table = "{:<15} | {:<15}\n".format("Attribute", "Value")  # header
        table += "{:<15} | {:<15}\n".format(
            "-" * 15,
            "-" * 15,
        )  # separating line
        table += "{:<15} | {:<15}\n".format("Name", self.summary.name)
        table += "{:<15} | {:<15}\n".format("Errors", str(self.summary.errors))
        table += "{:<15} | {:<15}\n".format(
            "Failures",
            str(self.summary.failures),
        )
        table += "{:<15} | {:<15}\n".format(
            "Skipped",
            str(self.summary.skipped),
        )
        table += "{:<15} | {:<15}\n".format("Tests", str(self.summary.tests))
        table += "{:<15} | {:<15}\n".format("Time", str(self.summary.time))
        table += "{:<15} | {:<15}\n".format(
            "Timestamp",
            str(self.summary.timestamp),
        )
        table += "{:<15} | {:<15}\n".format("Hostname", self.summary.hostname)

        return table


@dataclass(frozen=True)
class SolutionMetric:
    is_success: bool | None
    """Whether this result is considered a success. If there is no definition
    of success (for example a certain edit-distance or CodeBLEU), then this
    might be None"""
    float_val: float
    """A float representation of the metric"""
    label: str
    """The kind of this metric"""


@dataclass(frozen=True)
class SolutionEvaluation:
    solution: CodeSolution
    test_results: TestRunResultList | None
    extra_metrics: list[SolutionMetric]
    main_metric: SolutionMetric | None
    exception: Exception | None = None
    exception_traceback: str | None = None

    @property
    def has_exception(self) -> bool:
        """Whether this evaluation did not complete do to an unexpected exception"""
        return self.exception is not None

    def get_metric_from_key(self, key: str) -> SolutionMetric | None:
        if self.main_metric and self.main_metric.label == key:
            return self.main_metric
        for metric in self.extra_metrics:
            if metric.label == key:
                return metric
        return None


MetricFunc = Callable[[CodeSolution], SolutionMetric | Sequence[SolutionMetric]]


def exact_match_metric(solution: CodeSolution, strip: bool = False) -> SolutionMetric:
    solve_marks = {
        (step.path, step.mark_id): step.value for step in solution.solve_steps
    }

    known_solution_marks = {
        (step.path, step.mark_id): step.value
        for sol in solution.problem.known_solutions
        for step in sol.solve_steps
    }

    if strip:
        solve_marks = {key: value.strip() for key, value in solve_marks.items()}
        known_solution_marks = {
            key: value.strip() for key, value in known_solution_marks.items()
        }

    exact_match = len(solve_marks) == len(known_solution_marks) and all(
        solve_marks.get(key) == known_solution_marks.get(key)
        for key in solve_marks.keys()
    )

    return SolutionMetric(
        is_success=exact_match,
        float_val=1.0 if exact_match else 0.0,
        label="exact_match",
    )


def create_exact_match_metric(strip=True) -> Callable:
    return lambda solution: exact_match_metric(solution, strip=strip)


class SolutionEvaluator:
    """
    Used to evaluate a solution made by a solver. It is configurable (and inherritable)
    in case want to try other metrics.
    """

    def __init__(
        self,
        metric_funcs: tuple[MetricFunc, ...] | None = None,
        do_not_execute_if_syntax_fail: bool = True,
        use_execution_cache: bool = True,
    ):
        self._extra_metric_funcs = metric_funcs
        self._do_not_execute_if_syntax_fail = do_not_execute_if_syntax_fail
        self._use_execution_cache = use_execution_cache

    @classmethod
    def default_instance(cls):
        return _solution_evaluator_default

    def evaluate(self, solution: CodeSolution) -> SolutionEvaluation:
        if solution.problem.test_cases:
            test_results = evaluate_code_problem_execution(
                solution.problem,
                solution,
                do_not_execute_if_syntax_fail=self._do_not_execute_if_syntax_fail,
                try_cache=self._use_execution_cache,
            )
        else:
            test_results = None
        spec = solution.problem.transformation_spec
        if len(list(spec.classification_question_statements())) > 0:
            classification_metrics = classification_metric_function(solution)
        else:
            classification_metrics = []
        extra_metrics = [*classification_metrics]
        if self._extra_metric_funcs:
            for metric_func in self._extra_metric_funcs:
                metric = metric_func(solution)
                if isinstance(metric, SolutionMetric):
                    extra_metrics.append(metric)
                else:
                    extra_metrics.extend(metric)
        return SolutionEvaluation(
            solution,
            test_results,
            extra_metrics,
            test_results.as_metric() if test_results is not None else extra_metrics[0],
        )


_solution_evaluator_default = SolutionEvaluator(None)


def classification_metric_function(
    solution: CodeSolution,
) -> Sequence[SolutionMetric]:
    classifications = list(
        solution.problem.transformation_spec.classification_question_statements(),
    )
    if len(classifications) == 0:
        return []
    if len(classifications) > 1:
        msg = "Multiple classification questions not supported"
        raise NotImplementedError(msg)
    classification = classifications[0].verb
    assert isinstance(classification, StsCodeClassificationQuestion)
    if not classification.operation_is_static():
        msg = "Classification question must be static"
        raise ValueError(msg)

    chosen_answer = get_only_item(
        [
            solve_step
            for solve_step in solution.solve_steps
            if solve_step.mark_id == classification.static_mark_id
        ],
    ).value
    chosen_answer = classification.value_to_prob_dict(chosen_answer)

    if solution.problem.known_solutions is None:
        return []
    solution_answer = get_only_item(
        [
            solve_step
            for solve_step in solution.problem.known_solutions[0].solve_steps
            if solve_step.mark_id == classification.static_mark_id
        ],
    ).value
    solution_answer = classification.value_to_prob_dict(solution_answer)

    # find the key with the highest value
    chosen_answer_key = max(chosen_answer, key=chosen_answer.get)
    solution_answer_key = max(solution_answer, key=solution_answer.get)
    match_correct = chosen_answer_key == solution_answer_key

    # Correct probability
    correct_answer_prob = chosen_answer[solution_answer_key]
    cross_entropy = -sum(
        [
            solution_answer[key] * math.log(chosen_answer[key] + 1e-10)
            for key in solution_answer
        ],
    )

    return [
        SolutionMetric(
            is_success=match_correct,
            float_val=correct_answer_prob,
            label="classification_correct_prob",
        ),
        SolutionMetric(
            is_success=match_correct,
            float_val=1.0 if match_correct else 0.0,
            label="classification_correct",
        ),
        SolutionMetric(
            is_success=match_correct,
            float_val=cross_entropy,
            label="classification_cross_entropy",
        ),
    ]


def evaluate_code_problem_execution(
    code_problem: CodeProblem,
    solution: CodeSolution,
    user_prompted_guard: bool = False,
    do_not_execute_if_syntax_fail: bool = True,
    try_cache: bool = False,
) -> TestRunResultList:
    """
    Takes a code problem and formats it with a solution. Then it takes all
    the test cases and assembles them into pytest code to run.
    """
    start = timer()
    syntax_check = is_solution_syntax_likely_valid(solution)
    if (not syntax_check) and do_not_execute_if_syntax_fail:
        return TestRunResultList(code_problem, syntax_check_result=syntax_check)
    if any(isinstance(tc, DiscoveredTestsuite) for tc in code_problem.test_cases):
        if len(code_problem.test_cases) != 1:
            msg = (
                "Currently does not support mixing discovered test cases with"
                " other test cases"
            )
            raise NotImplementedError(
                msg,
            )
        tc = code_problem.test_cases[0]

        if not isinstance(tc, DiscoveredTestsuite):
            msg = "Test case is not of DiscoveredTestsuite type"
            raise TypeError(msg)

        assert isinstance(tc, DiscoveredTestsuite)
        exec_result = _evaluate_suite_problem(
            solution,
            tc,
            user_prompted_guard,
            try_cache=try_cache,
        )
    else:
        exec_result = _evaluate_self_contained_problem(
            solution,
            code_problem.test_cases,
            user_prompted_guard,
            try_cache=try_cache,
        )
    result = parse_junit_test_cases(
        exec_result.xml_result, code_problem)
    result.syntax_check_result = syntax_check
    result.runtime = timer() - start
    result.collection_error = exec_result.collection_error
    result.exec_error = exec_result.exec_error
    if exec_result.test_suite_exec_result is not None:
        result.timeout = exec_result.test_suite_exec_result.timeout
    return result


def is_solution_syntax_likely_valid(solution: CodeSolution) -> bool:
    """Hacky just check of the syntax of a module"""
    new_files = apply_solution(solution)
    dirty_files = new_files.walk(include_dirs=False, only_consider_dirty=True)
    is_valid = True
    for path, file in dirty_files:
        lang_spec = solution.problem.get_lang_spec_for_path(path)
        is_valid &= lang_spec.check_no_syntax_errors(file.content_str)
    return is_valid


def _evaluate_suite_problem(
    solution: CodeSolution,
    discovered_suite: DiscoveredTestsuite,
    user_prompted_guard: bool,
    try_cache: bool = False,
) -> TestsuiteExecResult:
    if user_prompted_guard:
        msg = "User prompted guard not supported for suite problems"
        raise ValueError(msg)
    new_files = apply_solution(solution)

    problem = solution.problem
    docker_env = problem.environment.docker_env or PY_DEFAULT_DOCKER_ENV
    limits = docker_env.default_limits or ExecLimits()
    if try_cache:
        cache_key = (
            f"{new_files.get_hash()}{limits!s}{docker_env.get_hash()}"
            f"{discovered_suite.get_hash()}"
        )
        cached_result = eval_cache.get(cache_key)
        if cached_result:
            #print("Cache hit")
            #assert isinstance(cached_result, TestsuiteExecResult)
            #print(cached_result.test_suite_exec_result.cmd.cmd)
            #print(cached_result.test_suite_exec_result.stderr.decode()[:5000])
            return cached_result

    cmd_outputs = run_on_docker(
        docker_context=docker_env,
        cmds=discovered_suite.cmds,
        files=new_files,
        limits=limits,
        interactive_shell_debug=False,
    )
    # for cmd_output in cmd_outputs:
    #     import colorama
    #     # print the command in green
    #     print(colorama.Fore.GREEN + "$ " + cmd_output.cmd.cmd)
    #     # print the stdout in white
    #     print(colorama.Fore.WHITE + cmd_output.stdout.decode())
    #     # print the stderr in red
    #     print(colorama.Fore.RED + cmd_output.stderr.decode())
    #     # reset the colors back to normal
    #     print(colorama.Style.RESET_ALL)
    result = discovered_suite.cmds_results_to_suite(cmd_outputs)

    if try_cache:
        eval_cache[cache_key] = result

    return result


def _evaluate_self_contained_problem(
    solution: CodeSolution,
    test_cases: list[TestCase],
    user_prompted_guard: bool,
    try_cache: bool = False,
) -> TestsuiteExecResult:
    new_files = apply_solution(solution)
    code_problem = solution.problem
    step = solution.solve_steps[0]
    edit_path = step.path
    all_tests_src = [
        test_case.format_for_test_framework(lang_spec=PythonLangSpec())
        for test_case in test_cases
    ]
    all_tests_src = "\n\n".join(all_tests_src)
    runable_code = new_files[edit_path].content_str  # TODO: avoid redecoding somehow
    runnable_code_with_tests = runable_code + "\n\nimport pytest\n" + all_tests_src
    if user_prompted_guard:
        pretty_print_python_code(runnable_code_with_tests)
        if input("Run? [y/N] ").lower() != "y":
            msg = "User aborted"
            raise ValueError(msg)
    new_files = new_files.set_file_contents(edit_path, runnable_code_with_tests)
    use_docker_env = code_problem.environment.docker_env or PY_DEFAULT_DOCKER_ENV
    if try_cache:
        cache_key = f"{new_files.get_hash()}{edit_path}{use_docker_env.get_hash()}"
        cached_result = eval_cache.get(cache_key)
        if cached_result:
            return cached_result

    result = pytest_on_docker(
        files=new_files,
        run_file=edit_path,
        docker_env=use_docker_env,
    )
    if try_cache:
        eval_cache[cache_key] = result
    return result


def _group_steps_by_path(
    steps: Sequence[SolveStep],
) -> dict[str, list[SolveStep]]:
    path_to_steps = defaultdict(list)
    for step in steps:
        path_to_steps[step.path].append(step)
    return path_to_steps


def apply_steps_to_markup(
    content_node: StsMarkNode,
    steps: Sequence[SolveStep],
    still_hide_invisible: bool = False,
    stop_after_first_change: bool = False,
) -> str:
    key_to_step = {step.mark_id: step for step in steps}
    if len(key_to_step) != len(steps):
        msg = "The same mark_id was used twice in a solution"
        raise ValueError(msg)
    text = []
    visitor = content_node.iterate_parser_states()
    try:
        state = next(visitor)
    except StopIteration:
        state = None
    while state:
        if state is None:
            break
        visit_children = True
        node = state.node
        if isinstance(node, MarkElement) and node.mark_id in key_to_step:
            value = key_to_step[node.mark_id].value
            _verify_solution_value_is_valid_place_transforming(node, value)
            text.append(value)
            if stop_after_first_change:
                break
            visit_children = False
        elif isinstance(node, MarkText):
            if not still_hide_invisible or state.is_visible:
                text.append(node.text)
        try:
            state = visitor.send(visit_children)
        except StopIteration:
            state = None
    return "".join(text)


def make_mark_id_to_node(
    root_node: StsMarkNode,
) -> dict[str, MarkElement]:
    visitor = root_node.depth_first_iterate()
    mark_id_to_node = {}
    try:
        node = next(visitor)
    except StopIteration:
        node = None
    while node:
        if isinstance(node, MarkElement):
            mark_id_to_node[node.mark_id] = node
        try:
            node = visitor.send(True)
        except StopIteration:
            node = None
    return mark_id_to_node


def _verify_solution_value_is_valid_place_transforming(node, value):
    if node.verb.max_length_chars is not None:
        raise NotImplementedError
    _verify_max_length_lines(node, value)

    if not isinstance(node.verb, StsPlaceTransforming):
        msg = "Node verb is not of StsPlaceTransforming type"
        raise TypeError(msg)

    _verify_end_text(node, value)
    _verify_space_indent(node, value)
    # TODO: verify respect the the stop_at_block_end flag


def _verify_max_length_lines(node, value):
    if (
        node.verb.max_length_lines is not None
        and len(value.splitlines()) > node.verb.max_length_lines
    ):
        msg = f"Too many lines in the solution for {node.mark_id}"
        raise ValueError(msg)


def _verify_space_indent(node, value):
    verb = node.verb
    if verb.set_base_space_indent:
        first_line = get_first_line_with_nonwhitespace(value)
        if first_line.strip() != "" and not first_line.startswith(
            " " * verb.set_base_space_indent
        ):
            msg = (
                f"Solution for {node.mark_id} must start with"
                f" {verb.set_base_space_indent} spaces.\n"
                " Value:\n"
                f" {value!r}"
            )
            raise ValueError(
                msg,
            )


def _verify_end_text(node, value):
    if node.verb.set_ending_text and not value.endswith(
        node.verb.set_ending_text,
    ):
        escaped_expected = node.verb.set_ending_text.replace("\n", "\\n")
        msg = f"Solution for {node.mark_id} must end with '{escaped_expected}'"
        raise ValueError(
            msg,
        )


def apply_solution(
    solution: CodeSolution,
    still_hide_invisible: bool = False,
    stop_after_first_change: bool = False,
) -> ProjectDir:
    """Produces new working directory files after apply the steps in a solution"""
    problem = solution.problem
    path_to_steps = _group_steps_by_path(solution.solve_steps)
    wd = problem.working_directory
    files = wd.files
    for path, steps in path_to_steps.items():
        markup_node = markup_path(wd, path, problem.transformation_spec)
        new_content = apply_steps_to_markup(
            markup_node,
            steps,
            still_hide_invisible=still_hide_invisible,
            stop_after_first_change=stop_after_first_change,
        )
        files = files.set_file_contents(path, new_content)
    return files


junit_xml_schema = (Path(__file__) / "../JUnit.xsd").resolve().read_bytes()


def parse_junit_test_cases(
    xml_text: str,
    problem: CodeProblem,
) -> TestRunResultList:
    results = TestRunResultList(problem)
    if xml_text == "":
        logger.debug("EMPTY XML!!!!")
        return results
    schema_root = etree.XML(junit_xml_schema)
    schema = etree.XMLSchema(schema_root)

    strict_parser = etree.XMLParser(schema=schema, dtd_validation=False)
    xml_valid = False
    # ValueError: Unicode strings with encoding declaration are not supported.
    # Please use bytes input or XML fragments without declaration.
    xml_bytes = xml_text.encode(encoding="utf-8")
    try:
        etree.fromstring(xml_bytes, strict_parser)
        xml_valid = True
    except etree.XMLSyntaxError:
        xml_valid = False
        logger.warning("Error! JUnit test result is not valid XML")

    junit_xml = etree.fromstring(xml_bytes)
    testsuite = junit_xml.find("testsuite")
    if testsuite is None:
        # Deal with the java junit output where the root is just the testsuite
        #   and no namesapce for the xml
        testsuite = junit_xml.xpath("//*[local-name() = 'testsuite']")
        assert len(testsuite) == 1
        testsuite = testsuite[0]

    if testsuite.get("timestamp", None) is None:
        timestamp = None
    else:
        timestamp = datetime.fromisoformat(testsuite.get("timestamp"))
    testsuite_summary = TestRunSummary(
        name=testsuite.get("name"),
        errors=int(testsuite.get("errors", 0)),
        failures=int(testsuite.get("failures", 0)),
        skipped=int(testsuite.get("skipped", 0)),
        tests=int(testsuite.get("tests")),
        time=timedelta(seconds=float(testsuite.get("time", 0))),
        timestamp=timestamp,
        hostname=testsuite.get("hostname", ""),
        valid_xml=xml_valid,
    )
    results.summary = testsuite_summary

    for case in junit_xml.iter("testcase"):
        fail_val = case.find("failure")
        if fail_val is None:
            fail_val = case.find("error")
        passed = fail_val is None
        results.append(
            TestRunResult(
                success=passed,
                output=None,
                completed=True,
                fail_message=(
                    fail_val.attrib.get("message", None)
                    if fail_val is not None
                    else None
                ),
                fail_type=fail_val.get("type", None) if fail_val is not None else None,
                fail_text=fail_val.text if fail_val is not None else None,
                id_=case.get("id", None),
                name=case.get("name", None),
            ),
        )
    return results


def pull_value_answers_from_solution(
    solution: CodeSolution,
) -> list[tuple[MarkElement[StsValueQuery], str]]:
    problem = solution.problem
    path_to_steps = _group_steps_by_path(solution.solve_steps)
    wd = problem.working_directory
    values = []
    for path, steps in path_to_steps.items():
        question_steps = [
            step for step in steps if isinstance(step.value, StsValueQuery)
        ]
        markup_node_root = markup_path(wd, path, problem.transformation_spec)
        mark_id_to_node = make_mark_id_to_node(markup_node_root)
        for step in question_steps:
            node = mark_id_to_node[step.mark_id]
            values.append((node, step.value))
    return values
