Coverage for nilearn/maskers/tests/test_multi_nifti_masker.py: 0%

146 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Test the multi_nifti_masker module.""" 

2 

3import shutil 

4from tempfile import mkdtemp 

5 

6import numpy as np 

7import pytest 

8from joblib import Memory, hash 

9from nibabel import Nifti1Image 

10from numpy.testing import assert_array_equal 

11from sklearn.utils.estimator_checks import parametrize_with_checks 

12 

13from nilearn._utils.estimator_checks import ( 

14 check_estimator, 

15 nilearn_check_estimator, 

16 return_expected_failed_checks, 

17) 

18from nilearn._utils.tags import SKLEARN_LT_1_6 

19from nilearn._utils.testing import write_imgs_to_path 

20from nilearn.image import get_data 

21from nilearn.maskers import MultiNiftiMasker 

22 

23ESTIMATORS_TO_CHECK = [MultiNiftiMasker()] 

24 

25if SKLEARN_LT_1_6: 

26 

27 @pytest.mark.parametrize( 

28 "estimator, check, name", 

29 check_estimator(estimators=ESTIMATORS_TO_CHECK), 

30 ) 

31 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001 

32 """Check compliance with sklearn estimators.""" 

33 check(estimator) 

34 

35 @pytest.mark.xfail(reason="invalid checks should fail") 

36 @pytest.mark.parametrize( 

37 "estimator, check, name", 

38 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False), 

39 ) 

40 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001 

41 """Check compliance with sklearn estimators.""" 

42 check(estimator) 

43 

44else: 

45 

46 @parametrize_with_checks( 

47 estimators=ESTIMATORS_TO_CHECK, 

48 expected_failed_checks=return_expected_failed_checks, 

49 ) 

50 def test_check_estimator_sklearn(estimator, check): 

51 """Check compliance with sklearn estimators.""" 

52 check(estimator) 

53 

54 

55# check_multi_masker_transformer_high_variance_confounds is slow 

56@pytest.mark.timeout(0) 

57@pytest.mark.parametrize( 

58 "estimator, check, name", 

59 nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK), 

60) 

61def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

62 """Check compliance with sklearn estimators.""" 

63 check(estimator) 

64 

65 

66@pytest.fixture 

67def data_2(shape_3d_default): 

68 """Return 3D zeros with a few 10 in the center.""" 

69 data = np.zeros(shape_3d_default) 

70 data[1:-2, 1:-2, 1:-2] = 10 

71 return data 

72 

73 

74@pytest.fixture 

75def img_1(data_1, affine_eye): 

76 """Return Nifti image of 3D zeros with a few 10 in the center.""" 

77 return Nifti1Image(data_1, affine_eye) 

78 

79 

80@pytest.fixture 

81def img_2(data_2, affine_eye): 

82 """Return Nifti image of 3D zeros with a few 10 in the center.""" 

83 return Nifti1Image(data_2, affine_eye) 

84 

85 

86def test_auto_mask(data_1, img_1, data_2, img_2): 

87 """Test that a proper mask is generated from fitted image.""" 

88 masker = MultiNiftiMasker(mask_args={"opening": 0}) 

89 

90 # Smoke test the fit 

91 masker.fit([[img_1]]) 

92 

93 # Test mask intersection 

94 masker.fit([[img_1, img_2]]) 

95 

96 assert_array_equal( 

97 get_data(masker.mask_img_), np.logical_or(data_1, data_2) 

98 ) 

99 

100 # Smoke test the transform 

101 masker.transform([[img_1]]) 

102 # It should also work with a 3D image 

103 masker.transform(img_1) 

104 

105 

106def test_nan(): 

107 """Check when fitted data contains nan.""" 

108 data = np.ones((9, 9, 9)) 

109 data[0] = np.nan 

110 data[:, 0] = np.nan 

111 data[:, :, 0] = np.nan 

112 data[-1] = np.nan 

113 data[:, -1] = np.nan 

114 data[:, :, -1] = np.nan 

115 data[3:-3, 3:-3, 3:-3] = 10 

116 img = Nifti1Image(data, np.eye(4)) 

117 

118 masker = MultiNiftiMasker(mask_args={"opening": 0}) 

119 masker.fit([img]) 

120 

121 mask = get_data(masker.mask_img_) 

122 

123 assert mask[1:-1, 1:-1, 1:-1].all() 

124 assert not mask[0].any() 

125 assert not mask[:, 0].any() 

126 assert not mask[:, :, 0].any() 

127 assert not mask[-1].any() 

128 assert not mask[:, -1].any() 

129 assert not mask[:, :, -1].any() 

130 

131 

132def test_different_affines(): 

133 """Check mask and EIP files with different affines.""" 

134 mask_img = Nifti1Image( 

135 np.ones((2, 2, 2), dtype=np.int8), affine=np.diag((4, 4, 4, 1)) 

136 ) 

137 epi_img1 = Nifti1Image(np.ones((4, 4, 4, 3)), affine=np.diag((2, 2, 2, 1))) 

138 epi_img2 = Nifti1Image(np.ones((3, 3, 3, 3)), affine=np.diag((3, 3, 3, 1))) 

139 

140 masker = MultiNiftiMasker(mask_img=mask_img) 

141 epis = masker.fit_transform([epi_img1, epi_img2]) 

142 for this_epi in epis: 

143 masker.inverse_transform(this_epi) 

144 

145 

146def test_3d_images(rng): 

147 """Test that the MultiNiftiMasker works with 3D images. 

