Coverage for nilearn/maskers/surface_maps_masker.py: 12%

178 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Extract data from a SurfaceImage, using maps of potentially overlapping 

2brain regions. 

3""" 

4 

5import warnings 

6 

7import numpy as np 

8from scipy import linalg 

9from sklearn.utils.estimator_checks import check_is_fitted 

10 

11from nilearn import DEFAULT_SEQUENTIAL_CMAP, signal 

12from nilearn._utils import fill_doc, logger 

13from nilearn._utils.cache_mixin import cache 

14from nilearn._utils.class_inspect import get_params 

15from nilearn._utils.helpers import ( 

16 constrained_layout_kwargs, 

17 is_matplotlib_installed, 

18 is_plotly_installed, 

19 rename_parameters, 

20) 

21from nilearn._utils.logger import find_stack_level 

22from nilearn._utils.masker_validation import ( 

23 check_compatibility_mask_and_images, 

24) 

25from nilearn._utils.param_validation import check_params 

26from nilearn.image import index_img, mean_img 

27from nilearn.maskers.base_masker import _BaseSurfaceMasker 

28from nilearn.surface.surface import ( 

29 SurfaceImage, 

30 at_least_2d, 

31 check_surf_img, 

32 get_data, 

33) 

34from nilearn.surface.utils import check_polymesh_equal 

35 

36 

37@fill_doc 

38class SurfaceMapsMasker(_BaseSurfaceMasker): 

39 """Extract data from a SurfaceImage, using maps of potentially overlapping 

40 brain regions. 

41 

42 .. versionadded:: 0.11.1 

43 

44 Parameters 

45 ---------- 

46 maps_img : :obj:`~nilearn.surface.SurfaceImage` 

47 Set of maps that define the regions. representative time course \ 

48 per map is extracted using least square regression. The data for \ 

49 each hemisphere is of shape (n_vertices_per_hemisphere, n_regions). 

50 

51 mask_img : :obj:`~nilearn.surface.SurfaceImage`, optional, default=None 

52 Mask to apply to regions before extracting signals. Defines the \ 

53 overall area of the brain to consider. The data for each \ 

54 hemisphere is of shape (n_vertices_per_hemisphere, n_regions). 

55 

56 allow_overlap : :obj:`bool`, default=True 

57 If False, an error is raised if the maps overlaps (ie at least two 

58 maps have a non-zero value for the same voxel). 

59 

60 %(smoothing_fwhm)s 

61 This parameter is not implemented yet. 

62 

63 %(standardize_maskers)s 

64 

65 %(standardize_confounds)s 

66 

67 %(detrend)s 

68 

69 high_variance_confounds : :obj:`bool`, default=False 

70 If True, high variance confounds are computed on provided image \ 

71 with :func:`nilearn.image.high_variance_confounds` and default \ 

72 parameters and regressed out. 

73 

74 %(low_pass)s 

75 

76 %(high_pass)s 

77 

78 %(t_r)s 

79 

80 %(memory)s 

81 

82 %(memory_level1)s 

83 

84 %(verbose0)s 

85 

86 reports : :obj:`bool`, default=True 

87 If set to True, data is saved in order to produce a report. 

88 

89 %(cmap)s 

90 default="inferno" 

91 Only relevant for the report figures. 

92 

93 %(clean_args)s 

94 

95 Attributes 

96 ---------- 

97 maps_img_ : :obj:`~nilearn.surface.SurfaceImage` 

98 The same as the input `maps_img`, kept solely for consistency 

99 across maskers. 

100 

101 mask_img_ : A 1D binary :obj:`~nilearn.surface.SurfaceImage` or None. 

102 The mask of the data. 

103 If no ``mask_img`` was passed at masker construction, 

104 then ``mask_img_`` is ``None``, otherwise 

105 is the resulting binarized version of ``mask_img`` 

106 where each vertex is ``True`` if all values across samples 

107 (for example across timepoints) is finite value different from 0. 

108 

109 n_elements_ : :obj:`int` 

110 The number of regions in the maps image. 

111 

112 

113 See Also 

114 -------- 

115 nilearn.maskers.SurfaceMasker 

116 nilearn.maskers.SurfaceLabelsMasker 

117 

