Coverage for nilearn/maskers/tests/test_multi_nifti_maps_masker.py: 0%

139 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Test the multi_nifti_maps_masker module.""" 

2 

3import numpy as np 

4import pytest 

5from numpy.testing import assert_almost_equal, assert_array_equal 

6from sklearn.utils.estimator_checks import parametrize_with_checks 

7 

8from nilearn._utils.data_gen import generate_fake_fmri, generate_maps 

9from nilearn._utils.estimator_checks import ( 

10 check_estimator, 

11 nilearn_check_estimator, 

12 return_expected_failed_checks, 

13) 

14from nilearn._utils.exceptions import DimensionError 

15from nilearn._utils.tags import SKLEARN_LT_1_6 

16from nilearn._utils.testing import write_imgs_to_path 

17from nilearn.conftest import _img_maps 

18from nilearn.maskers import MultiNiftiMapsMasker, NiftiMapsMasker 

19 

20ESTIMATORS_TO_CHECK = [MultiNiftiMapsMasker()] 

21 

22if SKLEARN_LT_1_6: 

23 

24 @pytest.mark.parametrize( 

25 "estimator, check, name", 

26 check_estimator(estimators=ESTIMATORS_TO_CHECK), 

27 ) 

28 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001 

29 """Check compliance with sklearn estimators.""" 

30 check(estimator) 

31 

32 @pytest.mark.xfail(reason="invalid checks should fail") 

33 @pytest.mark.parametrize( 

34 "estimator, check, name", 

35 check_estimator( 

36 estimators=ESTIMATORS_TO_CHECK, 

37 valid=False, 

38 ), 

39 ) 

40 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001 

41 """Check compliance with sklearn estimators.""" 

42 check(estimator) 

43 

44else: 

45 

46 @parametrize_with_checks( 

47 estimators=ESTIMATORS_TO_CHECK, 

48 expected_failed_checks=return_expected_failed_checks, 

49 ) 

50 def test_check_estimator_sklearn(estimator, check): 

51 """Check compliance with sklearn estimators.""" 

52 check(estimator) 

53 

54 

55@pytest.mark.timeout(0) 

56@pytest.mark.parametrize( 

57 "estimator, check, name", 

58 nilearn_check_estimator( 

59 estimators=[ 

60 # pass less than the default number of regions 

61 # to speed up the tests 

62 MultiNiftiMapsMasker(_img_maps(n_regions=2)), 

63 ] 

64 ), 

65) 

66def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

67 """Check compliance with nilearn estimators rules.""" 

68 check(estimator) 

69 

70 

71@pytest.mark.timeout(0) 

72def test_multi_nifti_maps_masker( 

73 affine_eye, length, n_regions, shape_3d_default, img_maps 

74): 

75 """Check basic functions of MultiNiftiMapsMasker. 

76 

77 - fit, transform, fit_transform, inverse_transform. 

78 - 4D and list[4D] inputs 

79 """ 

80 fmri11_img, mask11_img = generate_fake_fmri( 

81 shape_3d_default, affine=affine_eye, length=length 

82 ) 

83 

84 masker = MultiNiftiMapsMasker( 

85 img_maps, mask_img=mask11_img, resampling_target=None 

86 ) 

87 

88 signals11 = masker.fit_transform(fmri11_img) 

89 

90 assert signals11.shape == (length, n_regions) 

91 

92 MultiNiftiMapsMasker(img_maps).fit_transform(fmri11_img) 

93 

94 # Should work with 4D + 1D input too (also test fit_transform) 

95 signals_input = [fmri11_img, fmri11_img] 

96 

97 signals11_list = masker.fit_transform(signals_input) 

98 

99 for signals in signals11_list: 

100 assert signals.shape == (length, n_regions) 

101 

102 # Call inverse transform 

103 for signals in signals11_list: 

104 fmri11_img_r = masker.inverse_transform(signals) 

105 

106 assert fmri11_img_r.shape == fmri11_img.shape 

107 assert_almost_equal(fmri11_img_r.affine, fmri11_img.affine) 

108 

109 # Now try on a masker that has never seen the call to "transform" 

110 masker = MultiNiftiMapsMasker(img_maps, resampling_target=None) 

111 masker.fit() 

112 masker.inverse_transform(signals) 

113 

114 

115def test_multi_nifti_maps_masker_data_atlas_different_shape( 

116 affine_eye, length, img_maps 

117): 

118 """Test with data and atlas of different shape. 

