Coverage for nilearn/maskers/tests/test_nifti_masker.py: 0%

259 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-16 12:32 +0200

1"""Test the nifti_masker module. 

2 

3Functions in this file only test features added by the NiftiMasker class, 

4not the underlying functions used (e.g. clean()). See test_masking.py and 

5test_signal.py for this. 

6""" 

7 

8import shutil 

9from pathlib import Path 

10from tempfile import mkdtemp 

11 

12import numpy as np 

13import pytest 

14from nibabel import Nifti1Image 

15from numpy.testing import assert_array_equal 

16from sklearn.utils.estimator_checks import parametrize_with_checks 

17 

18from nilearn._utils import data_gen, exceptions, testing 

19from nilearn._utils.class_inspect import get_params 

20from nilearn._utils.estimator_checks import ( 

21 check_estimator, 

22 nilearn_check_estimator, 

23 return_expected_failed_checks, 

24) 

25from nilearn._utils.tags import SKLEARN_LT_1_6 

26from nilearn.image import get_data, index_img 

27from nilearn.maskers import NiftiMasker 

28from nilearn.maskers.nifti_masker import filter_and_mask 

29 

30ESTIMATORS_TO_CHECK = [NiftiMasker()] 

31 

32if SKLEARN_LT_1_6: 

33 

34 @pytest.mark.parametrize( 

35 "estimator, check, name", 

36 check_estimator(estimators=ESTIMATORS_TO_CHECK), 

37 ) 

38 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001 

39 """Check compliance with sklearn estimators.""" 

40 check(estimator) 

41 

42 @pytest.mark.xfail(reason="invalid checks should fail") 

43 @pytest.mark.parametrize( 

44 "estimator, check, name", 

45 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False), 

46 ) 

47 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001 

48 """Check compliance with sklearn estimators.""" 

49 check(estimator) 

50 

51else: 

52 

53 @parametrize_with_checks( 

54 estimators=ESTIMATORS_TO_CHECK, 

55 expected_failed_checks=return_expected_failed_checks, 

56 ) 

57 def test_check_estimator_sklearn(estimator, check): 

58 """Check compliance with sklearn estimators.""" 

59 check(estimator) 

60 

61 

62@pytest.mark.parametrize( 

63 "estimator, check, name", 

64 nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK), 

65) 

66def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001 

67 """Check compliance with nilearn estimators rules.""" 

68 check(estimator) 

69 

70 

71def test_detrend(img_3d_rand_eye, mask_img_1): 

72 """Check that detrending doesn't do something stupid with 3D images.""" 

73 # Smoke test the fit 

74 masker = NiftiMasker(mask_img=mask_img_1, detrend=True) 

75 X = masker.fit_transform(img_3d_rand_eye) 

76 assert np.any(X != 0) 

77 

78 

79@pytest.mark.parametrize("y", [None, np.ones((9, 9, 9))]) 

80def test_fit_transform(y, img_3d_rand_eye, mask_img_1): 

81 """Check fit_transform of BaseMasker with several input args.""" 

82 # Smoke test the fit 

83 for mask_img in [mask_img_1, None]: 

84 masker = NiftiMasker(mask_img=mask_img) 

85 X = masker.fit_transform(X=img_3d_rand_eye, y=y) 

86 assert np.any(X != 0) 

87 

88 

89def test_fit_transform_warning(img_3d_rand_eye, mask_img_1): 

90 """Warn that mask creation is happening \ 

91 when mask was provided at instantiation. 

92 """ 

93 y = np.ones((9, 9, 9)) 

94 masker = NiftiMasker(mask_img=mask_img_1) 

95 with pytest.warns( 

96 UserWarning, 

97 match="Generation of a mask has been requested .*" 

98 "while a mask was given at masker creation.", 

99 ): 

100 X = masker.fit_transform(X=img_3d_rand_eye, y=y) 

101 assert np.any(X != 0) 

102 

103 

104def test_resample(img_3d_rand_eye, mask_img_1): 

105 """Check that target_affine triggers the right resampling.""" 

106 masker = NiftiMasker(mask_img=mask_img_1, target_affine=2 * np.eye(3)) 

