from __future__ import print_function
from future.utils import itervalues
from builtins import str
from mrq.task import Task
from mrq.queue import Queue
from bson import ObjectId
from mrq.context import connections, get_current_config
from collections import defaultdict
from mrq.utils import group_iter
import datetime
import ujson as json


def get_task_cfg(taskpath):
    return get_current_config().get("tasks", {}).get(taskpath) or {}


class JobAction(Task):

    params = None
    collection = None

    def run(self, params):

        self.params = params
        self.collection = connections.mongodb_jobs.mrq_jobs

        query = self.build_query()

        return self.perform_action(
            self.params.get("action"), query, self.params.get("destination_queue")
        )

    def build_query(self):
        query = {}
        if self.params.get("id"):
            query["_id"] = ObjectId(self.params.get("id"))

        # TODO use redis for queue
        for k in [
                "queue",
                "status",
                "worker",
                "path",
                "dateretry",
                "exceptiontype"]:
            if self.params.get(k):
                if isinstance(self.params[k], (list, tuple)):
                    query[k] = {"$in": list(self.params[k])}
                else:
                    query[k] = self.params[k]

        if self.params.get("params"):
            params_dict = json.loads(self.params.get("params"))  # pylint: disable=no-member

            for key in params_dict:
                query["params.%s" % key] = params_dict[key]

        return query

    def perform_action(self, action, query, destination_queue):

        stats = {
            "requeued": 0,
            "cancelled": 0
        }

        if action == "cancel":

            default_job_timeout = get_current_config()["default_job_timeout"]

            # Finding the ttl here to expire is a bit hard because we may have mixed paths
            # and hence mixed ttls.
            # If we are cancelling by path, get this ttl
            if query.get("path"):
                result_ttl = get_task_cfg(query["path"]).get("result_ttl", default_job_timeout)

            # If not, get the maxmimum ttl of all tasks.
            else:

                tasks_defs = get_current_config().get("tasks", {})
                tasks_ttls = [cfg.get("result_ttl", 0) for cfg in itervalues(tasks_defs)]

                result_ttl = max([default_job_timeout] + tasks_ttls)

            now = datetime.datetime.utcnow()
            ret = self.collection.update(query, {"$set": {
                "status": "cancel",
                "dateexpires": now + datetime.timedelta(seconds=result_ttl),
                "dateupdated": now
            }}, multi=True)
            stats["cancelled"] = ret["n"]

            # Special case when emptying just by queue name: empty it directly!
            # In this case we could also loose some jobs that were queued after
            # the MongoDB update. They will be "lost" and requeued later like the other case
            # after the Redis BLPOP
            if list(query.keys()) == ["queue"]:
                Queue(query["queue"]).empty()

        elif action in ("requeue", "requeue_retry"):

            # Requeue task by groups of maximum 1k items (if all in the same
            # queue)
            cursor = self.collection.find(query, projection=["_id", "queue"])

            # We must freeze the list because queries below would change it.
            # This could not fit in memory, research adding {"stats": {"$ne":
            # "queued"}} in the query
            fetched_jobs = list(cursor)

            for jobs in group_iter(fetched_jobs, n=1000):

                jobs_by_queue = defaultdict(list)
                for job in jobs:
                    jobs_by_queue[job["queue"]].append(job["_id"])
                    stats["requeued"] += 1

                for queue in jobs_by_queue:

                    updates = {
                        "status": "queued",
                        "dateupdated": datetime.datetime.utcnow()
                    }

                    if destination_queue is not None:
                        updates["queue"] = destination_queue

                    if action == "requeue":
                        updates["retry_count"] = 0

                    self.collection.update({
                        "_id": {"$in": jobs_by_queue[queue]}
                    }, {"$set": updates}, multi=True)

                    # Between these two lines, jobs can become "lost" too.

                    Queue(destination_queue or queue, add_to_known_queues=True).enqueue_job_ids(
                        [str(x) for x in jobs_by_queue[queue]])

        print(stats)

        return stats
