Coverage for nilearn/decoding/tests/test_fista.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import numpy as np 

2import pytest 

3 

4from nilearn.decoding._objective_functions import ( 

5 logistic_loss, 

6 logistic_loss_lipschitz_constant, 

7 spectral_norm_squared, 

8 squared_loss, 

9 squared_loss_grad, 

10) 

11from nilearn.decoding._proximal_operators import prox_l1 

12from nilearn.decoding.fista import _check_lipschitz_continuous, mfista 

13 

14 

15@pytest.mark.parametrize("scaling", list(np.logspace(-3, 3, num=7))) 

16def test_logistic_lipschitz(rng, scaling, n_samples=4, n_features=2): 

17 X = rng.standard_normal((n_samples, n_features)) * scaling 

18 y = rng.standard_normal(n_samples) 

19 n_features = X.shape[1] 

20 

21 L = logistic_loss_lipschitz_constant(X) 

22 _check_lipschitz_continuous( 

23 lambda w: logistic_loss(X, y, w), n_features + 1, L 

24 ) 

25 

26 

27@pytest.mark.parametrize("scaling", list(np.logspace(-3, 3, num=7))) 

28def test_squared_loss_lipschitz(rng, scaling, n_samples=4, n_features=2): 

29 X = rng.standard_normal((n_samples, n_features)) * scaling 

30 y = rng.standard_normal(n_samples) 

31 n_features = X.shape[1] 

32 

33 L = spectral_norm_squared(X) 

34 _check_lipschitz_continuous( 

35 lambda w: squared_loss_grad(X, y, w), n_features, L 

36 ) 

37 

38 

39@pytest.mark.parametrize("cb_retval", [0, 1]) 

40@pytest.mark.parametrize("verbose", [0, 2]) 

41@pytest.mark.parametrize("dgap_factor", [1.0, None]) 

42def test_input_args_and_kwargs(cb_retval, verbose, dgap_factor, rng): 

43 p = 125 

44 noise_std = 1e-1 

45 sig = np.zeros(p) 

46 sig[[0, 2, 13, 4, 25, 32, 80, 89, 91, 93, -1]] = 1 

47 sig[:6] = 2 

48 sig[-7:] = 2 

49 sig[60:75] = 1 

50 y = sig + noise_std * rng.standard_normal(sig.shape) 

51 X = np.eye(p) 

52 mask = np.ones((p,)).astype(bool) 

53 alpha = 0.01 

54 alpha_ = alpha * X.shape[0] 

55 l1_ratio = 0.2 

56 l1_weight = alpha_ * l1_ratio 

57 

58 def f1(w): 

59 return squared_loss(X, y, w, compute_grad=False) 

60 

61 def f1_grad(w): 

62 return squared_loss(X, y, w, compute_grad=True, compute_energy=False) 

63 

64 def f2_prox(w, step_size, *args, **kwargs): # noqa: ARG001 

65 return prox_l1(w, step_size * l1_weight), {"converged": True} 

66 

67 def total_energy(w): 

68 return f1(w) + l1_weight * np.sum(np.abs(w)) 

69 

70 best_w, objective, init = mfista( 

71 f1_grad, 

72 f2_prox, 

73 total_energy, 

74 1.0, 

75 p, 

76 dgap_factor=dgap_factor, 

77 callback=lambda _: cb_retval, 

78 verbose=verbose, 

79 max_iter=100, 

80 ) 

81 

82 assert best_w.shape == mask.shape 

83 assert isinstance(objective, list) 

84 assert isinstance(init, dict) 

85 for key in ["w", "t", "dgap_tol", "stepsize"]: 

86 assert key in init