"""Backend implementations: CPU, GPU, and NPU (placeholder).

Each backend exposes: warmup(n_steps), run(inputs: dict), run_once(...), dispose().
"""
from typing import Dict, Optional, Tuple
import numpy as np

try:
    import onnxruntime as ort
except Exception:
    ort = None


class BaseBackend:
    def __init__(self, model_path: str):
        self.model_path = model_path
        self.session = None
        self.input_name = None

    def warmup(self, n_steps: int = 3):
        raise NotImplementedError()

    def run(self, inputs: Dict[str, np.ndarray]):
        raise NotImplementedError()

    def run_once(self, input_tokens: np.ndarray, past_kv: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
        raise NotImplementedError()

    def dispose(self):
        self.session = None


class CPUBackend(BaseBackend):
    def __init__(self, model_path: str):
        super().__init__(model_path)
        if ort is None:
            raise RuntimeError("onnxruntime is required for CPUBackend")
        self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
        inputs = self.session.get_inputs()
        self.input_name = inputs[0].name if inputs else None

    def warmup(self, n_steps: int = 3):
        """Run dummy inference with random inputs."""
        # Get first input's shape and create random tensor
        inputs = {}
        for inp in self.session.get_inputs():
            shape = [dim if isinstance(dim, int) else 1 for dim in inp.shape]
            dummy = np.random.rand(*shape).astype(np.float32)
            inputs[inp.name] = dummy
        # Run dummy inference n_steps times
        for _ in range(n_steps):
            self.session.run(None, inputs)

    def run(self, inputs: Dict[str, np.ndarray]):
        """Run inference with provided input tensors.
        
        Args:
            inputs: dict mapping input names to numpy arrays
        Returns:
            list of output tensors
        """
        # Ensure all required inputs are present
        required = {inp.name for inp in self.session.get_inputs()}
        missing = required - set(inputs.keys())
        if missing:
            raise ValueError(f"Required inputs ({list(missing)}) are missing from input feed ({list(inputs.keys())}).")
        return self.session.run(None, inputs)

    def run_once(self, input_tokens: np.ndarray, past_kv: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
        return self.session.run(None, {self.input_name: input_tokens}), {}


class GPUBackend(BaseBackend):
    def __init__(self, model_path: str, provider: str = "DmlExecutionProvider"):
        super().__init__(model_path)
        if ort is None:
            raise RuntimeError("onnxruntime is required for GPUBackend")
        providers = [provider, "CPUExecutionProvider"]
        self.session = ort.InferenceSession(model_path, providers=providers)
        inputs = self.session.get_inputs()
        self.input_name = inputs[0].name if inputs else None

    def warmup(self, n_steps: int = 3):
        """Run dummy inference with random inputs."""
        # Get first input's shape and create random tensor
        inputs = {}
        for inp in self.session.get_inputs():
            shape = [dim if isinstance(dim, int) else 1 for dim in inp.shape]
            dummy = np.random.rand(*shape).astype(np.float32)
            inputs[inp.name] = dummy
        # Run dummy inference n_steps times
        for _ in range(n_steps):
            self.session.run(None, inputs)

    def run(self, inputs: Dict[str, np.ndarray]):
        """Run inference with provided input tensors.
        
        Args:
            inputs: dict mapping input names to numpy arrays
        Returns:
            list of output tensors
        """
        # Ensure all required inputs are present
        required = {inp.name for inp in self.session.get_inputs()}
        missing = required - set(inputs.keys())
        if missing:
            raise ValueError(f"Required inputs ({list(missing)}) are missing from input feed ({list(inputs.keys())}).")
        return self.session.run(None, inputs)

    def run_once(self, input_tokens: np.ndarray, past_kv: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
        return self.session.run(None, {self.input_name: input_tokens}), {}


class NPUBackend(BaseBackend):
    def __init__(self, model_path: str):
        # Placeholder: vendor-specific SDK integration goes here.
        super().__init__(model_path)
        raise NotImplementedError("NPU backend is a placeholder until vendor SDK is provided")