119 

120 The atlas should be resampled to the data. 

121 """ 

122 # Check working of shape/affine checks 

123 shape2 = (12, 10, 14) 

124 shape22 = (5, 5, 6) 

125 affine2 = np.diag((1, 2, 3, 1)) 

126 affine2 = 2 * np.eye(4) 

127 affine2[-1, -1] = 1 

128 

129 _, mask21_img = generate_fake_fmri( 

130 shape2, affine=affine_eye, length=length 

131 ) 

132 fmri22_img, _ = generate_fake_fmri(shape22, affine=affine2, length=length) 

133 

134 masker = MultiNiftiMapsMasker(img_maps, mask_img=mask21_img) 

135 

136 masker.fit_transform(fmri22_img) 

137 

138 assert_array_equal(masker.maps_img_.affine, affine2) 

139 

140 

141def test_multi_nifti_maps_masker_errors( 

142 affine_eye, length, shape_3d_default, img_maps 

143): 

144 """Check errors raised by MultiNiftiMapsMasker.""" 

145 fmri11_img, mask11_img = generate_fake_fmri( 

146 shape_3d_default, affine=affine_eye, length=length 

147 ) 

148 

149 masker = MultiNiftiMapsMasker( 

150 img_maps, mask_img=mask11_img, resampling_target=None 

151 ) 

152 

153 signals_input = [fmri11_img, fmri11_img] 

154 

155 # NiftiMapsMasker should not work with 4D + 1D input 

156 masker = NiftiMapsMasker(img_maps, resampling_target=None) 

157 with pytest.raises(DimensionError, match="incompatible dimensionality"): 

158 masker.fit_transform(signals_input) 

159 

160 

161@pytest.mark.parametrize("create_files", [True, False]) 

162def test_multi_nifti_maps_masker_errors_field_of_view( 

163 tmp_path, 

164 affine_eye, 

165 length, 

166 create_files, 

167 shape_3d_default, 

168 img_maps, 

169): 

170 """Test all kinds of mismatches between shapes and between affines.""" 

171 # Check working of shape/affine checks 

172 shape2 = (12, 10, 14) 

173 affine2 = np.diag((1, 2, 3, 1)) 

174 

175 fmri12_img, mask12_img = generate_fake_fmri( 

176 shape_3d_default, affine=affine2, length=length 

177 ) 

178 fmri21_img, mask21_img = generate_fake_fmri( 

179 shape2, affine=affine_eye, length=length 

180 ) 

181 

182 error_msg = "Following field of view errors were detected" 

183 

184 masker = MultiNiftiMapsMasker( 

185 img_maps, mask_img=mask21_img, resampling_target=None 

186 ) 

187 with pytest.raises(ValueError, match=error_msg): 

188 masker.fit() 

189 

190 images = write_imgs_to_path( 

191 img_maps, 

192 mask12_img, 

193 file_path=tmp_path, 

194 create_files=create_files, 

195 ) 

196 labels11, mask12 = images 

197 masker = MultiNiftiMapsMasker(labels11, resampling_target=None) 

198 masker.fit() 

199 

200 with pytest.raises(ValueError, match=error_msg): 

201 masker.transform(fmri12_img) 

202 

203 with pytest.raises(ValueError, match=error_msg): 

204 masker.transform(fmri21_img) 

205 

206 masker = MultiNiftiMapsMasker( 

207 labels11, mask_img=mask12, resampling_target=None 

208 ) 

209 with pytest.raises(ValueError, match=error_msg): 

210 masker.fit() 

211 

212 

213def test_multi_nifti_maps_masker_resampling_error( 

214 affine_eye, n_regions, shape_3d_large 

215): 

216 """Test MultiNiftiMapsMasker when using resampling.""" 

217 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye) 

218 

219 # Test error checking 

220 masker = MultiNiftiMapsMasker(maps33_img, resampling_target="mask") 

221 with pytest.raises( 

222 ValueError, 

223 match=( 

224 "resampling_target has been set to 'mask' " 

225 "but no mask has been provided" 

226 ), 

227 ): 

228 masker.fit() 

229 

230 masker = MultiNiftiMapsMasker(maps33_img, resampling_target="invalid") 

231 with pytest.raises( 

232 ValueError, match="invalid value for 'resampling_target' parameter:" 

233 ): 

234 masker.fit() 

235 

236 

237@pytest.mark.timeout(0) 

238def test_multi_nifti_maps_masker_resampling_to_mask( 

239 shape_mask, 

240 affine_eye, 

241 length, 

242 n_regions, 

243 shape_3d_large, 

244 img_fmri, 

245): 

246 """Test resampling to mask in MultiNiftiMapsMasker.""" 

247 _, mask22_img = generate_fake_fmri( 

248 shape_mask, affine=affine_eye, length=length 

249 ) 

250 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye) 

251 

252 masker = MultiNiftiMapsMasker( 

253 maps33_img, mask_img=mask22_img, resampling_target="mask" 

254 ) 

255 

256 signals = masker.fit_transform([img_fmri, img_fmri]) 

257 

258 assert_almost_equal(masker.mask_img_.affine, mask22_img.affine) 

259 assert masker.mask_img_.shape == mask22_img.shape 

260 

261 assert_almost_equal(masker.mask_img_.affine, masker.maps_img_.affine) 

262 assert masker.mask_img_.shape == masker.maps_img_.shape[:3] 

263 

264 for t in signals: 

265 assert t.shape == (length, n_regions) 

266 

267 fmri11_img_r = masker.inverse_transform(t) 

268 

269 assert_almost_equal(fmri11_img_r.affine, masker.maps_img_.affine) 

270 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,)) 

271 

272 

273def test_multi_nifti_maps_masker_resampling_to_maps( 

274 shape_mask, 

275 affine_eye, 

276 length, 

277 n_regions, 

278 shape_3d_large, 

279 img_fmri, 

280): 

281 """Test resampling to maps in MultiNiftiMapsMasker.""" 

282 _, mask22_img = generate_fake_fmri( 

283 shape_mask, affine=affine_eye, length=length 

284 ) 

285 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye) 

286 

287 masker = MultiNiftiMapsMasker( 

288 maps33_img, mask_img=mask22_img, resampling_target="maps" 

289 ) 

290 

291 signals = masker.fit_transform([img_fmri, img_fmri]) 

292 

293 assert_almost_equal(masker.maps_img_.affine, maps33_img.affine) 

294 assert masker.maps_img_.shape == maps33_img.shape 

295 

296 assert_almost_equal(masker.mask_img_.affine, masker.maps_img_.affine) 

297 assert masker.mask_img_.shape == masker.maps_img_.shape[:3] 

298 

299 for t in signals: 

300 assert t.shape == (length, n_regions) 

301 

302 fmri11_img_r = masker.inverse_transform(t) 

303 

304 assert_almost_equal(fmri11_img_r.affine, masker.maps_img_.affine) 

305 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,)) 

306 

307 

308def test_multi_nifti_maps_masker_resampling_clipped_mask( 

309 affine_eye, length, n_regions, img_fmri 

310): 

311 """Test with clipped maps: mask does not contain all maps.""" 

312 # Shapes do matter in that case 

313 shape2 = (8, 9, 10) # mask 

314 shape3 = (16, 18, 20) # maps 

315 affine2 = np.diag((2, 2, 2, 1)) # just for mask 

316 

317 _, mask22_img = generate_fake_fmri(shape2, length=1, affine=affine2) 

318 maps33_img, _ = generate_maps(shape3, n_regions, affine=affine_eye) 

319 

320 masker = MultiNiftiMapsMasker( 

321 maps33_img, mask_img=mask22_img, resampling_target="maps" 

322 ) 

323 

324 signals = masker.fit_transform([img_fmri, img_fmri]) 

325 

326 assert_almost_equal(masker.maps_img_.affine, maps33_img.affine) 

327 assert masker.maps_img_.shape == maps33_img.shape 

328 

329 assert_almost_equal(masker.mask_img_.affine, masker.maps_img_.affine) 

330 assert masker.mask_img_.shape == masker.maps_img_.shape[:3] 

331 

332 for t in signals: 

333 assert t.shape == (length, n_regions) 

334 # Some regions have been clipped. Resulting signal must be zero 

335 assert (t.var(axis=0) == 0).sum() < n_regions 

336 

337 fmri11_img_r = masker.inverse_transform(t) 

338 

339 assert_almost_equal(fmri11_img_r.affine, masker.maps_img_.affine) 

340 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,))