Coverage for nilearn/decoding/searchlight.py: 22%

142 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""The searchlight is a widely used approach for the study \ 

2of the fine-grained patterns of information in fMRI analysis, \ 

3in which multivariate statistical relationships are iteratively tested \ 

4in the neighborhood of each location of a domain. 

5""" 

6 

7import time 

8import warnings 

9from copy import deepcopy 

10 

11import numpy as np 

12from joblib import Parallel, cpu_count, delayed 

13from sklearn import svm 

14from sklearn.base import BaseEstimator, TransformerMixin 

15from sklearn.exceptions import ConvergenceWarning 

16from sklearn.model_selection import KFold, cross_val_score 

17from sklearn.utils import check_array 

18from sklearn.utils.estimator_checks import check_is_fitted 

19 

20from nilearn._utils import check_niimg_3d, check_niimg_4d, fill_doc, logger 

21from nilearn._utils.param_validation import check_params 

22from nilearn._utils.tags import SKLEARN_LT_1_6 

23from nilearn.image import new_img_like 

24from nilearn.maskers.nifti_spheres_masker import apply_mask_and_get_affinity 

25 

26from .. import masking 

27from ..image.resampling import coord_transform 

28 

29ESTIMATOR_CATALOG = {"svc": svm.LinearSVC, "svr": svm.SVR} 

30 

31 

32@fill_doc 

33def search_light( 

34 X, 

35 y, 

36 estimator, 

37 A, 

38 groups=None, 

39 scoring=None, 

40 cv=None, 

41 n_jobs=-1, 

42 verbose=0, 

43): 

44 """Compute a search_light. 

45 

46 Parameters 

47 ---------- 

48 X : array-like of shape at least 2D 

49 data to fit. 

50 

51 y : array-like 

52 target variable to predict. 

53 

54 estimator : estimator object implementing 'fit' 

55 object to use to fit the data 

56 

57 A : scipy sparse matrix. 

58 adjacency matrix. Defines for each feature the neighboring features 

59 following a given structure of the data. 

60 

61 groups : array-like, default=None 

62 group label for each sample for cross validation. 

63 

64 scoring : :obj:`str` or callable or None, default=None 

65 The scoring strategy to use. See the scikit-learn documentation 

66 for possible values. 

67 If callable, it takes as arguments the fitted estimator, the 

68 test data (X_test) and the test target (y_test) if y is 

69 not None. 

70 

71 cv : cross-validation generator, default=None 

72 A cross-validation generator. If None, a 3-fold cross 

73 validation is used or 3-fold stratified cross-validation 

74 when y is supplied. 

75 

76 %(n_jobs_all)s 

77 

78 %(verbose0)s 

79 

80 Returns 

81 ------- 

82 scores : array-like of shape (number of rows in A) 

83 search_light scores 

84 """ 

85 group_iter = GroupIterator(A.shape[0], n_jobs) 

86 scores = Parallel(n_jobs=n_jobs, verbose=verbose)( 

87 delayed(_group_iter_search_light)( 

88 A.rows[list_i], 

89 estimator, 

90 X, 

91 y, 

92 groups, 

93 scoring, 

94 cv, 

95 thread_id + 1, 

96 A.shape[0], 

97 verbose, 

98 ) 

99 for thread_id, list_i in enumerate(group_iter) 

100 ) 

101 return np.concatenate(scores) 

102 

103 

104@fill_doc 

105class GroupIterator: 

106 """Group iterator. 

107 

108 Provides group of features for search_light loop 

109 that may be used with Parallel. 

110 

111 Parameters 

112 ---------- 

113 n_features : :obj:`int` 

114 Total number of features 

115 %(n_jobs)s 

116 

117 """ 

118 

119 def __init__(self, n_features, n_jobs=1): 

120 self.n_features = n_features 

121 if n_jobs == -1: 

122 n_jobs = cpu_count() 

123 self.n_jobs = n_jobs 

124 check_params(self.__dict__) 

125 

126 def __iter__(self): 

127 yield from np.array_split(np.arange(self.n_features), self.n_jobs) 

128 

129 

130def _group_iter_search_light( 

131 list_rows, 

132 estimator, 

133 X, 

134 y, 

135 groups, 

136 scoring, 

137 cv, 

138 thread_id, 

139 total, 

140 verbose=0, 

141): 

142 """Perform grouped iterations of search_light. 

