"""
Various utilities for rendering code problems into prompt strings.
These are intended as fairly standard versions. Certain models might
choose to do some form of specialized conversions.
"""

import dataclasses
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable

from lmwrapper.abstract_predictor import LmPredictor
from lmwrapper.structs import ChatGptRoles, LmChatDialog, LmChatTurn, LmPrompt

from synthegrator.code_problems import (
    CodeProblem,
    TestCase,
    TestCaseMethodCallIsEq,
)
from synthegrator.few_shotting import FewShotConfig
from synthegrator.lang_specs.lang_spec_python import PythonLangSpec
from synthegrator.prompting_test_case_selection import (
    PromptingTestCaseSelectionStrategy,
)
from synthegrator.transformation_spec import (
    MarkElement,
    MarkText,
    StsPlaceTransforming,
    get_verbs_per_path,
    markup_path,
)
from synthegrator.util import ensure_newline_at_end


@dataclasses.dataclass(frozen=True)
class LmPromptRender:
    prompt: LmPrompt


class ProblemNotSupportedError(ValueError):
    """
    Raised by a PromptRender when the .render is given
    a problem that the renderer does not support.
    """


@dataclasses.dataclass(frozen=True)
class LmPromptRenderSingleEdit(LmPromptRender):
    code_prefix: str
    """The code prefix is the source code given to the model
    that is a prefix to any output solution. This is useful because
    sometimes dialogue models will repeat part of the prompt in the
    output. Thus can be used extract out just the new contents"""
    path: str
    """The path of the file that is being edited. Needed when making
    the solve step."""
    changed_element: MarkElement[StsPlaceTransforming] | None
    """The element that is changed in by this edit. Needed when making the
     solve step."""


@dataclasses.dataclass(frozen=True)
class SingleLineBugRender(LmPromptRender):
    code_prefix: str
    """The code prefix is the source code given to the model
    that is a prefix to any output solution. This is useful because
    sometimes dialogue models will repeat part of the prompt in the
    output. Thus can be used extract out just the new contents"""
    path: str
    """The path of the file that is being edited. Needed when making
    the solve step."""
    changed_element: MarkElement[StsPlaceTransforming] | None
    """The element that is changed in by this edit. Needed when making the
     solve step."""


