from typing import Any, List, Optional, Union

import pydantic
from opik.evaluation.metrics import base_metric, score_result
from opik.evaluation.models import base_model, models_factory

from . import template, parser


class ContextPrecisionResponseFormat(pydantic.BaseModel):
    context_precision_score: float
    reason: str


class ContextPrecision(base_metric.BaseMetric):
    """
    A metric that evaluates the context precision of an input-output pair using an LLM.

    This metric uses a language model to assess how well the given output aligns with
    the provided context for the given input. It returns a score between 0.0 and 1.0,
    where higher values indicate better context precision.

    Args:
        model: The language model to use for evaluation. Can be a string (model name) or an `opik.evaluation.models.OpikBaseModel` subclass instance.
            `opik.evaluation.models.LiteLLMChatModel` is used by default.
        name: The name of the metric. Defaults to "context_precision_metric".
        few_shot_examples: A list of few-shot examples to provide to the model. If None, uses the default few-shot examples.
        track: Whether to track the metric. Defaults to True.
        project_name: Optional project name to track the metric in for the cases when
            there are no parent span/trace to inherit project name from.
        seed: Optional seed value for reproducible model generation. If provided, this seed will be passed to the model for deterministic outputs.
        temperature: Optional temperature value for model generation. If provided, this temperature will be passed to the model. If not provided, the model's default temperature will be used.

    Example:
        >>> from opik.evaluation.metrics import ContextPrecision
        >>> context_precision_metric = ContextPrecision()
        >>> result = context_precision_metric.score("What's the capital of France?", "The capital of France is Paris.", "Paris", ["France is a country in Europe."])
        >>> print(result.value)
        1.0
        >>> print(result.reason)
        The provided output perfectly matches the expected output of 'Paris' and accurately identifies it as the capital of France. ...
    """

    def __init__(
        self,
        model: Optional[Union[str, base_model.OpikBaseModel]] = None,
        name: str = "context_precision_metric",
        few_shot_examples: Optional[
            List[template.FewShotExampleContextPrecision]
        ] = None,
        track: bool = True,
        project_name: Optional[str] = None,
        seed: Optional[int] = None,
        temperature: Optional[float] = None,
    ):
        super().__init__(
            name=name,
            track=track,
            project_name=project_name,
        )
        self._seed = seed
        self._init_model(model, temperature=temperature)
        self.few_shot_examples = few_shot_examples or template.FEW_SHOT_EXAMPLES

    def _init_model(
        self,
        model: Optional[Union[str, base_model.OpikBaseModel]],
        temperature: Optional[float],
    ) -> None:
        if isinstance(model, base_model.OpikBaseModel):
            self._model = model
        else:
            model_kwargs = {}
            if temperature is not None:
                model_kwargs["temperature"] = temperature
            if self._seed is not None:
                model_kwargs["seed"] = self._seed

            self._model = models_factory.get(model_name=model, **model_kwargs)

    def score(
        self,
        input: str,
        output: str,
        expected_output: str,
        context: List[str],
        **ignored_kwargs: Any,
    ) -> score_result.ScoreResult:
        """
        Calculate the context precision score for the given input-output pair.

        Args:
            input: The input text to be evaluated.
            output: The output text to be evaluated.
            expected_output: The expected output for the given input.
            context: A list of context strings relevant to the input.
            **ignored_kwargs: Additional keyword arguments that are ignored.

        Returns:
            score_result.ScoreResult: A ScoreResult object containing the context precision score
            (between 0.0 and 1.0) and a reason for the score.
        """
        llm_query = template.generate_query(
            input=input,
            output=output,
            expected_output=expected_output,
            context=context,
            few_shot_examples=self.few_shot_examples,
        )
        model_output = self._model.generate_string(
            input=llm_query,
            response_format=ContextPrecisionResponseFormat,
        )

        return parser.parse_model_output(content=model_output, name=self.name)

    async def ascore(
        self,
        input: str,
        output: str,
        expected_output: str,
        context: List[str],
        **ignored_kwargs: Any,
    ) -> score_result.ScoreResult:
        """
        Asynchronously calculate the context precision score for the given input-output pair.

        This method is the asynchronous version of :meth:`score`. For detailed documentation,
        please refer to the :meth:`score` method.

        Args:
            input: The input text to be evaluated.
            output: The output text to be evaluated.
            expected_output: The expected output for the given input.
            context: A list of context strings relevant to the input.
            **ignored_kwargs: Additional keyword arguments that are ignored.

        Returns:
            score_result.ScoreResult: A ScoreResult object with the context precision score and reason.
        """
        llm_query = template.generate_query(
            input=input,
            output=output,
            expected_output=expected_output,
            context=context,
            few_shot_examples=self.few_shot_examples,
        )
        model_output = await self._model.agenerate_string(
            input=llm_query,
            response_format=ContextPrecisionResponseFormat,
        )

        return parser.parse_model_output(content=model_output, name=self.name)
