# global
import abc
import asyncio
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

import requests

# local
import unify

# noinspection PyProtectedMember
from openai._types import Headers, Query
from openai.types.chat import (
    ChatCompletion,
    ChatCompletionMessageParam,
    ChatCompletionStreamOptionsParam,
    ChatCompletionToolChoiceOptionParam,
    ChatCompletionToolParam,
)
from pydantic import BaseModel
from typing_extensions import Self
from unify import BASE_URL
from unify.utils import _requests

# noinspection PyProtectedMember
from unify.utils.helpers import _default, _validate_api_key

from ..clients import AsyncUnify, _Client, _UniClient


class _MultiClient(_Client, abc.ABC):
    def __init__(
        self,
        endpoints: Optional[Union[str, Iterable[str]]] = None,
        *,
        system_message: Optional[str] = None,
        messages: Optional[
            Union[
                List[ChatCompletionMessageParam],
                Dict[str, List[ChatCompletionMessageParam]],
            ]
        ] = None,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[Dict[str, int]] = None,
        logprobs: Optional[bool] = None,
        top_logprobs: Optional[int] = None,
        max_completion_tokens: Optional[int] = None,
        n: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        response_format: Optional[Union[Type[BaseModel], Dict[str, str]]] = None,
        seed: Optional[int] = None,
        stop: Union[Optional[str], List[str]] = None,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = None,
        service_tier: Optional[str] = None,
        tools: Optional[Iterable[ChatCompletionToolParam]] = None,
        tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None,
        parallel_tool_calls: Optional[bool] = None,
        reasoning_effort: Optional[str] = None,
        # platform arguments
        use_custom_keys: bool = False,
        tags: Optional[List[str]] = None,
        drop_params: Optional[bool] = True,
        region: Optional[str] = None,
        log_query_body: Optional[bool] = True,
        log_response_body: Optional[bool] = True,
        api_key: Optional[str] = None,
        # python client arguments
        stateful: bool = False,
        return_full_completion: bool = False,
        traced: bool = False,
        cache: Union[bool, str] = None,
        cache_backend: Optional[str] = None,
        # passthrough arguments
        extra_headers: Optional[Headers] = None,
        extra_query: Optional[Query] = None,
        **kwargs,
    ) -> None:
        """Initialize the Multi LLM Unify client.

        Args:
            endpoints: A single endpoint name or a list of endpoint names, with each name
            in OpenAI API format: <model_name>@<provider_name>. Defaults to None.

            system_message: An optional string containing the system message. This
            always appears at the beginning of the list of messages.

            messages: A list of messages comprising the conversation so far. This will
            be appended to the system_message if it is not None, and any user_message
            will be appended if it is not None.

            frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new
            tokens based on their existing frequency in the text so far, decreasing the
            model's likelihood to repeat the same line verbatim.

            logit_bias: Modify the likelihood of specified tokens appearing in the
            completion. Accepts a JSON object that maps tokens (specified by their token
            ID in the tokenizer) to an associated bias value from -100 to 100.
            Mathematically, the bias is added to the logits generated by the model prior
            to sampling. The exact effect will vary per model, but values between -1 and
            1 should decrease or increase likelihood of selection; values like -100 or
            100 should result in a ban or exclusive selection of the relevant token.

            logprobs: Whether to return log probabilities of the output tokens or not.
            If true, returns the log probabilities of each output token returned in the
            content of message.

            top_logprobs: An integer between 0 and 20 specifying the number of most
            likely tokens to return at each token position, each with an associated log
            probability. logprobs must be set to true if this parameter is used.

            max_completion_tokens: The maximum number of tokens that can be generated in
            the chat completion. The total length of input tokens and generated tokens
            is limited by the model's context length. Defaults to the provider's default
            max_completion_tokens when the value is None.

            n: How many chat completion choices to generate for each input message. Note
            that you will be charged based on the number of generated tokens across all
            of the choices. Keep n as 1 to minimize costs.

            presence_penalty: Number between -2.0 and 2.0. Positive values penalize new
            tokens based on whether they appear in the text so far, increasing the
            model's likelihood to talk about new topics.

            response_format: An object specifying the format that the model must output.
            Setting to `{ "type": "json_schema", "json_schema": {...} }` enables
            Structured Outputs which ensures the model will match your supplied JSON
            schema. Learn more in the Structured Outputs guide. Setting to
            `{ "type": "json_object" }` enables JSON mode, which ensures the message the
            model generates is valid JSON.

            seed: If specified, a best effort attempt is made to sample
            deterministically, such that repeated requests with the same seed and
            parameters should return the same result. Determinism is not guaranteed, and
            you should refer to the system_fingerprint response parameter to monitor
            changes in the backend.

            stop: Up to 4 sequences where the API will stop generating further tokens.

            temperature:  What sampling temperature to use, between 0 and 2.
            Higher values like 0.8 will make the output more random,
            while lower values like 0.2 will make it more focused and deterministic.
            It is generally recommended to alter this or top_p, but not both.
            Defaults to the provider's default max_completion_tokens when the value is
            None.

            top_p: An alternative to sampling with temperature, called nucleus sampling,
            where the model considers the results of the tokens with top_p probability
            mass. So 0.1 means only the tokens comprising the top 10% probability mass
            are considered. Generally recommended to alter this or temperature, but not
            both.

            tools: A list of tools the model may call. Currently, only functions are
            supported as a tool. Use this to provide a list of functions the model may
            generate JSON inputs for. A max of 128 functions are supported.

            tool_choice: Controls which (if any) tool is called by the
            model. none means the model will not call any tool and instead generates a
            message. auto means the model can pick between generating a message or
            calling one or more tools. required means the model must call one or more
            tools. Specifying a particular tool via
            `{ "type": "function", "function": {"name": "my_function"} }`
            forces the model to call that tool.
            none is the default when no tools are present. auto is the default if tools
            are present.

            parallel_tool_calls: Whether to enable parallel function calling during tool
            use.

            use_custom_keys:  Whether to use custom API keys or our unified API keys
            with the backend provider.

            tags: Arbitrary number of tags to classify this API query as needed. Helpful
            for generally grouping queries across tasks and users, for logging purposes.

            drop_params: Whether or not to drop unsupported OpenAI params by the
            provider you’re using.

            region: A string used to represent the region where the endpoint is
            accessed. Only relevant for on-prem deployments with certain providers like
            `vertex-ai`, `aws-bedrock` and `azure-ml`, where the endpoint is being
            accessed through a specified region.

            log_query_body: Whether to log the contents of the query json body.

            log_response_body: Whether to log the contents of the response json body.

            stateful:  Whether the conversation history is preserved within the messages
            of this client. If True, then history is preserved. If False, then this acts
            as a stateless client, and message histories must be managed by the user.

            return_full_completion: If False, only return the message content
            chat_completion.choices[0].message.content.strip(" ") from the OpenAI
            return. Otherwise, the full response chat_completion is returned.
            Defaults to False.

            traced: Whether to trace the generate method.

            cache: If True, then the arguments will be stored in a local cache file, and
            any future calls with identical arguments will read from the cache instead
            of running the LLM query. If "write" then the cache will only be written
            to, if "read" then the cache will be read from if a cache is available but
            will not write, and if "read-only" then the argument must be present in the
            cache, else an exception will be raised. Finally, an appending "-closest"
            will read the closest match from the cache, and overwrite it if cache writing
            is enabled. This argument only has any effect when stream=False.

            extra_headers: Additional "passthrough" headers for the request which are
            provider-specific, and are not part of the OpenAI standard. They are handled
            by the provider-specific API.

            extra_query: Additional "passthrough" query parameters for the request which
            are provider-specific, and are not part of the OpenAI standard. They are
            handled by the provider-specific API.

            kwargs: Additional "passthrough" JSON properties for the body of the
            request, which are provider-specific, and are not part of the OpenAI
            standard. They will be handled by the provider-specific API.

        Raises:
            UnifyError: If the API key is missing.
        """
        self._base_constructor_args = dict(
            system_message=system_message,
            messages=messages,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            max_completion_tokens=max_completion_tokens,
            n=n,
            presence_penalty=presence_penalty,
            response_format=response_format,
            seed=seed,
            stop=stop,
            stream=False,
            stream_options=None,
            temperature=temperature,
            top_p=top_p,
            service_tier=service_tier,
            tools=tools,
            tool_choice=tool_choice,
            parallel_tool_calls=parallel_tool_calls,
            reasoning_effort=reasoning_effort,
            # platform arguments
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
            api_key=api_key,
            # python client arguments
            stateful=stateful,
            return_full_completion=return_full_completion,
            traced=traced,
            cache=cache,
            cache_backend=cache_backend,
            # passthrough arguments
            extra_headers=extra_headers,
            extra_query=extra_query,
            **kwargs,
        )
        super().__init__(**self._base_constructor_args)
        self._constructor_args = dict(
            endpoints=endpoints,
            **self._base_constructor_args,
        )
        if isinstance(endpoints, str):
            endpoints = [endpoints]
        else:
            endpoints = list(endpoints)
        self._api_key = _validate_api_key(api_key)
        self._endpoints = endpoints
        self._client_class = AsyncUnify
        self._clients = self._create_clients(endpoints)

    def _create_clients(self, endpoints: List[str]) -> Dict[str, AsyncUnify]:
        return {
            endpoint: self._client_class(
                endpoint,
                system_message=self.system_message,
                messages=self.messages,
                frequency_penalty=self.frequency_penalty,
                logit_bias=self.logit_bias,
                logprobs=self.logprobs,
                top_logprobs=self.top_logprobs,
                max_completion_tokens=self.max_completion_tokens,
                n=self.n,
                presence_penalty=self.presence_penalty,
                response_format=self.response_format,
                seed=self.seed,
                stop=self.stop,
                temperature=self.temperature,
                top_p=self.top_p,
                service_tier=self.service_tier,
                tools=self.tools,
                tool_choice=self.tool_choice,
                parallel_tool_calls=self.parallel_tool_calls,
                reasoning_effort=self.reasoning_effort,
                # platform arguments
                use_custom_keys=self.use_custom_keys,
                tags=self.tags,
                drop_params=self.drop_params,
                region=self.region,
                log_query_body=self.log_query_body,
                log_response_body=self.log_response_body,
                api_key=self._api_key,
                # python client arguments
                stateful=self.stateful,
                return_full_completion=self.return_full_completion,
                cache=self.cache,
                # passthrough arguments
                extra_headers=self.extra_headers,
                extra_query=self.extra_query,
                **self.extra_body,
            )
            for endpoint in endpoints
        }

    def add_endpoints(
        self,
        endpoints: Union[List[str], str],
        ignore_duplicates: bool = True,
    ) -> Self:
        """
        Add extra endpoints to be queried for each call to generate.

        Args:
            endpoints: The extra endpoints to add.

            ignore_duplicates: Whether or not to ignore duplicate endpoints passed.

        Returns:
            This client, useful for chaining inplace calls.
        """
        if isinstance(endpoints, str):
            endpoints = [endpoints]
        # remove duplicates
        if ignore_duplicates:
            endpoints = [
                endpoint for endpoint in endpoints if endpoint not in self._endpoints
            ]
        elif len(self._endpoints + endpoints) != len(set(self._endpoints + endpoints)):
            raise Exception(
                "at least one of the provided endpoints to add {}"
                "was already set present in the endpoints {}."
                "Set ignore_duplicates to True to ignore errors like this".format(
                    endpoints,
                    self._endpoints,
                ),
            )
        # update endpoints
        self._endpoints = self._endpoints + endpoints
        # create new clients
        self._clients.update(self._create_clients(endpoints))
        return self

    def remove_endpoints(
        self,
        endpoints: Union[List[str], str],
        ignore_missing: bool = True,
    ) -> Self:
        """
        Remove endpoints from the current list, which are queried for each call to
        generate.

        Args:
            endpoints: The extra endpoints to add.

            ignore_missing: Whether or not to ignore endpoints passed which are not
            currently present in the client endpoint list.

        Returns:
            This client, useful for chaining inplace calls.
        """
        if isinstance(endpoints, str):
            endpoints = [endpoints]
        # remove irrelevant
        if ignore_missing:
            endpoints = [
                endpoint for endpoint in endpoints if endpoint in self._endpoints
            ]
        elif len(self._endpoints) != len(set(self._endpoints + endpoints)):
            raise Exception(
                "at least one of the provided endpoints to remove {}"
                "was not present in the current endpoints {}."
                "Set ignore_missing to True to ignore errors like this".format(
                    endpoints,
                    self._endpoints,
                ),
            )
        # update endpoints and clients
        for endpoint in endpoints:
            self._endpoints.remove(endpoint)
            del self._clients[endpoint]
        return self

    def get_credit_balance(self) -> Union[float, None]:
        """
        Get the remaining credits left on your account.

        Returns:
            The remaining credits on the account if successful, otherwise None.
        Raises:
            BadRequestError: If there was an HTTP error.
            ValueError: If there was an error parsing the JSON response.
        """
        url = f"{BASE_URL}/credits"
        headers = {
            "accept": "application/json",
            "Authorization": f"Bearer {self._api_key}",
        }
        try:
            response = _requests.get(url, headers=headers, timeout=10)
            if response.status_code != 200:
                raise Exception(response.json())
            return response.json()["credits"]
        except requests.RequestException as e:
            raise Exception("There was an error with the request.") from e
        except (KeyError, ValueError) as e:
            raise ValueError("Error parsing JSON response.") from e

    # Settable Properties #
    # --------------------#

    @property
    def endpoints(self) -> Tuple[str, ...]:
        """
        Get the current tuple of endpoints.

        Returns:
            The tuple of endpoints.
        """
        return tuple(self._endpoints)

    @property
    def clients(self) -> Dict[str, _UniClient]:
        """
        Get the current dictionary of clients, with endpoint names as keys and
        Unify or AsyncUnify instances as values.

        Returns:
            The dictionary of clients.
        """
        return self._clients

    # Representation #
    # ---------------#

    def __repr__(self):
        return "{}(endpoints={})".format(self.__class__.__name__, self._endpoints)

    def __str__(self):
        return "{}(endpoints={})".format(self.__class__.__name__, self._endpoints)

    # Generate #
    # ---------#

    def generate(
        self,
        arg0: Optional[Union[str, List[Union[str, Tuple[Any], Dict[str, Any]]]]] = None,
        /,
        system_message: Optional[str] = None,
        messages: Optional[
            Union[
                List[ChatCompletionMessageParam],
                Dict[str, List[ChatCompletionMessageParam]],
            ]
        ] = None,
        *,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[Dict[str, int]] = None,
        logprobs: Optional[bool] = None,
        top_logprobs: Optional[int] = None,
        max_completion_tokens: Optional[int] = None,
        n: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        response_format: Optional[Union[Type[BaseModel], Dict[str, str]]] = None,
        seed: Optional[int] = None,
        stop: Union[Optional[str], List[str]] = None,
        stream: Optional[bool] = None,
        stream_options: Optional[ChatCompletionStreamOptionsParam] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        tools: Optional[Iterable[ChatCompletionToolParam]] = None,
        tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None,
        parallel_tool_calls: Optional[bool] = None,
        reasoning_effort: Optional[str] = None,
        # platform arguments
        use_custom_keys: Optional[bool] = None,
        tags: Optional[List[str]] = None,
        drop_params: Optional[bool] = None,
        region: Optional[str] = None,
        log_query_body: Optional[bool] = None,
        log_response_body: Optional[bool] = None,
        # python client arguments
        stateful: Optional[bool] = None,
        return_full_completion: Optional[bool] = None,
        cache: Optional[Union[bool, str]] = None,
        # passthrough arguments
        extra_headers: Optional[Headers] = None,
        extra_query: Optional[Query] = None,
        **kwargs,
    ):
        """Generate a ChatCompletion response for the specified endpoint,
        from the provided query parameters.

        Args:
            arg0: A string containing the user message, or a list containing the inputs
            to send to each of the LLMs, in the format of str (user message), tuple
            (all-positional) or dict (all keyword).

            system_message: An optional string containing the system message. This
            always appears at the beginning of the list of messages.

            messages: A list of messages comprising the conversation so far, or
            optionally a dictionary of such messages, with clients as the keys in the
            case of multi-llm clients. This will be appended to the system_message if it
            is not None, and any user_message will be appended if it is not None.

            frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new
            tokens based on their existing frequency in the text so far, decreasing the
            model's likelihood to repeat the same line verbatim.

            logit_bias: Modify the likelihood of specified tokens appearing in the
            completion. Accepts a JSON object that maps tokens (specified by their token
            ID in the tokenizer) to an associated bias value from -100 to 100.
            Mathematically, the bias is added to the logits generated by the model prior
            to sampling. The exact effect will vary per model, but values between -1 and
            1 should decrease or increase likelihood of selection; values like -100 or
            100 should result in a ban or exclusive selection of the relevant token.

            logprobs: Whether to return log probabilities of the output tokens or not.
            If true, returns the log probabilities of each output token returned in the
            content of message.

            top_logprobs: An integer between 0 and 20 specifying the number of most
            likely tokens to return at each token position, each with an associated log
            probability. logprobs must be set to true if this parameter is used.

            max_completion_tokens: The maximum number of tokens that can be generated in
            the chat completion. The total length of input tokens and generated tokens
            is limited by the model's context length. Defaults value is None. Uses the
            provider's default max_completion_tokens when None is explicitly passed.

            n: How many chat completion choices to generate for each input message. Note
            that you will be charged based on the number of generated tokens across all
            of the choices. Keep n as 1 to minimize costs.

            presence_penalty: Number between -2.0 and 2.0. Positive values penalize new
            tokens based on whether they appear in the text so far, increasing the
            model's likelihood to talk about new topics.

            response_format: An object specifying the format that the model must output.
            Setting to `{ "type": "json_schema", "json_schema": {...} }` enables
            Structured Outputs which ensures the model will match your supplied JSON
            schema. Learn more in the Structured Outputs guide. Setting to
            `{ "type": "json_object" }` enables JSON mode, which ensures the message the
            model generates is valid JSON.

            seed: If specified, a best effort attempt is made to sample
            deterministically, such that repeated requests with the same seed and
            parameters should return the same result. Determinism is not guaranteed, and
            you should refer to the system_fingerprint response parameter to monitor
            changes in the backend.

            stop: Up to 4 sequences where the API will stop generating further tokens.

            stream: If True, generates content as a stream. If False, generates content
            as a single response. Defaults to False.

            stream_options: Options for streaming response. Only set this when you set
            stream: true.

            temperature:  What sampling temperature to use, between 0 and 2.
            Higher values like 0.8 will make the output more random,
            while lower values like 0.2 will make it more focused and deterministic.
            It is generally recommended to alter this or top_p, but not both.
            Default value is 1.0. Defaults to the provider's default temperature when
            None is explicitly passed.

            top_p: An alternative to sampling with temperature, called nucleus sampling,
            where the model considers the results of the tokens with top_p probability
            mass. So 0.1 means only the tokens comprising the top 10% probability mass
            are considered. Generally recommended to alter this or temperature, but not
            both.

            tools: A list of tools the model may call. Currently, only functions are
            supported as a tool. Use this to provide a list of functions the model may
            generate JSON inputs for. A max of 128 functions are supported.

            tool_choice: Controls which (if any) tool is called by the
            model. none means the model will not call any tool and instead generates a
            message. auto means the model can pick between generating a message or
            calling one or more tools. required means the model must call one or more
            tools. Specifying a particular tool via
            `{ "type": "function", "function": {"name": "my_function"} }`
            forces the model to call that tool.
            none is the default when no tools are present. auto is the default if tools
            are present.

            parallel_tool_calls: Whether to enable parallel function calling during tool
            use.

            stateful:  Whether the conversation history is preserved within the messages
            of this client. If True, then history is preserved. If False, then this acts
            as a stateless client, and message histories must be managed by the user.

            use_custom_keys:  Whether to use custom API keys or our unified API keys
            with the backend provider. Defaults to False.

            tags: Arbitrary number of tags to classify this API query as needed. Helpful
            for generally grouping queries across tasks and users, for logging purposes.

            drop_params: Whether or not to drop unsupported OpenAI params by the
            provider you’re using.

            region: A string used to represent the region where the endpoint is
            accessed. Only relevant for on-prem deployments with certain providers like
            `vertex-ai`, `aws-bedrock` and `azure-ml`, where the endpoint is being
            accessed through a specified region.

            log_query_body: Whether to log the contents of the query json body.

            log_response_body: Whether to log the contents of the response json body.

            stateful:  Whether the conversation history is preserved within the messages
            of this client. If True, then history is preserved. If False, then this acts
            as a stateless client, and message histories must be managed by the user.

            return_full_completion: If False, only return the message content
            chat_completion.choices[0].message.content.strip(" ") from the OpenAI
            return. Otherwise, the full response chat_completion is returned.
            Defaults to False.

            cache: If True, then the arguments will be stored in a local cache file, and
            any future calls with identical arguments will read from the cache instead
            of running the LLM query. If "write" then the cache will only be written
            to, if "read" then the cache will be read from if a cache is available but
            will not write, and if "read-only" then the argument must be present in the
            cache, else an exception will be raised. Finally, an appending "-closest"
            will read the closest match from the cache, and overwrite it if cache writing
            is enabled. This argument only has any effect when stream=False.

            extra_headers: Additional "passthrough" headers for the request which are
            provider-specific, and are not part of the OpenAI standard. They are handled
            by the provider-specific API.

            extra_query: Additional "passthrough" query parameters for the request which
            are provider-specific, and are not part of the OpenAI standard. They are
            handled by the provider-specific API.

            kwargs: Additional "passthrough" JSON properties for the body of the
            request, which are provider-specific, and are not part of the OpenAI
            standard. They will be handled by the provider-specific API.

        Returns:
            If stream is True, returns a generator yielding chunks of content.
            If stream is False, returns a single string response.

        Raises:
            UnifyError: If an error occurs during content generation.
        """
        system_message = _default(system_message, self._system_message)
        messages = _default(messages, self._messages)
        stateful = _default(stateful, self._stateful)
        if messages:
            # system message only added once at the beginning
            if isinstance(arg0, str):
                if isinstance(messages, dict):
                    messages = {
                        k: v + [{"role": "user", "content": arg0}]
                        for k, v in messages.items()
                    }
                else:
                    messages += [{"role": "user", "content": arg0}]
        else:
            messages = list()
            if system_message is not None:
                messages += [{"role": "system", "content": system_message}]
            if isinstance(arg0, str):
                messages += [{"role": "user", "content": arg0}]
            self._messages = messages
        return_full_completion = (
            True
            if _default(tools, self._tools)
            else _default(return_full_completion, self._return_full_completion)
        )
        ret = self._generate(
            messages=messages,
            frequency_penalty=_default(frequency_penalty, self._frequency_penalty),
            logit_bias=_default(logit_bias, self._logit_bias),
            logprobs=_default(logprobs, self._logprobs),
            top_logprobs=_default(top_logprobs, self._top_logprobs),
            max_completion_tokens=_default(
                max_completion_tokens,
                self._max_completion_tokens,
            ),
            n=_default(n, self._n),
            presence_penalty=_default(presence_penalty, self._presence_penalty),
            response_format=_default(response_format, self._response_format),
            seed=_default(_default(seed, self._seed), unify.get_seed()),
            stop=_default(stop, self._stop),
            stream=_default(stream, self._stream),
            stream_options=_default(stream_options, self._stream_options),
            temperature=_default(temperature, self._temperature),
            top_p=_default(top_p, self._top_p),
            tools=_default(tools, self._tools),
            tool_choice=_default(tool_choice, self._tool_choice),
            parallel_tool_calls=_default(
                parallel_tool_calls,
                self._parallel_tool_calls,
            ),
            reasoning_effort=_default(reasoning_effort, self._reasoning_effort),
            # platform arguments
            use_custom_keys=_default(use_custom_keys, self._use_custom_keys),
            tags=_default(tags, self._tags),
            drop_params=_default(drop_params, self._drop_params),
            region=_default(region, self._region),
            log_query_body=_default(log_query_body, self._log_query_body),
            log_response_body=_default(log_response_body, self._log_response_body),
            # python client arguments
            return_full_completion=return_full_completion,
            cache=_default(cache, self._cache),
            # passthrough arguments
            extra_headers=_default(extra_headers, self._extra_headers),
            extra_query=_default(extra_query, self._extra_query),
            **{**self._extra_body, **kwargs},
        )
        if stateful:
            if return_full_completion:
                msg = [ret.choices[0].message.model_dump()]
            else:
                msg = [{"role": "assistant", "content": ret}]
            if self._messages is None:
                self._messages = []
            self._messages += msg
        return ret