143 

144 Parameters 

145 ---------- 

146 list_rows : array of arrays of int 

147 adjacency rows. For a voxel with index i in X, list_rows[i] is the list 

148 of neighboring voxels indices (in X). 

149 

150 estimator : estimator object implementing 'fit' 

151 object to use to fit the data 

152 

153 X : array-like of shape at least 2D 

154 data to fit. 

155 

156 y : array-like or None 

157 Target variable to predict. If `y` is provided, it must be 

158 an array-like object 

159 with the same length as the number of samples in `X`. 

160 When `y` is `None`, a dummy 

161 target is generated internally with half the samples 

162 labeled as `0` and the other 

163 half labeled as `1`. This is useful during transformations 

164 where the model is applied without ground truth labels. 

165 

166 groups : array-like, optional 

167 group label for each sample for cross validation. 

168 

169 scoring : string or callable, optional 

170 Scoring strategy to use. See the scikit-learn documentation. 

171 If callable, takes as arguments the fitted estimator, the 

172 test data (X_test) and the test target (y_test) if y is 

173 not None. 

174 

175 cv : cross-validation generator, optional 

176 A cross-validation generator. If None, a 3-fold cross validation is 

177 used or 3-fold stratified cross-validation when y is supplied. 

178 

179 thread_id : int 

180 process id, used for display. 

181 

182 total : int 

183 Total number of voxels, used for display 

184 

185 %(verbose0)s 

186 

187 Returns 

188 ------- 

189 par_scores : numpy.ndarray 

190 score for each voxel. dtype: float64. 

191 """ 

192 par_scores = np.zeros(len(list_rows)) 

193 t0 = time.time() 

194 for i, row in enumerate(list_rows): 

195 kwargs = {"scoring": scoring, "groups": groups} 

196 if isinstance(cv, KFold): 

197 kwargs = {"scoring": scoring} 

198 

199 with warnings.catch_warnings(): # might not converge 

200 warnings.simplefilter("ignore", ConvergenceWarning) 

201 if y is None: 

202 y_dummy = np.array( 

203 [0] * (X.shape[0] // 2) + [1] * (X.shape[0] // 2) 

204 ) 

205 estimator.fit( 

206 X[:, row], y_dummy[: X.shape[0]] 

207 ) # Ensure the size matches X 

208 par_scores[i] = np.mean(estimator.decision_function(X[:, row])) 

209 else: 

210 par_scores[i] = np.mean( 

211 cross_val_score( 

212 estimator, X[:, row], y, cv=cv, n_jobs=1, **kwargs 

213 ) 

214 ) 

215 

216 if verbose > 0: 

217 # One can't print less than each 10 iterations 

218 step = 11 - min(verbose, 10) 

219 if i % step == 0: 

220 # If there is only one job, progress information is fixed 

221 crlf = "\r" if total == len(list_rows) else "\n" 

222 percent = float(i) / len(list_rows) 

223 percent = round(percent * 100, 2) 

224 dt = time.time() - t0 

225 # We use a max to avoid a division by zero 

226 remaining = (100.0 - percent) / max(0.01, percent) * dt 

227 logger.log( 

228 f"Job #{thread_id}, processed {i}/{len(list_rows)} steps " 

229 f"({percent:0.2f}%, " 

230 f"{remaining:0.1f} seconds remaining){crlf}", 

231 ) 

232 return par_scores 

233 

234 

235############################################################################## 

236# Class for search_light ##################################################### 

237############################################################################## 

238@fill_doc 

239class SearchLight(TransformerMixin, BaseEstimator): 

240 """Implement search_light analysis using an arbitrary type of classifier. 

241 

242 Parameters 

243 ---------- 

244 mask_img : Niimg-like object or None, 

245 See :ref:`extracting_data`. 

246 Boolean image giving location of voxels containing usable signals. 

247 

248 process_mask_img : Niimg-like object, optional 

249 See :ref:`extracting_data`. 

250 Boolean image giving voxels on which searchlight should be 

251 computed. 

252 

253 radius : :obj:`float`, default=2. 

254 radius of the searchlight ball, in millimeters. 

255 

256 estimator : 'svr', 'svc', or an estimator object implementing 'fit' 