118 """ 

119 

120 def __init__( 

121 self, 

122 maps_img=None, 

123 mask_img=None, 

124 allow_overlap=True, 

125 smoothing_fwhm=None, 

126 standardize=False, 

127 standardize_confounds=True, 

128 detrend=False, 

129 high_variance_confounds=False, 

130 low_pass=None, 

131 high_pass=None, 

132 t_r=None, 

133 memory=None, 

134 memory_level=1, 

135 verbose=0, 

136 reports=True, 

137 cmap=DEFAULT_SEQUENTIAL_CMAP, 

138 clean_args=None, 

139 ): 

140 self.maps_img = maps_img 

141 self.mask_img = mask_img 

142 self.allow_overlap = allow_overlap 

143 self.smoothing_fwhm = smoothing_fwhm 

144 self.standardize = standardize 

145 self.standardize_confounds = standardize_confounds 

146 self.high_variance_confounds = high_variance_confounds 

147 self.detrend = detrend 

148 self.low_pass = low_pass 

149 self.high_pass = high_pass 

150 self.t_r = t_r 

151 self.memory = memory 

152 self.memory_level = memory_level 

153 self.verbose = verbose 

154 self.reports = reports 

155 self.cmap = cmap 

156 self.clean_args = clean_args 

157 

158 @fill_doc 

159 @rename_parameters( 

160 replacement_params={"img": "imgs"}, end_version="0.13.2" 

161 ) 

162 def fit(self, imgs=None, y=None): 

163 """Prepare signal extraction from regions. 

164 

165 Parameters 

166 ---------- 

167 imgs : :obj:`~nilearn.surface.SurfaceImage` object or None, \ 

168 default=None 

169 

170 %(y_dummy)s 

171 

172 Returns 

173 ------- 

174 SurfaceMapsMasker object 

175 """ 

176 del y 

177 check_params(self.__dict__) 

178 if imgs is not None: 

179 self._check_imgs(imgs) 

180 

181 if self.maps_img is None: 

182 raise ValueError( 

183 "Please provide a maps_img during initialization. " 

184 "For example, masker = SurfaceMapsMasker(maps_img=maps_img)" 

185 ) 

186 

187 if imgs is not None: 

188 check_surf_img(imgs) 

189 

190 logger.log( 

191 msg=f"loading regions from {self.maps_img.__repr__()}", 

192 verbose=self.verbose, 

193 ) 

194 # check maps_img data is 2D 

195 self.maps_img.data._check_ndims(2, "maps_img") 

196 self.maps_img_ = self.maps_img 

197 

198 self.n_elements_ = self.maps_img.shape[1] 

199 

200 self.mask_img_ = self._load_mask(imgs) 

201 if self.mask_img_ is not None: 

202 check_polymesh_equal(self.maps_img.mesh, self.mask_img_.mesh) 

203 

204 self._shelving = False 

205 

206 # initialize reporting content and data 

207 if not self.reports: 

208 self._reporting_data = None 

209 return self 

210 

211 # content to inject in the HTML template 

212 self._report_content = { 

213 "description": ( 

214 "This report shows the input surface image " 

215 "(if provided via img) overlaid with the regions provided " 

216 "via maps_img." 

217 ), 

218 "n_vertices": {}, 

219 "number_of_regions": self.n_elements_, 

220 "summary": {}, 

221 "warning_message": None, 

222 } 

223 

224 for part in self.maps_img.data.parts: 

225 self._report_content["n_vertices"][part] = ( 

226 self.maps_img.mesh.parts[part].n_vertices 

227 ) 

228 

229 self._reporting_data = { 

230 "maps_img": self.maps_img_, 

231 "mask": self.mask_img_, 

232 "images": None, # we will update image in transform 

233 } 

234 

235 if self.clean_args is None: 

236 self.clean_args_ = {} 

237 else: 

238 self.clean_args_ = self.clean_args 

239 

240 return self 

241 

242 def __sklearn_is_fitted__(self): 

243 return hasattr(self, "n_elements_") 

244 

245 @fill_doc 

246 def transform_single_imgs(self, imgs, confounds=None, sample_mask=None): 

247 """Extract signals from surface object. 

248 

249 Parameters 

250 ---------- 

251 imgs : imgs : :obj:`~nilearn.surface.SurfaceImage` object or \ 

252 iterable of :obj:`~nilearn.surface.SurfaceImage` 

253 Images to process. 

254 Mesh and data for both hemispheres/parts. 

255 

256 %(confounds)s 

257 

258 %(sample_mask)s 

259 

260 Returns 

