#!/usr/bin/python3
# SPDX-FileCopyrightText: (C) 2022 Avnet Embedded GmbH
# SPDX-License-Identifier: GPL-3.0-only

'''(Github) test report creator'''

__copyright__ = 'Copyright (C) 2022 Avnet Embedded GmbH'
__license__ = 'GPLv3'

import argparse
import copy
import json
import logging
import multiprocessing as mp
import os
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
from functools import partial
from typing import Dict, List

r_spellcheck = re.compile(
    r'WARNING:\s+(?P<path>[A-Za-z0-9_/.-]+):(?P<line>\d+):.*Spell\scheck:\s+(?P<typo>.*):.*',
    flags=re.MULTILINE)

logging.basicConfig(stream=sys.stderr, level=logging.INFO)


def get_args() -> argparse.Namespace:
    '''Parse command-line arguments.'''
    script_description = 'testreport: create a testreport as a comment'
    parser = argparse.ArgumentParser(prog='testreport',
                                     description=script_description)
    parser.add_argument('--withcoverage', default='',
                        help='add recipe coverage report')
    parser.add_argument('--withdeadlinks', default='',
                        help='add documentation dead links report')
    parser.add_argument('--withqalogdir', default='',
                        help='Add QA logs from directory')
    parser.add_argument('--withspellcheck', default='',
                        help='Report from documentation spell check')
    parser.add_argument('--withshadoweddir', default='',
                        help='Directory with results from shadowed recipe search')

    subparsers = parser.add_subparsers(title="subcommand", required=True,
                                       dest="subcommand")
    parser_pr = subparsers.add_parser(
        'pr', help='pull request level test report')
    parser_pr.add_argument('--srcdir', default='', help=argparse.SUPPRESS)
    parser_pr.add_argument(
        '--github', default=os.environ.get('GITHUB'), help='Github pipeline json')
    parser_pr.add_argument(
        '--token', default=os.environ.get('GITHUB_TOKEN'), help='Access Token for Github')
    parser_pr.add_argument('dir', help='JUnit input files directory')

    parser_release = subparsers.add_parser(
        'release', help='release level test report')
    parser_release.add_argument(
        '--testnames', default=None, help='Store testname to...')
    parser_release.add_argument(
        '--rc', action='store_true', help='Reset all manual testing reports')
    parser_release.add_argument('dir', help='JUnit input files directory')
    parser_release.add_argument('output', help='Output directory')

    return parser.parse_args()


def find_files_and_folders(args: argparse.Namespace) -> List[str]:
    test_reports = set()
    for root, _, files in os.walk(args.dir):
        for f in files:
            fullpath = os.path.join(root, f)
            if not fullpath.endswith('_testresults.json'):
                continue
            test_reports.add(fullpath)
    return sorted(test_reports)


class TestResultImageEnvironment():
    def __init__(self,
                 machine: str = '',
                 baseboard: str = '',
                 feature: str = '',
                 hasLVDS: bool = False,
                 hasHDMI: bool = False,
                 version: str = '',
                 **kwargs):
        self.baseboard = baseboard
        self.machine = machine
        self.feature = feature
        self.hasLVDS = hasLVDS
        self.hasHDMI = hasHDMI
        self.version = version

    def __repr__(self):
        return f'{self.machine}:{self.feature} on {self.baseboard}[LVDS:{self.hasLVDS}, HDMI:{self.hasHDMI}]@{self.version}'

    def __eq__(self, other) -> bool:
        if isinstance(other, TestResultImageEnvironment):
            return (self.baseboard,
                    self.machine,
                    self.feature,
                    self.hasLVDS,
                    self.hasHDMI) == (other.baseboard,
                                      other.machine,
                                      other.feature,
                                      other.hasLVDS,
                                      other.hasHDMI)
        return False

    def __ne__(self, other) -> bool:
        return (not self.__eq__(other))

    def __lt__(self, other):
        return str(self) < str(other)

    def __hash__(self):
        return hash((self.baseboard, self.machine, self.feature, self.hasLVDS, self.hasHDMI))


