Coverage for nilearn/decoding/tests/test_searchlight.py: 0%
133 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-16 12:32 +0200
1"""Test the searchlight module."""
3import numpy as np
4import pytest
5from nibabel import Nifti1Image
6from sklearn.model_selection import (
7 KFold,
8 LeaveOneGroupOut,
9)
10from sklearn.utils.estimator_checks import parametrize_with_checks
12from nilearn._utils.estimator_checks import (
13 check_estimator,
14 nilearn_check_estimator,
15 return_expected_failed_checks,
16)
17from nilearn._utils.tags import SKLEARN_LT_1_6
18from nilearn.conftest import _rng
19from nilearn.decoding import searchlight
21ESTIMATOR_TO_CHECK = [searchlight.SearchLight()]
23if SKLEARN_LT_1_6:
25 @pytest.mark.parametrize(
26 "estimator, check, name",
27 check_estimator(estimators=ESTIMATOR_TO_CHECK),
28 )
29 def test_check_estimator_sklearn_valid(estimator, check, name): # noqa: ARG001
30 """Check compliance with sklearn estimators."""
31 check(estimator)
33 @pytest.mark.xfail(reason="invalid checks should fail")
34 @pytest.mark.parametrize(
35 "estimator, check, name",
36 check_estimator(estimators=ESTIMATOR_TO_CHECK, valid=False),
37 )
38 def test_check_estimator_sklearn_invalid(estimator, check, name): # noqa: ARG001
39 """Check compliance with sklearn estimators."""
40 check(estimator)
42else:
44 @parametrize_with_checks(
45 estimators=ESTIMATOR_TO_CHECK,
46 expected_failed_checks=return_expected_failed_checks,
47 )
48 def test_check_estimator_sklearn(estimator, check):
49 """Check compliance with sklearn estimators."""
50 check(estimator)
53@pytest.mark.parametrize(
54 "estimator, check, name",
55 nilearn_check_estimator(
56 estimators=[
57 searchlight.SearchLight(
58 mask_img=Nifti1Image(
59 np.ones((5, 5, 5), dtype=bool).astype("uint8"), np.eye(4)
60 )
61 )
62 ]
63 ),
64)
65def test_check_estimator_nilearn(estimator, check, name): # noqa: ARG001
66 """Check compliance with nilearn estimators rules."""
67 check(estimator)
70def _make_searchlight_test_data(frames):
71 data = _rng().random((5, 5, 5, frames))
72 mask = np.ones((5, 5, 5), dtype=bool)
73 mask_img = Nifti1Image(mask.astype("uint8"), np.eye(4))
74 # Create a condition array, with balanced classes
75 cond = np.arange(frames, dtype=int) >= (frames // 2)
77 data[2, 2, 2, :] = 0
78 data[2, 2, 2, cond] = 2
79 data_img = Nifti1Image(data, np.eye(4))
81 return data_img, cond, mask_img
84def define_cross_validation():
85 # Define cross validation
86 cv = KFold(n_splits=4)
87 n_jobs = 1
88 return cv, n_jobs
91def test_searchlight_no_mask():
92 """Check validation type mask."""
93 sl = searchlight.SearchLight(mask_img=1)
95 frames = 30
96 data_img, cond, _ = _make_searchlight_test_data(frames)
97 with pytest.raises(
98 TypeError,
99 match="input should be a NiftiLike object",
100 ):
101 sl.fit(data_img, y=cond)
104def test_searchlight_small_radius():
105 frames = 30
106 data_img, cond, mask_img = _make_searchlight_test_data(frames)
107 cv, n_jobs = define_cross_validation()
109 # Small radius : only one pixel is selected
110 sl = searchlight.SearchLight(
111 mask_img,
112 process_mask_img=mask_img,
113 radius=0.5,
114 n_jobs=n_jobs,
115 scoring="accuracy",
116 cv=cv,
117 verbose=1,
118 )
119 sl.fit(data_img, y=cond)
121 assert np.where(sl.scores_ == 1)[0].size == 1
122 assert sl.scores_[2, 2, 2] == 1.0
125def test_searchlight_mask_far_from_signal(affine_eye):
126 frames = 30
127 data_img, cond, mask_img = _make_searchlight_test_data(frames)
128 cv, n_jobs = define_cross_validation()
130 process_mask = np.zeros((5, 5, 5), dtype=bool)
131 process_mask[0, 0, 0] = True
132 process_mask_img = Nifti1Image(process_mask.astype("uint8"), affine_eye)
133 sl = searchlight.SearchLight(
134 mask_img,
135 process_mask_img=process_mask_img,
136 radius=0.5,
137 n_jobs=n_jobs,
138 scoring="accuracy",
139 cv=cv,
140 )
141 sl.fit(data_img, y=cond)
143 assert np.where(sl.scores_ == 1)[0].size == 0
146def test_searchlight_medium_radius():
147 frames = 30
148 data_img, cond, mask_img = _make_searchlight_test_data(frames)
149 cv, n_jobs = define_cross_validation()
151 sl = searchlight.SearchLight(
152 mask_img,
153 process_mask_img=mask_img,
154 radius=1,
155 n_jobs=n_jobs,
156 scoring="accuracy",
157 cv=cv,
158 )
159 sl.fit(data_img, cond)
161 assert np.where(sl.scores_ == 1)[0].size == 7
162 assert sl.scores_[2, 2, 2] == 1.0
163 assert sl.scores_[1, 2, 2] == 1.0
164 assert sl.scores_[2, 1, 2] == 1.0
165 assert sl.scores_[2, 2, 1] == 1.0
166 assert sl.scores_[3, 2, 2] == 1.0
167 assert sl.scores_[2, 3, 2] == 1.0
168 assert sl.scores_[2, 2, 3] == 1.0
171def test_searchlight_large_radius():
172 frames = 30
173 data_img, cond, mask_img = _make_searchlight_test_data(frames)
174 cv, n_jobs = define_cross_validation()
176 sl = searchlight.SearchLight(
177 mask_img,
178 process_mask_img=mask_img,
179 radius=2,
180 n_jobs=n_jobs,
181 scoring="accuracy",
182 cv=cv,
183 )
184 sl.fit(data_img, cond)
186 assert np.where(sl.scores_ == 1)[0].size == 33
187 assert sl.scores_[2, 2, 2] == 1.0
190def test_searchlight_group_cross_validation(rng):
191 frames = 30
192 data_img, cond, mask_img = _make_searchlight_test_data(frames)
193 _, n_jobs = define_cross_validation()
195 groups = rng.permutation(np.arange(frames, dtype=int) > (frames // 2))
197 sl = searchlight.SearchLight(
198 mask_img,
199 process_mask_img=mask_img,
200 radius=1,
201 n_jobs=n_jobs,
202 scoring="accuracy",
203 cv=LeaveOneGroupOut(),
204 )
205 sl.fit(data_img, y=cond, groups=groups)
207 assert np.where(sl.scores_ == 1)[0].size == 7
208 assert sl.scores_[2, 2, 2] == 1.0
211def test_searchlight_group_cross_validation_with_extra_group_variable(
212 rng,
213 affine_eye,
214):
215 frames = 30
216 data_img, cond, mask_img = _make_searchlight_test_data(frames)
217 cv, n_jobs = define_cross_validation()
219 groups = rng.permutation(np.arange(frames, dtype=int) > (frames // 2))
221 sl = searchlight.SearchLight(
222 mask_img,
223 process_mask_img=mask_img,
224 radius=1,
225 n_jobs=n_jobs,
226 scoring="accuracy",
227 cv=cv,
228 )
229 sl.fit(data_img, y=cond, groups=groups)
231 assert np.where(sl.scores_ == 1)[0].size == 7
232 assert sl.scores_[2, 2, 2] == 1.0
234 # Check whether searchlight works on list of 3D images
235 data = rng.random((5, 5, 5))
236 data_img = Nifti1Image(data, affine=affine_eye)
237 imgs = [data_img] * 12
239 # labels
240 y = [0, 1] * 6
242 # run searchlight on list of 3D images
243 sl = searchlight.SearchLight(mask_img)
244 sl.fit(imgs, y)
247def test_mask_img_dimension_mismatch():
248 """Test if SearchLight handles mismatched mask and
249 image dimensions gracefully.
250 """
251 data_img, cond, _ = _make_searchlight_test_data(frames=20)
253 # Create a mask with smaller dimensions (4x4x4 vs 5x5x5 in data_img)
254 invalid_mask_img = Nifti1Image(
255 np.ones((4, 4, 4), dtype="uint8"), np.eye(4)
256 )
258 # Instantiate SearchLight with mismatched mask
259 sl = searchlight.SearchLight(invalid_mask_img, radius=1.0)
261 # Fit should complete without raising an error
262 sl.fit(data_img, y=cond)
264 # Ensure scores_ exists and is the correct shape
265 assert sl.scores_ is not None
266 assert sl.scores_.shape == invalid_mask_img.shape
269def test_transform_applies_mask_correctly():
270 """Test if `transform()` applies the mask correctly."""
271 frames = 20
272 data_img, cond, mask_img = _make_searchlight_test_data(frames)
274 sl = searchlight.SearchLight(mask_img, radius=1.0)
275 sl.fit(data_img, y=cond)
277 # Ensure model is fitted correctly
278 assert sl.scores_ is not None
279 assert sl.process_mask_ is not None
281 # Perform transform on the same data
282 transformed_scores = sl.transform(data_img)
284 assert transformed_scores is not None
285 assert transformed_scores.shape == (5, 5, 5)
286 assert transformed_scores.size > 0
289def test_process_mask_shape_mismatch():
290 """Test SearchLight with mismatched process mask and image dimensions."""
291 frames = 20
292 data_img, cond, mask_img = _make_searchlight_test_data(frames)
294 # Create a process mask with smaller dimensions
295 # (4x4x4 vs 5x5x5 in data_img)
296 process_mask_img = Nifti1Image(
297 np.ones((4, 4, 4), dtype="uint8"), np.eye(4)
298 )
300 # Instantiate SearchLight with mismatched process mask
301 sl = searchlight.SearchLight(
302 mask_img=mask_img, process_mask_img=process_mask_img, radius=1.0
303 )
305 # Fit should complete without error, but scores may be partially populated
306 sl.fit(data_img, y=cond)
308 # Ensure scores_ exists and is the correct shape
309 assert sl.scores_ is not None
310 assert sl.scores_.shape == process_mask_img.shape