Coverage for nilearn/maskers/nifti_maps_masker.py: 12%
210 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Transformer for computing ROI signals."""
3import warnings
4from copy import deepcopy
6import numpy as np
7from sklearn.utils.estimator_checks import check_is_fitted
9from nilearn._utils import repr_niimgs
10from nilearn._utils.class_inspect import get_params
11from nilearn._utils.docs import fill_doc
12from nilearn._utils.helpers import is_matplotlib_installed
13from nilearn._utils.logger import find_stack_level, log
14from nilearn._utils.niimg_conversions import check_niimg, check_same_fov
15from nilearn._utils.param_validation import check_params
16from nilearn.image import clean_img, get_data, index_img, resample_img
17from nilearn.maskers._utils import compute_middle_image
18from nilearn.maskers.base_masker import BaseMasker, filter_and_extract
19from nilearn.masking import load_mask_img
22class _ExtractionFunctor:
23 func_name = "nifti_maps_masker_extractor"
25 def __init__(self, maps_img_, mask_img_, keep_masked_maps):
26 self.maps_img_ = maps_img_
27 self.mask_img_ = mask_img_
28 self.keep_masked_maps = keep_masked_maps
30 def __call__(self, imgs):
31 from ..regions import signal_extraction
33 return signal_extraction.img_to_signals_maps(
34 imgs,
35 self.maps_img_,
36 mask_img=self.mask_img_,
37 keep_masked_maps=self.keep_masked_maps,
38 )
41@fill_doc
42class NiftiMapsMasker(BaseMasker):
43 """Class for extracting data from Niimg-like objects \
44 using maps of potentially overlapping brain regions.
46 NiftiMapsMasker is useful when data from overlapping volumes should be
47 extracted (contrarily to :class:`nilearn.maskers.NiftiLabelsMasker`).
49 Use case:
50 summarize brain signals from large-scale networks
51 obtained by prior PCA or :term:`ICA`.
53 .. note::
54 Inf or NaN present in the given input images are automatically
55 put to zero rather than considered as missing data.
57 For more details on the definitions of maps in Nilearn,
58 see the :ref:`region` section.
60 Parameters
61 ----------
62 maps_img : 4D niimg-like object or None, default=None
63 See :ref:`extracting_data`.
64 Set of continuous maps. One representative time course per map is
65 extracted using least square regression.
67 mask_img : 3D niimg-like object, optional
68 See :ref:`extracting_data`.
69 Mask to apply to regions before extracting signals.
71 allow_overlap : :obj:`bool`, default=True
72 If False, an error is raised if the maps overlaps (ie at least two
73 maps have a non-zero value for the same voxel).
75 %(smoothing_fwhm)s
77 %(standardize_maskers)s
79 %(standardize_confounds)s
81 high_variance_confounds : :obj:`bool`, default=False
82 If True, high variance confounds are computed on provided image with
83 :func:`nilearn.image.high_variance_confounds` and default parameters
84 and regressed out.
86 %(detrend)s
88 %(low_pass)s
90 %(high_pass)s
92 %(t_r)s
94 %(dtype)s.
96 resampling_target : {"data", "mask", "maps", None}, default="data"
97 Gives which image gives the final shape/size. For example, if
98 `resampling_target` is "mask" then maps_img and images provided to
99 fit() are resampled to the shape and affine of mask_img. "None" means
100 no resampling: if shapes and affines do not match, a ValueError is
101 raised.
103 %(memory)s
105 %(memory_level)s
107 %(verbose0)s
109 %(keep_masked_maps)s
111 reports : :obj:`bool`, default=True
112 If set to True, data is saved in order to produce a report.
114 %(cmap)s
115 default="CMRmap_r"
116 Only relevant for the report figures.
118 %(clean_args)s
119 .. versionadded:: 0.11.2dev
121 %(masker_kwargs)s
123 Attributes
124 ----------
125 maps_img_ : :obj:`nibabel.nifti1.Nifti1Image`
126 The maps mask of the data.
128 %(nifti_mask_img_)s
130 n_elements_ : :obj:`int`
131 The number of overlapping maps in the mask.
132 This is equivalent to the number of volumes in the mask image.
134 .. versionadded:: 0.9.2
136 Notes
137 -----
138 If resampling_target is set to "maps", every 3D image processed by
139 transform() will be resampled to the shape of maps_img. It may lead to a
140 very large memory consumption if the voxel number in maps_img is large.
142 See Also
143 --------
144 nilearn.maskers.NiftiMasker
145 nilearn.maskers.NiftiLabelsMasker
147 """
149 # memory and memory_level are used by CacheMixin.
151 def __init__(
152 self,
153 maps_img=None,
154 mask_img=None,
155 allow_overlap=True,
156 smoothing_fwhm=None,
157 standardize=False,
158 standardize_confounds=True,
159 high_variance_confounds=False,
160 detrend=False,
161 low_pass=None,
162 high_pass=None,
163 t_r=None,
164 dtype=None,
165 resampling_target="data",
166 keep_masked_maps=True,
167 memory=None,
168 memory_level=0,
169 verbose=0,
170 reports=True,
171 cmap="CMRmap_r",
172 clean_args=None,
173 **kwargs, # TODO remove when bumping to nilearn >0.13
174 ):
175 self.maps_img = maps_img
176 self.mask_img = mask_img
178 # Maps Masker parameter
179 self.allow_overlap = allow_overlap
181 # Parameters for image.smooth
182 self.smoothing_fwhm = smoothing_fwhm
184 # Parameters for clean()
185 self.standardize = standardize
186 self.standardize_confounds = standardize_confounds
187 self.high_variance_confounds = high_variance_confounds
188 self.detrend = detrend
189 self.low_pass = low_pass
190 self.high_pass = high_pass
191 self.t_r = t_r
192 self.dtype = dtype
193 self.clean_args = clean_args
195 # TODO remove when bumping to nilearn >0.13
196 self.clean_kwargs = kwargs
198 # Parameters for resampling
199 self.resampling_target = resampling_target
201 # Parameters for joblib
202 self.memory = memory
203 self.memory_level = memory_level
204 self.verbose = verbose
206 self.reports = reports
207 self.cmap = cmap
209 self.keep_masked_maps = keep_masked_maps
211 def generate_report(self, displayed_maps=10):
212 """Generate an HTML report for the current ``NiftiMapsMasker`` object.
214 .. note::
215 This functionality requires to have ``Matplotlib`` installed.
217 Parameters
218 ----------
219 displayed_maps : :obj:`int`, or :obj:`list`, \
220 or :class:`~numpy.ndarray`, or "all", default=10
221 Indicates which maps will be displayed in the HTML report.
223 - If "all": All maps will be displayed in the report.
225 .. code-block:: python
227 masker.generate_report("all")
229 .. warning:
230 If there are too many maps, this might be time and
231 memory consuming, and will result in very heavy
232 reports.
234 - If a :obj:`list` or :class:`~numpy.ndarray`: This indicates
235 the indices of the maps to be displayed in the report. For
236 example, the following code will generate a report with maps
237 6, 3, and 12, displayed in this specific order:
239 .. code-block:: python
241 masker.generate_report([6, 3, 12])
243 - If an :obj:`int`: This will only display the first n maps,
244 n being the value of the parameter. By default, the report
245 will only contain the first 10 maps. Example to display the
246 first 16 maps:
248 .. code-block:: python
250 masker.generate_report(16)
252 Returns
253 -------
254 report : `nilearn.reporting.html_report.HTMLReport`
255 HTML report for the masker.
256 """
257 from nilearn.reporting.html_report import generate_report
259 if not is_matplotlib_installed():
260 return generate_report(self)
262 incorrect_type = not isinstance(
263 displayed_maps, (list, np.ndarray, int, str)
264 )
265 incorrect_string = (
266 isinstance(displayed_maps, str) and displayed_maps != "all"
267 )
268 not_integer = (
269 not isinstance(displayed_maps, str)
270 and np.array(displayed_maps).dtype != int
271 )
272 if incorrect_type or incorrect_string or not_integer:
273 raise TypeError(
274 "Parameter ``displayed_maps`` of "
275 "``generate_report()`` should be either 'all' or "
276 "an int, or a list/array of ints. You provided a "
277 f"{type(displayed_maps)}"
278 )
279 self.displayed_maps = displayed_maps
281 return generate_report(self)
283 def _reporting(self):
284 """Return a list of all displays to be rendered.
286 Returns
287 -------
288 displays : list
289 A list of all displays to be rendered.
291 """
292 from nilearn import plotting
293 from nilearn.reporting.html_report import embed_img
295 if self._reporting_data is not None:
296 maps_image = self._reporting_data["maps_image"]
297 else:
298 maps_image = None
300 if maps_image is None:
301 return [None]
303 n_maps = get_data(maps_image).shape[-1]
305 maps_to_be_displayed = range(n_maps)
306 if isinstance(self.displayed_maps, int):
307 if n_maps < self.displayed_maps:
308 msg = (
309 "`generate_report()` received "
310 f"{self.displayed_maps} to be displayed. "
311 f"But masker only has {n_maps} maps. "
312 f"Setting number of displayed maps to {n_maps}."
313 )
314 warnings.warn(
315 category=UserWarning,
316 message=msg,
317 stacklevel=find_stack_level(),
318 )
319 self.displayed_maps = n_maps
320 maps_to_be_displayed = range(self.displayed_maps)
322 elif isinstance(self.displayed_maps, (list, np.ndarray)):
323 if max(self.displayed_maps) > n_maps:
324 raise ValueError(
325 "Report cannot display the following maps "
326 f"{self.displayed_maps} because "
327 f"masker only has {n_maps} maps."
328 )
329 maps_to_be_displayed = self.displayed_maps
331 self._report_content["number_of_maps"] = n_maps
332 self._report_content["displayed_maps"] = list(maps_to_be_displayed)
334 img = self._reporting_data["img"]
335 embedded_images = []
337 if img is None:
338 msg = (
339 "No image provided to fit in NiftiMapsMasker. "
340 "Plotting only spatial maps for reporting."
341 )
342 warnings.warn(msg, stacklevel=find_stack_level())
343 self._report_content["warning_message"] = msg
344 for component in maps_to_be_displayed:
345 display = plotting.plot_stat_map(
346 index_img(maps_image, component)
347 )
348 embedded_images.append(embed_img(display))
349 display.close()
350 return embedded_images
352 if self._reporting_data["dim"] == 5:
353 msg = (
354 "A list of 4D subject images were provided to fit. "
355 "Only first subject is shown in the report."
356 )
357 warnings.warn(msg, stacklevel=find_stack_level())
358 self._report_content["warning_message"] = msg
360 for component in maps_to_be_displayed:
361 # Find the cut coordinates
362 cut_coords = plotting.find_xyz_cut_coords(
363 index_img(maps_image, component)
364 )
365 display = plotting.plot_img(
366 img,
367 cut_coords=cut_coords,
368 black_bg=False,
369 cmap=self.cmap,
370 )
371 display.add_overlay(
372 index_img(maps_image, component),
373 cmap=plotting.cm.black_blue,
374 )
375 embedded_images.append(embed_img(display))
376 display.close()
377 return embedded_images
379 @fill_doc
380 def fit(self, imgs=None, y=None):
381 """Prepare signal extraction from regions.
383 Parameters
384 ----------
385 imgs : :obj:`list` of Niimg-like objects or None, default=None
386 See :ref:`extracting_data`.
387 Image data passed to the reporter.
389 %(y_dummy)s
390 """
391 del y
392 check_params(self.__dict__)
393 if self.resampling_target not in ("mask", "maps", "data", None):
394 raise ValueError(
395 "invalid value for 'resampling_target' "
396 f"parameter: {self.resampling_target}"
397 )
399 if self.mask_img is None and self.resampling_target == "mask":
400 raise ValueError(
401 "resampling_target has been set to 'mask' but no mask "
402 "has been provided.\n"
403 "Set resampling_target to something else or provide a mask."
404 )
406 self._sanitize_cleaning_parameters()
407 self.clean_args_ = {} if self.clean_args is None else self.clean_args
409 self._report_content = {
410 "description": (
411 "This reports shows the spatial maps provided to the mask."
412 ),
413 "warning_message": None,
414 }
416 # Load images
417 maps_img = self.maps_img
418 if hasattr(self, "_maps_img"):
419 # This is for RegionExtractor that first modifies
420 # maps_img before passing to its parent fit method.
421 maps_img = self._maps_img
422 repr = repr_niimgs(maps_img, shorten=(not self.verbose))
423 msg = f"loading regions from {repr}"
424 log(msg=msg, verbose=self.verbose)
425 self.maps_img_ = deepcopy(maps_img)
426 self.maps_img_ = check_niimg(
427 self.maps_img_, dtype=self.dtype, atleast_4d=True
428 )
429 self.maps_img_ = clean_img(
430 self.maps_img_,
431 detrend=False,
432 standardize=False,
433 ensure_finite=True,
434 )
436 if imgs is not None:
437 imgs_ = check_niimg(imgs)
439 self.mask_img_ = self._load_mask(imgs)
441 # Check shapes and affines for resample.
442 if self.resampling_target is None:
443 images = {"maps": self.maps_img_}
444 if self.mask_img_ is not None:
445 images["mask"] = self.mask_img_
446 if imgs is not None:
447 images["data"] = imgs_
448 check_same_fov(raise_error=True, **images)
450 ref_img = None
451 if self.resampling_target == "data" and imgs is not None:
452 ref_img = imgs_
453 elif self.resampling_target == "mask":
454 ref_img = self.mask_img_
455 elif self.resampling_target == "maps":
456 ref_img = self.maps_img_
458 if ref_img is not None:
459 if self.resampling_target != "maps" and not check_same_fov(
460 ref_img, self.maps_img_
461 ):
462 log("Resampling maps...", self.verbose)
463 # TODO switch to force_resample=True
464 # when bumping to version > 0.13
465 self.maps_img_ = self._cache(resample_img)(
466 self.maps_img_,
467 interpolation="continuous",
468 target_shape=ref_img.shape[:3],
469 target_affine=ref_img.affine,
470 copy_header=True,
471 force_resample=False,
472 )
473 if self.mask_img_ is not None and not check_same_fov(
474 ref_img, self.mask_img_
475 ):
476 log("Resampling mask...", self.verbose)
477 # TODO switch to force_resample=True
478 # when bumping to version > 0.13
479 self.mask_img_ = resample_img(
480 self.mask_img_,
481 target_affine=ref_img.affine,
482 target_shape=ref_img.shape[:3],
483 interpolation="nearest",
484 copy=True,
485 copy_header=True,
486 force_resample=False,
487 )
489 # Just check that the mask is valid
490 load_mask_img(self.mask_img_)
492 if self.reports:
493 self._reporting_data = {
494 "maps_image": self.maps_img_,
495 "mask": self.mask_img_,
496 "dim": None,
497 "img": imgs,
498 }
499 if imgs is not None:
500 imgs, dims = compute_middle_image(imgs)
501 self._reporting_data["img"] = imgs
502 self._reporting_data["dim"] = dims
503 else:
504 self._reporting_data = None
506 # The number of elements is equal to the number of volumes
507 self.n_elements_ = self.maps_img_.shape[3]
509 return self
511 def __sklearn_is_fitted__(self):
512 return hasattr(self, "maps_img_") and hasattr(self, "n_elements_")
514 @fill_doc
515 def fit_transform(self, imgs, y=None, confounds=None, sample_mask=None):
516 """Prepare and perform signal extraction.
518 Parameters
519 ----------
520 imgs : 3D/4D Niimg-like object
521 See :ref:`extracting_data`.
522 Images to process.
523 If a 3D niimg is provided, a 1D array is returned.
525 %(y_dummy)s
527 %(confounds)s
529 %(sample_mask)s
531 .. versionadded:: 0.8.0
533 Returns
534 -------
535 %(signals_transform_nifti)s
536 """
537 del y
538 return self.fit(imgs).transform(
539 imgs, confounds=confounds, sample_mask=sample_mask
540 )
542 @fill_doc
543 def transform_single_imgs(self, imgs, confounds=None, sample_mask=None):
544 """Extract signals from a single 4D niimg.
546 Parameters
547 ----------
548 imgs : 3D/4D Niimg-like object
549 See :ref:`extracting_data`.
550 Images to process.
552 confounds : CSV file or array-like, default=None
553 This parameter is passed to :func:`nilearn.signal.clean`.
554 Please see the related documentation for details.
555 shape: (number of scans, number of confounds)
557 %(sample_mask)s
559 .. versionadded:: 0.8.0
561 Returns
562 -------
563 %(signals_transform_nifti)s
565 """
566 check_is_fitted(self)
568 # imgs passed at transform time may be different
569 # from those passed at fit time.
570 # So it may be needed to resample mask and maps,
571 # if 'data' is the resampling target.
572 # We handle the resampling of maps and mask separately because the
573 # affine of the maps and mask images should not impact the extraction
574 # of the signal.
575 #
576 # Any resampling of the mask or maps is not 'kept' after transform,
577 # to avoid modifying the masker after fit.
578 #
579 # If the resampling target is different,
580 # then resampling was already done at fit time
581 # (e.g resampling of the mask image to the maps image
582 # if the target was 'maps'),
583 # or resampling of the data will be done at extract time.
585 mask_img_ = self.mask_img_
586 maps_img_ = self.maps_img_
588 imgs_ = check_niimg(imgs, atleast_4d=True)
590 if self.resampling_target is None:
591 images = {"maps": maps_img_, "data": imgs_}
592 if mask_img_ is not None:
593 images["mask"] = mask_img_
594 check_same_fov(raise_error=True, **images)
595 elif self.resampling_target == "data":
596 ref_img = imgs_
598 if not check_same_fov(ref_img, maps_img_):
599 warnings.warn(
600 (
601 "Resampling maps at transform time...\n"
602 "To avoid this warning, make sure to pass the images "
603 "you want to transform to fit() first, "
604 "or directly use fit_transform()."
605 ),
606 stacklevel=find_stack_level(),
607 )
608 # TODO switch to force_resample=True
609 # when bumping to version > 0.13
610 maps_img_ = self._cache(resample_img)(
611 self.maps_img_,
612 interpolation="continuous",
613 target_shape=ref_img.shape[:3],
614 target_affine=ref_img.affine,
615 copy_header=True,
616 force_resample=False,
617 )
619 if self.mask_img_ is not None and not check_same_fov(
620 ref_img,
621 self.mask_img_,
622 ):
623 warnings.warn(
624 (
625 "Resampling mask at transform time...\n"
626 "To avoid this warning, make sure to pass the images "
627 "you want to transform to fit() first, "
628 "or directly use fit_transform()."
629 ),
630 stacklevel=find_stack_level(),
631 )
632 # TODO switch to force_resample=True
633 # when bumping to version > 0.13
634 mask_img_ = self._cache(resample_img)(
635 self.mask_img_,
636 interpolation="nearest",
637 target_shape=ref_img.shape[:3],
638 target_affine=ref_img.affine,
639 copy_header=True,
640 force_resample=False,
641 )
643 # Remove imgs_ from memory before loading the same image
644 # in filter_and_extract.
645 del imgs_
647 if not self.allow_overlap:
648 # Check if there is an overlap.
650 # If float, we set low values to 0
651 data = get_data(maps_img_)
652 dtype = data.dtype
653 if dtype.kind == "f":
654 data[data < np.finfo(dtype).eps] = 0.0
656 # Check the overlaps
657 if np.any(np.sum(data > 0.0, axis=3) > 1):
658 raise ValueError(
659 "Overlap detected in the maps. The overlap may be "
660 "due to the atlas itself or possibly introduced by "
661 "resampling."
662 )
664 target_shape = None
665 target_affine = None
666 if self.resampling_target != "data":
667 target_shape = maps_img_.shape[:3]
668 target_affine = maps_img_.affine
670 params = get_params(
671 NiftiMapsMasker,
672 self,
673 ignore=["resampling_target"],
674 )
675 params["target_shape"] = target_shape
676 params["target_affine"] = target_affine
677 params["clean_kwargs"] = self.clean_args_
678 # TODO remove in 0.13.2
679 if self.clean_kwargs:
680 params["clean_kwargs"] = self.clean_kwargs_
682 region_signals, _ = self._cache(
683 filter_and_extract,
684 ignore=["verbose", "memory", "memory_level"],
685 )(
686 # Images
687 imgs,
688 _ExtractionFunctor(
689 maps_img_,
690 mask_img_,
691 self.keep_masked_maps,
692 ),
693 # Pre-treatments
694 params,
695 confounds=confounds,
696 sample_mask=sample_mask,
697 dtype=self.dtype,
698 # Caching
699 memory=self.memory,
700 memory_level=self.memory_level,
701 # kwargs
702 verbose=self.verbose,
703 )
704 return region_signals
706 @fill_doc
707 def inverse_transform(self, region_signals):
708 """Compute :term:`voxel` signals from region signals.
710 Any mask given at initialization is taken into account.
712 Parameters
713 ----------
714 %(region_signals_inv_transform)s
716 Returns
717 -------
718 %(img_inv_transform_nifti)s
720 """
721 from ..regions import signal_extraction
723 check_is_fitted(self)
725 region_signals = self._check_array(region_signals)
727 log("computing image from signals", verbose=self.verbose)
728 return signal_extraction.signals_to_img_maps(
729 region_signals,
730 self.maps_img_,
731 mask_img=self.mask_img_,
732 )