Coverage for nilearn/decoding/tests/test_operators.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import itertools 

2 

3import numpy as np 

4import pytest 

5from numpy.testing import assert_almost_equal 

6 

7from nilearn.decoding._proximal_operators import prox_l1, prox_tvl1 

8 

9 

10def test_prox_l1_nonexpansiveness(rng, n_features=10): 

11 x = rng.standard_normal((n_features, 1)) 

12 tau = 0.3 

13 s = prox_l1(x.copy(), tau) 

14 p = x - s # projection + shrinkage = id 

15 

16 # We should have ||s(a) - s(b)||^2 <= ||a - b||^2 - ||p(a) - p(b)||^2 

17 # for all a and b (this is strong non-expansiveness 

18 for (a, b), (pa, pb), (sa, sb) in zip( 

19 *[itertools.product(z[0], z[0]) for z in [x, p, s]] 

20 ): 

21 assert (sa - sb) ** 2 <= (a - b) ** 2 - (pa - pb) ** 2 

22 

23 

24@pytest.mark.parametrize("ndim", range(3, 4)) 

25@pytest.mark.parametrize("weight", np.logspace(-10, 10, num=10)) 

26def test_prox_tvl1_approximates_prox_l1_for_lasso( 

27 rng, ndim, weight, size=15, decimal=4, dgap_tol=1e-7 

28): 

29 l1_ratio = 1.0 # pure LASSO 

30 

31 shape = [size] * ndim 

32 z = rng.standard_normal(shape) 

33 

34 # use prox_tvl1 approximation to prox_l1 

35 a = prox_tvl1( 

36 z.copy(), 

37 weight=weight, 

38 l1_ratio=l1_ratio, 

39 dgap_tol=dgap_tol, 

40 max_iter=10, 

41 )[0][-1].ravel() 

42 

43 # use exact closed-form soft shrinkage formula for prox_l1 

44 b = prox_l1(z.copy(), weight)[-1].ravel() 

45 

46 # results should be close in l-infinity norm 

47 assert_almost_equal(np.abs(a - b).max(), 0.0, decimal=decimal) 

48 

49 

50@pytest.mark.parametrize("verbose", [True, False]) 

51def test_prox_tvl1_verbose(rng, verbose): 

52 l1_ratio = 1.0 # pure LASSO 

53 

54 size = 15 

55 dgap_tol = 1e-7 

56 ndim = 3 

57 weight = -10 

58 

59 shape = [size] * ndim 

60 z = rng.standard_normal(shape) 

61 

62 prox_tvl1( 

63 z.copy(), 

64 weight=weight, 

65 l1_ratio=l1_ratio, 

66 dgap_tol=dgap_tol, 

67 max_iter=10, 

68 val_min=-np.inf, 

69 val_max=np.inf, 

70 verbose=verbose, 

71 x_tol=1e-7, 

72 )