# global
import abc
import inspect
import threading

# noinspection PyProtectedMember
import time
import uuid
from typing import (
    Any,
    AsyncGenerator,
    Coroutine,
    Dict,
    Generator,
    Iterable,
    List,
    Optional,
    Type,
    Union,
)

import openai

# local
import unify
from openai._types import Headers, Query
from openai.types import CompletionUsage
from openai.types.chat import (
    ChatCompletion,
    ChatCompletionMessage,
    ChatCompletionMessageParam,
    ChatCompletionStreamOptionsParam,
    ChatCompletionToolChoiceOptionParam,
    ChatCompletionToolParam,
)
from openai.types.chat.chat_completion import Choice
from pydantic import BaseModel
from typing_extensions import Self
from unify import BASE_URL, LOCAL_MODELS
from unify.universal_api.clients.helpers import (
    _assert_is_valid_endpoint,
    _assert_is_valid_model,
    _assert_is_valid_provider,
)
from unify.universal_api.utils.httpx_logging import (
    make_async_httpx_client_for_unify_logging,
    make_httpx_client_for_unify_logging,
)

from ...utils._caching import _get_cache, _write_to_cache, is_caching_enabled
from ...utils.helpers import _default
from ..clients.base import _Client
from ..types import Prompt


class _UniClient(_Client, abc.ABC):
    def __init__(
        self,
        endpoint: Optional[str] = None,
        *,
        model: Optional[str] = None,
        provider: Optional[str] = None,
        system_message: Optional[str] = None,
        messages: Optional[List[ChatCompletionMessageParam]] = None,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[Dict[str, int]] = None,
        logprobs: Optional[bool] = None,
        top_logprobs: Optional[int] = None,
        max_completion_tokens: Optional[int] = None,
        n: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        response_format: Optional[Union[Type[BaseModel], Dict[str, str]]] = None,
        seed: Optional[int] = None,
        stop: Union[Optional[str], List[str]] = None,
        stream: Optional[bool] = False,
        stream_options: Optional[ChatCompletionStreamOptionsParam] = None,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = None,
        service_tier: Optional[str] = None,
        tools: Optional[Iterable[ChatCompletionToolParam]] = None,
        tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None,
        parallel_tool_calls: Optional[bool] = None,
        reasoning_effort: Optional[str] = None,
        # platform arguments
        use_custom_keys: bool = False,
        tags: Optional[List[str]] = None,
        drop_params: Optional[bool] = True,
        region: Optional[str] = None,
        log_query_body: Optional[bool] = True,
        log_response_body: Optional[bool] = True,
        api_key: Optional[str] = None,
        # python client arguments
        stateful: bool = False,
        return_full_completion: bool = False,
        traced: bool = False,
        cache: Optional[Union[bool, str]] = None,
        cache_backend: Optional[str] = None,
        # passthrough arguments
        extra_headers: Optional[Headers] = None,
        extra_query: Optional[Query] = None,
        **kwargs,
    ):
        """Initialize the Uni LLM Unify client.

        Args:
            endpoint: Endpoint name in OpenAI API format:
            <model_name>@<provider_name>
            Defaults to None.

            model: Name of the model. Should only be set if endpoint is not set.

            provider: Name of the provider. Should only be set if endpoint is not set.

            system_message: An optional string containing the system message. This
            always appears at the beginning of the list of messages.

            messages: A list of messages comprising the conversation so far. This will
            be appended to the system_message if it is not None, and any user_message
            will be appended if it is not None.

            frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new
            tokens based on their existing frequency in the text so far, decreasing the
            model's likelihood to repeat the same line verbatim.

            logit_bias: Modify the likelihood of specified tokens appearing in the
            completion. Accepts a JSON object that maps tokens (specified by their token
            ID in the tokenizer) to an associated bias value from -100 to 100.
            Mathematically, the bias is added to the logits generated by the model prior
            to sampling. The exact effect will vary per model, but values between -1 and
            1 should decrease or increase likelihood of selection; values like -100 or
            100 should result in a ban or exclusive selection of the relevant token.

            logprobs: Whether to return log probabilities of the output tokens or not.
            If true, returns the log probabilities of each output token returned in the
            content of message.

            top_logprobs: An integer between 0 and 20 specifying the number of most
            likely tokens to return at each token position, each with an associated log
            probability. logprobs must be set to true if this parameter is used.

            max_completion_tokens: The maximum number of tokens that can be generated in
            the chat completion. The total length of input tokens and generated tokens
            is limited by the model's context length. Defaults to the provider's default
            max_completion_tokens when the value is None.

            n: How many chat completion choices to generate for each input message. Note
            that you will be charged based on the number of generated tokens across all
            of the choices. Keep n as 1 to minimize costs.

            presence_penalty: Number between -2.0 and 2.0. Positive values penalize new
            tokens based on whether they appear in the text so far, increasing the
            model's likelihood to talk about new topics.

            response_format: An object specifying the format that the model must output.
            Setting to `{ "type": "json_schema", "json_schema": {...} }` enables
            Structured Outputs which ensures the model will match your supplied JSON
            schema. Learn more in the Structured Outputs guide. Setting to
            `{ "type": "json_object" }` enables JSON mode, which ensures the message the
            model generates is valid JSON.

            seed: If specified, a best effort attempt is made to sample
            deterministically, such that repeated requests with the same seed and
            parameters should return the same result. Determinism is not guaranteed, and
            you should refer to the system_fingerprint response parameter to monitor
            changes in the backend.

            stop: Up to 4 sequences where the API will stop generating further tokens.

            stream: If True, generates content as a stream. If False, generates content
            as a single response. Defaults to False.

            stream_options: Options for streaming response. Only set this when you set
            stream: true.

            temperature:  What sampling temperature to use, between 0 and 2.
            Higher values like 0.8 will make the output more random,
            while lower values like 0.2 will make it more focused and deterministic.
            It is generally recommended to alter this or top_p, but not both.
            Defaults to the provider's default max_completion_tokens when the value is
            None.

            top_p: An alternative to sampling with temperature, called nucleus sampling,
            where the model considers the results of the tokens with top_p probability
            mass. So 0.1 means only the tokens comprising the top 10% probability mass
            are considered. Generally recommended to alter this or temperature, but not
            both.

            tools: A list of tools the model may call. Currently, only functions are
            supported as a tool. Use this to provide a list of functions the model may
            generate JSON inputs for. A max of 128 functions are supported.

            tool_choice: Controls which (if any) tool is called by the
            model. none means the model will not call any tool and instead generates a
            message. auto means the model can pick between generating a message or
            calling one or more tools. required means the model must call one or more
            tools. Specifying a particular tool via
            `{ "type": "function", "function": {"name": "my_function"} }`
            forces the model to call that tool.
            none is the default when no tools are present. auto is the default if tools
            are present.

            parallel_tool_calls: Whether to enable parallel function calling during tool
            use.

            use_custom_keys:  Whether to use custom API keys or our unified API keys
            with the backend provider.

            tags: Arbitrary number of tags to classify this API query as needed. Helpful
            for generally grouping queries across tasks and users, for logging purposes.

            drop_params: Whether or not to drop unsupported OpenAI params by the
            provider you’re using.

            region: A string used to represent the region where the endpoint is
            accessed. Only relevant for on-prem deployments with certain providers like
            `vertex-ai`, `aws-bedrock` and `azure-ml`, where the endpoint is being
            accessed through a specified region.

            log_query_body: Whether to log the contents of the query json body.

            log_response_body: Whether to log the contents of the response json body.

            stateful:  Whether the conversation history is preserved within the messages
            of this client. If True, then history is preserved. If False, then this acts
            as a stateless client, and message histories must be managed by the user.

            return_full_completion: If False, only return the message content
            chat_completion.choices[0].message.content.strip(" ") from the OpenAI
            return. Otherwise, the full response chat_completion is returned.
            Defaults to False.

            traced: Whether to trace the generate method.

            cache: If True, then the arguments will be stored in a local cache file, and
            any future calls with identical arguments will read from the cache instead
            of running the LLM query. If "write" then the cache will only be written
            to, if "read" then the cache will be read from if a cache is available but
            will not write, and if "read-only" then the argument must be present in the
            cache, else an exception will be raised. Finally, an appending "-closest"
            will read the closest match from the cache, and overwrite it if cache writing
            is enabled. This argument only has any effect when stream=False.

            extra_headers: Additional "passthrough" headers for the request which are
            provider-specific, and are not part of the OpenAI standard. They are handled
            by the provider-specific API.

            extra_query: Additional "passthrough" query parameters for the request which
            are provider-specific, and are not part of the OpenAI standard. They are
            handled by the provider-specific API.

            kwargs: Additional "passthrough" JSON properties for the body of the
            request, which are provider-specific, and are not part of the OpenAI
            standard. They will be handled by the provider-specific API.

        Raises:
            UnifyError: If the API key is missing.
        """
        self._base_constructor_args = dict(
            system_message=system_message,
            messages=messages,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            max_completion_tokens=max_completion_tokens,
            n=n,
            presence_penalty=presence_penalty,
            response_format=response_format,
            seed=seed,
            stop=stop,
            stream=stream,
            stream_options=stream_options,
            temperature=temperature,
            top_p=top_p,
            service_tier=service_tier,
            tools=tools,
            tool_choice=tool_choice,
            parallel_tool_calls=parallel_tool_calls,
            reasoning_effort=reasoning_effort,
            # platform arguments
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
            api_key=api_key,
            # python client arguments
            stateful=stateful,
            return_full_completion=return_full_completion,
            traced=traced,
            cache=cache,
            cache_backend=cache_backend,
            # passthrough arguments
            extra_headers=extra_headers,
            extra_query=extra_query,
            **kwargs,
        )
        super().__init__(**self._base_constructor_args)
        self._constructor_args = dict(
            endpoint=endpoint,
            model=model,
            provider=provider,
            **self._base_constructor_args,
        )
        if endpoint and (model or provider):
            raise Exception(
                "if the model or provider are passed, then the endpoint must not be"
                "passed.",
            )
        self._client = self._get_client()
        self._endpoint = None
        self._provider = None
        self._model = None
        if endpoint:
            self.set_endpoint(endpoint)
        if provider:
            self.set_provider(provider)
        if model:
            self.set_model(model)

    # Settable Properties #
    # --------------------#

    @property
    def endpoint(self) -> str:
        """
        Get the endpoint name.

        Returns:
            The endpoint name.
        """
        return self._endpoint

    @property
    def model(self) -> str:
        """
        Get the model name.

        Returns:
            The model name.
        """
        return self._model

    @property
    def provider(self) -> str:
        """
        Get the provider name.

        Returns:
            The provider name.
        """
        return self._provider

    # Setters #
    # --------#

    def set_endpoint(self, value: str) -> Self:
        """
        Set the endpoint name.  # noqa: DAR101.

        Args:
            value: The endpoint name.

        Returns:
            This client, useful for chaining inplace calls.
        """
        _assert_is_valid_endpoint(value, api_key=self._api_key)
        self._endpoint = value
        if value == "user-input":
            return self
        lhs = value.split("->")[0]
        if "@" in lhs:
            self._model, self._provider = lhs.split("@")
        else:
            self._model = lhs
            self._provider = value.split("->")[1]
        return self

    def set_model(self, value: str) -> Self:
        """
        Set the model name.  # noqa: DAR101.

        Args:
            value: The model name.

        Returns:
            This client, useful for chaining inplace calls.
        """
        custom_or_local = self._provider == "local" or "custom" in self._provider
        _assert_is_valid_model(
            value,
            custom_or_local=custom_or_local,
            api_key=self._api_key,
        )
        if self._provider:
            self._endpoint = "@".join([value, self._provider])
        return self

    def set_provider(self, value: str) -> Self:
        """
        Set the provider name.  # noqa: DAR101.

        Args:
            value: The provider name.

        Returns:
            This client, useful for chaining inplace calls.
        """
        _assert_is_valid_provider(value, api_key=self._api_key)
        self._provider = value
        if self._model:
            self._endpoint = "@".join([self._model, value])
        return self

    @staticmethod
    def _handle_kw(
        prompt,
        endpoint,
        stream,
        stream_options,
        use_custom_keys,
        tags,
        drop_params,
        region,
        log_query_body,
        log_response_body,
    ):
        prompt_dict = prompt.components
        if "extra_body" in prompt_dict:
            extra_body = prompt_dict["extra_body"]
            del prompt_dict["extra_body"]
        else:
            extra_body = {}
        kw = dict(
            model=endpoint,
            **prompt_dict,
            stream=stream,
            stream_options=stream_options,
            extra_body={  # platform arguments
                "signature": "python",
                "use_custom_keys": use_custom_keys,
                "tags": tags,
                "drop_params": drop_params,
                "region": region,
                "log_query_body": log_query_body,
                "log_response_body": log_response_body,
                # passthrough json arguments
                **extra_body,
            },
        )
        return {k: v for k, v in kw.items() if v is not None}

    # Representation #
    # ---------------#

    def __repr__(self):
        return "{}(endpoint={})".format(self.__class__.__name__, self._endpoint)

    def __str__(self):
        return "{}(endpoint={})".format(self.__class__.__name__, self._endpoint)

    # --------------------------------------------------------------------- #
    #  Helper(s) – keep the public surface of _UniClient unchanged          #
    # --------------------------------------------------------------------- #

    def _append_to_history(self, assistant_msg: dict) -> None:
        """Append a single assistant message to the internal history."""
        if self._messages is None:
            self._messages = []
        self._messages.append(assistant_msg)

    # --------------------------------------------------------------------- #
    #  Streaming wrappers                                                   #
    # --------------------------------------------------------------------- #

    def _wrap_sync_stream(
        self,
        stream: Generator[Any, None, None],
        *,
        stateful: bool,
        return_full_completion: bool,
    ) -> Generator[Any, None, None]:
        """
        Proxy a *synchronous* stream, collecting the emitted content so we can
        update or clear the history once (and only once) when the stream
        finishes.
        """
        collected: list[str] = []

        def _take(item: Any) -> str:
            if return_full_completion:
                # ChatCompletionChunk → extract incremental delta
                try:
                    delta = item.choices[0].delta.content
                    return delta or ""  # may be None
                except Exception:  # noqa: BLE001
                    return ""
            return str(item)

        try:
            for chunk in stream:
                if stateful:
                    piece = _take(chunk)
                    if piece:
                        collected.append(piece)
                yield chunk
        finally:  # executes on normal end *and* on .close()
            if stateful:
                if collected:
                    self._append_to_history(
                        {"role": "assistant", "content": "".join(collected).strip()},
                    )
            elif self._messages:
                self._messages.clear()

    def _wrap_async_stream(  # noqa: WPS231
        self,
        stream: AsyncGenerator[Any, None],
        *,
        stateful: bool,
        return_full_completion: bool,
    ) -> AsyncGenerator[Any, None]:
        """
        Same as `_wrap_sync_stream` but for *async* generators.
        """
        collected: list[str] = []

        async def _internal():
            async for chunk in stream:
                if stateful:
                    if return_full_completion:
                        try:
                            delta = chunk.choices[0].delta.content
                            if delta:
                                collected.append(delta)
                        except Exception:  # noqa: BLE001
                            pass
                    else:
                        collected.append(str(chunk))
                yield chunk

            # async-generator exhausted
            if stateful:
                if collected:
                    self._append_to_history(
                        {"role": "assistant", "content": "".join(collected).strip()},
                    )
            elif self._messages:
                self._messages.clear()

        return _internal()

    # --------------------------------------------------------------------- #
    #  Single place that decides which helper to call                       #
    # --------------------------------------------------------------------- #

    def _apply_stateful_logic(  # noqa: WPS231,WPS211
        self,
        *,
        response: Any,
        stateful: bool,
        was_stream: bool,
        return_full_completion: bool,
    ) -> Any:
        """
        Ensures the conversation history is updated (or cleared) **once** per
        call, for all four modalities.
        """

        if was_stream:

            if inspect.iscoroutine(response):

                async def _await_then_wrap(coro):
                    inner = await coro  # real result from _generate

                    # inner is expected to be async-gen, but handle sync-gen too
                    if inspect.isasyncgen(inner):
                        return self._wrap_async_stream(
                            inner,
                            stateful=stateful,
                            return_full_completion=return_full_completion,
                        )
                    if isinstance(inner, (list, tuple)) or inspect.isgenerator(inner):
                        # rare case: provider gave back sync generator
                        return self._wrap_sync_stream(
                            inner,
                            stateful=stateful,
                            return_full_completion=return_full_completion,
                        )
                    # not a generator at all – treat like non-stream single result
                    return self._apply_stateful_logic(
                        response=inner,
                        stateful=stateful,
                        was_stream=False,
                        return_full_completion=return_full_completion,
                    )

                # Return *the coroutine itself* so the caller still needs to `await`
                return _await_then_wrap(response)

            # choose correct wrapper (sync vs async)
            if inspect.isasyncgen(response):
                return self._wrap_async_stream(
                    response,
                    stateful=stateful,
                    return_full_completion=return_full_completion,
                )
            return self._wrap_sync_stream(
                response,
                stateful=stateful,
                return_full_completion=return_full_completion,
            )

        # ───── coroutine (async-non-stream) path ──────────────────────────
        if inspect.iscoroutine(response):

            # 1. Capture the index where the assistant reply will go, but
            #    **do not** insert the placeholder yet.  This avoids sending an
            #    empty assistant message to the LLM while still fixing order.
            placeholder_idx: int | None = len(self._messages) if stateful else None

            # 2. Await the real coroutine and overwrite the placeholder in-place
            async def _await_and_process(coro: Coroutine[Any, Any, Any]):
                try:
                    res = await coro
                except Exception:
                    # nothing was inserted yet, just re-raise
                    raise

                if stateful and placeholder_idx is not None:
                    if return_full_completion:
                        self._messages.insert(
                            placeholder_idx,
                            res.choices[0].message.model_dump(),
                        )
                    else:
                        self._messages.insert(
                            placeholder_idx,
                            {"role": "assistant", "content": str(res)},
                        )
                elif self._messages:
                    self._messages.clear()
                return res

            return _await_and_process(response)

        # ---------- non-streaming path ----------
        if stateful:
            if return_full_completion:
                assistant_dict = response.choices[0].message.model_dump()
            else:
                assistant_dict = {"role": "assistant", "content": str(response)}
            self._append_to_history(assistant_dict)
        elif self._messages:
            self._messages.clear()

        return response

    # Abstract #
    # ---------#

    @abc.abstractmethod
    def _get_client(self):
        raise NotImplementedError

    # Generate #
    # ---------#

    def generate(
        self,
        user_message: Optional[str] = None,
        system_message: Optional[str] = None,
        messages: Optional[
            Union[
                List[ChatCompletionMessageParam],
                Dict[str, List[ChatCompletionMessageParam]],
            ]
        ] = None,
        *,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[Dict[str, int]] = None,
        logprobs: Optional[bool] = None,
        top_logprobs: Optional[int] = None,
        max_completion_tokens: Optional[int] = None,
        n: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        response_format: Optional[Union[Type[BaseModel], Dict[str, str]]] = None,
        seed: Optional[int] = None,
        stop: Union[Optional[str], List[str]] = None,
        stream: Optional[bool] = None,
        stream_options: Optional[ChatCompletionStreamOptionsParam] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        tools: Optional[Iterable[ChatCompletionToolParam]] = None,
        tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None,
        parallel_tool_calls: Optional[bool] = None,
        reasoning_effort: Optional[str] = None,
        # platform arguments
        use_custom_keys: Optional[bool] = None,
        tags: Optional[List[str]] = None,
        drop_params: Optional[bool] = None,
        region: Optional[str] = None,
        log_query_body: Optional[bool] = None,
        log_response_body: Optional[bool] = None,
        # python client arguments
        stateful: Optional[bool] = None,
        return_full_completion: Optional[bool] = None,
        cache: Optional[Union[bool, str]] = None,
        cache_backend: Optional[str] = None,
        # passthrough arguments
        extra_headers: Optional[Headers] = None,
        extra_query: Optional[Query] = None,
        service_tier: Optional[str] = None,
        **kwargs,
    ):
        """Generate a ChatCompletion response for the specified endpoint,
        from the provided query parameters.

        Args:
            user_message: A string containing the user message.
            If provided, messages must be None.

            system_message: An optional string containing the system message. This
            always appears at the beginning of the list of messages.

            messages: A list of messages comprising the conversation so far, or
            optionally a dictionary of such messages, with clients as the keys in the
            case of multi-llm clients. This will be appended to the system_message if it
            is not None, and any user_message will be appended if it is not None.

            frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new
            tokens based on their existing frequency in the text so far, decreasing the
            model's likelihood to repeat the same line verbatim.

            logit_bias: Modify the likelihood of specified tokens appearing in the
            completion. Accepts a JSON object that maps tokens (specified by their token
            ID in the tokenizer) to an associated bias value from -100 to 100.
            Mathematically, the bias is added to the logits generated by the model prior
            to sampling. The exact effect will vary per model, but values between -1 and
            1 should decrease or increase likelihood of selection; values like -100 or
            100 should result in a ban or exclusive selection of the relevant token.

            logprobs: Whether to return log probabilities of the output tokens or not.
            If true, returns the log probabilities of each output token returned in the
            content of message.

            top_logprobs: An integer between 0 and 20 specifying the number of most
            likely tokens to return at each token position, each with an associated log
            probability. logprobs must be set to true if this parameter is used.

            max_completion_tokens: The maximum number of tokens that can be generated in
            the chat completion. The total length of input tokens and generated tokens
            is limited by the model's context length. Defaults value is None. Uses the
            provider's default max_completion_tokens when None is explicitly passed.

            n: How many chat completion choices to generate for each input message. Note
            that you will be charged based on the number of generated tokens across all
            of the choices. Keep n as 1 to minimize costs.

            presence_penalty: Number between -2.0 and 2.0. Positive values penalize new
            tokens based on whether they appear in the text so far, increasing the
            model's likelihood to talk about new topics.

            response_format: An object specifying the format that the model must output.
            Setting to `{ "type": "json_schema", "json_schema": {...} }` enables
            Structured Outputs which ensures the model will match your supplied JSON
            schema. Learn more in the Structured Outputs guide. Setting to
            `{ "type": "json_object" }` enables JSON mode, which ensures the message the
            model generates is valid JSON.

            seed: If specified, a best effort attempt is made to sample
            deterministically, such that repeated requests with the same seed and
            parameters should return the same result. Determinism is not guaranteed, and
            you should refer to the system_fingerprint response parameter to monitor
            changes in the backend.

            stop: Up to 4 sequences where the API will stop generating further tokens.

            stream: If True, generates content as a stream. If False, generates content
            as a single response. Defaults to False.

            stream_options: Options for streaming response. Only set this when you set
            stream: true.

            temperature:  What sampling temperature to use, between 0 and 2.
            Higher values like 0.8 will make the output more random,
            while lower values like 0.2 will make it more focused and deterministic.
            It is generally recommended to alter this or top_p, but not both.
            Default value is 1.0. Defaults to the provider's default temperature when
            None is explicitly passed.

            top_p: An alternative to sampling with temperature, called nucleus sampling,
            where the model considers the results of the tokens with top_p probability
            mass. So 0.1 means only the tokens comprising the top 10% probability mass
            are considered. Generally recommended to alter this or temperature, but not
            both.

            tools: A list of tools the model may call. Currently, only functions are
            supported as a tool. Use this to provide a list of functions the model may
            generate JSON inputs for. A max of 128 functions are supported.

            tool_choice: Controls which (if any) tool is called by the
            model. none means the model will not call any tool and instead generates a
            message. auto means the model can pick between generating a message or
            calling one or more tools. required means the model must call one or more
            tools. Specifying a particular tool via
            `{ "type": "function", "function": {"name": "my_function"} }`
            forces the model to call that tool.
            none is the default when no tools are present. auto is the default if tools
            are present.

            parallel_tool_calls: Whether to enable parallel function calling during tool
            use.

            use_custom_keys:  Whether to use custom API keys or our unified API keys
            with the backend provider. Defaults to False.

            tags: Arbitrary number of tags to classify this API query as needed. Helpful
            for generally grouping queries across tasks and users, for logging purposes.

            drop_params: Whether or not to drop unsupported OpenAI params by the
            provider you’re using.

            region: A string used to represent the region where the endpoint is
            accessed. Only relevant for on-prem deployments with certain providers like
            `vertex-ai`, `aws-bedrock` and `azure-ml`, where the endpoint is being
            accessed through a specified region.

            log_query_body: Whether to log the contents of the query json body.

            log_response_body: Whether to log the contents of the response json body.

            stateful:  Whether the conversation history is preserved within the messages
            of this client. If True, then history is preserved. If False, then this acts
            as a stateless client, and message histories must be managed by the user.

            return_full_completion: If False, only return the message content
            chat_completion.choices[0].message.content.strip(" ") from the OpenAI
            return. Otherwise, the full response chat_completion is returned.
            Defaults to False.

            cache: If True, then the arguments will be stored in a local cache file, and
            any future calls with identical arguments will read from the cache instead
            of running the LLM query. If "write" then the cache will only be written
            to, if "read" then the cache will be read from if a cache is available but
            will not write, and if "read-only" then the argument must be present in the
            cache, else an exception will be raised. Finally, an appending "-closest"
            will read the closest match from the cache, and overwrite it if cache writing
            is enabled. This argument only has any effect when stream=False.

            extra_headers: Additional "passthrough" headers for the request which are
            provider-specific, and are not part of the OpenAI standard. They are handled
            by the provider-specific API.

            extra_query: Additional "passthrough" query parameters for the request which
            are provider-specific, and are not part of the OpenAI standard. They are
            handled by the provider-specific API.

            kwargs: Additional "passthrough" JSON properties for the body of the
            request, which are provider-specific, and are not part of the OpenAI
            standard. They will be handled by the provider-specific API.

        Returns:
            If stream is True, returns a generator yielding chunks of content.
            If stream is False, returns a single string response.

        Raises:
            UnifyError: If an error occurs during content generation.
        """
        system_message = _default(system_message, self._system_message)
        messages = _default(messages, self._messages)
        stateful = _default(stateful, self._stateful)
        if messages:
            sys_msg_inside = any(msg["role"] == "system" for msg in messages)
            if not sys_msg_inside and system_message is not None:
                messages = [
                    {"role": "system", "content": system_message},
                ] + messages
            if user_message is not None:
                messages += [{"role": "user", "content": user_message}]
        else:
            messages = list()
            if system_message is not None:
                messages += [{"role": "system", "content": system_message}]
            if user_message is not None:
                messages += [{"role": "user", "content": user_message}]
            self._messages = messages
        return_full_completion = (
            True
            if _default(tools, self._tools)
            else _default(return_full_completion, self._return_full_completion)
        )
        cache = _default(cache, self._cache)
        _cache_modes = ["read", "read-only", "write", "both"]
        assert cache in _cache_modes + [m + "-closest" for m in _cache_modes] + [
            True,
            False,
            None,
        ]
        ret = self._generate(
            messages=messages,
            frequency_penalty=_default(frequency_penalty, self._frequency_penalty),
            logit_bias=_default(logit_bias, self._logit_bias),
            logprobs=_default(logprobs, self._logprobs),
            top_logprobs=_default(top_logprobs, self._top_logprobs),
            max_completion_tokens=_default(
                max_completion_tokens,
                self._max_completion_tokens,
            ),
            n=_default(n, self._n),
            presence_penalty=_default(presence_penalty, self._presence_penalty),
            response_format=_default(response_format, self._response_format),
            seed=_default(_default(seed, self._seed), unify.get_seed()),
            stop=_default(stop, self._stop),
            stream=_default(stream, self._stream),
            stream_options=_default(stream_options, self._stream_options),
            temperature=_default(temperature, self._temperature),
            top_p=_default(top_p, self._top_p),
            service_tier=_default(service_tier, self._service_tier),
            tools=_default(tools, self._tools),
            tool_choice=_default(tool_choice, self._tool_choice),
            parallel_tool_calls=_default(
                parallel_tool_calls,
                self._parallel_tool_calls,
            ),
            reasoning_effort=_default(reasoning_effort, self._reasoning_effort),
            # platform arguments
            use_custom_keys=_default(use_custom_keys, self._use_custom_keys),
            tags=_default(tags, self._tags),
            drop_params=_default(drop_params, self._drop_params),
            region=_default(region, self._region),
            log_query_body=_default(log_query_body, self._log_query_body),
            log_response_body=_default(log_response_body, self._log_response_body),
            # python client arguments
            return_full_completion=return_full_completion,
            cache=_default(cache, is_caching_enabled()),
            cache_backend=_default(cache_backend, self._cache_backend),
            # passthrough arguments
            extra_headers=_default(extra_headers, self._extra_headers),
            extra_query=_default(extra_query, self._extra_query),
            **{**self._extra_body, **kwargs},
        )
        ret = self._apply_stateful_logic(
            response=ret,
            stateful=stateful,
            was_stream=_default(stream, self._stream),
            return_full_completion=return_full_completion,
        )
        return ret


