# Copyright 2025 - AI4I. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import ssl
from typing import Any, Dict, Optional, Union

import httpx
from attrs import define, evolve, field


class MultipartFixClient(httpx.Client):
    """
    A custom httpx.Client that addresses potential issues with multipart/form-data
    requests generated by openapi-python-client.

    Specifically, it ensures that if a 'Content-Type' header is manually set to
    "multipart/form-data" without a boundary, it is removed. This allows httpx
    to correctly generate the 'Content-Type' header, including the boundary, based
    on the 'files' provided in the request. This is crucial for robust file uploads.
    """

    def request(
        self, method: str, url: Union[str, httpx.URL], **kwargs: Any
    ) -> httpx.Response:
        """
        Overrides the default request method to inspect and potentially modify
        headers for multipart/form-data requests.

        If 'files' are present and 'Content-Type' is 'multipart/form-data'
        without a boundary, this method removes the problematic 'Content-Type'
        header to let httpx handle its generation.
        """
        headers = kwargs.get("headers")
        if kwargs.get("files") is not None and headers is not None:
            content_type = headers.get("Content-Type")
            if content_type == "multipart/form-data":
                new_headers = {k: v for k, v in headers.items() if k != "Content-Type"}
                kwargs["headers"] = new_headers
        return super().request(method, url, **kwargs)


class AsyncMultipartFixClient(httpx.AsyncClient):
    """
    An asynchronous custom httpx.AsyncClient that addresses potential issues with
    multipart/form-data requests, similar to `MultipartFixClient`.

    It ensures correct 'Content-Type' header generation for multipart requests
    when using 'files' in an asynchronous context.
    """

    async def request(
        self, method: str, url: Union[str, httpx.URL], **kwargs: Any
    ) -> httpx.Response:
        """
        Overrides the default asynchronous request method to inspect and potentially
        modify headers for multipart/form-data requests.

        If 'files' are present and 'Content-Type' is 'multipart/form-data'
        without a boundary, this method removes the problematic 'Content-Type'
        header to let httpx handle its generation.
        """
        headers = kwargs.get("headers")
        if kwargs.get("files") is not None and headers is not None:
            content_type = headers.get("Content-Type")
            if content_type == "multipart/form-data":
                new_headers = {k: v for k, v in headers.items() if k != "Content-Type"}
                kwargs["headers"] = new_headers
        return await super().request(method, url, **kwargs)