class TestResultCase():

    def __init__(self, name: str, description: str, status: str, environment: dict, msg: str = None) -> None:
        self.name = name
        self.description = description
        self.status = status
        self.msg = msg
        self.environment: TestResultImageEnvironment = TestResultImageEnvironment(
            **environment)

    def failure_environment_to_str(self) -> str:
        return f'@{self.environment}'

    def __eq__(self, other) -> bool:
        if isinstance(other, TestResultCase):
            return (self.name,
                    self.description,
                    self.status,
                    self.msg,
                    self.environment) == (other.name,
                                          other.description,
                                          other.status,
                                          other.msg,
                                          other.environment)
        return False

    def __ne__(self, other) -> bool:
        return (not self.__eq__(other))

    def __hash__(self):
        return hash((self.name, self.description, self.status, self.environment, self.msg))


class TestResultImage():

    def __init__(self, name: str):
        self.tests: List[TestResultCase] = []
        self.skipped = 0
        self.passed = 0
        self.failed = 0
        self.name = name

    def add_new_case(self, name: str, description: str, status: str, environment: dict, msg: str = None):
        if status == 'PASSED':
            self.passed += 1
        elif status == 'SKIPPED':
            self.skipped += 1
        else:
            self.failed += 1
        self.tests.append(TestResultCase(
            name, description, status, environment, msg))

    def all_environment_to_str(self) -> str:

        environment = set()
        # better merge with versions and stuff
        for test in self.tests:
            environment.add(test.environment)
        if not environment:
            return ''
        res = ''
        for item in sorted(environment):
            res += f'- {item}\n'
        if not res:
            return ''
        return f'## Tested variants\n\n{res}\n'

    def has_fails(self) -> bool:
        return self.failed > 0

    def failures(self) -> List[TestResultCase]:
        return [x for x in self.tests if x.status in ['FAILED', 'ERROR']]

    def merge(self, other: 'TestResultImage') -> None:
        self.tests += other.tests
        self.passed += other.passed
        self.skipped += other.skipped
        self.failed += other.failed

    def mapped_by_name(self) -> Dict[str, List[TestResultCase]]:
        res = {}
        for test in self.tests:
            if test.name not in res:
                res[test.name] = []
            res[test.name].append(test)
        return res


class TestResultArch():

    def __init__(self, arch: str):
        self.images: List[TestResultImage] = []
        self.arch = arch

    def get_image_or_new(self, imgname: str) -> TestResultImage:
        for img in self.images:
            if imgname == img.name:
                return img
        new_img = TestResultImage(imgname)
        self.images.append(new_img)
        return new_img

    def has_fails(self):
        return any(x.has_fails() for x in self.images)

    def get_coverage(self) -> tuple[int, List[TestResultCase], List[TestResultCase]]:
        _tests_image_merged = None

        for img in self.images:
            if _tests_image_merged is None:
                _tests_image_merged = copy.copy(img)
            else:
                _tests_image_merged.merge(img)
        if _tests_image_merged is None:
            _tests_image_merged = TestResultImage('dummy')

        _tests_image_merged = _tests_image_merged.mapped_by_name()
        _failed = []
        _skipped = []
        _total = len(_tests_image_merged.keys())
        for name, cases in _tests_image_merged.items():
            if any([x for x in cases if x.status in ['FAILED', 'ERROR']]):
                _failed.append(cases[0])
            if ['SKIPPED'] == list(set(x.status for x in cases)):
                _skipped.append(cases[0])

        return (_total - len(_failed) - len(_skipped), _failed, _skipped)

    def merge(self, other: 'TestResultArch') -> None:
        for image in other.images:
            self.get_image_or_new(image.name).merge(image)


