Coverage for nilearn/maskers/tests/test_nifti_maps_masker.py: 0%

184 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Test nilearn.maskers.nifti_maps_masker. 

2 

3Functions in this file only test features added by the NiftiMapsMasker class, 

4rather than the underlying functions (clean(), img_to_signals_labels(), etc.). 

5 

6See test_masking.py and test_signal.py for details. 

7""" 

8 

9import numpy as np 

10import pytest 

11from nibabel import Nifti1Image 

12from numpy.testing import assert_almost_equal, assert_array_equal 

13from sklearn.utils.estimator_checks import parametrize_with_checks 

14 

15from nilearn._utils.data_gen import ( 

16 generate_fake_fmri, 

17 generate_maps, 

18 generate_random_img, 

19) 

20from nilearn._utils.estimator_checks import ( 

21 check_estimator, 

22 nilearn_check_estimator, 

23 return_expected_failed_checks, 

24) 

25from nilearn._utils.tags import SKLEARN_LT_1_6 

26from nilearn._utils.testing import write_imgs_to_path 

27from nilearn.conftest import _img_maps, _shape_3d_default 

28from nilearn.image import get_data 

29from nilearn.maskers import NiftiMapsMasker 

30 

31ESTIMATORS_TO_CHECK = [NiftiMapsMasker()] 

32 

33if SKLEARN_LT_1_6: 

34 

35 @pytest.mark.parametrize( 

36 "estimator, check, name", 

37 check_estimator(estimators=ESTIMATORS_TO_CHECK), 

38 ) 

39 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001 

40 """Check compliance with sklearn estimators.""" 

41 check(estimator) 

42 

43 @pytest.mark.xfail(reason="invalid checks should fail") 

44 @pytest.mark.parametrize( 

45 "estimator, check, name", 

46 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False), 

47 ) 

48 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001 

49 """Check compliance with sklearn estimators.""" 

50 check(estimator) 

51 

52else: 

53 

54 @parametrize_with_checks( 

55 estimators=ESTIMATORS_TO_CHECK, 

56 expected_failed_checks=return_expected_failed_checks, 

57 ) 

58 def test_check_estimator_sklearn(estimator, check): 

59 """Check compliance with sklearn estimators.""" 

60 check(estimator) 

61 

62 

63@pytest.mark.timeout(0) 

64@pytest.mark.parametrize( 

65 "estimator, check, name", 

66 nilearn_check_estimator( 

67 estimators=[ # pass less than the default number of regions 

68 # to speed up the tests 

69 NiftiMapsMasker(maps_img=_img_maps(n_regions=2)) 

70 ] 

71 ), 

72) 

73def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

74 """Check compliance with sklearn estimators.""" 

75 check(estimator) 

76 

77 

78def test_nifti_maps_masker_data_atlas_different_shape( 

79 length, affine_eye, img_maps 

80): 

81 """Test with data and atlas of different shape. 

82 

83 The atlas should be resampled to the data. 

