Coverage for nilearn/maskers/nifti_maps_masker.py: 12%

210 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Transformer for computing ROI signals.""" 

2 

3import warnings 

4from copy import deepcopy 

5 

6import numpy as np 

7from sklearn.utils.estimator_checks import check_is_fitted 

8 

9from nilearn._utils import repr_niimgs 

10from nilearn._utils.class_inspect import get_params 

11from nilearn._utils.docs import fill_doc 

12from nilearn._utils.helpers import is_matplotlib_installed 

13from nilearn._utils.logger import find_stack_level, log 

14from nilearn._utils.niimg_conversions import check_niimg, check_same_fov 

15from nilearn._utils.param_validation import check_params 

16from nilearn.image import clean_img, get_data, index_img, resample_img 

17from nilearn.maskers._utils import compute_middle_image 

18from nilearn.maskers.base_masker import BaseMasker, filter_and_extract 

19from nilearn.masking import load_mask_img 

20 

21 

22class _ExtractionFunctor: 

23 func_name = "nifti_maps_masker_extractor" 

24 

25 def __init__(self, maps_img_, mask_img_, keep_masked_maps): 

26 self.maps_img_ = maps_img_ 

27 self.mask_img_ = mask_img_ 

28 self.keep_masked_maps = keep_masked_maps 

29 

30 def __call__(self, imgs): 

31 from ..regions import signal_extraction 

32 

33 return signal_extraction.img_to_signals_maps( 

34 imgs, 

35 self.maps_img_, 

36 mask_img=self.mask_img_, 

37 keep_masked_maps=self.keep_masked_maps, 

38 ) 

39 

40 

41@fill_doc 

42class NiftiMapsMasker(BaseMasker): 

43 """Class for extracting data from Niimg-like objects \ 

44 using maps of potentially overlapping brain regions. 

45 

46 NiftiMapsMasker is useful when data from overlapping volumes should be 

47 extracted (contrarily to :class:`nilearn.maskers.NiftiLabelsMasker`). 

48 

49 Use case: 

50 summarize brain signals from large-scale networks 

51 obtained by prior PCA or :term:`ICA`. 

52 

53 .. note:: 

54 Inf or NaN present in the given input images are automatically 

55 put to zero rather than considered as missing data. 

56 

57 For more details on the definitions of maps in Nilearn, 

58 see the :ref:`region` section. 

59 

60 Parameters 

61 ---------- 

62 maps_img : 4D niimg-like object or None, default=None 

63 See :ref:`extracting_data`. 

64 Set of continuous maps. One representative time course per map is 

65 extracted using least square regression. 

66 

67 mask_img : 3D niimg-like object, optional 

68 See :ref:`extracting_data`. 

69 Mask to apply to regions before extracting signals. 

70 

71 allow_overlap : :obj:`bool`, default=True 

72 If False, an error is raised if the maps overlaps (ie at least two 

73 maps have a non-zero value for the same voxel). 

74 

75 %(smoothing_fwhm)s 

76 

77 %(standardize_maskers)s 

78 

79 %(standardize_confounds)s 

80 

81 high_variance_confounds : :obj:`bool`, default=False 

82 If True, high variance confounds are computed on provided image with 

83 :func:`nilearn.image.high_variance_confounds` and default parameters 

84 and regressed out. 

85 

86 %(detrend)s 

87 

88 %(low_pass)s 

89 

90 %(high_pass)s 

91 

92 %(t_r)s 

93 

94 %(dtype)s. 

95 

96 resampling_target : {"data", "mask", "maps", None}, default="data" 

97 Gives which image gives the final shape/size. For example, if 

98 `resampling_target` is "mask" then maps_img and images provided to 

99 fit() are resampled to the shape and affine of mask_img. "None" means 

100 no resampling: if shapes and affines do not match, a ValueError is 

101 raised. 

102 

103 %(memory)s 

104 

105 %(memory_level)s 

106 

107 %(verbose0)s 

108 

109 %(keep_masked_maps)s 

110 

111 reports : :obj:`bool`, default=True 

112 If set to True, data is saved in order to produce a report. 

113 

114 %(cmap)s 