261 ------- 

262 %(signals_transform_surface)s 

263 """ 

264 check_is_fitted(self) 

265 

266 check_compatibility_mask_and_images(self.maps_img, imgs) 

267 

268 check_polymesh_equal(self.maps_img.mesh, imgs.mesh) 

269 

270 imgs = at_least_2d(imgs) 

271 

272 img_data = np.concatenate( 

273 list(imgs.data.parts.values()), axis=0 

274 ).astype(np.float32) 

275 

276 # get concatenated hemispheres/parts data from maps_img and mask_img 

277 maps_data = get_data(self.maps_img) 

278 mask_data = ( 

279 get_data(self.mask_img_) if self.mask_img_ is not None else None 

280 ) 

281 

282 parameters = get_params( 

283 self.__class__, 

284 self, 

285 ) 

286 parameters["clean_args"] = self.clean_args_ 

287 

288 # apply mask if provided 

289 # and then extract signal via least square regression 

290 if mask_data is not None: 

291 region_signals = cache( 

292 linalg.lstsq, 

293 memory=self.memory, 

294 func_memory_level=2, 

295 memory_level=self.memory_level, 

296 shelve=self._shelving, 

297 )( 

298 maps_data[mask_data.flatten(), :], 

299 img_data[mask_data.flatten(), :], 

300 )[0].T 

301 # if no mask, directly extract signal 

302 else: 

303 region_signals = cache( 

304 linalg.lstsq, 

305 memory=self.memory, 

306 func_memory_level=2, 

307 memory_level=self.memory_level, 

308 shelve=self._shelving, 

309 )(maps_data, img_data)[0].T 

310 

311 parameters = get_params( 

312 self.__class__, 

313 self, 

314 ) 

315 

316 parameters["clean_args"] = self.clean_args_ 

317 

318 # signal cleaning here 

319 region_signals = cache( 

320 signal.clean, 

321 memory=self.memory, 

322 func_memory_level=2, 

323 memory_level=self.memory_level, 

324 shelve=self._shelving, 

325 )( 

326 region_signals, 

327 detrend=parameters["detrend"], 

328 standardize=parameters["standardize"], 

329 standardize_confounds=parameters["standardize_confounds"], 

330 t_r=parameters["t_r"], 

331 low_pass=parameters["low_pass"], 

332 high_pass=parameters["high_pass"], 

333 confounds=confounds, 

334 sample_mask=sample_mask, 

335 **parameters["clean_args"], 

336 ) 

337 

338 return region_signals 

339 

340 @fill_doc 

341 def inverse_transform(self, region_signals): 

342 """Compute :term:`vertex` signals from region signals. 

343 

344 Parameters 

345 ---------- 

346 %(region_signals_inv_transform)s 

347 

348 Returns 

349 ------- 

350 %(img_inv_transform_surface)s 

351 """ 

352 check_is_fitted(self) 

353 

354 return_1D = region_signals.ndim < 2 

355 

356 region_signals = self._check_array(region_signals) 

357 

358 # get concatenated hemispheres/parts data from maps_img and mask_img 

359 maps_data = get_data(self.maps_img) 

360 mask_data = ( 

361 get_data(self.mask_img) if self.mask_img is not None else None 

362 ) 

363 if region_signals.shape[1] != self.n_elements_: 

364 raise ValueError( 

365 f"Expected {self.n_elements_} regions, " 

366 f"but got {region_signals.shape[1]}." 

367 ) 

368 

369 logger.log("computing image from signals", verbose=self.verbose) 

370 # project region signals back to vertices 

371 if mask_data is not None: 

372 # vertices that are not in the mask will have a signal of 0 

373 # so we initialize the vertex signals with 0 

374 # and shape (n_timepoints, n_vertices) 

375 vertex_signals = np.zeros( 

376 (region_signals.shape[0], self.maps_img.mesh.n_vertices) 

377 ) 

378 # dot product between (n_timepoints, n_regions) and 

379 # (n_regions, n_vertices) 

380 vertex_signals[:, mask_data.flatten()] = np.dot( 

381 region_signals, maps_data[mask_data.flatten(), :].T 

382 ) 

383 else: 

384 vertex_signals = np.dot(region_signals, maps_data.T) 

385 

386 # we need the data to be of shape (n_vertices, n_timepoints) 

387 # because the SurfaceImage object expects it 

388 vertex_signals = vertex_signals.T 

389 

390 # split the signal into hemispheres 

391 vertex_signals = { 

392 "left": vertex_signals[ 

393 : self.maps_img.data.parts["left"].shape[0], : 

394 ], 

395 "right": vertex_signals[ 

396 self.maps_img.data.parts["left"].shape[0] :, : 

397 ], 

398 } 

399 

400 imgs = SurfaceImage(mesh=self.maps_img.mesh, data=vertex_signals) 

401 

402 if return_1D: 

403 for k, v in imgs.data.parts.items(): 

404 imgs.data.parts[k] = v.squeeze() 

405 

406 return imgs 

407 

408 def generate_report(self, displayed_maps=10, engine="matplotlib"): 

409 """Generate an HTML report for the current ``SurfaceMapsMasker`` 