@define
class Client:
    """
    A base client for keeping track of data related to API interaction.

    This class manages common HTTP client configurations such as base URL, cookies,
    headers, timeout, SSL verification, and redirect behavior. It serves as a
    foundation for more specialized clients (e.g., `AuthenticatedClient`).

    The following are accepted as keyword arguments and will be used to construct
    httpx Clients internally:

        base_url: The base URL for the API. All requests are made relative to this.
        cookies: A dictionary of cookies to be sent with every request.
        headers: A dictionary of headers to be sent with every request.
        timeout: The maximum time (httpx.Timeout) a request can take.
            API functions will raise `httpx.TimeoutException` if exceeded.
        verify_ssl: Whether to verify the SSL certificate (True/False), or a path
            to CA bundle, or an `ssl.SSLContext` instance.
        follow_redirects: Whether to follow redirects. Defaults to `False`.
        httpx_args: Additional keyword arguments passed to the `httpx.Client`
            and `httpx.AsyncClient` constructors.

    Attributes:
        raise_on_unexpected_status: If `True`, raises `errors.UnexpectedStatus`
            if the API returns a status code not documented in the OpenAPI spec.
    """

    raise_on_unexpected_status: bool = field(default=False, kw_only=True)
    _base_url: str = field(alias="base_url")
    _cookies: Dict[str, str] = field(factory=dict, kw_only=True, alias="cookies")
    _headers: Dict[str, str] = field(factory=dict, kw_only=True, alias="headers")
    _timeout: Optional[httpx.Timeout] = field(
        default=None, kw_only=True, alias="timeout"
    )
    _verify_ssl: Union[str, bool, ssl.SSLContext] = field(
        default=True, kw_only=True, alias="verify_ssl"
    )
    _follow_redirects: bool = field(
        default=False, kw_only=True, alias="follow_redirects"
    )
    _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args")
    _client: Optional[httpx.Client] = field(default=None, init=False)
    _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)

    def with_headers(self, headers: Dict[str, str]) -> "Client":
        """Creates a new client instance with additional or updated headers."""
        if self._client is not None:
            self._client.headers.update(headers)
        if self._async_client is not None:
            self._async_client.headers.update(headers)
        return evolve(self, headers={**self._headers, **headers})

    def with_cookies(self, cookies: Dict[str, str]) -> "Client":
        """Creates a new client instance with additional or updated cookies."""
        if self._client is not None:
            self._client.cookies.update(cookies)
        if self._async_client is not None:
            self._async_client.cookies.update(cookies)
        return evolve(self, cookies={**self._cookies, **cookies})

    def with_timeout(self, timeout: httpx.Timeout) -> "Client":
        """Creates a new client instance with an updated timeout."""
        if self._client is not None:
            self._client.timeout = timeout
        if self._async_client is not None:
            self._async_client.timeout = timeout
        return evolve(self, timeout=timeout)

    def set_httpx_client(self, client: httpx.Client) -> "Client":
        """
        Manually sets the underlying `httpx.Client` instance.

        Note: This will override any other client settings like cookies, headers,
        and timeout that were configured on this `Client` instance.
        The provided client should ideally be `MultipartFixClient` or compatible
        if multipart request fixes are desired.
        """
        self._client = client
        return self

    def get_httpx_client(self) -> httpx.Client:
        """
        Retrieves the underlying `httpx.Client`.

        If no client has been set or previously constructed, a new `httpx.Client`
        (or `MultipartFixClient` in derived classes like `AuthenticatedClient`)
        is initialized with the current configuration (base_url, headers, etc.).
        """
        if self._client is None:
            self._client = httpx.Client(
                base_url=self._base_url,
                cookies=self._cookies,
                headers=self._headers,
                timeout=self._timeout,
                verify=self._verify_ssl,
                follow_redirects=self._follow_redirects,
                **self._httpx_args,
            )
        return self._client

    def __enter__(self) -> "Client":
        """Enters a context manager for the synchronous httpx client."""
        self.get_httpx_client().__enter__()
        return self

    def __exit__(self, *args: Any, **kwargs: Any) -> None:
        """Exits the context manager for the synchronous httpx client."""
        self.get_httpx_client().__exit__(*args, **kwargs)

    def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "Client":
        """
        Manually sets the underlying `httpx.AsyncClient` instance.

        Note: This will override any other client settings like cookies, headers,
        and timeout. The provided client should ideally be `AsyncMultipartFixClient`
        or compatible if multipart request fixes are desired.
        """
        self._async_client = async_client
        return self

    def get_async_httpx_client(self) -> httpx.AsyncClient:
        """
        Retrieves the underlying `httpx.AsyncClient`.

        If no client has been set, a new `httpx.AsyncClient` (or
        `AsyncMultipartFixClient` in derived classes) is initialized.
        """
        if self._async_client is None:
            self._async_client = httpx.AsyncClient(
                base_url=self._base_url,
                cookies=self._cookies,
                headers=self._headers,
                timeout=self._timeout,
                verify=self._verify_ssl,
                follow_redirects=self._follow_redirects,
                **self._httpx_args,
            )
        return self._async_client

    async def __aenter__(self) -> "Client":
        """Enters a context manager for the asynchronous httpx client."""
        await self.get_async_httpx_client().__aenter__()
        return self

    async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
        """Exits the context manager for the asynchronous httpx client."""
        await self.get_async_httpx_client().__aexit__(*args, **kwargs)


