import logging
import re
from enum import Enum
from random import Random
from typing import TypedDict

from typing_extensions import Unpack

from inspect_ai._util.answer import answer_character, answer_index
from inspect_ai._util.logger import warn_once
from inspect_ai.util import resource

from ._solver import Generate, Solver, solver
from ._task_state import Choices, TaskState

logger = logging.getLogger(__name__)

SINGLE_ANSWER_TEMPLATE = r"""
Answer the following multiple choice question. The entire content of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of {letters}.

{question}

{choices}
""".strip()

SINGLE_ANSWER_TEMPLATE_COT = r"""
Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of {letters}. Think step by step before answering.

{question}

{choices}
""".strip()

MULTIPLE_ANSWER_TEMPLATE = r"""
Answer the following multiple choice question where multiple answers may be correct. The entire content of your response should be of the following format: 'ANSWER: $LETTERS' (without quotes) where LETTERS is one or more of {letters}.

{question}

{choices}
""".strip()


MULTIPLE_ANSWER_TEMPLATE_COT = r"""
Answer the following multiple choice question where multiple answers may be correct. The last line of your response should be of the following format: 'ANSWER: $LETTERS' (without quotes) where LETTERS is one or more of {letters}. Think step by step before answering.

{question}

{choices}
""".strip()


def unshuffle_choices(choices: Choices) -> Choices:
    # `sorted` returns `list[Choice]`, but for consistency we wrap this back
    # into a `Choices` object
    return Choices(sorted(choices, key=lambda choice: choice.original_position))


def answer_options(choices: Choices) -> str:
    r"""
    Returns the `choices` formatted as a multiple choice question, e.g.:

    ["choice 1", "choice 2", "choice 3"] ->
        "A) choice 1\nB) choice 2\nC) choice 3"
    """
    indexes = list(range(len(choices)))

    return "\n".join(
        [f"{answer_character(i)}) {choices[j].value}" for i, j in enumerate(indexes)]
    )


def prompt(question: str, choices: Choices, template: str) -> str:
    choices_text = answer_options(choices)
    letters = ",".join(answer_character(i) for i in range(len(choices)))

    return template.format(
        choices=choices_text,
        letters=letters,
        question=question,
    )


def parse_answers(state: TaskState, multiple_correct: bool) -> set[str]:
    """
    Convenience function for extracting answers from the state output.

    The generated response must be in the format 'ANSWER: <answers>',
    otherwise we can't extract what the model thinks is "true". We can be a
    bit flexible whether these are "AB" vs "A,B" vs "A B".

    However, if the answer isn't in the expected format the model has
    failed in the task so we'll ultimately just mark it as incorrect
    """
    # First check whether the string strictly ends with the expected answer
    # In this case, we're looking for a single line which contains the expected
    # ANSWER: <answer> string with only whitespace or a period/full stop at the end.
    match = re.search(
        r"(?i)^ANSWER\s*:\s*([A-Za-z\d ,]+)\s*(?:$|\n|\.)",
        state.output.completion,
        flags=re.MULTILINE,
    )

    # If we couldn't match the strict version, we can try the less strict
    # version for backward compatibility
    if match is None:
        match = re.search(
            r"(?i)ANSWER\s*:\s*([A-Za-z\d ,]+)(?:[^\w]|\n|$|\.)",
            state.output.completion,
        )

    if match is None:
        return set()

    matched = match.group(1)

    # Strip trailing period / full stop
    matched = matched.strip()
    matched = matched.rstrip(".")

    allowed_options = set(answer_character(i) for i in range(len(state.choices)))

    if multiple_correct:
        # Match must contain only the allowed choices
        # (may be separated by commas, spaces, the word 'and', or nothing at all)

        matched = matched.replace(" and ", "")

        matched = matched.replace(" ", "")

        split_comma = set(matched.split(","))
        if split_comma.issubset(allowed_options):
            answers = split_comma
            return answers

        split_nothing = set(matched)
        if split_nothing.issubset(allowed_options):
            answers = split_nothing
            return answers

    else:
        # Match must contain a single letter in the allowed choices
        if matched in allowed_options:
            answers = {matched}
            return answers

    return set()


def set_choices_based_on_generated_response(
    state: TaskState, answers: set[str]
) -> None:
    true_answers = [answer_index(letter) for letter in answers]

    for i in range(len(state.choices)):
        if i in true_answers:
            state.choices.mark_choice(i, True)
        else:
            state.choices.mark_choice(i, False)


def pretend_we_didnt_shuffle(
    state: TaskState, original_question: str, template: str
) -> None:
    """
    If we shuffled the choices, revert them to their unshuffled versions in the message history

    This is essentially just for usability. Without doing this, matching up the
    sample choices to the target value(s) can be misleading:

        * You may be expecting a particular result from your dataset which
          doesn't show up in the logs, for example you're expecting all correct
          answers to be "A" but they're not.
        * The Log Viewer knows nothing about the `TaskState` or shuffling, it's
          just looking at the messages and the Score. This leads to
          inconsistencies between the raw `Target` and the answers we're getting
          back from the models.

    By pretending we didn't shuffle in the message history, these
    inconsistencies are easily resolved as the output is what's expected for a
    given `Sample` and `Target`.

    Note that this just rewrites message history. The `TaskState.choices` are
    left shuffled, to allow us to be transparent about this elsewhere (e.g. in
    the scoring explanation).
    """
    # First, change the prompt to the unshuffled version
    pretend_prompt = prompt(
        question=original_question,
        choices=unshuffle_choices(state.choices),
        template=template,
    )
    state.user_prompt.text = pretend_prompt

    # Then, change the last message to appear as though we received unshuffled
    # answers
    answer_text = ", ".join(
        sorted(
            [
                answer_character(choice.original_position)
                for choice in state.choices
                if choice.correct is True
            ]
        )
    )
    pretend_answer = f"ANSWER: {answer_text}"

    state.output.completion = pretend_answer
    state.messages[-1].content = pretend_answer


