Coverage for nilearn/maskers/nifti_labels_masker.py: 14%

290 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Transformer for computing ROI signals.""" 

2 

3import warnings 

4from copy import deepcopy 

5from pathlib import Path 

6from typing import Union 

7 

8import numpy as np 

9import pandas as pd 

10from nibabel import Nifti1Image 

11from sklearn.utils.estimator_checks import check_is_fitted 

12 

13from nilearn._utils import logger, repr_niimgs 

14from nilearn._utils.bids import ( 

15 generate_atlas_look_up_table, 

16 sanitize_look_up_table, 

17) 

18from nilearn._utils.class_inspect import get_params 

19from nilearn._utils.docs import fill_doc 

20from nilearn._utils.logger import find_stack_level 

21from nilearn._utils.niimg import safe_get_data 

22from nilearn._utils.niimg_conversions import ( 

23 check_niimg, 

24 check_niimg_3d, 

25 check_same_fov, 

26) 

27from nilearn._utils.param_validation import ( 

28 check_params, 

29 check_reduction_strategy, 

30) 

31from nilearn.image import get_data, load_img, resample_img 

32from nilearn.maskers._utils import compute_middle_image 

33from nilearn.maskers.base_masker import BaseMasker, filter_and_extract 

34from nilearn.masking import load_mask_img 

35 

36 

37class _ExtractionFunctor: 

38 func_name = "nifti_labels_masker_extractor" 

39 

40 def __init__( 

41 self, 

42 labels_img, 

43 background_label, 

44 strategy, 

45 keep_masked_labels, 

46 mask_img, 

47 ): 

48 self.labels_img = labels_img 

49 self.background_label = background_label 

50 self.strategy = strategy 

51 self.keep_masked_labels = keep_masked_labels 

52 self.mask_img = mask_img 

53 

54 def __call__(self, imgs): 

55 from ..regions.signal_extraction import img_to_signals_labels 

56 

57 signals, labels, masked_labels_img = img_to_signals_labels( 

58 imgs, 

59 self.labels_img, 

60 background_label=self.background_label, 

61 strategy=self.strategy, 

62 keep_masked_labels=self.keep_masked_labels, 

63 mask_img=self.mask_img, 

64 return_masked_atlas=True, 

65 ) 

66 return signals, (labels, masked_labels_img) 

67 

68 

69@fill_doc 

70class NiftiLabelsMasker(BaseMasker): 

71 """Class for extracting data from Niimg-like objects \ 

72 using labels of non-overlapping brain regions. 

73 

74 NiftiLabelsMasker is useful when data from non-overlapping volumes should 

75 be extracted (contrarily to :class:`nilearn.maskers.NiftiMapsMasker`). 

76 

77 Use case: 

78 summarize brain signals from clusters that were obtained by prior 

79 K-means or Ward clustering. 

80 

81 For more details on the definitions of labels in Nilearn, 

82 see the :ref:`region` section. 

83 

84 Parameters 

85 ---------- 

86 labels_img : Niimg-like object or None, default=None 

87 See :ref:`extracting_data`. 

88 Region definitions, as one image of labels. 

89 

90 labels : :obj:`list` of :obj:`str`, optional 

91 Full labels corresponding to the labels image. 

92 This is used to improve reporting quality if provided. 

93 Mutually exclusive with ``lut``. 

94 

95 .. warning:: 

96 The labels must be consistent with the label values 

97 provided through ``labels_img``. 

98 

99 %(masker_lut)s 

100 

101 background_label : :obj:`int` or :obj:`float`, default=0 

102 Label used in labels_img to represent background. 

103 

104 .. warning::: 

105 

106 This value must be consistent with label values and image provided. 

107 

108 mask_img : Niimg-like object, optional 

109 See :ref:`extracting_data`. 

110 Mask to apply to regions before extracting signals. 

111 

112 %(smoothing_fwhm)s 

113 

114 %(standardize_maskers)s 

115 

116 %(standardize_confounds)s 

117 

118 high_variance_confounds : :obj:`bool`, default=False 

119 If True, high variance confounds are computed on provided image with 

120 :func:`nilearn.image.high_variance_confounds` and default parameters 

121 and regressed out. 

122 %(detrend)s 

123 

124 %(low_pass)s 

125 

126 %(high_pass)s 

127 