107 # Smoke test the fit 

108 X = masker.fit_transform(img_3d_rand_eye) 

109 assert np.any(X != 0) 

110 

111 

112def test_resample_to_mask_warning(img_3d_rand_eye, affine_eye): 

113 """Check that a warning is raised when data is \ 

114 being resampled to mask's resolution. 

115 """ 

116 # defining a mask with different fov than img 

117 mask = np.zeros((12, 12, 12)) 

118 mask[3:-3, 3:-3, 3:-3] = 10 

119 mask = mask.astype("uint8") 

120 mask_img = Nifti1Image(mask, affine_eye) 

121 masker = NiftiMasker(mask_img=mask_img) 

122 with pytest.warns( 

123 UserWarning, 

124 match="imgs are being resampled to the mask_img resolution. " 

125 "This process is memory intensive. You might want to provide " 

126 "a target_affine that is equal to the affine of the imgs " 

127 "or resample the mask beforehand " 

128 "to save memory and computation time.", 

129 ): 

130 masker.fit_transform(img_3d_rand_eye) 

131 

132 

133def test_nan(affine_eye): 

134 """Check that the masker handles NaNs appropriately.""" 

135 data = np.ones((9, 9, 9)) 

136 data[0] = np.nan 

137 data[:, 0] = np.nan 

138 data[:, :, 0] = np.nan 

139 data[-1] = np.nan 

140 data[:, -1] = np.nan 

141 data[:, :, -1] = np.nan 

142 data[3:-3, 3:-3, 3:-3] = 10 

143 img = Nifti1Image(data, affine_eye) 

144 masker = NiftiMasker(mask_args={"opening": 0}) 

145 masker.fit(img) 

146 mask = get_data(masker.mask_img_) 

147 assert mask[1:-1, 1:-1, 1:-1].all() 

148 assert not mask[0].any() 

149 assert not mask[:, 0].any() 

150 assert not mask[:, :, 0].any() 

151 assert not mask[-1].any() 

152 assert not mask[:, -1].any() 

153 assert not mask[:, :, -1].any() 

154 

155 

156def test_matrix_orientation(): 

157 """Test if processing is performed along the correct axis.""" 

158 # the "step" kind generate heavyside-like signals for each voxel. 

159 # all signals being identical, standardizing along the wrong axis 

160 # would leave a null signal. Along the correct axis, the step remains. 

161 fmri, mask = data_gen.generate_fake_fmri(shape=(40, 41, 42), kind="step") 

162 masker = NiftiMasker(mask_img=mask, standardize=True, detrend=True) 

163 timeseries = masker.fit_transform(fmri) 

164 assert timeseries.shape[0] == fmri.shape[3] 

165 assert timeseries.shape[1] == get_data(mask).sum() 

166 std = timeseries.std(axis=0) 

167 assert std.shape[0] == timeseries.shape[1] # paranoid 

168 assert not np.any(std < 0.1) 

169 

170 # Test inverse transform 

171 masker = NiftiMasker(mask_img=mask, standardize=False, detrend=False) 

172 masker.fit() 

173 timeseries = masker.transform(fmri) 

174 recovered = masker.inverse_transform(timeseries) 

175 np.testing.assert_array_almost_equal(get_data(recovered), get_data(fmri)) 

176 

177 

178def test_mask_4d(shape_3d_default, affine_eye): 

179 """Test performance with 4D data.""" 

180 # Dummy mask 

181 mask = np.zeros(shape_3d_default, dtype="int32") 

182 mask[3:7, 3:7, 3:7] = 1 

183 mask_bool = mask.astype(bool) 

184 mask_img = Nifti1Image(mask, affine_eye) 

185 

186 # Dummy data 

187 shape_4d = (*shape_3d_default, 5) 

188 data = np.zeros(shape_4d, dtype="int32") 

189 data[..., 0] = 1 

190 data[..., 1] = 2 

191 data[..., 2] = 3 

192 data_img_4d = Nifti1Image(data, affine_eye) 

