Coverage for nilearn/maskers/surface_labels_masker.py: 18%

198 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Extract data from a SurfaceImage, averaging over atlas regions.""" 

2 

3import warnings 

4from copy import deepcopy 

5from pathlib import Path 

6from typing import Union 

7 

8import numpy as np 

9import pandas as pd 

10from scipy import ndimage 

11from sklearn.utils.estimator_checks import check_is_fitted 

12 

13from nilearn import DEFAULT_SEQUENTIAL_CMAP, signal 

14from nilearn._utils.bids import ( 

15 generate_atlas_look_up_table, 

16 sanitize_look_up_table, 

17) 

18from nilearn._utils.cache_mixin import cache 

19from nilearn._utils.class_inspect import get_params 

20from nilearn._utils.docs import fill_doc 

21from nilearn._utils.helpers import ( 

22 constrained_layout_kwargs, 

23 rename_parameters, 

24) 

25from nilearn._utils.logger import find_stack_level 

26from nilearn._utils.masker_validation import ( 

27 check_compatibility_mask_and_images, 

28) 

29from nilearn._utils.param_validation import ( 

30 check_params, 

31 check_reduction_strategy, 

32) 

33from nilearn.image import mean_img 

34from nilearn.maskers.base_masker import _BaseSurfaceMasker 

35from nilearn.surface.surface import ( 

36 SurfaceImage, 

37 at_least_2d, 

38 check_surf_img, 

39 get_data, 

40) 

41from nilearn.surface.utils import check_polymesh_equal 

42 

43 

44def signals_to_surf_img_labels( 

45 signals: np.ndarray, 

46 labels: np.ndarray, 

47 labels_img: SurfaceImage, 

48 background_label=0, 

49) -> SurfaceImage: 

50 """Transform signals to surface image labels.""" 

51 labels = labels[labels != background_label] 

52 

53 data = {} 

54 for part_name, labels_part in labels_img.data.parts.items(): 

55 data[part_name] = np.zeros( 

56 (labels_part.shape[0], signals.shape[0]), 

57 dtype=signals.dtype, 

58 ) 

59 for label_idx, label in enumerate(labels): 

60 data[part_name][labels_part == label] = signals[:, label_idx].T 

61 return SurfaceImage(mesh=labels_img.mesh, data=data) 

62 

63 

64@fill_doc 

65class SurfaceLabelsMasker(_BaseSurfaceMasker): 

66 """Extract data from a SurfaceImage, averaging over atlas regions. 

67 

68 .. versionadded:: 0.11.0 

69 

70 Parameters 

71 ---------- 

72 labels_img : :obj:`~nilearn.surface.SurfaceImage` object 

73 Region definitions, as one image of labels. 

74 The data for each hemisphere 

75 is of shape (n_vertices_per_hemisphere, n_regions). 

76 

77 labels : :obj:`list` of :obj:`str`, default=None 

78 Mutually exclusive with ``lut``. 

79 Labels corresponding to the labels image. 

80 This is used to improve reporting quality if provided. 

81 

82 .. warning:: 

83 If the labels are not be consistent with the label values 

84 provided through ``labels_img``, 

85 excess labels will be dropped, 

86 and missing labels will be labeled ``'unknown'``. 

87 

88 %(masker_lut)s 

89 

90 background_label : :obj:`int` or :obj:`float`, default=0 

91 Label used in labels_img to represent background. 

92 

93 .. warning:: 

94 

95 This value must be consistent with label values 

96 and image provided. 

97 

98 mask_img : :obj:`~nilearn.surface.SurfaceImage` object, optional 

99 Mask to apply to labels_img before extracting signals. Defines the \ 

100 overall area of the brain to consider. The data for each \ 

101 hemisphere is of shape (n_vertices_per_hemisphere, n_regions). 

102 

103 %(smoothing_fwhm)s 

104 This parameter is not implemented yet. 

105 

106 %(standardize_maskers)s 

107 

108 %(standardize_confounds)s 

109 

