import json
import logging
from abc import abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, TypeVar, Generic, Self, Any, Callable, Iterable

import requests
from django.db import transaction
from django.db.models import QuerySet
from django.utils.timezone import make_aware
from pydantic import BaseModel, ValidationError
from redis.exceptions import LockError
from rest_framework.serializers import SerializerMetaclass

from wise.station.heap import ReplicatedHashHeap
from wise.utils.models import BaseModel as DjangoBaseModel
from wise.utils.redis import get_redis_client

ModelType = TypeVar("ModelType", bound=DjangoBaseModel)
BoundedType = TypeVar("BoundedType", bound=DjangoBaseModel)

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from django.db.models.manager import RelatedManager  # type: ignore

    RelatedManagerType = RelatedManager[ModelType]
else:
    RelatedManagerType = Any


class SkipUpdate(Exception):
    pass


class Updater:
    @abstractmethod
    def update(self):
        ...


class IterableUpdater(Updater):
    def __init__(self, col: Iterable[Updater]):
        self.col = col

    def update(self):
        for updater in self.col:
            updater.update()


class QuerysetUpdater(Updater):
    def __init__(self, qs: QuerySet):
        self.qs = qs

    def update(self):
        for updater in self.qs.all():
            updater.update()


class CacheUpdater(Updater):
    def __init__(
        self,
        queryset: QuerySet,
        serializer: SerializerMetaclass | Callable,
        cache_key_prefix: str,
        periodic_should_update_func: Callable | None = None,
        get_value_should_update_func: Callable | None = None,
        lock_timeout_seconds: int = 300,
    ):
        self.queryset = queryset
        self.serializer = serializer
        self.cache_key_prefix = cache_key_prefix
        if periodic_should_update_func is not None:
            self.periodic_should_update_func = periodic_should_update_func
        if get_value_should_update_func is not None:
            self.get_value_should_update_func = get_value_should_update_func
        self.lock_timeout_seconds = lock_timeout_seconds
        self.batch_size = 1000
        self.redis = get_redis_client()
        self.now = datetime.now()

    def update(self):
        try:
            with self.redis.lock(
                self.get_cache_key(), self.lock_timeout_seconds, blocking=False
            ):
                self._update()
        except LockError:
            ...

    def _update(self):
        self.now = datetime.now()
        start_index = 0
        count = self.queryset.count()
        while start_index < count:
            query_set = self.queryset.all()[start_index : start_index + self.batch_size]
            start_index += self.batch_size
            last_update_keys = [
                self.get_last_update_timestamp_redis_key(instance)
                for instance in query_set
            ]
            last_update_timestamps = self.redis.mget(last_update_keys)
            for instance, last_update_timestamp in zip(
                query_set, last_update_timestamps
            ):
                try:
                    last_update = None
                    if last_update_timestamp is not None:
                        last_update = make_aware(datetime.fromtimestamp(float(last_update_timestamp)))  # type: ignore
                    if self.periodic_should_update(instance, last_update):
                        self.force_update(instance)
                except Exception:
                    logger.exception(f"could not cache {instance}")

    def periodic_should_update(
        self, instance, last_update: datetime | None = None
    ) -> bool:
        if last_update is None:
            last_update = self.get_last_update(instance)
        if self.periodic_should_update_func is None:
            return self.default_should_update(instance, last_update)
        return self.periodic_should_update_func(instance, last_update)

    def get_value_should_update(self, instance) -> bool:
        last_update = self.get_last_update(instance)
        if self.get_value_should_update_func is None:
            return self.default_should_update(instance, last_update)
        return self.get_value_should_update_func(instance, last_update)

    def default_should_update(self, instance, last_update: datetime | None) -> bool:
        return last_update is None or last_update < instance.updated_at

    def get_last_update(self, instance) -> datetime | None:
        last_update_timestamp = self.redis.get(
            self.get_last_update_timestamp_redis_key(instance)
        )
        if last_update_timestamp is None:
            return None
        return make_aware(datetime.fromtimestamp(float(last_update_timestamp)))  # type: ignore

    def get_value(self, instance):
        if self.get_value_should_update(instance):
            value = self.force_update(instance)
        else:
            value = self.redis.get(self.get_value_redis_key(instance))
            if value is None:
                value = self.force_update(instance)
        return json.loads(value)

    def force_update(self, instance) -> str:
        r = get_redis_client()
        value = json.dumps(self.serializer(instance).data)
        r.set(self.get_value_redis_key(instance), value)
        r.set(self.get_last_update_timestamp_redis_key(instance), self.now.timestamp())
        return value

    def get_last_update_timestamp_redis_key(self, instance) -> str:
        return f"{self.cache_key_prefix}:{instance.key}:timestamp"

    def get_value_redis_key(self, instance) -> str:
        return f"{self.cache_key_prefix}:{instance.key}:value"

    def get_cache_key(self) -> str:
        return f"{self.cache_key_prefix}:run_lock"


