#!/usr/bin/python3
# SPDX-FileCopyrightText: (C) 2024 Avnet Embedded GmbH
# SPDX-License-Identifier: GPL-3.0-only

import argparse
import json
import re
import subprocess
import sys
from base64 import b64decode
from collections import defaultdict
from os import getenv
from typing import Dict, List, Set

import requests
from requests.adapters import HTTPAdapter, Retry


def get_args() -> argparse.Namespace:
    '''Parse command-line arguments.'''
    script_description = 'scotty-data-gen: create launcher data'
    parser = argparse.ArgumentParser(prog='scotty-data-gen',
                                     description=script_description)
    parser.add_argument('--t', '--token', dest='token',
                        default=getenv('GITHUB_PAT', ''), help='Github token')
    parser.add_argument('-o', '--output', dest='output',
                        default='data.json', help='Output file')
    return parser.parse_args()


def get_api_pages(first_url: str, session: requests.Session, headers: dict) -> list:
    """
    This function collects all the data from the provided URL and handles the pagination for you

    Args:
        first_url (str): the first URL to call
        session (requests.Session): session object
        headers (dict): additional headers

    Returns:
        list: the unpaginated data
    """
    URL_REGEX = re.compile(r'<(?P<url>https[^>]+)>; rel="(?P<type>\w+)"')

    next_url = first_url
    data = []

    while next_url:
        res = session.get(next_url, headers=headers)
        next_page_content = res.json()
        if isinstance(next_page_content, list):
            data.extend(next_page_content)
        elif isinstance(next_page_content, dict):
            data = [next_page_content]
        links = {
            x["type"]: x["url"] for x in URL_REGEX.finditer(res.headers.get("Link", ""))
        }
        next_url = links.get("next")
    return data


class ElementArchSpecific():

    def __init__(self):
        self.additionalterms: Dict[str, str] = {}
        self.public: bool = False
        self.sources: str = ''
        self.size: int = 0
        self.versions: Set[str] = set()

    def to_json(self) -> Dict:
        return {
            'additionalterms': self.additionalterms,
            'public': self.public,
            'size': self.size,
            'sources': self.sources,
            'versions': self.versions,
        }


class ElementArchSpecificDefaultDict(defaultdict):
    def __missing__(self, key):
        self[key] = ElementArchSpecific()
        return self[key]


