Coverage for nilearn/maskers/tests/test_nifti_maps_masker.py: 0%
184 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""Test nilearn.maskers.nifti_maps_masker.
3Functions in this file only test features added by the NiftiMapsMasker class,
4rather than the underlying functions (clean(), img_to_signals_labels(), etc.).
6See test_masking.py and test_signal.py for details.
7"""
9import numpy as np
10import pytest
11from nibabel import Nifti1Image
12from numpy.testing import assert_almost_equal, assert_array_equal
13from sklearn.utils.estimator_checks import parametrize_with_checks
15from nilearn._utils.data_gen import (
16 generate_fake_fmri,
17 generate_maps,
18 generate_random_img,
19)
20from nilearn._utils.estimator_checks import (
21 check_estimator,
22 nilearn_check_estimator,
23 return_expected_failed_checks,
24)
25from nilearn._utils.tags import SKLEARN_LT_1_6
26from nilearn._utils.testing import write_imgs_to_path
27from nilearn.conftest import _img_maps, _shape_3d_default
28from nilearn.image import get_data
29from nilearn.maskers import NiftiMapsMasker
31ESTIMATORS_TO_CHECK = [NiftiMapsMasker()]
33if SKLEARN_LT_1_6:
35 @pytest.mark.parametrize(
36 "estimator, check, name",
37 check_estimator(estimators=ESTIMATORS_TO_CHECK),
38 )
39 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001
40 """Check compliance with sklearn estimators."""
41 check(estimator)
43 @pytest.mark.xfail(reason="invalid checks should fail")
44 @pytest.mark.parametrize(
45 "estimator, check, name",
46 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False),
47 )
48 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001
49 """Check compliance with sklearn estimators."""
50 check(estimator)
52else:
54 @parametrize_with_checks(
55 estimators=ESTIMATORS_TO_CHECK,
56 expected_failed_checks=return_expected_failed_checks,
57 )
58 def test_check_estimator_sklearn(estimator, check):
59 """Check compliance with sklearn estimators."""
60 check(estimator)
63@pytest.mark.timeout(0)
64@pytest.mark.parametrize(
65 "estimator, check, name",
66 nilearn_check_estimator(
67 estimators=[ # pass less than the default number of regions
68 # to speed up the tests
69 NiftiMapsMasker(maps_img=_img_maps(n_regions=2))
70 ]
71 ),
72)
73def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001
74 """Check compliance with sklearn estimators."""
75 check(estimator)
78def test_nifti_maps_masker_data_atlas_different_shape(
79 length, affine_eye, img_maps
80):
81 """Test with data and atlas of different shape.
83 The atlas should be resampled to the data.
84 """
85 shape2 = (12, 10, 14)
87 shape22 = (5, 5, 6)
89 affine2 = 2 * affine_eye
90 affine2[-1, -1] = 1
92 _, mask21_img = generate_fake_fmri(
93 shape2, affine=affine_eye, length=length
94 )
95 fmri22_img, _ = generate_fake_fmri(shape22, affine=affine2, length=length)
97 masker = NiftiMapsMasker(img_maps, mask_img=mask21_img)
99 masker.fit(fmri22_img)
101 assert_array_equal(masker.maps_img_.affine, affine2)
104def test_nifti_maps_masker_fit(n_regions, img_maps):
105 """Check fitted attributes."""
106 masker = NiftiMapsMasker(img_maps, resampling_target=None)
108 masker.fit()
110 # Check attributes defined at fit
111 assert masker.n_elements_ == n_regions
114def test_nifti_maps_masker_errors():
115 """Check fitting errors."""
116 masker = NiftiMapsMasker()
117 with pytest.raises(TypeError, match="input should be a NiftiLike object"):
118 masker.fit()
121@pytest.mark.parametrize("create_files", (True, False))
122def test_nifti_maps_masker_errors_field_of_view(
123 tmp_path, length, affine_eye, shape_3d_default, create_files, img_maps
124):
125 """Check field of view errors."""
126 shape2 = (12, 10, 14)
127 affine2 = np.diag((1, 2, 3, 1))
129 fmri12_img, mask12_img = generate_fake_fmri(
130 shape_3d_default, affine=affine2, length=length
131 )
132 fmri21_img, mask21_img = generate_fake_fmri(
133 shape2, affine=affine_eye, length=length
134 )
136 error_msg = "Following field of view errors were detected"
138 masker = NiftiMapsMasker(
139 img_maps, mask_img=mask21_img, resampling_target=None
140 )
141 with pytest.raises(ValueError, match=error_msg):
142 masker.fit()
144 # Test all kinds of mismatches between shapes and between affines
145 images = write_imgs_to_path(
146 img_maps,
147 mask12_img,
148 file_path=tmp_path,
149 create_files=create_files,
150 )
151 labels11, mask12 = images
153 masker = NiftiMapsMasker(labels11, resampling_target=None)
155 with pytest.raises(ValueError, match=error_msg):
156 masker.fit_transform(fmri12_img)
158 with pytest.raises(ValueError, match=error_msg):
159 masker.fit_transform(fmri21_img)
161 masker = NiftiMapsMasker(labels11, mask_img=mask12, resampling_target=None)
162 with pytest.raises(ValueError, match=error_msg):
163 masker.fit()
166def test_nifti_maps_masker_resampling_errors(
167 n_regions, affine_eye, shape_3d_large
168):
169 """Test resampling errors."""
170 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye)
172 masker = NiftiMapsMasker(maps33_img, resampling_target="mask")
174 with pytest.raises(
175 ValueError,
176 match=(
177 "resampling_target has been set to 'mask' "
178 "but no mask has been provided."
179 ),
180 ):
181 masker.fit()
183 masker = NiftiMapsMasker(maps33_img, resampling_target="invalid")
184 with pytest.raises(
185 ValueError,
186 match="invalid value for 'resampling_target' parameter: invalid",
187 ):
188 masker.fit()
191def test_nifti_maps_masker_with_nans_and_infs(length, n_regions, affine_eye):
192 """Apply a NiftiMapsMasker containing NaNs and infs.
194 The masker should replace those NaNs and infs with zeros,
195 without raising a warning.
196 """
197 fmri_img, mask_img = generate_random_img(
198 (13, 11, 12, length), affine=affine_eye
199 )
200 maps_img, _ = generate_maps((13, 11, 12), n_regions, affine=affine_eye)
202 # Add NaNs and infs to atlas
203 maps_data = get_data(maps_img).astype(np.float32)
204 mask_data = get_data(mask_img).astype(np.float32)
205 maps_data = maps_data * mask_data[..., None]
207 # Choose a good voxel from the first label
208 vox_idx = np.where(maps_data[..., 0] > 0)
209 i1, j1, k1 = vox_idx[0][0], vox_idx[1][0], vox_idx[2][0]
210 i2, j2, k2 = vox_idx[0][1], vox_idx[1][1], vox_idx[2][1]
212 maps_data[:, :, :, 0] = np.nan
213 maps_data[i2, j2, k2, 0] = np.inf
214 maps_data[i1, j1, k1, 0] = 1
216 maps_img = Nifti1Image(maps_data, affine_eye)
218 # No warning, because maps_img is run through clean_img
219 # *before* safe_get_data.
220 masker = NiftiMapsMasker(maps_img, mask_img=mask_img)
222 signals = masker.fit_transform(fmri_img)
224 assert signals.shape == (length, n_regions)
225 assert np.all(np.isfinite(signals))
228def test_nifti_maps_masker_with_nans_and_infs_in_data(
229 length, n_regions, affine_eye
230):
231 """Apply a NiftiMapsMasker to 4D data containing NaNs and infs.
233 The masker should replace those NaNs and infs with zeros,
234 while raising a warning.
235 """
236 fmri_img, mask_img = generate_random_img(
237 (13, 11, 12, length), affine=affine_eye
238 )
239 maps_img, _ = generate_maps((13, 11, 12), n_regions, affine=affine_eye)
241 # Add NaNs and infs to data
242 fmri_data = get_data(fmri_img)
244 fmri_data[:, 9, 9, :] = np.nan
245 fmri_data[:, 5, 5, :] = np.inf
247 fmri_img = Nifti1Image(fmri_data, affine_eye)
249 masker = NiftiMapsMasker(maps_img, mask_img=mask_img)
251 with pytest.warns(UserWarning, match="Non-finite values detected."):
252 signals = masker.fit_transform(fmri_img)
254 assert signals.shape == (length, n_regions)
255 assert np.all(np.isfinite(signals))
258def test_nifti_maps_masker_resampling_to_mask(
259 length,
260 n_regions,
261 affine_eye,
262 shape_mask,
263 shape_3d_large,
264 img_fmri,
265):
266 """Test resampling to_mask in NiftiMapsMasker."""
267 _, mask22_img = generate_fake_fmri(
268 shape_mask, length=length, affine=affine_eye
269 )
270 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye)
272 # Target: mask
273 masker = NiftiMapsMasker(
274 maps33_img, mask_img=mask22_img, resampling_target="mask"
275 )
277 signals = masker.fit_transform(img_fmri)
279 assert_almost_equal(masker.mask_img_.affine, mask22_img.affine)
280 assert masker.mask_img_.shape == mask22_img.shape
282 assert_almost_equal(masker.maps_img_.affine, masker.mask_img_.affine)
283 assert masker.maps_img_.shape[:3] == masker.mask_img_.shape
285 assert signals.shape == (length, n_regions)
287 fmri11_img_r = masker.inverse_transform(signals)
289 assert_almost_equal(fmri11_img_r.affine, masker.mask_img_.affine)
290 assert fmri11_img_r.shape == (masker.mask_img_.shape[:3] + (length,))
293def test_nifti_maps_masker_resampling_to_maps(
294 length,
295 n_regions,
296 affine_eye,
297 shape_mask,
298 shape_3d_large,
299 img_fmri,
300):
301 """Test resampling to maps in NiftiMapsMasker."""
302 _, mask22_img = generate_fake_fmri(
303 shape_mask, length=length, affine=affine_eye
304 )
305 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye)
307 masker = NiftiMapsMasker(
308 maps33_img, mask_img=mask22_img, resampling_target="maps"
309 )
311 signals = masker.fit_transform(img_fmri)
313 assert_array_equal(masker.maps_img_.affine, maps33_img.affine)
314 assert masker.maps_img_.shape == maps33_img.shape
316 assert_array_equal(masker.mask_img_.affine, masker.maps_img_.affine)
317 assert masker.mask_img_.shape == masker.maps_img_.shape[:3]
319 assert signals.shape == (length, n_regions)
321 fmri11_img_r = masker.inverse_transform(signals)
323 assert_array_equal(fmri11_img_r.affine, masker.maps_img_.affine)
324 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,))
327def test_nifti_maps_masker_clipped_mask(n_regions, affine_eye):
328 """Test with clipped maps: mask does not contain all maps."""
329 # Shapes do matter in that case
330 length = 21
331 shape1 = (10, 11, 12, length)
332 shape2 = (8, 9, 10) # mask
333 shape3 = (16, 18, 20) # maps
334 affine2 = np.diag((2, 2, 2, 1)) # just for mask
336 fmri11_img, _ = generate_random_img(shape1, affine=affine_eye)
337 _, mask22_img = generate_fake_fmri(shape2, length=1, affine=affine2)
338 # Target: maps
339 maps33_img, _ = generate_maps(shape3, n_regions, affine=affine_eye)
341 masker = NiftiMapsMasker(
342 maps33_img, mask_img=mask22_img, resampling_target="maps"
343 )
345 signals = masker.fit_transform(fmri11_img)
347 assert_almost_equal(masker.maps_img_.affine, maps33_img.affine)
348 assert masker.maps_img_.shape == maps33_img.shape
350 assert_almost_equal(masker.mask_img_.affine, masker.maps_img_.affine)
351 assert masker.mask_img_.shape == masker.maps_img_.shape[:3]
353 assert signals.shape == (length, n_regions)
354 # Some regions have been clipped. Resulting signal must be zero
355 assert (signals.var(axis=0) == 0).sum() < n_regions
357 fmri11_img_r = masker.inverse_transform(signals)
359 assert_almost_equal(fmri11_img_r.affine, masker.maps_img_.affine)
360 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,))
363def non_overlapping_maps():
364 """Generate maps with non-overlapping regions.
366 All voxels belong to only 1 region.
367 """
368 non_overlapping_data = np.zeros((*_shape_3d_default(), 2))
369 non_overlapping_data[:2, :, :, 0] = 1.0
370 non_overlapping_data[2:, :, :, 1] = 1.0
371 return Nifti1Image(
372 non_overlapping_data,
373 np.eye(4),
374 )
377def overlapping_maps():
378 """Generate maps with overlapping regions.
380 Same voxel has non null value for 2 different regions.
381 """
382 overlapping_data = np.zeros((*_shape_3d_default(), 2))
383 overlapping_data[:3, :, :, 0] = 1.0
384 overlapping_data[2:, :, :, 1] = 1.0
385 return Nifti1Image(overlapping_data, np.eye(4))
388@pytest.mark.parametrize(
389 "maps_img_fn", [overlapping_maps, non_overlapping_maps]
390)
391@pytest.mark.parametrize("allow_overlap", [True, False])
392def test_nifti_maps_masker_overlap(maps_img_fn, allow_overlap, img_fmri):
393 """Test resampling in NiftiMapsMasker."""
394 masker = NiftiMapsMasker(maps_img_fn(), allow_overlap=allow_overlap)
396 if allow_overlap is False and maps_img_fn.__name__ == "overlapping_maps":
397 with pytest.raises(ValueError, match="Overlap detected"):
398 masker.fit_transform(img_fmri)
399 else:
400 masker.fit_transform(img_fmri)
403def test_standardization(rng, affine_eye, shape_3d_default):
404 """Check output properly standardized with 'standardize' parameter."""
405 length = 500
407 signals = rng.standard_normal(size=(np.prod(shape_3d_default), length))
408 means = (
409 rng.standard_normal(size=(np.prod(shape_3d_default), 1)) * 50 + 1000
410 )
411 signals += means
412 img = Nifti1Image(signals.reshape((*shape_3d_default, length)), affine_eye)
414 maps, _ = generate_maps((9, 9, 5), 10)
416 # Unstandarized
417 masker = NiftiMapsMasker(maps, standardize=False)
418 unstandarized_label_signals = masker.fit_transform(img)
420 # z-score
421 masker = NiftiMapsMasker(maps, standardize="zscore_sample")
422 trans_signals = masker.fit_transform(img)
424 assert_almost_equal(trans_signals.mean(0), 0)
425 assert_almost_equal(trans_signals.std(0), 1, decimal=3)
427 # psc
428 masker = NiftiMapsMasker(maps, standardize="psc")
429 trans_signals = masker.fit_transform(img)
431 assert_almost_equal(trans_signals.mean(0), 0)
432 assert_almost_equal(
433 trans_signals,
434 (
435 unstandarized_label_signals
436 / unstandarized_label_signals.mean(0)
437 * 100
438 - 100
439 ),
440 )