115 default="CMRmap_r" 

116 Only relevant for the report figures. 

117 

118 %(clean_args)s 

119 .. versionadded:: 0.11.2dev 

120 

121 %(masker_kwargs)s 

122 

123 Attributes 

124 ---------- 

125 maps_img_ : :obj:`nibabel.nifti1.Nifti1Image` 

126 The maps mask of the data. 

127 

128 %(nifti_mask_img_)s 

129 

130 n_elements_ : :obj:`int` 

131 The number of overlapping maps in the mask. 

132 This is equivalent to the number of volumes in the mask image. 

133 

134 .. versionadded:: 0.9.2 

135 

136 Notes 

137 ----- 

138 If resampling_target is set to "maps", every 3D image processed by 

139 transform() will be resampled to the shape of maps_img. It may lead to a 

140 very large memory consumption if the voxel number in maps_img is large. 

141 

142 See Also 

143 -------- 

144 nilearn.maskers.NiftiMasker 

145 nilearn.maskers.NiftiLabelsMasker 

146 

147 """ 

148 

149 # memory and memory_level are used by CacheMixin. 

150 

151 def __init__( 

152 self, 

153 maps_img=None, 

154 mask_img=None, 

155 allow_overlap=True, 

156 smoothing_fwhm=None, 

157 standardize=False, 

158 standardize_confounds=True, 

159 high_variance_confounds=False, 

160 detrend=False, 

161 low_pass=None, 

162 high_pass=None, 

163 t_r=None, 

164 dtype=None, 

165 resampling_target="data", 

166 keep_masked_maps=True, 

167 memory=None, 

168 memory_level=0, 

169 verbose=0, 

170 reports=True, 

171 cmap="CMRmap_r", 

172 clean_args=None, 

173 **kwargs, # TODO remove when bumping to nilearn >0.13 

174 ): 

175 self.maps_img = maps_img 

176 self.mask_img = mask_img 

177 

178 # Maps Masker parameter 

179 self.allow_overlap = allow_overlap 

180 

181 # Parameters for image.smooth 

182 self.smoothing_fwhm = smoothing_fwhm 

183 

184 # Parameters for clean() 

185 self.standardize = standardize 

186 self.standardize_confounds = standardize_confounds 

187 self.high_variance_confounds = high_variance_confounds 

188 self.detrend = detrend 

189 self.low_pass = low_pass 

190 self.high_pass = high_pass 

191 self.t_r = t_r 

192 self.dtype = dtype 

193 self.clean_args = clean_args 

194 

195 # TODO remove when bumping to nilearn >0.13 

196 self.clean_kwargs = kwargs 

197 

198 # Parameters for resampling 

199 self.resampling_target = resampling_target 

200 

201 # Parameters for joblib 

202 self.memory = memory 

203 self.memory_level = memory_level 

204 self.verbose = verbose 

205 

206 self.reports = reports 

207 self.cmap = cmap 

208 

209 self.keep_masked_maps = keep_masked_maps 

210 

211 def generate_report(self, displayed_maps=10): 

212 """Generate an HTML report for the current ``NiftiMapsMasker`` object. 

213 

214 .. note:: 

215 This functionality requires to have ``Matplotlib`` installed. 

216 

217 Parameters 

218 ---------- 

219 displayed_maps : :obj:`int`, or :obj:`list`, \ 

220 or :class:`~numpy.ndarray`, or "all", default=10 

221 Indicates which maps will be displayed in the HTML report. 

222 

223 - If "all": All maps will be displayed in the report. 

224 

225 .. code-block:: python 

226 

227 masker.generate_report("all") 

228 

229 .. warning: 

230 If there are too many maps, this might be time and 

231 memory consuming, and will result in very heavy 

232 reports. 

233 

234 - If a :obj:`list` or :class:`~numpy.ndarray`: This indicates 

235 the indices of the maps to be displayed in the report. For 

236 example, the following code will generate a report with maps 

237 6, 3, and 12, displayed in this specific order: 

238 

239 .. code-block:: python 

240 

241 masker.generate_report([6, 3, 12]) 

242 

243 - If an :obj:`int`: This will only display the first n maps, 