110 %(detrend)s 

111 

112 high_variance_confounds : :obj:`bool`, default=False 

113 If True, high variance confounds are computed on provided image with 

114 :func:`nilearn.image.high_variance_confounds` and default parameters 

115 and regressed out. 

116 

117 %(low_pass)s 

118 

119 %(high_pass)s 

120 

121 %(t_r)s 

122 

123 %(memory)s 

124 

125 %(memory_level1)s 

126 

127 %(verbose0)s 

128 

129 reports : :obj:`bool`, default=True 

130 If set to True, data is saved in order to produce a report. 

131 

132 %(cmap)s 

133 default="inferno" 

134 Only relevant for the report figures. 

135 

136 %(clean_args)s 

137 

138 Attributes 

139 ---------- 

140 labels_img_ : :obj:`nibabel.nifti1.Nifti1Image` 

141 The labels image after fitting. 

142 If a mask_img was used, 

143 then masked vertices will have the background value. 

144 

145 mask_img_ : A 1D binary :obj:`~nilearn.surface.SurfaceImage` or None. 

146 The mask of the data. 

147 If no ``mask_img`` was passed at masker construction, 

148 then ``mask_img_`` is ``None``, otherwise 

149 is the resulting binarized version of ``mask_img`` 

150 where each vertex is ``True`` if all values across samples 

151 (for example across timepoints) is finite value different from 0. 

152 

153 lut_ : :obj:`pandas.DataFrame` 

154 Look-up table derived from the ``labels`` or ``lut`` 

155 or from the values of the label image. 

156 """ 

157 

158 def __init__( 

159 self, 

160 labels_img=None, 

161 labels=None, 

162 lut=None, 

163 background_label=0, 

164 mask_img=None, 

165 smoothing_fwhm=None, 

166 standardize=False, 

167 standardize_confounds=True, 

168 detrend=False, 

169 high_variance_confounds=False, 

170 low_pass=None, 

171 high_pass=None, 

172 t_r=None, 

173 memory=None, 

174 memory_level=1, 

175 verbose=0, 

176 strategy="mean", 

177 reports=True, 

178 cmap=DEFAULT_SEQUENTIAL_CMAP, 

179 clean_args=None, 

180 ): 

181 self.labels_img = labels_img 

182 self.labels = labels 

183 self.lut = lut 

184 self.background_label = background_label 

185 self.mask_img = mask_img 

186 self.smoothing_fwhm = smoothing_fwhm 

187 self.standardize = standardize 

188 self.standardize_confounds = standardize_confounds 

189 self.high_variance_confounds = high_variance_confounds 

190 self.detrend = detrend 

191 self.low_pass = low_pass 

192 self.high_pass = high_pass 

193 self.t_r = t_r 

194 self.memory = memory 

195 self.memory_level = memory_level 

196 self.verbose = verbose 

197 self.reports = reports 

198 self.strategy = strategy 

199 self.cmap = cmap 

200 self.clean_args = clean_args 

201 

202 @property 

203 def n_elements_(self) -> int: 

204 """Return number of regions. 

205 

206 This is equal to the number of unique values 

207 in the fitted label image, 

208 minus the background value. 

209 """ 

210 check_is_fitted(self) 

211 lut = self.lut_ 

212 return len(lut[lut["index"] != self.background_label]) 

213 

214 @property 

215 def labels_(self) -> list[Union[int, float]]: 

216 """Return list of labels of the regions.""" 

217 check_is_fitted(self) 

218 lut = self.lut_ 

219 return lut["index"].to_list() 

220 

221 @property 

222 def region_names_(self) -> dict[int, str]: 

223 """Return a dictionary containing the region names corresponding \n 

224 to each column in the array returned by `transform`. 

225 

226 The region names correspond to the labels provided 

227 in labels in input. 

228 The region name corresponding to ``region_signal[:,i]`` 

229 is ``region_names_[i]``. 

230 

231 .. versionadded:: 0.11.2dev 