class Unify(_UniClient):
    """Class for interacting with the Unify chat completions endpoint in a synchronous
    manner."""

    def _get_client(self):
        try:
            http_client = make_httpx_client_for_unify_logging(BASE_URL)
            return openai.OpenAI(
                base_url=f"{BASE_URL}",
                api_key=self._api_key,
                timeout=3600.0,  # one hour
                http_client=http_client,
            )
        except openai.OpenAIError as e:
            raise Exception(f"Failed to initialize Unify client: {str(e)}")

    def _generate_stream(
        self,
        endpoint: str,
        prompt: Prompt,
        # stream
        stream_options: Optional[ChatCompletionStreamOptionsParam],
        # platform arguments
        use_custom_keys: bool,
        tags: Optional[List[str]],
        drop_params: Optional[bool],
        region: Optional[str],
        log_query_body: Optional[bool],
        log_response_body: Optional[bool],
        # python client arguments
        return_full_completion: bool,
    ) -> Generator[str, None, None]:
        kw = self._handle_kw(
            prompt=prompt,
            endpoint=endpoint,
            stream=True,
            stream_options=stream_options,
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
        )
        try:
            if endpoint in LOCAL_MODELS:
                kw.pop("extra_body")
                kw.pop("model")
                kw.pop("max_completion_tokens")
                chat_completion = LOCAL_MODELS[endpoint](**kw)
            else:
                if unify.CLIENT_LOGGING:
                    print(f"calling {kw['model']}... (thread {threading.get_ident()})")
                if self.traced:
                    chat_completion = unify.traced(
                        self._client.chat.completions.create,
                        span_type="llm-stream",
                        name=(
                            endpoint
                            if tags is None
                            else endpoint + "[" + ",".join([str(t) for t in tags]) + "]"
                        ),
                    )(**kw)
                else:
                    chat_completion = self._client.chat.completions.create(**kw)
                if unify.CLIENT_LOGGING:
                    print(f"done (thread {threading.get_ident()})")
            for chunk in chat_completion:
                if return_full_completion:
                    content = chunk
                else:
                    content = chunk.choices[0].delta.content  # type: ignore[union-attr]    # noqa: E501
                if content is not None:
                    yield content
        except openai.APIStatusError as e:
            raise Exception(e.message)

    def _generate_non_stream(
        self,
        endpoint: str,
        prompt: Prompt,
        # platform arguments
        use_custom_keys: bool,
        tags: Optional[List[str]],
        drop_params: Optional[bool],
        region: Optional[str],
        log_query_body: Optional[bool],
        log_response_body: Optional[bool],
        # python client arguments
        return_full_completion: bool,
        cache: Union[bool, str],
        cache_backend: str,
    ) -> Union[str, ChatCompletion]:
        kw = self._handle_kw(
            prompt=prompt,
            endpoint=endpoint,
            stream=False,
            stream_options=None,
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
        )
        if isinstance(cache, str) and cache.endswith("-closest"):
            cache = cache.removesuffix("-closest")
            read_closest = True
        else:
            read_closest = False
        if "response_format" in kw:
            chat_method = self._client.beta.chat.completions.parse
            del kw["stream"]
        elif endpoint == "user-input":
            chat_method = lambda *a, **kw: input("write your agent response:\n")
        else:
            chat_method = self._client.chat.completions.create
        chat_completion = None
        in_cache = False
        if cache in [True, "both", "read", "read-only"]:
            if self._traced:

                def _get_cache_traced(**kw):
                    return _get_cache(
                        fn_name="chat.completions.create",
                        kw=kw,
                        raise_on_empty=cache == "read-only",
                        read_closest=read_closest,
                        delete_closest=read_closest,
                        backend=cache_backend,
                    )

                chat_completion = unify.traced(
                    _get_cache_traced,
                    span_type="llm-cached",
                    name=(
                        endpoint
                        if tags is None
                        else endpoint + "[" + ",".join([str(t) for t in tags]) + "]"
                    ),
                )(**kw)
            else:
                chat_completion = _get_cache(
                    fn_name="chat.completions.create",
                    kw=kw,
                    raise_on_empty=cache == "read-only",
                    read_closest=read_closest,
                    delete_closest=read_closest,
                    backend=cache_backend,
                )
                in_cache = True if chat_completion is not None else False
        if chat_completion is None:
            try:
                if endpoint in LOCAL_MODELS:
                    kw.pop("extra_body")
                    kw.pop("model")
                    kw.pop("max_completion_tokens")
                    chat_completion = LOCAL_MODELS[endpoint](**kw)
                else:
                    if unify.CLIENT_LOGGING:
                        print(
                            f"calling {kw['model']}... (thread {threading.get_ident()})",
                        )
                    if self._traced:
                        chat_completion = unify.traced(
                            chat_method,
                            span_type="llm",
                            name=(
                                endpoint
                                if tags is None
                                else endpoint
                                + "["
                                + ",".join([str(t) for t in tags])
                                + "]"
                            ),
                        )(**kw)
                    else:
                        chat_completion = chat_method(**kw)
                    if unify.CLIENT_LOGGING:
                        print(f"done (thread {threading.get_ident()})")
            except openai.APIStatusError as e:
                raise Exception(e.message)
        if (chat_completion is not None or read_closest) and cache in [
            True,
            "both",
            "write",
        ]:
            if not in_cache or cache == "write":
                _write_to_cache(
                    fn_name="chat.completions.create",
                    kw=kw,
                    response=chat_completion,
                    backend=cache_backend,
                )
        if return_full_completion:
            if endpoint == "user-input":
                input_msg = sum(len(msg) for msg in prompt.components["messages"])
                return ChatCompletion(
                    id=str(uuid.uuid4()),
                    object="chat.completion",
                    created=int(time.time()),
                    model=endpoint,
                    choices=[
                        Choice(
                            index=0,
                            message=ChatCompletionMessage(
                                role="assistant",
                                content=chat_completion,
                            ),
                            finish_reason="stop",
                        ),
                    ],
                    usage=CompletionUsage(
                        prompt_tokens=input_msg,
                        completion_tokens=len(chat_completion),
                        total_tokens=input_msg + len(chat_completion),
                    ),
                )
            return chat_completion
        elif endpoint == "user-input":
            return chat_completion
        content = chat_completion.choices[0].message.content
        if content:
            return content.strip(" ")
        return ""

    def _generate(  # noqa: WPS234, WPS211
        self,
        messages: Optional[List[ChatCompletionMessageParam]],
        *,
        frequency_penalty: Optional[float],
        logit_bias: Optional[Dict[str, int]],
        logprobs: Optional[bool],
        top_logprobs: Optional[int],
        max_completion_tokens: Optional[int],
        n: Optional[int],
        presence_penalty: Optional[float],
        response_format: Optional[Union[Type[BaseModel], Dict[str, str]]],
        seed: Optional[int],
        stop: Union[Optional[str], List[str]],
        stream: Optional[bool],
        stream_options: Optional[ChatCompletionStreamOptionsParam],
        temperature: Optional[float],
        top_p: Optional[float],
        service_tier: Optional[str],
        tools: Optional[Iterable[ChatCompletionToolParam]],
        tool_choice: Optional[ChatCompletionToolChoiceOptionParam],
        parallel_tool_calls: Optional[bool],
        reasoning_effort: Optional[str],
        # platform arguments
        use_custom_keys: bool,
        tags: Optional[List[str]],
        drop_params: Optional[bool],
        region: Optional[str],
        log_query_body: Optional[bool],
        log_response_body: Optional[bool],
        # python client arguments
        return_full_completion: bool,
        cache: Union[bool, str],
        cache_backend: str,
        # passthrough arguments
        extra_headers: Optional[Headers],
        extra_query: Optional[Query],
        **kwargs,
    ) -> Union[Generator[str, None, None], str]:  # noqa: DAR101, DAR201, DAR401
        prompt = Prompt(
            messages=messages,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            max_completion_tokens=max_completion_tokens,
            n=n,
            presence_penalty=presence_penalty,
            response_format=response_format,
            seed=seed,
            stop=stop,
            temperature=temperature,
            top_p=top_p,
            service_tier=service_tier,
            tools=tools,
            tool_choice=tool_choice,
            parallel_tool_calls=parallel_tool_calls,
            reasoning_effort=reasoning_effort,
            extra_headers=extra_headers,
            extra_query=extra_query,
            extra_body=kwargs,
        )
        if stream:
            return self._generate_stream(
                self._endpoint,
                prompt,
                # stream
                stream_options=stream_options,
                # platform arguments
                use_custom_keys=use_custom_keys,
                tags=tags,
                drop_params=drop_params,
                region=region,
                log_query_body=log_query_body,
                log_response_body=log_response_body,
                # python client arguments
                return_full_completion=return_full_completion,
            )
        return self._generate_non_stream(
            self._endpoint,
            prompt,
            # platform arguments
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
            # python client arguments
            return_full_completion=return_full_completion,
            cache=cache,
            cache_backend=cache_backend,
        )

    def to_async_client(self):
        """
        Return an asynchronous version of the client (`AsyncUnify` instance), with the
        exact same configuration as this synchronous (`Unify`) client.

        Returns:
            An `AsyncUnify` instance with the same configuration as this `Unify`
            instance.
        """
        return AsyncUnify(**self._constructor_args)