257 The object to use to fit the data 

258 

259 %(n_jobs)s 

260 

261 scoring : :obj:`str` or callable, optional 

262 The scoring strategy to use. See the scikit-learn documentation 

263 If callable, takes as arguments the fitted estimator, the 

264 test data (X_test) and the test target (y_test) if y is 

265 not None. 

266 

267 cv : cross-validation generator, optional 

268 A cross-validation generator. If None, a 3-fold cross 

269 validation is used or 3-fold stratified cross-validation 

270 when y is supplied. 

271 

272 %(verbose0)s 

273 

274 Attributes 

275 ---------- 

276 scores_ : numpy.ndarray 

277 3D array containing searchlight scores for each voxel, aligned 

278 with the mask. 

279 

280 .. versionadded:: 0.11.0 

281 

282 process_mask_ : numpy.ndarray 

283 Boolean mask array representing the voxels included in the 

284 searchlight computation. 

285 

286 .. versionadded:: 0.11.0 

287 

288 masked_scores_ : numpy.ndarray 

289 1D array containing the searchlight scores corresponding 

290 to the masked region only. 

291 

292 .. versionadded:: 0.11.0 

293 

294 Notes 

295 ----- 

296 The searchlight [Kriegeskorte 06] is a widely used approach for the 

297 study of the fine-grained patterns of information in fMRI analysis. 

298 Its principle is relatively simple: a small group of neighboring 

299 features is extracted from the data, and the prediction function is 

300 instantiated on these features only. The resulting prediction 

301 accuracy is thus associated with all the features within the group, 

302 or only with the feature on the center. This yields a map of local 

303 fine-grained information, that can be used for assessing hypothesis 

304 on the local spatial layout of the neural code under investigation. 

305 

306 Nikolaus Kriegeskorte, Rainer Goebel & Peter Bandettini. 

307 Information-based functional brain mapping. 

308 Proceedings of the National Academy of Sciences 

309 of the United States of America, 

310 vol. 103, no. 10, pages 3863-3868, March 2006 

311 """ 

312 

313 def __init__( 

314 self, 

315 mask_img=None, 

316 process_mask_img=None, 

317 radius=2.0, 

318 estimator="svc", 

319 n_jobs=1, 

320 scoring=None, 

321 cv=None, 

322 verbose=0, 

323 ): 

324 self.mask_img = mask_img 

325 self.process_mask_img = process_mask_img 

326 self.radius = radius 

327 self.estimator = estimator 

328 self.n_jobs = n_jobs 

329 self.scoring = scoring 

330 self.cv = cv 

331 self.verbose = verbose 

332 

333 def _more_tags(self): 

334 """Return estimator tags. 

335 

336 TODO remove when bumping sklearn_version > 1.5 

337 """ 

338 return self.__sklearn_tags__() 

339 

340 def __sklearn_tags__(self): 

341 """Return estimator tags. 

342 

343 See the sklearn documentation for more details on tags 

344 https://scikit-learn.org/1.6/developers/develop.html#estimator-tags 

345 """ 

346 # TODO 

347 # get rid of if block 

348 # bumping sklearn_version > 1.5 

349 

350 if SKLEARN_LT_1_6: 

351 from nilearn._utils.tags import tags 

352 

353 return tags() 

354 

355 from sklearn.utils import ClassifierTags, RegressorTags 

356 

357 from nilearn._utils.tags import InputTags 

358 

359 tags = super().__sklearn_tags__() 

360 tags.input_tags = InputTags(surf_img=True) 

361 

362 if self.estimator == "svr": 

363 if SKLEARN_LT_1_6: 

364 tags["multioutput"] = True 

365 return tags 

366 tags.estimator_type = "regressor" 

367 tags.regressor_tags = RegressorTags() 

368 

369 elif self.estimator == "svc": 

370 if SKLEARN_LT_1_6: 

371 return tags 

372 tags.estimator_type = "classifier" 

373 tags.classifier_tags = ClassifierTags() 

374 

375 return tags 

376 

377 @property 

378 def _estimator_type(self): 

379 # TODO rm sklearn>=1.6 

380 if self.estimator == "svr": 

381 return "regressor" 

382 elif self.estimator == "svc": 

383 return "classifier" 

384 return "" 

385 

386 def fit(self, imgs, y, groups=None): 

387 """Fit the searchlight. 