193 data_imgs = [ 

194 index_img(data_img_4d, 0), 

195 index_img(data_img_4d, 1), 

196 index_img(data_img_4d, 2), 

197 ] 

198 

199 # check whether transform is indeed selecting niimgs subset 

200 sample_mask = np.array([0, 2]) 

201 masker = NiftiMasker(mask_img=mask_img) 

202 masker.fit() 

203 data_trans = masker.transform(data_imgs, sample_mask=sample_mask) 

204 data_trans_img = index_img(data_img_4d, sample_mask) 

205 data_trans_direct = get_data(data_trans_img)[mask_bool, :] 

206 data_trans_direct = np.swapaxes(data_trans_direct, 0, 1) 

207 

208 assert_array_equal(data_trans, data_trans_direct) 

209 

210 masker = NiftiMasker(mask_img=mask_img) 

211 masker.fit() 

212 data_trans2 = masker.transform(data_img_4d, sample_mask=sample_mask) 

213 

214 assert_array_equal(data_trans2, data_trans_direct) 

215 

216 diff_sample_mask = np.array([2, 4]) 

217 data_trans_img_diff = index_img(data_img_4d, diff_sample_mask) 

218 data_trans_direct_diff = get_data(data_trans_img_diff)[mask_bool, :] 

219 data_trans_direct_diff = np.swapaxes(data_trans_direct_diff, 0, 1) 

220 masker = NiftiMasker(mask_img=mask_img) 

221 masker.fit() 

222 data_trans3 = masker.transform(data_img_4d, sample_mask=diff_sample_mask) 

223 

224 assert_array_equal(data_trans3, data_trans_direct_diff) 

225 

226 

227def test_4d_single_scan(rng, shape_3d_default, affine_eye): 

228 """Test that list of 4D images with last dim=1 is treated as 3D.""" 

229 shape_3d = (10, 10, 10) 

230 mask = np.zeros(shape_3d) 

231 mask[3:7, 3:7, 3:7] = 1 

232 mask_img = Nifti1Image(mask, affine_eye) 

233 

234 shape_4d = (*shape_3d_default, 1) 

235 data_5d = [rng.random(shape_4d) for _ in range(5)] 

236 data_4d = [d[..., 0] for d in data_5d] 

237 data_5d = [Nifti1Image(d, affine_eye) for d in data_5d] 

238 data_4d = [Nifti1Image(d, affine_eye) for d in data_4d] 

239 

240 masker = NiftiMasker(mask_img=mask_img) 

241 

242 masker.fit() 

243 

244 # Check attributes defined at fit 

245 assert masker.n_elements_ == np.sum(mask) 

246 

247 data_trans_5d = masker.transform(data_5d) 

248 data_trans_4d = masker.transform(data_4d) 

249 

250 assert_array_equal(data_trans_4d, data_trans_5d) 

251 

252 

253def test_sessions(affine_eye): 

254 """Test the sessions vector.""" 

255 data = np.ones((40, 40, 40, 4)) 

256 # Create a border, so that the masking work well 

257 data[0] = 0 

258 data[-1] = 0 

259 data[:, -1] = 0 

260 data[:, 0] = 0 

261 data[..., -1] = 0 

262 data[..., 0] = 0 

263 data[20, 20, 20] = 1 

264 data_img = Nifti1Image(data, affine_eye) 

265 masker = NiftiMasker(runs=np.ones(3, dtype=int)) 

266 with pytest.raises(ValueError): 

267 masker.fit_transform(data_img) 

268 

269 

270def test_joblib_cache(tmp_path, mask_img_1): 

271 """Test using joblib cache.""" 

272 from joblib import Memory, hash 

273 

274 filename = testing.write_imgs_to_path( 

275 mask_img_1, 

276 file_path=tmp_path, 

277 create_files=True, 

278 ) 

279 masker = NiftiMasker(mask_img=filename) 

280 masker.fit() 

281 mask_hash = hash(masker.mask_img_) 

282 get_data(masker.mask_img_) 

283 assert mask_hash == hash(masker.mask_img_) 

284 

285 # Test a tricky issue with memmapped joblib.memory that makes 

286 # imgs return by inverse_transform impossible to save 

