#!/usr/bin/env python3

import argparse
import logging
import netrc
import os
import requests
import sys
import yaml

from copy import deepcopy
from jsonobject import *
from jsonobject.base import get_dynamic_properties
from pprint import pprint
from urllib.parse import urlsplit, urlunsplit


FORMAT = "%(funcName)20s() ] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)


# ExtLoaderMeta and ExtLoader copied from
# https://gist.github.com/joshbode/569627ced3076931b02f
class ExtLoaderMeta(type):

    def __new__(metacls, __name__, __bases__, __dict__):
        """Add include constructer to class."""

        # register the include constructor on the class
        cls = super().__new__(metacls, __name__, __bases__, __dict__)
        cls.add_constructor('!include', cls.construct_include)

        return cls


class ExtLoader(yaml.Loader, metaclass=ExtLoaderMeta):
    """YAML Loader with `!include` constructor."""

    def __init__(self, stream):
        """Initialise Loader."""

        try:
            self._root = os.path.split(stream.name)[0]
        except AttributeError:
            self._root = os.path.curdir

        super().__init__(stream)

    def construct_include(self, node):
        """Include file referenced at node."""

        filename = os.path.abspath(os.path.join(
            self._root, self.construct_scalar(node)
        ))
        extension = os.path.splitext(filename)[1].lstrip('.')

        with open(filename, 'r') as f:
            if extension in ('yaml', 'yml'):
                return yaml.load(f, ExtLoader)
            else:
                return ''.join(f.readlines())


def yaml_validator(data):
    yaml.load(data, Loader=ExtLoader)


class SquadConnectionException(Exception):
    pass


class SquadConnection(object):
    def __init__(self, args):
        self.url = args.url
        urlparts = urlsplit(self.url)
        self.base_url = urlparts.netloc
        self.url_scheme = urlparts.scheme

        connection_token = "Token %s" % self.__get_connection_token__(args.url)
        self.headers = {
            "Authorization": connection_token
        }

    def get_prepared_request(self, endpoint, method):
        URL = urlunsplit((self.url_scheme, self.base_url, "api/%s" % endpoint, None, None))
        req = requests.Request(method, URL, headers=self.headers)
        return req.prepare()

    def __get_connection_token__(self, url):
        netrcauth = netrc.netrc()
        try:
            self.username, _, self.token = netrcauth.authenticators(self.base_url)
            logger.info("Using username: %s" % self.username)
            return self.token
        except TypeError:
            logger.error("No credentials found for %s" % self.base_url)
            sys.exit(1)

    def download_list(self, endpoint, params=None):
        URL = urlunsplit((self.url_scheme, self.base_url, "api/%s" % endpoint, None, None))
        logger.debug(URL)
        response = requests.get(URL, params=params, headers=self.headers)
        result_list = []
        if response.status_code == 200:
            response_json = response.json()
            result_list = response_json['results']
            while response_json['next'] is not None:
                response = requests.get(response_json['next'], headers=self.headers)
                if response.status_code == 200:
                    response_json = response.json()
                    result_list = result_list + response_json['results']
                else:
                    break
        else:
            logger.error(URL)
            logger.error(response.status_code)
            logger.error(response.text)
        return result_list


    def download_object(self, url):
        if url is None:
            return None
        response = requests.get(url, headers=self.headers)
        if response.status_code == 200:
            return response.json()
        return None

    def filter_object(self, endpoint, params):
        old_configs = self.download_list(endpoint, params)
        if len(old_configs) == 0:
            # the config is new
            return None
        if len(old_configs) != 1:
            logger.error("Found too many objects of type: %s" % endpoint)
            logger.error("Params: %s" % params)
            raise SquadConnectionException("Too many objects found")
        return old_configs[0]

    def put_object(self, config):
        object_id = config.squad_config.get(config._squad_id)
        URL = urlunsplit((self.url_scheme, self.base_url, "api/%s/%s/" % (config._endpoint, object_id), None, None))
        logger.debug(URL)
        logger.debug(config.squad_config)
        response = requests.put(URL, data=config.squad_config, headers=self.headers)
        if response.status_code != 200:
            logger.error(response.text)

    def post_object(self, config):
        URL = urlunsplit((self.url_scheme, self.base_url, "api/%s/" % config._endpoint, None, None))
        logger.debug(URL)
        config.populate_squad_config(self)
        logger.debug(config.squad_config)
        response = requests.post(URL, data=config.squad_config, headers=self.headers)
        if response.status_code != 201:
            logger.error(response.text)


class SquadBaseConfigException(Exception):
    pass


class SquadLoadedConfigException(Exception):
    pass


