# This file was auto-generated by Fern from our API Definition.

import json
import typing
from json.decoder import JSONDecodeError

import websockets
import websockets.sync.connection as websockets_sync_connection
from ..core.events import EventEmitterMixin, EventType
from ..core.pydantic_utilities import parse_obj_as
from ..types.error_message_payload import ErrorMessagePayload
from ..types.final_message_payload import FinalMessagePayload
from ..types.streaming_update_payload import StreamingUpdatePayload
from ..types.user_message_payload import UserMessagePayload

try:
    from websockets.legacy.client import WebSocketClientProtocol  # type: ignore
except ImportError:
    from websockets import WebSocketClientProtocol  # type: ignore

ApolloWsSessionSocketClientResponse = typing.Union[StreamingUpdatePayload, FinalMessagePayload, ErrorMessagePayload]


class AsyncApolloWsSessionSocketClient(EventEmitterMixin):
    def __init__(self, *, websocket: WebSocketClientProtocol):
        super().__init__()
        self._websocket = websocket

    async def __aiter__(self):
        async for message in self._websocket:
            yield parse_obj_as(ApolloWsSessionSocketClientResponse, json.loads(message))  # type: ignore

    async def start_listening(self):
        """
        Start listening for messages on the websocket connection.

        Emits events in the following order:
        - EventType.OPEN when connection is established
        - EventType.MESSAGE for each message received
        - EventType.ERROR if an error occurs
        - EventType.CLOSE when connection is closed
        """
        await self._emit_async(EventType.OPEN, None)
        try:
            async for raw_message in self._websocket:
                json_data = json.loads(raw_message)
                parsed = parse_obj_as(ApolloWsSessionSocketClientResponse, json_data)  # type: ignore
                await self._emit_async(EventType.MESSAGE, parsed)
        except (websockets.WebSocketException, JSONDecodeError) as exc:
            await self._emit_async(EventType.ERROR, exc)
        finally:
            await self._emit_async(EventType.CLOSE, None)

    async def send_user_message(self, message: UserMessagePayload) -> None:
        """
        Send a message to the websocket connection.
        The message will be sent as a UserMessagePayload.
        """
        await self._send_model(message)

    async def recv(self) -> ApolloWsSessionSocketClientResponse:
        """
        Receive a message from the websocket connection.
        """
        data = await self._websocket.recv()
        json_data = json.loads(data)
        return parse_obj_as(ApolloWsSessionSocketClientResponse, json_data)  # type: ignore

    async def _send(self, data: typing.Any) -> None:
        """
        Send a message to the websocket connection.
        """
        if isinstance(data, dict):
            data = json.dumps(data)
        await self._websocket.send(data)

    async def _send_model(self, data: typing.Any) -> None:
        """
        Send a Pydantic model to the websocket connection.
        """
        await self._send(data.dict())


class ApolloWsSessionSocketClient(EventEmitterMixin):
    def __init__(self, *, websocket: websockets_sync_connection.Connection):
        super().__init__()
        self._websocket = websocket

    def __iter__(self):
        for message in self._websocket:
            yield parse_obj_as(ApolloWsSessionSocketClientResponse, json.loads(message))  # type: ignore

    def start_listening(self):
        """
        Start listening for messages on the websocket connection.

        Emits events in the following order:
        - EventType.OPEN when connection is established
        - EventType.MESSAGE for each message received
        - EventType.ERROR if an error occurs
        - EventType.CLOSE when connection is closed
        """
        self._emit(EventType.OPEN, None)
        try:
            for raw_message in self._websocket:
                json_data = json.loads(raw_message)
                parsed = parse_obj_as(ApolloWsSessionSocketClientResponse, json_data)  # type: ignore
                self._emit(EventType.MESSAGE, parsed)
        except (websockets.WebSocketException, JSONDecodeError) as exc:
            self._emit(EventType.ERROR, exc)
        finally:
            self._emit(EventType.CLOSE, None)

    def send_user_message(self, message: UserMessagePayload) -> None:
        """
        Send a message to the websocket connection.
        The message will be sent as a UserMessagePayload.
        """
        self._send_model(message)

    def recv(self) -> ApolloWsSessionSocketClientResponse:
        """
        Receive a message from the websocket connection.
        """
        data = self._websocket.recv()
        json_data = json.loads(data)
        return parse_obj_as(ApolloWsSessionSocketClientResponse, json_data)  # type: ignore

    def _send(self, data: typing.Any) -> None:
        """
        Send a message to the websocket connection.
        """
        if isinstance(data, dict):
            data = json.dumps(data)
        self._websocket.send(data)

    def _send_model(self, data: typing.Any) -> None:
        """
        Send a Pydantic model to the websocket connection.
        """
        self._send(data.dict())
