"""Clean, minimal public API for the lightweight ryzenai implementation."""
from typing import List, Optional, Dict
import numpy as np

from .core.scheduler import BackendScheduler
from .core.device_manager import detect_providers


class Model:
    """High level model wrapper.

    - Chooses a backend via the scheduler
    - Exposes warmup/run/run_once/dispose
    """

    def __init__(self, model_path: str, device_preference: Optional[List[str]] = None):
        self.model_path = model_path
        self.device_preference = device_preference or ["npu", "dml", "cpu"]
        self.scheduler = BackendScheduler(self.device_preference)
        self.backend = self.scheduler.load_backend(model_path)

    def warmup(self, n_steps: int = 3):
        self.backend.warmup(n_steps)

    def run(self, inputs: Dict[str, np.ndarray]):
        return self.backend.run(inputs)

    def run_once(self, input_tokens: np.ndarray, past_kv: Optional[Dict] = None):
        return self.backend.run_once(input_tokens, past_kv)

    def dispose(self):
        self.backend.dispose()


def load_model(model_path: str, device_preference: Optional[List[str]] = None) -> Model:
    """Factory to create a Model instance.

    Args:
        model_path: Path to the ONNX model file.
        device_preference: Optional ordered list of device preferences. Examples:
            ['dml', 'cpu'] to prefer DirectML then CPU.

    Returns:
        Model: initialized Model wrapper which exposes warmup/run/run_once/dispose.
    """
    return Model(model_path, device_preference)


def device_info() -> Dict:
    """Return detected execution providers and platform information.

    Returns a dict with keys 'providers' (list of provider names) and 'platform'.
    Example: {'providers': ['DmlExecutionProvider', 'CPUExecutionProvider'], 'platform': 'Windows'}
    """
    return detect_providers()


def benchmark_model(model_path: str, runs: int = 50, device_preference: Optional[List[str]] = None):
    m = load_model(model_path, device_preference=device_preference)
    from .utils.benchmark import benchmark_inference

    stats = benchmark_inference(m, runs=runs)
    m.dispose()
    return stats