232 """ 

233 check_is_fitted(self) 

234 lut = self.lut_ 

235 return lut.loc[lut["index"] != self.background_label, "name"].to_dict() 

236 

237 @property 

238 def region_ids_(self) -> dict[Union[str, int], int]: 

239 """Return dictionary containing the region ids corresponding \n 

240 to each column in the array \n 

241 returned by `transform`. 

242 

243 The region id corresponding to ``region_signal[:,i]`` 

244 is ``region_ids_[i]``. 

245 ``region_ids_['background']`` is the background label. 

246 

247 .. versionadded:: 0.11.2dev 

248 """ 

249 check_is_fitted(self) 

250 lut = self.lut_ 

251 return lut["index"].to_dict() 

252 

253 @fill_doc 

254 @rename_parameters( 

255 replacement_params={"img": "imgs"}, end_version="0.13.2" 

256 ) 

257 def fit(self, imgs=None, y=None): 

258 """Prepare signal extraction from regions. 

259 

260 Parameters 

261 ---------- 

262 imgs : :obj:`~nilearn.surface.SurfaceImage` object or None, \ 

263 default=None 

264 

265 %(y_dummy)s 

266 

267 Returns 

268 ------- 

269 SurfaceLabelsMasker object 

270 """ 

271 del y 

272 check_params(self.__dict__) 

273 if imgs is not None: 

274 self._check_imgs(imgs) 

275 

276 if imgs is not None: 

277 check_surf_img(imgs) 

278 

279 check_reduction_strategy(self.strategy) 

280 

281 if self.labels_img is None: 

282 raise ValueError( 

283 "Please provide a labels_img to the masker. For example, " 

284 "masker = SurfaceLabelsMasker(labels_img=labels_img)" 

285 ) 

286 

287 if self.labels and self.lut is not None: 

288 raise ValueError( 

289 "Pass either labels or a lookup table (lut) to the masker, " 

290 "but not both." 

291 ) 

292 

293 self.labels_img_ = deepcopy(self.labels_img) 

294 

295 self.mask_img_ = self._load_mask(imgs) 

296 if self.mask_img_ is not None: 

297 check_polymesh_equal(self.labels_img_.mesh, self.mask_img.mesh) 

298 

299 # apply mask to label image 

300 for k in self.labels_img_.data.parts: 

301 mask = self.mask_img_.data.parts[k] 

302 self.labels_img_.data.parts[k][np.logical_not(mask)] = ( 

303 self.background_label 

304 ) 

305 

306 labels_before_mask = { 

307 int(x) for x in np.unique(get_data(self.labels_img)) 

308 } 

309 labels_after_mask = { 

310 int(x) for x in np.unique(get_data(self.labels_img_)) 

311 } 

312 labels_diff = labels_before_mask - labels_after_mask 

313 if labels_diff: 

314 warnings.warn( 

315 "After applying mask to the labels image, " 

316 "the following labels were " 

317 f"removed: {labels_diff}. " 

318 f"Out of {len(labels_before_mask)} labels, the " 

319 "masked labels image only contains " 

320 f"{len(labels_after_mask)} labels " 

321 "(including background).", 

322 stacklevel=find_stack_level(), 

323 ) 

324 

325 self._shelving = False 

326 

327 # generate a look up table if one was not provided 

328 if self.lut is not None: 

329 if isinstance(self.lut, (str, Path)): 

330 lut = pd.read_table(self.lut, sep=None) 

331 else: 

332 lut = self.lut 

333 elif self.labels: 

334 lut = generate_atlas_look_up_table( 

335 function=None, 

336 name=self.labels, 

337 index=self.labels_img_, 

338 ) 

339 else: 

340 lut = generate_atlas_look_up_table( 

341 function=None, index=self.labels_img_ 

342 ) 

343 

344 self.lut_ = sanitize_look_up_table(lut, atlas=self.labels_img_) 

345 

346 self._shelving = False 

347 

348 if self.clean_args is None: 

349 self.clean_args_ = {} 

350 else: 

351 self.clean_args_ = self.clean_args 

352 

353 if not self.reports: 

354 self._reporting_data = None 

355 return self 

356 

357 # content to inject in the HTML template 

358 self._report_content = { 

359 "description": ( 

360 "This report shows the input surface image overlaid " 

361 "with the outlines of the mask. " 

362 "We recommend to inspect the report for the overlap " 

363 "between the mask and its input image. " 

364 ), 

365 "n_vertices": {}, 

366 "number_of_regions": self.n_elements_, 

367 "summary": {}, 

368 "warning_message": None, 

369 } 

370 

371 for part in self.labels_img_.data.parts: 

372 self._report_content["n_vertices"][part] = ( 

373 self.labels_img_.mesh.parts[part].n_vertices 

374 ) 

375 

376 self._reporting_data = self._generate_reporting_data() 

377 

378 return self 

379 

380 def _generate_reporting_data(self): 

381 for part in self.labels_img_.data.parts: 

382 size = [] 

383 relative_size = [] 

384 

385 table = self.lut_.copy() 

386 

387 for _, row in table.iterrows(): 

388 n_vertices = self.labels_img_.data.parts[part] == row["index"] 

389 size.append(n_vertices.sum()) 

390 tmp = ( 

391 n_vertices.sum() 

392 / self.labels_img_.mesh.parts[part].n_vertices 

393 * 100 

394 ) 

395 relative_size.append(f"{tmp:.2}") 

396 

397 table["size"] = size 

398 table["relative size"] = relative_size 

399 

400 self._report_content["summary"][part] = table 

401 

402 return { 

403 "labels_image": self.labels_img_, 

404 "images": None, 

405 } 

406 

407 def __sklearn_is_fitted__(self): 

408 return hasattr(self, "lut_") and hasattr(self, "mask_img_") 

409 

410 @fill_doc 

411 def transform_single_imgs(self, imgs, confounds=None, sample_mask=None): 

412 """Extract signals from surface object. 