84 """ 

85 shape2 = (12, 10, 14) 

86 

87 shape22 = (5, 5, 6) 

88 

89 affine2 = 2 * affine_eye 

90 affine2[-1, -1] = 1 

91 

92 _, mask21_img = generate_fake_fmri( 

93 shape2, affine=affine_eye, length=length 

94 ) 

95 fmri22_img, _ = generate_fake_fmri(shape22, affine=affine2, length=length) 

96 

97 masker = NiftiMapsMasker(img_maps, mask_img=mask21_img) 

98 

99 masker.fit(fmri22_img) 

100 

101 assert_array_equal(masker.maps_img_.affine, affine2) 

102 

103 

104def test_nifti_maps_masker_fit(n_regions, img_maps): 

105 """Check fitted attributes.""" 

106 masker = NiftiMapsMasker(img_maps, resampling_target=None) 

107 

108 masker.fit() 

109 

110 # Check attributes defined at fit 

111 assert masker.n_elements_ == n_regions 

112 

113 

114def test_nifti_maps_masker_errors(): 

115 """Check fitting errors.""" 

116 masker = NiftiMapsMasker() 

117 with pytest.raises(TypeError, match="input should be a NiftiLike object"): 

118 masker.fit() 

119 

120 

121@pytest.mark.parametrize("create_files", (True, False)) 

122def test_nifti_maps_masker_errors_field_of_view( 

123 tmp_path, length, affine_eye, shape_3d_default, create_files, img_maps 

124): 

125 """Check field of view errors.""" 

126 shape2 = (12, 10, 14) 

127 affine2 = np.diag((1, 2, 3, 1)) 

128 

129 fmri12_img, mask12_img = generate_fake_fmri( 

130 shape_3d_default, affine=affine2, length=length 

131 ) 

132 fmri21_img, mask21_img = generate_fake_fmri( 

133 shape2, affine=affine_eye, length=length 

134 ) 

135 

136 error_msg = "Following field of view errors were detected" 

137 

138 masker = NiftiMapsMasker( 

139 img_maps, mask_img=mask21_img, resampling_target=None 

140 ) 

141 with pytest.raises(ValueError, match=error_msg): 

142 masker.fit() 

143 

144 # Test all kinds of mismatches between shapes and between affines 

145 images = write_imgs_to_path( 

146 img_maps, 

147 mask12_img, 

148 file_path=tmp_path, 

149 create_files=create_files, 

150 ) 

151 labels11, mask12 = images 

152 

153 masker = NiftiMapsMasker(labels11, resampling_target=None) 

154 

155 with pytest.raises(ValueError, match=error_msg): 

156 masker.fit_transform(fmri12_img) 

157 

158 with pytest.raises(ValueError, match=error_msg): 

159 masker.fit_transform(fmri21_img) 

160 

161 masker = NiftiMapsMasker(labels11, mask_img=mask12, resampling_target=None) 

162 with pytest.raises(ValueError, match=error_msg): 

163 masker.fit() 

164 

165 

166def test_nifti_maps_masker_resampling_errors( 

167 n_regions, affine_eye, shape_3d_large 

168): 

169 """Test resampling errors.""" 

170 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye) 

171 

172 masker = NiftiMapsMasker(maps33_img, resampling_target="mask") 

173 

174 with pytest.raises( 

175 ValueError, 

176 match=( 

177 "resampling_target has been set to 'mask' " 

178 "but no mask has been provided." 

179 ), 

180 ): 

181 masker.fit() 

182 

183 masker = NiftiMapsMasker(maps33_img, resampling_target="invalid") 

184 with pytest.raises( 

185 ValueError, 

186 match="invalid value for 'resampling_target' parameter: invalid", 

187 ): 

188 masker.fit() 

189 

190 

191def test_nifti_maps_masker_with_nans_and_infs(length, n_regions, affine_eye): 

192 """Apply a NiftiMapsMasker containing NaNs and infs. 

193 

194 The masker should replace those NaNs and infs with zeros, 

195 without raising a warning. 

196 """ 

197 fmri_img, mask_img = generate_random_img( 

198 (13, 11, 12, length), affine=affine_eye 

199 ) 

200 maps_img, _ = generate_maps((13, 11, 12), n_regions, affine=affine_eye) 

201 

202 # Add NaNs and infs to atlas 

203 maps_data = get_data(maps_img).astype(np.float32) 

204 mask_data = get_data(mask_img).astype(np.float32) 

205 maps_data = maps_data * mask_data[..., None] 

206 

207 # Choose a good voxel from the first label 

208 vox_idx = np.where(maps_data[..., 0] > 0) 

209 i1, j1, k1 = vox_idx[0][0], vox_idx[1][0], vox_idx[2][0] 

210 i2, j2, k2 = vox_idx[0][1], vox_idx[1][1], vox_idx[2][1] 

211 

212 maps_data[:, :, :, 0] = np.nan 

213 maps_data[i2, j2, k2, 0] = np.inf 

214 maps_data[i1, j1, k1, 0] = 1 

215 

216 maps_img = Nifti1Image(maps_data, affine_eye) 

217 

218 # No warning, because maps_img is run through clean_img 

219 # *before* safe_get_data. 

220 masker = NiftiMapsMasker(maps_img, mask_img=mask_img) 

221 

222 signals = masker.fit_transform(fmri_img) 

223 

224 assert signals.shape == (length, n_regions) 

225 assert np.all(np.isfinite(signals)) 

226 

227 

228def test_nifti_maps_masker_with_nans_and_infs_in_data( 

229 length, n_regions, affine_eye 

230): 

231 """Apply a NiftiMapsMasker to 4D data containing NaNs and infs. 

