"""Integration test using a small ONNX model (add.onnx)."""
import sys
import os
import numpy as np

# Add repo root and src to path
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, os.path.join(ROOT, 'src'))
sys.path.insert(0, ROOT)

import ryzenai


def get_test_model_path():
    """Get path to test model, generate it if missing."""
    model_path = os.path.join(os.path.dirname(__file__), 'data', 'add.onnx')
    if not os.path.exists(model_path):
        import generate_test_model
        generate_test_model.make_add_model()
    return model_path


def test_model_load_and_run():
    """Load the test model and run inference on CPU."""
    model_path = get_test_model_path()
    
    # Try to load and run the model
    model = ryzenai.load_model(model_path, device_preference=['cpu'])
    
    # Run inference
    a = np.array([[1.0, 2.0]], dtype=np.float32)
    b = np.array([[3.0, 4.0]], dtype=np.float32)
    inputs = {'A': a, 'B': b}
    
    outputs = model.run(inputs)
    assert len(outputs) == 1
    y = outputs[0]
    
    # Verify output
    expected = a + b
    np.testing.assert_array_almost_equal(y, expected)
    
    # Test benchmark helper
    stats = ryzenai.benchmark_model(model_path, runs=3)
    assert stats['runs'] == 3
    assert 'mean_s' in stats


def test_model_on_gpu():
    """Try to run the model on GPU if available."""
    info = ryzenai.device_info()
    providers = info.get('providers', [])
    
    # Skip if no GPU provider detected
    has_gpu = any(p for p in providers if 'gpu' in p.lower() or 'dml' in p.lower())
    if not has_gpu:
        print('No GPU provider found, skipping GPU test')
        return
        
    model_path = get_test_model_path()
    model = ryzenai.load_model(model_path, device_preference=['dml', 'gpu'])
    
    # Run inference
    a = np.array([[1.0, 2.0]], dtype=np.float32)
    b = np.array([[3.0, 4.0]], dtype=np.float32)
    inputs = {'A': a, 'B': b}
    
    outputs = model.run(inputs)
    assert len(outputs) == 1
    y = outputs[0]
    
    # Verify output
    expected = a + b
    np.testing.assert_array_almost_equal(y, expected)