#!/usr/bin/env python3

# Copyright (C) 2022 David Härdeman <david@hardeman.nu>
# Copyright (C) 2022 DebOps <https://debops.org/>
# SPDX-License-Identifier: GPL-3.0-only

"""BIND DNSSEC key rollover parent zone update/notification script

This script is meant to be called periodically from cron and will check
for keys which need to be added to, or removed from, the parent zone.

When such keys are detected, an external utility will be executed or
an email will be sent to the administrator, requesting manual action.
"""

import os
import re
import sys
import json
import shutil
import socket
import smtplib
import logging
import argparse
import datetime
import subprocess

__license__ = 'GPL-3.0'
__author__ = 'David Härdeman <david@hardeman.nu>'
__version__ = '1.0.0'

LOG = logging.getLogger(__name__)


def die(msg):
    if msg is not None:
        LOG.error(msg)
    sys.exit(1)


def run_process(config, args, change_locale=True):
    """Runs some external program and returns its stdout"""

    result = {}
    LOG.debug('Executing: "{}"'.format(' '.join(args)))
    tmp_env = os.environ.copy()
    if change_locale:
        tmp_env["LC_ALL"] = "C"

    with subprocess.Popen(args, text=True, env=tmp_env,
                          stdout=subprocess.PIPE,
                          stderr=subprocess.PIPE) as process:
        stdout, stderr = process.communicate()
        result.stdout = stdout.strip()
        result.stderr = stderr.strip()
        result.rc = process.returncode

        if process.returncode != 0:
            cmd = ' '.join(args)
            msg = 'Command "{}" failed ({})'.format(cmd, stderr.strip())
            LOG.debug('Command {} rc: {}'.format(cmd, process.returncode))
            LOG.debug('Command {} stdout: {}'.format(cmd, stdout.strip()))
            LOG.debug('Command {} stderr: {}'.format(cmd, stderr.strip()))
            write_log_msg(config, msg + "\n")
            result.error_msg = msg

    return result


def run_process_get_stdout(config, args):
    result = do_run_process(config, args, change_locale=True)
    if result.rc != 0:
        die(result.error_msg)
    else:
        return result.stdout


def run_process_ok(config, args):
    result = do_run_process(config, args, change_locale=False)
    if result.rc != 0:
        LOG.error(result.error_msg)
        return False
    else:
        return True


class Key:
    """Keeps track of a zone key.

    Attributes:
        alg_num:        Mapping from algorithm names to IANA assigned numbers

    See https://www.iana.org/assignments/dns-sec-alg-numbers/dns-sec-alg-numbers.xhtml.
    """
    alg_num = {
        "DSA": "3",
        "RSASHA1": "5",
        "DSA-NSEC3-SHA1": "6",
        "RSASHA1-NSEC3-SHA1": "7",
        "RSASHA256": "8",
        "RSASHA512": "10",
        "ECC-GOST": "12",
        "ECDSAP256SHA256": "13",
        "ECDSAP384SHA384": "14",
        "ED25519": "15",
        "ED448": "16"
    }

    def __init__(self, zone, line):
        self.valid = False
        self.zone = zone
        self.goal = None
        self.ds_state = None
        self.key_id = None
        self.key_alg = None
        self.key_alg_num = None
        self.key_type = None

        line = re.sub('\\s+', ' ', line.strip())
        parts = line.split()

        if len(parts) != 4:
            LOG.debug('Invalid key description ({})'.format(line))
            return

        self.key_id = parts[1]

        self.key_alg = parts[2].strip(' ,()')
        if self.key_alg.upper() in Key.alg_num:
            self.key_alg_num = Key.alg_num[self.key_alg.upper()]
        else:
            LOG.debug('Unknown key alg {} for key {}'.format(self.key_alg, self.key_id))
            return

        self.key_type = parts[3].upper()
        if self.key_type not in ['KSK', 'CSK', 'ZSK']:
            LOG.debug('Unknown key type {}, key {}'.format(self.key_type, self.key_id))
            return

        self.valid = True

    def print(self):
        if not self.valid:
            LOG.debug('Invalid key')
            return

        LOG.debug('ID  : {}'.format(self.key_id))
        LOG.debug('Alg : {}'.format(self.key_alg))
        LOG.debug('AlgN: {}'.format(self.key_alg_num))
        LOG.debug('Type: {}'.format(self.key_type))
        LOG.debug('Goal: {}'.format(
            self.goal if self.goal is not None else "<unknown>"))
        LOG.debug('DS  : {}'.format(
            self.ds_state if self.ds_state is not None else "<unknown>"))

    def is_publishable(self):
        if not self.valid:
            return False
        elif self.key_type.upper() not in ['KSK', 'CSK']:
            return False
        elif self.goal is None or self.ds_state is None:
            return False
        else:
            return True

    def needs_publication(self):
        if not self.is_publishable():
            return False
        elif self.goal.lower() != 'omnipresent':
            return False
        elif self.ds_state.lower() != 'rumoured':
            return False
        return True

    def needs_withdrawal(self):
        if not self.is_publishable():
            return False
        elif self.goal.lower() != 'hidden':
            return False
        elif self.ds_state.lower() != 'unretentive':
            return False
        return True

    def parse_line(self, line):
        line = re.sub('\\s+', ' ', line.strip().lower())
        parts = line.split(':')

        if len(parts) != 2:
            return

        param = parts[0].strip()
        value = parts[1].strip()
        if value in ['hidden', 'rumoured', 'omnipresent', 'unretentive']:
            state = value
        else:
            state = None

        if param == '- goal':
            if state is None:
                LOG.debug('Unknown key goal: {}'.format(value))
            else:
                self.goal = state

        elif param == '- ds':
            if state is None:
                LOG.debug('Unknown DS state: {}'.format(value))
            else:
                self.ds_state = state


