Coverage for nilearn/maskers/tests/test_multi_nifti_masker.py: 0%
146 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""Test the multi_nifti_masker module."""
3import shutil
4from tempfile import mkdtemp
6import numpy as np
7import pytest
8from joblib import Memory, hash
9from nibabel import Nifti1Image
10from numpy.testing import assert_array_equal
11from sklearn.utils.estimator_checks import parametrize_with_checks
13from nilearn._utils.estimator_checks import (
14 check_estimator,
15 nilearn_check_estimator,
16 return_expected_failed_checks,
17)
18from nilearn._utils.tags import SKLEARN_LT_1_6
19from nilearn._utils.testing import write_imgs_to_path
20from nilearn.image import get_data
21from nilearn.maskers import MultiNiftiMasker
23ESTIMATORS_TO_CHECK = [MultiNiftiMasker()]
25if SKLEARN_LT_1_6:
27 @pytest.mark.parametrize(
28 "estimator, check, name",
29 check_estimator(estimators=ESTIMATORS_TO_CHECK),
30 )
31 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001
32 """Check compliance with sklearn estimators."""
33 check(estimator)
35 @pytest.mark.xfail(reason="invalid checks should fail")
36 @pytest.mark.parametrize(
37 "estimator, check, name",
38 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False),
39 )
40 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001
41 """Check compliance with sklearn estimators."""
42 check(estimator)
44else:
46 @parametrize_with_checks(
47 estimators=ESTIMATORS_TO_CHECK,
48 expected_failed_checks=return_expected_failed_checks,
49 )
50 def test_check_estimator_sklearn(estimator, check):
51 """Check compliance with sklearn estimators."""
52 check(estimator)
55# check_multi_masker_transformer_high_variance_confounds is slow
56@pytest.mark.timeout(0)
57@pytest.mark.parametrize(
58 "estimator, check, name",
59 nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK),
60)
61def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001
62 """Check compliance with sklearn estimators."""
63 check(estimator)
66@pytest.fixture
67def data_2(shape_3d_default):
68 """Return 3D zeros with a few 10 in the center."""
69 data = np.zeros(shape_3d_default)
70 data[1:-2, 1:-2, 1:-2] = 10
71 return data
74@pytest.fixture
75def img_1(data_1, affine_eye):
76 """Return Nifti image of 3D zeros with a few 10 in the center."""
77 return Nifti1Image(data_1, affine_eye)
80@pytest.fixture
81def img_2(data_2, affine_eye):
82 """Return Nifti image of 3D zeros with a few 10 in the center."""
83 return Nifti1Image(data_2, affine_eye)
86def test_auto_mask(data_1, img_1, data_2, img_2):
87 """Test that a proper mask is generated from fitted image."""
88 masker = MultiNiftiMasker(mask_args={"opening": 0})
90 # Smoke test the fit
91 masker.fit([[img_1]])
93 # Test mask intersection
94 masker.fit([[img_1, img_2]])
96 assert_array_equal(
97 get_data(masker.mask_img_), np.logical_or(data_1, data_2)
98 )
100 # Smoke test the transform
101 masker.transform([[img_1]])
102 # It should also work with a 3D image
103 masker.transform(img_1)
106def test_nan():
107 """Check when fitted data contains nan."""
108 data = np.ones((9, 9, 9))
109 data[0] = np.nan
110 data[:, 0] = np.nan
111 data[:, :, 0] = np.nan
112 data[-1] = np.nan
113 data[:, -1] = np.nan
114 data[:, :, -1] = np.nan
115 data[3:-3, 3:-3, 3:-3] = 10
116 img = Nifti1Image(data, np.eye(4))
118 masker = MultiNiftiMasker(mask_args={"opening": 0})
119 masker.fit([img])
121 mask = get_data(masker.mask_img_)
123 assert mask[1:-1, 1:-1, 1:-1].all()
124 assert not mask[0].any()
125 assert not mask[:, 0].any()
126 assert not mask[:, :, 0].any()
127 assert not mask[-1].any()
128 assert not mask[:, -1].any()
129 assert not mask[:, :, -1].any()
132def test_different_affines():
133 """Check mask and EIP files with different affines."""
134 mask_img = Nifti1Image(
135 np.ones((2, 2, 2), dtype=np.int8), affine=np.diag((4, 4, 4, 1))
136 )
137 epi_img1 = Nifti1Image(np.ones((4, 4, 4, 3)), affine=np.diag((2, 2, 2, 1)))
138 epi_img2 = Nifti1Image(np.ones((3, 3, 3, 3)), affine=np.diag((3, 3, 3, 1)))
140 masker = MultiNiftiMasker(mask_img=mask_img)
141 epis = masker.fit_transform([epi_img1, epi_img2])
142 for this_epi in epis:
143 masker.inverse_transform(this_epi)
146def test_3d_images(rng):
147 """Test that the MultiNiftiMasker works with 3D images.
149 Note that fit() requires all images in list to have the same affine.
150 """
151 mask_img = Nifti1Image(
152 np.ones((2, 2, 2), dtype=np.int8), affine=np.diag((2, 2, 2, 1))
153 )
154 epi_img1 = Nifti1Image(rng.random((2, 2, 2)), affine=np.diag((4, 4, 4, 1)))
155 epi_img2 = Nifti1Image(rng.random((2, 2, 2)), affine=np.diag((4, 4, 4, 1)))
156 masker = MultiNiftiMasker(mask_img=mask_img)
158 masker.fit_transform([epi_img1, epi_img2])
161def test_joblib_cache(mask_img_1, tmp_path):
162 """Check cached data."""
163 filename = write_imgs_to_path(
164 mask_img_1, file_path=tmp_path, create_files=True
165 )
166 masker = MultiNiftiMasker(mask_img=filename)
167 masker.fit()
168 mask_hash = hash(masker.mask_img_)
169 get_data(masker.mask_img_)
171 assert mask_hash == hash(masker.mask_img_)
174@pytest.mark.timeout(0)
175def test_shelving(rng):
176 """Check behavior when shelving masker."""
177 mask_img = Nifti1Image(
178 np.ones((2, 2, 2), dtype=np.int8), affine=np.diag((2, 2, 2, 1))
179 )
180 epi_img1 = Nifti1Image(rng.random((2, 2, 2)), affine=np.diag((4, 4, 4, 1)))
181 epi_img2 = Nifti1Image(rng.random((2, 2, 2)), affine=np.diag((4, 4, 4, 1)))
182 cachedir = mkdtemp()
183 try:
184 masker_shelved = MultiNiftiMasker(
185 mask_img=mask_img,
186 memory=Memory(location=cachedir, mmap_mode="r", verbose=0),
187 )
188 masker_shelved._shelving = True
189 epis_shelved = masker_shelved.fit_transform([epi_img1, epi_img2])
190 masker = MultiNiftiMasker(mask_img=mask_img)
191 epis = masker.fit_transform([epi_img1, epi_img2])
193 for epi_shelved, epi in zip(epis_shelved, epis):
194 epi_shelved = epi_shelved.get()
195 assert_array_equal(epi_shelved, epi)
197 epi = masker.fit_transform(epi_img1)
198 epi_shelved = masker_shelved.fit_transform(epi_img1)
199 epi_shelved = epi_shelved.get()
201 assert_array_equal(epi_shelved, epi)
203 finally:
204 # enables to delete "filename" on windows
205 del masker
206 shutil.rmtree(cachedir, ignore_errors=True)
209@pytest.fixture
210def list_random_imgs(img_3d_rand_eye):
211 """Create a list of random 3D nifti images."""
212 return [img_3d_rand_eye] * 2
215def test_mask_strategy_errors(list_random_imgs):
216 """Throw error with unknown mask_strategy."""
217 mask = MultiNiftiMasker(mask_strategy="foo")
219 with pytest.raises(
220 ValueError, match="Unknown value of mask_strategy 'foo'"
221 ):
222 mask.fit(list_random_imgs)
224 # Warning with deprecated 'template' strategy,
225 # plus an exception because there's no resulting mask
226 mask = MultiNiftiMasker(mask_strategy="template")
227 with pytest.warns(
228 UserWarning, match="Masking strategy 'template' is deprecated."
229 ):
230 mask.fit(list_random_imgs)
233@pytest.mark.parametrize(
234 "strategy", [f"{p}-template" for p in ["whole-brain", "gm", "wm"]]
235)
236def test_compute_mask_strategy(strategy, shape_3d_default, list_random_imgs):
237 """Check different strategies to compute masks."""
238 masker = MultiNiftiMasker(mask_strategy=strategy, mask_args={"opening": 1})
239 masker.fit(list_random_imgs)
241 # Check that the order of the images does not change the output
242 masker2 = MultiNiftiMasker(
243 mask_strategy=strategy, mask_args={"opening": 1}
244 )
245 masker2.fit(list_random_imgs[::-1])
246 mask_ref = np.zeros(shape_3d_default, dtype="int8")
248 np.testing.assert_array_equal(get_data(masker.mask_img_), mask_ref)
249 np.testing.assert_array_equal(get_data(masker2.mask_img_), mask_ref)
252def test_standardization(rng, shape_3d_default, affine_eye):
253 """Check output properly standardized with 'standardize' parameter."""
254 n_samples = 500
256 signals = rng.standard_normal(
257 size=(2, np.prod(shape_3d_default), n_samples)
258 )
259 means = (
260 rng.standard_normal(size=(2, np.prod(shape_3d_default), 1)) * 50 + 1000
261 )
262 signals += means
264 img1 = Nifti1Image(
265 signals[0].reshape((*shape_3d_default, n_samples)), affine_eye
266 )
267 img2 = Nifti1Image(
268 signals[1].reshape((*shape_3d_default, n_samples)), affine_eye
269 )
271 mask = Nifti1Image(np.ones(shape_3d_default), affine_eye)
273 # z-score
274 masker = MultiNiftiMasker(mask, standardize="zscore_sample")
275 trans_signals = masker.fit_transform([img1, img2])
277 for ts in trans_signals:
278 np.testing.assert_almost_equal(ts.mean(0), 0)
279 np.testing.assert_almost_equal(ts.std(0), 1, decimal=3)
281 # psc
282 masker = MultiNiftiMasker(mask, standardize="psc")
283 trans_signals = masker.fit_transform([img1, img2])
285 for ts, s in zip(trans_signals, signals):
286 np.testing.assert_almost_equal(ts.mean(0), 0)
287 np.testing.assert_almost_equal(
288 ts, (s / s.mean(1)[:, np.newaxis] * 100 - 100).T
289 )