from datetime import datetime, timezone
from pathlib import Path
from typing import Optional, Union

import fsspec
from dateutil.parser import isoparse  # type: ignore
from lamin_logger import logger

from ._settings_instance import InstanceSettings
from .upath import UPath, infer_filesystem

EXPIRATION_TIME = 3600  # 60 min

MAX_MSG_COUNTER = 100  # print the msg after this number of iterations


class empty_locker:
    has_lock = True

    @classmethod
    def lock(cls):
        pass

    @classmethod
    def unlock(cls):
        pass


class Locker:
    def __init__(self, user_id: str, storage_root: Union[UPath, Path], name: str):
        logger.debug(f"Init cloud sqlite locker: {user_id}, {storage_root}, {name}.")

        self._counter = 0

        self.user = user_id
        self.name = name

        self.root = storage_root
        self.fs, _ = infer_filesystem(storage_root)

        exclusion_path = storage_root / f"exclusion/{name}"
        self.mapper = fsspec.FSMap(str(exclusion_path), self.fs, create=True)

        priorities_path = str(exclusion_path / "priorities")
        if self.fs.exists(priorities_path):
            self.users = self.mapper["priorities"].decode().split("*")

            if self.user not in self.users:
                self.priority = len(self.users)
                self.users.append(self.user)
                # potential problem here if 2 users join at the same time
                # can be avoided by using separate files for each user
                # and giving priority by timestamp
                # here writing the whole list back because gcs
                # does not support the append mode
                self.mapper["priorities"] = "*".join(self.users).encode()
            else:
                self.priority = self.users.index(self.user)
        else:
            self.mapper["priorities"] = self.user.encode()
            self.users = [self.user]
            self.priority = 0

        self.mapper[f"numbers/{self.user}"] = b"0"
        self.mapper[f"entering/{self.user}"] = b"0"

        # clean up failures
        for user in self.users:
            for endpoint in ("numbers", "entering"):
                user_endpoint = f"{endpoint}/{user}"
                user_path = str(exclusion_path / user_endpoint)
                if not self.fs.exists(user_path):
                    continue
                if self.mapper[user_endpoint] == b"0":
                    continue
                period = (datetime.now() - self.modified(user_path)).total_seconds()
                if period > EXPIRATION_TIME:
                    logger.info(
                        f"The lock of the user {user} seems to be stale, clearing"
                        f" {endpoint}."
                    )
                    self.mapper[user_endpoint] = b"0"

        self._has_lock = None
        self._locked_by = None

    def modified(self, path):
        if "gcs" not in self.fs.protocol:
            mtime = self.fs.modified(path)
        else:
            stat = self.fs.stat(path)
            if "updated" in stat:
                mtime = stat["updated"]
                mtime = isoparse(mtime)
            else:
                return None
        # always convert to the local timezone before returning
        # assume in utc if the time zone is not specified
        if mtime.tzinfo is None:
            mtime = mtime.replace(tzinfo=timezone.utc)
        return mtime.astimezone().replace(tzinfo=None)

    def _msg_on_counter(self, user):
        if self._counter == MAX_MSG_COUNTER:
            logger.info(f"Competing for the lock with the user {user}.")

        if self._counter <= MAX_MSG_COUNTER:
            self._counter += 1

    def _lock_unsafe(self):
        if self._has_lock:
            return None

        self._has_lock = True
        self._locked_by = self.user

        self.users = self.mapper["priorities"].decode().split("*")

        self.mapper[f"entering/{self.user}"] = b"1"

        numbers = [int(self.mapper[f"numbers/{user}"]) for user in self.users]
        number = 1 + max(numbers)
        self.mapper[f"numbers/{self.user}"] = str(number).encode()

        self.mapper[f"entering/{self.user}"] = b"0"

        for i, user in enumerate(self.users):
            if i == self.priority:
                continue

            while self.mapper[f"entering/{user}"] == b"1":
                self._msg_on_counter(user)

            c_number = int(self.mapper[f"numbers/{user}"])

            if c_number == 0:
                continue

            if (number > c_number) or (number == c_number and self.priority > i):
                self._has_lock = False
                self._locked_by = user
                self.mapper[f"numbers/{self.user}"] = b"0"
                logger.info(f"The instance is already locked by the user {user}.")
                return None

    def lock(self):
        try:
            self._lock_unsafe()
        except BaseException as e:
            self.unlock()
            self._clear()
            raise e

    def unlock(self):
        self.mapper[f"numbers/{self.user}"] = b"0"
        self._has_lock = None
        self._locked_by = None
        self._counter = 0

    def _clear(self):
        self.mapper[f"entering/{self.user}"] = b"0"

    @property
    def has_lock(self):
        if self._has_lock is None:
            logger.info("The lock has not been initialized, trying to obtain the lock.")
            self.lock()

        return self._has_lock


_locker: Optional[Locker] = None


def get_locker(isettings: InstanceSettings) -> Locker:
    from .._settings import settings

    global _locker

    user_id = settings.user.id
    storage_root = isettings.storage.root
    instance_name = isettings.name

    if (
        _locker is None
        or _locker.user != user_id
        or _locker.root is not storage_root
        or _locker.name != instance_name
    ):
        _locker = Locker(user_id, storage_root, instance_name)

    return _locker
