from __future__ import annotations

from typing import Any, Optional, Tuple, cast

from llama_index import Prompt
from llama_index.callbacks import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.llm_predictor.base import BaseLLMPredictor, LLMMetadata, LLM
from llama_index.llm_predictor.vellum.exceptions import VellumGenerateException
from llama_index.llm_predictor.vellum.prompt_registry import VellumPromptRegistry
from llama_index.llm_predictor.vellum.types import (
    VellumCompiledPrompt,
    VellumRegisteredPrompt,
)
from llama_index.types import TokenAsyncGen, TokenGen


class VellumPredictor(BaseLLMPredictor):
    def __init__(
        self,
        vellum_api_key: str,
        callback_manager: Optional[CallbackManager] = None,
    ) -> None:
        import_err_msg = (
            "`vellum` package not found, please run `pip install vellum-ai`"
        )
        try:
            from vellum.client import AsyncVellum, Vellum  # noqa: F401
        except ImportError:
            raise ImportError(import_err_msg)

        self.callback_manager = callback_manager or CallbackManager([])

        # Vellum-specific
        self._vellum_client = Vellum(api_key=vellum_api_key)
        self._async_vellum_client = AsyncVellum(api_key=vellum_api_key)
        self._prompt_registry = VellumPromptRegistry(vellum_api_key=vellum_api_key)

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""

        # Note: We use default values here, but ideally we would retrieve this metadata
        # via Vellum's API based on the LLM that backs the registered prompt's
        # deployment. This is not currently possible, so we use default values.
        return LLMMetadata()

    @property
    def llm(self) -> LLM:
        """Get the LLM."""
        raise NotImplementedError("Vellum does not expose the LLM.")

    def predict(self, prompt: Prompt, **prompt_args: Any) -> str:
        """Predict the answer to a query."""

        from vellum import GenerateRequest

        registered_prompt, compiled_prompt, event_id = self._prepare_generate_call(
            prompt, **prompt_args
        )

        result = self._vellum_client.generate(
            deployment_id=registered_prompt.deployment_id,
            requests=[
                GenerateRequest(input_values=prompt.get_full_format_args(prompt_args))
            ],
        )

        completion_text = self._process_generate_response(
            result, compiled_prompt, event_id
        )

        return completion_text

    def stream(self, prompt: Prompt, **prompt_args: Any) -> TokenGen:
        """Stream the answer to a query."""

        from vellum import GenerateRequest, GenerateStreamResult

        registered_prompt, compiled_prompt, event_id = self._prepare_generate_call(
            prompt, **prompt_args
        )

        responses = self._vellum_client.generate_stream(
            deployment_id=registered_prompt.deployment_id,
            requests=[
                GenerateRequest(input_values=prompt.get_full_format_args(prompt_args))
            ],
        )

        def text_generator() -> TokenGen:
            complete_text = ""

            while True:
                try:
                    stream_response = next(responses)
                except StopIteration:
                    self.callback_manager.on_event_end(
                        CBEventType.LLM,
                        payload={
                            EventPayload.RESPONSE: complete_text,
                            EventPayload.PROMPT: compiled_prompt.text,
                        },
                        event_id=event_id,
                    )
                    break

                result: GenerateStreamResult = stream_response.delta

                if result.error:
                    raise VellumGenerateException(result.error.message)
                elif not result.data:
                    raise VellumGenerateException(
                        "Unknown error occurred while generating"
                    )

                completion_text_delta = result.data.completion.text
                complete_text += completion_text_delta

                yield completion_text_delta

        return text_generator()

    async def apredict(self, prompt: Prompt, **prompt_args: Any) -> str:
        """Asynchronously predict the answer to a query."""

        from vellum import GenerateRequest

        registered_prompt, compiled_prompt, event_id = self._prepare_generate_call(
            prompt, **prompt_args
        )

        result = await self._async_vellum_client.generate(
            deployment_id=registered_prompt.deployment_id,
            requests=[
                GenerateRequest(input_values=prompt.get_full_format_args(prompt_args))
            ],
        )

        completion_text = self._process_generate_response(
            result, compiled_prompt, event_id
        )

        return completion_text

    async def astream(self, prompt: Prompt, **prompt_args: Any) -> TokenAsyncGen:
        async def gen() -> TokenAsyncGen:
            for token in self.stream(prompt, **prompt_args):
                yield token

        # NOTE: convert generator to async generator
        return gen()

    def _prepare_generate_call(
        self, prompt: Prompt, **prompt_args: Any
    ) -> Tuple[VellumRegisteredPrompt, VellumCompiledPrompt, str]:
        """Prepare a generate call."""

        registered_prompt = self._prompt_registry.from_prompt(prompt)
        compiled_prompt = self._prompt_registry.get_compiled_prompt(
            registered_prompt, prompt_args
        )

        cb_payload = {
            **prompt_args,
            "deployment_id": registered_prompt.deployment_id,
            "model_version_id": registered_prompt.model_version_id,
        }
        event_id = self.callback_manager.on_event_start(
            CBEventType.LLM,
            payload=cb_payload,
        )
        return registered_prompt, compiled_prompt, event_id

    def _process_generate_response(
        self,
        result: Any,
        compiled_prompt: VellumCompiledPrompt,
        event_id: str,
    ) -> str:
        """Process the response from a generate call."""

        from vellum import GenerateResponse

        result = cast(GenerateResponse, result)

        completion_text = result.text

        self.callback_manager.on_event_end(
            CBEventType.LLM,
            payload={
                EventPayload.RESPONSE: completion_text,
                EventPayload.PROMPT: compiled_prompt.text,
            },
            event_id=event_id,
        )

        return completion_text
