Coverage for nilearn/maskers/surface_labels_masker.py: 18%
198 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Extract data from a SurfaceImage, averaging over atlas regions."""
3import warnings
4from copy import deepcopy
5from pathlib import Path
6from typing import Union
8import numpy as np
9import pandas as pd
10from scipy import ndimage
11from sklearn.utils.estimator_checks import check_is_fitted
13from nilearn import DEFAULT_SEQUENTIAL_CMAP, signal
14from nilearn._utils.bids import (
15 generate_atlas_look_up_table,
16 sanitize_look_up_table,
17)
18from nilearn._utils.cache_mixin import cache
19from nilearn._utils.class_inspect import get_params
20from nilearn._utils.docs import fill_doc
21from nilearn._utils.helpers import (
22 constrained_layout_kwargs,
23 rename_parameters,
24)
25from nilearn._utils.logger import find_stack_level
26from nilearn._utils.masker_validation import (
27 check_compatibility_mask_and_images,
28)
29from nilearn._utils.param_validation import (
30 check_params,
31 check_reduction_strategy,
32)
33from nilearn.image import mean_img
34from nilearn.maskers.base_masker import _BaseSurfaceMasker
35from nilearn.surface.surface import (
36 SurfaceImage,
37 at_least_2d,
38 check_surf_img,
39 get_data,
40)
41from nilearn.surface.utils import check_polymesh_equal
44def signals_to_surf_img_labels(
45 signals: np.ndarray,
46 labels: np.ndarray,
47 labels_img: SurfaceImage,
48 background_label=0,
49) -> SurfaceImage:
50 """Transform signals to surface image labels."""
51 labels = labels[labels != background_label]
53 data = {}
54 for part_name, labels_part in labels_img.data.parts.items():
55 data[part_name] = np.zeros(
56 (labels_part.shape[0], signals.shape[0]),
57 dtype=signals.dtype,
58 )
59 for label_idx, label in enumerate(labels):
60 data[part_name][labels_part == label] = signals[:, label_idx].T
61 return SurfaceImage(mesh=labels_img.mesh, data=data)
64@fill_doc
65class SurfaceLabelsMasker(_BaseSurfaceMasker):
66 """Extract data from a SurfaceImage, averaging over atlas regions.
68 .. versionadded:: 0.11.0
70 Parameters
71 ----------
72 labels_img : :obj:`~nilearn.surface.SurfaceImage` object
73 Region definitions, as one image of labels.
74 The data for each hemisphere
75 is of shape (n_vertices_per_hemisphere, n_regions).
77 labels : :obj:`list` of :obj:`str`, default=None
78 Mutually exclusive with ``lut``.
79 Labels corresponding to the labels image.
80 This is used to improve reporting quality if provided.
82 .. warning::
83 If the labels are not be consistent with the label values
84 provided through ``labels_img``,
85 excess labels will be dropped,
86 and missing labels will be labeled ``'unknown'``.
88 %(masker_lut)s
90 background_label : :obj:`int` or :obj:`float`, default=0
91 Label used in labels_img to represent background.
93 .. warning::
95 This value must be consistent with label values
96 and image provided.
98 mask_img : :obj:`~nilearn.surface.SurfaceImage` object, optional
99 Mask to apply to labels_img before extracting signals. Defines the \
100 overall area of the brain to consider. The data for each \
101 hemisphere is of shape (n_vertices_per_hemisphere, n_regions).
103 %(smoothing_fwhm)s
104 This parameter is not implemented yet.
106 %(standardize_maskers)s
108 %(standardize_confounds)s
110 %(detrend)s
112 high_variance_confounds : :obj:`bool`, default=False
113 If True, high variance confounds are computed on provided image with
114 :func:`nilearn.image.high_variance_confounds` and default parameters
115 and regressed out.
117 %(low_pass)s
119 %(high_pass)s
121 %(t_r)s
123 %(memory)s
125 %(memory_level1)s
127 %(verbose0)s
129 reports : :obj:`bool`, default=True
130 If set to True, data is saved in order to produce a report.
132 %(cmap)s
133 default="inferno"
134 Only relevant for the report figures.
136 %(clean_args)s
138 Attributes
139 ----------
140 labels_img_ : :obj:`nibabel.nifti1.Nifti1Image`
141 The labels image after fitting.
142 If a mask_img was used,
143 then masked vertices will have the background value.
145 mask_img_ : A 1D binary :obj:`~nilearn.surface.SurfaceImage` or None.
146 The mask of the data.
147 If no ``mask_img`` was passed at masker construction,
148 then ``mask_img_`` is ``None``, otherwise
149 is the resulting binarized version of ``mask_img``
150 where each vertex is ``True`` if all values across samples
151 (for example across timepoints) is finite value different from 0.
153 lut_ : :obj:`pandas.DataFrame`
154 Look-up table derived from the ``labels`` or ``lut``
155 or from the values of the label image.
156 """
158 def __init__(
159 self,
160 labels_img=None,
161 labels=None,
162 lut=None,
163 background_label=0,
164 mask_img=None,
165 smoothing_fwhm=None,
166 standardize=False,
167 standardize_confounds=True,
168 detrend=False,
169 high_variance_confounds=False,
170 low_pass=None,
171 high_pass=None,
172 t_r=None,
173 memory=None,
174 memory_level=1,
175 verbose=0,
176 strategy="mean",
177 reports=True,
178 cmap=DEFAULT_SEQUENTIAL_CMAP,
179 clean_args=None,
180 ):
181 self.labels_img = labels_img
182 self.labels = labels
183 self.lut = lut
184 self.background_label = background_label
185 self.mask_img = mask_img
186 self.smoothing_fwhm = smoothing_fwhm
187 self.standardize = standardize
188 self.standardize_confounds = standardize_confounds
189 self.high_variance_confounds = high_variance_confounds
190 self.detrend = detrend
191 self.low_pass = low_pass
192 self.high_pass = high_pass
193 self.t_r = t_r
194 self.memory = memory
195 self.memory_level = memory_level
196 self.verbose = verbose
197 self.reports = reports
198 self.strategy = strategy
199 self.cmap = cmap
200 self.clean_args = clean_args
202 @property
203 def n_elements_(self) -> int:
204 """Return number of regions.
206 This is equal to the number of unique values
207 in the fitted label image,
208 minus the background value.
209 """
210 check_is_fitted(self)
211 lut = self.lut_
212 return len(lut[lut["index"] != self.background_label])
214 @property
215 def labels_(self) -> list[Union[int, float]]:
216 """Return list of labels of the regions."""
217 check_is_fitted(self)
218 lut = self.lut_
219 return lut["index"].to_list()
221 @property
222 def region_names_(self) -> dict[int, str]:
223 """Return a dictionary containing the region names corresponding \n
224 to each column in the array returned by `transform`.
226 The region names correspond to the labels provided
227 in labels in input.
228 The region name corresponding to ``region_signal[:,i]``
229 is ``region_names_[i]``.
231 .. versionadded:: 0.11.2dev
232 """
233 check_is_fitted(self)
234 lut = self.lut_
235 return lut.loc[lut["index"] != self.background_label, "name"].to_dict()
237 @property
238 def region_ids_(self) -> dict[Union[str, int], int]:
239 """Return dictionary containing the region ids corresponding \n
240 to each column in the array \n
241 returned by `transform`.
243 The region id corresponding to ``region_signal[:,i]``
244 is ``region_ids_[i]``.
245 ``region_ids_['background']`` is the background label.
247 .. versionadded:: 0.11.2dev
248 """
249 check_is_fitted(self)
250 lut = self.lut_
251 return lut["index"].to_dict()
253 @fill_doc
254 @rename_parameters(
255 replacement_params={"img": "imgs"}, end_version="0.13.2"
256 )
257 def fit(self, imgs=None, y=None):
258 """Prepare signal extraction from regions.
260 Parameters
261 ----------
262 imgs : :obj:`~nilearn.surface.SurfaceImage` object or None, \
263 default=None
265 %(y_dummy)s
267 Returns
268 -------
269 SurfaceLabelsMasker object
270 """
271 del y
272 check_params(self.__dict__)
273 if imgs is not None:
274 self._check_imgs(imgs)
276 if imgs is not None:
277 check_surf_img(imgs)
279 check_reduction_strategy(self.strategy)
281 if self.labels_img is None:
282 raise ValueError(
283 "Please provide a labels_img to the masker. For example, "
284 "masker = SurfaceLabelsMasker(labels_img=labels_img)"
285 )
287 if self.labels and self.lut is not None:
288 raise ValueError(
289 "Pass either labels or a lookup table (lut) to the masker, "
290 "but not both."
291 )
293 self.labels_img_ = deepcopy(self.labels_img)
295 self.mask_img_ = self._load_mask(imgs)
296 if self.mask_img_ is not None:
297 check_polymesh_equal(self.labels_img_.mesh, self.mask_img.mesh)
299 # apply mask to label image
300 for k in self.labels_img_.data.parts:
301 mask = self.mask_img_.data.parts[k]
302 self.labels_img_.data.parts[k][np.logical_not(mask)] = (
303 self.background_label
304 )
306 labels_before_mask = {
307 int(x) for x in np.unique(get_data(self.labels_img))
308 }
309 labels_after_mask = {
310 int(x) for x in np.unique(get_data(self.labels_img_))
311 }
312 labels_diff = labels_before_mask - labels_after_mask
313 if labels_diff:
314 warnings.warn(
315 "After applying mask to the labels image, "
316 "the following labels were "
317 f"removed: {labels_diff}. "
318 f"Out of {len(labels_before_mask)} labels, the "
319 "masked labels image only contains "
320 f"{len(labels_after_mask)} labels "
321 "(including background).",
322 stacklevel=find_stack_level(),
323 )
325 self._shelving = False
327 # generate a look up table if one was not provided
328 if self.lut is not None:
329 if isinstance(self.lut, (str, Path)):
330 lut = pd.read_table(self.lut, sep=None)
331 else:
332 lut = self.lut
333 elif self.labels:
334 lut = generate_atlas_look_up_table(
335 function=None,
336 name=self.labels,
337 index=self.labels_img_,
338 )
339 else:
340 lut = generate_atlas_look_up_table(
341 function=None, index=self.labels_img_
342 )
344 self.lut_ = sanitize_look_up_table(lut, atlas=self.labels_img_)
346 self._shelving = False
348 if self.clean_args is None:
349 self.clean_args_ = {}
350 else:
351 self.clean_args_ = self.clean_args
353 if not self.reports:
354 self._reporting_data = None
355 return self
357 # content to inject in the HTML template
358 self._report_content = {
359 "description": (
360 "This report shows the input surface image overlaid "
361 "with the outlines of the mask. "
362 "We recommend to inspect the report for the overlap "
363 "between the mask and its input image. "
364 ),
365 "n_vertices": {},
366 "number_of_regions": self.n_elements_,
367 "summary": {},
368 "warning_message": None,
369 }
371 for part in self.labels_img_.data.parts:
372 self._report_content["n_vertices"][part] = (
373 self.labels_img_.mesh.parts[part].n_vertices
374 )
376 self._reporting_data = self._generate_reporting_data()
378 return self
380 def _generate_reporting_data(self):
381 for part in self.labels_img_.data.parts:
382 size = []
383 relative_size = []
385 table = self.lut_.copy()
387 for _, row in table.iterrows():
388 n_vertices = self.labels_img_.data.parts[part] == row["index"]
389 size.append(n_vertices.sum())
390 tmp = (
391 n_vertices.sum()
392 / self.labels_img_.mesh.parts[part].n_vertices
393 * 100
394 )
395 relative_size.append(f"{tmp:.2}")
397 table["size"] = size
398 table["relative size"] = relative_size
400 self._report_content["summary"][part] = table
402 return {
403 "labels_image": self.labels_img_,
404 "images": None,
405 }
407 def __sklearn_is_fitted__(self):
408 return hasattr(self, "lut_") and hasattr(self, "mask_img_")
410 @fill_doc
411 def transform_single_imgs(self, imgs, confounds=None, sample_mask=None):
412 """Extract signals from surface object.
414 Parameters
415 ----------
416 imgs : imgs : :obj:`~nilearn.surface.SurfaceImage` object or \
417 iterable of :obj:`~nilearn.surface.SurfaceImage`
418 Images to process.
419 Mesh and data for both hemispheres.
421 %(confounds)s
423 %(sample_mask)s
425 Returns
426 -------
427 %(signals_transform_surface)s
428 """
429 check_is_fitted(self)
431 check_compatibility_mask_and_images(self.labels_img_, imgs)
432 check_polymesh_equal(self.labels_img_.mesh, imgs.mesh)
434 imgs = at_least_2d(imgs)
435 img_data = get_data(imgs)
437 target_datatype = (
438 np.float32 if img_data.dtype == np.float32 else np.float64
439 )
441 img_data = img_data.astype(target_datatype)
443 n_samples = 1 if len(img_data.shape) == 1 else img_data.shape[1]
445 region_signals = np.ndarray(
446 (n_samples, self.n_elements_), dtype=target_datatype
447 )
448 # adapted from nilearn.regions.signal_extraction.img_to_signals_labels
449 # iterate over time points and apply reduction function over labels.
450 labels_data = get_data(self.labels_img_)
452 index = self.labels_
453 if self.background_label in index:
454 index.pop(index.index(self.background_label))
456 reduction_function = getattr(ndimage, self.strategy)
458 for n, sample in enumerate(np.rollaxis(img_data, -1)):
459 tmp = np.asarray(
460 reduction_function(sample, labels=labels_data, index=index)
461 )
462 region_signals[n] = tmp
464 parameters = get_params(
465 self.__class__,
466 self,
467 ignore=[
468 "mask_img",
469 ],
470 )
471 parameters["clean_args"] = self.clean_args_
473 # signal cleaning here
474 region_signals = cache(
475 signal.clean,
476 memory=self.memory,
477 func_memory_level=2,
478 memory_level=self.memory_level,
479 shelve=self._shelving,
480 )(
481 region_signals,
482 detrend=parameters["detrend"],
483 standardize=parameters["standardize"],
484 standardize_confounds=parameters["standardize_confounds"],
485 t_r=parameters["t_r"],
486 low_pass=parameters["low_pass"],
487 high_pass=parameters["high_pass"],
488 confounds=confounds,
489 sample_mask=sample_mask,
490 **parameters["clean_args"],
491 )
493 return region_signals
495 @fill_doc
496 def inverse_transform(self, signals):
497 """Transform extracted signal back to surface image.
499 Parameters
500 ----------
501 %(signals_inv_transform)s
503 Returns
504 -------
505 %(img_inv_transform_surface)s
506 """
507 check_is_fitted(self)
509 return_1D = signals.ndim < 2
511 signals = self._check_array(signals)
513 imgs = signals_to_surf_img_labels(
514 signals,
515 np.asarray(self.labels_),
516 self.labels_img_,
517 self.background_label,
518 )
520 if return_1D:
521 for k, v in imgs.data.parts.items():
522 imgs.data.parts[k] = v.squeeze()
524 return imgs
526 def generate_report(self):
527 """Generate a report."""
528 from nilearn.reporting.html_report import generate_report
530 return generate_report(self)
532 def _reporting(self):
533 """Load displays needed for report.
535 Returns
536 -------
537 displays : list
538 A list of all displays to be rendered.
539 """
540 import matplotlib.pyplot as plt
542 from nilearn.reporting.utils import figure_to_png_base64
544 # Handle the edge case where this function is
545 # called with a masker having report capabilities disabled
546 if self._reporting_data is None:
547 return [None]
549 fig = self._create_figure_for_report()
551 plt.close()
553 init_display = figure_to_png_base64(fig)
555 return [init_display]
557 def _create_figure_for_report(self):
558 """Create a figure of the contours of label image.
560 If transform() was applied to an image,
561 this image is used as background
562 on which the contours are drawn.
563 """
564 import matplotlib.pyplot as plt
566 from nilearn.plotting import plot_surf, plot_surf_contours
568 labels_img = self._reporting_data["labels_image"]
570 img = self._reporting_data["images"]
571 if img:
572 img = mean_img(img)
573 vmin, vmax = img.data._get_min_max()
575 # TODO: possibly allow to generate a report with other views
576 views = ["lateral", "medial"]
577 hemispheres = ["left", "right"]
579 fig, axes = plt.subplots(
580 len(views),
581 len(hemispheres),
582 subplot_kw={"projection": "3d"},
583 figsize=(20, 20),
584 **constrained_layout_kwargs(),
585 )
586 axes = np.atleast_2d(axes)
588 for ax_row, view in zip(axes, views):
589 for ax, hemi in zip(ax_row, hemispheres):
590 if img:
591 plot_surf(
592 surf_map=img,
593 hemi=hemi,
594 view=view,
595 figure=fig,
596 axes=ax,
597 cmap=self.cmap,
598 vmin=vmin,
599 vmax=vmax,
600 )
601 plot_surf_contours(
602 roi_map=labels_img,
603 hemi=hemi,
604 view=view,
605 figure=fig,
606 axes=ax,
607 )
609 return fig