128 %(t_r)s 

129 

130 %(dtype)s 

131 

132 resampling_target : {"data", "labels", None}, default="data" 

133 Gives which image gives the final shape/size. 

134 For example, if ``resampling_target`` is ``"data"``, 

135 the atlas is resampled to the shape of the data if needed. 

136 If it is ``"labels"`` then mask_img and images provided to fit() 

137 are resampled to the shape and affine of labels_img. 

138 ``"None"`` means no resampling: 

139 if shapes and affines do not match, a ValueError is raised. 

140 

141 %(memory)s 

142 

143 %(memory_level1)s 

144 

145 %(verbose0)s 

146 

147 %(strategy)s 

148 

149 %(keep_masked_labels)s 

150 

151 reports : :obj:`bool`, default=True 

152 If set to True, data is saved in order to produce a report. 

153 

154 %(cmap)s 

155 default="CMRmap_r" 

156 Only relevant for the report figures. 

157 

158 %(clean_args)s 

159 .. versionadded:: 0.11.2dev 

160 

161 %(masker_kwargs)s 

162 

163 Attributes 

164 ---------- 

165 %(nifti_mask_img_)s 

166 

167 labels_img_ : :obj:`nibabel.nifti1.Nifti1Image` 

168 The labels image. 

169 

170 lut_ : :obj:`pandas.DataFrame` 

171 Look-up table derived from the ``labels`` or ``lut`` 

172 or from the values of the label image. 

173 

174 region_atlas_ : Niimg-like object 

175 Regions definition as labels. 

176 The labels correspond to the indices in ``region_ids_``. 

177 The region in ``region_atlas_`` that takes the value ``region_ids_[i]`` 

178 is used to compute the signal in ``region_signal[:,i]``. 

179 

180 .. versionadded:: 0.10.3 

181 

182 See Also 

183 -------- 

184 nilearn.maskers.NiftiMasker 

185 

186 """ 

187 

188 # memory and memory_level are used by _utils.CacheMixin. 

189 

190 def __init__( 

191 self, 

192 labels_img=None, 

193 labels=None, 

194 lut=None, 

195 background_label=0, 

196 mask_img=None, 

197 smoothing_fwhm=None, 

198 standardize=False, 

199 standardize_confounds=True, 

200 high_variance_confounds=False, 

201 detrend=False, 

202 low_pass=None, 

203 high_pass=None, 

204 t_r=None, 

205 dtype=None, 

206 resampling_target="data", 

207 memory=None, 

208 memory_level=1, 

209 verbose=0, 

210 strategy="mean", 

211 keep_masked_labels=True, 

212 reports=True, 

213 cmap="CMRmap_r", 

214 clean_args=None, 

215 **kwargs, # TODO remove when bumping to nilearn >0.13 

216 ): 

217 self.labels_img = labels_img 

218 self.background_label = background_label 

219 

220 self.labels = labels 

221 self.lut = lut 

222 

223 self.mask_img = mask_img 

224 self.keep_masked_labels = keep_masked_labels 

225 

226 # Parameters for smooth_array 

227 self.smoothing_fwhm = smoothing_fwhm 

228 

229 # Parameters for clean() 

230 self.standardize = standardize 

231 self.standardize_confounds = standardize_confounds 

232 self.high_variance_confounds = high_variance_confounds 

233 self.detrend = detrend 

234 self.low_pass = low_pass 

235 self.high_pass = high_pass 

236 self.t_r = t_r 

237 self.dtype = dtype 

238 self.clean_args = clean_args 

239 

240 # TODO remove when bumping to nilearn >0.13 

241 self.clean_kwargs = kwargs 

242 

243 # Parameters for resampling 

244 self.resampling_target = resampling_target 

245 

246 # Parameters for joblib 

247 self.memory = memory 

248 self.memory_level = memory_level 

249 self.verbose = verbose 

250 

251 # Parameters for reports 

252 self.reports = reports 

253 self.cmap = cmap 

254 

255 self.strategy = strategy 

256 

257 @property 

258 def _region_id_name(self): 

259 """Return dictionary used to store region names and 

260 the corresponding region ids as keys. 

