Coverage for nilearn/regions/tests/test_rena_clustering.py: 0%

107 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1import numpy as np 

2import pytest 

3from joblib import Memory 

4from numpy.testing import assert_array_equal 

5from sklearn.utils.estimator_checks import parametrize_with_checks 

6 

7from nilearn._utils.data_gen import generate_fake_fmri 

8from nilearn._utils.estimator_checks import ( 

9 check_estimator, 

10 nilearn_check_estimator, 

11 return_expected_failed_checks, 

12) 

13from nilearn._utils.tags import SKLEARN_LT_1_6 

14from nilearn.conftest import _img_3d_mni, _shape_3d_default 

15from nilearn.image import get_data 

16from nilearn.maskers import NiftiMasker, SurfaceMasker 

17from nilearn.regions.rena_clustering import ( 

18 ReNA, 

19 _make_edges_and_weights_surface, 

20 make_edges_surface, 

21) 

22from nilearn.surface import SurfaceImage 

23 

24ESTIMATORS_TO_CHECK = [ReNA()] 

25 

26if SKLEARN_LT_1_6: 

27 

28 @pytest.mark.parametrize( 

29 "estimator, check, name", 

30 check_estimator(estimators=ESTIMATORS_TO_CHECK), 

31 ) 

32 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001 

33 """Check compliance with sklearn estimators.""" 

34 check(estimator) 

35 

36 @pytest.mark.xfail(reason="invalid checks should fail") 

37 @pytest.mark.parametrize( 

38 "estimator, check, name", 

39 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False), 

40 ) 

41 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001 

42 """Check compliance with sklearn estimators.""" 

43 check(estimator) 

44 

45else: 

46 

47 @parametrize_with_checks( 

48 estimators=ESTIMATORS_TO_CHECK, 

49 expected_failed_checks=return_expected_failed_checks, 

50 ) 

51 def test_check_estimator_sklearn(estimator, check): 

52 """Check compliance with sklearn estimators.""" 

53 check(estimator) 

54 

55 

56@pytest.mark.parametrize( 

57 "estimator, check, name", 

58 nilearn_check_estimator( 

59 estimators=[ReNA(mask_img=_img_3d_mni(), n_clusters=2)] 

60 ), 

61) 

62def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

63 """Check compliance with nilearn estimators rules.""" 

64 check(estimator) 

65 

66 

67def test_rena_clustering_mask_error(): 

68 """Check an error is raised if no mask is provided before fit.""" 

69 data_img, mask_img = generate_fake_fmri( 

70 shape=_shape_3d_default(), length=5 

71 ) 

72 rena = ReNA(n_clusters=10) 

73 

74 data = get_data(data_img) 

75 mask = get_data(mask_img) 

76 

77 X = np.empty((data.shape[3], int(mask.sum()))) 

78 for i in range(data.shape[3]): 

79 X[i, :] = np.copy(data[:, :, :, i])[get_data(mask_img) != 0] 

80 

81 with pytest.raises(TypeError, match="The mask image should be a"): 

82 rena.fit_transform(X) 

83 

84 

85def test_rena_clustering(): 

86 data_img, mask_img = generate_fake_fmri(shape=(10, 11, 12), length=5) 

87 

88 data = get_data(data_img) 

89 mask = get_data(mask_img) 

90 

91 X = np.empty((data.shape[3], int(mask.sum()))) 

92 for i in range(data.shape[3]): 

93 X[i, :] = np.copy(data[:, :, :, i])[get_data(mask_img) != 0] 

94 

95 nifti_masker = NiftiMasker(mask_img=mask_img).fit() 

96 n_voxels = nifti_masker.transform(data_img).shape[1] 

97 

98 rena = ReNA(mask_img, n_clusters=10) 

99 

100 X_red = rena.fit_transform(X) 

101 X_compress = rena.inverse_transform(X_red) 

102 

103 assert rena.n_clusters_ == 10 

104 assert X.shape == X_compress.shape 

105 

106 memory = Memory(location=None) 

107 rena = ReNA(mask_img, n_clusters=-2, memory=memory) 

108 with pytest.raises(ValueError): 

109 rena.fit(X) 

110 

111 rena = ReNA(mask_img, n_clusters=10, scaling=True) 

112 X_red = rena.fit_transform(X) 

113 X_compress = rena.inverse_transform(X_red) 

114 

115 for n_iter in [-2, 0]: 

116 rena = ReNA(mask_img, n_iter=n_iter, memory=memory) 

117 with pytest.raises(ValueError): 

118 rena.fit(X) 

119 

120 for n_clusters in [1, 2, 4, 8]: 

121 rena = ReNA( 

122 mask_img, n_clusters=n_clusters, n_iter=1, memory=memory 

123 ).fit(X) 

124 assert n_clusters != rena.n_clusters_ 

125 