244 n being the value of the parameter. By default, the report 

245 will only contain the first 10 maps. Example to display the 

246 first 16 maps: 

247 

248 .. code-block:: python 

249 

250 masker.generate_report(16) 

251 

252 Returns 

253 ------- 

254 report : `nilearn.reporting.html_report.HTMLReport` 

255 HTML report for the masker. 

256 """ 

257 from nilearn.reporting.html_report import generate_report 

258 

259 if not is_matplotlib_installed(): 

260 return generate_report(self) 

261 

262 incorrect_type = not isinstance( 

263 displayed_maps, (list, np.ndarray, int, str) 

264 ) 

265 incorrect_string = ( 

266 isinstance(displayed_maps, str) and displayed_maps != "all" 

267 ) 

268 not_integer = ( 

269 not isinstance(displayed_maps, str) 

270 and np.array(displayed_maps).dtype != int 

271 ) 

272 if incorrect_type or incorrect_string or not_integer: 

273 raise TypeError( 

274 "Parameter ``displayed_maps`` of " 

275 "``generate_report()`` should be either 'all' or " 

276 "an int, or a list/array of ints. You provided a " 

277 f"{type(displayed_maps)}" 

278 ) 

279 self.displayed_maps = displayed_maps 

280 

281 return generate_report(self) 

282 

283 def _reporting(self): 

284 """Return a list of all displays to be rendered. 

285 

286 Returns 

287 ------- 

288 displays : list 

289 A list of all displays to be rendered. 

290 

291 """ 

292 from nilearn import plotting 

293 from nilearn.reporting.html_report import embed_img 

294 

295 if self._reporting_data is not None: 

296 maps_image = self._reporting_data["maps_image"] 

297 else: 

298 maps_image = None 

299 

300 if maps_image is None: 

301 return [None] 

302 

303 n_maps = get_data(maps_image).shape[-1] 

304 

305 maps_to_be_displayed = range(n_maps) 

306 if isinstance(self.displayed_maps, int): 

307 if n_maps < self.displayed_maps: 

308 msg = ( 

309 "`generate_report()` received " 

310 f"{self.displayed_maps} to be displayed. " 

311 f"But masker only has {n_maps} maps. " 

312 f"Setting number of displayed maps to {n_maps}." 

313 ) 

314 warnings.warn( 

315 category=UserWarning, 

316 message=msg, 

317 stacklevel=find_stack_level(), 

318 ) 

319 self.displayed_maps = n_maps 

320 maps_to_be_displayed = range(self.displayed_maps) 

321 

322 elif isinstance(self.displayed_maps, (list, np.ndarray)): 

323 if max(self.displayed_maps) > n_maps: 

324 raise ValueError( 

325 "Report cannot display the following maps " 

326 f"{self.displayed_maps} because " 

327 f"masker only has {n_maps} maps." 

328 ) 

329 maps_to_be_displayed = self.displayed_maps 

330 

331 self._report_content["number_of_maps"] = n_maps 

332 self._report_content["displayed_maps"] = list(maps_to_be_displayed) 

333 

334 img = self._reporting_data["img"] 

335 embedded_images = [] 

336 

337 if img is None: 

338 msg = ( 

339 "No image provided to fit in NiftiMapsMasker. " 

340 "Plotting only spatial maps for reporting." 

341 ) 

342 warnings.warn(msg, stacklevel=find_stack_level()) 

343 self._report_content["warning_message"] = msg 

344 for component in maps_to_be_displayed: 

345 display = plotting.plot_stat_map( 

346 index_img(maps_image, component) 

347 ) 

348 embedded_images.append(embed_img(display)) 

349 display.close() 

350 return embedded_images 

351 

352 if self._reporting_data["dim"] == 5: 

353 msg = ( 

354 "A list of 4D subject images were provided to fit. " 

355 "Only first subject is shown in the report." 

356 ) 

357 warnings.warn(msg, stacklevel=find_stack_level()) 

358 self._report_content["warning_message"] = msg 

359 

360 for component in maps_to_be_displayed: 

361 # Find the cut coordinates 

362 cut_coords = plotting.find_xyz_cut_coords( 

363 index_img(maps_image, component) 

364 ) 

365 display = plotting.plot_img( 

366 img, 

367 cut_coords=cut_coords, 

368 black_bg=False, 

369 cmap=self.cmap, 

370 ) 

371 display.add_overlay( 

372 index_img(maps_image, component), 

373 cmap=plotting.cm.black_blue, 

374 ) 

375 embedded_images.append(embed_img(display)) 

376 display.close() 

377 return embedded_images 

378 

379 @fill_doc 

380 def fit(self, imgs=None, y=None): 

381 """Prepare signal extraction from regions. 

