import asyncio
from collections.abc import Awaitable, Callable
import functools
import inspect
import time
from typing import Any

from attrs import define
from pendulum import DateTime, Duration

from atta.observability.logging import get_logger
from atta.utils.backoff import exponential_backoff
from atta.utils.time import current_time

logger = get_logger(__name__)

# -------------------------------
# Shared
# -------------------------------


@define(slots=True, kw_only=True)
class FailedAttempt:
    """Failed attempt details."""

    attempt: int
    exception: Exception
    call_time: DateTime
    exception_time: DateTime


class RetryError(Exception):
    """Retry error."""

    def __init__(self, message: str, failures: list[FailedAttempt]) -> None:
        super().__init__(message)
        self.failures = failures


@define(slots=True, kw_only=True, frozen=True)
class RetryConfig:
    """Retry configuration."""

    on_exceptions: tuple[type[Exception], ...]
    max_attempts: int = 3
    backoff_func: Callable[[int], float] = lambda attempt: exponential_backoff(
        attempt, min_delay=Duration(seconds=1), max_delay=Duration(seconds=30)
    )


# -------------------------------
# Core Retry Logic
# -------------------------------


async def async_retry[T](func: Callable[[], Awaitable[T]], config: RetryConfig) -> T:
    """Retry async function with exponential backoff.

    Args:
        func: Async callable to retry
        config: Retry configuration (exceptions, attempts, backoff)

    Returns:
        Result from successful func() call

    Raises:
        RetryError: All attempts exhausted, contains failure history
        Exception: Any non-retryable exception bubbles immediately
    """
    attempt = 1
    failures: list[FailedAttempt] = []

    while attempt <= config.max_attempts:
        call_time = current_time()

        try:
            return await func()

        except config.on_exceptions as e:
            exception_time = current_time()

            failures.append(
                FailedAttempt(
                    attempt=attempt,
                    exception=e,
                    call_time=call_time,
                    exception_time=exception_time,
                )
            )

            logger.warning(
                f"Retry attempt {attempt}/{config.max_attempts} failed",
                extra={
                    "attempt": attempt,
                    "max_attempts": config.max_attempts,
                    "exception_type": type(e).__name__,
                    "exception_message": str(e),
                    "call_time": call_time.isoformat(),
                    "exception_time": exception_time.isoformat(),
                },
            )

            if attempt < config.max_attempts:
                delay = config.backoff_func(attempt)
                logger.debug(f"Retrying in {delay}s")
                await asyncio.sleep(delay)

            attempt += 1

    # All attempts exhausted
    logger.error(
        f"Max retry attempts ({config.max_attempts}) exhausted",
        extra={
            "max_attempts": config.max_attempts,
            "total_failures": len(failures),
            "last_exception_type": type(failures[-1].exception).__name__,
        },
    )

    raise RetryError(
        f"Max attempts ({config.max_attempts}) reached",
        failures=failures,
    ) from failures[-1].exception


def sync_retry[T](func: Callable[[], T], config: RetryConfig) -> T:
    """Retry sync function with exponential backoff.

    Args:
        func: Callable to retry
        config: Retry configuration (exceptions, attempts, backoff)

    Returns:
        Result from successful func() call

    Raises:
        RetryError: All attempts exhausted, contains failure history
        Exception: Any non-retryable exception bubbles immediately
    """
    attempt = 1
    failures: list[FailedAttempt] = []

    while attempt <= config.max_attempts:
        call_time = current_time()

        try:
            return func()

        except config.on_exceptions as e:
            exception_time = current_time()

            failures.append(
                FailedAttempt(
                    attempt=attempt,
                    exception=e,
                    call_time=call_time,
                    exception_time=exception_time,
                )
            )

            logger.warning(
                f"Retry attempt {attempt}/{config.max_attempts} failed",
                extra={
                    "attempt": attempt,
                    "max_attempts": config.max_attempts,
                    "exception_type": type(e).__name__,
                    "exception_message": str(e),
                    "call_time": call_time.isoformat(),
                    "exception_time": exception_time.isoformat(),
                },
            )

            if attempt < config.max_attempts:
                delay = config.backoff_func(attempt)
                logger.debug(f"Retrying in {delay}s")
                time.sleep(delay)

            attempt += 1

    # All attempts exhausted
    logger.error(
        f"Max retry attempts ({config.max_attempts}) exhausted",
        extra={
            "max_attempts": config.max_attempts,
            "total_failures": len(failures),
            "last_exception_type": type(failures[-1].exception).__name__,
        },
    )

    raise RetryError(
        f"Max attempts ({config.max_attempts}) reached",
        failures=failures,
    ) from failures[-1].exception


# -------------------------------
# Decorator
# -------------------------------


def retry[T](
    config: RetryConfig,
) -> Callable[
    [Callable[..., T] | Callable[..., Awaitable[T]]],
    Callable[..., T] | Callable[..., Awaitable[T]],
]:
    """Decorator to add retry logic to sync or async functions.

    Automatically detects function type and applies appropriate retry logic.

    Args:
        config: Retry configuration
            - on_exceptions: Tuple of exception types to retry
            - max_attempts: Maximum retry attempts (default: 3)
            - backoff_func: Backoff function (default: exponential)

    Returns:
        Decorated function with retry logic

    Raises:
        RetryError: Contains failure history when exhausted
    """

    def decorator(
        func: Callable[..., T] | Callable[..., Awaitable[T]],
    ) -> Callable[..., T] | Callable[..., Awaitable[T]]:
        if inspect.iscoroutinefunction(func):

            @functools.wraps(func)
            async def async_wrapper(*args: Any, **kwargs: Any) -> T:
                async def call() -> T:
                    return await func(*args, **kwargs)  # type:ignore[no-any-return]

                return await async_retry(call, config)

            return async_wrapper

        @functools.wraps(func)
        def sync_wrapper(*args: Any, **kwargs: Any) -> T:
            def call() -> T:
                return func(*args, **kwargs)  # type:ignore[return-value]

            return sync_retry(call, config)

        return sync_wrapper

    return decorator
