#!/usr/bin/env python3
# SPDX-FileCopyrightText: (C) 2024 Avnet Embedded GmbH
# SPDX-License-Identifier: GPL-3.0-only

# Stop running VMs

import argparse
import json
import logging
import os
import requests
from requests.adapters import HTTPAdapter, Retry
from time import sleep

AZURE_ACCEPTABLE_POWERSTATE = [
    'PowerState/deallocated',
    'PowerState/deallocating',
    'PowerState/stopped',
    'PowerState/stopping',
    'PowerState/unknown',
]

logging.getLogger().setLevel(logging.INFO)

session = requests.Session()
retries = Retry(
    total=5,
    backoff_factor=1,
    status_forcelist=[429, 500, 502, 503, 504],
)
session.mount('https://', HTTPAdapter(max_retries=retries))


class AzureClient():
    """Class to handle connection to Azure API"""

    def __init__(self, azure_token: str):
        """Class Constructor"""
        self.headers = {
            "content_type": "application/json",
            "authorization": f"bearer {azure_token}",
        }

    def get_vms(self, subscription_id: str, ressource_group_name: str) -> list[dict]:
        """Retrieve the VMs in the specified subscription and resoure group"""
        try:
            virtual_machines = session.get(
                f"https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{ressource_group_name}/providers/Microsoft.Compute/virtualMachines/?api-version=2021-03-01",  # noqa: E501
                headers=self.headers,
            ).json().get('value', [])
        except Exception:
            logging.exception(
                "Got error when requesting virtual machines from Azure, does the token have all permissions to get them?")
            exit(1)

        for vm in virtual_machines:
            statuses = session.get(
                f"https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{ressource_group_name}/providers/Microsoft.Compute/virtualMachines/{vm['name']}/instanceView?api-version=2021-03-01",  # noqa: E501
                headers=self.headers,
            ).json().get('statuses', [])
            status = [x for x in statuses if 'PowerState' in x.get('code', '')]
            if status:
                vm['status'] = status[0]['code']
            else:
                vm['status'] = None

        return virtual_machines

    def stop_vm(self, subscription_id: str, ressource_group_name: str, vm: dict) -> bool:
        """Stop the selected VM

        Args:
            subscription_id (str): the subscription in which we are currently working
            ressource_group_name (str): the resource group in which we are currently working
            vm (dict): the VM object as sent by Azure

        Returns:
            bool: whether the VM is stopping or stopped
        """
        logging.info(f"Stopping VM {vm['name']}")
        post_action = session.post(
                f"https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{ressource_group_name}/providers/Microsoft.Compute/virtualMachines/{vm['name']}/deallocate?api-version=2021-03-01",  # noqa: E501
                headers=self.headers,
            )
        logging.info(
            f"Requested stop for VM {vm['name']} - {post_action.status_code} {post_action.text}",
        )

        print(json.dumps({
            "name": vm['name'],
            "status": vm['status'],
            "gh_runner_name": vm['tags'].get('GH_RUNNER_NAME', ''),
        }))
        statuses = session.get(
                f"https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{ressource_group_name}/providers/Microsoft.Compute/virtualMachines/{vm['name']}/instanceView?api-version=2021-03-01",  # noqa: E501
                headers=self.headers,
            ).json().get('statuses', [])
        status = [x.get('code', '') for x in statuses if x.get(
            'code', '').startswith('PowerState')]
        logging.info(f"Current powerstate status of {vm['name']}: {status}")
        status = [x for x in status if x in AZURE_ACCEPTABLE_POWERSTATE]
        return status


def get_args() -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        prog="stop-vm",
        description="Stop all VMs in Azure",
    )
    parser.add_argument('subscription_id', help="Azure Subscription ID")
    parser.add_argument('ressource_group_name',
                        help="Azure Ressource Group Name")
    parser.add_argument('tag', help="VM tag to filter")
    parser.add_argument(
        '-t', '--token',
        help="Azure Token, may also be passed through the environment variable AZ_BEARER",
        default='',
    )

    return parser.parse_args()


def main():
    """Main function"""
    args = get_args()

    azure_token = args.token or os.getenv('AZ_BEARER')
    if not azure_token:
        logging.error(
            'Error: an Azure token is required, you may pass it as an argument or through the AZ_BEARER environment variable',
        )
        exit(1)

    client = AzureClient(azure_token)

    virtual_machines = client.get_vms(
        args.subscription_id, args.ressource_group_name)

    logging.info(f"""All vms status: {[{
                        "name": vm['name'],
                        "status": vm['status'],
                        "tags": vm['tags'],
                    } for vm in virtual_machines]}""",
                 )
    # Checking if we have a VM currently stopped and ready to be powered on, and start it
    available_vms = [
        vm for vm in virtual_machines if args.tag == vm.get("tags", {}).get("ROLE", "")
    ]

    for selected_vm in available_vms:
        for _ in range(0, 4):
            vm_stopped = client.stop_vm(
                args.subscription_id, args.ressource_group_name, selected_vm)
            if vm_stopped:
                break
            sleep(30)
    return 0


if __name__ == "__main__":
    exit(main())