413 

414 Parameters 

415 ---------- 

416 imgs : imgs : :obj:`~nilearn.surface.SurfaceImage` object or \ 

417 iterable of :obj:`~nilearn.surface.SurfaceImage` 

418 Images to process. 

419 Mesh and data for both hemispheres. 

420 

421 %(confounds)s 

422 

423 %(sample_mask)s 

424 

425 Returns 

426 ------- 

427 %(signals_transform_surface)s 

428 """ 

429 check_is_fitted(self) 

430 

431 check_compatibility_mask_and_images(self.labels_img_, imgs) 

432 check_polymesh_equal(self.labels_img_.mesh, imgs.mesh) 

433 

434 imgs = at_least_2d(imgs) 

435 img_data = get_data(imgs) 

436 

437 target_datatype = ( 

438 np.float32 if img_data.dtype == np.float32 else np.float64 

439 ) 

440 

441 img_data = img_data.astype(target_datatype) 

442 

443 n_samples = 1 if len(img_data.shape) == 1 else img_data.shape[1] 

444 

445 region_signals = np.ndarray( 

446 (n_samples, self.n_elements_), dtype=target_datatype 

447 ) 

448 # adapted from nilearn.regions.signal_extraction.img_to_signals_labels 

449 # iterate over time points and apply reduction function over labels. 

450 labels_data = get_data(self.labels_img_) 

451 

452 index = self.labels_ 

453 if self.background_label in index: 

454 index.pop(index.index(self.background_label)) 

455 

456 reduction_function = getattr(ndimage, self.strategy) 

457 

458 for n, sample in enumerate(np.rollaxis(img_data, -1)): 

459 tmp = np.asarray( 

460 reduction_function(sample, labels=labels_data, index=index) 

461 ) 

462 region_signals[n] = tmp 

463 

464 parameters = get_params( 

465 self.__class__, 

466 self, 

467 ignore=[ 

468 "mask_img", 

469 ], 

470 ) 

471 parameters["clean_args"] = self.clean_args_ 

472 

473 # signal cleaning here 

474 region_signals = cache( 

475 signal.clean, 

476 memory=self.memory, 

477 func_memory_level=2, 

478 memory_level=self.memory_level, 

479 shelve=self._shelving, 

480 )( 

481 region_signals, 

482 detrend=parameters["detrend"], 

483 standardize=parameters["standardize"], 

484 standardize_confounds=parameters["standardize_confounds"], 

485 t_r=parameters["t_r"], 

486 low_pass=parameters["low_pass"], 

487 high_pass=parameters["high_pass"], 

488 confounds=confounds, 

489 sample_mask=sample_mask, 

490 **parameters["clean_args"], 

491 ) 

492 

493 return region_signals 

494 

495 @fill_doc 

496 def inverse_transform(self, signals): 

497 """Transform extracted signal back to surface image. 