class MultiUnify(_MultiClient):
    async def _async_gen(
        self,
        messages: Optional[
            Union[
                List[ChatCompletionMessageParam],
                Dict[str, List[ChatCompletionMessageParam]],
            ]
        ] = None,
        *,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[Dict[str, int]] = None,
        logprobs: Optional[bool] = None,
        top_logprobs: Optional[int] = None,
        max_completion_tokens: Optional[int] = None,
        n: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        response_format: Optional[Union[Type[BaseModel], Dict[str, str]]] = None,
        seed: Optional[int] = None,
        stop: Union[Optional[str], List[str]] = None,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = None,
        tools: Optional[Iterable[ChatCompletionToolParam]] = None,
        tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None,
        parallel_tool_calls: Optional[bool] = None,
        reasoning_effort: Optional[str] = None,
        # platform arguments
        use_custom_keys: bool = False,
        tags: Optional[List[str]] = None,
        drop_params: Optional[bool] = True,
        region: Optional[str] = None,
        log_query_body: Optional[bool] = True,
        log_response_body: Optional[bool] = True,
        # python client arguments
        return_full_completion: bool = False,
        cache_backend: Optional[str] = None,
        # passthrough arguments
        extra_headers: Optional[Headers] = None,
        extra_query: Optional[Query] = None,
        service_tier: Optional[str] = None,
        **kwargs,
    ) -> Union[Union[str, ChatCompletion], Dict[str, Union[str, ChatCompletion]]]:
        kw = dict(
            messages=messages,
            max_completion_tokens=max_completion_tokens,
            stop=stop,
            temperature=temperature,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            n=n,
            presence_penalty=presence_penalty,
            response_format=response_format,
            seed=seed,
            top_p=top_p,
            tools=tools,
            tool_choice=tool_choice,
            parallel_tool_calls=parallel_tool_calls,
            reasoning_effort=reasoning_effort,
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
            return_full_completion=return_full_completion,
            extra_headers=extra_headers,
            extra_query=extra_query,
            service_tier=service_tier,
            **kwargs,
        )
        multi_message = isinstance(kw["messages"], dict)
        kw_ = {k: v for k, v in kw.items() if v is not None}
        responses = dict()
        for endpoint, client in self._clients.items():
            these_kw = kw_.copy()
            if multi_message:
                these_kw["messages"] = these_kw["messages"][endpoint]
            responses[endpoint] = await client.generate(**these_kw)
        return responses[self._endpoints[0]] if len(self._endpoints) == 1 else responses

    async def _multi_inp_gen(
        self,
        multi_input: List[Union[str, Tuple[Any], Dict[str, Any]]],
        **kwargs,
    ) -> List[Union[str, ChatCompletion]]:
        assert isinstance(multi_input, list), (
            f"Expected multi_kwargs to be a list, "
            f"but found {multi_input} of type {type(multi_input)}."
        )
        assert all(
            type(multi_input[0]) is type(i) for i in multi_input
        ), "all entries in the list of inputs must be of the same type."
        if isinstance(multi_input[0], str):
            coroutines = [self._async_gen(s, **kwargs) for s in multi_input]
        elif isinstance(multi_input[0], tuple):
            coroutines = [self._async_gen(*a, **kwargs) for a in multi_input]
        elif isinstance(multi_input[0], dict):
            coroutines = [self._async_gen(**{**kwargs, **kw}) for kw in multi_input]
        else:
            raise Exception(
                f"Invalid format for first argument in list, expected either str, "
                f"tuple or dict but found "
                f"{multi_input[0]} of type {type(multi_input[0])}.",
            )
        return await asyncio.gather(*coroutines)

    def _multi_inp_generate(
        self,
        *args,
        **kwargs,
    ) -> List[Union[str, ChatCompletion]]:
        """
        Perform multiple generations to multiple inputs asynchronously, based on the
        list keywords arguments passed in.
        """
        return asyncio.run(self._multi_inp_gen(*args, **kwargs))

    def _generate(  # noqa: WPS234, WPS211
        self,
        *args,
        **kwargs,
    ) -> Union[
        Union[
            Union[str, ChatCompletion],
            List[Union[str, ChatCompletion]],
        ],
        Dict[
            str,
            Union[
                Union[str, ChatCompletion],
                List[Union[str, ChatCompletion]],
            ],
        ],
    ]:
        # refresh the openai client before doing a new event loop
        for key, client in self._clients.items():
            client._client = client._get_client()
        if args and isinstance(args[0], list):
            return self._multi_inp_generate(*args, **kwargs)
        return asyncio.run(
            self._async_gen(
                *args,
                **kwargs,
            ),
        )

    def to_async_client(self):
        """
        Return an asynchronous version of the client (`AsyncMultiUnify` instance), with
        the exact same configuration as this synchronous (`MultiUnify`) client.

        Returns:
            An `AsyncMultiUnify` instance with the same configuration as this `MultiUnify`
            instance.
        """
        return AsyncMultiUnify(**self._constructor_args)


