"""
A PromptRenderer and parser that prompts a model with xml-like
tags on places to insert text. The model is then expected to output
the corresponding edit tags and locations. We expect this to
pretty hard.
"""

import dataclasses
import warnings
from io import StringIO

from lmwrapper.abstract_predictor import LmPredictor
from lmwrapper.structs import LmPrediction, LmPrompt
from lxml import etree

from synthegrator.code_problems import CodeProblem
from synthegrator.environments import ProjectWorkingDirectory
from synthegrator.few_shotting import FewShotConfig
from synthegrator.problem_rendering import LmPromptRender, PromptRenderer
from synthegrator.prompting_test_case_selection import (
    PromptingTestCaseSelectionStrategy,
)
from synthegrator.response_parser import (
    ResponseParser,
    format_return_val_for_node,
)
from synthegrator.transformation_spec import (
    MarkElement,
    MarkText,
    SolveStep,
    StsEditable,
    StsInsert,
    StsPlaceTransforming,
    TransformationSpec,
    get_mark_element,
    get_verbs_per_path,
    map_paths_to_path_ids,
    markup_path,
)


class LmPromptRenderMultiEdit(LmPromptRender):
    pass


@dataclasses.dataclass(frozen=True)
class LmTaggedEditPrompt(LmPromptRenderMultiEdit):
    tag_name_edit: str
    tag_name_solve: str
    preprompted_tag_start: str = ""