261 """ 

262 check_is_fitted(self) 

263 lut = self.lut_ 

264 return ( 

265 lut.loc[lut["name"] != "Background"] 

266 .set_index("index")["name"] 

267 .to_dict() 

268 ) 

269 

270 @property 

271 def labels_(self) -> list[str]: 

272 """Return list of labels of the regions.""" 

273 check_is_fitted(self) 

274 lut = self.lut_ 

275 if hasattr(self, "_lut_"): 

276 lut = self._lut_ 

277 return lut["index"].to_list() 

278 

279 @property 

280 def region_names_(self) -> dict[int, str]: 

281 """Return a dictionary containing the region names corresponding \n 

282 to each column in the array returned by `transform`. 

283 

284 The region names correspond to the labels provided 

285 in labels in input. 

286 The region name corresponding to ``region_signal[:,i]`` 

287 is ``region_names_[i]``. 

288 

289 .. versionadded:: 0.10.3 

290 """ 

291 check_is_fitted(self) 

292 lut = self.lut_ 

293 if hasattr(self, "_lut_"): 

294 lut = self._lut_ 

295 return lut.loc[lut["index"] != self.background_label, "name"].to_dict() 

296 

297 @property 

298 def region_ids_(self) -> dict[Union[str, int], int]: 

299 """Return dictionary containing the region ids corresponding \n 

300 to each column in the array \n 

301 returned by `transform`. 

302 

303 The region id corresponding to ``region_signal[:,i]`` 

304 is ``region_ids_[i]``. 

305 ``region_ids_['background']`` is the background label. 

306 

307 .. versionadded:: 0.10.3 

308 """ 

309 check_is_fitted(self) 

310 lut = self.lut_ 

311 if hasattr(self, "_lut_"): 

312 lut = self._lut_ 

313 return lut["index"].to_dict() 

314 

315 @property 

316 def n_elements_(self) -> int: 

317 """Return number of regions. 

318 

319 This is equal to the number of unique values 

320 in the fitted label image, 

321 minus the background value. 

322 

323 .. versionadded:: 0.9.2 

324 """ 

325 check_is_fitted(self) 

326 lut = self.lut_ 

327 if hasattr(self, "_lut_"): 

328 lut = self._lut_ 

329 return len(lut[lut["index"] != self.background_label]) 

330 

331 def _post_masking_atlas(self, visualize=False): 

332 """ 

333 Find the masked atlas before transform and return it. 

334 

335 Also return the removed region ids and names. 

336 if visualize is True, plot the masked atlas. 

337 """ 

338 labels_data = safe_get_data(self.labels_img_, ensure_finite=True) 

339 labels_data = labels_data.copy() 

340 mask_data = safe_get_data(self.mask_img_, ensure_finite=True) 

341 mask_data = mask_data.copy() 

342 region_ids_before_masking = np.unique(labels_data).tolist() 

343 # apply the mask to the atlas 

344 labels_data[np.logical_not(mask_data)] = self.background_label 

345 region_ids_after_masking = np.unique(labels_data).tolist() 

346 masked_atlas = Nifti1Image( 

347 labels_data.astype(np.int8), self.labels_img_.affine 

348 ) 

349 removed_region_ids = [ 

350 region_id 

351 for region_id in region_ids_before_masking 

352 if region_id not in region_ids_after_masking 

353 ] 

354 removed_region_names = [ 

355 self._region_id_name[region_id] 

356 for region_id in removed_region_ids 

357 if region_id != self.background_label 

358 ] 

359 display = None 

360 if visualize: 

361 from nilearn.plotting import plot_roi 

362 

363 display = plot_roi(masked_atlas, title="Masked atlas") 

364 

365 return masked_atlas, removed_region_ids, removed_region_names, display 

366 

367 def generate_report(self): 

368 """Generate a report.""" 

369 from nilearn.reporting.html_report import generate_report 

370 

371 return generate_report(self) 

372 

373 def _reporting(self): 

374 """Return a list of all displays to be rendered. 

375 

376 Returns 

377 ------- 

378 displays : list 

379 A list of all displays to be rendered. 

380 

