Coverage for nilearn/maskers/multi_nifti_masker.py: 18%

128 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Transformer used to apply basic transformations \ 

2on multi subject MRI data. 

3""" 

4 

5import collections.abc 

6import itertools 

7import warnings 

8from functools import partial 

9 

10import numpy as np 

11from joblib import Parallel, delayed 

12from sklearn.utils.estimator_checks import check_is_fitted 

13 

14from nilearn._utils import ( 

15 fill_doc, 

16 logger, 

17 repr_niimgs, 

18 stringify_path, 

19) 

20from nilearn._utils.class_inspect import ( 

21 get_params, 

22) 

23from nilearn._utils.logger import find_stack_level 

24from nilearn._utils.niimg_conversions import iter_check_niimg 

25from nilearn._utils.param_validation import check_params 

26from nilearn._utils.tags import SKLEARN_LT_1_6 

27from nilearn.image import ( 

28 resample_img, 

29) 

30from nilearn.maskers._utils import compute_middle_image 

31from nilearn.maskers.base_masker import prepare_confounds_multimaskers 

32from nilearn.maskers.nifti_masker import NiftiMasker, filter_and_mask 

33from nilearn.masking import ( 

34 compute_multi_background_mask, 

35 compute_multi_brain_mask, 

36 compute_multi_epi_mask, 

37 load_mask_img, 

38) 

39from nilearn.typing import NiimgLike 

40 

41 

42def _get_mask_strategy(strategy): 

43 """Return the mask computing method based on a provided strategy.""" 

44 if strategy == "background": 

45 return compute_multi_background_mask 

46 elif strategy == "epi": 

47 return compute_multi_epi_mask 

48 elif strategy == "whole-brain-template": 

49 return partial(compute_multi_brain_mask, mask_type="whole-brain") 

50 elif strategy == "gm-template": 

51 return partial(compute_multi_brain_mask, mask_type="gm") 

52 elif strategy == "wm-template": 

53 return partial(compute_multi_brain_mask, mask_type="wm") 

54 elif strategy == "template": 

55 warnings.warn( 

56 "Masking strategy 'template' is deprecated. " 

57 "Please use 'whole-brain-template' instead.", 

58 stacklevel=find_stack_level(), 

59 ) 

60 return partial(compute_multi_brain_mask, mask_type="whole-brain") 

61 else: 

62 raise ValueError( 

63 f"Unknown value of mask_strategy '{strategy}'. " 

64 "Acceptable values are 'background', " 

65 "'epi', 'whole-brain-template', " 

66 "'gm-template', and 'wm-template'." 

67 ) 

68 

69 

70@fill_doc 

71class MultiNiftiMasker(NiftiMasker): 

72 """Applying a mask to extract time-series from multiple Niimg-like objects. 

73 

74 MultiNiftiMasker is useful when dealing with image sets from multiple 

75 subjects. 

76 

77 Use case: 

78 integrates well with decomposition by MultiPCA and CanICA 

79 (multi-subject models) 

80 

81 Parameters 

82 ---------- 

83 mask_img : Niimg-like object 

84 See :ref:`extracting_data`. 

85 Mask of the data. If not given, a mask is computed in the fit step. 

86 Optional parameters can be set using mask_args and mask_strategy to 

87 fine tune the mask extraction. 

88 

89 %(smoothing_fwhm)s 

90 

91 %(standardize_maskers)s 

92 

93 %(standardize_confounds)s 

94 

95 high_variance_confounds : :obj:`bool`, default=False 

96 If True, high variance confounds are computed on provided image with 

97 :func:`nilearn.image.high_variance_confounds` and default parameters 

98 and regressed out. 

99 

100 %(detrend)s 

101 

102 %(low_pass)s 

103 

104 %(high_pass)s 

105 

106 %(t_r)s 

107 

108 %(target_affine)s 

109 

110 .. note:: 

111 This parameter is passed to :func:`nilearn.image.resample_img`. 

112 

113 %(target_shape)s 

114 

115 .. note:: 

116 This parameter is passed to :func:`nilearn.image.resample_img`. 

117 

118 %(mask_strategy)s 

119 

120 .. note:: 

121 Depending on this value, the mask will be computed from 

122 :func:`nilearn.masking.compute_multi_background_mask`, 

123 :func:`nilearn.masking.compute_multi_epi_mask`, or 

124 :func:`nilearn.masking.compute_multi_brain_mask`. 

125 

126 Default='background'. 

127 

128 mask_args : :obj:`dict`, optional 

129 If mask is None, these are additional parameters passed to 

130 :func:`nilearn.masking.compute_background_mask`, 

131 or :func:`nilearn.masking.compute_epi_mask` 

132 to fine-tune mask computation. 

133 Please see the related documentation for details. 

