#!/usr/bin/env python3

# Copyright (C) 2021 David Härdeman <david@hardeman.nu>
# SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only OR Apache-2.0

"""Rspamd DKIM key generation/rollover script

This script is meant to be executed on a monthly basis (e.g. via cron
or systemd) and will automatically manage DKIM key rollover.

Keys go through four states:
* future  - a key which is created in advance of being used and which
            should be published in the DNS zone ASAP
* current - key which is actively being used to sign emails
* expired - key which is no longer used to sign emails but the DNS
            entry is still kept in order to make sure that queued
            and in-flight emails will still have a valid signature
* dead    - key which can (hopefully) be safely removed from the file
            system and from the DNS zone after enough time has passed
            that no signed emails are in flight anymore

These states correspond to the published/active/inactive/removed states
known from DNSSEC key rollover.

The default settings can be overridden by creating a JSON snippet in
/etc/rspamd/dkim-keygen.json (the path can be overridden with the
-c/--config parameter, use -C/--create-config to print a config
template).

At least "domains" and "key_types" need to be set using the config file.
If "update_script" is set, it will be called for each key which is created
or deleted, with the following arguments:
    <add/del> <domain> <selector> <dns_entry_file_path>
"""

from __future__ import (print_function, division)
import os
import sys
import json
import time
import logging
import tempfile
import argparse
import subprocess
import shutil

__license__ = 'GPL-2.0 OR Apache-2.0'
__author__ = 'David Härdeman <david@hardeman.nu>'
__version__ = '1.0.1'

LOG = logging.getLogger(__name__)


