Coverage for nilearn/maskers/tests/test_multi_nifti_maps_masker.py: 0%
139 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""Test the multi_nifti_maps_masker module."""
3import numpy as np
4import pytest
5from numpy.testing import assert_almost_equal, assert_array_equal
6from sklearn.utils.estimator_checks import parametrize_with_checks
8from nilearn._utils.data_gen import generate_fake_fmri, generate_maps
9from nilearn._utils.estimator_checks import (
10 check_estimator,
11 nilearn_check_estimator,
12 return_expected_failed_checks,
13)
14from nilearn._utils.exceptions import DimensionError
15from nilearn._utils.tags import SKLEARN_LT_1_6
16from nilearn._utils.testing import write_imgs_to_path
17from nilearn.conftest import _img_maps
18from nilearn.maskers import MultiNiftiMapsMasker, NiftiMapsMasker
20ESTIMATORS_TO_CHECK = [MultiNiftiMapsMasker()]
22if SKLEARN_LT_1_6:
24 @pytest.mark.parametrize(
25 "estimator, check, name",
26 check_estimator(estimators=ESTIMATORS_TO_CHECK),
27 )
28 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001
29 """Check compliance with sklearn estimators."""
30 check(estimator)
32 @pytest.mark.xfail(reason="invalid checks should fail")
33 @pytest.mark.parametrize(
34 "estimator, check, name",
35 check_estimator(
36 estimators=ESTIMATORS_TO_CHECK,
37 valid=False,
38 ),
39 )
40 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001
41 """Check compliance with sklearn estimators."""
42 check(estimator)
44else:
46 @parametrize_with_checks(
47 estimators=ESTIMATORS_TO_CHECK,
48 expected_failed_checks=return_expected_failed_checks,
49 )
50 def test_check_estimator_sklearn(estimator, check):
51 """Check compliance with sklearn estimators."""
52 check(estimator)
55@pytest.mark.timeout(0)
56@pytest.mark.parametrize(
57 "estimator, check, name",
58 nilearn_check_estimator(
59 estimators=[
60 # pass less than the default number of regions
61 # to speed up the tests
62 MultiNiftiMapsMasker(_img_maps(n_regions=2)),
63 ]
64 ),
65)
66def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001
67 """Check compliance with nilearn estimators rules."""
68 check(estimator)
71@pytest.mark.timeout(0)
72def test_multi_nifti_maps_masker(
73 affine_eye, length, n_regions, shape_3d_default, img_maps
74):
75 """Check basic functions of MultiNiftiMapsMasker.
77 - fit, transform, fit_transform, inverse_transform.
78 - 4D and list[4D] inputs
79 """
80 fmri11_img, mask11_img = generate_fake_fmri(
81 shape_3d_default, affine=affine_eye, length=length
82 )
84 masker = MultiNiftiMapsMasker(
85 img_maps, mask_img=mask11_img, resampling_target=None
86 )
88 signals11 = masker.fit_transform(fmri11_img)
90 assert signals11.shape == (length, n_regions)
92 MultiNiftiMapsMasker(img_maps).fit_transform(fmri11_img)
94 # Should work with 4D + 1D input too (also test fit_transform)
95 signals_input = [fmri11_img, fmri11_img]
97 signals11_list = masker.fit_transform(signals_input)
99 for signals in signals11_list:
100 assert signals.shape == (length, n_regions)
102 # Call inverse transform
103 for signals in signals11_list:
104 fmri11_img_r = masker.inverse_transform(signals)
106 assert fmri11_img_r.shape == fmri11_img.shape
107 assert_almost_equal(fmri11_img_r.affine, fmri11_img.affine)
109 # Now try on a masker that has never seen the call to "transform"
110 masker = MultiNiftiMapsMasker(img_maps, resampling_target=None)
111 masker.fit()
112 masker.inverse_transform(signals)
115def test_multi_nifti_maps_masker_data_atlas_different_shape(
116 affine_eye, length, img_maps
117):
118 """Test with data and atlas of different shape.
120 The atlas should be resampled to the data.
121 """
122 # Check working of shape/affine checks
123 shape2 = (12, 10, 14)
124 shape22 = (5, 5, 6)
125 affine2 = np.diag((1, 2, 3, 1))
126 affine2 = 2 * np.eye(4)
127 affine2[-1, -1] = 1
129 _, mask21_img = generate_fake_fmri(
130 shape2, affine=affine_eye, length=length
131 )
132 fmri22_img, _ = generate_fake_fmri(shape22, affine=affine2, length=length)
134 masker = MultiNiftiMapsMasker(img_maps, mask_img=mask21_img)
136 masker.fit_transform(fmri22_img)
138 assert_array_equal(masker.maps_img_.affine, affine2)
141def test_multi_nifti_maps_masker_errors(
142 affine_eye, length, shape_3d_default, img_maps
143):
144 """Check errors raised by MultiNiftiMapsMasker."""
145 fmri11_img, mask11_img = generate_fake_fmri(
146 shape_3d_default, affine=affine_eye, length=length
147 )
149 masker = MultiNiftiMapsMasker(
150 img_maps, mask_img=mask11_img, resampling_target=None
151 )
153 signals_input = [fmri11_img, fmri11_img]
155 # NiftiMapsMasker should not work with 4D + 1D input
156 masker = NiftiMapsMasker(img_maps, resampling_target=None)
157 with pytest.raises(DimensionError, match="incompatible dimensionality"):
158 masker.fit_transform(signals_input)
161@pytest.mark.parametrize("create_files", [True, False])
162def test_multi_nifti_maps_masker_errors_field_of_view(
163 tmp_path,
164 affine_eye,
165 length,
166 create_files,
167 shape_3d_default,
168 img_maps,
169):
170 """Test all kinds of mismatches between shapes and between affines."""
171 # Check working of shape/affine checks
172 shape2 = (12, 10, 14)
173 affine2 = np.diag((1, 2, 3, 1))
175 fmri12_img, mask12_img = generate_fake_fmri(
176 shape_3d_default, affine=affine2, length=length
177 )
178 fmri21_img, mask21_img = generate_fake_fmri(
179 shape2, affine=affine_eye, length=length
180 )
182 error_msg = "Following field of view errors were detected"
184 masker = MultiNiftiMapsMasker(
185 img_maps, mask_img=mask21_img, resampling_target=None
186 )
187 with pytest.raises(ValueError, match=error_msg):
188 masker.fit()
190 images = write_imgs_to_path(
191 img_maps,
192 mask12_img,
193 file_path=tmp_path,
194 create_files=create_files,
195 )
196 labels11, mask12 = images
197 masker = MultiNiftiMapsMasker(labels11, resampling_target=None)
198 masker.fit()
200 with pytest.raises(ValueError, match=error_msg):
201 masker.transform(fmri12_img)
203 with pytest.raises(ValueError, match=error_msg):
204 masker.transform(fmri21_img)
206 masker = MultiNiftiMapsMasker(
207 labels11, mask_img=mask12, resampling_target=None
208 )
209 with pytest.raises(ValueError, match=error_msg):
210 masker.fit()
213def test_multi_nifti_maps_masker_resampling_error(
214 affine_eye, n_regions, shape_3d_large
215):
216 """Test MultiNiftiMapsMasker when using resampling."""
217 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye)
219 # Test error checking
220 masker = MultiNiftiMapsMasker(maps33_img, resampling_target="mask")
221 with pytest.raises(
222 ValueError,
223 match=(
224 "resampling_target has been set to 'mask' "
225 "but no mask has been provided"
226 ),
227 ):
228 masker.fit()
230 masker = MultiNiftiMapsMasker(maps33_img, resampling_target="invalid")
231 with pytest.raises(
232 ValueError, match="invalid value for 'resampling_target' parameter:"
233 ):
234 masker.fit()
237@pytest.mark.timeout(0)
238def test_multi_nifti_maps_masker_resampling_to_mask(
239 shape_mask,
240 affine_eye,
241 length,
242 n_regions,
243 shape_3d_large,
244 img_fmri,
245):
246 """Test resampling to mask in MultiNiftiMapsMasker."""
247 _, mask22_img = generate_fake_fmri(
248 shape_mask, affine=affine_eye, length=length
249 )
250 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye)
252 masker = MultiNiftiMapsMasker(
253 maps33_img, mask_img=mask22_img, resampling_target="mask"
254 )
256 signals = masker.fit_transform([img_fmri, img_fmri])
258 assert_almost_equal(masker.mask_img_.affine, mask22_img.affine)
259 assert masker.mask_img_.shape == mask22_img.shape
261 assert_almost_equal(masker.mask_img_.affine, masker.maps_img_.affine)
262 assert masker.mask_img_.shape == masker.maps_img_.shape[:3]
264 for t in signals:
265 assert t.shape == (length, n_regions)
267 fmri11_img_r = masker.inverse_transform(t)
269 assert_almost_equal(fmri11_img_r.affine, masker.maps_img_.affine)
270 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,))
273def test_multi_nifti_maps_masker_resampling_to_maps(
274 shape_mask,
275 affine_eye,
276 length,
277 n_regions,
278 shape_3d_large,
279 img_fmri,
280):
281 """Test resampling to maps in MultiNiftiMapsMasker."""
282 _, mask22_img = generate_fake_fmri(
283 shape_mask, affine=affine_eye, length=length
284 )
285 maps33_img, _ = generate_maps(shape_3d_large, n_regions, affine=affine_eye)
287 masker = MultiNiftiMapsMasker(
288 maps33_img, mask_img=mask22_img, resampling_target="maps"
289 )
291 signals = masker.fit_transform([img_fmri, img_fmri])
293 assert_almost_equal(masker.maps_img_.affine, maps33_img.affine)
294 assert masker.maps_img_.shape == maps33_img.shape
296 assert_almost_equal(masker.mask_img_.affine, masker.maps_img_.affine)
297 assert masker.mask_img_.shape == masker.maps_img_.shape[:3]
299 for t in signals:
300 assert t.shape == (length, n_regions)
302 fmri11_img_r = masker.inverse_transform(t)
304 assert_almost_equal(fmri11_img_r.affine, masker.maps_img_.affine)
305 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,))
308def test_multi_nifti_maps_masker_resampling_clipped_mask(
309 affine_eye, length, n_regions, img_fmri
310):
311 """Test with clipped maps: mask does not contain all maps."""
312 # Shapes do matter in that case
313 shape2 = (8, 9, 10) # mask
314 shape3 = (16, 18, 20) # maps
315 affine2 = np.diag((2, 2, 2, 1)) # just for mask
317 _, mask22_img = generate_fake_fmri(shape2, length=1, affine=affine2)
318 maps33_img, _ = generate_maps(shape3, n_regions, affine=affine_eye)
320 masker = MultiNiftiMapsMasker(
321 maps33_img, mask_img=mask22_img, resampling_target="maps"
322 )
324 signals = masker.fit_transform([img_fmri, img_fmri])
326 assert_almost_equal(masker.maps_img_.affine, maps33_img.affine)
327 assert masker.maps_img_.shape == maps33_img.shape
329 assert_almost_equal(masker.mask_img_.affine, masker.maps_img_.affine)
330 assert masker.mask_img_.shape == masker.maps_img_.shape[:3]
332 for t in signals:
333 assert t.shape == (length, n_regions)
334 # Some regions have been clipped. Resulting signal must be zero
335 assert (t.var(axis=0) == 0).sum() < n_regions
337 fmri11_img_r = masker.inverse_transform(t)
339 assert_almost_equal(fmri11_img_r.affine, masker.maps_img_.affine)
340 assert fmri11_img_r.shape == (masker.maps_img_.shape[:3] + (length,))