410 object. 

411 

412 .. note:: 

413 This functionality requires to have ``Matplotlib`` installed. 

414 

415 Parameters 

416 ---------- 

417 displayed_maps : :obj:`int`, or :obj:`list`, \ 

418 or :class:`~numpy.ndarray`, or "all", default=10 

419 Indicates which maps will be displayed in the HTML report. 

420 

421 - If "all": All maps will be displayed in the report. 

422 

423 .. code-block:: python 

424 

425 masker.generate_report("all") 

426 

427 .. warning: 

428 If there are too many maps, this might be time and 

429 memory consuming, and will result in very heavy 

430 reports. 

431 

432 - If a :obj:`list` or :class:`~numpy.ndarray`: This indicates 

433 the indices of the maps to be displayed in the report. For 

434 example, the following code will generate a report with maps 

435 6, 3, and 12, displayed in this specific order: 

436 

437 .. code-block:: python 

438 

439 masker.generate_report([6, 3, 12]) 

440 

441 - If an :obj:`int`: This will only display the first n maps, 

442 n being the value of the parameter. By default, the report 

443 will only contain the first 10 maps. Example to display the 

444 first 16 maps: 

445 

446 .. code-block:: python 

447 

448 masker.generate_report(16) 

449 

450 engine : :obj:`str`, default="matplotlib" 

451 The plotting engine to use for the report. Can be either 

452 "matplotlib" or "plotly". If "matplotlib" is selected, the report 

453 will be static. If "plotly" is selected, the report 

454 will be interactive. If the selected engine is not installed, the 

455 report will use the available plotting engine. If none of the 

456 engines are installed, no report will be generated. 

457 

458 Returns 

459 ------- 

460 report : `nilearn.reporting.html_report.HTMLReport` 

461 HTML report for the masker. 

462 """ 

463 # need to have matplotlib installed to generate reports no matter what 

464 # engine is selected 

465 from nilearn.reporting.html_report import generate_report 

466 

467 if not is_matplotlib_installed(): 

468 return generate_report(self) 

469 

470 if engine not in ["plotly", "matplotlib"]: 

471 raise ValueError( 

472 "Parameter ``engine`` should be either 'matplotlib' or " 

473 "'plotly'." 

474 ) 

475 

476 # switch to matplotlib if plotly is selected but not installed 

477 if engine == "plotly" and not is_plotly_installed(): 

478 engine = "matplotlib" 

479 warnings.warn( 

480 "Plotly is not installed. " 

481 "Switching to matplotlib for report generation.", 

482 stacklevel=find_stack_level(), 

483 ) 

484 if hasattr(self, "_report_content"): 

485 self._report_content["engine"] = engine 

486 

487 incorrect_type = not isinstance( 

488 displayed_maps, (list, np.ndarray, int, str) 

489 ) 

490 incorrect_string = ( 

491 isinstance(displayed_maps, str) and displayed_maps != "all" 

492 ) 

493 not_integer = ( 

494 not isinstance(displayed_maps, str) 

495 and np.array(displayed_maps).dtype != int 

496 ) 

497 if incorrect_type or incorrect_string or not_integer: 

498 raise TypeError( 

499 "Parameter ``displayed_maps`` of " 

500 "``generate_report()`` should be either 'all' or " 

501 "an int, or a list/array of ints. You provided a " 

502 f"{type(displayed_maps)}" 

503 ) 

504 

505 self.displayed_maps = displayed_maps 

506 

507 return generate_report(self) 

508 

509 def _reporting(self): 

510 """Load displays needed for report. 

511 

512 Returns 

513 ------- 

514 displays : list 

