Coverage for nilearn/decomposition/tests/test_dict_learning.py: 0%

55 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1import numpy as np 

2import pytest 

3 

4from nilearn.decomposition.dict_learning import DictLearning 

5from nilearn.decomposition.tests.conftest import ( 

6 RANDOM_STATE, 

7 check_decomposition_estimator, 

8) 

9from nilearn.image import get_data, iter_img 

10from nilearn.maskers import NiftiMasker, SurfaceMasker 

11from nilearn.surface.surface import get_data as get_surface_data 

12 

13 

14@pytest.mark.parametrize("data_type", ["nifti", "surface"]) 

15@pytest.mark.parametrize("n_epochs", [1, 2, 10]) 

16def test_check_values_epoch_argument_smoke( 

17 decomposition_mask_img, n_epochs, canica_components, canica_data, data_type 

18): 

19 """Smoke test to check different values of the epoch argument.""" 

20 if data_type == "nifti": 

21 masker = NiftiMasker(mask_img=decomposition_mask_img).fit() 

22 mask = get_data(decomposition_mask_img) != 0 

23 else: 

24 masker = SurfaceMasker(mask_img=decomposition_mask_img).fit() 

25 mask = get_surface_data(decomposition_mask_img) != 0 

26 

27 flat_mask = mask.ravel() 

28 dict_init = masker.inverse_transform(canica_components[:, flat_mask]) 

29 

30 dict_learning = DictLearning( 

31 n_components=4, 

32 random_state=RANDOM_STATE, 

33 dict_init=dict_init, 

34 mask=decomposition_mask_img, 

35 n_epochs=n_epochs, 

36 smoothing_fwhm=None, 

37 alpha=1, 

38 ) 

39 dict_learning.fit(canica_data) 

40 

41 check_decomposition_estimator(dict_learning, data_type) 

42 

43 

44@pytest.mark.parametrize("data_type", ["nifti"]) 

45def test_dict_learning( 

46 decomposition_mask_img, canica_components, canica_data, data_type 

47): 

48 """Check content of components_img_.""" 

49 masker = NiftiMasker(mask_img=decomposition_mask_img).fit() 

50 mask = get_data(decomposition_mask_img) != 0 

51 flat_mask = mask.ravel() 

52 masked_components = canica_components[:, flat_mask] 

53 dict_init = masker.inverse_transform(masked_components) 

54 

55 # Note that 

56 # adding smoothing will make this test break 

57 smoothing_fwhm = None 

58 

59 dict_learning = DictLearning( 

60 n_components=4, 

61 random_state=RANDOM_STATE, 

62 dict_init=dict_init, 

63 mask=decomposition_mask_img, 

64 smoothing_fwhm=smoothing_fwhm, 

65 alpha=1, 

66 ) 

67 

68 dict_learning_auto_init = DictLearning( 

69 n_components=4, 

70 random_state=RANDOM_STATE, 

71 mask=decomposition_mask_img, 

72 n_epochs=10, 

73 smoothing_fwhm=smoothing_fwhm, 

74 alpha=1, 

75 ) 

76 maps = {} 

77 for estimator in [dict_learning, dict_learning_auto_init]: 

78 estimator.fit(canica_data) 

79 

80 check_decomposition_estimator(dict_learning, data_type) 

81 

82 maps[estimator] = get_data(estimator.components_img_) 

83 maps[estimator] = np.reshape( 

84 np.rollaxis(maps[estimator], 3, 0)[:, mask], (4, flat_mask.sum()) 

85 ) 

86 

87 for this_dict_learning in [dict_learning]: 

88 these_maps = maps[this_dict_learning] 

89 S = np.sqrt(np.sum(masked_components**2, axis=1)) 

90 S[S == 0] = 1 

91 masked_components /= S[:, np.newaxis] 

92 

93 S = np.sqrt(np.sum(these_maps**2, axis=1)) 

94 S[S == 0] = 1 

95 these_maps /= S[:, np.newaxis] 

96 

97 K = np.abs(masked_components.dot(these_maps.T)) 

98 recovered_maps = np.sum(K > 0.9) 

99 

100 assert recovered_maps >= 2 

101 

102 

103@pytest.mark.parametrize("data_type", ["nifti", "surface"]) 

104def test_component_sign( 

105 decomposition_mask_img, canica_data, data_type 

106) -> None: 

107 """Check sign of extracted components. 

108 

109 Regression test: 

110 We should have a heuristic that flips the sign of components in 

111 DictLearning to have more positive values than negative values, for 

112 instance by making sure that the largest value is positive. 

113 """ 

114 dict_learning = DictLearning( 

115 n_components=4, 

116 random_state=RANDOM_STATE, 

117 mask=decomposition_mask_img, 

118 smoothing_fwhm=None, 

119 alpha=1, 

120 ) 

121 dict_learning.fit(canica_data) 

122 

123 check_decomposition_estimator(dict_learning, data_type) 

124 

125 for mp in iter_img(dict_learning.components_img_): 

126 mp = get_data(mp) if data_type == "nifti" else get_surface_data(mp) 

127 assert np.sum(mp[mp <= 0]) <= np.sum(mp[mp > 0])