287 cachedir = Path(mkdtemp()) 

288 try: 

289 masker.memory = Memory(location=cachedir, mmap_mode="r", verbose=0) 

290 X = masker.transform(mask_img_1) 

291 # inverse_transform a first time, so that the result is cached 

292 out_img = masker.inverse_transform(X) 

293 out_img = masker.inverse_transform(X) 

294 out_img.to_filename(cachedir / "test.nii") 

295 finally: 

296 # enables to delete "filename" on windows 

297 del masker 

298 shutil.rmtree(cachedir, ignore_errors=True) 

299 

300 

301def test_mask_strategy_errors_warnings(img_fmri): 

302 """Check that mask_strategy errors are raised.""" 

303 # Error with unknown mask_strategy 

304 

305 masker = NiftiMasker(mask_strategy="oops", mask_args={"threshold": 0.0}) 

306 with pytest.raises( 

307 ValueError, match="Unknown value of mask_strategy 'oops'" 

308 ): 

309 masker.fit(img_fmri) 

310 

311 # Warning with deprecated 'template' strategy, 

312 # plus an exception because there's no resulting mask 

313 masker = NiftiMasker( 

314 mask_strategy="template", mask_args={"threshold": 0.0} 

315 ) 

316 with pytest.warns( 

317 UserWarning, match="Masking strategy 'template' is deprecated." 

318 ): 

319 masker.fit(img_fmri) 

320 

321 

322def test_compute_epi_mask(affine_eye): 

323 """Test that the masker class is passing parameters appropriately.""" 

324 # Taken from test_masking.py, but used to test that the masker class 

325 # is passing parameters appropriately. 

326 mean_image = np.ones((9, 9, 3)) 

327 mean_image[3:-2, 3:-2, :] = 10 

328 mean_image[5, 5, :] = 11 

329 mean_image = Nifti1Image(mean_image.astype(float), affine_eye) 

330 

331 masker = NiftiMasker(mask_strategy="epi", mask_args={"opening": False}) 

332 masker.fit(mean_image) 

333 mask1 = masker.mask_img_ 

334 

335 masker2 = NiftiMasker( 

336 mask_strategy="epi", 

337 mask_args={"opening": False, "exclude_zeros": True}, 

338 ) 

339 masker2.fit(mean_image) 

340 mask2 = masker2.mask_img_ 

341 

342 # With an array with no zeros, exclude_zeros should not make 

343 # any difference 

344 np.testing.assert_array_equal(get_data(mask1), get_data(mask2)) 

345 

346 # Check that padding with zeros does not change the extracted mask 

347 mean_image2 = np.zeros((30, 30, 3)) 

348 mean_image2[3:12, 3:12, :] = get_data(mean_image) 

349 mean_image2 = Nifti1Image(mean_image2, affine_eye) 

350 

351 masker3 = NiftiMasker( 

352 mask_strategy="epi", 

353 mask_args={"opening": False, "exclude_zeros": True}, 

354 ) 

355 masker3.fit(mean_image2) 

356 mask3 = masker3.mask_img_ 

357 np.testing.assert_array_equal(get_data(mask1), get_data(mask3)[3:12, 3:12]) 

358 

359 # However, without exclude_zeros, it does 

360 masker4 = NiftiMasker(mask_strategy="epi", mask_args={"opening": False}) 

361 masker4.fit(mean_image2) 

362 mask4 = masker4.mask_img_ 

363 

364 assert not np.allclose(get_data(mask1), get_data(mask4)[3:12, 3:12]) 

365 

366 

367@pytest.fixture 

368def expected_mask(mask_args): 

369 """Create an expected mask.""" 

370 mask = np.zeros((9, 9, 5)) 

371 if mask_args == {}: 

372 return mask 

373 

374 mask[2:7, 2:7, 2] = 1 

375 return mask 

376 

377 

378@pytest.mark.parametrize( 

379 "strategy", [f"{p}-template" for p in ["whole-brain", "gm", "wm"]] 

380) 

381@pytest.mark.parametrize("mask_args", [{}]) 

382def test_compute_brain_mask_empty_mask_error(strategy, mask_args): 