class Zone:
    """Keeps track of a zone"""

    def __init__(self, name, rr_class, view):
        self.name = name
        self.rr_class = rr_class
        self.view = view
        self.keys = None

    def get_keys(self, config, rndc_output=None):
        """Get the currently known keys for a zone"""

        if self.keys is not None:
            return self.keys

        if rndc_output is None:
            args = ["rndc", "dnssec", "-status",
                    self.name, self.rr_class, self.view]
            rndc_output = run_process(config, args)

        """
        Entries can have the following states:

         * hidden (not published)
         * rumoured (partially published)
         * omnipresent (completely published)
         * unretentive (stale)

        This is an example of the kind of output we need to parse:

        root@foobar:~# rndc dnssec -status example.com
        dnssec-policy: kskzsk
        current time:  Sat Jul 30 20:43:25 2022

        key: 38979 (ECDSAP256SHA256), KSK
          published:      yes - since Fri Jul 29 23:12:42 2022
          key signing:    yes - since Fri Jul 29 23:12:42 2022

          Next rollover scheduled on Sat Jul 29 23:12:42 2023
          - goal:           omnipresent
          - dnskey:         omnipresent
          - ds:             rumoured
          - key rrsig:      omnipresent

        key: 3732 (ECDSAP256SHA256), ZSK
          published:      yes - since Fri Jul 29 22:58:05 2022
          zone signing:   yes - since Fri Jul 29 22:58:05 2022

          Next rollover scheduled on Tue Sep 27 22:55:05 2022
          - goal:           omnipresent
          - dnskey:         omnipresent
          - zone rrsig:     omnipresent

        key: 52180 (ECDSAP256SHA256), KSK
          published:      yes - since Fri Jul 29 22:58:05 2022
          key signing:    yes - since Fri Jul 29 22:58:05 2022

          Rollover is due since Fri Jul 29 23:18:42 2022
          - goal:           hidden
          - dnskey:         omnipresent
          - ds:             unretentive
          - key rrsig:      omnipresent


        Note that there is:

         * one KSK (38979) which has a goal state of omnipresent but where the DS
           record is rumoured (i.e. not known to be published in the parent zone)

         * one KSK (52180) which has a goal state of hidden but where the DS record
           is unretentive (i.e. not known to be removed from the parent zone)

        """

        self.keys = []
        key = None
        for line in rndc_output.splitlines():
            if line.startswith('key:'):
                if key is not None:
                    self.keys.append(key)
                key = Key(self, line)
                continue
            elif key is None:
                continue
            else:
                key.parse_line(line)

        if key is not None:
            self.keys.append(key)

    def get_keys_to_publish(self):
        if self.rr_class.upper() != 'IN':
            return []
        elif self.keys is None:
            return []

        keys = []
        for key in self.keys:
            if key.needs_publication():
                keys.append(key)

        return keys

    def get_keys_to_withdraw(self):
        if self.rr_class.upper() != 'IN':
            return []
        elif self.keys is None:
            return []

        keys = []
        for key in self.keys:
            if key.needs_withdrawal():
                keys.append(key)

        return keys

    def print_keys(self):
        if self.keys is None:
            return

        for key in self.keys:
            key.print()
            LOG.debug('')

    def print_zone(self):
        LOG.debug('View: {}, Zone: {}, Keys: {}'.format(
                  self.name,
                  self.view,
                  "<undefined>" if self.keys is None else len(self.keys)))

    @classmethod
    def get_zones(cls, config, path='/var/cache/bind/named_dump.db'):
        """Creates a list of zones, possibly by parsing BINDs dumpfile"""
        zones = []

        if "zones" in config and len(config["zones"]) > 0:
            for zone in config["zones"]:
                parts = zone.split("/")

                if len(parts) > 2:
                    die('Invalid zone: "{}"'.format(zone))
                elif len(parts) > 1:
                    view = parts[1].strip()
                else:
                    view = "_default"

                name = parts[0].strip()
                rr_class = "IN"

                if len(name) < 1 or len(view) < 1:
                    die('Invalid zone: "{}"'.format(zone))

                zone = Zone(name, rr_class, view)
                zones.append(zone)

            return zones

        # No zones defined, time to dig in rndc dump files to autodetect them...
        args = ["rndc", "dumpdb", "-zones"]

        if not run_process_ok(config, args):
            die(None)

        """
        This is an example of the kind of output we need to parse:

        Note that the '(signed)' suffix may or may not be present for signed
        zones depending on the version of BIND.

        Without views:
        ;
        ; Start view _default
        ;
        ; Zone dump of 'example.com/IN'
        ;
        ...
        ;
        ; Zone dump of 'example.net/IN (signed)'
        ;
        ...

        With views:
        ;
        ; Start view foo
        ;
        ; Zone dump of 'example.com/IN/foo'
        ;
        ...
        ;
        ; Start view bar
        ;
        ; Zone dump of 'example.net/IN/bar (signed)'
        ;
        ...
        """

        start_view = '_default'
        with open(path) as file:
            for line in file:
                if not line.startswith(';'):
                    continue
                line = line[1:].strip()
                if line.startswith('Start view '):
                    start_view = line[len('Start view '):]
                    continue
                elif not line.startswith('Zone dump of '):
                    continue

                line = line[len('Zone dump of '):].strip().strip("'")

                if line.endswith('(signed)'):
                    line = line[:-len('(signed)')].strip()

                parts = line.split('/')
                if len(parts) < 2 or len(parts) > 3:
                    LOG.debug('Invalid zone line ({})'.format(line))
                    continue

                name = parts[0].strip()
                rr_class = parts[1].strip()
                view = start_view

                if len(parts) == 3:
                    view = parts[2].strip()
                    if view != start_view:
                        LOG.debug('Invalid zone view ({}, expected {})'.format(
                            view, start_view))
                        continue

                if rr_class != 'IN':
                    continue

                if view == '_bind':
                    continue

                zone = Zone(name, rr_class, view)
                zones.append(zone)

        return zones


