import logging
import re
from abc import ABC
from collections.abc import Iterable, Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, assert_never, cast
from uuid import UUID

from ldp.alg import (
    Callback,
    ComputeTrajectoryMetricsMixin,
    bulk_evaluate_consensus,
)
from ldp.data_structures import Trajectory
from lmi import LLMModel
from paperqa.agents.tools import Complete, EnvironmentState
from paperqa.docs import Docs
from paperqa.settings import Settings
from paperqa.types import DocDetails, PQASession

from aviary.core import (
    DEFAULT_EVAL_MODEL_NAME,
    TASK_DATASET_REGISTRY,
    Environment,
    Frame,
    MultipleChoiceQuestion,
    TaskDataset,
    ToolResponseMessage,
)
from aviary.env import ENV_REGISTRY
from aviary.envs.litqa import (
    DEFAULT_REWARD_MAPPING,
    GradablePaperQAEnvironment,
)

if TYPE_CHECKING:
    import pandas as pd
    from ldp.agent import Agent
    from ldp.data_structures import Transition

logger = logging.getLogger(__name__)


DEFAULT_LABBENCH_HF_HUB_NAME = "futurehouse/lab-bench"
# Test split from Aviary paper's section 4.3: https://doi.org/10.48550/arXiv.2412.21154
DEFAULT_AVIARY_PAPER_HF_HUB_NAME = "futurehouse/aviary-paper-data"

ENV_NAME = "paperqa-local"
ENV_REGISTRY[ENV_NAME] = (
    GradablePaperQAEnvironment.__module__,
    GradablePaperQAEnvironment.__name__,
)


async def evaluate_consensus_sampling(
    data: Iterable[GradablePaperQAEnvironment | Frame],
    exclude_no_answer: bool = False,
    num_samples: int = 1,
    seed: int | None = None,
) -> tuple[dict[str, list[tuple[str, int]]], float]:
    """
    Create consensus groups based on question and evaluate the consensus for each.

    Args:
        data: Data to evaluate consensus upon, either gradable environments or frames.
        exclude_no_answer: Opt-in flag to filter out empty answers (due to the
            Environment/Frame not having a graded answer). Use of this flag does not
            affect the accuracy term of the return.
        num_samples: Passed through to evaluate_consensus.
        seed: Passed through to evaluate_consensus.

    Returns:
        Two-tuple of consensus list generated by collections.Counter.most_common (keys
            are question, values are list of (answer, vote count)) and the proportion of
            groups for which the consensus matches the ideal.
    """

    def extract_question(x: GradablePaperQAEnvironment | Frame) -> str:
        if isinstance(x, GradablePaperQAEnvironment):
            query: str | MultipleChoiceQuestion | dict[str, Any] = x._query
        else:
            query = x.info["query"]  # type: ignore[call-overload,index]
        if isinstance(query, str):
            return query
        if isinstance(query, MultipleChoiceQuestion):
            return query.question_prompt
        return query["question"]

    def extract_answer(x: GradablePaperQAEnvironment | Frame) -> str:
        ses: PQASession | dict[str, Any] = (
            x.state.session
            if isinstance(x.state, EnvironmentState)
            else cast("PQASession | dict[str, Any]", x.state["session"])  # type: ignore[call-overload,index]
        )
        graded_answer = (
            ses.graded_answer if isinstance(ses, PQASession) else ses["graded_answer"]
        )
        # One can filter the below empty string injection via the exclude_no_answer arg
        return graded_answer or ""

    def extract_ideal(x: GradablePaperQAEnvironment | Frame) -> str:
        if isinstance(x, GradablePaperQAEnvironment):
            query: str | MultipleChoiceQuestion | dict[str, Any] = x._query
        else:
            query = x.info["query"]  # type: ignore[call-overload,index]
        if isinstance(query, str):
            raise ValueError(  # noqa: TRY004
                f"We require a {MultipleChoiceQuestion.__name__} variant to extract"
                " ideal answer, not a string."
            )
        if isinstance(query, MultipleChoiceQuestion):
            return query.ideal_answer
        return query["ideal_answer"]

    try:
        consensus, accuracy = await bulk_evaluate_consensus(
            data=data,
            grouping_fn=extract_question,
            extract_answer_fn=extract_answer,
            ideal_answer_fn=extract_ideal,
            num_samples=num_samples,
            seed=seed,
        )
    except TypeError:
        raise ImportError(
            "Evaluating consensus requires the 'ldp' extra for 'ldp'. Please:"
            " `pip install paper-qa[ldp]`."
        ) from None
    if exclude_no_answer:
        consensus = {
            q: [(a, c) for a, c in answers if a] for q, answers in consensus.items()
        }
    return consensus, accuracy


