"""Wrapper around Prediction Guard APIs."""
import logging
from typing import Any, Dict, List, Optional

from pydantic import Extra, root_validator

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)


class PredictionGuard(LLM):
    """Wrapper around Prediction Guard large language models.
    To use, you should have the ``predictionguard`` python package installed, and the
    environment variable ``PREDICTIONGUARD_TOKEN`` set with your access token, or pass
    it as a named parameter to the constructor.
    Example:
        .. code-block:: python
            pgllm = PredictionGuard(name="text-gen-proxy-name", token="my-access-token")
    """

    client: Any  #: :meta private:
    name: Optional[str] = "default-text-gen"
    """Proxy name to use."""

    max_tokens: int = 256
    """Denotes the number of tokens to predict per generation."""

    temperature: float = 0.75
    """A non-negative float that tunes the degree of randomness in generation."""

    token: Optional[str] = None

    stop: Optional[List[str]] = None

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that the access token and python package exists in environment."""
        token = get_from_dict_or_env(values, "token", "PREDICTIONGUARD_TOKEN")
        try:
            import predictionguard as pg

            values["client"] = pg.Client(token=token)
        except ImportError:
            raise ValueError(
                "Could not import predictionguard python package. "
                "Please install it with `pip install predictionguard`."
            )
        return values

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling Cohere API."""
        return {
            "max_tokens": self.max_tokens,
            "temperature": self.temperature,
        }

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Get the identifying parameters."""
        return {**{"name": self.name}, **self._default_params}

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "predictionguard"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
    ) -> str:
        """Call out to Prediction Guard's model proxy.
        Args:
            prompt: The prompt to pass into the model.
        Returns:
            The string generated by the model.
        Example:
            .. code-block:: python
                response = pgllm("Tell me a joke.")
        """
        params = self._default_params
        if self.stop is not None and stop is not None:
            raise ValueError("`stop` found in both the input and default params.")
        elif self.stop is not None:
            params["stop_sequences"] = self.stop
        else:
            params["stop_sequences"] = stop

        response = self.client.predict(
            name=self.name,
            data={
                "prompt": prompt,
                "max_tokens": params["max_tokens"],
                "temperature": params["temperature"],
            },
        )
        text = response["text"]

        # If stop tokens are provided, Prediction Guard's endpoint returns them.
        # In order to make this consistent with other endpoints, we strip them.
        if stop is not None or self.stop is not None:
            text = enforce_stop_tokens(text, params["stop_sequences"])

        return text
