Coverage for nilearn/maskers/nifti_spheres_masker.py: 13%
235 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Transformer for computing seeds signals.
3Mask nifti images by spherical volumes for seed-region analyses
4"""
6import contextlib
7import warnings
9import numpy as np
10from joblib import Memory
11from scipy import sparse
12from sklearn import neighbors
13from sklearn.utils.estimator_checks import check_is_fitted
15from nilearn._utils import logger
16from nilearn._utils.class_inspect import get_params
17from nilearn._utils.docs import fill_doc
18from nilearn._utils.helpers import (
19 is_matplotlib_installed,
20 rename_parameters,
21)
22from nilearn._utils.logger import find_stack_level
23from nilearn._utils.niimg import img_data_dtype
24from nilearn._utils.niimg_conversions import (
25 check_niimg_3d,
26 check_niimg_4d,
27 safe_get_data,
28)
29from nilearn.datasets import load_mni152_template
30from nilearn.image import resample_img
31from nilearn.image.resampling import coord_transform
32from nilearn.maskers._utils import compute_middle_image
33from nilearn.maskers.base_masker import BaseMasker, filter_and_extract
34from nilearn.masking import apply_mask_fmri, load_mask_img, unmask
37def apply_mask_and_get_affinity(
38 seeds, niimg, radius, allow_overlap, mask_img=None
39):
40 """Get only the rows which are occupied by sphere \
41 at given seed locations and the provided radius.
43 Rows are in target_affine and target_shape space.
45 Parameters
46 ----------
47 seeds : List of triplets of coordinates in native space
48 Seed definitions. List of coordinates of the seeds in the same space
49 as target_affine.
51 niimg : 3D/4D Niimg-like object
52 See :ref:`extracting_data`.
53 Images to process.
54 If a 3D niimg is provided, a singleton dimension will be added to
55 the output to represent the single scan in the niimg.
57 radius : float
58 Indicates, in millimeters, the radius for the sphere around the seed.
60 allow_overlap : boolean
61 If False, a ValueError is raised if VOIs overlap
63 mask_img : Niimg-like object, optional
64 Mask to apply to regions before extracting signals. If niimg is None,
65 mask_img is used as a reference space in which the spheres 'indices are
66 placed.
68 Returns
69 -------
70 X : numpy.ndarray
71 Signal for each brain voxel in the (masked) niimgs.
72 shape: (number of scans, number of voxels)
74 A : scipy.sparse.lil_matrix
75 Contains the boolean indices for each sphere.
76 shape: (number of seeds, number of voxels)
78 """
79 seeds = list(seeds)
81 # Compute world coordinates of all in-mask voxels.
82 if niimg is None:
83 mask, affine = load_mask_img(mask_img)
84 # Get coordinate for all voxels inside of mask
85 mask_coords = np.asarray(np.nonzero(mask)).T.tolist()
86 X = None
88 elif mask_img is not None:
89 affine = niimg.affine
90 mask_img = check_niimg_3d(mask_img)
91 # TODO switch to force_resample=True
92 # when bumping to version > 0.13
93 mask_img = resample_img(
94 mask_img,
95 target_affine=affine,
96 target_shape=niimg.shape[:3],
97 interpolation="nearest",
98 copy_header=True,
99 force_resample=False,
100 )
101 mask, _ = load_mask_img(mask_img)
102 mask_coords = list(zip(*np.where(mask != 0)))
104 X = apply_mask_fmri(niimg, mask_img)
106 else:
107 affine = niimg.affine
108 if np.isnan(np.sum(safe_get_data(niimg))):
109 warnings.warn(
110 "The imgs you have fed into fit_transform() contains NaN "
111 "values which will be converted to zeroes.",
112 stacklevel=find_stack_level(),
113 )
114 X = safe_get_data(niimg, True).reshape([-1, niimg.shape[3]]).T
115 else:
116 X = safe_get_data(niimg).reshape([-1, niimg.shape[3]]).T
118 mask_coords = list(np.ndindex(niimg.shape[:3]))
120 # For each seed, get coordinates of nearest voxel
121 nearests = []
122 for sx, sy, sz in seeds:
123 nearest = np.round(coord_transform(sx, sy, sz, np.linalg.inv(affine)))
124 nearest = nearest.astype(int)
125 nearest = (nearest[0], nearest[1], nearest[2])
126 try:
127 nearests.append(mask_coords.index(nearest))
128 except ValueError:
129 nearests.append(None)
131 mask_coords = np.asarray(list(zip(*mask_coords)))
132 mask_coords = coord_transform(
133 mask_coords[0], mask_coords[1], mask_coords[2], affine
134 )
135 mask_coords = np.asarray(mask_coords).T
137 clf = neighbors.NearestNeighbors(radius=radius)
138 A = clf.fit(mask_coords).radius_neighbors_graph(seeds)
139 A = A.tolil()
140 for i, nearest in enumerate(nearests):
141 if nearest is None:
142 continue
144 A[i, nearest] = True
146 # Include the voxel containing the seed itself if not masked
147 mask_coords = mask_coords.astype(int).tolist()
148 for i, seed in enumerate(seeds):
149 with contextlib.suppress(ValueError): # if seed is not in the mask
150 A[i, mask_coords.index(list(map(int, seed)))] = True
152 sphere_sizes = np.asarray(A.tocsr().sum(axis=1)).ravel()
153 empty_spheres = np.nonzero(sphere_sizes == 0)[0]
154 if len(empty_spheres) != 0:
155 raise ValueError(f"These spheres are empty: {empty_spheres}")
157 if (not allow_overlap) and np.any(A.sum(axis=0) >= 2):
158 raise ValueError("Overlap detected between spheres")
160 return X, A
163def _iter_signals_from_spheres(
164 seeds, niimg, radius, allow_overlap, mask_img=None
165):
166 """Iterate over spheres.
168 Parameters
169 ----------
170 seeds : :obj:`list` of triplets of coordinates in native space
171 Seed definitions. List of coordinates of the seeds in the same space
172 as the images (typically MNI or TAL).
174 niimg : 3D/4D Niimg-like object
175 See :ref:`extracting_data`.
176 Images to process.
177 If a 3D niimg is provided, a singleton dimension will be added to
178 the output to represent the single scan in the niimg.
180 radius : float
181 Indicates, in millimeters, the radius for the sphere around the seed.
183 allow_overlap : boolean
184 If False, an error is raised if the maps overlaps (ie at least two
185 maps have a non-zero value for the same voxel).
187 mask_img : Niimg-like object, optional
188 See :ref:`extracting_data`.
189 Mask to apply to regions before extracting signals.
191 """
192 X, A = apply_mask_and_get_affinity(
193 seeds, niimg, radius, allow_overlap, mask_img=mask_img
194 )
195 for row in A.rows:
196 yield X[:, row]
199class _ExtractionFunctor:
200 func_name = "nifti_spheres_masker_extractor"
202 def __init__(self, seeds_, radius, mask_img, allow_overlap, dtype):
203 self.seeds_ = seeds_
204 self.radius = radius
205 self.mask_img = mask_img
206 self.allow_overlap = allow_overlap
207 self.dtype = dtype
209 def __call__(self, imgs):
210 n_seeds = len(self.seeds_)
212 imgs = check_niimg_4d(imgs, dtype=self.dtype)
214 signals = np.empty(
215 (imgs.shape[3], n_seeds), dtype=img_data_dtype(imgs)
216 )
217 for i, sphere in enumerate(
218 _iter_signals_from_spheres(
219 self.seeds_,
220 imgs,
221 self.radius,
222 self.allow_overlap,
223 mask_img=self.mask_img,
224 )
225 ):
226 signals[:, i] = np.mean(sphere, axis=1)
228 return signals, None
231@fill_doc
232class NiftiSpheresMasker(BaseMasker):
233 """Class for masking of Niimg-like objects using seeds.
235 NiftiSpheresMasker is useful when data from given seeds should be
236 extracted.
238 Use case:
239 summarize brain signals from seeds that were obtained from prior knowledge.
241 Parameters
242 ----------
243 seeds : :obj:`list` of triplet of coordinates in native space or None, \
244 default=None
245 Seed definitions. List of coordinates of the seeds in the same space
246 as the images (typically MNI or TAL).
248 radius : :obj:`float`, default=None
249 Indicates, in millimeters, the radius for the sphere around the seed.
250 By default signal is extracted on a single voxel.
252 mask_img : Niimg-like object, default=None
253 See :ref:`extracting_data`.
254 Mask to apply to regions before extracting signals.
256 allow_overlap : :obj:`bool`, default=False
257 If False, an error is raised if the maps overlaps (ie at least two
258 maps have a non-zero value for the same voxel).
259 %(smoothing_fwhm)s
260 %(standardize_maskers)s
261 %(standardize_confounds)s
262 high_variance_confounds : :obj:`bool`, default=False
263 If True, high variance confounds are computed on provided image with
264 :func:`nilearn.image.high_variance_confounds` and default parameters
265 and regressed out.
266 %(detrend)s
267 %(low_pass)s
268 %(high_pass)s
269 %(t_r)s
271 %(dtype)s
273 %(memory)s
274 %(memory_level1)s
275 %(verbose0)s
277 %(clean_args)s
278 .. versionadded:: 0.11.2dev
280 %(masker_kwargs)s
282 Attributes
283 ----------
284 %(nifti_mask_img_)s
286 n_elements_ : :obj:`int`
287 The number of seeds in the masker.
289 .. versionadded:: 0.9.2
291 seeds_ : :obj:`list` of :obj:`list`
292 The coordinates of the seeds in the masker.
294 reports : boolean, default=True
295 If set to True, data is saved in order to produce a report.
297 See Also
298 --------
299 nilearn.maskers.NiftiMasker
301 """
303 # memory and memory_level are used by CacheMixin.
304 def __init__(
305 self,
306 seeds=None,
307 radius=None,
308 mask_img=None,
309 allow_overlap=False,
310 smoothing_fwhm=None,
311 standardize=False,
312 standardize_confounds=True,
313 high_variance_confounds=False,
314 detrend=False,
315 low_pass=None,
316 high_pass=None,
317 t_r=None,
318 dtype=None,
319 memory=None,
320 memory_level=1,
321 verbose=0,
322 reports=True,
323 clean_args=None,
324 **kwargs,
325 ):
326 self.seeds = seeds
327 self.mask_img = mask_img
328 self.radius = radius
329 self.allow_overlap = allow_overlap
331 # Parameters for smooth_array
332 self.smoothing_fwhm = smoothing_fwhm
334 # Parameters for clean()
335 self.standardize = standardize
336 self.standardize_confounds = standardize_confounds
337 self.high_variance_confounds = high_variance_confounds
338 self.detrend = detrend
339 self.low_pass = low_pass
340 self.high_pass = high_pass
341 self.t_r = t_r
342 self.dtype = dtype
343 self.clean_args = clean_args
344 self.clean_kwargs = kwargs
346 # Parameters for joblib
347 self.memory = memory
348 self.memory_level = memory_level
350 # Parameters for reporting
351 self.reports = reports
352 self.verbose = verbose
354 def generate_report(self, displayed_spheres="all"):
355 """Generate an HTML report for current ``NiftiSpheresMasker`` object.
357 .. note::
358 This functionality requires to have ``Matplotlib`` installed.
360 Parameters
361 ----------
362 displayed_spheres : :obj:`int`, or :obj:`list`,\
363 or :class:`~numpy.ndarray`, or "all", default="all"
364 Indicates which spheres will be displayed in the HTML report.
366 - If "all": All spheres will be displayed in the report.
368 .. code-block:: python
370 masker.generate_report("all")
372 .. warning::
374 If there are too many spheres, this might be time and
375 memory consuming, and will result in very heavy
376 reports.
378 - If a :obj:`list` or :class:`~numpy.ndarray`: This indicates
379 the indices of the spheres to be displayed in the report.
380 For example, the following code will generate a report with
381 spheres 6, 3, and 12, displayed in this specific order:
383 .. code-block:: python
385 masker.generate_report([6, 3, 12])
387 - If an :obj:`int`: This will only display the first n
388 spheres, n being the value of the parameter. By default,
389 the report will only contain the first 10 spheres.
390 Example to display the first 16 spheres:
392 .. code-block:: python
394 masker.generate_report(16)
396 Returns
397 -------
398 report : `nilearn.reporting.html_report.HTMLReport`
399 HTML report for the masker.
400 """
401 from nilearn.reporting.html_report import generate_report
403 if not is_matplotlib_installed():
404 return generate_report(self)
406 if displayed_spheres != "all" and not isinstance(
407 displayed_spheres, (list, np.ndarray, int)
408 ):
409 raise TypeError(
410 "Parameter ``displayed_spheres`` of "
411 "``generate_report()`` should be either 'all' or "
412 "an int, or a list/array of ints. You provided a "
413 f"{type(displayed_spheres)}"
414 )
415 self.displayed_spheres = displayed_spheres
417 return generate_report(self)
419 def _reporting(self):
420 """Return a list of all displays to be rendered.
422 Returns
423 -------
424 displays : list
425 A list of all displays to be rendered.
426 """
427 from nilearn import plotting
428 from nilearn.reporting.html_report import embed_img
430 if self._reporting_data is not None:
431 seeds = self._reporting_data["seeds"]
432 else:
433 self._report_content["summary"] = None
435 return [None]
437 img = self._reporting_data["img"]
438 if img is None:
439 img = load_mni152_template()
440 positions = seeds
441 msg = (
442 "No image provided to fit in NiftiSpheresMasker. "
443 "Spheres are plotted on top of the MNI152 template."
444 )
445 warnings.warn(msg, stacklevel=find_stack_level())
446 self._report_content["warning_message"] = msg
447 else:
448 positions = [
449 np.round(
450 coord_transform(*seed, np.linalg.inv(img.affine))
451 ).astype(int)
452 for seed in seeds
453 ]
455 self._report_content["number_of_seeds"] = len(seeds)
457 spheres_to_be_displayed = range(len(seeds))
458 if isinstance(self.displayed_spheres, int):
459 if len(seeds) < self.displayed_spheres:
460 msg = (
461 "generate_report() received "
462 f"{self.displayed_spheres} spheres to be displayed. "
463 f"But masker only has {len(seeds)} seeds. "
464 "Setting number of displayed spheres "
465 f"to {len(seeds)}."
466 )
467 warnings.warn(
468 category=UserWarning,
469 message=msg,
470 stacklevel=find_stack_level(),
471 )
472 self.displayed_spheres = len(seeds)
473 spheres_to_be_displayed = range(self.displayed_spheres)
474 elif isinstance(self.displayed_spheres, (list, np.ndarray)):
475 if max(self.displayed_spheres) > len(seeds):
476 raise ValueError(
477 "Report cannot display the "
478 "following spheres "
479 f"{self.displayed_spheres} because "
480 f"masker only has {len(seeds)} seeds."
481 )
482 spheres_to_be_displayed = self.displayed_spheres
483 # extend spheres_to_be_displayed by 1
484 # as the default image is a glass brain with all the spheres
485 tmp = [0]
486 spheres_to_be_displayed = np.asarray(spheres_to_be_displayed) + 1
487 tmp.extend(spheres_to_be_displayed.tolist())
488 self._report_content["displayed_maps"] = tmp
490 columns = [
491 "seed number",
492 "coordinates",
493 "position",
494 "radius",
495 "size (in mm^3)",
496 "size (in voxels)",
497 "relative size (in %)",
498 ]
499 regions_summary = {c: [] for c in columns}
501 radius = 1.0 if self.radius is None else self.radius
502 display = plotting.plot_markers(
503 [1 for _ in seeds], seeds, node_size=20 * radius, colorbar=False
504 )
505 embedded_images = [embed_img(display)]
506 display.close()
507 for idx, seed in enumerate(seeds):
508 regions_summary["seed number"].append(idx)
509 regions_summary["coordinates"].append(str(seed))
510 regions_summary["position"].append(positions[idx])
511 regions_summary["radius"].append(radius)
512 regions_summary["size (in voxels)"].append("not implemented")
513 regions_summary["size (in mm^3)"].append(
514 round(4.0 / 3.0 * np.pi * radius**3, 2)
515 )
516 regions_summary["relative size (in %)"].append("not implemented")
518 if idx + 1 in self._report_content["displayed_maps"]:
519 display = plotting.plot_img(img, cut_coords=seed, cmap="gray")
520 display.add_markers(
521 marker_coords=[seed],
522 marker_color="g",
523 marker_size=20 * radius,
524 )
525 embedded_images.append(embed_img(display))
526 display.close()
528 assert len(embedded_images) == len(
529 self._report_content["displayed_maps"]
530 )
532 self._report_content["summary"] = regions_summary
534 return embedded_images
536 @rename_parameters(replacement_params={"X": "imgs"}, end_version="0.13.2")
537 def fit(
538 self,
539 imgs=None,
540 y=None,
541 ):
542 """Prepare signal extraction from regions.
544 All parameters are unused; they are for scikit-learn compatibility.
546 """
547 del y
548 self._report_content = {
549 "description": (
550 "This reports shows the regions defined "
551 "by the spheres of the masker."
552 ),
553 "warning_message": None,
554 }
556 self._sanitize_cleaning_parameters()
557 self.clean_args_ = {} if self.clean_args is None else self.clean_args
559 error = (
560 "Seeds must be a list of triplets of coordinates in "
561 "native space.\n"
562 )
564 self.mask_img_ = self._load_mask(imgs)
566 if self.memory is None:
567 self.memory = Memory(location=None)
569 if imgs is not None:
570 if self.reports:
571 if self.mask_img_ is not None:
572 # TODO switch to force_resample=True
573 # when bumping to version > 0.13
574 resampl_imgs = self._cache(resample_img)(
575 imgs,
576 target_affine=self.mask_img_.affine,
577 copy=False,
578 interpolation="nearest",
579 copy_header=True,
580 force_resample=False,
581 )
582 else:
583 resampl_imgs = imgs
584 # Store 1 timepoint to pass to reporter
585 resampl_imgs, _ = compute_middle_image(resampl_imgs)
586 elif self.reports: # imgs not provided to fit
587 resampl_imgs = None
589 if not hasattr(self.seeds, "__iter__"):
590 raise ValueError(
591 f"{error}Given seed list is of type: {type(self.seeds)}"
592 )
594 self.seeds_ = []
595 # Check seeds and convert them to lists if needed
596 for i, seed in enumerate(self.seeds):
597 # Check the type first
598 if not hasattr(seed, "__len__"):
599 raise ValueError(
600 f"{error}Seed #{i} is not a valid triplet of coordinates. "
601 f"It is of type {type(seed)}."
602 )
603 # Convert to list because it is easier to process
604 seed = (
605 seed.tolist() if isinstance(seed, np.ndarray) else list(seed)
606 )
607 # Check the length
608 if len(seed) != 3:
609 raise ValueError(
610 f"{error}Seed #{i} is of length {len(seed)} instead of 3."
611 )
613 self.seeds_.append(seed)
615 self._reporting_data = None
616 if self.reports:
617 self._reporting_data = {
618 "seeds": self.seeds_,
619 "mask": self.mask_img_,
620 "img": resampl_imgs,
621 }
623 self.n_elements_ = len(self.seeds_)
625 return self
627 @fill_doc
628 def fit_transform(self, imgs, y=None, confounds=None, sample_mask=None):
629 """Prepare and perform signal extraction.
631 Parameters
632 ----------
633 imgs : 3D/4D Niimg-like object
634 See :ref:`extracting_data`.
635 Images to process.
637 y : None
638 This parameter is unused. It is solely included for scikit-learn
639 compatibility.
641 %(confounds)s
643 %(sample_mask)s
645 .. versionadded:: 0.8.0
647 Returns
648 -------
649 %(signals_transform_nifti)s
651 """
652 del y
653 return self.fit(imgs).transform(
654 imgs, confounds=confounds, sample_mask=sample_mask
655 )
657 def __sklearn_is_fitted__(self):
658 return hasattr(self, "seeds_") and hasattr(self, "n_elements_")
660 @fill_doc
661 def transform_single_imgs(self, imgs, confounds=None, sample_mask=None):
662 """Extract signals from a single 4D niimg.
664 Parameters
665 ----------
666 imgs : 3D/4D Niimg-like object
667 See :ref:`extracting_data`.
668 Images to process.
670 %(confounds)s
672 %(sample_mask)s
674 .. versionadded:: 0.8.0
676 Returns
677 -------
678 %(signals_transform_nifti)s
680 """
681 check_is_fitted(self)
683 params = get_params(NiftiSpheresMasker, self)
684 params["clean_kwargs"] = self.clean_args_
685 # TODO remove in 0.13.2
686 if self.clean_kwargs:
687 params["clean_kwargs"] = self.clean_kwargs_
689 signals, _ = self._cache(
690 filter_and_extract, ignore=["verbose", "memory", "memory_level"]
691 )(
692 imgs,
693 _ExtractionFunctor(
694 self.seeds_,
695 self.radius,
696 self.mask_img,
697 self.allow_overlap,
698 self.dtype,
699 ),
700 # Pre-processing
701 params,
702 confounds=confounds,
703 sample_mask=sample_mask,
704 dtype=self.dtype,
705 # Caching
706 memory=self.memory,
707 memory_level=self.memory_level,
708 # kwargs
709 verbose=self.verbose,
710 )
711 return np.atleast_1d(signals)
713 @fill_doc
714 def inverse_transform(self, region_signals):
715 """Compute :term:`voxel` signals from spheres signals.
717 Any mask given at initialization is taken into account. Throws an error
718 if ``mask_img==None``
720 Parameters
721 ----------
722 %(region_signals_inv_transform)s
724 Returns
725 -------
726 %(img_inv_transform_nifti)s
728 """
729 check_is_fitted(self)
731 region_signals = self._check_array(region_signals)
733 logger.log("computing image from signals", verbose=self.verbose)
735 if self.mask_img_ is not None:
736 mask = check_niimg_3d(self.mask_img_)
737 else:
738 raise ValueError(
739 "Please provide mask_img at initialization to "
740 "provide a reference for the inverse_transform."
741 )
743 _, adjacency = apply_mask_and_get_affinity(
744 self.seeds_, None, self.radius, self.allow_overlap, mask_img=mask
745 )
746 adjacency = adjacency.tocsr()
747 # Compute overlap scaling for mean signal:
748 if self.allow_overlap:
749 n_adjacent_spheres = np.asarray(adjacency.sum(axis=0)).ravel()
750 scale = 1 / np.maximum(1, n_adjacent_spheres)
751 adjacency = adjacency.dot(sparse.diags(scale))
753 img = adjacency.T.dot(region_signals.T).T
754 return unmask(img, self.mask_img_)