Coverage for nilearn/maskers/surface_masker.py: 16%

157 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Masker for surface objects.""" 

2 

3from __future__ import annotations 

4 

5from copy import deepcopy 

6from warnings import warn 

7 

8import numpy as np 

9from sklearn.utils.estimator_checks import check_is_fitted 

10 

11from nilearn import DEFAULT_SEQUENTIAL_CMAP, signal 

12from nilearn._utils import constrained_layout_kwargs, fill_doc 

13from nilearn._utils.cache_mixin import cache 

14from nilearn._utils.class_inspect import get_params 

15from nilearn._utils.helpers import ( 

16 rename_parameters, 

17) 

18from nilearn._utils.logger import find_stack_level 

19from nilearn._utils.masker_validation import ( 

20 check_compatibility_mask_and_images, 

21) 

22from nilearn._utils.param_validation import check_params 

23from nilearn.image import concat_imgs, mean_img 

24from nilearn.maskers.base_masker import _BaseSurfaceMasker 

25from nilearn.surface.surface import SurfaceImage, at_least_2d, check_surf_img 

26from nilearn.surface.utils import check_polymesh_equal 

27 

28 

29@fill_doc 

30class SurfaceMasker(_BaseSurfaceMasker): 

31 """Extract data from a :obj:`~nilearn.surface.SurfaceImage`. 

32 

33 .. versionadded:: 0.11.0 

34 

35 Parameters 

36 ---------- 

37 mask_img : :obj:`~nilearn.surface.SurfaceImage` or None, default=None 

38 

39 %(smoothing_fwhm)s 

40 This parameter is not implemented yet. 

41 

42 %(standardize_maskers)s 

43 

44 %(standardize_confounds)s 

45 

46 %(detrend)s 

47 

48 high_variance_confounds : :obj:`bool`, default=False 

49 If True, high variance confounds are computed on provided image with 

50 :func:`nilearn.image.high_variance_confounds` and default parameters 

51 and regressed out. 

52 

53 %(low_pass)s 

54 

55 %(high_pass)s 

56 

57 %(t_r)s 

58 

59 %(memory)s 

60 

61 %(memory_level1)s 

62 

63 %(verbose0)s 

64 

65 reports : :obj:`bool`, default=True 

66 If set to True, data is saved in order to produce a report. 

67 

68 %(cmap)s 

69 default="inferno" 

70 Only relevant for the report figures. 

71 

72 %(clean_args)s 

73 

74 Attributes 

75 ---------- 

76 mask_img_ : A 1D binary :obj:`~nilearn.surface.SurfaceImage` 

77 The mask of the data, or the one computed from ``imgs`` passed to fit. 

78 If a ``mask_img`` is passed at masker construction, 

79 then ``mask_img_`` is the resulting binarized version of it 

80 where each vertex is ``True`` if all values across samples 

81 (for example across timepoints) is finite value different from 0. 

82 

83 n_elements_ : :obj:`int` or None 

84 number of vertices included in mask 

85 

86 """ 

87 

88 def __init__( 

89 self, 

90 mask_img=None, 

91 smoothing_fwhm=None, 

92 standardize=False, 

93 standardize_confounds=True, 

94 detrend=False, 

95 high_variance_confounds=False, 

96 low_pass=None, 

97 high_pass=None, 

98 t_r=None, 

99 memory=None, 

100 memory_level=1, 

101 verbose=0, 

102 reports=True, 

103 cmap=DEFAULT_SEQUENTIAL_CMAP, 

104 clean_args=None, 

105 ): 

106 self.mask_img = mask_img 

107 self.smoothing_fwhm = smoothing_fwhm 

108 self.standardize = standardize 

109 self.standardize_confounds = standardize_confounds 

110 self.high_variance_confounds = high_variance_confounds 

111 self.detrend = detrend 

112 self.low_pass = low_pass 

113 self.high_pass = high_pass 

114 self.t_r = t_r 

115 self.memory = memory 

116 self.memory_level = memory_level 

117 self.verbose = verbose 

118 self.reports = reports 

119 self.cmap = cmap 

120 self.clean_args = clean_args 

121 self._shelving = False 

122 # content to inject in the HTML template 

123 self._report_content = { 

124 "description": ( 

125 "This report shows the input surface image overlaid " 

126 "with the outlines of the mask. " 

127 "We recommend to inspect the report for the overlap " 

128 "between the mask and its input image. " 

129 ), 

130 "n_vertices": {}, 

131 # unused but required in HTML template 

132 "number_of_regions": None, 

133 "summary": None, 

134 "warning_message": None, 

135 "n_elements": 0, 

136 "coverage": 0, 

137 } 

138 # data necessary to construct figure for the report 

139 self._reporting_data = None 

140 

141 def __sklearn_is_fitted__(self): 

142 return ( 

143 hasattr(self, "mask_img_") 

144 and hasattr(self, "n_elements_") 

145 and self.mask_img_ is not None 

146 and self.n_elements_ is not None 

147 ) 

148 

149 def _fit_mask_img(self, img): 

150 """Get mask passed during init or compute one from input image. 

151 