381 """ 

382 import matplotlib.pyplot as plt 

383 

384 from nilearn import plotting 

385 

386 labels_image = None 

387 if self._reporting_data is not None: 

388 labels_image = self._reporting_data["labels_image"] 

389 

390 if ( 

391 labels_image is None 

392 or not self.__sklearn_is_fitted__ 

393 or not self.reports 

394 ): 

395 self._report_content["summary"] = None 

396 return [None] 

397 

398 # Remove warning message in case where the masker was 

399 # previously fitted with no func image and is re-fitted 

400 if "warning_message" in self._report_content: 

401 self._report_content["warning_message"] = None 

402 

403 table = self.lut_.copy() 

404 if hasattr(self, "_lut_"): 

405 table = self._lut_.copy() 

406 

407 table = table[["index", "name"]] 

408 

409 table["index"] = table["index"].astype(int) 

410 

411 table = table.rename( 

412 columns={"name": "region name", "index": "label value"} 

413 ) 

414 

415 labels_image = load_img(labels_image, dtype="int32") 

416 labels_image_data = get_data(labels_image) 

417 labels_image_affine = labels_image.affine 

418 

419 voxel_volume = np.abs(np.linalg.det(labels_image_affine[:3, :3])) 

420 

421 new_columns = {"size (in mm^3)": [], "relative size (in %)": []} 

422 for label in table["label value"].to_list(): 

423 size = len(labels_image_data[labels_image_data == label]) 

424 new_columns["size (in mm^3)"].append(round(size * voxel_volume)) 

425 

426 new_columns["relative size (in %)"].append( 

427 round( 

428 size 

429 / len( 

430 labels_image_data[ 

431 labels_image_data != self.background_label 

432 ] 

433 ) 

434 * 100, 

435 2, 

436 ) 

437 ) 

438 

439 table = pd.concat([table, pd.DataFrame(new_columns)], axis=1) 

440 

441 table = table[table["label value"] != self.background_label] 

442 

443 self._report_content["summary"] = table 

444 self._report_content["number_of_regions"] = self.n_elements_ 

445 

446 img = self._reporting_data["img"] 

447 

448 # compute the cut coordinates on the label image in case 

449 # we have a functional image 

450 cut_coords = plotting.find_xyz_cut_coords( 

451 labels_image, activation_threshold=0.5 

452 ) 

453 

454 # If we have a func image to show in the report, use it 

455 if img is not None: 

456 if self._reporting_data["dim"] == 5: 

457 msg = ( 

458 "A list of 4D subject images were provided to fit. " 

459 "Only first subject is shown in the report." 

460 ) 

461 warnings.warn(msg, stacklevel=find_stack_level()) 

462 self._report_content["warning_message"] = msg 

463 display = plotting.plot_img( 

464 img, 

465 cut_coords=cut_coords, 

466 black_bg=False, 

467 cmap=self.cmap, 

468 ) 

469 plt.close() 

470 display.add_contours(labels_image, filled=False, linewidths=3) 

471 

472 # Otherwise, simply plot the ROI of the label image 

473 # and give a warning to the user 

474 else: 

475 msg = ( 

476 "No image provided to fit in NiftiLabelsMasker. " 

477 "Plotting ROIs of label image on the " 

478 "MNI152Template for reporting." 

479 ) 

480 warnings.warn(msg, stacklevel=find_stack_level()) 

481 self._report_content["warning_message"] = msg 

482 display = plotting.plot_roi(labels_image) 

483 plt.close() 

484 

485 # If we have a mask, show its contours 

486 if self._reporting_data["mask"] is not None: 

487 display.add_contours( 

488 self._reporting_data["mask"], 

489 filled=False, 

490 colors="g", 

491 linewidths=3, 

492 ) 

493 

494 return [display] 

495 

496 @fill_doc 

497 def fit(self, imgs=None, y=None): 

498 """Prepare signal extraction from regions. 

499 

500 Parameters 

501 ---------- 

502 imgs : :obj:`list` of Niimg-like objects or None, default=None 

503 See :ref:`extracting_data`. 

504 Image data passed to the reporter. 

505 

506 %(y_dummy)s 