class PromptRenderer(ABC):
    def __init__(self):
        pass

    @abstractmethod
    def render(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> LmPromptRender:
        """
        Renders a code problem into a prompt for a language model.
        :param problem: The code problem to render
        :param lm: The language model to render for. Useful if want to
            determine whether to make a dialog or how long can generate.
        :param prompt_seed: A seed to use for any randomization (e.g. for
             few shot seed.)
        """
        raise NotImplementedError

    def __call__(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> LmPromptRender:
        return self.render(problem, lm, prompt_seed)


class PromptRendererSingleEdit(PromptRenderer):
    def __init__(self):
        super().__init__()

    def render(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> LmPromptRenderSingleEdit:
        raise NotImplementedError

    def __call__(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> LmPromptRenderSingleEdit:
        return self.render(problem, lm, prompt_seed)


RenderFuncSingleEdit = Callable[
    [CodeProblem, LmPredictor, int | None],
    LmPromptRenderSingleEdit,
]


def render_markup_to_text_autoregressive(
    markup: MarkElement,
    path: str,
    problem: CodeProblem,
    include_instructions: bool | None = None,
) -> LmPromptRenderSingleEdit:
    """Renders markup up to the first edit/insert statement. Masks any invisible text."""
    text = []
    final_instructions = ""
    if include_instructions is None:
        include_instructions = (
            problem.instructions_are_essential and problem.instructions
        )
    if include_instructions:
        lang_spec = problem.get_lang_spec_for_path(path)
        final_instructions = (
            # fmt: off
            f"{lang_spec.get_comment_line_start() * 2} Instructions\n"
            + f"{lang_spec.get_comment_line_start()} ".join(
                problem.instructions.splitlines(keepends=True)
            )
            + "\n"
            + f"{lang_spec.get_comment_line_start() * 2}\n"
            # fmt: on
        )

    def _finish_up(node: MarkElement[StsPlaceTransforming]):
        joined_text = "".join(text)
        stop = []
        if node is not None and node.verb.max_length_lines == 1:
            stop = ["\n"]
        if node is not None and node.verb.stop_at_block_end:
            # Extra stops like in https://github.com/bigcode-project/bigcode-evaluation-harness/blob/136b93c0aea/lm_eval/tasks/humaneval.py#L54
            # TODO multilingual
            if node.verb.lang_spec_name == "python":
                # TODO make it work when not starting at the 0 block
                # TODO move this to autoregressive renderer
                stop.extend(["\nclass", "\ndef", "\n#"])
        if len(stop) == 0:
            stop = None

        prompt = LmPrompt(
            final_instructions + joined_text,
            stop=stop,
        )
        return LmPromptRenderSingleEdit(
            prompt,
            joined_text,
            path,
            node,
        )

    for state in markup.iterate_parser_states():
        if isinstance(state.node, MarkText):
            if state.is_visible:
                text.append(state.node.text)
        elif isinstance(state.node, MarkElement) and state.is_editable:
            return _finish_up(state.node)
    return _finish_up(None)


def render_markup_to_plain_text(
    markup: MarkElement,
) -> str:
    text = []
    for state in markup.iterate_parser_states():
        if isinstance(state.node, MarkText) and state.is_visible:
            text.append(state.node.text)
    return "".join(text)


def render_single_edit_codeproblem_autoregressive(
    problem: CodeProblem,
    prompt_seed: int | None = None,
) -> LmPromptRenderSingleEdit:
    """
    Basic rendering (as in converting to a string prompt) of a code problem as
    a single file autoregressive prompt.
    Only works if there is a single edit or insertion point in the TransformationSpec
    """
    spec = problem.transformation_spec
    if spec.count_editing_statements() != 1:
        msg = "Currently simple and assumes a exactly one edit statement"
        raise NotImplementedError(
            msg,
        )
    path_to_verbs = get_verbs_per_path(problem.working_directory, spec)
    if len(path_to_verbs) != 1:
        msg = "Currently simple and assumes a single path"
        raise ValueError(msg)

    path, _verbs = path_to_verbs.popitem()
    markedup_text = markup_path(
        problem.working_directory,
        path,
        problem.transformation_spec,
    )
    out = render_markup_to_text_autoregressive(
        markedup_text,
        path,
        problem,
    )
    if out.changed_element is None:
        msg = "No insert element found in the markup"
        raise RuntimeError(msg)
    return out


def render_single_edit_code_repair_problem_autoregressive(
    problem: CodeProblem,
    prompt_seed: int | None = None,
    prompt_test_case_selection_strategy: PromptingTestCaseSelectionStrategy | None = None,
) -> LmPromptRenderSingleEdit:
    """Renders markup up to the first edit/insert statement. Masks any invisible text."""
    if prompt_test_case_selection_strategy is not None:
        msg = "Currently not implemented for this renderer"
        raise NotImplementedError(
            msg,
        )
    spec = problem.transformation_spec
    if spec.count_editing_statements() != 1:
        msg = "Currently simple and assumes a exactly one edit statement"
        raise NotImplementedError(
            msg,
        )
    path_to_verbs = get_verbs_per_path(problem.working_directory, spec)
    if len(path_to_verbs) != 1:
        msg = "Currently simple and assumes a single path"
        raise ValueError(msg)

    path, _verbs = path_to_verbs.popitem()
    markedup_text = markup_path(
        problem.working_directory,
        path,
        problem.transformation_spec,
    )

    text = []
    final_instructions = (
        "\nQuestion: There is a bug in the above code snippet tagged by <buggy> and"
        " </buggy>. Please generate the correct version.\nAnswer:\n<fixed>"
    )

    def _finish_up(node: MarkElement[StsPlaceTransforming]):
        joined_text = "".join(text)
        stop = None
        if node is not None:
            stop = ["</fixed>"]
        prompt = LmPrompt(
            joined_text + final_instructions,
            stop=stop,
            cache=False,
        )
        return SingleLineBugRender(
            prompt,
            joined_text,
            path,
            node,
        )

    edit_node = None
    for state in markedup_text.iterate_parser_states():
        if isinstance(state.node, MarkText):
            if state.is_visible:
                text.append(state.node.text)
        elif isinstance(state.node, MarkElement) and state.is_editable:
            if edit_node is not None:
                msg = "Currently simple and assumes a exactly one edit statement"
                raise NotImplementedError(
                    msg,
                )
            edit_node = state.node
            text.append("<buggy>")

            def after_done():
                text[-1] = text[-1].rstrip()
                text.append("</buggy>\n")

            state.after_visit_children_callbacks.append(after_done)

    out = _finish_up(edit_node)

    if out.changed_element is None:
        msg = "No insert element found in the markup"
        raise RuntimeError(msg)
    return out


def render_single_edit_codeproblem_dialog(
    problem: CodeProblem,
    few_shot_config: FewShotConfig = None,
    prompt_seed: int | None = None,
    prompt_test_case_selection_strategy: PromptingTestCaseSelectionStrategy = None,
) -> LmPromptRenderSingleEdit:
    messages = []

    def select_prompt_tc(p):
        return (
            prompt_test_case_selection_strategy.select_for_problem(p)
            if prompt_test_case_selection_strategy is not None
            else None
        )

    if few_shot_config is not None:
        for few_shot in few_shot_config.library.sample_problems(
            n=few_shot_config.num_examples,
            target_problem=problem,
            seed=prompt_seed,
        ):
            new_messages, auto_regressive = _problem_instance_dialog(
                few_shot,
                ex_index=len(messages),
                add_answer=True,
                prompt_test_cases=select_prompt_tc(few_shot),
            )
            messages.extend(new_messages)
    # add final prompt
    new_messages, auto_regressive = _problem_instance_dialog(
        problem,
        ex_index=len(messages),
        add_answer=False,
        prompt_test_cases=select_prompt_tc(problem),
    )
    messages.extend(new_messages)
    return LmPromptRenderSingleEdit(
        prompt=LmPrompt(
            LmChatDialog(
                values=messages,
            ),
        ),
        code_prefix=auto_regressive.prompt.text,
        path=auto_regressive.path,
        changed_element=auto_regressive.changed_element,
    )


def format_together_test_cases_test_suite(
    test_cases: list[TestCase],
    ex_index: int = 0,
) -> str:
    if test_cases is None:
        return ""
    msg = []
    py_spec = PythonLangSpec()
    for i, test_case in enumerate(test_cases):
        category_name = f"ex{ex_index}"
        if isinstance(test_case, TestCaseMethodCallIsEq):
            category_name = f"{test_case.method_name}"
        override_name = f"{category_name}_{i}"
        msg.append(
            test_case.format_for_test_framework(
                py_spec,
                override_test_id=override_name,
            ),
        )
    return "\n".join(msg)


def _problem_instance_dialog(
    problem: CodeProblem,
    ex_index: int,
    add_answer: bool,
    prompt_test_cases: list[TestCase] | None = None,
) -> tuple[list[LmChatTurn], LmPromptRenderSingleEdit]:
    auto_regressive = render_single_edit_codeproblem_autoregressive(problem)
    out = []
    instructions = problem.instructions
    verb = auto_regressive.changed_element.verb
    if instructions is None or instructions == "":
        # A default instructions to get the chat model like a completion model
        instructions = (
            "Give a completion to the following code to the best "
            "of your ability as if acting like a smart completion model. "
            "Give only the completion, no other text. "
            "Do not put in a markdown code block. "
        )
        if verb.max_length_lines == 1:
            instructions += "For now we are only predicting the next line. "
            instructions += "Give the text of the next line."
        else:
            instructions += "Do not include any other text than the completion."

    message = f"{instructions}\n" + "```\n" + f"{auto_regressive.prompt.text}"
    message = ensure_newline_at_end(message)
    message += "```\n"
    tests_text = format_together_test_cases_test_suite(
        prompt_test_cases,
        ex_index,
    )
    if tests_text:
        message += "Test cases:\n```\n"
        message += ensure_newline_at_end(tests_text)
        message += "```"
    prompt_msg = LmChatTurn(
        role=ChatGptRoles.user,
        content=message,
    )
    out.append(prompt_msg)
    if add_answer:
        if len(problem.known_solutions) != 1:
            msg = (
                "In few shot currently only support problems with exactly one solution"
            )
            raise NotImplementedError(
                msg,
            )
        steps = problem.known_solutions[0].solve_steps
        if len(steps) != 1:
            msg = "In fewshot currently only supports single step solutions"
            raise ValueError(msg)
        answer_content = f"```\n{steps[0].value}"
        if not answer_content.endswith("\n"):
            answer_content += "\n"
        answer_content += "```"
        out.append(
            LmChatTurn(
                role=ChatGptRoles.assistant,
                content=answer_content,
            ),
        )
    return out, auto_regressive


class PromptRendererSingleEditAutoregressive(PromptRendererSingleEdit):
    def __init__(
        self,
        few_shot_config: FewShotConfig = None,
        prompt_test_case_selection_strategy: PromptingTestCaseSelectionStrategy = None,
    ):
        super().__init__()
        self._few_shot_config = few_shot_config
        self._prompt_test_case_selection_strategy = prompt_test_case_selection_strategy

    def render(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> LmPromptRenderSingleEdit:
        if self._few_shot_config is not None:
            warnings.warn(
                "Few shot not supported with this non-chat version. "
                "No few shot will be used. TODO",
            )
        return render_single_edit_codeproblem_autoregressive(problem, prompt_seed)


class PromptRendererSingleEditDialog(PromptRendererSingleEdit):
    def __init__(
        self,
        few_shot_config: FewShotConfig = None,
        prompt_test_case_selection_strategy: PromptingTestCaseSelectionStrategy = None,
    ):
        super().__init__()
        self._few_shot_config = few_shot_config
        self._prompt_test_case_selection_strategy = prompt_test_case_selection_strategy

    def render(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> LmPromptRenderSingleEdit:
        return render_single_edit_codeproblem_dialog(
            problem,
            few_shot_config=self._few_shot_config,
            prompt_seed=prompt_seed,
            prompt_test_case_selection_strategy=self._prompt_test_case_selection_strategy,
        )


class PromptRendererSingleEditGeneric(PromptRendererSingleEdit):
    """
    A prompt render that combines several prompt renderers and tries to choose
    the best one based on the problem and language model.
    """

    @classmethod
    def build_from_defaults(
        cls,
        max_tokens: int = 500,
        force_use_dialogue: bool = False,
        few_shot_config: FewShotConfig = None,
        prompt_test_case_selection_strategy: PromptingTestCaseSelectionStrategy = None,
    ) -> "PromptRendererSingleEditGeneric":
        """
        The different defaults component renderers have some common components (such
        as the few_shot_config). This builder method will construct all the defaults
        with the supplied configs.
        """
        return cls(
            max_tokens=max_tokens,
            force_use_dialogue=force_use_dialogue,
            completion_render_func=cls.default_completion(
                few_shot_config=few_shot_config,
                prompt_test_case_selection_strategy=prompt_test_case_selection_strategy,
            ),
            dialogue_render_func=cls.default_dialogue(
                few_shot_config,
                prompt_test_case_selection_strategy,
            ),
        )

    @classmethod
    def default_completion(
        cls,
        few_shot_config: FewShotConfig = None,
        prompt_test_case_selection_strategy: PromptingTestCaseSelectionStrategy = None,
    ):
        return PromptRendererSingleEditAutoregressive(
            few_shot_config=few_shot_config,
            prompt_test_case_selection_strategy=prompt_test_case_selection_strategy,
        )

    @classmethod
    def default_dialogue(
        cls,
        few_shot_config: FewShotConfig = None,
        prompt_test_case_selection_strategy: PromptingTestCaseSelectionStrategy = None,
    ):
        return PromptRendererSingleEditDialog(
            few_shot_config=few_shot_config,
            prompt_test_case_selection_strategy=prompt_test_case_selection_strategy,
        )

    def __init__(
        self,
        max_tokens: int = 500,
        force_use_dialogue: bool = False,
        completion_render_func: PromptRendererSingleEdit = None,
        dialogue_render_func: PromptRendererSingleEdit = None,
        few_shot_config_default: FewShotConfig = None,
    ):
        super().__init__()
        self._max_tokens = max_tokens
        self._force_use_dialogue = force_use_dialogue
        self._completion_render_func = (
            completion_render_func
            or self.default_completion(few_shot_config=few_shot_config_default)
        )
        self._dialogue_render_func = dialogue_render_func or self.default_dialogue(
            few_shot_config=few_shot_config_default,
        )

    def render(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> LmPromptRenderSingleEdit:
        if not lm.is_chat_model or self._force_use_dialogue:
            render = self._completion_render_func(problem, lm, prompt_seed)
        else:
            render = self._dialogue_render_func(problem, lm, prompt_seed)
        dataclasses.replace(
            render,
            prompt=dataclasses.replace(
                render.prompt,
                max_tokens=self._max_tokens,
                cache=True,
            ),
        )
        return render
