Coverage for nilearn/regions/rena_clustering.py: 16%
227 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Recursive Neighbor Agglomeration (ReNA).
3Fastclustering for approximation of structured signals
4"""
6import itertools
7import warnings
9import numpy as np
10from joblib import Memory
11from nibabel import Nifti1Image
12from scipy.sparse import coo_matrix, csgraph, dia_matrix
13from sklearn.base import BaseEstimator, ClusterMixin, TransformerMixin
14from sklearn.utils import check_array
15from sklearn.utils.validation import check_is_fitted
17from nilearn._utils import fill_doc, logger
18from nilearn._utils.logger import find_stack_level
19from nilearn._utils.param_validation import check_params
20from nilearn._utils.tags import SKLEARN_LT_1_6
21from nilearn.image import get_data
22from nilearn.maskers import SurfaceMasker
23from nilearn.masking import unmask_from_to_3d_array
24from nilearn.surface import SurfaceImage
27def _compute_weights(X, mask_img):
28 """Compute the weights in direction of each axis using Euclidean distance.
30 i.e. weights = (weight_deep, weights_right, weight_down).
32 Notes
33 -----
34 Here we assume a square lattice (no diagonal connections).
36 Parameters
37 ----------
38 X : ndarray, shape = [n_samples, n_features]
39 Training data.
41 mask_img : Niimg-like object
42 Object used for masking the data.
44 Returns
45 -------
46 weights : ndarray
47 Weights corresponding to all edges in the mask.
48 shape: (n_edges,).
50 """
51 n_samples, n_features = X.shape
53 mask = get_data(mask_img).astype("bool")
54 shape = mask.shape
56 data = np.empty((shape[0], shape[1], shape[2], n_samples))
57 for sample in range(n_samples):
58 data[:, :, :, sample] = unmask_from_to_3d_array(X[sample].copy(), mask)
60 weights_deep = np.sum(np.diff(data, axis=2) ** 2, axis=-1).ravel()
61 weights_right = np.sum(np.diff(data, axis=1) ** 2, axis=-1).ravel()
62 weights_down = np.sum(np.diff(data, axis=0) ** 2, axis=-1).ravel()
64 weights = np.hstack([weights_deep, weights_right, weights_down])
66 return weights
69def _make_3d_edges(vertices, is_mask):
70 """Create the edges set: Returns a list of edges for a 3D image.
72 Parameters
73 ----------
74 vertices : ndarray
75 The indices of the voxels.
77 is_mask : boolean
78 If is_mask is true, it returns the mask of edges.
79 Returns 1 if the edge is contained in the mask, 0 otherwise.
81 Returns
82 -------
83 edges : ndarray
84 Edges corresponding to the image or mask.
85 shape: (1, n_edges) if_mask,
86 (2, n_edges) otherwise.
88 """
89 if is_mask:
90 edges_deep = np.logical_and(
91 vertices[:, :, :-1].ravel(), vertices[:, :, 1:].ravel()
92 )
93 edges_right = np.logical_and(
94 vertices[:, :-1].ravel(), vertices[:, 1:].ravel()
95 )
96 edges_down = np.logical_and(
97 vertices[:-1].ravel(), vertices[1:].ravel()
98 )
99 else:
100 edges_deep = np.vstack(
101 [vertices[:, :, :-1].ravel(), vertices[:, :, 1:].ravel()]
102 )
103 edges_right = np.vstack(
104 [vertices[:, :-1].ravel(), vertices[:, 1:].ravel()]
105 )
106 edges_down = np.vstack([vertices[:-1].ravel(), vertices[1:].ravel()])
108 edges = np.hstack([edges_deep, edges_right, edges_down])
110 return edges
113def _make_edges_and_weights(X, mask_img):
114 """Compute the weights to all edges in the mask.
116 Parameters
117 ----------
118 X : ndarray, shape = [n_samples, n_features]
119 Training data.
121 mask_img : Niimg-like object
122 Object used for masking the data.
124 Returns
125 -------
126 edges : ndarray
127 Array containing [edges_deep, edges_right, edges_down]
129 weights : ndarray
130 Weights corresponding to all edges in the mask.
131 shape: (n_edges,).
133 """
134 mask = get_data(mask_img)
135 shape = mask.shape
136 n_vertices = np.prod(shape)
138 # Indexing each voxel
139 vertices = np.arange(n_vertices).reshape(shape)
141 weights_unmasked = _compute_weights(X, mask_img)
143 edges_unmasked = _make_3d_edges(vertices, is_mask=False)
144 edges_mask = _make_3d_edges(mask, is_mask=True)
146 # Apply mask to edges and weights
147 weights = np.copy(weights_unmasked[edges_mask])
148 edges = np.copy(edges_unmasked[:, edges_mask])
150 # Reorder the indices of the graph
151 max_index = edges.max()
152 order = np.searchsorted(np.unique(edges.ravel()), np.arange(max_index + 1))
153 edges = order[edges]
155 return edges, weights
158def _compute_weights_surface(X, mask, edges):
159 """Compute the weights for each edge using squared Euclidean distance.
161 Parameters
162 ----------
163 X : ndarray, shape = [n_samples, n_features]
164 Masked training data, where some vertices were removed during masking.
165 So n_features is only the number of vertices that were kept after
166 masking.
168 mask : boolean ndarray, shape = [1, n_vertices]
169 Initial mask used for getting the X. So n_vertices is the total number
170 of vertices in the mesh.
172 edges : ndarray, shape = [2, n_edges]
173 Edges between the all the vertices in the mesh before masking.
175 Returns
176 -------
177 weights : ndarray
178 Weights corresponding to all edges.
179 shape: (n_edges,).
181 """
182 n_samples, _ = X.shape
183 shape = mask.shape
185 data = np.empty((shape[0], n_samples))
186 # Unmasking the X
187 # this will give us the back the transpose of original data
188 # with the masked vertices set to 0
189 # data will be of shape (n_vertices, n_samples)
190 for sample in range(n_samples):
191 data[:, sample] = unmask_from_to_3d_array(X[sample].copy(), mask)
193 data_i = data[edges[0]]
194 data_j = data[edges[1]]
195 weights = np.sum((data_i - data_j) ** 2, axis=-1).ravel()
197 return weights
200def _circular_pairwise(iterable):
201 """Pairwise iterator with the first element reused as the last one.
203 Return successive overlapping pairs taken from the input `iterable`.
204 The number of 2-tuples in the `output` iterator will be the number of
205 inputs.
207 Parameters
208 ----------
209 iterable : iterable
211 Returns
212 -------
213 output : iterable
215 """
216 a, b = itertools.tee(iterable)
217 return itertools.zip_longest(a, b, fillvalue=next(b, None))
220def make_edges_surface(faces, mask):
221 """Create the edges set: Returns a list of edges for a surface mesh.
223 Parameters
224 ----------
225 faces : ndarray
226 The vertex indices corresponding the mesh triangles.
228 mask : boolean
229 True if the edge is contained in the mask, False otherwise.
231 Returns
232 -------
233 edges : ndarray
234 Edges corresponding to the image with shape: (2, n_edges).
236 edges_masked : ndarray
237 Edges corresponding to the mask with shape: (1, n_edges).
239 """
240 mesh_edges = {
241 tuple(sorted(pair))
242 for face in faces
243 for pair in _circular_pairwise(face)
244 }
245 edges = np.array(list(mesh_edges))
246 false_indices = np.where(~mask)[0]
247 edges_masked = ~np.isin(edges, false_indices).any(axis=1)
249 return edges.T, edges_masked
252def _make_edges_and_weights_surface(X, mask_img):
253 """Compute the weights to all edges in the mask.
255 Parameters
256 ----------
257 X : ndarray, shape = [n_samples, n_features]
258 Training data.
260 mask_img : :obj:`~nilearn.surface.SurfaceImage` object
261 Object used for masking the data.
263 Returns
264 -------
265 edges : dict[str, np.array]
266 Array containing edges of mesh
268 weights : dict[str, np.array]
269 Weights corresponding to all edges in the mask.
270 shape: (n_edges,).
272 """
273 weights = {}
274 edges = {}
275 len_previous_mask = 0
276 for part in mask_img.mesh.parts:
277 face_part = mask_img.mesh.parts[part].faces
279 if len(mask_img.shape) == 1:
280 mask_part = mask_img.data.parts[part]
281 else:
282 mask_part = mask_img.data.parts[part][:, 0]
284 edges_unmasked, edges_mask = make_edges_surface(face_part, mask_part)
286 idxs = np.array(range(mask_part.sum())) + len_previous_mask
287 weights_unmasked = _compute_weights_surface(
288 X[:, idxs], mask_part.astype("bool"), edges_unmasked
289 )
290 # Apply mask to edges and weights
291 weights[part] = np.copy(weights_unmasked[edges_mask])
292 edges_ = np.copy(edges_unmasked[:, edges_mask])
294 # Reorder the indices of the graph
295 max_index = edges_.max()
296 order = np.searchsorted(
297 np.unique(edges_.ravel()), np.arange(max_index + 1)
298 )
299 # increasing the order by the number of vertices in the previous mask
300 # to avoid overlapping indices
301 order += len_previous_mask
302 edges[part] = order[edges_]
304 len_previous_mask += mask_part.sum()
306 return edges, weights
309def _weighted_connectivity_graph(X, mask_img):
310 """Create a symmetric weighted graph.
312 Data and topology are encoded by a connectivity matrix.
314 Parameters
315 ----------
316 X : :class:`numpy.ndarray`
317 Training data. shape = [n_samples, n_features]
319 mask_img : Niimg-like object or :obj:`~nilearn.surface.SurfaceImage` object
320 Object used for masking the data.
322 Returns
323 -------
324 connectivity : a CSR matrix
325 Sparse matrix representation of the weighted adjacency graph.
327 """
328 n_features = X.shape[1]
330 if isinstance(mask_img, SurfaceImage):
331 edges, weight = _make_edges_and_weights_surface(X, mask_img)
332 connectivity = coo_matrix((n_features, n_features))
333 for part in mask_img.mesh.parts:
334 conn_temp = coo_matrix(
335 (weight[part], edges[part]), (n_features, n_features)
336 ).tocsr()
337 connectivity += conn_temp
338 else:
339 edges, weight = _make_edges_and_weights(X, mask_img)
341 connectivity = coo_matrix(
342 (weight, edges), (n_features, n_features)
343 ).tocsr()
345 # Making it symmetrical
346 connectivity = (connectivity + connectivity.T) / 2
347 return connectivity
350def _nn_connectivity(connectivity, threshold=1e-7):
351 """Fast implementation of nearest neighbor connectivity.
353 Parameters
354 ----------
355 connectivity : a sparse matrix in COOrdinate format.
356 Sparse matrix representation of the weighted adjacency graph.
358 threshold : float in the close interval [0, 1], default=1e-7
359 The threshold is set to handle eccentricities.
361 Returns
362 -------
363 nn_connectivity : a sparse matrix in COOrdinate format.
365 """
366 n_features = connectivity.shape[0]
368 connectivity_ = coo_matrix(
369 (1.0 / connectivity.data, connectivity.nonzero()),
370 (n_features, n_features),
371 ).tocsr()
373 # maximum on the axis = 0
374 max_connectivity = connectivity_.max(axis=0).toarray()[0]
375 inv_max = dia_matrix(
376 (1.0 / max_connectivity, 0), shape=(n_features, n_features)
377 )
379 connectivity_ = inv_max * connectivity_
381 # Dealing with eccentricities, there are probably many nearest neighbors
382 edge_mask = connectivity_.data > 1 - threshold
384 j_idx = connectivity_.nonzero()[1][edge_mask]
385 i_idx = connectivity_.nonzero()[0][edge_mask]
387 # Set weights to 1
388 weight = np.ones_like(j_idx)
389 edges = np.array([i_idx, j_idx])
391 nn_connectivity = coo_matrix((weight, edges), (n_features, n_features))
393 return nn_connectivity
396def _reduce_data_and_connectivity(
397 X, labels, n_components, connectivity, threshold=1e-7
398):
399 """Perform feature grouping and reduce the connectivity matrix.
401 During the reduction step one changes the value of each cluster
402 by their mean.
403 In addition, connected nodes are merged.
405 Parameters
406 ----------
407 X : ndarray, shape = [n_samples, n_features]
408 Training data.
410 labels : ndarray
411 Contains the label assignation for each voxel.
413 n_components : int
414 The number of clusters in the current iteration.
416 connectivity : a sparse matrix in COOrdinate format.
417 Sparse matrix representation of the weighted adjacency graph.
419 threshold : float in the close interval [0, 1], default=1e-7
420 The threshold is set to handle eccentricities.
422 Returns
423 -------
424 reduced_connectivity : a sparse matrix in COOrdinate format.
426 reduced_X : ndarray
427 Data reduced with agglomerated signal for each cluster.
429 """
430 n_features = len(labels)
432 incidence = coo_matrix(
433 (np.ones(n_features), (labels, np.arange(n_features))),
434 shape=(n_components, n_features),
435 dtype=np.float32,
436 ).tocsc()
438 inv_sum_col = dia_matrix(
439 (np.array(1.0 / incidence.sum(axis=1)).squeeze(), 0),
440 shape=(n_components, n_components),
441 )
443 incidence = inv_sum_col * incidence
445 reduced_X = (incidence * X.T).T
446 reduced_connectivity = (incidence * connectivity) * incidence.T
448 reduced_connectivity = reduced_connectivity - dia_matrix(
449 (reduced_connectivity.diagonal(), 0),
450 shape=(reduced_connectivity.shape),
451 )
453 i_idx, j_idx = reduced_connectivity.nonzero()
455 weights_ = np.sum((reduced_X[:, i_idx] - reduced_X[:, j_idx]) ** 2, axis=0)
456 weights_ = np.maximum(threshold, weights_)
457 reduced_connectivity.data = weights_
459 return reduced_connectivity, reduced_X
462def _nearest_neighbor_grouping(X, connectivity, n_clusters, threshold=1e-7):
463 """Cluster using nearest neighbor agglomeration.
465 Merge clusters according to their nearest neighbors,
466 then the data and the connectivity are reduced.
468 Parameters
469 ----------
470 X : :class:`numpy.ndarray`
471 Training data. shape = [n_samples, n_features]
473 connectivity : a sparse matrix in COOrdinate format.
474 Sparse matrix representation of the weighted adjacency graph.
476 n_clusters : :obj:`int`
477 The number of clusters to find.
479 threshold : :obj:`float` in the close interval [0, 1], default=1e-7
480 The threshold is set to handle eccentricities.
482 Returns
483 -------
484 reduced_connectivity : a sparse matrix in COOrdinate format.
486 reduced_X : :class:`numpy.ndarray`
487 Data reduced with agglomerated signal for each cluster.
489 labels : :class:`numpy.ndarray`, shape = [n_features]
490 It contains the clusters assignation.
492 """
493 # Nearest neighbor connectivity
494 nn_connectivity = _nn_connectivity(connectivity, threshold)
495 n_features = connectivity.shape[0]
496 n_components = n_features - (nn_connectivity + nn_connectivity.T).nnz / 2
498 if n_components < n_clusters:
499 # remove edges so that the final number of clusters is not less than
500 # n_clusters (to achieve the desired number of clusters)
501 n_edges = n_features - n_clusters
502 nn_connectivity = nn_connectivity + nn_connectivity.T
504 i_idx, j_idx = nn_connectivity.nonzero()
505 edges = np.array([i_idx, j_idx])
507 # select n_edges to merge.
508 edge_mask = np.argsort(i_idx - j_idx)[:n_edges]
509 # Set weights to 1, and the connectivity matrix symmetrical.
510 weight = np.ones(2 * n_edges)
511 edges = np.hstack([edges[:, edge_mask], edges[::-1, edge_mask]])
513 nn_connectivity = coo_matrix((weight, edges), (n_features, n_features))
515 # Clustering step: getting the connected components of the nn matrix
516 n_components, labels = csgraph.connected_components(nn_connectivity)
518 # Reduction step: reduction by averaging
519 reduced_connectivity, reduced_X = _reduce_data_and_connectivity(
520 X, labels, n_components, connectivity, threshold
521 )
523 return reduced_connectivity, reduced_X, labels
526@fill_doc
527def recursive_neighbor_agglomeration(
528 X, mask_img, n_clusters, n_iter=10, threshold=1e-7, verbose=0
529):
530 """Recursive neighbor agglomeration (:term:`ReNA`).
532 It performs iteratively the nearest neighbor grouping.
533 See :footcite:t:`Hoyos2019`.
535 Parameters
536 ----------
537 X : :class:`numpy.ndarray`
538 Training data. shape = [n_samples, n_features]
540 mask_img : Niimg-like object or :obj:`~nilearn.surface.SurfaceImage` object
541 Object used for masking the data.
543 n_clusters : :obj:`int`
544 The number of clusters to find.
546 n_iter : :obj:`int`, default=10
547 Number of iterations.
549 threshold : :obj:`float` in the close interval [0, 1], default=1e-07
550 The threshold is set to handle eccentricities.
552 %(verbose0)s
554 Returns
555 -------
556 n_components : :obj:`int`
557 Number of clusters.
559 labels : :class:`numpy.ndarray`
560 Cluster assignation. shape = [n_features]
562 References
563 ----------
564 .. footbibliography::
566 """
567 connectivity = _weighted_connectivity_graph(X, mask_img)
569 # Initialization
570 labels = np.arange(connectivity.shape[0])
571 n_components = connectivity.shape[0]
573 for i in range(n_iter):
574 connectivity, X, reduced_labels = _nearest_neighbor_grouping(
575 X, connectivity, n_clusters, threshold
576 )
578 labels = reduced_labels[labels]
579 n_components = connectivity.shape[0]
581 logger.log(
582 f"After iteration number {i + 1}, features are "
583 f" grouped into {n_components} clusters",
584 verbose,
585 )
587 if n_components <= n_clusters:
588 break
590 return n_components, labels
593@fill_doc
594class ReNA(ClusterMixin, TransformerMixin, BaseEstimator):
595 """Recursive Neighbor Agglomeration (:term:`ReNA`).
597 Recursively merges the pair of clusters according to 1-nearest neighbors
598 criterion.
599 See :footcite:t:`Hoyos2019`.
601 Parameters
602 ----------
603 mask_img : Niimg-like object or :obj:`~nilearn.surface.SurfaceImage` \
604 or :obj:`~nilearn.maskers.SurfaceMasker` object \
605 or None, default=None
606 Object used for masking the data.
608 n_clusters : :obj:`int`, default=2
609 The number of clusters to find.
611 scaling : :obj:`bool`, default=False
612 If scaling is True, each cluster is scaled by the square root of its
613 size, preserving the l2-norm of the image.
615 n_iter : :obj:`int`, default=10
616 Number of iterations of the recursive neighbor agglomeration.
618 threshold : :obj:`float` in the open interval (0., 1.), default=1e-7
619 Threshold used to handle eccentricities.
620 %(memory)s
621 %(memory_level1)s
622 %(verbose0)s
624 Attributes
625 ----------
626 labels_ : :class:`numpy.ndarray`, shape = [n_features]
627 Cluster labels for each feature.
629 n_clusters_ : :obj:`int`
630 Number of clusters.
632 sizes_ : :class:`numpy.ndarray`, shape = [n_features]
633 It contains the size of each cluster.
635 References
636 ----------
637 .. footbibliography::
639 """
641 def __init__(
642 self,
643 mask_img=None,
644 n_clusters=2,
645 scaling=False,
646 n_iter=10,
647 threshold=1e-7,
648 memory=None,
649 memory_level=1,
650 verbose=0,
651 ):
652 self.mask_img = mask_img
653 self.n_clusters = n_clusters
654 self.scaling = scaling
655 self.n_iter = n_iter
656 self.threshold = threshold
657 self.memory = memory
658 self.memory_level = memory_level
659 self.verbose = verbose
661 def _more_tags(self):
662 """Return estimator tags.
664 TODO remove when bumping sklearn_version > 1.5
665 """
666 return self.__sklearn_tags__()
668 def __sklearn_tags__(self):
669 """Return estimator tags.
671 See the sklearn documentation for more details on tags
672 https://scikit-learn.org/1.6/developers/develop.html#estimator-tags
673 """
674 # TODO
675 # get rid of if block
676 # bumping sklearn_version > 1.5
677 if SKLEARN_LT_1_6:
678 from nilearn._utils.tags import tags
680 return tags()
682 from nilearn._utils.tags import InputTags
684 tags = super().__sklearn_tags__()
685 tags.input_tags = InputTags(niimg_like=False)
686 return tags
688 @fill_doc
689 def fit(self, X, y=None):
690 """Compute clustering of the data.
692 Parameters
693 ----------
694 X : :class:`numpy.ndarray`, shape = [n_samples, n_features]
695 Training data.
697 %(y_dummy)s
699 Returns
700 -------
701 self : `ReNA` object
703 """
704 del y
705 check_params(self.__dict__)
706 X = check_array(
707 X, ensure_min_features=2, ensure_min_samples=2, estimator=self
708 )
709 n_features = X.shape[1]
711 if not isinstance(
712 self.mask_img, (str, Nifti1Image, SurfaceImage, SurfaceMasker)
713 ):
714 raise TypeError(
715 "The mask image should be a Niimg-like object, "
716 "a SurfaceImage object or a SurfaceMasker."
717 f"Instead a {type(self.mask_img)} object was provided."
718 )
720 # If mask_img is a SurfaceMasker, we need to extract the mask_img
721 if isinstance(self.mask_img, SurfaceMasker):
722 self.mask_img = self.mask_img.mask_img_
724 if self.memory is None or isinstance(self.memory, str):
725 self.memory_ = Memory(
726 location=self.memory, verbose=max(0, self.verbose - 1)
727 )
728 else:
729 self.memory_ = self.memory
731 if self.n_clusters <= 0:
732 raise ValueError(
733 "n_clusters should be an integer greater than 0."
734 f" {self.n_clusters} was provided."
735 )
737 if self.n_iter <= 0:
738 raise ValueError(
739 "n_iter should be an integer greater than 0."
740 f" {self.n_iter} was provided."
741 )
743 if self.n_clusters > n_features:
744 self.n_clusters = n_features
745 warnings.warn(
746 "n_clusters should be at most the number of features. "
747 f"Taking n_clusters = {n_features} instead.",
748 stacklevel=find_stack_level(),
749 )
751 n_components, labels = self.memory_.cache(
752 recursive_neighbor_agglomeration
753 )(
754 X,
755 self.mask_img,
756 self.n_clusters,
757 n_iter=self.n_iter,
758 threshold=self.threshold,
759 verbose=self.verbose,
760 )
762 sizes = np.bincount(labels)
763 sizes = sizes[sizes > 0]
765 self.labels_ = labels
766 self.n_clusters_ = np.unique(self.labels_).shape[0]
767 self.sizes_ = sizes
769 return self
771 def __sklearn_is_fitted__(self):
772 return hasattr(self, "labels_")
774 @fill_doc
775 def transform(
776 self,
777 X,
778 y=None, # noqa: ARG002
779 ):
780 """Apply clustering, reduce the dimensionality of the data.
782 Parameters
783 ----------
784 X : :class:`numpy.ndarray`, shape = [n_samples, n_features]
785 Data to transform with the fitted clustering.
787 %(y_dummy)s
789 Returns
790 -------
791 X_red : :class:`numpy.ndarray`, shape = [n_samples, n_clusters]
792 Data reduced with agglomerated signal for each cluster.
794 """
795 check_is_fitted(self)
797 unique_labels = np.unique(self.labels_)
799 mean_cluster = [
800 np.mean(X[:, self.labels_ == label], axis=1)
801 for label in unique_labels
802 ]
803 X_red = np.array(mean_cluster).T
805 if self.scaling:
806 X_red = X_red * np.sqrt(self.sizes_)
808 return X_red
810 def inverse_transform(self, X_red):
811 """Send the reduced 2D data matrix back to the original feature \
812 space (:term:`voxels<voxel>`).
814 Parameters
815 ----------
816 X_red : :class:`numpy.ndarray`, shape = [n_samples, n_clusters]
817 Data reduced with agglomerated signal for each cluster.
819 Returns
820 -------
821 X_inv : :class:`numpy.ndarray`, shape = [n_samples, n_features]
822 Data reduced expanded to the original feature space.
824 """
825 check_is_fitted(self)
827 _, inverse = np.unique(self.labels_, return_inverse=True)
829 if self.scaling:
830 X_red = X_red / np.sqrt(self.sizes_)
831 X_inv = X_red[..., inverse]
833 return X_inv
835 def set_output(self, *, transform=None):
836 """Set the output container when ``"transform"`` is called.
838 .. warning::
840 This has not been implemented yet.
841 """
842 raise NotImplementedError()