from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import Any

from dapr_agents.agents.configs import WorkflowGrpcOptions

logger = logging.getLogger(__name__)


def apply_grpc_options(options: WorkflowGrpcOptions | None) -> None:
    """Patch durabletask's gRPC channel factory with the provided limits."""

    if not options:
        return
    if not options.max_send_message_length and not options.max_receive_message_length:
        return

    try:
        import grpc
        from durabletask.internal import shared
    except ImportError as exc:  # pragma: no cover - defensive
        logger.error("Unable to patch durabletask gRPC settings: %s", exc)
        raise

    grpc_options: list[tuple[str, Any]] = []
    if options.max_send_message_length:
        grpc_options.append(("grpc.max_send_message_length", options.max_send_message_length))
    if options.max_receive_message_length:
        grpc_options.append(("grpc.max_receive_message_length", options.max_receive_message_length))

    def get_grpc_channel_with_options(
        host_address: str | None,
        secure_channel: bool = False,
        interceptors: Sequence[grpc.ClientInterceptor] | None = None,
        options: Sequence[tuple[str, Any]] | None = None,
    ):
        if host_address is None:
            host_address = shared.get_default_host_address()

        for protocol in getattr(shared, "SECURE_PROTOCOLS", []):
            if host_address.lower().startswith(protocol):
                secure_channel = True
                host_address = host_address[len(protocol) :]
                break

        for protocol in getattr(shared, "INSECURE_PROTOCOLS", []):
            if host_address.lower().startswith(protocol):
                secure_channel = False
                host_address = host_address[len(protocol) :]
                break

        channel_options = list(options or []) + grpc_options

        if secure_channel:
            channel = grpc.secure_channel(
                host_address,
                grpc.ssl_channel_credentials(),
                options=channel_options,
            )
        else:
            channel = grpc.insecure_channel(host_address, options=channel_options)

        if interceptors:
            channel = grpc.intercept_channel(channel, *interceptors)
        return channel

    shared.get_grpc_channel = get_grpc_channel_with_options
    logger.debug("Patched durabletask gRPC options: %s", dict(grpc_options))
