#!/usr/bin/env python3

# Copyright (C) 2021 David Härdeman <david@hardeman.nu>
# SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only OR Apache-2.0

"""Rspamd DKIM key update/notification script

This script is meant to be called by rspamd-dkim-keygen and will perform
various actions (log, email the administrator, or attempt an automated
update using nsupdate) when DNS entries for DKIM keys should be
added/removed.
"""

from __future__ import (print_function)
import os
import sys
import json
import socket
import logging
import smtplib
import argparse
import datetime
import subprocess
import shutil

__license__ = 'GPL-2.0 OR Apache-2.0'
__author__ = 'David Härdeman <david@hardeman.nu>'
__version__ = '1.0.1'

LOG = logging.getLogger(__name__)


def die(msg):
    LOG.error(msg)
    sys.exit(1)


def get_args_parser():
    args_parser = argparse.ArgumentParser(
        prog='rspamd-dkim-update',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
        description="Update DKIM configuration in DNS.",
    )
    args_parser.add_argument(
        '-V', '--version',
        action='version',
        version=__version__,
    )
    args_parser.add_argument(
        '-d', '--debug',
        help="Enable debug output",
        action='store_const',
        dest='loglevel',
        const=logging.DEBUG,
    )
    args_parser.add_argument(
        '-v', '--verbose',
        help="Enable verbose output",
        action="store_const",
        dest="loglevel",
        const=logging.INFO,
    )
    args_parser.add_argument(
        '-c', '--config',
        help="Configuration file path",
        default="/etc/rspamd/dkim-update.json",
        dest="config_path",
    )
    args_parser.add_argument(
        '-C', '--print-config',
        help="Print a configuration file example",
        action="store_true",
        default=False,
        dest="print_config",
    )
    args_parser.add_argument(
        metavar='<"add"/"del"> <domain> <selector> <dns_entry_file_path>',
        help="4-tuple(s) defining the key which has been added/removed",
        nargs="*",
        dest="keys",
    )

    return args_parser


def get_default_config():
    return {
        "method": "log/email/nsupdate/nsupdate_tsig/nsupdate_gsstsig",
        "log_file": "/var/log/rspamd/rspamd-dkim-update.log",
        "email_to": "root@localhost",
        "email_from": "noreply@{}".format(socket.getfqdn()),
        "email_host": "localhost",
        "email_port": 25,
        "email_subject": "Rspamd DKIM DNS updates",
        "nsupdate_keyfile": "/etc/rspamd/dkim_dns_key",
        "nsupdate_gsstsig_princ": "dnsupdate@EXAMPLE.COM",
        "nsupdate_ttl": 3600,
        "nsupdate_server": "dns.example.com",
    }


def print_example_config():
    config = get_default_config()
    print(json.dumps(config, sort_keys=False, indent=4))
    sys.exit(0)


def read_config_file(path):
    config = get_default_config()

    try:
        with open(path, "r") as f:
            data = f.read()
    except Exception as err:
        LOG.info("Can't open cfg file {}: {}".format(path, str(err)))
        data = None

    if data is not None:
        try:
            j = json.loads(data)
            config.update(j)
        except Exception as err:
            die("Can't parse cfg file {}: {}".format(path, str(err)))

    if not config["method"]:
        die("Method not configured")

    elif config["method"] == "log":
        if not config["log_file"]:
            die("Log file path not configured")

    elif config["method"] == "email":
        if not config["email_to"]:
            die("Recipient email address not configured")

        if not config["email_host"]:
            die("SMTP host not configured")

        if not config["email_port"]:
            die("SMTP port not configured")

        if not config["email_subject"]:
            die("Email subject not configured")

    elif config["method"] in ["nsupdate", "nsupdate_tsig", "nsupdate_gsstsig"]:
        if (config["method"] in ["nsupdate_tsig", "nsupdate_gsstsig"] and
                not config["nsupdate_keyfile"]):
            die("nsupdate keyfile not configured")

        if (config["method"] == "nsupdate_gsstsig" and
                not config["nsupdate_gsstsig_princ"]):
            die("nsupdate gsstsig principal not configured")

    # Note: invalid methods are caught in main()
    return config


def get_keys(key_list):
    for i in range(0, len(key_list), 4):
        action = key_list[i]
        domain = key_list[i + 1]
        selector = key_list[i + 2]
        dns_file = key_list[i + 3]
        name = "{}._domainkey".format(selector)
        domain_name = "{}.{}.".format(name, domain)
        rr = "{} IN TXT".format(domain_name)

        rr_data = ""
        # The public key files generated by rspamadm have the following format:
        # <selector>._domainkey IN TXT ( "v=DKIM1; k=<keytype>; "
        # \t"chunk1"
        # \t"chunk2"
        # \t...
        # \t"chunkN"
        # ) ;
        #
        # The last line can also be:
        # \t"chunkN" ) ;
        #
        if action == "add":
            try:
                with open(dns_file, "r") as f:
                    rr_data = f.read().strip()
            except Exception as err:
                die("Can't open file {}: {}".format(dns_file, str(err)))

            parts = rr_data.split(maxsplit=3)
            if len(parts) != 4:
                die("Invalid file {}: missing parts".format(dns_file))
            elif parts[0] != name:
                die("Invalid file {}: name mismatch".format(dns_file))
            elif parts[1].upper() != "IN":
                die("Invalid file {}: IN missing".format(dns_file))
            elif parts[2].upper() != "TXT":
                die("Invalid file {}: TXT missing".format(dns_file))

            rr_lines = parts[3].strip("\r\n\t ();").split("\n")
            # nsupdate wants the data in < 256 byte chunks
            rr_data = " ".join([line.strip("\r\n\t ") for line in rr_lines])

        elif action == "del":
            action = "delete"

        else:
            die("Invalid action specified")

        yield {
            "action": action,
            "domain": domain,
            "selector": selector,
            "dns_file": dns_file,
            "name": name,
            "domain_name": domain_name,
            "rr": rr,
            "rr_data": rr_data,
        }


