#!/usr/bin/env python3
# coding: utf-8

__author__ = "ChenyangGao <https://chenyanggao.github.io>"
__version__ = (0, 1, 5)
__all__ = ["request", "request_sync", "request_async"]

from collections import UserString
from collections.abc import (
    Awaitable, Buffer, Callable, Iterable, Mapping, 
)
from contextlib import aclosing, closing
from inspect import isawaitable, signature
from os import PathLike
from types import EllipsisType
from typing import cast, overload, Any, Final, Literal

from asynctools import async_map, run_async
from argtools import argcount
from dicttools import get_all_items
from ensure import ensure_buffer
from filewrap import bio_chunk_iter, bio_chunk_async_iter, SupportsRead
from http_request import normalize_request_args, SupportsGeturl
from http_response import parse_response
from httpx import AsyncHTTPTransport, HTTPTransport, Request, Response
from httpx._types import SyncByteStream
from httpx._client import AsyncClient, Client, Response
from yarl import URL


type string = Buffer | str | UserString

_BUILD_REQUEST_KWARGS: Final = signature(Client.build_request).parameters.keys() - {"self"}
_INIT_CLIENT_KWARGS: Final   = signature(Client).parameters.keys() - _BUILD_REQUEST_KWARGS
_SEND_REQUEST_KWARGS: Final  = ("auth", "stream", "follow_redirects")

if "__del__" not in Client.__dict__:
    setattr(Client, "__del__", Client.close)
if "close" not in AsyncClient.__dict__:
    def close(self, /):
        return run_async(self.aclose())
    setattr(AsyncClient, "close", close)
if "__del__" not in AsyncClient.__dict__:
    setattr(AsyncClient, "__del__", getattr(AsyncClient, "close"))
if "__del__" not in Response.__dict__:
    def __del__(self, /):
        if self.is_closed:
            return
        if isinstance(self.stream, SyncByteStream):
            self.close()
        else:
            return run_async(self.aclose())
    setattr(Response, "__del__", __del__)


@overload
def request_sync(
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | Client = None, 
    *, 
    parse: None = None, 
    **request_kwargs, 
) -> bytes:
    ...
@overload
def request_sync(
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | Client = None, 
    *, 
    parse: Literal[False], 
    **request_kwargs, 
) -> Response:
    ...
@overload
def request_sync(
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | Client = None, 
    *, 
    parse: Literal[True], 
    **request_kwargs, 
) -> bytes | str | dict | list | int | float | bool | None:
    ...
@overload
def request_sync[T](
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | Client = None, 
    *, 
    parse: Callable[[Response, bytes], T] | Callable[[Response], T], 
    **request_kwargs, 
) -> T:
    ...
def request_sync[T](
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | Client = None, 
    *, 
    parse: None | EllipsisType | bool | Callable[[Response, bytes], T] | Callable[[Response], T] = None, 
    **request_kwargs, 
) -> Response | bytes | str | dict | list | int | float | bool | None | T:
    request_kwargs["follow_redirects"] = follow_redirects
    request_kwargs.setdefault("stream", True)
    if session is None:
        init_kwargs = dict(get_all_items(request_kwargs, *_INIT_CLIENT_KWARGS))
        if "transport" not in init_kwargs:
            init_kwargs["transport"] = HTTPTransport(http2=True, retries=5)
        session = Client(**init_kwargs)
    if isinstance(url, Request):
        request = url
    else:
        if isinstance(data, PathLike):
            data = bio_chunk_iter(open(data, "rb"))
        elif isinstance(data, SupportsRead):
            data = map(ensure_buffer, bio_chunk_iter(data))
        request_kwargs.update(normalize_request_args(
            method=method, 
            url=url, 
            params=params, 
            data=data, 
            json=json, 
            headers=headers, 
        ))
        request = session.build_request(**dict(get_all_items(request_kwargs, _BUILD_REQUEST_KWARGS)))
    response = session.send(request, **dict(get_all_items(request_kwargs, _SEND_REQUEST_KWARGS)))
    # NOTE: keep ref to prevent gc
    setattr(response, "session", session)
    if response.status_code >= 400 and raise_for_status:
        response.raise_for_status()
    if parse is None:
        return response
    elif parse is ...:
        response.close()
        return response
    with closing(response):
        if isinstance(parse, bool):
            content = response.read()
            if parse:
                return parse_response(response, content)
            return content
        ac = argcount(parse)
        if ac == 1:
            return cast(Callable[[Response], T], parse)(response)
        else:
            return cast(Callable[[Response, bytes], T], parse)(response, response.read())


@overload
async def request_async(
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | AsyncClient = None, 
    *, 
    parse: None = None, 
    **request_kwargs, 
) -> bytes:
    ...
@overload
async def request_async(
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | AsyncClient = None, 
    *, 
    parse: Literal[False], 
    **request_kwargs, 
) -> Response:
    ...