class StoreForConsensusSamplingCallback(Callback):
    """Store environments or frames for later consensus sampling."""

    def __init__(self):
        super().__init__()
        self.stored: list[GradablePaperQAEnvironment | Frame] = []

    async def after_transition(
        self,
        traj_id: str,
        agent: "Agent",
        env: Environment,
        transition: "Transition",
    ) -> None:
        if not isinstance(env, GradablePaperQAEnvironment):
            raise NotImplementedError(
                f"So far only handled {GradablePaperQAEnvironment} in this callback,"
                f" not {type(env)}."
            )
        if transition.done and not transition.failed:  # Only store once
            return
        self.stored.append(env.export_frame())

    async def evaluate_consensus_sampling(
        self, num_samples: int = 1, seed: int | None = None
    ) -> tuple[dict[str, list[tuple[str, int]]], float]:
        return await evaluate_consensus_sampling(
            data=self.stored, num_samples=num_samples, seed=seed
        )


def read_litqa_v2_from_hub(
    train_eval_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
    test_dataset: str = DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
    randomize: bool = True,
    seed: int | None = None,
    train_eval_split: float = 0.8,
) -> "tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]":
    """
    Read LitQA v2 JSONL into train, eval, and test DataFrames.

    Args:
        train_eval_dataset: Hugging Face Hub dataset's name corresponding with train
            and eval splits.
        test_dataset: Hugging Face Hub dataset's name corresponding with a test split.
        randomize: Opt-out flag to shuffle the dataset after loading in by question.
        seed: Random seed to use for the shuffling.
        train_eval_split: Train/eval split fraction, default is 80% train 20% eval.

    Raises:
        DatasetNotFoundError: If any of the datasets are not found, or the
            user is unauthenticated.
    """  # noqa: DOC502
    try:
        from datasets import load_dataset
    except ImportError as exc:
        raise ImportError(
            "Reading in LitQA2 requires the 'datasets' extra for 'datasets'. Please:"
            " `pip install aviary.litqa[datasets]`."
        ) from exc

    train_eval = load_dataset(train_eval_dataset, "LitQA2")["train"].to_pandas()
    test = load_dataset(test_dataset, "LitQA2")["test"].to_pandas()
    # Convert to list so it's not unexpectedly a numpy array
    train_eval["distractors"] = train_eval["distractors"].apply(list)
    test["distractors"] = test["distractors"].apply(list)
    # Let downstream usage in the TaskDataset's environment factories check for the
    # presence of other DataFrame columns
    if randomize:
        train_eval = train_eval.sample(frac=1, random_state=seed)
        test = test.sample(frac=1, random_state=seed)
    num_train = int(len(train_eval) * train_eval_split)
    return train_eval[:num_train], train_eval[num_train:], test


