"""Test ODE model functions"""

import numpy as np
import pytest

from gemlib.deterministic import ode_model


def test_deterministic_dormandprince(evaltest, homogeneous_sir_params):
    """Test we get a functioning deterministic model using DormandPrince"""

    model_params = homogeneous_sir_params(0.8, 0.14)
    times = np.arange(0.0, model_params["num_steps"], 1.0, dtype=np.float32)
    del model_params["num_steps"]

    sample = evaltest(lambda: ode_model(**model_params, times=times))

    assert sample.times.shape == times.shape
    assert sample.states.shape == (199, 1, 3)


@pytest.mark.skip(reason="BDF solver too slow")
def test_deterministic_bdf(evaltest, homogeneous_sir_params):
    """Test we get a functioning deterministic model using BDF"""

    model_params = homogeneous_sir_params(0.8, 0.14)

    sample = evaltest(lambda: ode_model(**model_params, solver="BDF"))

    assert sample.times.shape == (2, 3, 199)
    assert sample.states.shape == (2, 3, 199, 1, 3)