388 

389 Parameters 

390 ---------- 

391 imgs : Niimg-like object 

392 See :ref:`extracting_data`. 

393 4D image. 

394 

395 y : 1D array-like 

396 Target variable to predict. Must have exactly as many elements as 

397 3D images in img. 

398 

399 groups : array-like, default=None 

400 group label for each sample for cross validation. Must have 

401 exactly as many elements as 3D images in img. 

402 """ 

403 check_params(self.__dict__) 

404 

405 # check if image is 4D 

406 imgs = check_niimg_4d(imgs) 

407 

408 check_array(y, ensure_2d=False, dtype=None) 

409 

410 # Get the seeds 

411 self.mask_img_ = deepcopy(self.mask_img) 

412 if self.mask_img_ is not None: 

413 self.mask_img_ = check_niimg_3d(self.mask_img_) 

414 process_mask_img = self.process_mask_img or self.mask_img_ 

415 

416 # Compute world coordinates of the seeds 

417 process_mask, process_mask_affine = masking.load_mask_img( 

418 process_mask_img 

419 ) 

420 

421 self.process_mask_ = process_mask 

422 process_mask_coords = np.where(process_mask != 0) 

423 process_mask_coords = coord_transform( 

424 process_mask_coords[0], 

425 process_mask_coords[1], 

426 process_mask_coords[2], 

427 process_mask_affine, 

428 ) 

429 process_mask_coords = np.asarray(process_mask_coords).T 

430 

431 X, A = apply_mask_and_get_affinity( 

432 process_mask_coords, 

433 imgs, 

434 self.radius, 

435 True, 

436 mask_img=self.mask_img_, 

437 ) 

438 

439 estimator = self.estimator 

440 if estimator == "svc": 

441 estimator = ESTIMATOR_CATALOG[estimator](dual=True) 

442 elif isinstance(estimator, str): 

443 estimator = ESTIMATOR_CATALOG[estimator]() 

444 

445 scores = search_light( 

446 X, 

447 y, 

448 estimator, 

449 A, 

450 groups, 

451 self.scoring, 

452 self.cv, 

453 self.n_jobs, 

454 self.verbose, 

455 ) 

456 self.masked_scores_ = scores 

457 self.scores_ = np.zeros(process_mask.shape) 

458 self.scores_[np.where(process_mask)] = scores 

459 return self 

460 

461 def __sklearn_is_fitted__(self): 

462 return ( 

463 hasattr(self, "scores_") 

464 and hasattr(self, "process_mask_") 

465 and hasattr(self, "mask_img_") 

466 and self.scores_ is not None 

467 and self.process_mask_ is not None 

468 ) 

469 

470 @property 

471 def scores_img_(self): 

472 """Convert the 3D scores array into a NIfTI image.""" 

473 check_is_fitted(self) 

474 return new_img_like(self.mask_img_, self.scores_) 

475 

476 def transform(self, imgs): 

477 """Apply the fitted searchlight on new images.""" 

478 check_is_fitted(self) 

479 

480 imgs = check_niimg_4d(imgs) 

481 

482 X, A = apply_mask_and_get_affinity( 

483 np.asarray(np.where(self.process_mask_)).T, 

484 imgs, 

485 self.radius, 

486 True, 

487 mask_img=self.mask_img_, 

488 ) 

489 

490 estimator = self.estimator 

491 if estimator == "svc": 

492 estimator = ESTIMATOR_CATALOG[estimator](dual=True) 

493 

494 # Use the modified `_group_iter_search_light` logic to avoid `y` issues 

495 result = search_light( 

496 X, 

497 None, 

498 estimator, 

499 A, 

500 None, 

501 self.scoring, 

502 self.cv, 

503 self.n_jobs, 

504 self.verbose, 

505 ) 

506 

507 reshaped_result = np.zeros(self.process_mask_.shape) 

508 reshaped_result[np.where(self.process_mask_)] = result 

509 reshaped_result = np.abs(reshaped_result) 

510 

511 return reshaped_result 

512 

513 def set_output(self, *, transform=None): 

514 """Set the output container when ``"transform"`` is called. 

515 

516 .. warning:: 

517 

518 This has not been implemented yet. 

519 """ 

520 raise NotImplementedError()