def valid_template(template: str) -> bool:
    """Check if a template has the required capture groups for a multiple choice question"""
    return bool(
        re.search(r"\{question\}", template) and re.search(r"\{choices\}", template)
    )


class MultipleChoiceTemplate(str, Enum):
    """Templates for multiple choice questions.

    Based on the multiple choice template in openai simple evals:
    https://github.com/openai/simple-evals/blob/main/mmlu_eval.py
    """

    SINGLE_ANSWER = SINGLE_ANSWER_TEMPLATE
    SINGLE_ANSWER_COT = SINGLE_ANSWER_TEMPLATE_COT
    MULTIPLE_ANSWER = MULTIPLE_ANSWER_TEMPLATE
    MULTIPLE_ANSWER_COT = MULTIPLE_ANSWER_TEMPLATE_COT


class DeprecatedArgs(TypedDict, total=False):
    shuffle: bool | Random


@solver
def multiple_choice(
    *,
    template: str | None = None,
    cot: bool = False,
    multiple_correct: bool = False,
    max_tokens: int | None = None,
    **kwargs: Unpack[DeprecatedArgs],
) -> Solver:
    """Multiple choice question solver. Formats a multiple choice question prompt, then calls `generate()`.

    Note that due to the way this solver works, it has some constraints:

    1. The `Sample` must have the `choices` attribute set.
    2. The only built-in compatible scorer is the `choice` scorer.
    3. It calls `generate()` internally, so you don't need to call it again

    Args:
      template: Template to use for the multiple choice question.
        The defaults vary based on the options and are taken from the `MultipleChoiceTemplate` enum. The template will have questions and possible answers substituted into it before being sent to the model. Consequently it requires three specific template variables:

          - `{question}`: The question to be asked.
          - `{choices}`: The choices available, which will be formatted as a
            list of A) ... B) ... etc. before sending to the model.
          - `{letters}`: (optional) A string of letters representing the choices, e.g.
            "A,B,C". Used to be explicit to the model about the possible answers.
      cot: Default `False`. Whether the solver should perform chain-of-thought
        reasoning before answering. NOTE: this has no effect if you provide a custom template.
      multiple_correct: Default `False`. Whether to allow multiple
        answers to the multiple choice question. For example, "What numbers are
        squares? A) 3, B) 4, C) 9" has multiple correct answers, B and C. Leave
        as `False` if there's exactly one correct answer from the choices
        available. NOTE: this has no effect if you provide a custom template.
      max_tokens: Default `None`. Controls the number of tokens generated through the call
        to generate().
      **kwargs (Any): Deprecated arguments for backward compatibility.

    #### Shuffling

    You can shuffle choices when you load your dataset by using the `shuffle_choices` method or parameter of the datasets API.
    """
    shuffle: bool | Random = False
    if "shuffle" in kwargs:
        shuffle = kwargs["shuffle"]

        if shuffle:
            warn_once(
                logger,
                "The multiple choice shuffle parameter is deprecated. Please shuffle choices at the time your dataset is read by using the shuffle_choices method/parameter of the datasets API.",
            )

    if template and not valid_template(template):
        raise ValueError(
            "The template must contain '{question}' and '{choices}' placeholders for string substitution."
        )

    if template is None:
        if multiple_correct:
            if cot:
                template = MULTIPLE_ANSWER_TEMPLATE_COT
            else:
                template = MULTIPLE_ANSWER_TEMPLATE
        else:
            if cot:
                template = SINGLE_ANSWER_TEMPLATE_COT
            else:
                template = SINGLE_ANSWER_TEMPLATE

    template = resource(template)

    if shuffle is True:
        shuffle = Random()

    async def solve(state: TaskState, generate: Generate) -> TaskState:
        if not state.choices:
            raise ValueError("The multiple_choice solver requires samples with choices")

        if isinstance(shuffle, Random):
            state.choices.shuffle(shuffle)

        # Memoise the current prompt (which is the raw "question" part of the
        # sample). Required in case we unshuffle, because we then alter message
        # history based on the multiple-choice template.
        original_question = state.user_prompt.text

        state.user_prompt.text = prompt(
            question=state.user_prompt.text,
            choices=state.choices,
            template=str(template),
        )

        state = await generate(state, max_tokens=max_tokens)

        answers = parse_answers(state, multiple_correct)
        if answers:
            # If we've found answers, update the state appropriately
            set_choices_based_on_generated_response(
                state=state,
                answers=answers,
            )
            if shuffle:
                pretend_we_didnt_shuffle(
                    state=state,
                    original_question=original_question,
                    template=str(template),
                )

        return state

    return solve