class ModelUpdater(BaseModel, Generic[ModelType], Updater):
    @abstractmethod
    def update(self) -> ModelType:
        ...


class BindingUpdater(BaseModel, Generic[ModelType, BoundedType]):
    @abstractmethod
    def update(self, bounded: BoundedType) -> ModelType:
        ...

    @classmethod
    def update_bindings(
        cls,
        *,
        updaters: list[Self],
        bounded: BoundedType,
        bindings: RelatedManagerType,  # type: ignore
    ) -> list[ModelType]:
        with transaction.atomic():
            instances = [u.update(bounded) for u in updaters]
            bindings.exclude(key__in=[i.key for i in instances]).delete()  # type: ignore
            return instances


class UpdaterHandler:
    def __init__(self, default_object_name: str | None = None) -> None:
        self.handlers: dict[
            str, tuple[type[Updater], tuple[Callable[[Any], None], ...]]
        ] = {}
        self.default_object_name = default_object_name

    def add(
        self,
        object_name: str,
        updater_class: type[Updater],
        *callbacks: Callable[[Any], None],
    ) -> Self:
        assert object_name not in self.handlers
        self.handlers[object_name] = updater_class, callbacks
        return self

    def get_heaps(self) -> Iterable[ReplicatedHashHeap]:
        for name, (updater_class, _) in self.handlers.items():
            heap = self.get_heap(name, updater_class)
            if heap:
                yield heap

    def get_heap(
        self, name: str, updater_class: type[Updater]
    ) -> ReplicatedHashHeap | None:
        if f := getattr(updater_class, "heap_endpoint", None):
            if callable(f):
                endpoint = f()
            else:
                endpoint = f
            return ReplicatedHashHeap(
                name=name,
                get_node=lambda index: self._fetch_heap_node(name, endpoint, index),
            )
        return None

    def _fetch_heap_node(self, name: str, endpoint: str, index: int) -> dict:
        resp = requests.get(
            f"{endpoint}/station/heap/node", params={"name": name, "index": str(index)}
        )
        resp.raise_for_status()
        d = resp.json()
        self(d, update_heap=False, raise_exception=True)
        return d["heap_node"]

    def __call__(
        self, value, update_heap: bool = True, raise_exception: bool = False
    ) -> None:
        heap_node = None

        if "object_name" in value:
            name = value["object_name"]
            data = value["body"]
            heap_node = value.get("heap_node")

        elif self.default_object_name:
            name = self.default_object_name
            data = value
        else:
            logger.info("unknown object received")
            return

        if name not in self.handlers:
            logger.info("unknown object received", extra={"object_name": name})
            return

        updater_class, callbacks = self.handlers[name]

        def do_heap_update():
            if not heap_node or not update_heap:
                return
            heap = self.get_heap(heap_node["name"], updater_class)
            heap.update_node(int(heap_node["index"]), timeout=10)  # TODO: timeout

        try:
            updater = updater_class(**data)
        except ValidationError as e:
            if raise_exception:
                raise e
            logger.exception("invalid object received")
            return

        try:
            instance = updater.update()
        except SkipUpdate:
            logger.info("update skipped")
            do_heap_update()
            return

        if instance is None:
            do_heap_update()
            return

        if isinstance(instance, DjangoBaseModel):
            logger.info(
                "object persisted",
                extra={
                    "model": instance.__class__.__name__,
                    "key": instance.key,
                    "created_at": instance.created_at,
                    "updated_at": instance.updated_at,
                },
            )
        else:
            logger.info("object persisted", extra={"object_name": name})

        for callback in callbacks:
            try:
                callback(instance)
            except Exception as e:
                if raise_exception:
                    raise e
                logger.exception("callback failed")

        do_heap_update()


class UpdaterSet:
    def __init__(self, updaters: dict[str, Updater]) -> None:
        self.updaters = updaters

    def add(self, name: str, updater: Updater) -> None:
        self.updaters[name] = updater

    def update(self, name: str | None = None) -> None:
        if name:
            self.updaters[name].update()
        else:
            for updater in self.updaters.values():
                updater.update()