class SquadBaseConfig(object):
    _endpoint = ""
    _object_name = ""
    _config_id = ""
    _squad_filter_fields = []
    _squad_id = "id"
    _squad_reference_id = "url"
    _loaded_config_class_name = ""
    _nested_objects = {} #  dict of {name: type}

    def __init__(self, _obj=None, **kwargs):
        self.config = None
        if self._loaded_config_class_name and (_obj or kwargs):
            if 'attributes' not in _obj.keys():
                raise SquadBaseConfigException()
            self.config = getattr(self, self._loaded_config_class_name)(_obj['attributes'], **kwargs)
            self.__remove_redundant_keys__()
            self.identifier = _obj.get(self._object_name, None)
        self.squad_config = None

    def __download_squad_object__(self, connection):
        parameters = {}
        for key in self._squad_filter_fields:
            obj_type = key
            fields = key.split(".", 1)
            if len(fields) == 2:
                obj_type = fields[0]
                obj_field = fields[1]
            if obj_type in self._nested_objects.keys():
                # retrieve the obj_field of nested_object
                cls = self._nested_objects[obj_type]
                nested_object = cls({cls._object_name: self.config[obj_type], "attributes": {cls._config_id: self.config[obj_type]}})
                nested_object_squad_id = nested_object.get_squad_id(connection)
                if nested_object_squad_id is not None:
                    parameters.update({obj_type: nested_object_squad_id})
                else:
                    logger.warning("nested object of type %s does not exist" % (obj_type))
            else:
                parameters.update({key: self.config[key]})
        logger.debug(parameters)
        self.squad_config = connection.filter_object(self._endpoint, parameters)

    def __remove_redundant_keys__(self):
        if self.config:
            # delete 'dynamic items' such as IDs or URLs generated by API
            dynitems = get_dynamic_properties(self.config)
            for key, value in dynitems.items():
                self.config.__delitem__(key)

    def __download_object_by_url__(self, connection, nested_object, cls):
        if nested_object is not None:
            squad_object = connection.download_object(nested_object)
            nested_attr = cls()
            nested_attr.from_squad_object(squad_object)
            return nested_attr.to_yaml().get(nested_attr._object_name)
        return nested_object

    def __update_nested_object__(self, connection, nested_object, cls):
        if isinstance(nested_object, list):
            nested_list = []
            for item in nested_object:
                nested_attr_config_id = self.__download_object_by_url__(connection, item, cls)
                nested_list.append(nested_attr_config_id)

            return nested_list
        return self.__download_object_by_url__(connection, nested_object, cls)


    def __update_nested_from_class__(self, connection, cls, value):
        try:
            nested_object = cls({cls._object_name: value, "attributes": {cls._config_id: value}})
            return nested_object.get_squad_reference(connection)
        except SquadConnectionException:
            # this happens when filtering for null objects
            # assuming the new object was set to None
            pass
        return None

    def update_nested_objects(self, connection):
        if self.config:
            for name, cls in self._nested_objects.items():
                nested_object = getattr(self.config, name)
                nested_attr_config_id = self.__update_nested_object__(connection, nested_object, cls)
                setattr(self.config, name, nested_attr_config_id)

    def populate_squad_config(self, connection):
        self.squad_config = deepcopy(self.config.to_json())
        self.__reverse_update_config__(connection)

    def __reverse_update_config__(self, connection):
        logger.debug("Updating squad_config")
        for key, value in self.config.to_json().items():
            if key not in self._nested_objects.keys():
                if value != self.squad_config.get(key):
                    self.squad_config.update({key: value})
            else:
                if isinstance(value, list):
                    nested_list = []
                    for nested_item in value:
                        squad_id = self.__update_nested_from_class__(
                            connection,
                            self._nested_objects[key],
                            nested_item)
                        if squad_id:
                            nested_list.append(squad_id)
                        else:
                            logger.warning("Object %s identified with %s doesn't exist" % (self._nested_objects[key], nested_item))
                    self.squad_config.update({key: nested_list})
                else:
                    squad_id = self.__update_nested_from_class__(
                        connection,
                        self._nested_objects[key],
                        value)
                    if squad_id:
                        self.squad_config.update({key: squad_id})
                    else:
                        logger.warning("Object %s identified with %s doesn't exist" % (self._nested_objects[key], value))

    def to_yaml(self):
        ret_dict = {self._object_name: getattr(self.config, self._config_id, self.identifier)}
        ret_dict.update({"attributes": self.config.to_json()})
        return ret_dict

    def get_squad_id(self, connection):
        if self.squad_config is None:
            self.__download_squad_object__(connection)
        if self.squad_config is None:
            return None
        return self.squad_config.get(self._squad_id)

    def get_squad_reference(self, connection):
        if self.squad_config is None:
            self.__download_squad_object__(connection)
        if self.squad_config:
            return self.squad_config.get(self._squad_reference_id)
        return None

    def __is_field_equal__(self, field_name, field_compare_value):
        return self.config.to_json().get(field_name) == field_compare_value

    def to_squad_object(self, connection):
        """
        Returns modified object that can be uploaded to SQUAD
        and object_modifications dictionary. If object_modifications is empty
        the object wasn't modified
        """
        object_modifications = {}
        if self.squad_config is None:
            self.__download_squad_object__(connection)
            if self.squad_config is None:
                # no object found in SQUAD db
                if self.config:
                    return None, self.config.to_json()
                raise SquadBaseConfigException()
            for key, value in self.config.to_json().items():
                squad_value = self.squad_config.get(key, None)
                if key in self._nested_objects.keys():
                    squad_value = self.__update_nested_object__(connection, self.squad_config.get(key), self._nested_objects[key])
                if not self.__is_field_equal__(key, squad_value):
                    object_modifications.update(
                        {key: {'new': value , 'old': squad_value}})
            self.__reverse_update_config__(connection)
        return self.squad_config, object_modifications

    def from_squad_object(self, squad_dict):
        """
        Initializes 'config' from passed SQUAD object dictionary.
        This method is used to create initial config for SQUAD instance.
        """
        self.squad_config = squad_dict
        self.config = getattr(self, self._loaded_config_class_name)(squad_dict)
        self.__remove_redundant_keys__()
        self.identifier = self.squad_config.get(self._config_id, None)


