import json
import time

from howler.remote.datatypes import get_client, retry_call

_drop_card_script = """
local set_name = ARGV[1]
local key = ARGV[2]

redis.call('srem', set_name, key)
return redis.call('scard', set_name)
"""

_limited_add = """
local set_name = KEYS[1]
local key = ARGV[1]
local limit = tonumber(ARGV[2])

if redis.call('scard', set_name) < limit then
    redis.call('sadd', set_name, key)
    return true
end
return false
"""


class Set(object):
    def __init__(self, name, host=None, port=None):
        self.c = get_client(host, port, False)
        self.name = name
        self._drop_card = self.c.register_script(_drop_card_script)
        self._limited_add = self.c.register_script(_limited_add)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.delete()

    def add(self, *values):
        return retry_call(self.c.sadd, self.name, *[json.dumps(v) for v in values])

    def limited_add(self, value, size_limit):
        """Add a single value to the set, but only if that wouldn't make the set grow past a given size."""
        return retry_call(self._limited_add, keys=[self.name], args=[json.dumps(value), size_limit])

    def exist(self, value):
        return retry_call(self.c.sismember, self.name, json.dumps(value))

    def length(self):
        return retry_call(self.c.scard, self.name)

    def members(self):
        return [json.loads(s) for s in retry_call(self.c.smembers, self.name)]

    def rand_member(self, number=1):
        result = retry_call(self.c.srandmember, self.name, number)
        if not isinstance(result, list):
            result = [result]

        return [json.loads(entry) for entry in result]

    def remove(self, *values):
        return retry_call(self.c.srem, self.name, *[json.dumps(v) for v in values])

    def drop(self, value):
        return retry_call(self._drop_card, args=[value])

    def random(self, num=None):
        ret_val = retry_call(self.c.srandmember, self.name, num)
        if isinstance(ret_val, list):
            return [json.loads(s) for s in ret_val]
        else:
            return json.loads(ret_val)

    def pop(self):
        data = retry_call(self.c.spop, self.name)
        return json.loads(data) if data else None

    def pop_all(self):
        return [json.loads(s) for s in retry_call(self.c.spop, self.name, self.length())]

    def delete(self):
        retry_call(self.c.delete, self.name)


class ExpiringSet(Set):
    def __init__(self, name, ttl=86400, host=None, port=None):
        super(ExpiringSet, self).__init__(name, host, port)
        self.ttl = ttl
        self.last_expire_time: float = 0

    def _conditional_expire(self):
        if self.ttl:
            ctime = time.time()
            if ctime > self.last_expire_time + (self.ttl / 2):
                retry_call(self.c.expire, self.name, self.ttl)
                self.last_expire_time = ctime

    def add(self, *values):
        rval = super(ExpiringSet, self).add(*values)
        self._conditional_expire()
        return rval

    def limited_add(self, value, size_limit):
        rval = super(ExpiringSet, self).limited_add(value, size_limit)
        self._conditional_expire()
        return rval

    def members(self):
        rval = super(ExpiringSet, self).members()
        self._conditional_expire()
        return rval

    def rand_member(self, number=1):
        rval = super(ExpiringSet, self).rand_member(number)
        self._conditional_expire()
        return rval