232 

233 The masker should replace those NaNs and infs with zeros, 

234 while raising a warning. 

235 """ 

236 fmri_img, mask_img = generate_random_img( 

237 (13, 11, 12, length), affine=affine_eye 

238 ) 

239 maps_img, _ = generate_maps((13, 11, 12), n_regions, affine=affine_eye) 

240 

241 # Add NaNs and infs to data 

242 fmri_data = get_data(fmri_img) 

243 

244 fmri_data[:, 9, 9, :] = np.nan 

245 fmri_data[:, 5, 5, :] = np.inf 

246 

247 fmri_img = Nifti1Image(fmri_data, affine_eye) 

248 

249 masker = NiftiMapsMasker(maps_img, mask_img=mask_img) 

250 

251 with pytest.warns(UserWarning, match="Non-finite values detected."): 

252 signals = masker.fit_transform(fmri_img) 

253 

254 assert signals.shape == (length, n_regions) 

255 assert np.all(np.isfinite(signals)) 

256 

257 

258def test_nifti_maps_masker_resampling_to_mask( 

259 length, 

260 n_regions, 

261 affine_eye, 

262 shape_mask, 

263 shape_3d_large, 

264 img_fmri, 

265): 

266 """Test resampling to_mask in NiftiMapsMasker.""" 

267 _, mask22_img = generate_fake_fmri( 

268 shape_mask, length=length, affine=affine_eye 

269 ) 

270 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye) 

271 

272 # Target: mask 

273 masker = NiftiMapsMasker( 

274 maps33_img, mask_img=mask22_img, resampling_target="mask" 

275 ) 

276 

277 signals = masker.fit_transform(img_fmri) 

278 

279 assert_almost_equal(masker.mask_img_.affine, mask22_img.affine) 

280 assert masker.mask_img_.shape == mask22_img.shape 

281 

282 assert_almost_equal(masker.maps_img_.affine, masker.mask_img_.affine) 

283 assert masker.maps_img_.shape[:3] == masker.mask_img_.shape 

284 

285 assert signals.shape == (length, n_regions) 

286 

287 fmri11_img_r = masker.inverse_transform(signals) 

288 

289 assert_almost_equal(fmri11_img_r.affine, masker.mask_img_.affine) 

290 assert fmri11_img_r.shape == (masker.mask_img_.shape[:3] + (length,)) 

291 

292 

293def test_nifti_maps_masker_resampling_to_maps( 

294 length, 

295 n_regions, 

296 affine_eye, 

297 shape_mask, 

298 shape_3d_large, 

299 img_fmri, 

300): 

301 """Test resampling to maps in NiftiMapsMasker.""" 

302 _, mask22_img = generate_fake_fmri( 

303 shape_mask, length=length, affine=affine_eye 

304 ) 

305 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye) 

306 

307 masker = NiftiMapsMasker( 

308 maps33_img, mask_img=mask22_img, resampling_target="maps" 

309 ) 

310 

311 signals = masker.fit_transform(img_fmri) 

312 

313 assert_array_equal(masker.maps_img_.affine, maps33_img.affine) 

314 assert masker.maps_img_.shape == maps33_img.shape 

315 

316 assert_array_equal(masker.mask_img_.affine, masker.maps_img_.affine) 

317 assert masker.mask_img_.shape == masker.maps_img_.shape[:3] 

318 

319 assert signals.shape == (length, n_regions) 

320 

321 fmri11_img_r = masker.inverse_transform(signals) 

322 

323 assert_array_equal(fmri11_img_r.affine, masker.maps_img_.affine) 

324 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,)) 

325 

326 

327def test_nifti_maps_masker_clipped_mask(n_regions, affine_eye): 

328 """Test with clipped maps: mask does not contain all maps.""" 

329 # Shapes do matter in that case 

330 length = 21 

331 shape1 = (10, 11, 12, length) 

332 shape2 = (8, 9, 10) # mask 

333 shape3 = (16, 18, 20) # maps 

334 affine2 = np.diag((2, 2, 2, 1)) # just for mask 

335 

336 fmri11_img, _ = generate_random_img(shape1, affine=affine_eye) 

337 _, mask22_img = generate_fake_fmri(shape2, length=1, affine=affine2) 

338 # Target: maps 

339 maps33_img, _ = generate_maps(shape3, n_regions, affine=affine_eye) 

340 

341 masker = NiftiMapsMasker( 

342 maps33_img, mask_img=mask22_img, resampling_target="maps" 

343 ) 

344 

345 signals = masker.fit_transform(fmri11_img) 

346 

347 assert_almost_equal(masker.maps_img_.affine, maps33_img.affine) 

348 assert masker.maps_img_.shape == maps33_img.shape 

349 

350 assert_almost_equal(masker.mask_img_.affine, masker.maps_img_.affine) 

351 assert masker.mask_img_.shape == masker.maps_img_.shape[:3] 

352 

353 assert signals.shape == (length, n_regions) 

354 # Some regions have been clipped. Resulting signal must be zero 

355 assert (signals.var(axis=0) == 0).sum() < n_regions 

356 

357 fmri11_img_r = masker.inverse_transform(signals) 

358 

359 assert_almost_equal(fmri11_img_r.affine, masker.maps_img_.affine) 

360 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,)) 

361 

362 

363def non_overlapping_maps(): 

364 """Generate maps with non-overlapping regions. 