515 A list of all displays to be rendered. 

516 """ 

517 import matplotlib.pyplot as plt 

518 

519 from nilearn.reporting.utils import figure_to_png_base64 

520 

521 # Handle the edge case where this function is 

522 # called with a masker having report capabilities disabled 

523 if self._reporting_data is None: 

524 return [None] 

525 

526 maps_img = self._reporting_data["maps_img"] 

527 

528 img = self._reporting_data["images"] 

529 if img: 

530 img = mean_img(img) 

531 

532 n_maps = self.maps_img_.shape[1] 

533 maps_to_be_displayed = range(n_maps) 

534 if isinstance(self.displayed_maps, int): 

535 if n_maps < self.displayed_maps: 

536 msg = ( 

537 "`generate_report()` received " 

538 f"{self.displayed_maps} maps to be displayed. " 

539 f"But masker only has {n_maps} maps. " 

540 f"Setting number of displayed maps to {n_maps}." 

541 ) 

542 warnings.warn( 

543 category=UserWarning, 

544 message=msg, 

545 stacklevel=find_stack_level(), 

546 ) 

547 self.displayed_maps = n_maps 

548 maps_to_be_displayed = range(self.displayed_maps) 

549 

550 elif isinstance(self.displayed_maps, (list, np.ndarray)): 

551 if max(self.displayed_maps) > n_maps: 

552 raise ValueError( 

553 "Report cannot display the following maps " 

554 f"{self.displayed_maps} because " 

555 f"masker only has {n_maps} maps." 

556 ) 

557 maps_to_be_displayed = self.displayed_maps 

558 

559 self._report_content["number_of_maps"] = n_maps 

560 self._report_content["displayed_maps"] = list(maps_to_be_displayed) 

561 embeded_images = [] 

562 

563 if img is None: 

564 msg = ( 

565 "SurfaceMapsMasker has not been transformed (via transform() " 

566 "method) on any image yet. Plotting only maps for reporting." 

567 ) 

568 warnings.warn(msg, stacklevel=find_stack_level()) 

569 

570 for roi in maps_to_be_displayed: 

571 roi = index_img(maps_img, roi) 

572 fig = self._create_figure_for_report(roi=roi, bg_img=img) 

573 if self._report_content["engine"] == "plotly": 

574 embeded_images.append(fig) 

575 elif self._report_content["engine"] == "matplotlib": 

576 embeded_images.append(figure_to_png_base64(fig)) 

577 plt.close() 

578 

579 return embeded_images 

580 

581 def _create_figure_for_report(self, roi, bg_img): 

582 """Create a figure of maps image, one region at a time. 

583 

584 If transform() was applied to an image, this image is used as 

585 background on which the maps are plotted. 

586 """ 

587 import matplotlib.pyplot as plt 

588 

589 from nilearn.plotting import plot_surf, view_surf 

590 

591 threshold = 1e-6 

592 if self._report_content["engine"] == "plotly": 

593 # squeeze the last dimension 

594 for part in roi.data.parts: 

595 roi.data.parts[part] = np.squeeze( 

596 roi.data.parts[part], axis=-1 

597 ) 

598 fig = view_surf( 

599 surf_map=roi, 

600 bg_map=bg_img, 

601 bg_on_data=True, 

602 threshold=threshold, 

603 hemi="both", 

604 cmap=self.cmap, 

605 ).get_iframe(width=500) 

606 elif self._report_content["engine"] == "matplotlib": 

607 # TODO: possibly allow to generate a report with other views 

608 views = ["lateral", "medial"] 

609 hemispheres = ["left", "right"] 

610 fig, axes = plt.subplots( 

611 len(views), 

612 len(hemispheres), 

613 subplot_kw={"projection": "3d"}, 

614 figsize=(20, 20), 

615 **constrained_layout_kwargs(), 

616 ) 

617 axes = np.atleast_2d(axes) 

618 for ax_row, view in zip(axes, views): 

619 for ax, hemi in zip(ax_row, hemispheres): 

620 # very low threshold to only make 0 values transparent 

621 plot_surf( 

622 surf_map=roi, 

623 bg_map=bg_img, 

624 hemi=hemi, 

625 view=view, 

626 figure=fig, 

627 axes=ax, 

628 cmap=self.cmap, 

629 colorbar=False, 

630 threshold=threshold, 

631 bg_on_data=True, 

632 ) 

633 return fig