# This file was auto-generated by Fern from our API Definition.

import typing
from contextlib import asynccontextmanager, contextmanager

import httpx
import websockets.exceptions
import websockets.sync.client as websockets_sync_client
from ..core.api_error import ApiError
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.request_options import RequestOptions
from .socket_client import AsyncWebsocketsSocketClient, WebsocketsSocketClient

try:
    from websockets.legacy.client import connect as websockets_client_connect  # type: ignore
except ImportError:
    from websockets import connect as websockets_client_connect  # type: ignore


class RawWebsocketsClient:
    def __init__(self, *, client_wrapper: SyncClientWrapper):
        self._client_wrapper = client_wrapper

    @contextmanager
    def connect(
        self, *, auth_token: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None
    ) -> typing.Iterator[WebsocketsSocketClient]:
        """
        Parameters
        ----------
        auth_token : typing.Optional[str]
            Your API key. Required if Authorization header is not set.

        request_options : typing.Optional[RequestOptions]
            Request-specific configuration.

        Returns
        -------
        WebsocketsSocketClient
        """
        ws_url = self._client_wrapper.get_environment().websockets + "/v0"
        query_params = httpx.QueryParams()
        if auth_token is not None:
            query_params = query_params.add("auth_token", auth_token)
        ws_url = ws_url + f"?{query_params}"
        headers = self._client_wrapper.get_headers()
        if request_options and "additional_headers" in request_options:
            headers.update(request_options["additional_headers"])
        try:
            with websockets_sync_client.connect(ws_url, additional_headers=headers) as protocol:
                yield WebsocketsSocketClient(websocket=protocol)
        except websockets.exceptions.InvalidStatusCode as exc:
            status_code: int = exc.status_code
            if status_code == 401:
                raise ApiError(
                    status_code=status_code,
                    headers=dict(headers),
                    body="Websocket initialized with invalid credentials.",
                )
            raise ApiError(
                status_code=status_code,
                headers=dict(headers),
                body="Unexpected error when initializing websocket connection.",
            )


class AsyncRawWebsocketsClient:
    def __init__(self, *, client_wrapper: AsyncClientWrapper):
        self._client_wrapper = client_wrapper

    @asynccontextmanager
    async def connect(
        self, *, auth_token: typing.Optional[str] = None, request_options: typing.Optional[RequestOptions] = None
    ) -> typing.AsyncIterator[AsyncWebsocketsSocketClient]:
        """
        Parameters
        ----------
        auth_token : typing.Optional[str]
            Your API key. Required if Authorization header is not set.

        request_options : typing.Optional[RequestOptions]
            Request-specific configuration.

        Returns
        -------
        AsyncWebsocketsSocketClient
        """
        ws_url = self._client_wrapper.get_environment().websockets + "/v0"
        query_params = httpx.QueryParams()
        if auth_token is not None:
            query_params = query_params.add("auth_token", auth_token)
        ws_url = ws_url + f"?{query_params}"
        headers = self._client_wrapper.get_headers()
        if request_options and "additional_headers" in request_options:
            headers.update(request_options["additional_headers"])
        try:
            async with websockets_client_connect(ws_url, extra_headers=headers) as protocol:
                yield AsyncWebsocketsSocketClient(websocket=protocol)
        except websockets.exceptions.InvalidStatusCode as exc:
            status_code: int = exc.status_code
            if status_code == 401:
                raise ApiError(
                    status_code=status_code,
                    headers=dict(headers),
                    body="Websocket initialized with invalid credentials.",
                )
            raise ApiError(
                status_code=status_code,
                headers=dict(headers),
                body="Unexpected error when initializing websocket connection.",
            )
