import dataclasses
import json
import math
import random
from collections.abc import Callable, Sequence

from lmwrapper.abstract_predictor import LmPredictor
from lmwrapper.structs import LmPrediction, LmPrompt

from synthegrator.code_problems import CodeProblem
from synthegrator.code_solver import ResponseParser
from synthegrator.problem_rendering import (
    LmPromptRender,
    ProblemNotSupportedError,
    PromptRenderer,
    render_markup_to_plain_text,
)
from synthegrator.response_parser import AnswerForSolutionNeverSupported
from synthegrator.transformation_spec import (
    MarkElement,
    SolveStep,
    StsCodeClassificationQuestion,
    StsCodeQuestion,
    StsSpecStatement,
    StsValueQuery,
    markup_path,
)
from synthegrator.util import (
    find_after_first_occurrence_of_string,
    get_only_item,
    normalize_probs_dict,
)

PARSE_CALLBACK = Callable[[LmPrediction], list[SolveStep]]
ANSWER_FOR_SOLUTION_CALLBACK = Callable[[Sequence[SolveStep]], str]


@dataclasses.dataclass(frozen=True)
class PromptRenderWithCallback(LmPromptRender):
    parse_callback: PARSE_CALLBACK
    answer_for_solution_callback: ANSWER_FOR_SOLUTION_CALLBACK = None