class PromptRendererMultiEdit(PromptRenderer):
    def __init__(self):
        super().__init__()

    def render(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> "LmPromptRenderMultiEdit":
        raise NotImplementedError

    def __call__(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> "LmPromptRenderMultiEdit":
        return self.render(problem, lm, prompt_seed)


class TaggedEditRenderer(PromptRendererMultiEdit):
    """Renders a problem as a sequence of edits in xml-like tags"""

    def __init__(
        self,
        few_shot_config: FewShotConfig = None,
        prompt_test_case_selection_strategy: PromptingTestCaseSelectionStrategy = None,
        tag_name_edit: str = "edit",
        tag_name_solve: str = "edit_solve",
        include_first_tag_at_end: bool | None = None,
        custom_closing_lines: str | None = None,
        add_stop: bool | None = None,
    ):
        super().__init__()
        self._few_shot_config = few_shot_config
        self._prompt_test_case_selection_strategy = prompt_test_case_selection_strategy
        self._tag_name_edit = tag_name_edit
        self._tag_name_solve = tag_name_solve
        self._include_first_tag_at_end = include_first_tag_at_end
        self._custom_closing_lines = custom_closing_lines
        self._add_stop = add_stop

    def render(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> LmTaggedEditPrompt:
        if self._few_shot_config is not None:
            warnings.warn(
                "Few shot not supported with this non-chat version. "
                "No few shot will be used. TODO",
            )
        spec = problem.transformation_spec
        if spec.count_editing_statements() != 1:
            msg = "Currently simple and assumes a exactly one edit statement"
            raise NotImplementedError(
                msg,
            )
        path_to_verbs = get_verbs_per_path(problem.working_directory, spec)
        prompt_lines = []
        if problem.instructions:
            prompt_lines.append(problem.instructions)
        path_to_path_id = map_paths_to_path_ids(problem.working_directory, spec)
        all_tag_ids = []
        for path in path_to_verbs:
            markedup_text = markup_path(
                problem.working_directory,
                path,
                problem.transformation_spec,
            )

            prompt_lines.append("@@ " + str(path) + " @@")
            tagged_prompt, tag_ids = _marked_up_text_to_tagged_prompt_text(
                markedup_text,
                path_id=path_to_path_id[path],
                tag_name=self._tag_name_edit,
            )
            prompt_lines.append(tagged_prompt)
            all_tag_ids.extend(tag_ids)
            prompt_lines.append("")
        prompt_lines.append("---")
        closing, prepend = _make_closing_lines(
            all_tag_ids=all_tag_ids,
            tag_name_edit=self._tag_name_edit,
            tag_name_solve=self._tag_name_solve,
            include_first_tag_at_end=(
                not lm.is_chat_model
                if self._include_first_tag_at_end is None
                else self._include_first_tag_at_end
            ),
            custom_closing_lines=self._custom_closing_lines,
        )
        prompt_lines.append(closing)
        text = "\n".join(prompt_lines)
        stop = None
        add_stop = self._add_stop
        if add_stop is None:
            add_stop = not lm.is_chat_model
        if len(all_tag_ids) == 1 and add_stop:
            stop = [f"</{self._tag_name_solve}>"]
        return LmTaggedEditPrompt(
            prompt=LmPrompt(
                text=text,
                stop=stop,
            ),
            tag_name_edit=self._tag_name_edit,
            tag_name_solve=self._tag_name_solve,
            preprompted_tag_start=prepend,
        )

    def __call__(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
    ) -> "LmTaggedEditPrompt":
        return self.render(problem, lm, prompt_seed)


def _make_closing_lines(
    all_tag_ids: list[str],
    tag_name_edit: str = "edit",
    tag_name_solve: str = "edit_solve",
    include_first_tag_at_end: bool = True,
    custom_closing_lines: str | None = None,
) -> tuple[str, str]:
    value = (
        f"Implement the correct changes in the above code in the <{tag_name_edit}>"
        f' tags. Please place each fix in a separate <{tag_name_solve} id="id">'
        f' tags.\nFor example:\n<{tag_name_solve} id="id_example">new replacement for'
        f" the text...</{tag_name_solve}>\nAnswer:\n"
    )
    if custom_closing_lines:
        value = custom_closing_lines
    if include_first_tag_at_end:
        preprepended_prompt = f'<{tag_name_solve} id="{all_tag_ids[0]}">'
    else:
        preprepended_prompt = ""
    return value + preprepended_prompt, preprepended_prompt


def _marked_up_text_to_tagged_prompt_text(
    markedup_text: MarkElement,
    path_id: str,
    tag_name: str = "edit",
) -> tuple[str, list[str]]:
    """
    Converts a marked up text to a prompt
    Returns the prompt text and a list of the ids of the tags in the prompt
    """
    text = []
    tag_ids = []
    for state in markedup_text.iterate_parser_states():
        node = state.node
        if not state.is_visible:
            continue
        if isinstance(node, MarkText):
            text.append(node.text)
        if isinstance(node, MarkElement):
            verb = node.verb
            if isinstance(verb, StsEditable | StsInsert):
                tag_id = f"{path_id}_{node.mark_id}"
                tag_ids.append(tag_id)
                tag_props = [f'id="{tag_id}"']
                if verb.max_length_lines:
                    tag_props.append(f"max_length_lines={verb.max_length_lines}")
                tag = f"<{tag_name} {' '.join(tag_props)}>"
                text.append(tag)

                def close_tag():
                    text.append(f"</{tag_name}>")

                state.after_visit_children_callbacks.append(close_tag)
            else:
                pass
                # raise NotImplementedError("Unhandled {}".format(node.verb))
    return "".join(text), tag_ids


class TaggedEditResponseParser(ResponseParser):
    def __init__(self):
        super().__init__()

    def parse(
        self,
        render: LmTaggedEditPrompt,
        resp: LmPrediction,
        problem: CodeProblem,
    ) -> list[SolveStep]:
        ids_to_values = _find_ids_and_values_in_tagged_answer(
            render.preprompted_tag_start + resp.completion_text,
            tag_name_edit=render.tag_name_edit,
            tag_name_solve=render.tag_name_solve,
            default_id=_get_mark_default_id(problem),
        )
        spec = problem.transformation_spec
        wd = problem.working_directory
        # print(resp.completion_text)
        ids_to_path_and_mark_id = _ids_to_path_id_and_mark_id(ids_to_values, spec, wd)
        solve_steps = []
        for tag_id, (path, mark_id) in ids_to_path_and_mark_id.items():
            node = get_mark_element(spec, path, wd, mark_id)
            if node is None:
                continue  # We are permissive for responses with nonexisted nodes
            lang_spec = problem.get_lang_spec_for_path(path)
            solution_text = format_return_val_for_node(
                ids_to_values[tag_id],
                node,
                lang_spec,
            )
            solve_steps.append(
                SolveStep(
                    path,
                    mark_id,
                    solution_text,
                ),
            )
        return solve_steps


def _find_ids_and_values_in_tagged_answer(
    answer: str,
    tag_name_edit: str = "edit",
    tag_name_solve: str = "edit_solve",
    default_id: str = None,
) -> dict[str, str]:
    """Returns a dictionary mapping from the id of each tag to its value"""
    # The lxml parser requires a root element to parse the content correctly.
    # We wrap the content in a root element named <root>.
    content_wrapped = f"<root>{answer}</root>"

    # Parse the content
    tree = etree.parse(StringIO(content_wrapped), etree.HTMLParser())

    # Initialize the dictionary to store ID to content mapping
    id_to_value = {}

    # Find all <edit_solve> elements and add them to the dictionary
    edit_solves = tree.xpath(f"//{tag_name_solve}")
    for edit_solve in edit_solves:
        edit_id = edit_solve.get("id", default_id)
        if edit_id:
            edit_content = (edit_solve.text or "")
            id_to_value[edit_id] = edit_content

    # We want to be extra permissive in case the model decides to use
    # the <edit> tag instead of the <edit_solve> tag.
    # Find all <edit> elements and add them to the dictionary
    # if their ID is not already present
    edits = tree.xpath(f"//{tag_name_edit}")
    for edit in edits:
        edit_id = edit.get("id", default_id if default_id not in id_to_value else None)
        if edit_id and edit_id not in id_to_value:
            edit_content = (edit.text or "")
            id_to_value[edit_id] = edit_content

    if len(id_to_value) == 0 and default_id is not None:
        striped = answer.split(f"</{tag_name_edit}>", 1)[0]
        striped = striped.split(f"</{tag_name_solve}>", 1)[0]
        id_to_value[default_id] = striped

    return id_to_value


def _ids_to_path_id_and_mark_id(
    ids_to_values: dict[str, str],
    spec: TransformationSpec,
    wd: ProjectWorkingDirectory,
) -> dict[str, tuple[str, str]]:
    path_ids_to_paths = {v: k for k, v in map_paths_to_path_ids(wd, spec).items()}
    ids_to_path_and_mark_id = {}
    for id in ids_to_values:
        if len(id.split("_")) != 2:
            continue
        path_id, mark_id = id.split("_")
        path = path_ids_to_paths.get(path_id)
        if path is None:
            continue
        ids_to_path_and_mark_id[id] = (path, mark_id)
    return ids_to_path_and_mark_id


def _get_mark_default_id(problem: CodeProblem) -> str | None:
    path_to_verbs = get_verbs_per_path(
        problem.working_directory, problem.transformation_spec
    )
    if len(path_to_verbs) > 1:
        raise NotImplementedError("Multiple paths not supported")
    for path, verbs in path_to_verbs.items():
        if not verbs:
            continue
        markup_text = markup_path(
            problem.working_directory, path, problem.transformation_spec
        )
        for state in markup_text.iterate_parser_states():
            if not state.is_visible:
                continue
            if isinstance(state.node, MarkElement):
                if isinstance(state.node.verb, StsPlaceTransforming):
                    return "0_" + state.node.mark_id
    return None