class LitQATaskDataset(
    TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
):
    """
    Abstract base class for a task dataset of LitQA v1 or v2 questions.

    This is an ABC because it's non-specific to a LitQA version.
    Examples include LitQA v1, v2, or a test stub version of LitQA.
    """

    def __init__(
        self,
        settings: Settings | dict | None = None,
        base_docs: Docs | dict | None = None,
        rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
        question_kwargs: Mapping[str, Any] | None = None,
        eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
        **env_kwargs,
    ):
        if settings is None:
            settings = Settings()
        if isinstance(settings, dict):
            settings = Settings(**settings)
        self._settings = settings
        if base_docs is None:
            base_docs = Docs()
        if isinstance(base_docs, dict):
            base_docs = Docs(**base_docs)
        self._base_docs = base_docs
        self._rewards = rewards
        self._question_kwargs = question_kwargs
        self._eval_model = eval_model
        self._env_kwargs = env_kwargs

    def _make_gradable_environment(
        self,
        ideal_answer: str,
        distractors: str | list[str],
        question_id: str | UUID,
        question: str,
        sources: str | list[str] | None = None,
    ) -> GradablePaperQAEnvironment:
        mc_question = MultipleChoiceQuestion(
            question_id=question_id,
            question=question,
            options=(
                distractors
                if isinstance(distractors, list)
                else MultipleChoiceQuestion.split_options(distractors)
            ),
            ideal_answer=ideal_answer,
            prompt_without_id=True,
            **(self._question_kwargs or {}),
        )
        return GradablePaperQAEnvironment(
            query=mc_question,
            settings=self._settings,
            docs=self._base_docs.model_copy(),
            sources=sources,
            rewards=self._rewards,
            **self._env_kwargs,
        )

    def compute_trajectory_metrics(
        self, trajectories: "Sequence[Trajectory]"
    ) -> dict[str, list[float]]:
        total_paper_count: list[float] = []
        relevant_paper_count: list[float] = []
        evidence_count: list[float] = []
        for t in trajectories:
            split_certainties = [
                split_certainty
                for split_certainty in (
                    re.split(
                        pattern=Complete.CERTAINTY_SPLIT_REGEX_PATTERN,
                        string=obs.content,
                        maxsplit=1,
                    )
                    for obs in t.steps[-1].next_observation
                    if (
                        isinstance(obs, ToolResponseMessage)
                        and obs.name == Complete.TOOL_FN_NAME
                    )
                )
                # Filter for places where the regex split succeeded
                if len(split_certainty) >= 4  # noqa: PLR2004
            ]
            for i, metric_list in enumerate(
                (total_paper_count, relevant_paper_count, evidence_count),
                start=1,  # Regex extraction of status starts after has_successful_answer
            ):
                # NOTE: we use mean to not break if there's 2+ complete calls (which
                # we're prompted not to do). If it happens, they should all have the
                # same status, so the mean value should equal the individual values
                metric_list.append(
                    sum(int(sa[i]) for sa in split_certainties) / len(split_certainties)
                    if split_certainties  # Avoid div0 (when complete wasn't called)
                    else 0
                )
        return super().compute_trajectory_metrics(trajectories) | {
            "total_paper_count": total_paper_count,
            "relevant_paper_count": relevant_paper_count,
            "evidence_count": evidence_count,
            "correct": [
                int(t.steps[-1].reward == self._rewards["correct"])
                for t in trajectories
            ],
            "correct_unsure": [
                int(
                    t.steps[-1].reward
                    in {self._rewards["correct"], self._rewards["unsure"]}
                )
                for t in trajectories
            ],
        }


class LitQAv2TaskSplit(StrEnum):
    TRAIN = "train"
    EVAL = "eval"
    TEST = "test"

    def get_index(self) -> int:
        """
        Get the index of the train (0), eval (1), or test (2) split.

        NOTE: the value matches the index in read_litqa_v2_from_hub's returned splits.
        """
        if self == self.TRAIN:
            return 0
        if self == self.EVAL:
            return 1
        if self == self.TEST:
            return 2
        assert_never(self)  # type: ignore[arg-type]


class LitQAv2TaskDataset(LitQATaskDataset):
    """Task dataset of LitQA v2 questions."""

    def __init__(
        self,
        *args,
        train_eval_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
        test_dataset: str = DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
        read_data_kwargs: Mapping[str, Any] | None = None,
        split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        split_dfs = read_litqa_v2_from_hub(
            train_eval_dataset, test_dataset, **(read_data_kwargs or {})
        )
        self.data = split_dfs[LitQAv2TaskSplit(split).get_index()]

    def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
        sources = []
        for s in self.data.iloc[idx].sources:
            try:
                (doi,) = (
                    s.split(substr, maxsplit=1)[1]
                    for substr in DocDetails.DOI_URL_FORMATS
                    if substr in s
                )
            except ValueError as exc:
                raise NotImplementedError(
                    f"Didn't handle DOI extraction from source {s!r}."
                ) from exc
            sources.append(doi)
        return self._make_gradable_environment(
            ideal_answer=self.data.iloc[idx].ideal,
            distractors=self.data.iloc[idx].distractors,
            question_id=UUID(self.data.iloc[idx].id),
            question=self.data.iloc[idx].question,
            sources=sources,
        )

    def __len__(self) -> int:
        return len(self.data)


TASK_DATASET_NAME = "litqa-v2"
TASK_DATASET_REGISTRY[TASK_DATASET_NAME] = (
    LitQAv2TaskDataset.__module__,
    LitQAv2TaskDataset.__name__,
)
