Coverage for nilearn/decoding/searchlight.py: 22%
142 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""The searchlight is a widely used approach for the study \
2of the fine-grained patterns of information in fMRI analysis, \
3in which multivariate statistical relationships are iteratively tested \
4in the neighborhood of each location of a domain.
5"""
7import time
8import warnings
9from copy import deepcopy
11import numpy as np
12from joblib import Parallel, cpu_count, delayed
13from sklearn import svm
14from sklearn.base import BaseEstimator, TransformerMixin
15from sklearn.exceptions import ConvergenceWarning
16from sklearn.model_selection import KFold, cross_val_score
17from sklearn.utils import check_array
18from sklearn.utils.estimator_checks import check_is_fitted
20from nilearn._utils import check_niimg_3d, check_niimg_4d, fill_doc, logger
21from nilearn._utils.param_validation import check_params
22from nilearn._utils.tags import SKLEARN_LT_1_6
23from nilearn.image import new_img_like
24from nilearn.maskers.nifti_spheres_masker import apply_mask_and_get_affinity
26from .. import masking
27from ..image.resampling import coord_transform
29ESTIMATOR_CATALOG = {"svc": svm.LinearSVC, "svr": svm.SVR}
32@fill_doc
33def search_light(
34 X,
35 y,
36 estimator,
37 A,
38 groups=None,
39 scoring=None,
40 cv=None,
41 n_jobs=-1,
42 verbose=0,
43):
44 """Compute a search_light.
46 Parameters
47 ----------
48 X : array-like of shape at least 2D
49 data to fit.
51 y : array-like
52 target variable to predict.
54 estimator : estimator object implementing 'fit'
55 object to use to fit the data
57 A : scipy sparse matrix.
58 adjacency matrix. Defines for each feature the neighboring features
59 following a given structure of the data.
61 groups : array-like, default=None
62 group label for each sample for cross validation.
64 scoring : :obj:`str` or callable or None, default=None
65 The scoring strategy to use. See the scikit-learn documentation
66 for possible values.
67 If callable, it takes as arguments the fitted estimator, the
68 test data (X_test) and the test target (y_test) if y is
69 not None.
71 cv : cross-validation generator, default=None
72 A cross-validation generator. If None, a 3-fold cross
73 validation is used or 3-fold stratified cross-validation
74 when y is supplied.
76 %(n_jobs_all)s
78 %(verbose0)s
80 Returns
81 -------
82 scores : array-like of shape (number of rows in A)
83 search_light scores
84 """
85 group_iter = GroupIterator(A.shape[0], n_jobs)
86 scores = Parallel(n_jobs=n_jobs, verbose=verbose)(
87 delayed(_group_iter_search_light)(
88 A.rows[list_i],
89 estimator,
90 X,
91 y,
92 groups,
93 scoring,
94 cv,
95 thread_id + 1,
96 A.shape[0],
97 verbose,
98 )
99 for thread_id, list_i in enumerate(group_iter)
100 )
101 return np.concatenate(scores)
104@fill_doc
105class GroupIterator:
106 """Group iterator.
108 Provides group of features for search_light loop
109 that may be used with Parallel.
111 Parameters
112 ----------
113 n_features : :obj:`int`
114 Total number of features
115 %(n_jobs)s
117 """
119 def __init__(self, n_features, n_jobs=1):
120 self.n_features = n_features
121 if n_jobs == -1:
122 n_jobs = cpu_count()
123 self.n_jobs = n_jobs
124 check_params(self.__dict__)
126 def __iter__(self):
127 yield from np.array_split(np.arange(self.n_features), self.n_jobs)
130def _group_iter_search_light(
131 list_rows,
132 estimator,
133 X,
134 y,
135 groups,
136 scoring,
137 cv,
138 thread_id,
139 total,
140 verbose=0,
141):
142 """Perform grouped iterations of search_light.
144 Parameters
145 ----------
146 list_rows : array of arrays of int
147 adjacency rows. For a voxel with index i in X, list_rows[i] is the list
148 of neighboring voxels indices (in X).
150 estimator : estimator object implementing 'fit'
151 object to use to fit the data
153 X : array-like of shape at least 2D
154 data to fit.
156 y : array-like or None
157 Target variable to predict. If `y` is provided, it must be
158 an array-like object
159 with the same length as the number of samples in `X`.
160 When `y` is `None`, a dummy
161 target is generated internally with half the samples
162 labeled as `0` and the other
163 half labeled as `1`. This is useful during transformations
164 where the model is applied without ground truth labels.
166 groups : array-like, optional
167 group label for each sample for cross validation.
169 scoring : string or callable, optional
170 Scoring strategy to use. See the scikit-learn documentation.
171 If callable, takes as arguments the fitted estimator, the
172 test data (X_test) and the test target (y_test) if y is
173 not None.
175 cv : cross-validation generator, optional
176 A cross-validation generator. If None, a 3-fold cross validation is
177 used or 3-fold stratified cross-validation when y is supplied.
179 thread_id : int
180 process id, used for display.
182 total : int
183 Total number of voxels, used for display
185 %(verbose0)s
187 Returns
188 -------
189 par_scores : numpy.ndarray
190 score for each voxel. dtype: float64.
191 """
192 par_scores = np.zeros(len(list_rows))
193 t0 = time.time()
194 for i, row in enumerate(list_rows):
195 kwargs = {"scoring": scoring, "groups": groups}
196 if isinstance(cv, KFold):
197 kwargs = {"scoring": scoring}
199 with warnings.catch_warnings(): # might not converge
200 warnings.simplefilter("ignore", ConvergenceWarning)
201 if y is None:
202 y_dummy = np.array(
203 [0] * (X.shape[0] // 2) + [1] * (X.shape[0] // 2)
204 )
205 estimator.fit(
206 X[:, row], y_dummy[: X.shape[0]]
207 ) # Ensure the size matches X
208 par_scores[i] = np.mean(estimator.decision_function(X[:, row]))
209 else:
210 par_scores[i] = np.mean(
211 cross_val_score(
212 estimator, X[:, row], y, cv=cv, n_jobs=1, **kwargs
213 )
214 )
216 if verbose > 0:
217 # One can't print less than each 10 iterations
218 step = 11 - min(verbose, 10)
219 if i % step == 0:
220 # If there is only one job, progress information is fixed
221 crlf = "\r" if total == len(list_rows) else "\n"
222 percent = float(i) / len(list_rows)
223 percent = round(percent * 100, 2)
224 dt = time.time() - t0
225 # We use a max to avoid a division by zero
226 remaining = (100.0 - percent) / max(0.01, percent) * dt
227 logger.log(
228 f"Job #{thread_id}, processed {i}/{len(list_rows)} steps "
229 f"({percent:0.2f}%, "
230 f"{remaining:0.1f} seconds remaining){crlf}",
231 )
232 return par_scores
235##############################################################################
236# Class for search_light #####################################################
237##############################################################################
238@fill_doc
239class SearchLight(TransformerMixin, BaseEstimator):
240 """Implement search_light analysis using an arbitrary type of classifier.
242 Parameters
243 ----------
244 mask_img : Niimg-like object or None,
245 See :ref:`extracting_data`.
246 Boolean image giving location of voxels containing usable signals.
248 process_mask_img : Niimg-like object, optional
249 See :ref:`extracting_data`.
250 Boolean image giving voxels on which searchlight should be
251 computed.
253 radius : :obj:`float`, default=2.
254 radius of the searchlight ball, in millimeters.
256 estimator : 'svr', 'svc', or an estimator object implementing 'fit'
257 The object to use to fit the data
259 %(n_jobs)s
261 scoring : :obj:`str` or callable, optional
262 The scoring strategy to use. See the scikit-learn documentation
263 If callable, takes as arguments the fitted estimator, the
264 test data (X_test) and the test target (y_test) if y is
265 not None.
267 cv : cross-validation generator, optional
268 A cross-validation generator. If None, a 3-fold cross
269 validation is used or 3-fold stratified cross-validation
270 when y is supplied.
272 %(verbose0)s
274 Attributes
275 ----------
276 scores_ : numpy.ndarray
277 3D array containing searchlight scores for each voxel, aligned
278 with the mask.
280 .. versionadded:: 0.11.0
282 process_mask_ : numpy.ndarray
283 Boolean mask array representing the voxels included in the
284 searchlight computation.
286 .. versionadded:: 0.11.0
288 masked_scores_ : numpy.ndarray
289 1D array containing the searchlight scores corresponding
290 to the masked region only.
292 .. versionadded:: 0.11.0
294 Notes
295 -----
296 The searchlight [Kriegeskorte 06] is a widely used approach for the
297 study of the fine-grained patterns of information in fMRI analysis.
298 Its principle is relatively simple: a small group of neighboring
299 features is extracted from the data, and the prediction function is
300 instantiated on these features only. The resulting prediction
301 accuracy is thus associated with all the features within the group,
302 or only with the feature on the center. This yields a map of local
303 fine-grained information, that can be used for assessing hypothesis
304 on the local spatial layout of the neural code under investigation.
306 Nikolaus Kriegeskorte, Rainer Goebel & Peter Bandettini.
307 Information-based functional brain mapping.
308 Proceedings of the National Academy of Sciences
309 of the United States of America,
310 vol. 103, no. 10, pages 3863-3868, March 2006
311 """
313 def __init__(
314 self,
315 mask_img=None,
316 process_mask_img=None,
317 radius=2.0,
318 estimator="svc",
319 n_jobs=1,
320 scoring=None,
321 cv=None,
322 verbose=0,
323 ):
324 self.mask_img = mask_img
325 self.process_mask_img = process_mask_img
326 self.radius = radius
327 self.estimator = estimator
328 self.n_jobs = n_jobs
329 self.scoring = scoring
330 self.cv = cv
331 self.verbose = verbose
333 def _more_tags(self):
334 """Return estimator tags.
336 TODO remove when bumping sklearn_version > 1.5
337 """
338 return self.__sklearn_tags__()
340 def __sklearn_tags__(self):
341 """Return estimator tags.
343 See the sklearn documentation for more details on tags
344 https://scikit-learn.org/1.6/developers/develop.html#estimator-tags
345 """
346 # TODO
347 # get rid of if block
348 # bumping sklearn_version > 1.5
350 if SKLEARN_LT_1_6:
351 from nilearn._utils.tags import tags
353 return tags()
355 from sklearn.utils import ClassifierTags, RegressorTags
357 from nilearn._utils.tags import InputTags
359 tags = super().__sklearn_tags__()
360 tags.input_tags = InputTags(surf_img=True)
362 if self.estimator == "svr":
363 if SKLEARN_LT_1_6:
364 tags["multioutput"] = True
365 return tags
366 tags.estimator_type = "regressor"
367 tags.regressor_tags = RegressorTags()
369 elif self.estimator == "svc":
370 if SKLEARN_LT_1_6:
371 return tags
372 tags.estimator_type = "classifier"
373 tags.classifier_tags = ClassifierTags()
375 return tags
377 @property
378 def _estimator_type(self):
379 # TODO rm sklearn>=1.6
380 if self.estimator == "svr":
381 return "regressor"
382 elif self.estimator == "svc":
383 return "classifier"
384 return ""
386 def fit(self, imgs, y, groups=None):
387 """Fit the searchlight.
389 Parameters
390 ----------
391 imgs : Niimg-like object
392 See :ref:`extracting_data`.
393 4D image.
395 y : 1D array-like
396 Target variable to predict. Must have exactly as many elements as
397 3D images in img.
399 groups : array-like, default=None
400 group label for each sample for cross validation. Must have
401 exactly as many elements as 3D images in img.
402 """
403 check_params(self.__dict__)
405 # check if image is 4D
406 imgs = check_niimg_4d(imgs)
408 check_array(y, ensure_2d=False, dtype=None)
410 # Get the seeds
411 self.mask_img_ = deepcopy(self.mask_img)
412 if self.mask_img_ is not None:
413 self.mask_img_ = check_niimg_3d(self.mask_img_)
414 process_mask_img = self.process_mask_img or self.mask_img_
416 # Compute world coordinates of the seeds
417 process_mask, process_mask_affine = masking.load_mask_img(
418 process_mask_img
419 )
421 self.process_mask_ = process_mask
422 process_mask_coords = np.where(process_mask != 0)
423 process_mask_coords = coord_transform(
424 process_mask_coords[0],
425 process_mask_coords[1],
426 process_mask_coords[2],
427 process_mask_affine,
428 )
429 process_mask_coords = np.asarray(process_mask_coords).T
431 X, A = apply_mask_and_get_affinity(
432 process_mask_coords,
433 imgs,
434 self.radius,
435 True,
436 mask_img=self.mask_img_,
437 )
439 estimator = self.estimator
440 if estimator == "svc":
441 estimator = ESTIMATOR_CATALOG[estimator](dual=True)
442 elif isinstance(estimator, str):
443 estimator = ESTIMATOR_CATALOG[estimator]()
445 scores = search_light(
446 X,
447 y,
448 estimator,
449 A,
450 groups,
451 self.scoring,
452 self.cv,
453 self.n_jobs,
454 self.verbose,
455 )
456 self.masked_scores_ = scores
457 self.scores_ = np.zeros(process_mask.shape)
458 self.scores_[np.where(process_mask)] = scores
459 return self
461 def __sklearn_is_fitted__(self):
462 return (
463 hasattr(self, "scores_")
464 and hasattr(self, "process_mask_")
465 and hasattr(self, "mask_img_")
466 and self.scores_ is not None
467 and self.process_mask_ is not None
468 )
470 @property
471 def scores_img_(self):
472 """Convert the 3D scores array into a NIfTI image."""
473 check_is_fitted(self)
474 return new_img_like(self.mask_img_, self.scores_)
476 def transform(self, imgs):
477 """Apply the fitted searchlight on new images."""
478 check_is_fitted(self)
480 imgs = check_niimg_4d(imgs)
482 X, A = apply_mask_and_get_affinity(
483 np.asarray(np.where(self.process_mask_)).T,
484 imgs,
485 self.radius,
486 True,
487 mask_img=self.mask_img_,
488 )
490 estimator = self.estimator
491 if estimator == "svc":
492 estimator = ESTIMATOR_CATALOG[estimator](dual=True)
494 # Use the modified `_group_iter_search_light` logic to avoid `y` issues
495 result = search_light(
496 X,
497 None,
498 estimator,
499 A,
500 None,
501 self.scoring,
502 self.cv,
503 self.n_jobs,
504 self.verbose,
505 )
507 reshaped_result = np.zeros(self.process_mask_.shape)
508 reshaped_result[np.where(self.process_mask_)] = result
509 reshaped_result = np.abs(reshaped_result)
511 return reshaped_result
513 def set_output(self, *, transform=None):
514 """Set the output container when ``"transform"`` is called.
516 .. warning::
518 This has not been implemented yet.
519 """
520 raise NotImplementedError()