383 """Check masker raise error when estimated mask is empty.""" 

384 masker = NiftiMasker(mask_strategy=strategy, mask_args=mask_args) 

385 

386 img, _ = data_gen.generate_random_img((9, 9, 5)) 

387 

388 with pytest.raises(ValueError, match="masks all data"): 

389 masker.fit(img) 

390 

391 

392@pytest.mark.timeout(0) 

393@pytest.mark.parametrize( 

394 "strategy", [f"{p}-template" for p in ["whole-brain", "gm", "wm"]] 

395) 

396@pytest.mark.parametrize("mask_args", [{"threshold": 0.0}]) 

397def test_compute_brain_mask(strategy, expected_mask, mask_args): 

398 """Check masker for template masking strategy.""" 

399 masker = NiftiMasker(mask_strategy=strategy, mask_args=mask_args) 

400 img, _ = data_gen.generate_random_img((9, 9, 5)) 

401 

402 masker.fit(img) 

403 

404 np.testing.assert_array_equal(get_data(masker.mask_img_), expected_mask) 

405 

406 

407def test_filter_and_mask_error(affine_eye): 

408 """Check filter_and_mask fails if mask if 4D.""" 

409 data = np.zeros([20, 30, 40, 5]) 

410 mask = np.zeros([20, 30, 40, 2]) 

411 mask[10, 15, 20, :] = 1 

412 

413 data_img = Nifti1Image(data, affine_eye) 

414 mask_img = Nifti1Image(mask, affine_eye) 

415 

416 masker = NiftiMasker() 

417 params = get_params(NiftiMasker, masker) 

418 

419 with pytest.raises( 

420 exceptions.DimensionError, 

421 match="Input data has incompatible dimensionality: " 

422 "Expected dimension is 3D and you provided " 

423 "a 4D image.", 

424 ): 

425 filter_and_mask(data_img, mask_img, params) 

426 

427 

428def test_filter_and_mask(affine_eye): 

429 """Test filter_and_mask returns output with correct shape.""" 

430 data_shape = (20, 30, 40, 5) 

431 mask_shape = (20, 30, 40) 

432 data = np.zeros(data_shape) 

433 mask = np.ones(mask_shape) 

434 

435 data_img = Nifti1Image(data, affine_eye) 

436 mask_img = Nifti1Image(mask, affine_eye) 

437 

438 masker = NiftiMasker() 

439 params = get_params(NiftiMasker, masker) 

440 params["clean_kwargs"] = {} 

441 

442 # Test return_affine = False 

443 data = filter_and_mask(data_img, mask_img, params) 

444 assert data.shape == (data_shape[3], np.prod(np.array(mask.shape))) 

445 

446 

447def test_standardization(rng, shape_3d_default, affine_eye): 

448 """Check output properly standardized with 'standardize' parameter.""" 

449 n_samples = 500 

450 

451 signals = rng.standard_normal(size=(np.prod(shape_3d_default), n_samples)) 

452 means = ( 

453 rng.standard_normal(size=(np.prod(shape_3d_default), 1)) * 50 + 1000 

454 ) 

455 signals += means 

456 img = Nifti1Image( 

457 signals.reshape((*shape_3d_default, n_samples)), 

458 affine_eye, 

459 ) 

460 

461 mask = Nifti1Image(np.ones(shape_3d_default), affine_eye) 

462 

463 # z-score 

464 masker = NiftiMasker(mask, standardize="zscore_sample") 

465 trans_signals = masker.fit_transform(img) 

466 

467 np.testing.assert_almost_equal(trans_signals.mean(0), 0) 

468 np.testing.assert_almost_equal(trans_signals.std(0), 1, decimal=3) 

469 

470 # psc 

471 masker = NiftiMasker(mask, standardize="psc") 

472 trans_signals = masker.fit_transform(img) 

473 

474 np.testing.assert_almost_equal(trans_signals.mean(0), 0) 

475 np.testing.assert_almost_equal( 

476 trans_signals, 

477 (signals / signals.mean(1)[:, np.newaxis] * 100 - 100).T, 

478 )