382 

383 Parameters 

384 ---------- 

385 imgs : :obj:`list` of Niimg-like objects or None, default=None 

386 See :ref:`extracting_data`. 

387 Image data passed to the reporter. 

388 

389 %(y_dummy)s 

390 """ 

391 del y 

392 check_params(self.__dict__) 

393 if self.resampling_target not in ("mask", "maps", "data", None): 

394 raise ValueError( 

395 "invalid value for 'resampling_target' " 

396 f"parameter: {self.resampling_target}" 

397 ) 

398 

399 if self.mask_img is None and self.resampling_target == "mask": 

400 raise ValueError( 

401 "resampling_target has been set to 'mask' but no mask " 

402 "has been provided.\n" 

403 "Set resampling_target to something else or provide a mask." 

404 ) 

405 

406 self._sanitize_cleaning_parameters() 

407 self.clean_args_ = {} if self.clean_args is None else self.clean_args 

408 

409 self._report_content = { 

410 "description": ( 

411 "This reports shows the spatial maps provided to the mask." 

412 ), 

413 "warning_message": None, 

414 } 

415 

416 # Load images 

417 maps_img = self.maps_img 

418 if hasattr(self, "_maps_img"): 

419 # This is for RegionExtractor that first modifies 

420 # maps_img before passing to its parent fit method. 

421 maps_img = self._maps_img 

422 repr = repr_niimgs(maps_img, shorten=(not self.verbose)) 

423 msg = f"loading regions from {repr}" 

424 log(msg=msg, verbose=self.verbose) 

425 self.maps_img_ = deepcopy(maps_img) 

426 self.maps_img_ = check_niimg( 

427 self.maps_img_, dtype=self.dtype, atleast_4d=True 

428 ) 

429 self.maps_img_ = clean_img( 

430 self.maps_img_, 

431 detrend=False, 

432 standardize=False, 

433 ensure_finite=True, 

434 ) 

435 

436 if imgs is not None: 

437 imgs_ = check_niimg(imgs) 

438 

439 self.mask_img_ = self._load_mask(imgs) 

440 

441 # Check shapes and affines for resample. 

442 if self.resampling_target is None: 

443 images = {"maps": self.maps_img_} 

444 if self.mask_img_ is not None: 

445 images["mask"] = self.mask_img_ 

446 if imgs is not None: 

447 images["data"] = imgs_ 

448 check_same_fov(raise_error=True, **images) 

449 

450 ref_img = None 

451 if self.resampling_target == "data" and imgs is not None: 

452 ref_img = imgs_ 

453 elif self.resampling_target == "mask": 

454 ref_img = self.mask_img_ 

455 elif self.resampling_target == "maps": 

456 ref_img = self.maps_img_ 

457 

458 if ref_img is not None: 

459 if self.resampling_target != "maps" and not check_same_fov( 

460 ref_img, self.maps_img_ 

461 ): 

462 log("Resampling maps...", self.verbose) 

463 # TODO switch to force_resample=True 

464 # when bumping to version > 0.13 

465 self.maps_img_ = self._cache(resample_img)( 

466 self.maps_img_, 

467 interpolation="continuous", 

468 target_shape=ref_img.shape[:3], 

469 target_affine=ref_img.affine, 

470 copy_header=True, 

471 force_resample=False, 

472 ) 

473 if self.mask_img_ is not None and not check_same_fov( 

474 ref_img, self.mask_img_ 

475 ): 

476 log("Resampling mask...", self.verbose) 

477 # TODO switch to force_resample=True 

478 # when bumping to version > 0.13 

479 self.mask_img_ = resample_img( 

480 self.mask_img_, 

481 target_affine=ref_img.affine, 

482 target_shape=ref_img.shape[:3], 

483 interpolation="nearest", 

484 copy=True, 

485 copy_header=True, 

486 force_resample=False, 

487 ) 

488 

489 # Just check that the mask is valid 

490 load_mask_img(self.mask_img_) 

491 

492 if self.reports: 

493 self._reporting_data = { 

494 "maps_image": self.maps_img_, 

495 "mask": self.mask_img_, 

496 "dim": None, 

497 "img": imgs, 

498 } 

499 if imgs is not None: 

500 imgs, dims = compute_middle_image(imgs) 

501 self._reporting_data["img"] = imgs 

502 self._reporting_data["dim"] = dims 

503 else: 

504 self._reporting_data = None 

505 

506 # The number of elements is equal to the number of volumes 

507 self.n_elements_ = self.maps_img_.shape[3] 

508 

509 return self 

510 

511 def __sklearn_is_fitted__(self): 

512 return hasattr(self, "maps_img_") and hasattr(self, "n_elements_") 

513 

514 @fill_doc 

515 def fit_transform(self, imgs, y=None, confounds=None, sample_mask=None): 

516 """Prepare and perform signal extraction. 