134 

135 %(dtype)s 

136 

137 %(memory)s 

138 

139 %(memory_level)s 

140 

141 %(n_jobs)s 

142 

143 %(verbose0)s 

144 

145 %(clean_args)s 

146 

147 %(masker_kwargs)s 

148 

149 Attributes 

150 ---------- 

151 mask_img_ : A 3D binary :obj:`nibabel.nifti1.Nifti1Image` 

152 The mask of the data, or the one computed from ``imgs`` passed to fit. 

153 If a ``mask_img`` is passed at masker construction, 

154 then ``mask_img_`` is the resulting binarized version of it 

155 where each voxel is ``True`` if all values across samples 

156 (for example across timepoints) is finite value different from 0. 

157 

158 affine_ : 4x4 :obj:`numpy.ndarray` 

159 Affine of the transformed image. 

160 

161 n_elements_ : :obj:`int` 

162 The number of voxels in the mask. 

163 

164 .. versionadded:: 0.9.2 

165 

166 See Also 

167 -------- 

168 nilearn.image.resample_img: image resampling 

169 nilearn.masking.compute_epi_mask: mask computation 

170 nilearn.masking.apply_mask: mask application on image 

171 nilearn.signal.clean: confounds removal and general filtering of signals 

172 

173 """ 

174 

175 def __init__( 

176 self, 

177 mask_img=None, 

178 smoothing_fwhm=None, 

179 standardize=False, 

180 standardize_confounds=True, 

181 detrend=False, 

182 high_variance_confounds=False, 

183 low_pass=None, 

184 high_pass=None, 

185 t_r=None, 

186 target_affine=None, 

187 target_shape=None, 

188 mask_strategy="background", 

189 mask_args=None, 

190 dtype=None, 

191 memory=None, 

192 memory_level=0, 

193 n_jobs=1, 

194 verbose=0, 

195 cmap="CMRmap_r", 

196 clean_args=None, 

197 **kwargs, # TODO remove when bumping to nilearn >0.13 

198 ): 

199 super().__init__( 

200 # Mask is provided or computed 

201 mask_img=mask_img, 

202 smoothing_fwhm=smoothing_fwhm, 

203 standardize=standardize, 

204 standardize_confounds=standardize_confounds, 

205 high_variance_confounds=high_variance_confounds, 

206 detrend=detrend, 

207 low_pass=low_pass, 

208 high_pass=high_pass, 

209 t_r=t_r, 

210 target_affine=target_affine, 

211 target_shape=target_shape, 

212 mask_strategy=mask_strategy, 

213 mask_args=mask_args, 

214 dtype=dtype, 

215 memory=memory, 

216 memory_level=memory_level, 

217 verbose=verbose, 

218 cmap=cmap, 

219 clean_args=clean_args, 

220 # TODO remove when bumping to nilearn >0.13 

221 **kwargs, 

222 ) 

223 self.n_jobs = n_jobs 

224 

225 def __sklearn_tags__(self): 

226 """Return estimator tags. 

227 

228 See the sklearn documentation for more details on tags 

229 https://scikit-learn.org/1.6/developers/develop.html#estimator-tags 

230 """ 

231 # TODO 

232 # get rid of if block 

233 # bumping sklearn_version > 1.5 

234 if SKLEARN_LT_1_6: 

235 from nilearn._utils.tags import tags 

236 

237 return tags(masker=True, multi_masker=True) 

238 

239 from nilearn._utils.tags import InputTags 

240 

241 tags = super().__sklearn_tags__() 

242 tags.input_tags = InputTags(masker=True, multi_masker=True) 

243 return tags 

244 

245 @fill_doc 

246 def fit( 

247 self, 

248 imgs=None, 

249 y=None, 

250 ): 

251 """Compute the mask corresponding to the data. 

252 

253 Parameters 

254 ---------- 

255 imgs : Niimg-like objects, :obj:`list` of Niimg-like objects or None, \ 

256 default=None 

257 See :ref:`extracting_data`. 

258 Data on which the mask must be calculated. 

259 If this is a list, the affine is considered the same for all. 

260 

261 %(y_dummy)s 

262 