class AsyncMultiUnify(_MultiClient):
    async def _generate(  # noqa: WPS234, WPS211
        self,
        messages: Optional[
            Union[
                List[ChatCompletionMessageParam],
                Dict[str, List[ChatCompletionMessageParam]],
            ]
        ] = None,
        *,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[Dict[str, int]] = None,
        logprobs: Optional[bool] = None,
        top_logprobs: Optional[int] = None,
        max_completion_tokens: Optional[int] = None,
        n: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        response_format: Optional[Union[Type[BaseModel], Dict[str, str]]] = None,
        seed: Optional[int] = None,
        stop: Union[Optional[str], List[str]] = None,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = None,
        tools: Optional[Iterable[ChatCompletionToolParam]] = None,
        tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None,
        parallel_tool_calls: Optional[bool] = None,
        reasoning_effort: Optional[str] = None,
        # platform arguments
        use_custom_keys: bool = False,
        tags: Optional[List[str]] = None,
        drop_params: Optional[bool] = True,
        region: Optional[str] = None,
        log_query_body: Optional[bool] = True,
        log_response_body: Optional[bool] = True,
        # python client arguments
        return_full_completion: bool = False,
        cache_backend: Optional[str] = None,
        # passthrough arguments
        extra_headers: Optional[Headers] = None,
        extra_query: Optional[Query] = None,
        service_tier: Optional[str] = None,
        **kwargs,
    ) -> Dict[str, str]:
        kw = dict(
            messages=messages,
            max_completion_tokens=max_completion_tokens,
            stop=stop,
            temperature=temperature,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            n=n,
            presence_penalty=presence_penalty,
            response_format=response_format,
            seed=seed,
            top_p=top_p,
            tools=tools,
            tool_choice=tool_choice,
            parallel_tool_calls=parallel_tool_calls,
            reasoning_effort=reasoning_effort,
            use_custom_keys=use_custom_keys,
            tags=tags,
            drop_params=drop_params,
            region=region,
            log_query_body=log_query_body,
            log_response_body=log_response_body,
            return_full_completion=return_full_completion,
            cache_backend=cache_backend,
            extra_headers=extra_headers,
            extra_query=extra_query,
            service_tier=service_tier,
            **kwargs,
        )
        multi_message = isinstance(messages, dict)
        kw = {k: v for k, v in kw.items() if v is not None}
        responses = dict()
        for endpoint, client in self._clients.items():
            these_kw = kw.copy()
            if multi_message:
                these_kw["messages"] = these_kw["messages"][endpoint]
            responses[endpoint] = await client.generate(**these_kw)
        return responses

    def to_sync_client(self):
        """
        Return a synchronous version of the client (`MultiUnify` instance), with the
        exact same configuration as this asynchronous (`AsyncMultiUnify`) client.

        Returns:
            A `MultiUnify` instance with the same configuration as this `AsyncMultiUnify`
            instance.
        """
        return MultiUnify(**self._constructor_args)
