#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2018 Procube Co., Ltd.

import docker
import os
import logging
from datetime import datetime, timedelta
from certbot import main as certbot_main
import schedule
import json
import socket

logging.basicConfig(level=os.environ.get('LOG_LEVEL', logging.INFO), format='%(asctime)-15s %(message)s')

CHECK_INTERVAL = timedelta(seconds=30)


class webserver:
  def __init__(self, labels):
    self.labels = labels
    self.cert_path = f'/etc/letsencrypt/live/{labels["published_fqdn"]}/fullchain.pem'
    self.key_path = f'/etc/letsencrypt/live/{labels["published_fqdn"]}/privkey.pem'

  def get_new_certificate(self, logger):
    if not os.access(self.cert_path, os.R_OK) or not os.access(self.key_path, os.R_OK):
      try:
        args = ['--agree-tos', '--text', '--renew-by-default', '-n', '--no-eff-email',
                '--authenticator', 'certbot-pdns:auth', 'certonly', '-d', self.labels["published_fqdn"]]
        if 'CERTBOT_EMAIL' in os.environ:
          args += ['-m', os.environ['CERTBOT_EMAIL']]
        error = certbot_main.main(args)
        if error:
          logger.error(f'error occuer in certbot new: {error}')
      except Exception as e:
        logger.exception(e)

  def adjust(self, logger):
    dirty = False
    if os.access(self.cert_path, os.R_OK):
      with open(self.cert_path, 'r') as cert_file:
        cert = cert_file.read()
      if self.labels.get('cert') != cert:
        self.labels['cert'] = cert
        logger.info(f'set certificate for {self.labels["published_fqdn"]} mode = {"renew" if "cert" in self.labels else "new"}')
        dirty = True
    else:
      logger.info(f'no certificate for {self.labels["published_fqdn"]}')
      if self.labels.pop('cert', False):
        dirty = True
    if os.access(self.key_path, os.R_OK):
      with open(self.key_path, 'r') as key_file:
        key = key_file.read()
      if self.labels.get('key') != key:
        self.labels['key'] = key
        logger.info(f'set key for {self.labels["published_fqdn"]} mode = {"renew" if "key" in self.labels else "new"}')
        dirty = True
    else:
      logger.info(f'no key for {self.labels["published_fqdn"]}')
      if self.labels.pop('key', False):
        dirty = True
    return dirty


class runner:
  def __init__(self):
    self.client = docker.from_env()
    self.proxy_id = None
    self.logger = logging.getLogger(__name__ + '@' + self.client.info()['Name'])
    self.pending_fqnds = []

  # DNS query failed from DNS authorization server?
  # (docker) [centos@hive2 ~]$  dig @169.254.169.254 pdns.procube-demo.jp ns
  # ; <<>> DiG 9.9.4-RedHat-9.9.4-74.el7_6.2 <<>> @169.254.169.254 pdns.procube-demo.jp ns
  # ; (1 server found)
  # ;; global options: +cmd
  # ;; Got answer:
  # ;; ->>HEADER<<- opcode: QUERY, status: SERVFAIL, id: 7336
  # ;; flags: qr rd ra; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 0
  # ;; QUESTION SECTION:
  # ;pdns.procube-demo.jp.		IN	NS
  # ;; Query time: 3003 msec
  # ;; SERVER: 169.254.169.254#53(169.254.169.254)
  # ;; WHEN: Tue Sep 17 19:44:20 JST 2019
  # ;; MSG SIZE  rcvd: 38
  # so, cannot use following method to investigate whther DNS record is available or not.
  # def check_dns(self, fqdn):
  #   try:
  #     socket.gethostbyname(fqdn)
  #   except Exception:
  #     self.pending_fqnds.append(fqdn)
  #     return False
  #   return True

  def retry_pending_fqdn(self):
    for fqdn in self.pending_fqnds:
      try:
        socket.gethostbyname(fqdn)
      except Exception:
        pass
      else:
        self.logger.info(f'Found new DNS entry {fqdn}')
        try:
          self.check_servers()
        except Exception as e:
          self.logger.exception(e)
        return

  def check_servers(self):
    self.pending_fqnds = []
    for service in self.client.services.list():
      labels = service.attrs.get('Spec', {}).get('Labels', {})
      if 'published_fqdn' in labels:
        self.logger.info(f'Upstream service { service.id } as { service.attrs["Spec"]["Name"] } is found (fqdn = { labels.get("published_fqdn") }).')
        server = webserver(labels)
        server.get_new_certificate(self.logger)
        if server.adjust(self.logger):
          s = self.client.services.get(service.id)
          self.logger.info(f'Update labels of service { service.id } as {server.labels}')
          s.update(labels=server.labels)

  def run(self):
    self.logger.info('Started')
    self.check_servers()
    last_checked_datetime = datetime.now()
    while True:
      # try:
        schedule.run_pending()
        since = last_checked_datetime.timestamp()
        until = (last_checked_datetime + CHECK_INTERVAL).timestamp()
        # self.logger.info(f'gather events from docker until {(last_checked_datetime + CHECK_INTERVAL).isoformat()}')
        event_found = False
        for ev in self.client.events(decode=True, since=since, until=until):
          if ev.get('Type') == 'service' and ev.get('Action') == 'update' and ev.get('Actor', {}).get('Attributes', {}).get('updatestate.new') == 'completed':
            self.logger.info(f'found update completed event={ev}')
            event_found = True
          if ev.get('Type') == 'service' and ev.get('Action') == 'update' and ev.get('Actor', {}).get('Attributes', {}).get('updatestate.new') == 'paused':
            # sometimes update is paused due to failure or early termination of task ..
            # so force update here
            self.logger.info(f'found update paused event={ev}')
            service_id = ev.get('Actor', {}).get('ID')
            if service_id:
              self.logger.info(f'Update is paused so force update')
              # TODO: ensure the service is updated by me
              # TODO: prevent infinite loop update->pause->update->pause ...
              s = self.client.services.get(service_id)
              s.force_update()
        if event_found:
          self.logger.info('Receive a service update completed event.')
          self.check_servers()
        last_checked_datetime += CHECK_INTERVAL
      # except Exception as e:
      #   self.logger.exception(e)

  def renew_certs(self):
    try:
      error = certbot_main.main(['renew'])
      if error:
        self.logger.error(f'error occuer in certbot renew: {error}')
      else:
        self.check_servers()
    except Exception as e:
      self.logger.exception(e)


RUNNER = runner()


def renew_certs_on_timer():
  RUNNER.renew_certs()


def retry_pending_fqdn_on_timer():
  # RUNNER.retry_pending_fqdn()
  RUNNER.check_servers()


def main():
  schedule.every().day.at(os.environ.get('RENEW_TIME', '01:00')).do(renew_certs_on_timer)
  schedule.every(3).minutes.do(retry_pending_fqdn_on_timer)
  # if envionment variables is not defined, raise exception
  certbot_pdns_config = {
      'api-key': os.environ['CERTBOT_PDNS_API_KEY'],
      'base-url': os.environ['CERTBOT_PDNS_BASE_URL'],
      'axfr-time': 5
  }
  with open('/etc/letsencrypt/certbot-pdns.json', 'w') as config_file:
    json.dump(certbot_pdns_config, config_file)
  RUNNER.run()


if __name__ == "__main__":
    main()
