# Copyright (c) DataLab Platform Developers, BSD 3-Clause license, see LICENSE file.

"""
Module providing test utilities
"""

from __future__ import annotations

import atexit
import functools
import os
import os.path as osp
import pathlib
import subprocess
import sys
import tempfile
from collections.abc import Callable
from typing import Any

import numpy as np
from guidata.configtools import get_module_data_path

from sigima.config import MOD_NAME
from sigima.tests.env import execenv

TST_PATH = []


def get_test_paths() -> list[str]:
    """Return the list of test data paths"""
    return TST_PATH


def add_test_path(path: str) -> None:
    """Appends test data path, after normalizing it and making it absolute.
    Do nothing if the path is already in the list.

    Args:
        Path to add to the list of test data paths

    Raises:
        FileNotFoundError: if the path does not exist
    """
    path = osp.abspath(osp.normpath(path))
    if path not in TST_PATH:
        if not osp.exists(path):
            raise FileNotFoundError(f"Test data path does not exist: {path}")
        TST_PATH.append(path)


def add_test_path_from_env(envvar: str) -> None:
    """Appends test data path from environment variable (fails silently)"""
    # Note: this function is used in third-party plugins
    path = os.environ.get(envvar)
    if path:
        add_test_path(path)


# Add test data files and folders pointed by `SIGIMA_DATA` environment variable:
add_test_path_from_env("SIGIMA_DATA")


def add_test_module_path(modname: str, relpath: str) -> None:
    """
    Appends test data path relative to a module name.
    Used to add module local data that resides in a module directory
    but will be shipped under sys.prefix / share/ ...

    modname must be the name of an already imported module as found in
    sys.modules
    """
    add_test_path(get_module_data_path(modname, relpath=relpath))


# Add test data files and folders for the DataLab module:
add_test_module_path(MOD_NAME, osp.join("data", "tests"))


def get_test_fnames(pattern: str, in_folder: str | None = None) -> list[str]:
    """
    Return the absolute path list to test files with specified pattern

    Pattern may be a file name (basename), a wildcard (e.g. *.txt)...

    Args:
        pattern: pattern to match
        in_folder: folder to search in, in test data path (default: None,
         search in all test data paths)
    """
    pathlist = []
    for pth in [osp.join(TST_PATH[0], in_folder)] if in_folder else TST_PATH:
        pathlist += sorted(pathlib.Path(pth).rglob(pattern))
    if not pathlist:
        raise FileNotFoundError(f"Test file(s) {pattern} not found")
    return [str(path) for path in pathlist]


def try_open_test_data(title: str, pattern: str) -> Callable:
    """Decorator handling test data opening"""

    def try_open_test_data_decorator(func: Callable) -> Callable:
        """Decorator handling test data opening"""

        @functools.wraps(func)
        def func_wrapper() -> None:
            """Decorator wrapper function"""
            execenv.print(title + ":")
            execenv.print("-" * len(title))
            try:
                for fname in get_test_fnames(pattern):
                    execenv.print(f"=> Opening: {fname}")
                    func(fname)
            except FileNotFoundError:
                execenv.print(f"  No test data available for {pattern}")
            finally:
                execenv.print(os.linesep)

        return func_wrapper

    return try_open_test_data_decorator


def get_default_test_name(suffix: str | None = None) -> str:
    """Return default test name based on script name"""
    name = osp.splitext(osp.basename(sys.argv[0]))[0]
    if suffix is not None:
        name += "_" + suffix
    return name


def get_output_data_path(extension: str, suffix: str | None = None) -> str:
    """Return full path for data file with extension, generated by a test script"""
    name = get_default_test_name(suffix)
    return osp.join(TST_PATH[0], f"{name}.{extension}")


def reduce_path(filename: str) -> str:
    """Reduce a file path to a relative path

    Args:
        filename: path to reduce

    Returns:
        Relative path to the file, relative to its parent directory
    """
    return osp.relpath(filename, osp.join(osp.dirname(filename), osp.pardir))


class WorkdirRestoringTempDir(tempfile.TemporaryDirectory):
    """Enhanced temporary directory with working directory preservation.

    A subclass of :py:class:`tempfile.TemporaryDirectory` that:

    * Preserves and automatically restores the working directory during cleanup
    * Handles common cleanup errors silently (PermissionError, RecursionError)

    Example::

        with WorkdirRestoringTempDir() as tmpdir:
            os.chdir(tmpdir)  # Directory change is automatically reverted at exit
    """

    def __init__(self) -> None:
        super().__init__()
        self.__cwd = os.getcwd()

    def cleanup(self) -> None:
        """Clean up temporary directory, restore working directory, ignore errors."""
        os.chdir(self.__cwd)
        try:
            super().cleanup()
        except (PermissionError, RecursionError):
            pass


def get_temporary_directory() -> str:
    """Return path to a temporary directory, and clean-up at exit"""
    tmp = WorkdirRestoringTempDir()
    atexit.register(tmp.cleanup)
    return tmp.name


def exec_script(
    path: str,
    wait: bool = True,
    args: list[str] = None,
    env: dict[str, str] | None = None,
    verbose: bool = False,
) -> subprocess.Popen | None:
    """Run test script.

    Args:
        path: path to script
        wait: wait for script to finish
        args: arguments to pass to script
        env: environment variables to pass to script
        verbose: if True, print command and output

    Returns:
        subprocess.Popen object if wait is False, None otherwise
    """
    stderr = subprocess.DEVNULL if execenv.unattended else None
    # pylint: disable=consider-using-with
    if verbose:
        command = [sys.executable, path] + ([] if args is None else args)
        proc = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            env=env,
            text=True,
        )
    else:
        command = [sys.executable, '"' + path + '"'] + ([] if args is None else args)
        proc = subprocess.Popen(" ".join(command), shell=True, stderr=stderr, env=env)
    if wait:
        if verbose:
            stdout, stderr = proc.communicate()
            print("Command:", " ".join(command))
            print("Return code:", proc.returncode)
            print("---- STDOUT ----\n", stdout)
            print("---- STDERR ----\n", stderr)
            return None
        proc.wait()
    return proc