507 """ 

508 del y 

509 check_params(self.__dict__) 

510 check_reduction_strategy(self.strategy) 

511 

512 if self.resampling_target not in ("labels", "data", None): 

513 raise ValueError( 

514 "invalid value for 'resampling_target' " 

515 f"parameter: {self.resampling_target}" 

516 ) 

517 

518 self._sanitize_cleaning_parameters() 

519 self.clean_args_ = {} if self.clean_args is None else self.clean_args 

520 

521 self._report_content = { 

522 "description": ( 

523 "This reports shows the regions " 

524 "defined by the labels of the mask." 

525 ), 

526 "warning_message": None, 

527 } 

528 

529 repr = repr_niimgs(self.labels_img, shorten=(not self.verbose)) 

530 msg = f"loading data from {repr}" 

531 logger.log(msg=msg, verbose=self.verbose) 

532 self.labels_img_ = deepcopy(self.labels_img) 

533 self.labels_img_ = check_niimg_3d(self.labels_img_) 

534 

535 if self.labels: 

536 if self.lut is not None: 

537 raise ValueError( 

538 "Pass either labels " 

539 "or a lookup table (lut) to the masker, " 

540 "but not both." 

541 ) 

542 self._check_labels() 

543 if "background" in self.labels: 

544 idx = self.labels.index("background") 

545 self.labels[idx] = "Background" 

546 

547 self.lut_ = self._generate_lut() 

548 

549 self._original_region_ids = self.lut_["index"].to_list() 

550 

551 if imgs is not None: 

552 imgs_ = check_niimg(imgs, atleast_4d=True) 

553 

554 self.mask_img_ = self._load_mask(imgs) 

555 

556 # Check shapes and affines for resample. 

557 if self.resampling_target is None: 

558 images = {"labels": self.labels_img_} 

559 if self.mask_img_ is not None: 

560 images["mask"] = self.mask_img_ 

561 if imgs is not None: 

562 images["data"] = imgs_ 

563 check_same_fov(raise_error=True, **images) 

564 

565 # resample labels 

566 if ( 

567 self.resampling_target == "data" 

568 and imgs is not None 

569 and not check_same_fov( 

570 imgs_, 

571 self.labels_img_, 

572 ) 

573 ): 

574 self.labels_img_ = self._resample_labels(imgs_) 

575 

576 # resample mask 

577 ref_img = None 

578 if self.resampling_target == "data" and imgs is not None: 

579 ref_img = imgs_ 

580 elif self.resampling_target == "labels": 

581 ref_img = self.labels_img_ 

582 if ( 

583 self.mask_img_ is not None 

584 and ref_img is not None 

585 and not check_same_fov( 

586 ref_img, 

587 self.mask_img_, 

588 ) 

589 ): 

590 logger.log("Resampling mask...", self.verbose) 

591 # TODO switch to force_resample=True 

592 # when bumping to version > 0.13 

593 self.mask_img_ = self._cache(resample_img, func_memory_level=2)( 

594 self.mask_img_, 

595 interpolation="nearest", 

596 target_shape=ref_img.shape[:3], 

597 target_affine=ref_img.affine, 

598 copy_header=True, 

599 force_resample=False, 

600 ) 

601 

602 # Just check that the mask is valid 

603 load_mask_img(self.mask_img_) 

604 

605 if self.reports: 

606 self._reporting_data = { 

607 "labels_image": self.labels_img_, 

608 "mask": self.mask_img_, 

609 "dim": None, 

610 "img": imgs, 

611 } 

612 if imgs is not None: 

613 imgs, dims = compute_middle_image(imgs) 

614 self._reporting_data["img"] = imgs 

615 self._reporting_data["dim"] = dims 

616 else: 

617 self._reporting_data = None 

618 

619 return self 

620 

621 def _check_labels(self): 

622 """Check labels. 

623 

624 - checks that labels is a list of strings. 

625 """ 

626 labels = self.labels 

627 if not isinstance(labels, list): 

628 raise TypeError( 

629 f"'labels' must be a list. Got: {type(labels)}", 

630 ) 

631 if not all(isinstance(x, str) for x in labels): 

632 types_labels = {type(x) for x in labels} 

633 raise TypeError( 

634 "All elements of 'labels' must be a string.\n" 

635 f"Got a list of {types_labels}", 

636 ) 

637 

638 def _generate_lut(self): 

639 """Generate a look up table if one was not provided. 

640 

641 Also sanitize its content if necessary. 