263 """ 

264 del y 

265 check_params(self.__dict__) 

266 if getattr(self, "_shelving", None) is None: 

267 self._shelving = False 

268 

269 self._report_content = { 

270 "description": ( 

271 "This report shows the input Nifti image overlaid " 

272 "with the outlines of the mask (in green). We " 

273 "recommend to inspect the report for the overlap " 

274 "between the mask and its input image. " 

275 ), 

276 "warning_message": None, 

277 "n_elements": 0, 

278 "coverage": 0, 

279 } 

280 self._overlay_text = ( 

281 "\n To see the input Nifti image before resampling, " 

282 "hover over the displayed image." 

283 ) 

284 

285 self._sanitize_cleaning_parameters() 

286 self.clean_args_ = {} if self.clean_args is None else self.clean_args 

287 

288 self.mask_img_ = self._load_mask(imgs) 

289 

290 if imgs is not None: 

291 logger.log( 

292 f"Loading data from {repr_niimgs(imgs, shorten=False)}.", 

293 self.verbose, 

294 ) 

295 

296 # Compute the mask if not given by the user 

297 if self.mask_img_ is None: 

298 if imgs is None: 

299 raise ValueError( 

300 "Parameter 'imgs' must be provided to " 

301 f"{self.__class__.__name__}.fit() " 

302 "if no mask is passed to mask_img." 

303 ) 

304 

305 logger.log("Computing mask", self.verbose) 

306 

307 imgs = stringify_path(imgs) 

308 if not isinstance(imgs, collections.abc.Iterable) or isinstance( 

309 imgs, str 

310 ): 

311 imgs = [imgs] 

312 

313 mask_args = self.mask_args if self.mask_args is not None else {} 

314 compute_mask = _get_mask_strategy(self.mask_strategy) 

315 self.mask_img_ = self._cache( 

316 compute_mask, 

317 ignore=["n_jobs", "verbose", "memory"], 

318 )( 

319 imgs, 

320 target_affine=self.target_affine, 

321 target_shape=self.target_shape, 

322 n_jobs=self.n_jobs, 

323 memory=self.memory, 

324 verbose=max(0, self.verbose - 1), 

325 **mask_args, 

326 ) 

327 elif imgs is not None: 

328 warnings.warn( 

329 f"[{self.__class__.__name__}.fit] " 

330 "Generation of a mask has been requested (imgs != None) " 

331 "while a mask was given at masker creation. " 

332 "Given mask will be used.", 

333 stacklevel=find_stack_level(), 

334 ) 

335 

336 self._reporting_data = None 

337 if self.reports: # save inputs for reporting 

338 self._reporting_data = { 

339 "mask": self.mask_img_, 

340 "dim": None, 

341 "images": imgs, 

342 } 

343 if imgs is not None: 

344 imgs, dims = compute_middle_image(imgs) 

345 self._reporting_data["images"] = imgs 

346 self._reporting_data["dim"] = dims 

347 

348 # If resampling is requested, resample the mask as well. 

349 # Resampling: allows the user to change the affine, the shape or both. 

350 logger.log("Resampling mask") 

351 

352 # TODO switch to force_resample=True 

353 # when bumping to version > 0.13 

354 self.mask_img_ = self._cache(resample_img)( 

355 self.mask_img_, 

356 target_affine=self.target_affine, 

357 target_shape=self.target_shape, 

358 interpolation="nearest", 

359 copy=False, 

360 copy_header=True, 

361 force_resample=False, 

362 ) 

363 

364 if self.target_affine is not None: 

365 self.affine_ = self.target_affine 

366 else: 

367 self.affine_ = self.mask_img_.affine 

368 

369 # Load data in memory, while also checking that mask is binary/valid 

370 data, _ = load_mask_img(self.mask_img_, allow_empty=True) 

371 

372 # Infer the number of elements (voxels) in the mask 

373 self.n_elements_ = int(data.sum()) 

374 self._report_content["n_elements"] = self.n_elements_ 

375 self._report_content["coverage"] = ( 

376 self.n_elements_ / np.prod(data.shape) * 100 

377 ) 

378 

379 if (self.target_shape is not None) or ( 

380 (self.target_affine is not None) and self.reports 

381 ): 

382 resampl_imgs = None 

383 if imgs is not None: 

384 # TODO switch to force_resample=True 

385 # when bumping to version > 0.13 

386 resampl_imgs = self._cache(resample_img)( 

387 imgs, 

388 target_affine=self.affine_, 

389 copy=False, 

390 interpolation="nearest", 

391 copy_header=True, 

392 force_resample=False, 

393 ) 

394 

395 self._reporting_data["transform"] = [resampl_imgs, self.mask_img_] 

396 

397 return self 

398 

399 @fill_doc 

400 def transform_imgs( 

401 self, imgs_list, confounds=None, sample_mask=None, copy=True, n_jobs=1 

402 ): 

403 """Prepare multi subject data in parallel. 

404 

405 Parameters 

406 ---------- 

407 %(imgs)s 

408 Images to process. 

409 

410 %(confounds_multi)s 

411 

412 %(sample_mask_multi)s 

413 

414 .. versionadded:: 0.8.0 

415 

416 copy : :obj:`bool`, default=True 

417 If True, guarantees that output array has no memory in common with 

418 input array. 

419 

420 %(n_jobs)s 

421 

422 Returns 

423 ------- 

