from __future__ import annotations

import logging
from functools import lru_cache

from .. import envs
from . import pyrocmsmi
from .__types__ import Detector, Device, Devices, ManufacturerEnum
from .__utils__ import PCIDevice, get_pci_devices

logger = logging.getLogger(__name__)


class HygonDetector(Detector):
    """
    Detect Hygon GPUs.
    """

    @staticmethod
    @lru_cache
    def is_supported() -> bool:
        """
        Check if the Hygon detector is supported.

        Returns:
            True if supported, False otherwise.

        """
        supported = False
        if envs.GPUSTACK_RUNTIME_DETECT.lower() not in ("auto", "hygon"):
            logger.debug("Hygon detection is disabled by environment variable")
            return supported

        pci_devs = HygonDetector.detect_pci_devices()
        if not pci_devs:
            logger.debug("No Hygon PCI devices found")
            return supported

        try:
            pyrocmsmi.rsmi_init()
            supported = True
        except pyrocmsmi.ROCMSMIError:
            if logger.isEnabledFor(logging.DEBUG):
                logger.exception("Failed to initialize ROCM SMI")

        return supported

    @staticmethod
    @lru_cache
    def detect_pci_devices() -> dict[str, PCIDevice] | None:
        # See https://pcisig.com/membership/member-companies?combine=Higon.
        pci_devs = get_pci_devices(vendor="0x1d94")
        if not pci_devs:
            return None
        return {dev.address: dev for dev in pci_devs}

    def __init__(self):
        super().__init__(ManufacturerEnum.HYGON)

    def detect(self) -> Devices | None:
        """
        Detect Hygon GPUs using pyrocmsmi.

        Returns:
            A list of detected Hygon GPU devices,
            or None if not supported.

        Raises:
            If there is an error during detection.

        """
        if not self.is_supported():
            return None

        ret: Devices = []

        try:
            pyrocmsmi.rsmi_init()

            sys_driver_ver = pyrocmsmi.rsmi_driver_version_get()
            sys_driver_ver_t = (
                [int(v) if v.isdigit() else v for v in sys_driver_ver.split(".")]
                if sys_driver_ver
                else None
            )

            devs_count = pyrocmsmi.rsmi_num_monitor_devices()
            for dev_idx in range(devs_count):
                dev_index = dev_idx

                dev_uuid = pyrocmsmi.rsmi_dev_unique_id_get(dev_idx)
                dev_name = "Hygon " + pyrocmsmi.rsmi_dev_name_get(dev_idx)
                dev_cc = pyrocmsmi.rsmi_dev_target_graphics_version_get(dev_idx)
                dev_cc_t = None
                if dev_cc:
                    dev_cc = dev_cc[3:]  # Strip "gfx" prefix
                    dev_cc_t = [int(v) if v.isdigit() else v for v in dev_cc.split(".")]

                dev_cores = None
                dev_cores_util = pyrocmsmi.rsmi_dev_busy_percent_get(dev_idx)
                dev_mem = pyrocmsmi.rsmi_dev_memory_total_get(dev_idx)
                dev_mem_used = pyrocmsmi.rsmi_dev_memory_usage_get(dev_idx)
                dev_temp = pyrocmsmi.rsmi_dev_temp_metric_get(dev_idx)

                dev_power = pyrocmsmi.rsmi_dev_power_cap_get(dev_idx)
                dev_power_used = pyrocmsmi.rsmi_dev_power_get(dev_idx)

                dev_appendix = {
                    "vgpu": False,
                }

                ret.append(
                    Device(
                        manufacturer=self.manufacturer,
                        index=dev_index,
                        name=dev_name,
                        uuid=dev_uuid,
                        driver_version=sys_driver_ver,
                        driver_version_tuple=sys_driver_ver_t,
                        compute_capability=dev_cc,
                        compute_capability_tuple=dev_cc_t,
                        cores=dev_cores,
                        cores_utilization=dev_cores_util,
                        memory=(dev_mem >> 20 if dev_mem > 0 else 0),
                        memory_used=(dev_mem_used >> 20 if dev_mem_used > 0 else 0),
                        memory_utilization=(
                            (dev_mem_used * 100 // dev_mem) if dev_mem > 0 else 0
                        ),
                        temperature=dev_temp,
                        power=dev_power,
                        power_used=dev_power_used,
                        appendix=dev_appendix,
                    ),
                )
        except pyrocmsmi.ROCMSMIError:
            if logger.isEnabledFor(logging.DEBUG):
                logger.exception("Failed to fetch devices")
            raise
        except Exception:
            if logger.isEnabledFor(logging.DEBUG):
                logger.exception("Failed to process devices fetching")
            raise

        return ret