517 

518 Parameters 

519 ---------- 

520 imgs : 3D/4D Niimg-like object 

521 See :ref:`extracting_data`. 

522 Images to process. 

523 If a 3D niimg is provided, a 1D array is returned. 

524 

525 %(y_dummy)s 

526 

527 %(confounds)s 

528 

529 %(sample_mask)s 

530 

531 .. versionadded:: 0.8.0 

532 

533 Returns 

534 ------- 

535 %(signals_transform_nifti)s 

536 """ 

537 del y 

538 return self.fit(imgs).transform( 

539 imgs, confounds=confounds, sample_mask=sample_mask 

540 ) 

541 

542 @fill_doc 

543 def transform_single_imgs(self, imgs, confounds=None, sample_mask=None): 

544 """Extract signals from a single 4D niimg. 

545 

546 Parameters 

547 ---------- 

548 imgs : 3D/4D Niimg-like object 

549 See :ref:`extracting_data`. 

550 Images to process. 

551 

552 confounds : CSV file or array-like, default=None 

553 This parameter is passed to :func:`nilearn.signal.clean`. 

554 Please see the related documentation for details. 

555 shape: (number of scans, number of confounds) 

556 

557 %(sample_mask)s 

558 

559 .. versionadded:: 0.8.0 

560 

561 Returns 

562 ------- 

563 %(signals_transform_nifti)s 

564 

