from __future__ import annotations

import base64
import contextvars
import json
from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, TypeVar

from fastapi import Query, Request
from pydantic import BaseModel, Field

T = TypeVar("T")


# ---------- Core query models ----------
class CursorParams(BaseModel):
    cursor: Optional[str] = None
    limit: int = 50


class PageParams(BaseModel):
    page: int = 1
    page_size: int = 50


class FilterParams(BaseModel):
    q: Optional[str] = None
    sort: Optional[str] = None
    created_after: Optional[str] = None
    created_before: Optional[str] = None
    updated_after: Optional[str] = None
    updated_before: Optional[str] = None


# ---------- Envelope model ----------
class Paginated(BaseModel, Generic[T]):
    items: List[T]
    next_cursor: Optional[str] = Field(None, description="Opaque cursor for next page")
    total: Optional[int] = Field(None, description="Total items (optional)")


# ---------- Cursor helpers ----------
def _encode_cursor(payload: dict) -> str:
    raw = json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8")
    return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")


# public; handy if you need to decode an incoming cursor
def decode_cursor(token: Optional[str]) -> dict:
    if not token:
        return {}
    s = token + "=" * (-len(token) % 4)
    raw = base64.urlsafe_b64decode(s.encode("ascii")).decode("utf-8")
    return json.loads(raw)


# ---------- Context ----------
class PaginationContext(Generic[T]):
    # mode config
    envelope: bool
    allow_cursor: bool
    allow_page: bool

    # values
    cursor_params: CursorParams | None
    page_params: PageParams | None
    filters: FilterParams | None

    def __init__(
        self,
        *,
        envelope: bool,
        allow_cursor: bool,
        allow_page: bool,
        cursor_params: CursorParams | None,
        page_params: PageParams | None,
        filters: FilterParams | None,
        limit_override: int | None = None,
    ):
        self.envelope = envelope
        self.allow_cursor = allow_cursor
        self.allow_page = allow_page
        self.cursor_params = cursor_params
        self.page_params = page_params
        self.filters = filters
        self.limit_override = limit_override

    @property
    def cursor(self) -> Optional[str]:
        return (self.cursor_params or CursorParams()).cursor if self.allow_cursor else None

    @property
    def limit(self) -> int:
        if self.allow_cursor and self.cursor_params and self.cursor_params.cursor is not None:
            return self.cursor_params.limit
        if self.allow_page and self.page_params:
            return self.limit_override or self.page_params.page_size
        return 50

    @property
    def page(self) -> Optional[int]:
        return self.page_params.page if (self.allow_page and self.page_params) else None

    @property
    def page_size(self) -> Optional[int]:
        return self.page_params.page_size if (self.allow_page and self.page_params) else None

    @property
    def offset(self) -> int:
        if self.cursor is None and self.allow_page and self.page and self.page_size:
            return (self.page - 1) * self.page_size
        return 0

    def wrap(
        self, items: list[T], *, next_cursor: Optional[str] = None, total: Optional[int] = None
    ):
        if self.envelope:
            return Paginated[T](items=items, next_cursor=next_cursor, total=total)
        return items

    # convenience: derive a naive next_cursor from the last item
    def next_cursor_from_last(
        self, items: Sequence[T], *, key: Callable[[T], str | int]
    ) -> Optional[str]:
        if not items:
            return None
        last_key = key(items[-1])
        return _encode_cursor({"after": last_key})


_pagination_ctx: contextvars.ContextVar[PaginationContext] = contextvars.ContextVar(
    "pagination_ctx", default=None
)


def use_pagination() -> PaginationContext:
    ctx = _pagination_ctx.get()
    if ctx is None:
        # Safe defaults; this happens if a route forgot to install the injector
        ctx = PaginationContext(
            envelope=False,
            allow_cursor=True,
            allow_page=True,
            cursor_params=CursorParams(),
            page_params=PageParams(),
            filters=FilterParams(),
        )
    return ctx


def text_filter(items: Iterable[T], q: Optional[str], *getters: Callable[[T], str]) -> list[T]:
    """Simple contains filter across one or more string fields."""
    if not q:
        return list(items)
    ql = q.lower()
    out: list[T] = []
    for it in items:
        for g in getters:
            try:
                if ql in (g(it) or "").lower():
                    out.append(it)
                    break
            except Exception:
                pass
    return out


def sort_by(
    items: Iterable[T],
    *,
    key: Callable[[T], Any],
    desc: bool = False,
) -> list[T]:
    """Stable sort with a key func."""
    return sorted(list(items), key=key, reverse=desc)


def cursor_window(items, *, cursor, limit, key, descending: bool, offset: int = 0):
    # items must already be filtered/sorted

    # compute start_index
    if cursor:
        payload = decode_cursor(cursor)
        after = payload.get("after")
        ids = [key(x) for x in items]
        if descending:
            start_index = next((i for i, v in enumerate(ids) if v < after), len(items))
        else:
            start_index = next((i for i, v in enumerate(ids) if v > after), len(items))
    else:
        start_index = offset

    # take limit+1 to see if there’s another page
    slice_ = items[start_index : start_index + limit + 1]
    has_more = len(slice_) > limit
    window = slice_[:limit]

    next_cur = None
    if has_more:
        last_key = key(window[-1])
        next_cur = _encode_cursor({"after": last_key})

    return window, next_cur


# ---------- Dependency factory (used by the router decorator) ----------
def make_pagination_injector(
    *,
    envelope: bool,
    allow_cursor: bool,
    allow_page: bool,
    default_limit: int = 50,
    max_limit: int = 200,
):
    async def _inject(
        request: Request,
        cursor: str | None = Query(None),
        limit: int = Query(default_limit, ge=1, le=max_limit),
        page: int = Query(1, ge=1),
        page_size: int = Query(default_limit, ge=1, le=max_limit),
        q: str | None = Query(None),
        sort: str | None = Query(None),
        created_after: str | None = Query(None),
        created_before: str | None = Query(None),
        updated_after: str | None = Query(None),
        updated_before: str | None = Query(None),
    ):
        cur = CursorParams(cursor=cursor, limit=limit) if allow_cursor else None
        pag = PageParams(page=page, page_size=page_size) if allow_page else None
        flt = FilterParams(
            q=q,
            sort=sort,
            created_after=created_after,
            created_before=created_before,
            updated_after=updated_after,
            updated_before=updated_before,
        )

        # detect if 'limit' was explicitly provided
        limit_override = (
            limit
            if (
                "limit" in request.query_params
                and "page_size" not in request.query_params
                and cursor is None
            )
            else None
        )

        _pagination_ctx.set(
            PaginationContext(
                envelope=envelope,
                allow_cursor=allow_cursor,
                allow_page=allow_page,
                cursor_params=cur,
                page_params=pag,
                filters=flt,
                limit_override=limit_override,
            )
        )
        return None

    return _inject
