#!/usr/bin/env python

"""
Thin wrapper around the "terraform" command line interface (CLI) for use
with LocalStack.

The "tflocal" CLI allows you to easily interact with your local services
without having to specify the local endpoints in the "provider" section of
your TF config.
"""

import os
import re
import sys
import subprocess

PARENT_FOLDER = os.path.realpath(os.path.join(os.path.dirname(__file__), '..'))
if os.path.isdir(os.path.join(PARENT_FOLDER, '.venv')):
    sys.path.insert(0, PARENT_FOLDER)

from localstack_client import config  # noqa: E402

DEFAULT_REGION = "us-east-1"
LOCALHOST_HOSTNAME = "localhost.localstack.cloud"
S3_HOSTNAME = os.environ.get("S3_HOSTNAME") or f"s3.{LOCALHOST_HOSTNAME}"
USE_EXEC = str(os.environ.get("USE_EXEC")).strip().lower() in ["1", "true"]
TF_CMD = os.environ.get("TF_CMD") or "terraform"
LS_PROVIDERS_FILE = os.environ.get("LS_PROVIDERS_FILE") or "localstack_providers_override.tf"
LOCALSTACK_HOSTNAME = os.environ.get("LOCALSTACK_HOSTNAME") or "localhost"
EDGE_PORT = int(os.environ.get("EDGE_PORT") or 4566)
TF_PROVIDER_CONFIG = """
provider "aws" {
  access_key                  = "test"
  secret_key                  = "test"
  skip_credentials_validation = true
  skip_metadata_api_check     = true
  <configs>
 endpoints {
<endpoints>
 }
}
"""
PROCESS = None


def create_provider_config_file():
    # maps services to be replaced with alternative names
    service_replaces = {
        "apigatewaymanagementapi": "",
        "ce": "costexplorer",
        "edge": "",
        "iotdata": "",
        "iotjobsdata": "",
        "logs": "cloudwatchlogs",
        "opensearch": "",
        "timestream": ""
    }
    # service names to be excluded (not yet available in TF)
    service_excludes = ["meteringmarketplace"]

    # create list of service names
    services = list(config.get_service_ports())
    services = [srvc for srvc in services if srvc not in service_excludes]
    services = [s.replace("-", "") for s in services]
    for old, new in service_replaces.items():
        try:
            services.remove(old)
            if new:
                services.append(new)
        except ValueError:
            pass
    services = sorted(services)

    # create config
    endpoints = "\n".join([f'{s} = "{get_service_endpoint(s)}"' for s in services])
    tf_config = TF_PROVIDER_CONFIG.replace("<endpoints>", endpoints)
    additional_configs = []
    if use_s3_path_style():
        additional_configs += [" s3_use_path_style = true"]
    additional_configs += [f' region = "{get_region()}"']
    tf_config = tf_config.replace("<configs>", "\n".join(additional_configs))

    # write temporary config file
    providers_file = get_providers_file_path()
    if os.path.exists(providers_file):
        msg = f"Providers override file {providers_file} already exists - please delete it first"
        raise Exception(msg)
    with open(providers_file, mode="w") as fp:
        fp.write(tf_config)
    return providers_file


def get_providers_file_path() -> str:
    chdir = [arg for arg in sys.argv if arg.startswith("-chdir=")]
    base_dir = "."
    if chdir:
        base_dir = chdir[0].removeprefix("-chdir=")
    return os.path.join(base_dir, LS_PROVIDERS_FILE)


def get_service_endpoint(service: str) -> str:
    # allow configuring a custom endpoint via the environment
    env_name = f"{service.replace('-', '_').upper().strip()}_ENDPOINT"
    env_endpoint = os.environ.get(env_name, "").strip()
    if env_endpoint:
        if "://" not in env_endpoint:
            env_endpoint = f"http://{env_endpoint}"
        return env_endpoint

    # some services need specific hostnames
    hostname = LOCALSTACK_HOSTNAME
    if service == "s3":
        hostname = S3_HOSTNAME
    elif service == "mwaa":
        hostname = f"mwaa.{LOCALHOST_HOSTNAME}"

    return f"http://{hostname}:{EDGE_PORT}"


def use_s3_path_style() -> bool:
    regex = r"^[a-z]+://(localhost|[0-9.]+)(:[0-9]+)?$"
    return bool(re.match(regex, get_service_endpoint("s3")))


def get_region() -> str:
    region = str(os.environ.get("AWS_DEFAULT_REGION") or "").strip()
    if region:
        return region
    try:
        # If boto3 is installed, try to get the region from local credentials.
        # Note that boto3 is currently not included in the dependencies, to
        # keep the library lightweight.
        import boto3
        return boto3.session.Session().region_name
    except Exception:
        pass
    # fall back to default region
    return DEFAULT_REGION


def to_bytes(obj) -> bytes:
    return obj.encode("UTF-8") if isinstance(obj, str) else obj


def to_str(obj) -> bytes:
    return obj.decode("UTF-8") if isinstance(obj, bytes) else obj


def run_tf_exec(cmd, env):
    """Run terraform using os.exec - can be useful as it does not require any I/O
        handling for stdin/out/err. Does *not* allow us to perform any cleanup logic."""
    os.execvpe(cmd[0], cmd, env=env)


def run_tf_subprocess(cmd, env):
    """Run terraform in a subprocess - useful to perform cleanup logic at the end."""
    global PROCESS

    # register signal handlers
    import signal
    signal.signal(signal.SIGINT, signal_handler)

    PROCESS = subprocess.Popen(
        cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stdout)
    PROCESS.communicate()
    sys.exit(PROCESS.returncode)


def signal_handler(sig, frame):
    PROCESS.send_signal(sig)


def main():
    env = dict(os.environ)
    cmd = [TF_CMD] + sys.argv[1:]

    # create TF provider config file
    config_file = create_provider_config_file()

    # call terraform command
    try:
        if USE_EXEC:
            run_tf_exec(cmd, env)
        else:
            run_tf_subprocess(cmd, env)
    finally:
        os.remove(config_file)


if __name__ == '__main__':
    main()