126 del n_voxels, X_red, X_compress 

127 

128 

129# ------------------------ surface tests ------------------------------------ # 

130 

131 

132@pytest.mark.parametrize("part", ["left", "right"]) 

133def test_make_edges_surface(surf_mask_1d, part): 

134 """Test if the edges and edge mask are correctly computed.""" 

135 faces = surf_mask_1d.mesh.parts[part].faces 

136 # the mask for left part has total 4 vertices out of which 2 are True 

137 # and for right part it has total 5 vertices out of which 3 are True 

138 mask = surf_mask_1d.data.parts[part] 

139 edges_unmasked, edges_mask = make_edges_surface(faces, mask) 

140 

141 # only one edge remains after masking the left part (between 2 vertices) 

142 if part == "left": 

143 assert edges_unmasked[:, edges_mask].shape == (2, 1) 

144 # three edges remain after masking the right part (between 3 vertices) 

145 elif part == "right": 

146 assert edges_unmasked[:, edges_mask].shape == (2, 3) 

147 

148 

149def test_make_edges_and_weights_surface(surf_mesh, surf_img_2d): 

150 """Smoke test for _make_edges_and_weights_surface. Here we create a new 

151 surface mask (relative to the one used in test_make_edges_surface) to make 

152 sure overall edge and weight computation is robust. 

153 """ 

154 # make a new mask for this test 

155 # the mask for left part has total 4 vertices out of which 3 are True 

156 # and for right part it has total 5 vertices out of which 3 are True 

157 data = { 

158 "left": np.array([False, True, True, True]), 

159 "right": np.array([True, True, False, True, False]), 

160 } 

161 surf_mask_1d = SurfaceImage(surf_mesh, data) 

162 # create a surface masker 

163 masker = SurfaceMasker(surf_mask_1d).fit() 

164 # mask the surface image with 50 samples 

165 X = masker.transform(surf_img_2d(50)) 

166 # compute edges and weights 

167 edges, weights = _make_edges_and_weights_surface(X, surf_mask_1d) 

168 

169 # make sure edges and weights have two parts, left and right 

170 assert len(edges) == 2 

171 assert len(weights) == 2 

172 for part in ["left", "right"]: 

173 assert part in edges 

174 assert part in weights 

175 

176 # make sure there are no overlapping indices between left and right parts 

177 assert np.intersect1d(edges["left"], edges["right"]).size == 0 

178 

179 # three edges remain after masking the left part (between 3 vertices) 

180 # these would be the edges between 0th and 1st, 1st and 2nd, 

181 # and 0th and 2nd vertices of the adjacency matrix 

182 assert_array_equal(edges["left"], np.array([[0, 1, 0], [1, 2, 2]])) 

183 # three edges remain after masking the right part (between 3 vertices) 

184 # these would be the edges between 3rd and 4th, 3rd and 5th, 

185 # and 4th and 5th vertices of the adjacency matrix 

186 assert_array_equal(edges["right"], np.array([[3, 3, 4], [4, 5, 5]])) 

187 

188 # weights are computed for each edge 

189 assert len(weights["left"]) == 3 

190 assert len(weights["right"]) == 3 

191 

192 

193@pytest.mark.parametrize("surf_mask_dim", [1, 2]) 

194@pytest.mark.parametrize("mask_as", ["surface_image", "surface_masker"]) 

195@pytest.mark.parametrize("n_clusters", [2, 4, 5]) 

196def test_rena_clustering_input_mask_surface( 

197 surf_img_2d, surf_mask_dim, surf_mask_1d, surf_mask_2d, mask_as, n_clusters 

198): 

199 """Test if ReNA clustering works in both cases when mask_img is either a 

200 SurfaceImage or SurfaceMasker. 

201 """ 

202 surf_mask = surf_mask_1d if surf_mask_dim == 1 else surf_mask_2d() 

203 # create a surface masker 

204 masker = SurfaceMasker(surf_mask).fit() 

205 # mask the surface image with 50 samples 

206 X = masker.transform(surf_img_2d(50)) 

207 if mask_as == "surface_image": 

208 # instantiate ReNA with mask_img as a SurfaceImage 

209 clustering = ReNA(mask_img=surf_mask, n_clusters=n_clusters) 

210 elif mask_as == "surface_masker": 

211 # instantiate ReNA with mask_img as a SurfaceMasker 

212 clustering = ReNA(mask_img=masker, n_clusters=n_clusters) 

213 # fit and transform the data 

214 X_transformed = clustering.fit_transform(X) 

215 # inverse transform the transformed data 

216 X_inverse = clustering.inverse_transform(X_transformed) 

217 

218 # make sure the n_features in transformed data were reduced to n_clusters 

219 assert X_transformed.shape[1] == n_clusters 

220 

221 # make sure the inverse transformed data has the same shape as the original 

222 assert X_inverse.shape == X.shape