@define
class AuthenticatedClient:
    """
    A client authenticated for use on secured API endpoints.

    This class extends the basic client configuration with authentication details,
    specifically a token and its associated prefix for the Authorization header.
    It defaults to using `MultipartFixClient` and `AsyncMultipartFixClient` for
    its underlying synchronous and asynchronous HTTP clients respectively, to handle
    potential multipart request issues.

    Accepted keyword arguments for construction are the same as for the `Client`
    class, plus `token`, `prefix`, and `auth_header_name`.

    Attributes:
        token: The authentication token.
        prefix: The prefix for the token in the Authorization header (e.g., "Bearer").
            Defaults to "Bearer". If an empty string, only the token is used.
        auth_header_name: The name of the HTTP header used for authorization.
            Defaults to "Authorization".
        raise_on_unexpected_status: See `Client` class.
        _base_url: See `Client` class. Defaults to "https://hackagent.dev/".
        _cookies: See `Client` class.
        _headers: See `Client` class.
        _timeout: See `Client` class.
        _verify_ssl: See `Client` class.
        _follow_redirects: See `Client` class.
        _httpx_args: See `Client` class.
    """

    token: str
    raise_on_unexpected_status: bool = field(default=False, kw_only=True)
    _base_url: str = field(
        default="https://hackagent.dev/",
        alias="base_url",
    )
    _cookies: Dict[str, str] = field(factory=dict, kw_only=True, alias="cookies")
    _headers: Dict[str, str] = field(factory=dict, kw_only=True, alias="headers")
    _timeout: Optional[httpx.Timeout] = field(
        default=None, kw_only=True, alias="timeout"
    )
    _verify_ssl: Union[str, bool, ssl.SSLContext] = field(
        default=True, kw_only=True, alias="verify_ssl"
    )
    _follow_redirects: bool = field(
        default=False, kw_only=True, alias="follow_redirects"
    )
    _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args")
    _client: Optional[httpx.Client] = field(default=None, init=False)
    _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)

    prefix: str = "Bearer"
    auth_header_name: str = "Authorization"

    def __attrs_post_init__(self):
        """Ensures `_base_url` is set to its default if `None` was explicitly passed."""
        if self._base_url is None:
            self._base_url = "https://hackagent.dev/"

    def with_headers(self, headers: Dict[str, str]) -> "AuthenticatedClient":
        """Creates a new authenticated client instance with additional or updated headers."""
        if self._client is not None:
            self._client.headers.update(headers)
        if self._async_client is not None:
            self._async_client.headers.update(headers)
        return evolve(self, headers={**self._headers, **headers})

    def with_cookies(self, cookies: Dict[str, str]) -> "AuthenticatedClient":
        """Creates a new authenticated client instance with additional or updated cookies."""
        if self._client is not None:
            self._client.cookies.update(cookies)
        if self._async_client is not None:
            self._async_client.cookies.update(cookies)
        return evolve(self, cookies={**self._cookies, **cookies})

    def with_timeout(self, timeout: httpx.Timeout) -> "AuthenticatedClient":
        """Creates a new authenticated client instance with an updated timeout."""
        if self._client is not None:
            self._client.timeout = timeout
        if self._async_client is not None:
            self._async_client.timeout = timeout
        return evolve(self, timeout=timeout)

    def set_httpx_client(self, client: httpx.Client) -> "AuthenticatedClient":
        """
        Manually sets the underlying `httpx.Client`.

        It is recommended that the provided client is an instance of
        `MultipartFixClient` or a compatible class to ensure correct handling
        of multipart/form-data requests. If a different type of client is set,
        the multipart fix behavior might be lost.
        This will override other client settings like cookies, headers, and timeout.
        """
        if not isinstance(client, MultipartFixClient):
            # Log a warning or raise an error if strict type adherence is required.
            # For now, we allow it but the user should be aware.
            pass
        self._client = client
        return self

    def get_httpx_client(self) -> httpx.Client:
        """
        Retrieves the underlying `httpx.Client`, defaulting to `MultipartFixClient`.

        If no client has been set, a new `MultipartFixClient` is initialized.
        The client is configured with the `AuthenticatedClient`'s settings
        (base_url, cookies, timeout, etc.) and the necessary Authorization header
        is automatically added to its default headers.
        """
        if self._client is None:
            request_headers = self._headers.copy()
            auth_value = f"{self.prefix} {self.token}" if self.prefix else self.token
            request_headers[self.auth_header_name] = auth_value

            self._client = MultipartFixClient(
                base_url=self._base_url,
                cookies=self._cookies,
                headers=request_headers,
                timeout=self._timeout,
                verify=self._verify_ssl,
                follow_redirects=self._follow_redirects,
                **self._httpx_args,
            )
        return self._client

    def __enter__(self) -> "AuthenticatedClient":
        """Enters a context manager for the synchronous httpx client."""
        self.get_httpx_client().__enter__()
        return self

    def __exit__(self, *args: Any, **kwargs: Any) -> None:
        """Exits the context manager for the synchronous httpx client."""
        self.get_httpx_client().__exit__(*args, **kwargs)

    def set_async_httpx_client(
        self, async_client: httpx.AsyncClient
    ) -> "AuthenticatedClient":
        """
        Manually sets the underlying `httpx.AsyncClient`.

        It is recommended that the provided client is an instance of
        `AsyncMultipartFixClient` or compatible. This will override other
        client settings.
        """
        if not isinstance(async_client, AsyncMultipartFixClient):
            pass
        self._async_client = async_client
        return self

    def get_async_httpx_client(self) -> httpx.AsyncClient:
        """
        Retrieves the underlying `httpx.AsyncClient`, defaulting to `AsyncMultipartFixClient`.

        If no client has been set, a new `AsyncMultipartFixClient` is initialized
        with the `AuthenticatedClient`'s settings and Authorization header.
        """
        if self._async_client is None:
            request_headers = self._headers.copy()
            auth_value = f"{self.prefix} {self.token}" if self.prefix else self.token
            request_headers[self.auth_header_name] = auth_value

            self._async_client = AsyncMultipartFixClient(
                base_url=self._base_url,
                cookies=self._cookies,
                headers=request_headers,
                timeout=self._timeout,
                verify=self._verify_ssl,
                follow_redirects=self._follow_redirects,
                **self._httpx_args,
            )
        return self._async_client

    async def __aenter__(self) -> "AuthenticatedClient":
        """Enters a context manager for the asynchronous httpx client."""
        await self.get_async_httpx_client().__aenter__()
        return self

    async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
        """Exits the context manager for the asynchronous httpx client."""
        await self.get_async_httpx_client().__aexit__(*args, **kwargs)