642 """ 

643 if self.lut is not None: 

644 if isinstance(self.lut, (str, Path)): 

645 lut = pd.read_table(self.lut, sep=None, engine="python") 

646 else: 

647 lut = self.lut 

648 

649 elif self.labels: 

650 lut = generate_atlas_look_up_table( 

651 function=None, 

652 name=deepcopy(self.labels), 

653 index=self.labels_img_, 

654 ) 

655 

656 else: 

657 lut = generate_atlas_look_up_table( 

658 function=None, index=self.labels_img_ 

659 ) 

660 

661 # passed labels or lut may not include background label 

662 # because of poor data standardization 

663 # so we need to update the lut accordingly 

664 mask_background_name = lut["name"] == "Background" 

665 mask_background_index = lut["index"] == self.background_label 

666 if (mask_background_index).any(): 

667 # Ensure background is the first row with name "Background" 

668 # Shift the 'name' column down by one 

669 # if background row was not named properly 

670 first_rows = lut[mask_background_index] 

671 other_rows = lut[~mask_background_index] 

672 lut = pd.concat([first_rows, other_rows], ignore_index=True) 

673 

674 if not (mask_background_name).any(): 

675 lut["name"] = lut["name"].shift(1) 

676 

677 lut.loc[0, "name"] = "Background" 

678 

679 return sanitize_look_up_table(lut, atlas=self.labels_img_) 

680 

681 @fill_doc 

682 def fit_transform(self, imgs, y=None, confounds=None, sample_mask=None): 

683 """Prepare and perform signal extraction from regions. 

684 

685 Parameters 

686 ---------- 

687 imgs : 3D/4D Niimg-like object 

688 See :ref:`extracting_data`. 

689 Images to process. 

690 If a 3D niimg is provided, a 1D array is returned. 

691 

692 y : None 

693 This parameter is unused. It is solely included for scikit-learn 

694 compatibility. 

695 

696 %(confounds)s 

697 

698 %(sample_mask)s 

699 

700 .. versionadded:: 0.8.0 

701 

702 Returns 

703 ------- 

704 %(signals_transform_nifti)s 

705 

706 """ 

707 del y 

708 return self.fit(imgs).transform( 

709 imgs, confounds=confounds, sample_mask=sample_mask 

710 ) 

711 

712 def __sklearn_is_fitted__(self): 

713 return hasattr(self, "labels_img_") and hasattr(self, "lut_") 

714 

715 @fill_doc 

716 def transform_single_imgs(self, imgs, confounds=None, sample_mask=None): 

717 """Extract signals from a single 4D niimg. 

718 

719 Parameters 

720 ---------- 

721 imgs : 3D/4D Niimg-like object 

722 See :ref:`extracting_data`. 

723 Images to process. 

724 

725 %(confounds)s 

726 

727 %(sample_mask)s 

728 

729 .. versionadded:: 0.8.0 

730 

731 Returns 

732 ------- 

733 %(signals_transform_nifti)s 

734 

