Coverage for nilearn/maskers/surface_maps_masker.py: 12%
178 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Extract data from a SurfaceImage, using maps of potentially overlapping
2brain regions.
3"""
5import warnings
7import numpy as np
8from scipy import linalg
9from sklearn.utils.estimator_checks import check_is_fitted
11from nilearn import DEFAULT_SEQUENTIAL_CMAP, signal
12from nilearn._utils import fill_doc, logger
13from nilearn._utils.cache_mixin import cache
14from nilearn._utils.class_inspect import get_params
15from nilearn._utils.helpers import (
16 constrained_layout_kwargs,
17 is_matplotlib_installed,
18 is_plotly_installed,
19 rename_parameters,
20)
21from nilearn._utils.logger import find_stack_level
22from nilearn._utils.masker_validation import (
23 check_compatibility_mask_and_images,
24)
25from nilearn._utils.param_validation import check_params
26from nilearn.image import index_img, mean_img
27from nilearn.maskers.base_masker import _BaseSurfaceMasker
28from nilearn.surface.surface import (
29 SurfaceImage,
30 at_least_2d,
31 check_surf_img,
32 get_data,
33)
34from nilearn.surface.utils import check_polymesh_equal
37@fill_doc
38class SurfaceMapsMasker(_BaseSurfaceMasker):
39 """Extract data from a SurfaceImage, using maps of potentially overlapping
40 brain regions.
42 .. versionadded:: 0.11.1
44 Parameters
45 ----------
46 maps_img : :obj:`~nilearn.surface.SurfaceImage`
47 Set of maps that define the regions. representative time course \
48 per map is extracted using least square regression. The data for \
49 each hemisphere is of shape (n_vertices_per_hemisphere, n_regions).
51 mask_img : :obj:`~nilearn.surface.SurfaceImage`, optional, default=None
52 Mask to apply to regions before extracting signals. Defines the \
53 overall area of the brain to consider. The data for each \
54 hemisphere is of shape (n_vertices_per_hemisphere, n_regions).
56 allow_overlap : :obj:`bool`, default=True
57 If False, an error is raised if the maps overlaps (ie at least two
58 maps have a non-zero value for the same voxel).
60 %(smoothing_fwhm)s
61 This parameter is not implemented yet.
63 %(standardize_maskers)s
65 %(standardize_confounds)s
67 %(detrend)s
69 high_variance_confounds : :obj:`bool`, default=False
70 If True, high variance confounds are computed on provided image \
71 with :func:`nilearn.image.high_variance_confounds` and default \
72 parameters and regressed out.
74 %(low_pass)s
76 %(high_pass)s
78 %(t_r)s
80 %(memory)s
82 %(memory_level1)s
84 %(verbose0)s
86 reports : :obj:`bool`, default=True
87 If set to True, data is saved in order to produce a report.
89 %(cmap)s
90 default="inferno"
91 Only relevant for the report figures.
93 %(clean_args)s
95 Attributes
96 ----------
97 maps_img_ : :obj:`~nilearn.surface.SurfaceImage`
98 The same as the input `maps_img`, kept solely for consistency
99 across maskers.
101 mask_img_ : A 1D binary :obj:`~nilearn.surface.SurfaceImage` or None.
102 The mask of the data.
103 If no ``mask_img`` was passed at masker construction,
104 then ``mask_img_`` is ``None``, otherwise
105 is the resulting binarized version of ``mask_img``
106 where each vertex is ``True`` if all values across samples
107 (for example across timepoints) is finite value different from 0.
109 n_elements_ : :obj:`int`
110 The number of regions in the maps image.
113 See Also
114 --------
115 nilearn.maskers.SurfaceMasker
116 nilearn.maskers.SurfaceLabelsMasker
118 """
120 def __init__(
121 self,
122 maps_img=None,
123 mask_img=None,
124 allow_overlap=True,
125 smoothing_fwhm=None,
126 standardize=False,
127 standardize_confounds=True,
128 detrend=False,
129 high_variance_confounds=False,
130 low_pass=None,
131 high_pass=None,
132 t_r=None,
133 memory=None,
134 memory_level=1,
135 verbose=0,
136 reports=True,
137 cmap=DEFAULT_SEQUENTIAL_CMAP,
138 clean_args=None,
139 ):
140 self.maps_img = maps_img
141 self.mask_img = mask_img
142 self.allow_overlap = allow_overlap
143 self.smoothing_fwhm = smoothing_fwhm
144 self.standardize = standardize
145 self.standardize_confounds = standardize_confounds
146 self.high_variance_confounds = high_variance_confounds
147 self.detrend = detrend
148 self.low_pass = low_pass
149 self.high_pass = high_pass
150 self.t_r = t_r
151 self.memory = memory
152 self.memory_level = memory_level
153 self.verbose = verbose
154 self.reports = reports
155 self.cmap = cmap
156 self.clean_args = clean_args
158 @fill_doc
159 @rename_parameters(
160 replacement_params={"img": "imgs"}, end_version="0.13.2"
161 )
162 def fit(self, imgs=None, y=None):
163 """Prepare signal extraction from regions.
165 Parameters
166 ----------
167 imgs : :obj:`~nilearn.surface.SurfaceImage` object or None, \
168 default=None
170 %(y_dummy)s
172 Returns
173 -------
174 SurfaceMapsMasker object
175 """
176 del y
177 check_params(self.__dict__)
178 if imgs is not None:
179 self._check_imgs(imgs)
181 if self.maps_img is None:
182 raise ValueError(
183 "Please provide a maps_img during initialization. "
184 "For example, masker = SurfaceMapsMasker(maps_img=maps_img)"
185 )
187 if imgs is not None:
188 check_surf_img(imgs)
190 logger.log(
191 msg=f"loading regions from {self.maps_img.__repr__()}",
192 verbose=self.verbose,
193 )
194 # check maps_img data is 2D
195 self.maps_img.data._check_ndims(2, "maps_img")
196 self.maps_img_ = self.maps_img
198 self.n_elements_ = self.maps_img.shape[1]
200 self.mask_img_ = self._load_mask(imgs)
201 if self.mask_img_ is not None:
202 check_polymesh_equal(self.maps_img.mesh, self.mask_img_.mesh)
204 self._shelving = False
206 # initialize reporting content and data
207 if not self.reports:
208 self._reporting_data = None
209 return self
211 # content to inject in the HTML template
212 self._report_content = {
213 "description": (
214 "This report shows the input surface image "
215 "(if provided via img) overlaid with the regions provided "
216 "via maps_img."
217 ),
218 "n_vertices": {},
219 "number_of_regions": self.n_elements_,
220 "summary": {},
221 "warning_message": None,
222 }
224 for part in self.maps_img.data.parts:
225 self._report_content["n_vertices"][part] = (
226 self.maps_img.mesh.parts[part].n_vertices
227 )
229 self._reporting_data = {
230 "maps_img": self.maps_img_,
231 "mask": self.mask_img_,
232 "images": None, # we will update image in transform
233 }
235 if self.clean_args is None:
236 self.clean_args_ = {}
237 else:
238 self.clean_args_ = self.clean_args
240 return self
242 def __sklearn_is_fitted__(self):
243 return hasattr(self, "n_elements_")
245 @fill_doc
246 def transform_single_imgs(self, imgs, confounds=None, sample_mask=None):
247 """Extract signals from surface object.
249 Parameters
250 ----------
251 imgs : imgs : :obj:`~nilearn.surface.SurfaceImage` object or \
252 iterable of :obj:`~nilearn.surface.SurfaceImage`
253 Images to process.
254 Mesh and data for both hemispheres/parts.
256 %(confounds)s
258 %(sample_mask)s
260 Returns
261 -------
262 %(signals_transform_surface)s
263 """
264 check_is_fitted(self)
266 check_compatibility_mask_and_images(self.maps_img, imgs)
268 check_polymesh_equal(self.maps_img.mesh, imgs.mesh)
270 imgs = at_least_2d(imgs)
272 img_data = np.concatenate(
273 list(imgs.data.parts.values()), axis=0
274 ).astype(np.float32)
276 # get concatenated hemispheres/parts data from maps_img and mask_img
277 maps_data = get_data(self.maps_img)
278 mask_data = (
279 get_data(self.mask_img_) if self.mask_img_ is not None else None
280 )
282 parameters = get_params(
283 self.__class__,
284 self,
285 )
286 parameters["clean_args"] = self.clean_args_
288 # apply mask if provided
289 # and then extract signal via least square regression
290 if mask_data is not None:
291 region_signals = cache(
292 linalg.lstsq,
293 memory=self.memory,
294 func_memory_level=2,
295 memory_level=self.memory_level,
296 shelve=self._shelving,
297 )(
298 maps_data[mask_data.flatten(), :],
299 img_data[mask_data.flatten(), :],
300 )[0].T
301 # if no mask, directly extract signal
302 else:
303 region_signals = cache(
304 linalg.lstsq,
305 memory=self.memory,
306 func_memory_level=2,
307 memory_level=self.memory_level,
308 shelve=self._shelving,
309 )(maps_data, img_data)[0].T
311 parameters = get_params(
312 self.__class__,
313 self,
314 )
316 parameters["clean_args"] = self.clean_args_
318 # signal cleaning here
319 region_signals = cache(
320 signal.clean,
321 memory=self.memory,
322 func_memory_level=2,
323 memory_level=self.memory_level,
324 shelve=self._shelving,
325 )(
326 region_signals,
327 detrend=parameters["detrend"],
328 standardize=parameters["standardize"],
329 standardize_confounds=parameters["standardize_confounds"],
330 t_r=parameters["t_r"],
331 low_pass=parameters["low_pass"],
332 high_pass=parameters["high_pass"],
333 confounds=confounds,
334 sample_mask=sample_mask,
335 **parameters["clean_args"],
336 )
338 return region_signals
340 @fill_doc
341 def inverse_transform(self, region_signals):
342 """Compute :term:`vertex` signals from region signals.
344 Parameters
345 ----------
346 %(region_signals_inv_transform)s
348 Returns
349 -------
350 %(img_inv_transform_surface)s
351 """
352 check_is_fitted(self)
354 return_1D = region_signals.ndim < 2
356 region_signals = self._check_array(region_signals)
358 # get concatenated hemispheres/parts data from maps_img and mask_img
359 maps_data = get_data(self.maps_img)
360 mask_data = (
361 get_data(self.mask_img) if self.mask_img is not None else None
362 )
363 if region_signals.shape[1] != self.n_elements_:
364 raise ValueError(
365 f"Expected {self.n_elements_} regions, "
366 f"but got {region_signals.shape[1]}."
367 )
369 logger.log("computing image from signals", verbose=self.verbose)
370 # project region signals back to vertices
371 if mask_data is not None:
372 # vertices that are not in the mask will have a signal of 0
373 # so we initialize the vertex signals with 0
374 # and shape (n_timepoints, n_vertices)
375 vertex_signals = np.zeros(
376 (region_signals.shape[0], self.maps_img.mesh.n_vertices)
377 )
378 # dot product between (n_timepoints, n_regions) and
379 # (n_regions, n_vertices)
380 vertex_signals[:, mask_data.flatten()] = np.dot(
381 region_signals, maps_data[mask_data.flatten(), :].T
382 )
383 else:
384 vertex_signals = np.dot(region_signals, maps_data.T)
386 # we need the data to be of shape (n_vertices, n_timepoints)
387 # because the SurfaceImage object expects it
388 vertex_signals = vertex_signals.T
390 # split the signal into hemispheres
391 vertex_signals = {
392 "left": vertex_signals[
393 : self.maps_img.data.parts["left"].shape[0], :
394 ],
395 "right": vertex_signals[
396 self.maps_img.data.parts["left"].shape[0] :, :
397 ],
398 }
400 imgs = SurfaceImage(mesh=self.maps_img.mesh, data=vertex_signals)
402 if return_1D:
403 for k, v in imgs.data.parts.items():
404 imgs.data.parts[k] = v.squeeze()
406 return imgs
408 def generate_report(self, displayed_maps=10, engine="matplotlib"):
409 """Generate an HTML report for the current ``SurfaceMapsMasker``
410 object.
412 .. note::
413 This functionality requires to have ``Matplotlib`` installed.
415 Parameters
416 ----------
417 displayed_maps : :obj:`int`, or :obj:`list`, \
418 or :class:`~numpy.ndarray`, or "all", default=10
419 Indicates which maps will be displayed in the HTML report.
421 - If "all": All maps will be displayed in the report.
423 .. code-block:: python
425 masker.generate_report("all")
427 .. warning:
428 If there are too many maps, this might be time and
429 memory consuming, and will result in very heavy
430 reports.
432 - If a :obj:`list` or :class:`~numpy.ndarray`: This indicates
433 the indices of the maps to be displayed in the report. For
434 example, the following code will generate a report with maps
435 6, 3, and 12, displayed in this specific order:
437 .. code-block:: python
439 masker.generate_report([6, 3, 12])
441 - If an :obj:`int`: This will only display the first n maps,
442 n being the value of the parameter. By default, the report
443 will only contain the first 10 maps. Example to display the
444 first 16 maps:
446 .. code-block:: python
448 masker.generate_report(16)
450 engine : :obj:`str`, default="matplotlib"
451 The plotting engine to use for the report. Can be either
452 "matplotlib" or "plotly". If "matplotlib" is selected, the report
453 will be static. If "plotly" is selected, the report
454 will be interactive. If the selected engine is not installed, the
455 report will use the available plotting engine. If none of the
456 engines are installed, no report will be generated.
458 Returns
459 -------
460 report : `nilearn.reporting.html_report.HTMLReport`
461 HTML report for the masker.
462 """
463 # need to have matplotlib installed to generate reports no matter what
464 # engine is selected
465 from nilearn.reporting.html_report import generate_report
467 if not is_matplotlib_installed():
468 return generate_report(self)
470 if engine not in ["plotly", "matplotlib"]:
471 raise ValueError(
472 "Parameter ``engine`` should be either 'matplotlib' or "
473 "'plotly'."
474 )
476 # switch to matplotlib if plotly is selected but not installed
477 if engine == "plotly" and not is_plotly_installed():
478 engine = "matplotlib"
479 warnings.warn(
480 "Plotly is not installed. "
481 "Switching to matplotlib for report generation.",
482 stacklevel=find_stack_level(),
483 )
484 if hasattr(self, "_report_content"):
485 self._report_content["engine"] = engine
487 incorrect_type = not isinstance(
488 displayed_maps, (list, np.ndarray, int, str)
489 )
490 incorrect_string = (
491 isinstance(displayed_maps, str) and displayed_maps != "all"
492 )
493 not_integer = (
494 not isinstance(displayed_maps, str)
495 and np.array(displayed_maps).dtype != int
496 )
497 if incorrect_type or incorrect_string or not_integer:
498 raise TypeError(
499 "Parameter ``displayed_maps`` of "
500 "``generate_report()`` should be either 'all' or "
501 "an int, or a list/array of ints. You provided a "
502 f"{type(displayed_maps)}"
503 )
505 self.displayed_maps = displayed_maps
507 return generate_report(self)
509 def _reporting(self):
510 """Load displays needed for report.
512 Returns
513 -------
514 displays : list
515 A list of all displays to be rendered.
516 """
517 import matplotlib.pyplot as plt
519 from nilearn.reporting.utils import figure_to_png_base64
521 # Handle the edge case where this function is
522 # called with a masker having report capabilities disabled
523 if self._reporting_data is None:
524 return [None]
526 maps_img = self._reporting_data["maps_img"]
528 img = self._reporting_data["images"]
529 if img:
530 img = mean_img(img)
532 n_maps = self.maps_img_.shape[1]
533 maps_to_be_displayed = range(n_maps)
534 if isinstance(self.displayed_maps, int):
535 if n_maps < self.displayed_maps:
536 msg = (
537 "`generate_report()` received "
538 f"{self.displayed_maps} maps to be displayed. "
539 f"But masker only has {n_maps} maps. "
540 f"Setting number of displayed maps to {n_maps}."
541 )
542 warnings.warn(
543 category=UserWarning,
544 message=msg,
545 stacklevel=find_stack_level(),
546 )
547 self.displayed_maps = n_maps
548 maps_to_be_displayed = range(self.displayed_maps)
550 elif isinstance(self.displayed_maps, (list, np.ndarray)):
551 if max(self.displayed_maps) > n_maps:
552 raise ValueError(
553 "Report cannot display the following maps "
554 f"{self.displayed_maps} because "
555 f"masker only has {n_maps} maps."
556 )
557 maps_to_be_displayed = self.displayed_maps
559 self._report_content["number_of_maps"] = n_maps
560 self._report_content["displayed_maps"] = list(maps_to_be_displayed)
561 embeded_images = []
563 if img is None:
564 msg = (
565 "SurfaceMapsMasker has not been transformed (via transform() "
566 "method) on any image yet. Plotting only maps for reporting."
567 )
568 warnings.warn(msg, stacklevel=find_stack_level())
570 for roi in maps_to_be_displayed:
571 roi = index_img(maps_img, roi)
572 fig = self._create_figure_for_report(roi=roi, bg_img=img)
573 if self._report_content["engine"] == "plotly":
574 embeded_images.append(fig)
575 elif self._report_content["engine"] == "matplotlib":
576 embeded_images.append(figure_to_png_base64(fig))
577 plt.close()
579 return embeded_images
581 def _create_figure_for_report(self, roi, bg_img):
582 """Create a figure of maps image, one region at a time.
584 If transform() was applied to an image, this image is used as
585 background on which the maps are plotted.
586 """
587 import matplotlib.pyplot as plt
589 from nilearn.plotting import plot_surf, view_surf
591 threshold = 1e-6
592 if self._report_content["engine"] == "plotly":
593 # squeeze the last dimension
594 for part in roi.data.parts:
595 roi.data.parts[part] = np.squeeze(
596 roi.data.parts[part], axis=-1
597 )
598 fig = view_surf(
599 surf_map=roi,
600 bg_map=bg_img,
601 bg_on_data=True,
602 threshold=threshold,
603 hemi="both",
604 cmap=self.cmap,
605 ).get_iframe(width=500)
606 elif self._report_content["engine"] == "matplotlib":
607 # TODO: possibly allow to generate a report with other views
608 views = ["lateral", "medial"]
609 hemispheres = ["left", "right"]
610 fig, axes = plt.subplots(
611 len(views),
612 len(hemispheres),
613 subplot_kw={"projection": "3d"},
614 figsize=(20, 20),
615 **constrained_layout_kwargs(),
616 )
617 axes = np.atleast_2d(axes)
618 for ax_row, view in zip(axes, views):
619 for ax, hemi in zip(ax_row, hemispheres):
620 # very low threshold to only make 0 values transparent
621 plot_surf(
622 surf_map=roi,
623 bg_map=bg_img,
624 hemi=hemi,
625 view=view,
626 figure=fig,
627 axes=ax,
628 cmap=self.cmap,
629 colorbar=False,
630 threshold=threshold,
631 bg_on_data=True,
632 )
633 return fig