424 %(signals_transform_imgs_multi_nifti)s 

425 

426 """ 

427 check_is_fitted(self) 

428 

429 target_fov = "first" if self.target_affine is None else None 

430 niimg_iter = iter_check_niimg( 

431 imgs_list, 

432 ensure_ndim=None, 

433 atleast_4d=False, 

434 target_fov=target_fov, 

435 memory=self.memory, 

436 memory_level=self.memory_level, 

437 ) 

438 

439 confounds = prepare_confounds_multimaskers(self, imgs_list, confounds) 

440 

441 if sample_mask is None: 

442 sample_mask = itertools.repeat(None, len(imgs_list)) 

443 elif len(sample_mask) != len(imgs_list): 

444 raise ValueError( 

445 f"number of sample_mask ({len(sample_mask)}) unequal to " 

446 f"number of images ({len(imgs_list)})." 

447 ) 

448 

449 # Ignore the mask-computing params: they are not useful and will 

450 # just invalidate the cache for no good reason 

451 # target_shape and target_affine are conveyed implicitly in mask_img 

452 params = get_params( 

453 self.__class__, 

454 self, 

455 ignore=[ 

456 "mask_img", 

457 "mask_args", 

458 "mask_strategy", 

459 "copy", 

460 ], 

461 ) 

462 params["clean_kwargs"] = self.clean_args_ 

463 # TODO remove in 0.13.2 

464 if self.clean_kwargs: 

465 params["clean_kwargs"] = self.clean_kwargs_ 

466 

467 func = self._cache( 

468 filter_and_mask, 

469 ignore=[ 

470 "verbose", 

471 "memory", 

472 "memory_level", 

473 "copy", 

474 ], 

475 shelve=self._shelving, 

476 ) 

477 data = Parallel(n_jobs=n_jobs)( 

478 delayed(func)( 

479 imgs, 

480 self.mask_img_, 

481 params, 

482 memory_level=self.memory_level, 

483 memory=self.memory, 

484 verbose=self.verbose, 

485 confounds=cfs, 

486 copy=copy, 

487 dtype=self.dtype, 

488 sample_mask=sms, 

489 ) 

490 for imgs, cfs, sms in zip(niimg_iter, confounds, sample_mask) 

491 ) 

492 return data 

493 

494 @fill_doc 

495 def transform(self, imgs, confounds=None, sample_mask=None): 

496 """Apply mask, spatial and temporal preprocessing. 

497 

498 Parameters 

499 ---------- 

500 imgs : Niimg-like object, or a :obj:`list` of Niimg-like objects 

501 See :ref:`extracting_data`. 

502 Data to be preprocessed 

503 

504 %(confounds_multi)s 

505 

506 %(sample_mask_multi)s 

507 

508 .. versionadded:: 0.8.0 

509 

510 Returns 

511 ------- 

512 %(signals_transform_multi_nifti)s 

513 

514 """ 

515 check_is_fitted(self) 

516 

517 if not (confounds is None or isinstance(confounds, list)): 

518 raise TypeError( 

519 "'confounds' must be a None or a list. " 

520 f"Got {confounds.__class__.__name__}." 

521 ) 

522 if not (sample_mask is None or isinstance(sample_mask, list)): 

523 raise TypeError( 

524 "'sample_mask' must be a None or a list. " 

525 f"Got {sample_mask.__class__.__name__}." 

526 ) 

527 if isinstance(imgs, NiimgLike): 

528 if isinstance(confounds, list): 

529 confounds = confounds[0] 

530 if isinstance(sample_mask, list): 

531 sample_mask = sample_mask[0] 

532 return super().transform( 

533 imgs, confounds=confounds, sample_mask=sample_mask 

534 ) 

535 

536 return self.transform_imgs( 

537 imgs, 

538 confounds=confounds, 

539 sample_mask=sample_mask, 

540 n_jobs=self.n_jobs, 

541 ) 

542 

543 @fill_doc 

544 def fit_transform(self, imgs, y=None, confounds=None, sample_mask=None): 

545 """ 

546 Fit to data, then transform it. 

547 

548 Parameters 

549 ---------- 

550 imgs : Niimg-like object, or a :obj:`list` of Niimg-like objects 

551 See :ref:`extracting_data`. 

552 Data to be preprocessed 

553 

554 y : None 

555 This parameter is unused. It is solely included for scikit-learn 

556 compatibility. 

557 

558 %(confounds_multi)s 

559 

560 %(sample_mask_multi)s 

561 

562 .. versionadded:: 0.8.0 

563 

564 Returns 

565 ------- 

566 %(signals_transform_multi_nifti)s 

567 """ 

568 return self.fit(imgs, y=y).transform( 

569 imgs, confounds=confounds, sample_mask=sample_mask 

570 )