735 """ 

736 check_is_fitted(self) 

737 

738 # imgs passed at transform time may be different 

739 # from those passed at fit time. 

740 # So it may be needed to resample mask and labels, 

741 # if 'data' is the resampling target. 

742 # We handle the resampling of labels and mask separately because the 

743 # affine of the labels and mask images should not impact the extraction 

744 # of the signal. 

745 # 

746 # Any resampling of the mask or labels is not 'kept' after transform, 

747 # to avoid modifying the masker after fit. 

748 # 

749 # If the resampling target is different, 

750 # then resampling was already done at fit time 

751 # (e.g resampling of the mask image to the labels image 

752 # if the target was 'labels'), 

753 # or resampling of the data will be done at extract time. 

754 labels_img_ = self.labels_img_ 

755 mask_img_ = self.mask_img_ 

756 if self.resampling_target == "data": 

757 imgs_ = check_niimg(imgs, atleast_4d=True) 

758 if not check_same_fov( 

759 imgs_, 

760 labels_img_, 

761 ): 

762 warnings.warn( 

763 ( 

764 "Resampling labels at transform time...\n" 

765 "To avoid this warning, make sure to pass the images " 

766 "you want to transform to fit() first, " 

767 "or directly use fit_transform()." 

768 ), 

769 stacklevel=find_stack_level(), 

770 ) 

771 labels_img_ = self._resample_labels(imgs_) 

772 

773 if (mask_img_ is not None) and ( 

774 not check_same_fov( 

775 imgs_, 

776 mask_img_, 

777 ) 

778 ): 

779 warnings.warn( 

780 ( 

781 "Resampling mask at transform time...\n" 

782 "To avoid this warning, make sure to pass the images " 

783 "you want to transform to fit() first, " 

784 "or directly use fit_transform()." 

785 ), 

786 stacklevel=find_stack_level(), 

787 ) 

788 mask_img_ = self._cache(resample_img, func_memory_level=2)( 

789 mask_img_, 

790 interpolation="nearest", 

791 target_shape=imgs_.shape[:3], 

792 target_affine=imgs_.affine, 

793 copy_header=True, 

794 force_resample=False, 

795 ) 

796 

797 # Remove imgs_ from memory before loading the same image 

798 # in filter_and_extract. 

799 del imgs_ 

800 

801 target_shape = None 

802 target_affine = None 

803 if self.resampling_target == "labels": 

804 target_shape = labels_img_.shape[:3] 

805 target_affine = labels_img_.affine 

806 

807 params = get_params( 

808 NiftiLabelsMasker, 

809 self, 

810 ignore=["resampling_target"], 

811 ) 

812 params["target_shape"] = target_shape 

813 params["target_affine"] = target_affine 

814 params["clean_kwargs"] = self.clean_args_ 

815 # TODO remove in 0.13.2 

816 if self.clean_kwargs: 

817 params["clean_kwargs"] = self.clean_kwargs_ 

818 

819 region_signals, (ids, masked_atlas) = self._cache( 

820 filter_and_extract, 

821 ignore=["verbose", "memory", "memory_level"], 

822 )( 

823 # Images 

824 imgs, 

825 _ExtractionFunctor( 

826 labels_img_, 

827 self.background_label, 

828 self.strategy, 

829 self.keep_masked_labels, 

830 mask_img_, 

831 ), 

832 # Pre-processing 

833 params, 

834 confounds=confounds, 

835 sample_mask=sample_mask, 

836 dtype=self.dtype, 

837 # Caching 

838 memory=self.memory, 

839 memory_level=self.memory_level, 

840 verbose=self.verbose, 

841 ) 

842 

843 self._lut_ = self.lut_.copy() 

844 mask = mask = self.lut_["index"].isin([self.background_label, *ids]) 

845 self._lut_ = self._lut_[mask] 

846 self._lut_ = sanitize_look_up_table( 

847 self._lut_, atlas=np.array([self.background_label, *ids]) 

848 ) 

849 

850 self.region_atlas_ = masked_atlas 

851 

852 return region_signals 

853 

854 def _resample_labels(self, imgs_): 

855 logger.log( 

856 "Resampling labels", 

857 self.verbose, 

858 ) 

859 labels_before_resampling = set( 

860 np.unique(safe_get_data(self.labels_img_)) 

861 ) 

862 labels_img_ = self._cache(resample_img, func_memory_level=2)( 

863 self.labels_img_, 

864 interpolation="nearest", 

865 target_shape=imgs_.shape[:3], 

866 target_affine=imgs_.affine, 

867 copy_header=True, 

868 force_resample=False, 

869 ) 

870 labels_after_resampling = set(np.unique(safe_get_data(labels_img_))) 

871 if labels_diff := labels_before_resampling.difference( 

872 labels_after_resampling 

873 ): 

874 warnings.warn( 

875 "After resampling the label image to the data image, " 

876 f"the following labels were removed: {labels_diff}. " 

877 "Label image only contains " 

878 f"{len(labels_after_resampling)} labels " 

879 "(including background).", 

880 stacklevel=find_stack_level(), 

881 ) 

882 

883 return labels_img_ 

884 

885 @fill_doc 

886 def inverse_transform(self, signals): 

887 """Compute :term:`voxel` signals from region signals. 

888 

889 Any mask given at initialization is taken into account. 

890 

891 .. versionchanged:: 0.9.2 

892 

893 This method now supports 1D arrays, which will produce 3D images. 

894 

895 Parameters 

896 ---------- 

897 %(signals_inv_transform)s 

898 

899 Returns 

900 ------- 

901 %(img_inv_transform_nifti)s 

902 

903 """ 

904 from ..regions import signal_extraction 

905 

906 check_is_fitted(self) 

907 

908 signals = self._check_array(signals) 

909 

910 logger.log("computing image from signals", verbose=self.verbose) 

911 return signal_extraction.signals_to_img_labels( 

912 signals, 

913 self.labels_img_, 

914 self.mask_img_, 

915 background_label=self.background_label, 

916 )