def get_args_parser():
    args_parser = argparse.ArgumentParser(
        prog='debops-bind-rollkey',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
        description="Publish/unpublish DNSSEC keys.",
    )
    args_parser.add_argument(
        '-V', '--version',
        action='version',
        version=__version__,
    )
    args_parser.add_argument(
        '-d', '--debug',
        help="Enable debug output",
        action='store_const',
        dest='loglevel',
        const=logging.DEBUG,
    )
    args_parser.add_argument(
        '-v', '--verbose',
        help="Enable verbose output",
        action='store_const',
        dest='loglevel',
        const=logging.INFO,
    )
    args_parser.add_argument(
        '-c', '--config',
        help="Configuration file path",
        default="/etc/bind/debops-bind-rollkey.json",
        dest="config_path",
    )
    args_parser.add_argument(
        '-C', '--print-config',
        help="Print a configuration file example",
        action="store_true",
        default=False,
        dest="print_config",
    )

    return args_parser


def get_default_config():
    return {
        "method": "log/email/external",
        "log_file": "/var/log/debops-bind-rollkey.log",
        "email_to": "root@localhost",
        "email_from": "noreply@{}".format(socket.getfqdn()),
        "email_host": "localhost",
        "email_port": 25,
        "email_subject": "BIND DNSSEC key updates",
        "external_script": "/usr/local/sbin/debops-bind-rollkey-action",
    }


def print_example_config():
    config = get_default_config()
    config["zones"] = ["zone1/view1", "zone2/view2", "(optional)"]
    print(json.dumps(config, sort_keys=False, indent=4))
    sys.exit(0)


