#!/usr/bin/env python3

"""
Helper utility that lists and mounts AWS EFS filesystems.
"""

import os, sys, subprocess, argparse, logging, collections

import boto3, requests

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--timeout", type=int, default=10)
args = parser.parse_args()

def get_metadata(path):
    return requests.get("http://169.254.169.254/latest/meta-data/{}".format(path)).content.decode()

az = get_metadata("placement/availability-zone")
services_domain = get_metadata("services/domain")
# subnet = get_metadata("network/interfaces/macs/{}/subnet-id".format(get_metadata("mac")))

efs = boto3.Session(region_name=az[:-1]).client("efs")

mountpoints = {}

for filesystem in efs.describe_file_systems()["FileSystems"]:
    for tag in efs.describe_tags(FileSystemId=filesystem["FileSystemId"])["Tags"]:
        if tag["Key"] == "mountpoint":
            mountpoints[tag["Value"]] = filesystem
else:
    logger.info("No EFS filesystems found")

mount_procs = collections.OrderedDict()
for mountpoint, filesystem in mountpoints.items():
    logger.info("Mounting %s on %s", filesystem["FileSystemId"], mountpoint)
    fs_url = "{az}.{fs}.efs.{region}.{domain}:/".format(az=az,
                                                        fs=filesystem["FileSystemId"],
                                                        region=efs.meta.region_name,
                                                        domain=services_domain)

    mount_procs[mountpoint] = subprocess.Popen(["mount", "-t", "nfs4", "-o",
                                                "nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2",
                                                fs_url, mountpoint])

rvals = {}
for m, p in mount_procs.items():
    rvals[m] = 1
    if not os.path.isdir(m):
        os.makedirs(m)
    try:
        if not os.path.ismount(m):
            logger.info("Path: %s is not a mountpoint ... mounting (timeout: %d)" % (m, args.timeout))
            rvals[m] = p.wait(timeout=args.timeout)
            logger.info("Mount exited with status: %d" % rvals[m])
        else:
            rvals[m] = 0
    except Exception:
        logger.exception("Error while mounting EFS filesystem!")
        try:
            p.terminate()
        except Exception:
            logger.exception("Caught exception terminating")
            pass

exit(sum(rvals.values()))