class BackendConfig(SquadBaseConfig):

    class BackendConfigMeta(JsonObject):
        name = StringProperty(required=True)
        url = StringProperty()
        username  = StringProperty()
        token  = StringProperty(exclude_if_none=True)
        implementation_type  = StringProperty(choices=['null', 'fake', 'lava'])
        backend_settings  = StringProperty(validators=[yaml_validator])
        poll_interval = IntegerProperty()
        max_fetch_attempts = IntegerProperty()

    _endpoint = "backends"
    _object_name = "backend"
    _config_id = "name"
    _squad_filter_fields = ["name"]
    #_squad_id = "id"
    _loaded_config_class_name = "BackendConfigMeta"

    def __is_field_equal__(self, field_name, field_compare_value):
        if field_name == "backend_settings":
            # try to parse the field to dict
            # and compare each item separately
            try:
                old_dict = yaml.load(self.config.backend_settings)
                new_dict = yaml.load(field_compare_value)
                for key, value in old_dict.items():
                    # values should be lists
                    new_value = new_dict.get(key)
                    if not new_value:
                        return False
                    if set(value) == set(new_value):
                        return True
            except:
                return False
        elif field_name == 'token':
            # ignore tokens as they can't be downloaded from SQUAD
            return True
        else:
            return super(BackendConfig, self).__is_field_equal__(field_name, field_compare_value)
        return False


class EmailTemplateConfig(SquadBaseConfig):

    class EmailTemplateConfigMeta(JsonObject):
        name = StringProperty()
        subject = StringProperty()
        plain_text = StringProperty()
        html = StringProperty()

    _endpoint = "emailtemplates"
    _object_name = "emailtemplate"
    _config_id = "name"
    _squad_filter_fields = ["name"]
    #_squad_id = "id"
    _loaded_config_class_name = "EmailTemplateConfigMeta"


class UserGroupConfig(SquadBaseConfig):

    class UserGroupConfigMeta(JsonObject):
        name = StringProperty()

    _endpoint = "usergroups"
    _object_name = "usergroup"
    _config_id = "name"
    _squad_filter_fields = ["name"]
    #_squad_id = "url"
    _loaded_config_class_name = "UserGroupConfigMeta"


class GroupConfig(SquadBaseConfig):

    class GroupConfigMeta(JsonObject):
        slug = StringProperty()
        name = StringProperty()
        description = StringProperty()
        user_groups = ListProperty(str)

    _endpoint = "groups"
    _object_name = "group"
    _config_id = "slug"
    _squad_filter_fields = ["slug"]
    #_squad_id = "id"
    _loaded_config_class_name = "GroupConfigMeta"
    _nested_objects = {'user_groups': UserGroupConfig} #  dict of {name: type}