def read_config_file(path):
    config = get_default_config()

    try:
        with open(path, "r") as f:
            data = f.read()
    except Exception as err:
        LOG.info("Can't open cfg file {}: {}".format(path, str(err)))
        data = None

    if data is not None:
        try:
            j = json.loads(data)
            config.update(j)
        except Exception as err:
            die("Can't parse cfg file {}: {}".format(path, str(err)))

    if not config["method"]:
        die("Method not configured")

    elif config["method"] == "log":
        if not config["log_file"]:
            die("Log file path not configured")

    elif config["method"] == "email":
        if not config["email_to"]:
            die("Recipient email address not configured")

        if not config["email_host"]:
            die("SMTP host not configured")

        if not config["email_port"]:
            die("SMTP port not configured")

        if not config["email_subject"]:
            die("Email subject not configured")

    elif config["method"] == "external":
        if not config["external_script"]:
            die("External script not configured")

        if not os.access(config["external_script"], os.F_OK):
            die("External script not found")

        if not os.access(config["external_script"], os.X_OK):
            die("External script not executable")

    else:
        die("Invalid method configured")

    return config


def get_actionable_keys(zones):
    keys = []

    for zone in zones:
        pub_keys = zone.get_keys_to_publish()
        for key in pub_keys:
            key.action = "Publish"
            key.action_goal = "published"
            keys.append(key)

        unpub_keys = zone.get_keys_to_withdraw()
        for key in unpub_keys:
            key.action = "Withdraw"
            key.action_goal = "withdrawn"
            keys.append(key)

    return keys


def do_external(config, zones):
    keys = get_actionable_keys(zones)
    failures = False

    for key in keys:
        args_ext = [config["external_script"],
                    key.action.lower(),
                    key.key_id,
                    key.key_alg_num,
                    key.zone.name,
                    key.zone.rr_class,
                    key.zone.view]
        args_rndc = ["rndc",
                     "dnssec",
                     "-checkds",
                     "-key",
                     key.key_id,
                     "-alg",
                     key.key_alg_num,
                     key.action_goal,
                     key.zone.name,
                     key.zone.rr_class]
        if key.zone.view != '_default':
            args_rndc.append(key.zone.view)

        if not run_process_ok(config, args_ext):
            failures = True
        elif not run_process_ok(config, args_rndc):
            failures = True

    if failures:
        write_log_msg(config, "Some commands failed\n\n")
    else:
        write_log_msg(config, "Done\n\n")


def create_message(config, keys, sep):
    msg = "{}{}Executed at {}{}".format(
        sys.argv[0],
        sep,
        datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S"),
        sep)

    msg += "The following DNSSEC key changes need to be made" + sep
    msg += "================================================" + sep

    if len(keys) == 0:
        msg += "None" + sep
    else:
        for key in keys:
            msg += "{} {} {} {} IN {}{}".format(
                key.action,
                key.key_id,
                key.key_alg_num,
                key.zone.name,
                key.zone.view,
                sep)

    msg += "================================================" + sep
    msg += sep
    return msg


def do_email(config, zones):
    keys = get_actionable_keys(zones)
    if len(keys) == 0:
        return

    msg = "From: {}\r\nTo: {}\r\nSubject: {}\r\n\r\n".format(
        config["email_from"],
        config["email_to"],
        config["email_subject"])

    msg += create_message(config, keys, "\r\n")

    with smtplib.SMTP(host=config["email_host"],
                      port=config["email_port"]) as smtp:
        # smtp.set_debuglevel(1)
        smtp.sendmail(config["email_from"], config["email_to"], msg)


def write_log_msg(config, msg):
    if not config["log_file"]:
        return

    with open(config["log_file"], "a") as f:
        f.write(msg)


def do_log(config, zones):
    keys = get_actionable_keys(zones)
    msg = create_message(config, keys, "\n")
    write_log_msg(config, msg)
    if len(keys) == 0:
        LOG.debug("Nothing to do")


def main():
    args_parser = get_args_parser()
    args_parser.set_defaults(loglevel=logging.WARN)
    args = args_parser.parse_args()

    logging.basicConfig(
        format='%(levelname)s{}, %(asctime)s: %(message)s'.format(
            ' (%(filename)s:%(lineno)s)'
            if args.loglevel <= logging.DEBUG else '',
        ),
        level=args.loglevel,
    )

    if args.print_config:
        print_example_config()

    config = read_config_file(args.config_path)

    if shutil.which("rndc") is None:
        die("rndc executable not found")

    zones = Zone.get_zones(config)

    for zone in zones:
        zone.get_keys(config)
        zone.print_zone()
        zone.print_keys()

    do_log(config, zones)

    if config["method"] == "email":
        do_email(config, zones)
    elif config["method"] == "external":
        do_external(config, zones)


if __name__ == '__main__':
    main()