365 

366 All voxels belong to only 1 region. 

367 """ 

368 non_overlapping_data = np.zeros((*_shape_3d_default(), 2)) 

369 non_overlapping_data[:2, :, :, 0] = 1.0 

370 non_overlapping_data[2:, :, :, 1] = 1.0 

371 return Nifti1Image( 

372 non_overlapping_data, 

373 np.eye(4), 

374 ) 

375 

376 

377def overlapping_maps(): 

378 """Generate maps with overlapping regions. 

379 

380 Same voxel has non null value for 2 different regions. 

381 """ 

382 overlapping_data = np.zeros((*_shape_3d_default(), 2)) 

383 overlapping_data[:3, :, :, 0] = 1.0 

384 overlapping_data[2:, :, :, 1] = 1.0 

385 return Nifti1Image(overlapping_data, np.eye(4)) 

386 

387 

388@pytest.mark.parametrize( 

389 "maps_img_fn", [overlapping_maps, non_overlapping_maps] 

390) 

391@pytest.mark.parametrize("allow_overlap", [True, False]) 

392def test_nifti_maps_masker_overlap(maps_img_fn, allow_overlap, img_fmri): 

393 """Test resampling in NiftiMapsMasker.""" 

394 masker = NiftiMapsMasker(maps_img_fn(), allow_overlap=allow_overlap) 

395 

396 if allow_overlap is False and maps_img_fn.__name__ == "overlapping_maps": 

397 with pytest.raises(ValueError, match="Overlap detected"): 

398 masker.fit_transform(img_fmri) 

399 else: 

400 masker.fit_transform(img_fmri) 

401 

402 

403def test_standardization(rng, affine_eye, shape_3d_default): 

404 """Check output properly standardized with 'standardize' parameter.""" 

405 length = 500 

406 

407 signals = rng.standard_normal(size=(np.prod(shape_3d_default), length)) 

408 means = ( 

409 rng.standard_normal(size=(np.prod(shape_3d_default), 1)) * 50 + 1000 

410 ) 

411 signals += means 

412 img = Nifti1Image(signals.reshape((*shape_3d_default, length)), affine_eye) 

413 

414 maps, _ = generate_maps((9, 9, 5), 10) 

415 

416 # Unstandarized 

417 masker = NiftiMapsMasker(maps, standardize=False) 

418 unstandarized_label_signals = masker.fit_transform(img) 

419 

420 # z-score 

421 masker = NiftiMapsMasker(maps, standardize="zscore_sample") 

422 trans_signals = masker.fit_transform(img) 

423 

424 assert_almost_equal(trans_signals.mean(0), 0) 

425 assert_almost_equal(trans_signals.std(0), 1, decimal=3) 

426 

427 # psc 

428 masker = NiftiMapsMasker(maps, standardize="psc") 

429 trans_signals = masker.fit_transform(img) 

430 

431 assert_almost_equal(trans_signals.mean(0), 0) 

432 assert_almost_equal( 

433 trans_signals, 

434 ( 

435 unstandarized_label_signals 

436 / unstandarized_label_signals.mean(0) 

437 * 100 

438 - 100 

439 ), 

440 )