class ProjectConfig(SquadBaseConfig):

    class ProjectConfigMeta(JsonObject):
        slug = StringProperty()
        group = StringProperty()
        name = StringProperty()
        is_public = BooleanProperty()
        html_mail = BooleanProperty()
        moderate_notifications = BooleanProperty()
        custom_email_template = StringProperty()
        description = StringProperty()
        important_metadata_keys = StringProperty()
        enabled_plugins_list = StringProperty()
        wait_before_notification = IntegerProperty()
        notification_timeout = IntegerProperty()

    _endpoint = "projects"
    _object_name = "project"
    _config_id = "slug"
    _squad_filter_fields = ["slug", "group.slug"]
    #_squad_id = "url"
    _loaded_config_class_name = "ProjectConfigMeta"
    _nested_objects = {
        'group': GroupConfig,
        'custom_email_template': EmailTemplateConfig
    } #  dict of {name: type}


def create_config_list(args, cls):
    config_list = []
    connection = SquadConnection(args)
    squad_config_list = connection.download_list(cls._endpoint)
    for squad_config in squad_config_list:
        config = cls()
        config.from_squad_object(squad_config)
        config.update_nested_objects(connection)
        config_list.append(config)
    return config_list


def create_config_object(key, config_dict):
    if key == "backend":
        return BackendConfig(config_dict)
    if key == "emailtemplate":
        return EmailTemplateConfig(config_dict)
    if key == "usergroup":
        return UserGroupConfig(config_dict)
    if key == "group":
        return GroupConfig(config_dict)
    if key == "project":
        return ProjectConfig(config_dict)
    return None


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-c",
                        "--config-files",
                        nargs="+",
                        default=[],
                        help="Instance config files",
                        dest="config_files")
    parser.add_argument("-d",
                        "--dry-run",
                        action="store_true",
                        default=False,
                        help="Dry run",
                        dest="dry_run")
    parser.add_argument("-v",
                        "--debug",
                        action="store_true",
                        default=False,
                        help="Enable debug",
                        dest="debug")
    parser.add_argument("-u",
                        "--url",
                        help="SQUAD instance URL",
                        dest="url")
    parser.add_argument("-s",
                        "--save-to-file",
                        default=None,
                        help="Save current config to file",
                        dest="save_filename")
    parser.add_argument("-y",
                        "--assume-yes",
                        action="store_true",
                        default=False,
                        help="Assume 'yes' answer for all questions",
                        dest="assume_yes")

    args = parser.parse_args()

    if args.debug:
        logger.setLevel(logging.DEBUG)

    config_list = []
    if args.config_files:
        modifications_list = []
        creations_list = []
        connection = SquadConnection(args)
        for f in args.config_files:
            with open(f, 'r') as stream:
                logger.debug("loading config file: %s" % f)
                try:
                    loaded_config = yaml.load(stream, Loader=ExtLoader)
                    if not isinstance(loaded_config, list):
                        raise SquadLoadedConfigException()
                    for item in loaded_config:
                        for key in item.keys():
                            if key not in ['attributes']:
                                config = create_config_object(key, item)
                                if config:
                                    config_list.append(config)
                except yaml.YAMLError as exc:
                    logger.error(exc)
                    sys.exit(1)
        for item in config_list:
            config, modified = item.to_squad_object(connection)
            if not config and modified:
                # new object created
                creations_list.append({"config": item, "modified": modified})
            if config and modified:
                # changes detected
                modifications_list.append({"config": item, "modified": modified})
        if modifications_list:
            print("Modifications")
        for item in modifications_list:
            #print(item['config'].squad_config)
            print("%s: %s" % (item['config']._object_name, item['config'].config[item['config']._config_id]))
            print("modified fields:")
            for key, value in item['modified'].items():
                print(key)
                print("old: %s" % value["old"])
                print("new: %s" % value["new"])
            print("\n")
        if creations_list:
            print("New objects")
        for item in creations_list:
            print(item['config'].config)
            print("\n")

        if args.dry_run:
            return

        if not args.assume_yes and (modifications_list or creations_list):
            proceed = input("Deploy changes (y/n)?")
            if proceed.lower() == "n":
                return
            logger.info("Assuming 'YES' for all changes")

        for item in modifications_list:
            # send PUT requests
            connection.put_object(item['config'])
        for item in creations_list:
            # send POST requests
            connection.post_object(item['config'])

    if args.save_filename is not None and args.dry_run:
        config_list = config_list + create_config_list(args, BackendConfig)
        config_list = config_list + create_config_list(args, EmailTemplateConfig)
        config_list = config_list + create_config_list(args, UserGroupConfig)
        config_list = config_list + create_config_list(args, GroupConfig)
        config_list = config_list + create_config_list(args, ProjectConfig)

        dump_list = []
        for config in config_list:
            dump_list.append(config.to_yaml())
        with open(args.save_filename, 'w') as outfile:
            yaml.dump(
                dump_list,
                outfile,
                default_flow_style=False,
                indent=4,
                width=96)
        return


if __name__ == '__main__':
    main()