def do_nsupdate(config, keys):
    # First, prepare the commands and make sure all executables are present
    if shutil.which("nsupdate") is None:
        die("nsupdate executable not found")

    pre_cmd = []
    main_cmd = ["nsupdate"]
    post_cmd = []

    if config["method"] in ["nsupdate_tsig", "nsupdate_gsstsig"]:
        if not os.path.isfile(config["nsupdate_keyfile"]):
            die("keyfile {} missing".format(config["nsupdate_keyfile"]))

    if config["method"] == "nsupdate_tsig":
        main_cmd += ["-k", config["nsupdate_keyfile"]]

    elif config["method"] == "nsupdate_gsstsig":
        main_cmd += ["-g"]

        if shutil.which("kinit") is None:
            die("kinit executable not found")
        pre_cmd = [
            "kinit", "-k",
            "-t", config["nsupdate_keyfile"],
            "-p", config["nsupdate_gsstsig_princ"]
        ]

        if shutil.which("kdestroy") is None:
            die("kdestroy executable not found")
        post_cmd = ["kdestroy"]

    # Second, prepare the list of nsupdate commands
    cmds = ""

    if config["nsupdate_server"]:
        cmds += "server {}\n".format(config["nsupdate_server"])

    if config["nsupdate_ttl"]:
        cmds += "ttl {}\n".format(config["nsupdate_ttl"])

    for key in keys:
        cmds += "update {} {} {}\n".format(
            key["action"].lower(),
            key["rr"],
            key["rr_data"])

    cmds += "send\n"

    # Third, launch our commands
    if pre_cmd:
        LOG.debug("Calling nsupdate pre cmd: {}".format(pre_cmd))
        if subprocess.call(pre_cmd) != 0:
            die("Failed to call {}".format(pre_cmd))

    LOG.debug("Calling {}".format(main_cmd))
    main_cmd_rc = subprocess.call(main_cmd)

    # Note: we always want to call post_cmd, even if main_cmd failed
    if post_cmd:
        LOG.debug("Calling nsupdate post cmd: {}".format(post_cmd))
        if subprocess.call(post_cmd) != 0:
            die("Failed to call {}".format(post_cmd))

    if main_cmd_rc != 0:
        die("Failed to call {}".format(main_cmd))


def create_message(config, keys, sep):
    msg = "{}{}Executed at {}{}".format(
        sys.argv[0],
        sep,
        datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S"),
        sep)

    msg += "The following DKIM key changes need to be made in the DNS" + sep
    msg += "========================================================="

    for key in keys:
        msg += "{}{} {} {}{}".format(
            sep,
            key["action"].upper(),
            key["rr"],
            key["rr_data"],
            sep)

    msg += "=========================================================" + sep
    msg += sep
    msg += sep
    return msg


def do_email(config, keys):
    msg = "From: {}\r\nTo: {}\r\nSubject: {}\r\n\r\n".format(
        config["email_from"],
        config["email_to"],
        config["email_subject"])

    msg += create_message(config, keys, "\r\n")

    with smtplib.SMTP(host=config["email_host"],
                      port=config["email_port"]) as smtp:
        # smtp.set_debuglevel(1)
        smtp.sendmail(config["email_from"], config["email_to"], msg)


def do_log(config, keys):
    if not config["log_file"]:
        return

    msg = create_message(config, keys, "\n")
    with open(config["log_file"], "a") as f:
        f.write(msg)


def main():
    args_parser = get_args_parser()
    args_parser.set_defaults(loglevel=logging.WARN)
    args = args_parser.parse_args()

    logging.basicConfig(
        format='%(levelname)s{}, %(asctime)s: %(message)s'.format(
            ' (%(filename)s:%(lineno)s)'
            if args.loglevel <= logging.DEBUG else '',
        ),
        level=args.loglevel,
    )

    if args.print_config:
        print_example_config()

    if len(args.keys) < 4:
        die("At least 4 positional arguments are necessary")

    if len(args.keys) % 4 != 0:
        die("A multiple of 4 positional arguments are necessary")

    config = read_config_file(args.config_path)
    keys = list(get_keys(args.keys))

    if config["method"] == "email":
        do_email(config, keys)
    elif config["method"] in ["nsupdate", "nsupdate_tsig", "nsupdate_gsstsig"]:
        do_nsupdate(config, keys)
    elif config["method"] != "log":
        die("Invalid method configured")

    do_log(config, keys)


if __name__ == '__main__':
    main()