152 Parameters 

153 ---------- 

154 img : SurfaceImage object or :obj:`list` of SurfaceImage or None 

155 """ 

156 self.mask_img_ = self._load_mask(img) 

157 

158 if self.mask_img_ is not None: 

159 if img is not None: 

160 warn( 

161 f"[{self.__class__.__name__}.fit] " 

162 "Generation of a mask has been" 

163 " requested (y != None) while a mask was" 

164 " given at masker creation. Given mask" 

165 " will be used.", 

166 stacklevel=find_stack_level(), 

167 ) 

168 return 

169 

170 if img is None: 

171 raise ValueError( 

172 "Parameter 'imgs' must be provided to " 

173 f"{self.__class__.__name__}.fit() " 

174 "if no mask is passed to mask_img." 

175 ) 

176 

177 img = deepcopy(img) 

178 if not isinstance(img, list): 

179 img = [img] 

180 img = concat_imgs(img) 

181 

182 img = at_least_2d(img) 

183 

184 check_surf_img(img) 

185 

186 mask_data = {} 

187 for part, v in img.data.parts.items(): 

188 # mask out vertices with NaN or infinite values 

189 mask_data[part] = np.isfinite(v.astype("float32")).all(axis=1) 

190 if not mask_data[part].all(): 

191 warn( 

192 "Non-finite values detected in the input image. " 

193 "The computed mask will mask out these vertices.", 

194 stacklevel=find_stack_level(), 

195 ) 

196 self.mask_img_ = SurfaceImage(mesh=img.mesh, data=mask_data) 

197 

198 @rename_parameters( 

199 replacement_params={"img": "imgs"}, end_version="0.13.2" 

200 ) 

201 @fill_doc 

202 def fit(self, imgs=None, y=None): 

203 """Prepare signal extraction from regions. 

204 

205 Parameters 

206 ---------- 

207 imgs : :obj:`~nilearn.surface.SurfaceImage` or \ 

208 :obj:`list` of :obj:`~nilearn.surface.SurfaceImage` or \ 

209 :obj:`tuple` of :obj:`~nilearn.surface.SurfaceImage` or None, \ 

210 default = None 

211 Mesh and data for both hemispheres. 

212 

213 %(y_dummy)s 

214 

215 Returns 

216 ------- 

217 SurfaceMasker object 

218 """ 

219 del y 

220 check_params(self.__dict__) 

221 if imgs is not None: 

222 self._check_imgs(imgs) 

223 

224 self._fit_mask_img(imgs) 

225 assert self.mask_img_ is not None 

226 

227 start, stop = 0, 0 

228 self._slices = {} 

229 for part_name, mask in self.mask_img_.data.parts.items(): 

230 stop = start + mask.sum() 

231 self._slices[part_name] = start, stop 

232 start = stop 

233 self.n_elements_ = int(stop) 

234 

235 if self.reports: 

236 self._report_content["n_elements"] = self.n_elements_ 

237 for part in self.mask_img_.data.parts: 

238 self._report_content["n_vertices"][part] = ( 

239 self.mask_img_.mesh.parts[part].n_vertices 

240 ) 

241 self._report_content["coverage"] = ( 

242 self.n_elements_ / self.mask_img_.mesh.n_vertices * 100 

243 ) 

244 self._reporting_data = { 

245 "mask": self.mask_img_, 

246 "images": imgs, 

247 } 

248 

249 if self.clean_args is None: 

250 self.clean_args_ = {} 

251 else: 

252 self.clean_args_ = self.clean_args 

253 

254 return self 

255 

256 @fill_doc 

257 def transform_single_imgs( 

258 self, 

259 imgs, 

260 confounds=None, 

261 sample_mask=None, 

262 ): 

263 """Extract signals from fitted surface object. 

264 

265 Parameters 

266 ---------- 

267 imgs : imgs : :obj:`~nilearn.surface.SurfaceImage` object or \ 

268 iterable of :obj:`~nilearn.surface.SurfaceImage` 

269 Images to process. 

270 Mesh and data for both hemispheres/parts. 

271 

272 %(confounds)s 

273 

274 %(sample_mask)s 

275 

276 Returns 

277 ------- 

278 %(signals_transform_surface)s 

279 

280 """ 

281 check_is_fitted(self) 

282 

283 parameters = get_params( 

284 self.__class__, 

285 self, 

286 ignore=[ 

287 "mask_img", 

288 ], 

289 ) 

290 

291 parameters["clean_args"] = self.clean_args_ 

292 

293 check_compatibility_mask_and_images(self.mask_img_, imgs) 

294 

295 check_polymesh_equal(self.mask_img_.mesh, imgs.mesh) 

296 

297 if self.reports: 

298 self._reporting_data["images"] = imgs 

299 

300 output = np.empty((1, self.n_elements_)) 

301 if len(imgs.shape) == 2: 

302 output = np.empty((imgs.shape[1], self.n_elements_)) 

303 for part_name, (start, stop) in self._slices.items(): 

304 mask = self.mask_img_.data.parts[part_name].ravel() 

305 output[:, start:stop] = imgs.data.parts[part_name][mask].T 

306 

307 # signal cleaning here 

308 output = cache( 

309 signal.clean, 

310 memory=self.memory, 

311 func_memory_level=2, 

312 memory_level=self.memory_level, 

313 shelve=self._shelving, 

314 )( 

315 output, 

316 detrend=parameters["detrend"], 

317 standardize=parameters["standardize"], 

318 standardize_confounds=parameters["standardize_confounds"], 

319 t_r=parameters["t_r"], 

320 low_pass=parameters["low_pass"], 

321 high_pass=parameters["high_pass"], 

322 confounds=confounds, 

323 sample_mask=sample_mask, 

324 **parameters["clean_args"], 

325 ) 

326 

327 return output 

328 

329 @fill_doc 

330 def inverse_transform(self, signals): 

331 """Transform extracted signal back to surface object. 

