Coverage for nilearn/maskers/surface_masker.py: 16%
157 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Masker for surface objects."""
3from __future__ import annotations
5from copy import deepcopy
6from warnings import warn
8import numpy as np
9from sklearn.utils.estimator_checks import check_is_fitted
11from nilearn import DEFAULT_SEQUENTIAL_CMAP, signal
12from nilearn._utils import constrained_layout_kwargs, fill_doc
13from nilearn._utils.cache_mixin import cache
14from nilearn._utils.class_inspect import get_params
15from nilearn._utils.helpers import (
16 rename_parameters,
17)
18from nilearn._utils.logger import find_stack_level
19from nilearn._utils.masker_validation import (
20 check_compatibility_mask_and_images,
21)
22from nilearn._utils.param_validation import check_params
23from nilearn.image import concat_imgs, mean_img
24from nilearn.maskers.base_masker import _BaseSurfaceMasker
25from nilearn.surface.surface import SurfaceImage, at_least_2d, check_surf_img
26from nilearn.surface.utils import check_polymesh_equal
29@fill_doc
30class SurfaceMasker(_BaseSurfaceMasker):
31 """Extract data from a :obj:`~nilearn.surface.SurfaceImage`.
33 .. versionadded:: 0.11.0
35 Parameters
36 ----------
37 mask_img : :obj:`~nilearn.surface.SurfaceImage` or None, default=None
39 %(smoothing_fwhm)s
40 This parameter is not implemented yet.
42 %(standardize_maskers)s
44 %(standardize_confounds)s
46 %(detrend)s
48 high_variance_confounds : :obj:`bool`, default=False
49 If True, high variance confounds are computed on provided image with
50 :func:`nilearn.image.high_variance_confounds` and default parameters
51 and regressed out.
53 %(low_pass)s
55 %(high_pass)s
57 %(t_r)s
59 %(memory)s
61 %(memory_level1)s
63 %(verbose0)s
65 reports : :obj:`bool`, default=True
66 If set to True, data is saved in order to produce a report.
68 %(cmap)s
69 default="inferno"
70 Only relevant for the report figures.
72 %(clean_args)s
74 Attributes
75 ----------
76 mask_img_ : A 1D binary :obj:`~nilearn.surface.SurfaceImage`
77 The mask of the data, or the one computed from ``imgs`` passed to fit.
78 If a ``mask_img`` is passed at masker construction,
79 then ``mask_img_`` is the resulting binarized version of it
80 where each vertex is ``True`` if all values across samples
81 (for example across timepoints) is finite value different from 0.
83 n_elements_ : :obj:`int` or None
84 number of vertices included in mask
86 """
88 def __init__(
89 self,
90 mask_img=None,
91 smoothing_fwhm=None,
92 standardize=False,
93 standardize_confounds=True,
94 detrend=False,
95 high_variance_confounds=False,
96 low_pass=None,
97 high_pass=None,
98 t_r=None,
99 memory=None,
100 memory_level=1,
101 verbose=0,
102 reports=True,
103 cmap=DEFAULT_SEQUENTIAL_CMAP,
104 clean_args=None,
105 ):
106 self.mask_img = mask_img
107 self.smoothing_fwhm = smoothing_fwhm
108 self.standardize = standardize
109 self.standardize_confounds = standardize_confounds
110 self.high_variance_confounds = high_variance_confounds
111 self.detrend = detrend
112 self.low_pass = low_pass
113 self.high_pass = high_pass
114 self.t_r = t_r
115 self.memory = memory
116 self.memory_level = memory_level
117 self.verbose = verbose
118 self.reports = reports
119 self.cmap = cmap
120 self.clean_args = clean_args
121 self._shelving = False
122 # content to inject in the HTML template
123 self._report_content = {
124 "description": (
125 "This report shows the input surface image overlaid "
126 "with the outlines of the mask. "
127 "We recommend to inspect the report for the overlap "
128 "between the mask and its input image. "
129 ),
130 "n_vertices": {},
131 # unused but required in HTML template
132 "number_of_regions": None,
133 "summary": None,
134 "warning_message": None,
135 "n_elements": 0,
136 "coverage": 0,
137 }
138 # data necessary to construct figure for the report
139 self._reporting_data = None
141 def __sklearn_is_fitted__(self):
142 return (
143 hasattr(self, "mask_img_")
144 and hasattr(self, "n_elements_")
145 and self.mask_img_ is not None
146 and self.n_elements_ is not None
147 )
149 def _fit_mask_img(self, img):
150 """Get mask passed during init or compute one from input image.
152 Parameters
153 ----------
154 img : SurfaceImage object or :obj:`list` of SurfaceImage or None
155 """
156 self.mask_img_ = self._load_mask(img)
158 if self.mask_img_ is not None:
159 if img is not None:
160 warn(
161 f"[{self.__class__.__name__}.fit] "
162 "Generation of a mask has been"
163 " requested (y != None) while a mask was"
164 " given at masker creation. Given mask"
165 " will be used.",
166 stacklevel=find_stack_level(),
167 )
168 return
170 if img is None:
171 raise ValueError(
172 "Parameter 'imgs' must be provided to "
173 f"{self.__class__.__name__}.fit() "
174 "if no mask is passed to mask_img."
175 )
177 img = deepcopy(img)
178 if not isinstance(img, list):
179 img = [img]
180 img = concat_imgs(img)
182 img = at_least_2d(img)
184 check_surf_img(img)
186 mask_data = {}
187 for part, v in img.data.parts.items():
188 # mask out vertices with NaN or infinite values
189 mask_data[part] = np.isfinite(v.astype("float32")).all(axis=1)
190 if not mask_data[part].all():
191 warn(
192 "Non-finite values detected in the input image. "
193 "The computed mask will mask out these vertices.",
194 stacklevel=find_stack_level(),
195 )
196 self.mask_img_ = SurfaceImage(mesh=img.mesh, data=mask_data)
198 @rename_parameters(
199 replacement_params={"img": "imgs"}, end_version="0.13.2"
200 )
201 @fill_doc
202 def fit(self, imgs=None, y=None):
203 """Prepare signal extraction from regions.
205 Parameters
206 ----------
207 imgs : :obj:`~nilearn.surface.SurfaceImage` or \
208 :obj:`list` of :obj:`~nilearn.surface.SurfaceImage` or \
209 :obj:`tuple` of :obj:`~nilearn.surface.SurfaceImage` or None, \
210 default = None
211 Mesh and data for both hemispheres.
213 %(y_dummy)s
215 Returns
216 -------
217 SurfaceMasker object
218 """
219 del y
220 check_params(self.__dict__)
221 if imgs is not None:
222 self._check_imgs(imgs)
224 self._fit_mask_img(imgs)
225 assert self.mask_img_ is not None
227 start, stop = 0, 0
228 self._slices = {}
229 for part_name, mask in self.mask_img_.data.parts.items():
230 stop = start + mask.sum()
231 self._slices[part_name] = start, stop
232 start = stop
233 self.n_elements_ = int(stop)
235 if self.reports:
236 self._report_content["n_elements"] = self.n_elements_
237 for part in self.mask_img_.data.parts:
238 self._report_content["n_vertices"][part] = (
239 self.mask_img_.mesh.parts[part].n_vertices
240 )
241 self._report_content["coverage"] = (
242 self.n_elements_ / self.mask_img_.mesh.n_vertices * 100
243 )
244 self._reporting_data = {
245 "mask": self.mask_img_,
246 "images": imgs,
247 }
249 if self.clean_args is None:
250 self.clean_args_ = {}
251 else:
252 self.clean_args_ = self.clean_args
254 return self
256 @fill_doc
257 def transform_single_imgs(
258 self,
259 imgs,
260 confounds=None,
261 sample_mask=None,
262 ):
263 """Extract signals from fitted surface object.
265 Parameters
266 ----------
267 imgs : imgs : :obj:`~nilearn.surface.SurfaceImage` object or \
268 iterable of :obj:`~nilearn.surface.SurfaceImage`
269 Images to process.
270 Mesh and data for both hemispheres/parts.
272 %(confounds)s
274 %(sample_mask)s
276 Returns
277 -------
278 %(signals_transform_surface)s
280 """
281 check_is_fitted(self)
283 parameters = get_params(
284 self.__class__,
285 self,
286 ignore=[
287 "mask_img",
288 ],
289 )
291 parameters["clean_args"] = self.clean_args_
293 check_compatibility_mask_and_images(self.mask_img_, imgs)
295 check_polymesh_equal(self.mask_img_.mesh, imgs.mesh)
297 if self.reports:
298 self._reporting_data["images"] = imgs
300 output = np.empty((1, self.n_elements_))
301 if len(imgs.shape) == 2:
302 output = np.empty((imgs.shape[1], self.n_elements_))
303 for part_name, (start, stop) in self._slices.items():
304 mask = self.mask_img_.data.parts[part_name].ravel()
305 output[:, start:stop] = imgs.data.parts[part_name][mask].T
307 # signal cleaning here
308 output = cache(
309 signal.clean,
310 memory=self.memory,
311 func_memory_level=2,
312 memory_level=self.memory_level,
313 shelve=self._shelving,
314 )(
315 output,
316 detrend=parameters["detrend"],
317 standardize=parameters["standardize"],
318 standardize_confounds=parameters["standardize_confounds"],
319 t_r=parameters["t_r"],
320 low_pass=parameters["low_pass"],
321 high_pass=parameters["high_pass"],
322 confounds=confounds,
323 sample_mask=sample_mask,
324 **parameters["clean_args"],
325 )
327 return output
329 @fill_doc
330 def inverse_transform(self, signals):
331 """Transform extracted signal back to surface object.
333 Parameters
334 ----------
335 %(signals_inv_transform)s
337 Returns
338 -------
339 %(img_inv_transform_surface)s
340 """
341 check_is_fitted(self)
343 return_1D = signals.ndim < 2
345 # do not run sklearn_check as they may cause some failure
346 # with some GLM inputs
347 signals = self._check_array(signals, sklearn_check=False)
349 data = {}
350 for part_name, mask in self.mask_img_.data.parts.items():
351 data[part_name] = np.zeros(
352 (mask.shape[0], signals.shape[0]),
353 dtype=signals.dtype,
354 )
355 start, stop = self._slices[part_name]
356 data[part_name][mask.ravel()] = signals[:, start:stop].T
357 if return_1D:
358 data[part_name] = data[part_name].squeeze()
360 return SurfaceImage(mesh=self.mask_img_.mesh, data=data)
362 def generate_report(self):
363 """Generate a report for the SurfaceMasker.
365 Returns
366 -------
367 list(None) or HTMLReport
368 """
369 from nilearn.reporting.html_report import generate_report
371 return generate_report(self)
373 def _reporting(self):
374 """Load displays needed for report.
376 Returns
377 -------
378 displays : :obj:`list` of None or bytes
379 A list of all displays figures encoded as bytes to be rendered.
380 Or a list with a single None element.
381 """
382 # avoid circular import
383 import matplotlib.pyplot as plt
385 from nilearn.reporting.utils import figure_to_png_base64
387 # Handle the edge case where this function is
388 # called with a masker having report capabilities disabled
389 if self._reporting_data is None:
390 return [None]
392 fig = self._create_figure_for_report()
394 if not fig:
395 return [None]
397 plt.close()
399 init_display = figure_to_png_base64(fig)
401 return [init_display]
403 def _create_figure_for_report(self):
404 """Generate figure to include in the report.
406 Returns
407 -------
408 None, :class:`~matplotlib.figure.Figure` or\
409 :class:`~nilearn.plotting.displays.PlotlySurfaceFigure`
410 Returns ``None`` in case the masker was not fitted.
411 """
412 # avoid circular import
413 import matplotlib.pyplot as plt
415 from nilearn.plotting import plot_surf, plot_surf_contours
417 if not self._reporting_data["images"] and not getattr(
418 self, "mask_img_", None
419 ):
420 return None
422 background_data = self.mask_img_
423 vmin = None
424 vmax = None
425 if self._reporting_data["images"]:
426 background_data = self._reporting_data["images"]
427 background_data = mean_img(background_data)
428 vmin, vmax = background_data.data._get_min_max()
430 views = ["lateral", "medial"]
431 hemispheres = ["left", "right"]
433 fig, axes = plt.subplots(
434 len(views),
435 len(hemispheres),
436 subplot_kw={"projection": "3d"},
437 figsize=(20, 20),
438 **constrained_layout_kwargs(),
439 )
440 axes = np.atleast_2d(axes)
442 for ax_row, view in zip(axes, views):
443 for ax, hemi in zip(ax_row, hemispheres):
444 plot_surf(
445 surf_map=background_data,
446 hemi=hemi,
447 view=view,
448 figure=fig,
449 axes=ax,
450 cmap=self.cmap,
451 vmin=vmin,
452 vmax=vmax,
453 )
455 colors = None
456 n_regions = len(np.unique(self.mask_img_.data.parts[hemi]))
457 if n_regions == 1:
458 colors = "b"
459 elif n_regions == 2:
460 colors = ["w", "b"]
462 plot_surf_contours(
463 roi_map=self.mask_img_,
464 hemi=hemi,
465 view=view,
466 figure=fig,
467 axes=ax,
468 colors=colors,
469 )
471 return fig