class QuestionPromptRenderer(PromptRenderer):
    """
    A prompt renderer for solving StsCodeQuestion problems in
    typical autoregressive way.

    :param randomize_multiple_choice: Whether to randomize the order of the
        multiple choice answers. This might be nice if the model shows
        some bias towards certain answers.
    :param seed_on_problem_id: Only matters if randomize_multiple_choice is
        True. If True, then the randomization will depend based on
        the problem id. Note, that aditional state can still be
        passed in to a make it still random accross runs, but
        a given (problem id, randomizing_seed) pair will always
        produce the same randomization.
    """

    def __init__(
        self,
        randomize_multiple_choice: bool = False,
        seed_on_problem_id: bool = True,
    ):
        # TODO toggle for Chain of Thought
        super().__init__()
        self.randomize_multiple_choice = randomize_multiple_choice
        self.seed_on_problem_id = seed_on_problem_id

    def render(
        self,
        problem: CodeProblem,
        lm: LmPredictor,
        prompt_seed: int | None = None,
        chosen_statement: StsSpecStatement | None = None,
    ) -> LmPromptRender:
        """
        Renders a code problem into a prompt for a language model.
        :param problem: The code problem to render
        :param lm: The language model to render for. Useful if want to
            determine whether to make a dialog or how long can generate.
        :param prompt_seed: A seed to use for any randomization (e.g. for
             few shot seed.)
        :param chosen_statement: The statement to render for in case there are
            multiple questions in the problem. If None, will expect that there
            is exactly one value question in the problem.
        """
        chosen_statement = self._get_and_check_chosen_statement(
            problem,
            chosen_statement,
        )
        try:
            chosen_statement.attempt_to_get_path_as_exact_path()
        except ValueError:
            msg = "Chosen statement must have a path that is a plain path"
            raise ProblemNotSupportedError(
                msg,
            )
        output = ""
        if problem.instructions and problem.instructions_are_essential:
            output += problem.instructions
        output += self._build_code_part(problem, chosen_statement)
        prompt_args, callback, ans_to_sol_cb = self._build_question_part(
            problem,
            chosen_statement,
            prompt_seed,
        )
        output += prompt_args.pop("text")
        return PromptRenderWithCallback(
            prompt=LmPrompt(
                text=output,
                **prompt_args,
            ),
            parse_callback=callback,
            answer_for_solution_callback=ans_to_sol_cb,
        )

    def _build_question_part(
        self,
        problem,
        chosen_statement,
        randomizing_seed,
    ) -> tuple[dict, PARSE_CALLBACK, ANSWER_FOR_SOLUTION_CALLBACK]:
        verb: StsCodeQuestion = chosen_statement.verb
        text = "Question: " + verb.question_text
        prompt_args = {}
        if isinstance(verb, StsCodeClassificationQuestion):
            text += " Answer with only the letter corresponding to the best answer."
            answers = self._get_valid_answers_maybe_randomized(
                verb,
                problem,
                randomizing_seed,
            )
            if len(answers) > 26:
                msg = (
                    "Too many answers for multiple choice question. Expected at most"
                    f" 26. Got {len(answers)}"
                )
                raise ProblemNotSupportedError(
                    msg,
                )
            choice_to_answer = {}
            for i, answer in enumerate(answers):
                letter = chr(ord("A") + i)
                text += f"\n{letter}) {answer}"
                choice_to_answer[letter] = answer
            prompt_args["max_tokens"] = 2
            prompt_args["logprobs"] = 5
            callback, ans_to_sol_cb = self._build_multiple_choice_callback(
                problem,
                chosen_statement,
                choice_to_answer,
            )
        else:
            callback, ans_to_sol_cb = self._build_standard_callback(
                problem, chosen_statement
            )
        text += "\nAnswer:"
        prompt_args["text"] = text
        return prompt_args, callback, ans_to_sol_cb

    def _get_valid_answers_maybe_randomized(
        self,
        verb: StsCodeClassificationQuestion,
        problem: CodeProblem,
        randomizing_seed: int | None,
    ) -> list[str]:
        answers = list(verb.valid_answers)
        if self.randomize_multiple_choice:
            seed = ""
            if self.seed_on_problem_id:
                seed += str(problem.problem_id)
            if randomizing_seed is not None:
                seed += str(randomizing_seed)
            if seed:
                random.seed(seed)
            random.shuffle(answers)
        return answers

    def _build_code_part(self, problem, chosen_statement: StsSpecStatement):
        path = chosen_statement.attempt_to_get_path_as_exact_path()
        markup_root = markup_path(
            problem.working_directory,
            path,
            problem.transformation_spec,
        )
        code_text = render_markup_to_plain_text(
            markup_root,
        )
        lang_spec = problem.get_lang_spec_for_path(
            chosen_statement.attempt_to_get_path_as_exact_path(),
        )
        # fmt: off
        code_text = code_text.strip("\n")
        return (
            f"```{lang_spec.get_lang_md_name()}\n"
            f"{code_text}\n"
            f"```\n"
        )
        # fmt: on

    def _get_and_check_chosen_statement(
        self,
        problem,
        chosen_statement,
    ) -> StsSpecStatement:
        if chosen_statement is None:
            question_nodes = list(problem.transformation_spec.question_statements())
            if len(question_nodes) == 0:
                msg = "No question statements in problem"
                raise ProblemNotSupportedError(msg)
            if len(question_nodes) > 1:
                msg = (
                    "Multiple question statements in problem. Must specify"
                    " chosen_statement."
                )
                raise ProblemNotSupportedError(
                    msg,
                )
            chosen_statement = question_nodes[0]
        if not isinstance(chosen_statement.verb, StsCodeQuestion):
            msg = "chosen_statement must be a StsCodeQuestion"
            raise ValueError(msg)
        return chosen_statement

    def _get_mark_node_for_chosen_statement(
        self,
        problem,
        chosen_statement: StsSpecStatement,
        exception_on_not_found: bool = True,
    ) -> tuple[str, MarkElement[StsValueQuery] | None]:
        path = chosen_statement.attempt_to_get_path_as_exact_path()
        markup = markup_path(
            problem.working_directory,
            path,
            problem.transformation_spec,
        )
        for state in markup.iterate_parser_states():
            if isinstance(state.node.verb, StsValueQuery):
                return path, state.node
        if exception_on_not_found:
            raise RuntimeError
        return path, None

    def _build_standard_callback(
        self,
        problem,
        chosen_statement,
    ) -> PARSE_CALLBACK:
        path, node = self._get_mark_node_for_chosen_statement(problem, chosen_statement)

        def cb(resp: LmPrediction) -> list[SolveStep]:
            resp_str, token = grab_after_prompt(resp, "Answer:")
            return [
                SolveStep(
                    path,
                    node.mark_id,
                    value=resp_str,
                ),
            ]

        def answer_to_solution_cb(solve_steps: list[SolveStep]) -> str:
            val = get_only_item(solve_steps).value
            if not val.startswith(" "):
                return " " + val
            return val

        return cb, answer_to_solution_cb

    def _build_multiple_choice_callback(
        self,
        problem,
        chosen_statement,
        choice_to_answer: dict[str, str],
        case_insensitive: bool = True,
        strip_whitespace: bool = True,
    ) -> tuple[PARSE_CALLBACK, ANSWER_FOR_SOLUTION_CALLBACK]:
        path, node = self._get_mark_node_for_chosen_statement(problem, chosen_statement)
        new_choice_to_answer = dict(choice_to_answer)
        answer_to_choice = {v: k for k, v in new_choice_to_answer.items()}
        #print("Answer to choice")
        #print(answer_to_choice)
        if case_insensitive:
            new_choice_to_answer = {
                k.lower(): v for k, v in new_choice_to_answer.items()
            }
        if strip_whitespace:
            new_choice_to_answer = {
                k.strip(): v for k, v in new_choice_to_answer.items()
            }
        available_answers = list(new_choice_to_answer.values())
        # Make an actual answers map to itself
        new_choice_to_answer |= {v: v for v in new_choice_to_answer.values()}
        #print("New choice to answer")
        #print(new_choice_to_answer)

        def cb(resp: LmPrediction) -> list[SolveStep]:
            resp_str, token_index_start = grab_after_prompt(resp, "Answer:")
            token_log_probs = resp.top_token_logprobs[token_index_start]
            #print("Token log probs")
            #print(token_log_probs)
            #exit()
            answer_to_prob = {}
            for tok, logprob in token_log_probs.items():
                if strip_whitespace:
                    tok = tok.strip()
                if case_insensitive:
                    tok = tok.lower()
                if tok in new_choice_to_answer:
                    answer = new_choice_to_answer.get(tok, None)
                    if answer is None:
                        continue
                    if answer in answer_to_prob:
                        answer_to_prob[answer] += math.exp(logprob)
                    else:
                        answer_to_prob[answer] = math.exp(logprob)
            answer_to_prob = normalize_probs_dict(
                answer_to_prob,
                available_answers,
            )
            return [
                SolveStep(
                    path,
                    node.mark_id,
                    value=json.dumps(answer_to_prob),
                ),
            ]

        verb: StsCodeClassificationQuestion = chosen_statement.verb

        def answer_to_solution_cb(solve_steps: list[SolveStep]) -> str:
            answer_to_prob = verb.value_to_prob_dict(get_only_item(solve_steps).value)
            answer = max(answer_to_prob, key=answer_to_prob.get)
            if strip_whitespace:
                return " " + answer_to_choice[answer]
            return answer_to_choice[answer]

        return cb, answer_to_solution_cb


def grab_after_prompt(
    resp: LmPrediction,
    prompt_str: str,
) -> tuple[str, int]:
    if prompt_str not in resp.completion_text:
        return resp.completion_text, 0
    token_offset, within_token_offset = find_after_first_occurrence_of_string(
        resp.completion_tokens,
        prompt_str,
    )
    if within_token_offset != 0:
        msg = "Not currently dealing with weird split tokens"
        raise NotImplementedError(msg)
    return resp.completion_text.split(prompt_str, 1)[1], token_offset


def res_to_str(
    resp: LmPrediction,
) -> str:
    return resp.completion_text


class CallbackResponseParser(ResponseParser):
    def __init__(self):
        super().__init__()

    def parse(
        self,
        render: PromptRenderWithCallback,
        resp: LmPrediction,
        problem: CodeProblem,
    ) -> list[SolveStep]:
        return render.parse_callback(resp)

    def answer_for_solution(
        self,
        render: PromptRenderWithCallback,
        problem: CodeProblem,
        solve_steps: Sequence[SolveStep],
    ) -> str:
        if render.answer_for_solution_callback is None:
            raise AnswerForSolutionNeverSupported()
        return render.answer_for_solution_callback(solve_steps)