class AsyncUnify(_UniClient):
    """Class for interacting with the Unify chat completions endpoint in a synchronous
    manner."""

    def _get_client(self):
        try:
            # Async event hooks must use AsyncClient
            http_client = make_async_httpx_client_for_unify_logging(BASE_URL)
            return openai.AsyncOpenAI(
                base_url=f"{BASE_URL}",
                api_key=self._api_key,
                timeout=3600.0,  # one hour
                http_client=http_client,
            )
        except openai.APIStatusError as e:
            raise Exception(f"Failed to initialize Unify client: {str(e)}")

    async def _generate_stream(
        self,
        endpoint: str,
        prompt: Prompt,
        # stream
        stream_options: Optional[ChatCompletionStreamOptionsParam],
        # platform arguments
        use_custom_keys: bool,
        tags: Optional[List[str]],
        drop_params: Optional[bool],
        region: Optional[str],
        log_query_body: Optional[bool],
        log_response_body: Optional[bool],
        # python client arguments
        return_full_completion: bool,
    ) -> AsyncGenerator[str, None]:
        kw = self._handle_kw(
            prompt=prompt,
            endpoint=endpoint,
            stream=True,
            stream_options=stream_options,
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
        )
        try:
            if endpoint in LOCAL_MODELS:
                kw.pop("extra_body")
                kw.pop("model")
                kw.pop("max_completion_tokens")
                async_stream = await LOCAL_MODELS[endpoint](**kw)
            else:
                if unify.CLIENT_LOGGING:
                    print(f"calling {kw['model']}... (thread {threading.get_ident()})")
                if self._traced:
                    # ToDo: test if this works, it probably won't
                    async_stream = await unify.traced(
                        self._client.chat.completions.create,
                        span_type="llm-stream",
                        name=(
                            endpoint
                            if tags is None
                            else endpoint + "[" + ",".join([str(t) for t in tags]) + "]"
                        ),
                    )(**kw)
                else:
                    async_stream = await self._client.chat.completions.create(**kw)
                if unify.CLIENT_LOGGING:
                    print(f"done (thread {threading.get_ident()})")
            async for chunk in async_stream:  # type: ignore[union-attr]
                if return_full_completion:
                    yield chunk
                else:
                    yield chunk.choices[0].delta.content or ""
        except openai.APIStatusError as e:
            raise Exception(e.message)

    async def _generate_non_stream(
        self,
        endpoint: str,
        prompt: Prompt,
        # platform arguments
        use_custom_keys: bool,
        tags: Optional[List[str]],
        drop_params: Optional[bool],
        region: Optional[str],
        log_query_body: Optional[bool],
        log_response_body: Optional[bool],
        # python client arguments
        return_full_completion: bool,
        cache: Union[bool, str],
        cache_backend: str,
    ) -> Union[str, ChatCompletion]:
        kw = self._handle_kw(
            prompt=prompt,
            endpoint=endpoint,
            stream=False,
            stream_options=None,
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
        )
        if isinstance(cache, str) and cache.endswith("-closest"):
            cache = cache.removesuffix("-closest")
            read_closest = True
        else:
            read_closest = False
        if "response_format" in kw and kw["response_format"]:
            chat_method = self._client.beta.chat.completions.parse
            if "stream" in kw:
                del kw["stream"]  # .parse() does not accept the stream argument
        else:
            chat_method = self._client.chat.completions.create
        chat_completion = None
        in_cache = False
        if cache in [True, "both", "read", "read-only"]:
            if self._traced:

                def _get_cache_traced(**kw):
                    return _get_cache(
                        fn_name="chat.completions.create",
                        kw=kw,
                        raise_on_empty=cache == "read-only",
                        read_closest=read_closest,
                        delete_closest=read_closest,
                        backend=cache_backend,
                    )

                chat_completion = unify.traced(
                    _get_cache_traced,
                    span_type="llm-cached",
                    name=(
                        endpoint
                        if tags is None
                        else endpoint + "[" + ",".join([str(t) for t in tags]) + "]"
                    ),
                )(**kw)
            else:
                chat_completion = _get_cache(
                    fn_name="chat.completions.create",
                    kw=kw,
                    raise_on_empty=cache == "read-only",
                    read_closest=read_closest,
                    delete_closest=read_closest,
                    backend=cache_backend,
                )
                in_cache = True if chat_completion is not None else False
        if chat_completion is None:
            try:
                if endpoint in LOCAL_MODELS:
                    kw.pop("extra_body")
                    kw.pop("model")
                    kw.pop("max_completion_tokens")
                    chat_completion = await LOCAL_MODELS[endpoint](**kw)
                else:
                    if unify.CLIENT_LOGGING:
                        print(
                            f"calling {kw['model']}... (thread {threading.get_ident()})",
                        )
                    if self.traced:
                        chat_completion = await unify.traced(
                            chat_method,
                            span_type="llm",
                            name=(
                                endpoint
                                if tags is None
                                else endpoint
                                + "["
                                + ",".join([str(t) for t in tags])
                                + "]"
                            ),
                            fn_type="async",
                        )(**kw)
                    else:
                        chat_completion = await chat_method(
                            **kw,
                        )
                    if unify.CLIENT_LOGGING:
                        print(
                            f"done (thread {threading.get_ident()})",
                        )
            except openai.APIStatusError as e:
                raise Exception(e.message)
        if (chat_completion is not None or read_closest) and cache in [
            True,
            "both",
            "write",
        ]:
            if not in_cache or cache == "write":
                _write_to_cache(
                    fn_name="chat.completions.create",
                    kw=kw,
                    response=chat_completion,
                    backend=cache_backend,
                )
        if return_full_completion:
            return chat_completion
        content = chat_completion.choices[0].message.content
        if content:
            return content.strip(" ")
        return ""

    async def _generate(  # noqa: WPS234, WPS211
        self,
        messages: Optional[List[ChatCompletionMessageParam]],
        *,
        frequency_penalty: Optional[float],
        logit_bias: Optional[Dict[str, int]],
        logprobs: Optional[bool],
        top_logprobs: Optional[int],
        max_completion_tokens: Optional[int],
        n: Optional[int],
        presence_penalty: Optional[float],
        response_format: Optional[Union[Type[BaseModel], Dict[str, str]]],
        seed: Optional[int],
        stop: Union[Optional[str], List[str]],
        stream: Optional[bool],
        stream_options: Optional[ChatCompletionStreamOptionsParam],
        temperature: Optional[float],
        top_p: Optional[float],
        tools: Optional[Iterable[ChatCompletionToolParam]],
        tool_choice: Optional[ChatCompletionToolChoiceOptionParam],
        parallel_tool_calls: Optional[bool],
        reasoning_effort: Optional[str],
        # platform arguments
        use_custom_keys: bool,
        tags: Optional[List[str]],
        drop_params: Optional[bool],
        region: Optional[str],
        log_query_body: Optional[bool],
        log_response_body: Optional[bool],
        # python client arguments
        return_full_completion: bool,
        cache: Union[bool, str],
        cache_backend: str,
        # passthrough arguments
        extra_headers: Optional[Headers],
        extra_query: Optional[Query],
        service_tier: Optional[str] = None,
        **kwargs,
    ) -> Union[AsyncGenerator[str, None], str]:  # noqa: DAR101, DAR201, DAR401
        prompt = Prompt(
            messages=messages,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            max_completion_tokens=max_completion_tokens,
            n=n,
            presence_penalty=presence_penalty,
            response_format=response_format,
            seed=seed,
            stop=stop,
            temperature=temperature,
            top_p=top_p,
            tools=tools,
            tool_choice=tool_choice,
            parallel_tool_calls=parallel_tool_calls,
            extra_headers=extra_headers,
            extra_query=extra_query,
            extra_body=kwargs,
            reasoning_effort=reasoning_effort,
            service_tier=service_tier,
        )
        if stream:
            return self._generate_stream(
                self._endpoint,
                prompt,
                # stream
                stream_options=stream_options,
                # platform arguments
                use_custom_keys=use_custom_keys,
                tags=tags,
                drop_params=drop_params,
                region=region,
                log_query_body=log_query_body,
                log_response_body=log_response_body,
                # python client arguments
                return_full_completion=return_full_completion,
            )
        return await self._generate_non_stream(
            self._endpoint,
            prompt,
            # platform arguments
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
            # python client arguments
            return_full_completion=return_full_completion,
            cache=cache,
            cache_backend=cache_backend,
        )

    def to_sync_client(self):
        """
        Return a synchronous version of the client (`Unify` instance), with the
        exact same configuration as this asynchronous (`AsyncUnify`) client.

        Returns:
            A `Unify` instance with the same configuration as this `AsyncUnify`
            instance.
        """
        return Unify(**self._constructor_args)

    async def close(self):
        """
        Close the underlying client.
        """
        await self._client.close()