565 """ 

566 check_is_fitted(self) 

567 

568 # imgs passed at transform time may be different 

569 # from those passed at fit time. 

570 # So it may be needed to resample mask and maps, 

571 # if 'data' is the resampling target. 

572 # We handle the resampling of maps and mask separately because the 

573 # affine of the maps and mask images should not impact the extraction 

574 # of the signal. 

575 # 

576 # Any resampling of the mask or maps is not 'kept' after transform, 

577 # to avoid modifying the masker after fit. 

578 # 

579 # If the resampling target is different, 

580 # then resampling was already done at fit time 

581 # (e.g resampling of the mask image to the maps image 

582 # if the target was 'maps'), 

583 # or resampling of the data will be done at extract time. 

584 

585 mask_img_ = self.mask_img_ 

586 maps_img_ = self.maps_img_ 

587 

588 imgs_ = check_niimg(imgs, atleast_4d=True) 

589 

590 if self.resampling_target is None: 

591 images = {"maps": maps_img_, "data": imgs_} 

592 if mask_img_ is not None: 

593 images["mask"] = mask_img_ 

594 check_same_fov(raise_error=True, **images) 

595 elif self.resampling_target == "data": 

596 ref_img = imgs_ 

597 

598 if not check_same_fov(ref_img, maps_img_): 

599 warnings.warn( 

600 ( 

601 "Resampling maps at transform time...\n" 

602 "To avoid this warning, make sure to pass the images " 

603 "you want to transform to fit() first, " 

604 "or directly use fit_transform()." 

605 ), 

606 stacklevel=find_stack_level(), 

607 ) 

608 # TODO switch to force_resample=True 

609 # when bumping to version > 0.13 

610 maps_img_ = self._cache(resample_img)( 

611 self.maps_img_, 

612 interpolation="continuous", 

613 target_shape=ref_img.shape[:3], 

614 target_affine=ref_img.affine, 

615 copy_header=True, 

616 force_resample=False, 

617 ) 

618 

619 if self.mask_img_ is not None and not check_same_fov( 

620 ref_img, 

621 self.mask_img_, 

622 ): 

623 warnings.warn( 

624 ( 

625 "Resampling mask at transform time...\n" 

626 "To avoid this warning, make sure to pass the images " 

627 "you want to transform to fit() first, " 

628 "or directly use fit_transform()." 

629 ), 

630 stacklevel=find_stack_level(), 

631 ) 

632 # TODO switch to force_resample=True 

633 # when bumping to version > 0.13 

634 mask_img_ = self._cache(resample_img)( 

635 self.mask_img_, 

636 interpolation="nearest", 

637 target_shape=ref_img.shape[:3], 

638 target_affine=ref_img.affine, 

639 copy_header=True, 

640 force_resample=False, 

641 ) 

642 

643 # Remove imgs_ from memory before loading the same image 

644 # in filter_and_extract. 

645 del imgs_ 

646 

647 if not self.allow_overlap: 

648 # Check if there is an overlap. 

649 

650 # If float, we set low values to 0 

651 data = get_data(maps_img_) 

652 dtype = data.dtype 

653 if dtype.kind == "f": 

654 data[data < np.finfo(dtype).eps] = 0.0 

655 

656 # Check the overlaps 

657 if np.any(np.sum(data > 0.0, axis=3) > 1): 

658 raise ValueError( 

659 "Overlap detected in the maps. The overlap may be " 

660 "due to the atlas itself or possibly introduced by " 

661 "resampling." 

662 ) 

663 

664 target_shape = None 

665 target_affine = None 

666 if self.resampling_target != "data": 

667 target_shape = maps_img_.shape[:3] 

668 target_affine = maps_img_.affine 

669 

670 params = get_params( 

671 NiftiMapsMasker, 

672 self, 

673 ignore=["resampling_target"], 

674 ) 

675 params["target_shape"] = target_shape 

676 params["target_affine"] = target_affine 

677 params["clean_kwargs"] = self.clean_args_ 

678 # TODO remove in 0.13.2 

679 if self.clean_kwargs: 

680 params["clean_kwargs"] = self.clean_kwargs_ 

681 

682 region_signals, _ = self._cache( 

683 filter_and_extract, 

684 ignore=["verbose", "memory", "memory_level"], 

685 )( 

686 # Images 

687 imgs, 

688 _ExtractionFunctor( 

689 maps_img_, 

690 mask_img_, 

691 self.keep_masked_maps, 

692 ), 

693 # Pre-treatments 

694 params, 

695 confounds=confounds, 

696 sample_mask=sample_mask, 

697 dtype=self.dtype, 

698 # Caching 

699 memory=self.memory, 

700 memory_level=self.memory_level, 

701 # kwargs 

702 verbose=self.verbose, 

703 ) 

704 return region_signals 

705 

706 @fill_doc 

707 def inverse_transform(self, region_signals): 

708 """Compute :term:`voxel` signals from region signals. 

709 

710 Any mask given at initialization is taken into account. 

711 

712 Parameters 

713 ---------- 

714 %(region_signals_inv_transform)s 

715 

716 Returns 

717 ------- 

718 %(img_inv_transform_nifti)s 

719 

720 """ 

721 from ..regions import signal_extraction 

722 

723 check_is_fitted(self) 

724 

725 region_signals = self._check_array(region_signals) 

726 

727 log("computing image from signals", verbose=self.verbose) 

728 return signal_extraction.signals_to_img_maps( 

729 region_signals, 

730 self.maps_img_, 

731 mask_img=self.mask_img_, 

732 )