# fastapi_easylimiter.py

from abc import ABC, abstractmethod
from time import time
import asyncio
from typing import Optional, Callable, Dict
import redis.asyncio as redis_async

from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse

# ----------------------------
# Backend Interface
# ----------------------------
class RateLimitBackend(ABC):
    @abstractmethod
    async def incr(self, key: str, expire: Optional[int] = None) -> int:
        ...

    @abstractmethod
    async def expire(self, key: str, ttl: int):
        ...

    @abstractmethod
    async def ttl(self, key: str) -> int:
        ...

# ----------------------------
# Redis Backend (async, atomic)
# ----------------------------
class AsyncRedisBackend(RateLimitBackend):
    LUA_SCRIPT = """
    local current = redis.call('INCR', KEYS[1])
    if current == 1 and tonumber(ARGV[1]) > 0 then
        redis.call('EXPIRE', KEYS[1], ARGV[1])
    end
    return current
    """

    def __init__(self, redis_client: redis_async.Redis):
        self.redis = redis_client
        # Pre-register Lua script
        self.incr_script = self.redis.register_script(self.LUA_SCRIPT)

    async def incr(self, key: str, expire: Optional[int] = None) -> int:
        # Atomic increment + set TTL if first hit
        return int(await self.incr_script(keys=[key], args=[expire or 0]))

    async def expire(self, key: str, ttl: int):
        await self.redis.expire(key, ttl)

    async def ttl(self, key: str) -> int:
        ttl = await self.redis.ttl(key)
        return ttl if ttl >= 0 else 0

# ----------------------------
# In-Memory Backend (single instance, high QPS friendly)
# ----------------------------
class InMemoryBackend(RateLimitBackend):
    def __init__(self, cleanup_interval: int = 60):
        self.store: Dict[str, tuple[int, float]] = {}
        self.locks: Dict[str, asyncio.Lock] = {}  # Per-key locks
        self.cleanup_interval = cleanup_interval
        asyncio.create_task(self._cleanup_task())

    async def _get_lock(self, key: str) -> asyncio.Lock:
        if key not in self.locks:
            self.locks[key] = asyncio.Lock()
        return self.locks[key]

    async def incr(self, key: str, expire: Optional[int] = None) -> int:
        now = time()
        lock = await self._get_lock(key)
        async with lock:
            value, expire_ts = self.store.get(key, (0, 0))
            if expire_ts and now >= expire_ts:
                value, expire_ts = 0, 0
            value += 1
            if expire:
                expire_ts = now + expire
            self.store[key] = (value, expire_ts)
            return value

    async def expire(self, key: str, ttl: int):
        now = time()
        lock = await self._get_lock(key)
        async with lock:
            value, _ = self.store.get(key, (0, 0))
            self.store[key] = (value, now + ttl)

    async def ttl(self, key: str) -> int:
        entry = self.store.get(key)
        if not entry:
            return 0
        _, expire_ts = entry
        if not expire_ts:
            return 0
        return max(int(expire_ts - time()), 0)

    async def _cleanup_task(self):
        while True:
            await asyncio.sleep(self.cleanup_interval)
            now = time()
            keys_to_delete = []
            for k, (_, ts) in self.store.items():
                if ts and ts < now:
                    keys_to_delete.append(k)
            for k in keys_to_delete:
                del self.store[k]
                if k in self.locks:
                    del self.locks[k]

# ----------------------------
# Rate Limiter Middleware
# ----------------------------
class RateLimiterMiddleware(BaseHTTPMiddleware):
    def __init__(
        self,
        app: FastAPI,
        rules: dict,
        backend: RateLimitBackend,
        backoff_threshold: int = 3,
        max_backoff: int = 60
    ):
        super().__init__(app)
        self.backend = backend
        self.backoff_threshold = backoff_threshold
        self.max_backoff = max_backoff
        # Sort prefixes by length descending for specific routes first
        self.sorted_rules = sorted(rules.items(), key=lambda x: len(x[0]), reverse=True)

    async def dispatch(self, request: Request, call_next):
        path = request.url.path
        user_key = None
        exceeded_ttls = []
        limit_headers = {}

        matched = False
        for prefix, cfg in self.sorted_rules:
            if path.startswith(prefix):
                matched = True
                limit = cfg.get("limit", 1)
                period = cfg.get("period", 1)
                key_func: Callable[[Request], str] = cfg.get("key_func", lambda r: r.client.host)

                if user_key is None:
                    user_key = key_func(request)

                key = f"ratelimit:{user_key}:{prefix}"
                violations_key = f"{key}:violations"

                count = await self.backend.incr(key, expire=period)
                remaining = max(limit - count, 0)
                limit_headers = {
                    "X-RateLimit-Limit": str(limit),
                    "X-RateLimit-Remaining": str(remaining),
                }

                if count > limit:
                    violations = await self.backend.incr(violations_key, expire=period * 5)
                    backoff_multiplier = 2 ** max(0, violations - self.backoff_threshold)
                    new_ttl = min(period * backoff_multiplier, self.max_backoff)
                    # Use expire instead of incr to reset TTL
                    await self.backend.expire(key, int(new_ttl))
                    exceeded_ttls.append(int(new_ttl))

        if not matched:
            # Fast path: no rules matched, just continue
            return await call_next(request)

        if exceeded_ttls:
            retry_after = max(exceeded_ttls)
            return JSONResponse(
                status_code=429,
                content={"detail": f"Too many requests. Retry in {retry_after} seconds."},
                headers={**limit_headers, "Retry-After": str(retry_after)}
            )

        response = await call_next(request)
        response.headers.update(limit_headers)
        return response
# ----------------------------