Coverage for nilearn/maskers/tests/test_nifti_spheres_masker.py: 0%
178 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""Test nilearn.maskers.nifti_spheres_masker."""
3import numpy as np
4import pytest
5from nibabel import Nifti1Image
6from numpy.testing import assert_array_almost_equal, assert_array_equal
7from sklearn.utils.estimator_checks import parametrize_with_checks
9from nilearn._utils.estimator_checks import (
10 check_estimator,
11 nilearn_check_estimator,
12 return_expected_failed_checks,
13)
14from nilearn._utils.tags import SKLEARN_LT_1_6
15from nilearn.image import get_data, new_img_like
16from nilearn.maskers import NiftiSpheresMasker
18ESTIMATORS_TO_CHECK = [NiftiSpheresMasker(seeds=[(1, 1, 1)])]
20if SKLEARN_LT_1_6:
22 @pytest.mark.parametrize(
23 "estimator, check, name",
24 check_estimator(estimators=ESTIMATORS_TO_CHECK),
25 )
26 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001
27 """Check compliance with sklearn estimators."""
28 check(estimator)
30 @pytest.mark.xfail(reason="invalid checks should fail")
31 @pytest.mark.parametrize(
32 "estimator, check, name",
33 check_estimator(estimators=ESTIMATORS_TO_CHECK, valid=False),
34 )
35 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001
36 """Check compliance with sklearn estimators."""
37 check(estimator)
39else:
41 @parametrize_with_checks(
42 estimators=ESTIMATORS_TO_CHECK,
43 expected_failed_checks=return_expected_failed_checks,
44 )
45 def test_check_estimator_sklearn(estimator, check):
46 """Check compliance with sklearn estimators."""
47 check(estimator)
50@pytest.mark.parametrize(
51 "estimator, check, name",
52 nilearn_check_estimator(estimators=ESTIMATORS_TO_CHECK),
53)
54def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001
55 """Check compliance with nilearn estimators rules."""
56 check(estimator)
59def test_seed_extraction(rng, affine_eye):
60 """Test seed extraction."""
61 data = rng.random((3, 3, 3, 5))
62 img = Nifti1Image(data, affine_eye)
63 masker = NiftiSpheresMasker([(1, 1, 1)])
65 # Test the fit
66 masker.fit()
68 # Test the transform
69 s = masker.transform(img)
71 assert_array_equal(s[:, 0], data[1, 1, 1])
74def test_sphere_extraction(rng, affine_eye):
75 """Test sphere extraction."""
76 seed = (1, 1, 1)
78 data = rng.random((3, 3, 3, 5))
80 img = Nifti1Image(data, affine_eye)
82 masker = NiftiSpheresMasker([seed], radius=1)
84 masker.fit()
86 # Check attributes defined at fit
87 assert masker.n_elements_ == 1
89 # Test the transform
90 s = masker.transform(img)
92 mask = np.zeros((3, 3, 3), dtype=bool)
93 mask[:, 1, 1] = True
94 mask[1, :, 1] = True
95 mask[1, 1, :] = True
96 assert_array_equal(s[:, 0], np.mean(data[mask], axis=0))
98 # Now with a mask
99 mask_img = np.zeros((3, 3, 3))
100 mask_img[1, :, :] = 1
101 mask_img = Nifti1Image(mask_img, affine_eye)
103 masker = NiftiSpheresMasker([seed], radius=1, mask_img=mask_img)
104 masker.fit()
105 s = masker.transform(img)
107 assert_array_equal(
108 s[:, 0],
109 np.mean(data[np.logical_and(mask, get_data(mask_img))], axis=0),
110 )
113def test_anisotropic_sphere_extraction(rng, affine_eye):
114 """Test non anisotropic sphere extraction."""
115 seed = (2, 1, 2)
117 data = rng.random((3, 3, 3, 5))
119 affine = affine_eye
120 affine[0, 0] = 2
121 affine[2, 2] = 2
123 img = Nifti1Image(data, affine_eye)
125 masker = NiftiSpheresMasker([seed], radius=1)
127 # Test the fit
128 masker.fit()
130 # Test the transform
131 s = masker.transform(img)
133 mask = np.zeros((3, 3, 3), dtype=bool)
134 mask[1, :, 1] = True
135 assert_array_equal(s[:, 0], np.mean(data[mask], axis=0))
137 # Now with a mask
138 mask_img = np.zeros((3, 2, 3))
139 mask_img[1, 0, 1] = 1
141 affine_2 = affine_eye.copy()
142 affine_2[0, 0] = 4
144 mask_img = Nifti1Image(mask_img, affine=affine_2)
146 masker = NiftiSpheresMasker([seed], radius=1, mask_img=mask_img)
147 masker.fit()
148 s = masker.transform(img)
150 assert_array_equal(s[:, 0], data[1, 0, 1])
153def test_errors():
154 """Check seed input."""
155 masker = NiftiSpheresMasker(([1, 2]), radius=0.2)
156 with pytest.raises(ValueError, match="Seeds must be a list .+"):
157 masker.fit()
160def test_nifti_spheres_masker_overlap(rng, affine_eye):
161 """Throw error when allow_overlap=False and some spheres overlap."""
162 shape = (5, 5, 5)
164 data = rng.random((*shape, 5))
165 fmri_img = Nifti1Image(data, affine_eye)
167 seeds = [(0, 0, 0), (2, 2, 2)]
169 overlapping_masker = NiftiSpheresMasker(
170 seeds, radius=1, allow_overlap=True
171 )
172 overlapping_masker.fit_transform(fmri_img)
174 overlapping_masker = NiftiSpheresMasker(
175 seeds, radius=2, allow_overlap=True
176 )
177 overlapping_masker.fit_transform(fmri_img)
179 noverlapping_masker = NiftiSpheresMasker(
180 seeds, radius=1, allow_overlap=False
181 )
182 noverlapping_masker.fit_transform(fmri_img)
184 noverlapping_masker = NiftiSpheresMasker(
185 seeds, radius=2, allow_overlap=False
186 )
188 with pytest.raises(ValueError, match="Overlap detected"):
189 noverlapping_masker.fit_transform(fmri_img)
192def test_small_radius(rng):
193 """Check behavior when radius smaller than voxel size."""
194 shape = (3, 3, 3)
196 data = rng.random(shape)
198 mask = np.zeros(shape)
199 mask[1, 1, 1] = 1
200 mask[2, 2, 2] = 1
202 affine = np.eye(4) * 1.2
204 seed = (1.4, 1.4, 1.4)
206 masker = NiftiSpheresMasker(
207 [seed], radius=0.1, mask_img=Nifti1Image(mask, affine)
208 )
209 spheres_data = masker.fit_transform(Nifti1Image(data, affine))
210 masker.inverse_transform(spheres_data)
212 # Test if masking is taken into account
213 mask[1, 1, 1] = 0
214 mask[1, 1, 0] = 1
216 masker = NiftiSpheresMasker(
217 [seed], radius=0.1, mask_img=Nifti1Image(mask, affine)
218 )
220 with pytest.raises(ValueError, match="These spheres are empty"):
221 masker.fit_transform(Nifti1Image(data, affine))
223 masker.fit(Nifti1Image(data, affine))
225 with pytest.raises(ValueError, match="These spheres are empty"):
226 masker.inverse_transform(spheres_data)
228 # Inverse transform should still work with a masker larger radius
229 masker = NiftiSpheresMasker(
230 [seed], radius=1.6, mask_img=Nifti1Image(mask, affine)
231 )
232 masker.fit(Nifti1Image(data, affine))
233 masker.inverse_transform(spheres_data)
236def test_is_nifti_spheres_masker_give_nans(rng, affine_eye):
237 """Check behavior when data to fit_transform contains nan."""
238 data_with_nans = np.zeros((10, 10, 10), dtype=np.float32)
239 data_with_nans[:, :, :] = np.nan
241 data_without_nans = rng.random((9, 9, 9))
242 indices = np.nonzero(data_without_nans)
244 # Leaving nans outside of some data
245 data_with_nans[indices] = data_without_nans[indices]
246 img = Nifti1Image(data_with_nans, affine_eye)
248 # Interaction of seed with nans
249 seed = [(7, 7, 7)]
250 masker = NiftiSpheresMasker(seeds=seed, radius=2.0)
252 assert not np.isnan(np.sum(masker.fit_transform(img)))
254 # When mask_img is provided, the seed interacts within the brain, so no nan
255 mask = np.ones((9, 9, 9))
256 mask_img = Nifti1Image(mask, affine_eye)
257 masker = NiftiSpheresMasker(seeds=seed, radius=2.0, mask_img=mask_img)
259 assert not np.isnan(np.sum(masker.fit_transform(img)))
262def test_standardization(rng, affine_eye):
263 """Check output properly standardized with 'standardize' parameter."""
264 data = rng.random((3, 3, 3, 5))
265 img = Nifti1Image(data, affine_eye)
267 # test zscore
268 masker = NiftiSpheresMasker([(1, 1, 1)], standardize="zscore_sample")
269 # Test the fit
270 s = masker.fit_transform(img)
272 np.testing.assert_almost_equal(s.mean(), 0)
273 np.testing.assert_almost_equal(s.std(), 1, decimal=1)
275 # test psc
276 masker = NiftiSpheresMasker([(1, 1, 1)], standardize="psc")
277 # Test the fit
278 s = masker.fit_transform(img)
280 np.testing.assert_almost_equal(s.mean(), 0)
281 np.testing.assert_almost_equal(
282 s.ravel(),
283 data[1, 1, 1] / data[1, 1, 1].mean() * 100 - 100,
284 )
287def test_nifti_spheres_masker_inverse_transform(rng, affine_eye):
288 """Applying the sphere_extraction example from above backwards."""
289 data = rng.random((3, 3, 3, 5))
291 img = Nifti1Image(data, affine_eye)
293 masker = NiftiSpheresMasker([(1, 1, 1)], radius=1)
295 # Test the fit
296 masker.fit()
298 # Transform data
299 signal = masker.transform(img)
300 with pytest.raises(ValueError, match="Please provide mask_img"):
301 masker.inverse_transform(signal)
303 # Now with a mask
304 mask_img = np.zeros((3, 3, 3))
305 mask_img[1, :, :] = 1
306 mask_img = Nifti1Image(mask_img, affine_eye)
308 masker = NiftiSpheresMasker([(1, 1, 1)], radius=1, mask_img=mask_img)
309 masker.fit()
310 s = masker.transform(img)
312 # Mask describes the extend of the masker's sphere
313 mask = np.zeros((3, 3, 3), dtype=bool)
314 mask[:, 1, 1] = True
315 mask[1, :, 1] = True
316 mask[1, 1, :] = True
318 # Create an array mask
319 array_mask = np.logical_and(mask, get_data(mask_img))
321 inverse_map = masker.inverse_transform(s)
323 # Testing whether mask is applied to inverse transform
324 assert_array_equal(
325 np.mean(get_data(inverse_map), axis=-1) != 0, array_mask
326 )
327 # Test whether values are preserved
328 assert_array_equal(get_data(inverse_map)[array_mask].mean(0), s[:, 0])
330 # Test whether the mask's shape is applied
331 assert_array_equal(inverse_map.shape[:3], mask_img.shape)
334def test_nifti_spheres_masker_inverse_overlap(rng, affine_eye):
335 """Throw error when data to inverse_transform has overlapping data and \
336 allow_overlap=False.
337 """
338 shape = (5, 5, 5)
340 data = rng.random((*shape, 5))
341 fmri_img = Nifti1Image(data, affine_eye)
343 # Apply mask image - to allow inversion
344 mask_img = new_img_like(fmri_img, np.ones(shape))
345 seeds = [(0, 0, 0), (2, 2, 2)]
346 # Inverse data
347 inv_data = rng.random(len(seeds))
349 overlapping_masker = NiftiSpheresMasker(
350 seeds, radius=1, allow_overlap=True, mask_img=mask_img
351 ).fit()
352 overlapping_masker.inverse_transform(inv_data)
354 overlapping_masker = NiftiSpheresMasker(
355 seeds, radius=2, allow_overlap=True, mask_img=mask_img
356 ).fit()
358 overlap = overlapping_masker.inverse_transform(inv_data)
360 # Test whether overlapping data is averaged
361 assert_array_almost_equal(get_data(overlap)[1, 1, 1], np.mean(inv_data))
363 noverlapping_masker = NiftiSpheresMasker(
364 seeds, radius=1, allow_overlap=False, mask_img=mask_img
365 ).fit()
367 noverlapping_masker.inverse_transform(inv_data)
368 noverlapping_masker = NiftiSpheresMasker(
369 seeds, radius=2, allow_overlap=False, mask_img=mask_img
370 ).fit()
372 with pytest.raises(ValueError, match="Overlap detected"):
373 noverlapping_masker.inverse_transform(inv_data)