Coverage for nilearn/decoding/tests/test_fista.py: 0%
51 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1import numpy as np
2import pytest
4from nilearn.decoding._objective_functions import (
5 logistic_loss,
6 logistic_loss_lipschitz_constant,
7 spectral_norm_squared,
8 squared_loss,
9 squared_loss_grad,
10)
11from nilearn.decoding._proximal_operators import prox_l1
12from nilearn.decoding.fista import _check_lipschitz_continuous, mfista
15@pytest.mark.parametrize("scaling", list(np.logspace(-3, 3, num=7)))
16def test_logistic_lipschitz(rng, scaling, n_samples=4, n_features=2):
17 X = rng.standard_normal((n_samples, n_features)) * scaling
18 y = rng.standard_normal(n_samples)
19 n_features = X.shape[1]
21 L = logistic_loss_lipschitz_constant(X)
22 _check_lipschitz_continuous(
23 lambda w: logistic_loss(X, y, w), n_features + 1, L
24 )
27@pytest.mark.parametrize("scaling", list(np.logspace(-3, 3, num=7)))
28def test_squared_loss_lipschitz(rng, scaling, n_samples=4, n_features=2):
29 X = rng.standard_normal((n_samples, n_features)) * scaling
30 y = rng.standard_normal(n_samples)
31 n_features = X.shape[1]
33 L = spectral_norm_squared(X)
34 _check_lipschitz_continuous(
35 lambda w: squared_loss_grad(X, y, w), n_features, L
36 )
39@pytest.mark.parametrize("cb_retval", [0, 1])
40@pytest.mark.parametrize("verbose", [0, 2])
41@pytest.mark.parametrize("dgap_factor", [1.0, None])
42def test_input_args_and_kwargs(cb_retval, verbose, dgap_factor, rng):
43 p = 125
44 noise_std = 1e-1
45 sig = np.zeros(p)
46 sig[[0, 2, 13, 4, 25, 32, 80, 89, 91, 93, -1]] = 1
47 sig[:6] = 2
48 sig[-7:] = 2
49 sig[60:75] = 1
50 y = sig + noise_std * rng.standard_normal(sig.shape)
51 X = np.eye(p)
52 mask = np.ones((p,)).astype(bool)
53 alpha = 0.01
54 alpha_ = alpha * X.shape[0]
55 l1_ratio = 0.2
56 l1_weight = alpha_ * l1_ratio
58 def f1(w):
59 return squared_loss(X, y, w, compute_grad=False)
61 def f1_grad(w):
62 return squared_loss(X, y, w, compute_grad=True, compute_energy=False)
64 def f2_prox(w, step_size, *args, **kwargs): # noqa: ARG001
65 return prox_l1(w, step_size * l1_weight), {"converged": True}
67 def total_energy(w):
68 return f1(w) + l1_weight * np.sum(np.abs(w))
70 best_w, objective, init = mfista(
71 f1_grad,
72 f2_prox,
73 total_energy,
74 1.0,
75 p,
76 dgap_factor=dgap_factor,
77 callback=lambda _: cb_retval,
78 verbose=verbose,
79 max_iter=100,
80 )
82 assert best_w.shape == mask.shape
83 assert isinstance(objective, list)
84 assert isinstance(init, dict)
85 for key in ["w", "t", "dgap_tol", "stepsize"]:
86 assert key in init