class Element():

    def __init__(self):
        self.arch: Set[str] = set()
        self.arch_specific = ElementArchSpecificDefaultDict()
        self.capabilities: Set[str] = set()
        self.categories: Set[str] = set()
        self.conflicts: Set[str] = set()
        self.created_at: str = '2999-12-31T23:59:59Z'
        self.dependencies_mandatory: Set[str] = set()
        self.dependencies_recommended: Set[str] = set()
        self.description: str = ''
        self.license: str = ''
        self.name: str = ''
        self.short_name: str = ''
        self.picture: str = ''
        self.section: Set[Dict] = set()
        self.summary: str = ''
        self.supersededby: Set[str] = set()
        self.video: str = ''

        # deprecated
        self.public: bool = False
        self.sources: str = ''
        self.versions: Set[str] = set()
        self.size: Dict[str, int] = {}
        self.additionalterms: Dict[str, str] = {}

    def _get_size(self, image_name: str) -> int:
        image = json.loads(subprocess.check_output(
            f"skopeo inspect --retry-times=3 --raw docker://ghcr.io/avnet-embedded/{image_name}", shell=True))
        if 'layers' in image:
            return sum(x.get('size', 0) for x in image.get('layers', {}))
        return 0

    def _get_tags(self, image_name: str) -> List[str]:
        return json.loads(subprocess.check_output(f"skopeo list-tags docker://ghcr.io/avnet-embedded/{image_name}",
                                                  shell=True))['Tags']

    def _get_visibility(self, image_name: str) -> str:
        url_name = image_name.replace('/', '%2F')
        r = requests.get("https://api.github.com/orgs/avnet-embedded/packages/container/" + url_name,
                         headers={
                             "Accept": "application/vnd.github+json",
                             "Authorization": f"Bearer {getenv('GITHUB_PAT')}",
                             "X-GitHub-Api-Version": "2022-11-28"
                         })
        return r.json().get('visibility', '')

    def _get_oci_label(self, image_name) -> Dict:
        if not image_name:
            return {}
        image = json.loads(subprocess.check_output(
            f"skopeo inspect --retry-times=3 docker://ghcr.io/avnet-embedded/{image_name}", shell=True))
        return image.get('Labels', {})

    def _get_from_label(self, _labels, label, default='') -> object:
        return _labels.get(label, default)

    def _get_from_label_b64(self, _labels, label, default=b'') -> bytes:
        res = self._get_from_label(_labels, label, default=default)
        if res:
            return b64decode(res)
        return res

    def _get_from_label_b64_utf8(self, _labels, label, default='', dec='utf-8') -> str:
        res = self._get_from_label_b64(_labels, label, default=default)
        if res:
            return res.decode(dec)
        return res

    def _get_from_label_json(self, _labels, label, default=[]) -> object:
        res = self._get_from_label_b64(_labels, label, default=default)
        if res:
            return json.loads(res)
        return res

    def to_json(self) -> Dict:
        res = {
            'additionalterms': self.additionalterms,
            'arch': self.arch,
            'arch.specific': self.arch_specific,
            'capabilities': self.capabilities,
            'categories': self.categories,
            'conflicts': self.conflicts,
            'created_at': self.created_at,
            'dependencies.mandatory': self.dependencies_mandatory,
            'dependencies.recommended': self.dependencies_recommended,
            'description': self.description,
            'license': self.license,
            'picture': self.picture,
            'public': self.public,
            'size': self.size,
            'sources': self.sources,
            'summary': self.summary,
            'supersededby': self.supersededby,
            'versions': self.versions,
            'video': self.video,
        }
        if self.section:
            res['section'] = self.section
        return res

    def populate(self, gh_info: Dict) -> None:
        image_name = gh_info.get('name', '')
        if not image_name:
            return
        _labels = self._get_oci_label(image_name)
        _arch = self._get_from_label_json(
            _labels, 'com.avnet.simpleswitch.arch', [])
        if not _arch:
            print(
                f'WARNING: {image_name} has no SimpleSwitch attributes, skipping!')
            return

        self.name = image_name.split('/')[-1]
        _size = self._get_size(image_name)
        _tags = self._get_tags(image_name)
        _public = self._get_visibility(image_name) == 'public'

        self.created_at = min(self.created_at, gh_info.get('created_at', '2999-12-31T23:59:59Z'))
        # auto calc and updated
        self.arch.update(_arch)
        for arch in _arch:
            self.arch_specific[arch].public = _public
            self.arch_specific[arch].sources = self._get_from_label(
                _labels, 'org.opencontainers.image.source')
            self.arch_specific[arch].size = _size
            self.arch_specific[arch].additionalterms = self._get_from_label_json(
                _labels, 'com.avnet.simpleswitch.additionalterms', default={})
            self.arch_specific[arch].versions = _tags

        # overwrite get
        self.capabilities = self._get_from_label_json(
            _labels, 'com.avnet.simpleswitch.capabilities')
        self.categories = self._get_from_label_json(
            _labels, 'com.avnet.simpleswitch.categories')
        self.conflicts = self._get_from_label_json(
            _labels, 'com.avnet.simpleswitch.conflicts')
        self.dependencies_mandatory = self._get_from_label_json(
            _labels, 'com.avnet.simpleswitch.dependencies.mandatory')
        self.dependencies_recommended = self._get_from_label_json(
            _labels, 'com.avnet.simpleswitch.dependencies.recommended')
        self.description = self._get_from_label_b64_utf8(
            _labels, 'com.avnet.simpleswitch.description', dec='unicode-escape').strip('"').replace('\\n', '\n')
        self.license = self._get_from_label(
            _labels, 'org.opencontainers.image.licenses')
        self.picture = self._get_from_label(
            _labels, 'com.avnet.simpleswitch.picture')
        self.section = self._get_from_label_json(
            _labels, 'com.avnet.simpleswitch.sections', default={})
        self.summary = self._get_from_label_b64_utf8(
            _labels, 'com.avnet.simpleswitch.summary', dec='unicode-escape').strip('"')
        self.supersededby = self._get_from_label_json(
            _labels, 'com.avnet.simpleswitch.supersededby')
        self.video = self._get_from_label_b64_utf8(
            _labels, 'com.avnet.simpleswitch.video').strip('"')

        # deprecated
        self.public = _public
        self.sources = self._get_from_label(
            _labels, 'org.opencontainers.image.source')
        self.additionalterms = self._get_from_label_json(
            _labels, 'com.avnet.simpleswitch.additionalterms', default={})
        self.versions = _tags
        self.size.update({arch: _size for arch in _arch})

    def __eq__(self, other: 'Element'):
        return self.name == other.name

    @staticmethod
    def singleton(list_: List['Element'], gh_info: dict) -> 'Element':
        name = gh_info.get('name').split('/')[-1]
        for item in list_:
            if item.name == name:
                return item
        item = Element()
        list_.append(item)
        return item


class CustomJSONEncoder(json.JSONEncoder):
    """This method is used to serialize sets into lists then convert to JSON."""

    def default(self, obj):
        """Defaut JSON serializer"""
        if isinstance(obj, set):
            return sorted(obj)
        if isinstance(obj, ElementArchSpecific):
            return obj.to_json()
        return json.JSONEncoder.default(self, obj)


def main() -> list:
    _args = get_args()

    if not _args.token:
        sys.stderr.write('A Github token needs to be specified\n')
        sys.exit(1)

    session = requests.Session()
    retries = Retry(
        total=5,
        backoff_factor=1,
        status_forcelist=[429, 500, 502, 503, 504],
    )
    session.mount("https://", HTTPAdapter(max_retries=retries))

    headers = {
        "X-GitHub-Api-Version": "2022-11-28",
        "Authorization": f"Bearer {_args.token}",
        "Accept": "application/vnd.github.v3.repository+json",
    }

    _elements = []
    all_images = get_api_pages(
        "https://api.github.com/orgs/avnet-embedded/packages?package_type=container&per_page=100",
        session=session,
        headers=headers
    )
    images = [x for x in all_images if x['name'].startswith('simpleswitch/')]
    for image in images:
        print(f'Checking {image["name"]}')
        Element.singleton(_elements, image).populate(image)

    with open(_args.output, 'w') as f:
        json.dump({
            item.name: item.to_json() for item in _elements if item.name
        }, f, indent=2, cls=CustomJSONEncoder, sort_keys=True)
    return 0


if __name__ == '__main__':
    exit(main())