class TestResult():

    def __init__(self) -> None:
        self.archs: List[TestResultArch] = []

    def get_arch_or_new(self, archname: str) -> TestResultArch:
        for arch in self.archs:
            if archname == arch.arch:
                return arch
        new_arch = TestResultArch(archname)
        self.archs.append(new_arch)
        return new_arch

    def has_fails(self):
        return any(x.has_fails() for x in self.archs)

    def get_coverage(self) -> tuple[int, List[TestResultCase], List[TestResultCase]]:
        _tests_image_merged = None

        for arch in self.archs:
            for img in arch.images:
                if _tests_image_merged is None:
                    _tests_image_merged = copy.copy(img)
                else:
                    _tests_image_merged.merge(img)
        if _tests_image_merged is None:
            _tests_image_merged = TestResultImage('dummy')

        _tests_image_merged = _tests_image_merged.mapped_by_name()
        _failed = []
        _skipped = []
        _total = len(_tests_image_merged.keys())
        for name, cases in _tests_image_merged.items():
            if any([x for x in cases if x.status in ['FAILED', 'ERROR']]):
                _failed.append(cases[0])
            if ['SKIPPED'] == list(set(x.status for x in cases)):
                _skipped.append(cases[0])

        return (_total - len(_failed) - len(_skipped), _failed, _skipped)

    def merge(self, other: 'TestResult') -> None:
        for arch in other.archs:
            self.get_arch_or_new(arch.arch).merge(arch)


def parse_report(report: str):
    res = TestResult()
    logging.info(f'Open {report}')
    with open(report) as i:
        document = json.load(i)
        machine = document.get('configuration', {}).get('MACHINE', '')
        imagename = document.get(
            'configuration', {}).get('IMAGE_BASENAME', '')
        version = document.get('configuration', {}).get('VERSION', '')
        image = res.get_arch_or_new(machine).get_image_or_new(imagename)

        environment = document.get('environment', {})
        environment['version'] = version

        for name, v in document.get('result', {}).items():
            status = v.get('status', '')
            logging.info(f'{report}: result {name}')
            description = v.get('description', '')
            msg = v.get('log', None)
            if msg is not None and len(msg) > 1024:
                msg = msg[:1024] + '...'
            image.add_new_case(name, description, status, environment, msg)
    return res


def parse_reports(reports: List[str]) -> TestResult:
    res = TestResult()

    if reports:
        with mp.Pool(processes=min(mp.cpu_count(), len(reports))) as pool:
            try:
                results = pool.map(partial(parse_report), reports)
            finally:
                pool.close()
                pool.join()

        for result in results:
            res.merge(result)

    return res


def get_coverage(args: argparse.Namespace) -> str:
    if args.withcoverage:
        with open(args.withcoverage) as i:
            return i.read()
    return ''


def get_deadlinks(args: argparse.Namespace) -> str:
    if args.withdeadlinks:
        with open(args.withdeadlinks) as i:
            return i.read()
    return ''


def get_qalog(args: argparse.Namespace) -> str:
    res = '## Build QA logs\n\n'
    if args.withqalogdir:
        for root, _, files in os.walk(args.withqalogdir):
            for f in files:
                with open(os.path.join(root, f)) as i:
                    cnt = i.read()
                    if cnt:
                        res += f'### {f}\n\n'
                        res += cnt + '\n'
    return res


def get_shadowedresults(args: argparse.Namespace) -> str:
    res = '## Recipes provided by other layer\n\n'
    cnt = set()
    if args.withshadoweddir:
        for root, _, files in os.walk(args.withshadoweddir):
            for f in files:
                with open(os.path.join(root, f)) as i:
                    cnt.update([x.strip() for x in i.readlines() if x.strip()])
    logging.info(f'shadowed result: {cnt}')
    if not cnt:
        return ''
    for item in cnt:
        res += f'- {item}\n'
    res += '\n'
    return res


def get_spellcheck(args: argparse.Namespace) -> str:
    res = ''
    findings = False
    if args.withspellcheck:
        res += '## Spell check\n\n'
        with open(args.withspellcheck) as i:
            for finding in re.finditer(r_spellcheck, i.read()):
                res += f'- {finding.group("path")}:{finding.group("line")}: {finding.group("typo")}\n'
                findings = True
            if findings is False:
                res += '✅ no spelling issues found\n'
            res += '\n'
    return res