@overload
async def request_async(
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | AsyncClient = None, 
    *, 
    parse: Literal[True], 
    **request_kwargs, 
) -> bytes | str | dict | list | int | float | bool | None:
    ...
@overload
async def request_async[T](
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | AsyncClient = None, 
    *, 
    parse: Callable[[Response, bytes], T] | Callable[[Response, bytes], Awaitable[T]] | Callable[[Response], T] | Callable[[Response], Awaitable[T]], 
    **request_kwargs, 
) -> T:
    ...
async def request_async[T](
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | AsyncClient = None, 
    *, 
    parse: None | EllipsisType | bool | Callable[[Response, bytes], T] | Callable[[Response, bytes], Awaitable[T]] | Callable[[Response], T] | Callable[[Response], Awaitable[T]] = None, 
    **request_kwargs, 
) -> Response | bytes | str | dict | list | int | float | bool | None | T:
    request_kwargs["follow_redirects"] = follow_redirects
    request_kwargs.setdefault("stream", True)
    if session is None:
        init_kwargs = dict(get_all_items(request_kwargs, *_INIT_CLIENT_KWARGS))
        if "transport" not in init_kwargs:
            init_kwargs["transport"] = AsyncHTTPTransport(http2=True, retries=5)
        session = AsyncClient(**init_kwargs)
    if isinstance(url, Request):
        request = url
    else:
        if isinstance(data, PathLike):
            data = bio_chunk_async_iter(open(data, "rb"))
        elif isinstance(data, SupportsRead):
            data = async_map(ensure_buffer, bio_chunk_async_iter(data))
        request_kwargs.update(normalize_request_args(
            method=method, 
            url=url, 
            params=params, 
            data=data, 
            json=json, 
            headers=headers, 
        ))
        request = session.build_request(**dict(get_all_items(request_kwargs, _BUILD_REQUEST_KWARGS)))
    response = await session.send(request, **dict(get_all_items(request_kwargs, _SEND_REQUEST_KWARGS)))
    setattr(response, "session", session)
    if response.status_code >= 400 and raise_for_status:
        response.raise_for_status()
    if parse is None:
        return response
    elif parse is ...:
        await response.aclose()
        return response
    async with aclosing(response):
        if isinstance(parse, bool):
            content = await response.aread()
            if parse:
                return parse_response(response, content)
            return content
        ac = argcount(parse)
        if ac == 1:
            ret = cast(Callable[[Response], T] | Callable[[Response], Awaitable[T]], parse)(response)
        else:
            ret = cast(Callable[[Response, bytes], T] | Callable[[Response, bytes], Awaitable[T]], parse)(
                response, await response.aread())
        if isawaitable(ret):
            ret = await ret
        return ret


@overload
def request[T](
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | Client = None, 
    *, 
    parse: None | EllipsisType | bool | Callable[[Response, bytes], T] | Callable[[Response], T] = None, 
    async_: Literal[False] = False, 
    **request_kwargs, 
) -> Response | bytes | str | dict | list | int | float | bool | None | T:
    ...
@overload
def request[T](
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | AsyncClient = None, 
    *, 
    parse: None | EllipsisType | bool | Callable[[Response, bytes], T] | Callable[[Response, bytes], Awaitable[T]] | Callable[[Response], T] | Callable[[Response], Awaitable[T]] = None, 
    async_: Literal[True], 
    **request_kwargs, 
) -> Awaitable[Response | bytes | str | dict | list | int | float | bool | None | T]:
    ...
def request[T](
    url: string | SupportsGeturl | URL | Request, 
    method: string = "GET", 
    params: None | string | Mapping | Iterable[tuple[Any, Any]] = None, 
    data: Any = None, 
    json: Any = None, 
    headers: None | Mapping[string, string] | Iterable[tuple[string, string]] = None, 
    follow_redirects: bool = True, 
    raise_for_status: bool = True, 
    session: None | Client | AsyncClient = None, 
    *, 
    parse: None | EllipsisType | bool | Callable[[Response, bytes], T] | Callable[[Response, bytes], Awaitable[T]] | Callable[[Response], T] | Callable[[Response], Awaitable[T]] = None, 
    async_: Literal[False, True] = False, 
    **request_kwargs, 
) -> Response | bytes | str | dict | list | int | float | bool | None | T | Awaitable[Response | bytes | str | dict | list | int | float | bool | None | T]:
    if async_:
        return request_async(
            url=url, 
            method=method, 
            params=params, 
            data=data, 
            json=json, 
            headers=headers, 
            follow_redirects=follow_redirects, 
            raise_for_status=raise_for_status, 
            session=cast(None | AsyncClient, session), 
            parse=parse, # type: ignore 
            **request_kwargs, 
        )
    else:
        return request_sync(
            url=url, 
            method=method, 
            params=params, 
            data=data, 
            json=json, 
            headers=headers, 
            follow_redirects=follow_redirects, 
            raise_for_status=raise_for_status, 
            session=cast(None | Client, session), 
            parse=parse, # type: ignore  
            **request_kwargs, 
        )