148 

149 Note that fit() requires all images in list to have the same affine. 

150 """ 

151 mask_img = Nifti1Image( 

152 np.ones((2, 2, 2), dtype=np.int8), affine=np.diag((2, 2, 2, 1)) 

153 ) 

154 epi_img1 = Nifti1Image(rng.random((2, 2, 2)), affine=np.diag((4, 4, 4, 1))) 

155 epi_img2 = Nifti1Image(rng.random((2, 2, 2)), affine=np.diag((4, 4, 4, 1))) 

156 masker = MultiNiftiMasker(mask_img=mask_img) 

157 

158 masker.fit_transform([epi_img1, epi_img2]) 

159 

160 

161def test_joblib_cache(mask_img_1, tmp_path): 

162 """Check cached data.""" 

163 filename = write_imgs_to_path( 

164 mask_img_1, file_path=tmp_path, create_files=True 

165 ) 

166 masker = MultiNiftiMasker(mask_img=filename) 

167 masker.fit() 

168 mask_hash = hash(masker.mask_img_) 

169 get_data(masker.mask_img_) 

170 

171 assert mask_hash == hash(masker.mask_img_) 

172 

173 

174@pytest.mark.timeout(0) 

175def test_shelving(rng): 

176 """Check behavior when shelving masker.""" 

177 mask_img = Nifti1Image( 

178 np.ones((2, 2, 2), dtype=np.int8), affine=np.diag((2, 2, 2, 1)) 

179 ) 

180 epi_img1 = Nifti1Image(rng.random((2, 2, 2)), affine=np.diag((4, 4, 4, 1))) 

181 epi_img2 = Nifti1Image(rng.random((2, 2, 2)), affine=np.diag((4, 4, 4, 1))) 

182 cachedir = mkdtemp() 

183 try: 

184 masker_shelved = MultiNiftiMasker( 

185 mask_img=mask_img, 

186 memory=Memory(location=cachedir, mmap_mode="r", verbose=0), 

187 ) 

188 masker_shelved._shelving = True 

189 epis_shelved = masker_shelved.fit_transform([epi_img1, epi_img2]) 

190 masker = MultiNiftiMasker(mask_img=mask_img) 

191 epis = masker.fit_transform([epi_img1, epi_img2]) 

192 

193 for epi_shelved, epi in zip(epis_shelved, epis): 

194 epi_shelved = epi_shelved.get() 

195 assert_array_equal(epi_shelved, epi) 

196 

197 epi = masker.fit_transform(epi_img1) 

198 epi_shelved = masker_shelved.fit_transform(epi_img1) 

199 epi_shelved = epi_shelved.get() 

200 

201 assert_array_equal(epi_shelved, epi) 

202 

203 finally: 

204 # enables to delete "filename" on windows 

205 del masker 

206 shutil.rmtree(cachedir, ignore_errors=True) 

207 

208 

209@pytest.fixture 

210def list_random_imgs(img_3d_rand_eye): 

211 """Create a list of random 3D nifti images.""" 

212 return [img_3d_rand_eye] * 2 

213 

214 

215def test_mask_strategy_errors(list_random_imgs): 

216 """Throw error with unknown mask_strategy.""" 

217 mask = MultiNiftiMasker(mask_strategy="foo") 

218 

219 with pytest.raises( 

220 ValueError, match="Unknown value of mask_strategy 'foo'" 

221 ): 

222 mask.fit(list_random_imgs) 

223 

224 # Warning with deprecated 'template' strategy, 

225 # plus an exception because there's no resulting mask 

226 mask = MultiNiftiMasker(mask_strategy="template") 

227 with pytest.warns( 

228 UserWarning, match="Masking strategy 'template' is deprecated." 

229 ): 

230 mask.fit(list_random_imgs) 

231 

232 

233@pytest.mark.parametrize( 

234 "strategy", [f"{p}-template" for p in ["whole-brain", "gm", "wm"]] 

235) 

236def test_compute_mask_strategy(strategy, shape_3d_default, list_random_imgs): 

237 """Check different strategies to compute masks.""" 

238 masker = MultiNiftiMasker(mask_strategy=strategy, mask_args={"opening": 1}) 

239 masker.fit(list_random_imgs) 

240 

241 # Check that the order of the images does not change the output 

242 masker2 = MultiNiftiMasker( 

243 mask_strategy=strategy, mask_args={"opening": 1} 

244 ) 

245 masker2.fit(list_random_imgs[::-1]) 

246 mask_ref = np.zeros(shape_3d_default, dtype="int8") 

247 

248 np.testing.assert_array_equal(get_data(masker.mask_img_), mask_ref) 

249 np.testing.assert_array_equal(get_data(masker2.mask_img_), mask_ref) 

250 

251 

252def test_standardization(rng, shape_3d_default, affine_eye): 

253 """Check output properly standardized with 'standardize' parameter.""" 

254 n_samples = 500 

255 

256 signals = rng.standard_normal( 

257 size=(2, np.prod(shape_3d_default), n_samples) 

258 ) 

259 means = ( 

260 rng.standard_normal(size=(2, np.prod(shape_3d_default), 1)) * 50 + 1000 

261 ) 

262 signals += means 

263 

264 img1 = Nifti1Image( 

265 signals[0].reshape((*shape_3d_default, n_samples)), affine_eye 

266 ) 

267 img2 = Nifti1Image( 

268 signals[1].reshape((*shape_3d_default, n_samples)), affine_eye 

269 ) 

270 

271 mask = Nifti1Image(np.ones(shape_3d_default), affine_eye) 

272 

273 # z-score 

274 masker = MultiNiftiMasker(mask, standardize="zscore_sample") 

275 trans_signals = masker.fit_transform([img1, img2]) 

276 

277 for ts in trans_signals: 

278 np.testing.assert_almost_equal(ts.mean(0), 0) 

279 np.testing.assert_almost_equal(ts.std(0), 1, decimal=3) 

280 

281 # psc 

282 masker = MultiNiftiMasker(mask, standardize="psc") 

283 trans_signals = masker.fit_transform([img1, img2]) 

284 

285 for ts, s in zip(trans_signals, signals): 

286 np.testing.assert_almost_equal(ts.mean(0), 0) 

287 np.testing.assert_almost_equal( 

288 ts, (s / s.mean(1)[:, np.newaxis] * 100 - 100).T 

289 )