def create_markdown_report(results: TestResult) -> str:
    res = '**Test Report**\n\n'
    for arch in results.archs:
        for image in arch.images:
            res += f'| Testsuite | {arch.arch}: {image.name} | |\n'
            res += '| - | - | - |\n'
            res += '| ✅ Passed | ⏩ Skipped | ❌ Failed |\n'
            res += f'| {image.passed} | {image.skipped} | {image.failed} |\n'
            res += '\n\n'
            failures = image.failures()
            for test in failures:
                res += '| Failed cases |\n'
                res += '| - |\n'
                res += f'| 🔎 {test.name} |\n'
                res += '```\n'
                res += test.msg
                res += '\n```\n'
                res += '\n\n'
    return res


def _create_coverage_report(in_: tuple[int, List[TestResultCase], List[TestResultCase]]) -> str:
    good, failed, skipped = in_
    res = textwrap.dedent(f'''
    .. plot::
        :include-source: false
        :show-source-link: false
        :context: close-figs

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots()
        ax.pie([{good}, {len(failed)}, {len(skipped)}],
               explode=(0.3, 0.2, 0.1),
               shadow=True,
               autopct="%1.0f%%",
               labels=["Passed", "Failed", "Skipped"])
        plt.show()

    ''')

    if failed:
        res += '.. dropdown:: Failed tests\n'
        res += '    :color: warning\n\n'
        for fail in failed:
            res += f'    - {fail.name}\n'
        res += '\n'

    if skipped:
        res += '.. dropdown:: Skipped tests\n'
        res += '    :color: info\n\n'
        for skip in skipped:
            res += f'    - {skip.name}\n'
        res += '\n'
    return res


def create_report_coverage(results: TestResult, base_path: str):
    res = '.. dropdown:: Global\n'
    res += '    :color: muted\n'
    res += '    :open:\n\n'
    res += textwrap.indent(_create_coverage_report(results.get_coverage()), '    ')

    for arch in results.archs:
        res += f'.. dropdown:: {arch.arch}\n'
        res += '    :color: muted\n\n'
        res += textwrap.indent(_create_coverage_report(arch.get_coverage()), '    ')

    output = os.path.join(base_path, '_statistics.rst.tpl')
    os.makedirs(os.path.dirname(output), exist_ok=True)
    with open(output, 'w') as o:
        o.write(res)


def create_detailed_markdown_report(results: TestResult, base_path: str, testnames: str):
    testnames_map = {}

    create_report_coverage(results, base_path)

    for arch in results.archs:
        for image in arch.images:

            footnote_map = []

            logging.info(f'Create release report {arch.arch}:{image.name}')
            res = f'# {arch.arch}: {image.name} test report\n\n'

            # generate header
            res += image.all_environment_to_str()

            # generate result table
            res += ':::{table} Results\n'
            res += ':class: datatable\n'
            res += '| Test | Result | Description |\n'
            res += '| - | - | - |\n'
            for name, tests in image.mapped_by_name().items():
                clean_name = name.replace('.', ' > ')
                clean_dsc = tests[0].description.replace(
                    '|', '-').replace('\n', '<br/>')

                if name not in testnames_map or not testnames_map[name]:
                    testnames_map[name] = clean_dsc

                if all(x.status in ['SKIPPED'] for x in tests):
                    continue

                if all(x.status in ['PASSED', 'SKIPPED'] for x in tests):
                    status = '✅ Passed'
                elif all(x.status in ['FAILED', 'SKIPPED'] for x in tests):
                    status = '❌ All Failed'
                else:
                    # passed and failed test cases
                    footnotes_to_add = set()
                    for item in [x for x in tests if x.status == 'FAILED']:
                        item: TestResultCase = item
                        try:
                            footnotes_to_add.add(
                                f'[^{footnote_map.index(item.failure_environment_to_str()) + 1}]'
                            )
                        except (ValueError, IndexError):
                            index = f'[^{len(footnote_map) + 1}]'
                            footnote_map.append(
                                item.failure_environment_to_str())
                            footnotes_to_add.add(index)

                    status = f'⚠️ Failed for {" ".join(sorted(footnotes_to_add))}'

                res += f'| {clean_name} | {status} | {clean_dsc} |\n'

            res += '\n\n'
            res += ":::\n"
            for index, value in enumerate(footnote_map):
                res += f'[^{index + 1}]: {value}\n'
            res += '\n'

            # write output
            output = os.path.join(base_path, arch.arch, f'{image.name}.md')
            os.makedirs(os.path.dirname(output), exist_ok=True)
            with open(output, 'w') as o:
                o.write(res)

    if testnames:
        testnames = os.path.abspath(testnames)
        os.makedirs(os.path.dirname(testnames), exist_ok=True)
        with open(testnames, 'w') as o:
            json.dump(testnames_map, o, sort_keys=True)