class Key:
    """Keeps track of a real or virtual key

    Attributes:
        dryrun             whether to change any files
        current_month      as the number of months since year 0
        key_archive        directory to move expired keys to
        update_script      script to execute upon key creation/removal
        update_script_args list of arguments to pass to the update_script
        future_period      time between key creation and activation (months)
        active_period      key active time (months)
        expired_period     time between key inactivation and deletion (months)
    """
    dryrun = False
    current_month = 0
    key_archive = None
    update_script = None
    update_script_args = []
    future_period = 1
    active_period = 3
    expired_period = 1

    def __init__(self):
        self.state = "init"
        self.key_path = ""
        self.txt_path = ""
        self.valid_from = 0
        self.age = 0
        self.type = ""
        self.domain = ""
        self.selector = ""
        self.error = ""

    def validate_filename(self, path, selector):
        """Validate the filename of a key

        The expected path is e.g. /some/path/YYYYMM_keytype_domain.key
        with the DNS record in /some/path/YYYYMM_keytype_domain.txt
        """

        abspath = os.path.abspath(path)
        if not abspath.endswith(".key"):
            raise ValueError("key name suffix invalid")
        abspath_wo_suffix = abspath[:-len(".key")]
        basename = os.path.basename(abspath_wo_suffix)

        parts = basename.split("_")
        if len(parts) != 3:
            raise ValueError("key name missing components")

        self.key_path = abspath
        self.txt_path = abspath_wo_suffix + ".txt"
        self.type = parts[1]
        self.domain = parts[2]
        self.selector = selector or ("{}-{}".format(parts[0].strip(), self.type))

        date = parts[0].strip()
        if len(date) != 6:
            raise ValueError("key validity incorrect size")
        try:
            year = int(date[0:4])
            month = int(date[4:6])
        except ValueError:
            raise ValueError("key validity not numeric") from None
        self.valid_from = year * 12 + month
        self.age = self.current_month - self.valid_from

        if self.type not in KeyType.key_types:
            raise ValueError("key type unknown")

    def read(self, path, selector=None, skip_exist_check=False):
        """Read a DKIM key"""

        self.error = ""
        try:
            self.validate_filename(path, selector)
        except Exception as err:
            self.state = "invalid"
            self.error = str(err)
            LOG.debug("Key {}, invalid filename: {}".format(path, self.error))
            return False

        if not skip_exist_check:
            try:
                with open(self.key_path, "r"):
                    pass
            except Exception as err:
                self.state = "invalid"
                self.error = "can't open key file"
                LOG.debug("Key {}, {}: {}".format(self.key_path, self.error, err))
                return False

            try:
                with open(self.txt_path, "r"):
                    pass
            except Exception as err:
                self.state = "invalid"
                self.error = "can't open txt file"
                LOG.debug("Key {}, {}: {}".format(self.txt_path, self.error, err))
                return False

        if self.age >= self.active_period + self.expired_period:
            self.state = "dead"
        elif self.age >= self.active_period:
            self.state = "expired"
        elif self.age >= 0:
            self.state = "active"
        else:
            self.state = "future"
        LOG.debug("Key {} loaded, state {}".format(path, self.state))
        LOG.debug(vars(self))
        return True

    def create(self, domain, key_type, offset):
        """Create a DKIM key"""

        valid_from = self.current_month + offset
        datestr = "{:04d}{:02d}".format(valid_from // 12, valid_from % 12)
        selector = "{}-{}".format(datestr, key_type.type)
        self.key_path = os.path.abspath("{}_{}_{}.key".format(
            datestr, key_type.type, domain))
        self.txt_path = os.path.abspath("{}_{}_{}.txt".format(
            datestr, key_type.type, domain))

        LOG.debug("Asked to create key {}".format(self.key_path))
        if self.read(self.key_path, selector):
            return

        if not self.dryrun:
            try:
                os.unlink(self.key_path)
            except Exception:
                pass

            try:
                os.unlink(self.txt_path)
            except Exception:
                pass

        if self.dryrun:
            LOG.info("Would have created key {}".format(self.key_path))
        else:
            key_type.create(domain, selector, self.key_path, self.txt_path)
            LOG.info("Created key {}".format(self.key_path))

        self.read(self.key_path, selector, skip_exist_check=self.dryrun)
        if self.state == "invalid":
            die("Created key is invalid!?")

        self.__class__.update_script_args += \
            ["add", self.domain, self.selector, self.txt_path]

    def delete(self):
        """Delete/archive a key from the file system"""

        self.__class__.update_script_args += \
            ["del", self.domain, self.selector, self.txt_path]

        for path in self.key_path, self.txt_path:
            if not os.path.exists(path):
                continue

            if self.dryrun:
                if self.key_archive:
                    LOG.info("Would have archived {}".format(path))
                else:
                    LOG.info("Would have unlinked {}".format(path))
                continue

            try:
                if self.key_archive:
                    LOG.info("Archiving {}".format(path))
                    os.rename(path, os.path.join(self.key_archive,
                                                 os.path.basename(path)))
                else:
                    LOG.info("Unlinking {}".format(path))
                    os.unlink(path)
            except Exception:
                pass

    @classmethod
    def pending_changes(cls):
        return len(cls.update_script_args) > 0

    @classmethod
    def run_update_script(cls):
        """Run the update_script if configured and if there are any changes"""
        if cls.update_script and cls.update_script_args:
            args = [cls.update_script] + cls.update_script_args

            if cls.dryrun:
                LOG.debug("Would have called key script {}".format(args))
            else:
                try:
                    LOG.debug("Calling key script {}".format(args))
                    subprocess.call(args)
                except Exception as err:
                    LOG.error("Calling key script failed: {}".format(err))

        cls.update_script_args = []


class KeyType:
    """Keeps track of configured key types and associated attributes

    Attributes:
        instances       list of created instances
        key_types       list of known key types
    """
    instances = []
    key_types = []

    def __init__(self, key_type_dict):
        """Create a new key type from a dict object"""

        key_type_name = key_type_dict.get("type")
        if key_type_name is None:
            raise ValueError("Key type definition missing key type")
        elif key_type_name in self.__class__.key_types:
            raise ValueError("Key type {} defined twice".format(type))

        self.type = key_type_name
        extra_args = key_type_dict.get("extra_args")
        if extra_args is not None:
            self.extra_args = extra_args

        self.__class__.key_types.append(key_type_name)
        self.__class__.instances.append(self)

    def create(self, domain, selector, key_path, txt_path):
        """Create a key of the given type"""

        args = ["rspamadm", "dkim_keygen",
                "--type", self.type,
                "--domain", domain,
                "--selector", selector,
                "--privkey", key_path]

        if hasattr(self, "extra_args"):
            args.extend(self.extra_args)

        LOG.debug("Key creation command: {}".format(args))

        with open(txt_path, "w") as tout:
            with subprocess.Popen(args, text=True, stdout=tout,
                                  stderr=subprocess.PIPE) as process:
                stdout, stderr = process.communicate()
                if process.returncode != 0:
                    LOG.debug("Key creation rc: {}".format(process.returncode))
                    LOG.debug("Key creation stdout: {}".format(stdout))
                    LOG.debug("Key creation stderr: {}".format(stderr))
                    die("Key creation failed ({})".format(stderr))


class KeyConfig:
    """Keeps track of a set of configured keys

    Attributes:
        dryrun      whether to change any files
    """
    dryrun = False

    def __init__(self, path, state):
        self.keys = {}
        self.raw_json = None
        self.path = os.path.abspath(path)
        self.state = state

    def remove_invalid(self):
        """Remove and return all keys unsuited for this config"""

        invalid = []
        for domain, domain_keys in self.keys.items():
            valid = []
            for key in domain_keys:
                if key.state == self.state:
                    valid.append(key)
                else:
                    invalid.append(key)
            self.keys[domain] = valid

        return invalid

    def get(self, domain, key_type, state, remove=False):
        """Get and optionally remove a key from this config"""

        if domain in self.keys:
            for i, key in enumerate(self.keys[domain]):
                if key.type != key_type.type:
                    continue
                if key.error:
                    continue
                if key.state != state:
                    continue
                if remove:
                    LOG.debug("Removing key {} from {} config".format(
                        key.key_path, self.state))
                    del self.keys[domain][i]
                return key

        return None

    def add(self, key):
        """Adds a key to this config"""

        if key.domain not in self.keys:
            self.keys[key.domain] = []
        LOG.debug("Adding key {} to {} config".format(key.key_path, self.state))
        self.keys[key.domain].append(key)

    def read(self):
        """ Read configuration file"""
        try:
            with open(self.path, "r") as f:
                self.raw_json = f.read()
                cfg = json.loads("{" + self.raw_json + "}").get("domain", {})
                LOG.debug("Processing file {}".format(self.path))
        except FileNotFoundError:
            LOG.debug("File {} not found".format(self.path))
            return
        except json.decoder.JSONDecodeError as err:
            LOG.error("Failed to parse {}: {}".format(self.path, str(e)))
            return
        except Exception as err:
            die("Error reading {}: {}".format(self.path, err))

        for domain in cfg:
            LOG.debug("Configured domain {}".format(domain))

            if "selectors" not in cfg[domain]:
                LOG.debug("File {} contains no selectors".format(self.path))
                continue

            for item in cfg[domain]["selectors"]:
                if "path" not in item:
                    LOG.error("Key entry in {} missing a selector".format(self.path))
                    continue

                LOG.debug("Configured key {}, selector {}".format(
                    item["path"], item.get("selector", "<unknown>")))
                key = Key()
                key.read(item["path"], item.get("selector"))
                self.add(key)

        LOG.debug("Done processing file {}".format(self.path))

    def to_json(self, include_invalid):
        """Converts the current configuration to a JSON fragment

        Note that the fragment is suitable for inclusion in other
        JSON files (i.e. in the Rspamd configuration files),
        which is why the opening/closing curly braces are removed.
        """
        json_dict = {"domain": {}}
        for domain, domain_keys in self.keys.items():
            key_list = []

            for key in domain_keys:
                if include_invalid or key.state == self.state:
                    key_list.append({"path": key.key_path,
                                     "selector": key.selector})

            if key_list:
                json_dict["domain"][domain] = {"selectors": key_list}

        json_lines = []
        raw_json = json.dumps(json_dict, sort_keys=True, indent=4)
        for line in raw_json.split("\n")[1:-1]:
            # Python 2 and 3 differ in the trailing whitespace
            json_lines.append(line[4:].rstrip())
        return "\n".join(json_lines)

    def dump(self, reason):
        """Prints the current configuration as a JSON fragment"""
        json = self.to_json(True)
        LOG.debug("Dumping {} {} config".format(reason, self.state))
        LOG.debug(json)

    def write(self):
        """Writes out the current configuration as a JSON fragment"""
        json = self.to_json(False)

        if json == self.raw_json:
            LOG.info("Not updating {}, no changes".format(self.path))
            return

        self.raw_json = json
        if self.dryrun:
            LOG.info("Would have written {}".format(self.path))
            return

        LOG.info("Writing {}".format(self.path))
        wd = os.path.dirname(self.path)
        with tempfile.NamedTemporaryFile(mode="w", dir=wd, delete=False) as f:
            f.write(self.raw_json)
            f.close()
            os.rename(f.name, self.path)
            os.chmod(self.path, 0o640)


def die(msg):
    LOG.error(msg)
    sys.exit(1)


def get_args_parser():
    args_parser = argparse.ArgumentParser(
        prog='rspamd-dkim-keygen',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
        description="Generate and rollover DKIM keys.",
    )
    args_parser.add_argument(
        '-V', '--version',
        action='version',
        version=__version__,
    )
    args_parser.add_argument(
        '-d', '--debug',
        help="Enable debug output",
        action='store_const',
        dest='loglevel',
        const=logging.DEBUG,
    )
    args_parser.add_argument(
        '-v', '--verbose',
        help="Enable verbose output",
        action="store_const",
        dest="loglevel",
        const=logging.INFO,
    )
    args_parser.add_argument(
        '-c', '--config',
        help="Configuration file path",
        default="/etc/rspamd/dkim-keygen.json",
        dest="config_path",
    )
    args_parser.add_argument(
        '-C', '--print-config',
        help="Print a configuration file example",
        action="store_true",
        default=False,
        dest="print_config",
    )
    args_parser.add_argument(
        '-m', '--month',
        help="Set current month (for testing)",
        type=int,
        default=0,
        dest="month",
    )
    args_parser.add_argument(
        "--dry-run",
        help="don't change any files",
        action="store_true",
        default=False,
        dest="dryrun",
    )

    return args_parser


def get_default_config():
    return {
        "key_directory": "/var/lib/rspamd/dkim",
        "key_archive": Key.key_archive,
        "update_script": Key.update_script,
        "active_config": "dkim-active.conf",
        "future_config": "dkim-future.conf",
        "expired_config": "dkim-expired.conf",
        "future_period": Key.future_period,
        "active_period": Key.active_period,
        "expired_period": Key.expired_period,
        "domains": [],
        "key_types": [vars(x) for x in KeyType.instances],
    }


def print_example_config():
    KeyType({"type": "ed25519"})
    KeyType({"type": "rsa", "extra_args": ["--bits", "2048"]})

    config = get_default_config()

    config["key_archive"] = "/var/lib/rspamd/dkim_archive"
    config["update_script"] = "/usr/local/sbin/rspamd-dkim-update"
    config["domains"] = ["example.com"]

    print(json.dumps(config, sort_keys=False, indent=4))
    sys.exit(0)


def read_config_file(path):
    config = get_default_config()

    try:
        with open(path, "r") as f:
            data = f.read()
    except Exception as err:
        LOG.info("Can't open cfg file {}: {}".format(path, str(err)))
        data = None

    if data is not None:
        try:
            j = json.loads(data)
            config.update(j)
        except Exception as err:
            die("Unable to parse cfg file {}: {}".format(path, str(err)))

    if not config["key_directory"]:
        die("Key directory not configured")

    if not os.path.isdir(config["key_directory"]):
        die("Key directory \"{}\" not found".format(config["key_directory"]))

    try:
        os.chdir(config["key_directory"])
    except Exception as err:
        die("Failed to chdir to key directory: {}".format(err))

    if config["key_archive"] and not os.path.isdir(config["key_archive"]):
        die("Key archive \"{}\" not found".format(config["key_archive"]))
    else:
        Key.key_archive = config["key_archive"]

    if config["update_script"] and not os.access(config["update_script"], os.X_OK):
        die("Update script \"{}\" not found".format(config["update_script"]))
    else:
        Key.update_script = config["update_script"]

    if not isinstance(config["domains"], list) or len(config["domains"]) < 1:
        die("No domains configured")

    if isinstance(config["key_types"], list) and len(config["key_types"]) > 0:
        for key_type in config["key_types"]:
            try:
                KeyType(key_type)
            except Exception as err:
                die("Invalid key type definition: {}".format(err))
    else:
        die("No key types configured")

    if isinstance(config["future_period"], int) and config["future_period"] > 0:
        Key.future_period = config["future_period"]
    else:
        die("Invalid future_period")

    if isinstance(config["active_period"], int) and config["active_period"] > 0:
        Key.active_period = config["active_period"]
    else:
        die("Invalid active_period")

    if isinstance(config["expired_period"], int) and config["expired_period"] > 0:
        Key.expired_period = config["expired_period"]
    else:
        die("Invalid expired_period")

    return config


def main():
    os.umask(0o027)
    args_parser = get_args_parser()
    args_parser.set_defaults(loglevel=logging.WARN)
    args = args_parser.parse_args()

    logging.basicConfig(
        format='%(levelname)s{}, %(asctime)s: %(message)s'.format(
            ' (%(filename)s:%(lineno)s)'
            if args.loglevel <= logging.DEBUG else '',
        ),
        level=args.loglevel,
    )

    if args.print_config:
        print_example_config()

    if shutil.which("rspamadm") is None:
        die("rspamadm executable not found")

    if args.month == 0:
        current_time = time.localtime()
        args.month = current_time.tm_year * 12 + current_time.tm_mon
    Key.current_month = args.month

    Key.dryrun = args.dryrun
    KeyConfig.dryrun = args.dryrun

    config = read_config_file(args.config_path)

    LOG.debug("Current month: {}".format(Key.current_month))

    active_keys = KeyConfig(config["active_config"], "active")
    active_keys.read()
    active_keys.dump("old")

    future_keys = KeyConfig(config["future_config"], "future")
    future_keys.read()
    future_keys.dump("old")

    expired_keys = KeyConfig(config["expired_config"], "expired")
    expired_keys.read()
    expired_keys.dump("old")

    LOG.debug("Handling active and future keys")
    for domain in config["domains"]:
        for key_type in KeyType.instances:

            # First, check for an already valid & configured key
            active_key = active_keys.get(domain, key_type, "active")
            if active_key:
                LOG.info("Using existing key {}".format(active_key.key_path))

            # Second, see if there's a future key which has become active
            if active_key is None:
                active_key = future_keys.get(domain, key_type, "active", remove=True)
                if active_key:
                    LOG.info("Using future key {}".format(active_key.key_path))
                    active_keys.add(active_key)

            # Third, since expired keys should still have valid DNS records,
            # keep using an expired key while waiting for a future key to become
            # active.
            if active_key is None:
                active_key = active_keys.get(domain, key_type, "expired")
                if active_key:
                    LOG.info("Using expired key {}".format(active_key.key_path))
                    active_key.state = "active"

            # As a fallback, create a brand new key
            if active_key is None:
                try:
                    active_key = Key()
                    active_key.create(domain, key_type, 0)
                except Exception as err:
                    die("Can't create active key for domain {}: {}".format(
                        domain, err))

                LOG.info("Using created key {}".format(active_key.key_path))
                active_keys.add(active_key)

            # Finally, make sure we have a future key if expiry is approaching
            if Key.active_period - active_key.age <= Key.future_period:
                future_key = future_keys.get(domain, key_type, "future")
                if future_key is None:
                    try:
                        future_key = Key()
                        future_key.create(domain, key_type, Key.future_period)
                    except Exception as err:
                        die("Can't create future key for domain {}: {}".format(
                            domain, err))

                    LOG.info("Created future key {}".format(future_key.key_path))
                    future_keys.add(future_key)

    LOG.debug("Handling expired/invalid keys")
    invalid_keys = []
    for key in active_keys.remove_invalid() + future_keys.remove_invalid():
        if key.state == "expired":
            LOG.info("Key {} expired".format(key.key_path))
            expired_keys.add(key)
        else:
            invalid_keys.append(key)

    LOG.debug("Handling dead/invalid keys")
    dead_keys = invalid_keys + expired_keys.remove_invalid()
    for key in dead_keys:
        LOG.info("Key {} invalid/dead".format(key.key_path))
        key.delete()

    LOG.debug("Writing updated configuration files")
    active_keys.write()
    future_keys.write()
    expired_keys.write()

    active_keys.dump("new")
    future_keys.dump("new")
    expired_keys.dump("new")

    if Key.pending_changes():
        LOG.debug("Changes detected, potentially calling update_script")
        Key.run_update_script()
        sys.exit(2)


if __name__ == '__main__':
    main()