def get_script_output(
    path: str, args: list[str] = None, env: dict[str, str] | None = None
) -> str:
    """Run test script and return its output.

    Args:
        path (str): path to script
        args (list): arguments to pass to script
        env (dict): environment variables to pass to script

    Returns:
        str: script output
    """
    command = [sys.executable, '"' + path + '"'] + ([] if args is None else args)
    result = subprocess.run(
        " ".join(command), capture_output=True, text=True, env=env, check=False
    )
    return result.stdout.strip()


def compare_lists(list1: list, list2: list, level: int = 1) -> bool:
    """Compare two lists

    Args:
        list1: first list
        list2: second list
        level: recursion level

    Returns:
        True if lists are the same, False otherwise
    """
    same = True
    prefix = "  " * level
    for idx, (elem1, elem2) in enumerate(zip(list1, list2)):
        execenv.print(f"{prefix}Checking element {idx}...", end=" ")
        if isinstance(elem1, (list, tuple)):
            execenv.print("")
            same = same and compare_lists(elem1, elem2, level + 1)
        elif isinstance(elem1, dict):
            execenv.print("")
            same = same and compare_metadata(elem1, elem2, level + 1)
        else:
            same_value = str(elem1) == str(elem2)
            if not same_value:
                execenv.print(f"Different values: {elem1} != {elem2}")
            same = same and same_value
            execenv.print("OK" if same else "KO")
    return same


def compare_metadata(
    dict1: dict[str, Any], dict2: dict[str, Any], level: int = 1
) -> bool:
    """Compare metadata dictionaries without private elements

    Args:
        dict1: first dictionary, exclusively with string keys
        dict2: second dictionary, exclusively with string keys
        level: recursion level

    Returns:
        True if metadata is the same, False otherwise
    """
    dict_a, dict_b = dict1.copy(), dict2.copy()
    for dict_ in (dict_a, dict_b):
        for key in list(dict_.keys()):
            if key.startswith("__"):
                dict_.pop(key)
    same = True
    prefix = "  " * level
    for key in dict_a:
        if key not in dict_b:
            same = False
            break
        val_a, val_b = dict_a[key], dict_b[key]
        execenv.print(f"{prefix}Checking key {key}...", end=" ")
        if isinstance(val_a, dict):
            execenv.print("")
            same = same and compare_metadata(val_a, val_b, level + 1)
        elif isinstance(val_a, (list, tuple)):
            execenv.print("")
            same = same and compare_lists(val_a, val_b, level + 1)
        else:
            same_value = str(val_a) == str(val_b)
            if not same_value:
                execenv.print(f"Different values for key {key}: {val_a} != {val_b}")
            same = same and same_value
            execenv.print("OK" if same else "KO")
    return same


def __array_to_str(data: np.ndarray) -> str:
    """Return a compact description of the array properties"""
    dims = "×".join(str(dim) for dim in data.shape)
    return f"{dims},{data.dtype},{data.min():.2g}→{data.max():.2g},µ={data.mean():.2g}"


def check_array_result(
    title: str,
    res: np.ndarray,
    exp: np.ndarray,
    rtol: float = 1.0e-5,
    atol: float = 1.0e-8,
    sort: bool = False,
    verbose: bool = True,
) -> None:
    """Assert that two arrays are almost equal.

    Args:
        title: title of the test
        res: result array
        exp: expected array
        rtol: relative tolerance for comparison
        atol: absolute tolerance for comparison
        sort: if True, sort arrays before comparison (default: False)
        verbose: if True, print detailed result (default: True)

    Raises:
        AssertionError: if arrays are not almost equal or have different dtypes
    """
    if sort:
        res = np.sort(np.array(res, copy=True), axis=None)
        exp = np.sort(np.array(exp, copy=True), axis=None)
    restxt = f"{title}: {__array_to_str(res)} (expected: {__array_to_str(exp)})"
    if verbose:
        execenv.print(restxt)
    assert np.allclose(res, exp, rtol=rtol, atol=atol, equal_nan=True), restxt
    assert res.dtype == exp.dtype, restxt


def check_scalar_result(
    title: str,
    res: float,
    exp: float | tuple[float, ...],
    rtol: float = 1.0e-5,
    atol: float = 1.0e-8,
    verbose: bool = True,
) -> None:
    """Assert that two scalars are almost equal.

    Args:
        title: title of the test
        res: result value
        exp: expected value or tuple of expected values
        rtol: relative tolerance for comparison
        atol: absolute tolerance for comparison
        verbose: if True, print detailed result (default: True)

    Raises:
        AssertionError: if values are not almost equal or if expected is not a scalar
         or tuple
    """
    restxt = f"{title}: {res} (expected: {exp})"
    if verbose:
        execenv.print(restxt)
    if isinstance(exp, tuple):
        assert any(np.isclose(res, exp_val, rtol=rtol, atol=atol) for exp_val in exp), (
            restxt
        )
    else:
        assert np.isclose(res, exp, rtol=rtol, atol=atol), restxt