332 

333 Parameters 

334 ---------- 

335 %(signals_inv_transform)s 

336 

337 Returns 

338 ------- 

339 %(img_inv_transform_surface)s 

340 """ 

341 check_is_fitted(self) 

342 

343 return_1D = signals.ndim < 2 

344 

345 # do not run sklearn_check as they may cause some failure 

346 # with some GLM inputs 

347 signals = self._check_array(signals, sklearn_check=False) 

348 

349 data = {} 

350 for part_name, mask in self.mask_img_.data.parts.items(): 

351 data[part_name] = np.zeros( 

352 (mask.shape[0], signals.shape[0]), 

353 dtype=signals.dtype, 

354 ) 

355 start, stop = self._slices[part_name] 

356 data[part_name][mask.ravel()] = signals[:, start:stop].T 

357 if return_1D: 

358 data[part_name] = data[part_name].squeeze() 

359 

360 return SurfaceImage(mesh=self.mask_img_.mesh, data=data) 

361 

362 def generate_report(self): 

363 """Generate a report for the SurfaceMasker. 

364 

365 Returns 

366 ------- 

367 list(None) or HTMLReport 

368 """ 

369 from nilearn.reporting.html_report import generate_report 

370 

371 return generate_report(self) 

372 

373 def _reporting(self): 

374 """Load displays needed for report. 

375 

376 Returns 

377 ------- 

378 displays : :obj:`list` of None or bytes 

379 A list of all displays figures encoded as bytes to be rendered. 

380 Or a list with a single None element. 

381 """ 

382 # avoid circular import 

383 import matplotlib.pyplot as plt 

384 

385 from nilearn.reporting.utils import figure_to_png_base64 

386 

387 # Handle the edge case where this function is 

388 # called with a masker having report capabilities disabled 

389 if self._reporting_data is None: 

390 return [None] 

391 

392 fig = self._create_figure_for_report() 

393 

394 if not fig: 

395 return [None] 

396 

397 plt.close() 

398 

399 init_display = figure_to_png_base64(fig) 

400 

401 return [init_display] 

402 

403 def _create_figure_for_report(self): 

404 """Generate figure to include in the report. 

405 

406 Returns 

407 ------- 

408 None, :class:`~matplotlib.figure.Figure` or\ 

409 :class:`~nilearn.plotting.displays.PlotlySurfaceFigure` 

410 Returns ``None`` in case the masker was not fitted. 

411 """ 

412 # avoid circular import 

413 import matplotlib.pyplot as plt 

414 

415 from nilearn.plotting import plot_surf, plot_surf_contours 

416 

417 if not self._reporting_data["images"] and not getattr( 

418 self, "mask_img_", None 

419 ): 

420 return None 

421 

422 background_data = self.mask_img_ 

423 vmin = None 

424 vmax = None 

425 if self._reporting_data["images"]: 

426 background_data = self._reporting_data["images"] 

427 background_data = mean_img(background_data) 

428 vmin, vmax = background_data.data._get_min_max() 

429 

430 views = ["lateral", "medial"] 

431 hemispheres = ["left", "right"] 

432 

433 fig, axes = plt.subplots( 

434 len(views), 

435 len(hemispheres), 

436 subplot_kw={"projection": "3d"}, 

437 figsize=(20, 20), 

438 **constrained_layout_kwargs(), 

439 ) 

440 axes = np.atleast_2d(axes) 

441 

442 for ax_row, view in zip(axes, views): 

443 for ax, hemi in zip(ax_row, hemispheres): 

444 plot_surf( 

445 surf_map=background_data, 

446 hemi=hemi, 

447 view=view, 

448 figure=fig, 

449 axes=ax, 

450 cmap=self.cmap, 

451 vmin=vmin, 

452 vmax=vmax, 

453 ) 

454 

455 colors = None 

456 n_regions = len(np.unique(self.mask_img_.data.parts[hemi])) 

457 if n_regions == 1: 

458 colors = "b" 

459 elif n_regions == 2: 

460 colors = ["w", "b"] 

461 

462 plot_surf_contours( 

463 roi_map=self.mask_img_, 

464 hemi=hemi, 

465 view=view, 

466 figure=fig, 

467 axes=ax, 

468 colors=colors, 

469 ) 

470 

471 return fig