498 

499 Parameters 

500 ---------- 

501 %(signals_inv_transform)s 

502 

503 Returns 

504 ------- 

505 %(img_inv_transform_surface)s 

506 """ 

507 check_is_fitted(self) 

508 

509 return_1D = signals.ndim < 2 

510 

511 signals = self._check_array(signals) 

512 

513 imgs = signals_to_surf_img_labels( 

514 signals, 

515 np.asarray(self.labels_), 

516 self.labels_img_, 

517 self.background_label, 

518 ) 

519 

520 if return_1D: 

521 for k, v in imgs.data.parts.items(): 

522 imgs.data.parts[k] = v.squeeze() 

523 

524 return imgs 

525 

526 def generate_report(self): 

527 """Generate a report.""" 

528 from nilearn.reporting.html_report import generate_report 

529 

530 return generate_report(self) 

531 

532 def _reporting(self): 

533 """Load displays needed for report. 

534 

535 Returns 

536 ------- 

537 displays : list 

538 A list of all displays to be rendered. 

539 """ 

540 import matplotlib.pyplot as plt 

541 

542 from nilearn.reporting.utils import figure_to_png_base64 

543 

544 # Handle the edge case where this function is 

545 # called with a masker having report capabilities disabled 

546 if self._reporting_data is None: 

547 return [None] 

548 

549 fig = self._create_figure_for_report() 

550 

551 plt.close() 

552 

553 init_display = figure_to_png_base64(fig) 

554 

555 return [init_display] 

556 

557 def _create_figure_for_report(self): 

558 """Create a figure of the contours of label image. 

559 

560 If transform() was applied to an image, 

561 this image is used as background 

562 on which the contours are drawn. 

563 """ 

564 import matplotlib.pyplot as plt 

565 

566 from nilearn.plotting import plot_surf, plot_surf_contours 

567 

568 labels_img = self._reporting_data["labels_image"] 

569 

570 img = self._reporting_data["images"] 

571 if img: 

572 img = mean_img(img) 

573 vmin, vmax = img.data._get_min_max() 

574 

575 # TODO: possibly allow to generate a report with other views 

576 views = ["lateral", "medial"] 

577 hemispheres = ["left", "right"] 

578 

579 fig, axes = plt.subplots( 

580 len(views), 

581 len(hemispheres), 

582 subplot_kw={"projection": "3d"}, 

583 figsize=(20, 20), 

584 **constrained_layout_kwargs(), 

585 ) 

586 axes = np.atleast_2d(axes) 

587 

588 for ax_row, view in zip(axes, views): 

589 for ax, hemi in zip(ax_row, hemispheres): 

590 if img: 

591 plot_surf( 

592 surf_map=img, 

593 hemi=hemi, 

594 view=view, 

595 figure=fig, 

596 axes=ax, 

597 cmap=self.cmap, 

598 vmin=vmin, 

599 vmax=vmax, 

600 ) 

601 plot_surf_contours( 

602 roi_map=labels_img, 

603 hemi=hemi, 

604 view=view, 

605 figure=fig, 

606 axes=ax, 

607 ) 

608 

609 return fig