def send_comment(args: argparse.Namespace, input: Dict, url: str) -> bool:
    if not url:
        return True
    content = json.dumps(input)

    logging.info(f'{url}: {content}')

    with tempfile.NamedTemporaryFile(mode='w') as tmp:
        tmp.write(content)
        tmp.flush()
        try:
            subprocess.check_call(
                ['curl', '-X', 'POST',
                    '-H', 'Accept: application/vnd.github+json',
                    '-H', f'Authorization: token {args.token}',
                    url,
                    '-d', f'@{tmp.name}'])
        except subprocess.CalledProcessError as e:
            logging.exception(e)
            return False
    return True


def get_url(args: argparse.Namespace) -> str:
    try:
        _input = json.loads(args.github).get('event', {})
        if 'inputs' in _input:
            url = 'https://api.github.com/repos/avnet-embedded/'
            if 'pr-reponame' in _input['inputs']:
                url += _input['inputs'].get('pr-reponame',
                                            'doesnotexist') + '/issues/'
                url += '/'.join(_input['inputs']['pr-sha'].split('/')[2:-1])
                url += '/comments'
            else:
                # special case for the manifests
                url += 'simplecore-manifest/issues/'
                url += '/'.join(_input['inputs']['pr-sha'].split('/')[2:-1])
                url += '/comments'
            return url
        elif 'pull_request' in _input:
            return _input['pull_request']['comments_url']
        return ''
    except TypeError:
        return ''


def main():
    '''Check layer test script entry point.'''
    logging.getLogger().setLevel(logging.INFO)
    args = get_args()

    if args.subcommand == 'release':
        if args.dir != args.output:
            reports = find_files_and_folders(args)
            for r in reports:
                _relpath = os.path.relpath(r, args.dir)
                fullpath = os.path.join(args.output, _relpath)
                os.makedirs(os.path.dirname(fullpath), exist_ok=True)
                shutil.copy(r, fullpath)
            args.dir = args.output

    reports = find_files_and_folders(args)
    results = parse_reports(reports)
    coverage_report = get_coverage(args)
    if coverage_report:
        coverage_report = '<br/><br/>\n\n---\n\n' + coverage_report
    deadlinks_report = get_deadlinks(args)
    if deadlinks_report:
        deadlinks_report = '<br/><br/>\n\n---\n\n' + deadlinks_report
    spellcheck = get_spellcheck(args)
    if spellcheck:
        spellcheck = '<br/><br/>\n\n---\n\n' + spellcheck
    qa_log = get_qalog(args)
    if qa_log:
        qa_log = '<br/><br/>\n\n---\n\n' + qa_log
    shadow_res = get_shadowedresults(args)
    if shadow_res:
        shadow_res = '<br/><br/>\n\n---\n\n' + shadow_res

    if args.subcommand == 'pr':
        url = get_url(args)
        res = {
            'body': create_markdown_report(results) + qa_log + coverage_report + deadlinks_report + spellcheck + shadow_res
        }
        send_comment(args, res, url)
        sys.exit(1 if results.has_fails() else 0)

    if args.subcommand == 'release':
        create_detailed_markdown_report(results, args.output, args.testnames)


if __name__ == '__main__':
    main()
