Coverage for nilearn/regions/parcellations.py: 18%

148 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 10:58 +0200

1"""Parcellation tools such as KMeans or Ward for fMRI images.""" 

2 

3import warnings 

4from typing import ClassVar 

5 

6import numpy as np 

7from joblib import Memory, Parallel, delayed 

8from scipy.sparse import coo_matrix 

9from sklearn.base import clone 

10from sklearn.feature_extraction import image 

11from sklearn.utils.estimator_checks import check_is_fitted 

12 

13from nilearn._utils import fill_doc, logger, stringify_path 

14from nilearn._utils.logger import find_stack_level 

15from nilearn._utils.niimg import safe_get_data 

16from nilearn._utils.niimg_conversions import iter_check_niimg 

17from nilearn.decomposition._multi_pca import _MultiPCA 

18from nilearn.maskers import NiftiLabelsMasker, SurfaceLabelsMasker 

19from nilearn.maskers.surface_labels_masker import signals_to_surf_img_labels 

20from nilearn.regions.hierarchical_kmeans_clustering import HierarchicalKMeans 

21from nilearn.regions.rena_clustering import ( 

22 ReNA, 

23 make_edges_surface, 

24) 

25from nilearn.surface import SurfaceImage 

26 

27 

28def _connectivity_surface(mask_img): 

29 """Compute connectivity matrix for surface data, used for Agglomerative 

30 Clustering method. 

31 

32 Based on surface part of 

33 :func:`~nilearn.regions.rena_clustering._weighted_connectivity_graph`. 

34 The difference is that this function returns a non-weighted connectivity 

35 matrix with diagonal set to 1 (because that's what use with volumes). 

36 

37 Parameters 

38 ---------- 

39 mask_img : :class:`~nilearn.surface.SurfaceImage` object 

40 Mask image provided to the Parcellation object. 

41 

42 Returns 

43 ------- 

44 connectivity : a sparse matrix 

45 Connectivity or adjacency matrix for the mask. 

46 

47 """ 

48 # total True vertices in the mask 

49 n_vertices = ( 

50 mask_img.data.parts["left"].sum() + mask_img.data.parts["right"].sum() 

51 ) 

52 connectivity = coo_matrix((n_vertices, n_vertices)) 

53 len_previous_mask = 0 

54 for part in mask_img.mesh.parts: 

55 face_part = mask_img.mesh.parts[part].faces 

56 mask_part = mask_img.data.parts[part] 

57 edges, edge_mask = make_edges_surface(face_part, mask_part) 

58 # keep only the edges that are in the mask 

59 edges = edges[:, edge_mask] 

60 # Reorder the indices of the graph 

61 max_index = edges.max() 

62 order = np.searchsorted( 

63 np.unique(edges.ravel()), np.arange(max_index + 1) 

64 ) 

65 # increasing the order by the number of vertices in the previous mask 

66 # to avoid overlapping indices 

67 order += len_previous_mask 

68 # reorder the edges such that the first True edge in the mask is the 

69 # is the first edge in the matrix (even if it is not the first edge in 

70 # the mask) and so on... 

71 edges = order[edges] 

72 len_previous_mask += mask_part.sum() 

73 # update the connectivity matrix 

74 conn_temp = coo_matrix( 

75 (np.ones((edges.shape[1])), edges), 

76 (n_vertices, n_vertices), 

77 ).tocsr() 

78 connectivity += conn_temp 

79 # make symmetric 

80 connectivity = connectivity + connectivity.T 

81 # set diagonal to 1 for connectivity matrix 

82 connectivity[np.diag_indices_from(connectivity)] = 1 

83 return connectivity 

84 

85 

86def _estimator_fit(data, estimator, method=None): 

87 """Estimator to fit on the data matrix. 

88 Mostly just choosing which methods to transpose the data for because 

89 KMeans, AgglomerativeClustering cluster first dimension of data (samples) 

90 but we want to cluster features (voxels). 

91 

92 Parameters 

93 ---------- 

94 data : numpy array 

95 Data matrix. 

96 

97 estimator : instance of estimator from sklearn 

98 MiniBatchKMeans or AgglomerativeClustering. 

99 

100 method : str, 

101 {'kmeans', 'ward', 'complete', 'average', 'rena', 'hierarchical_kmeans'}, 

102 optional 

103 

104 A method to choose between for brain parcellations. 

105 

106 Returns 

107 ------- 

108 labels_ : numpy.ndarray 

109 labels_ estimated from estimator. 

110 

111 """ 

112 estimator = clone(estimator) 

113 if method in ["rena", "hierarchical_kmeans"]: 

114 estimator.fit(data) 

115 # transpose data for KMeans, AgglomerativeClustering because 

116 # they cluster first dimension of data (samples) but we want to cluster 

117 # features (voxels) 

118 else: 

119 estimator.fit(data.T) 

120 labels_ = estimator.labels_ 

121 

122 return labels_ 

123 

124 

125def _check_parameters_transform(imgs, confounds): 

126 """Check the parameters and prepare for processing as a list.""" 

127 imgs = stringify_path(imgs) 

128 confounds = stringify_path(confounds) 

129 if not isinstance(imgs, (list, tuple)) or isinstance(imgs, str): 

130 imgs = [imgs] 

131 single_subject = True 

132 elif len(imgs) == 1: 

133 single_subject = True 

134 else: 

135 single_subject = False 

136 

137 if confounds is None and isinstance(imgs, (list, tuple)): 

138 confounds = [None] * len(imgs) 

139 

140 if confounds is not None and ( 

141 not isinstance(confounds, (list, tuple)) or isinstance(confounds, str) 

142 ): 

143 confounds = [confounds] 

144 

145 if len(confounds) != len(imgs): 

146 raise ValueError( 

147 "Number of confounds given does not match with " 

148 "the given number of images." 

149 ) 

150 return imgs, confounds, single_subject 

151 

152 

153def _labels_masker_extraction(img, masker, confound): 

154 """Parallelize NiftiLabelsMasker extractor on list of Nifti images. 

155 

156 Parameters 

157 ---------- 

158 img : 4D Nifti image like object 

159 Image to process. 

160 

161 masker : instance of NiftiLabelsMasker 

162 Used for extracting signals with fit_transform. 

163 

164 confound : csv file, numpy ndarray or pandas DataFrame 

165 Confound used for signal cleaning while extraction. 

166 Passed to signal.clean. 

167 

168 Returns 

169 ------- 

170 signals : numpy array 

171 Signals extracted on given img. 

172 

173 """ 

174 masker = clone(masker) 

175 signals = masker.fit_transform(img, confounds=confound) 

176 return signals 

177 

178 

179def _get_unique_labels(labels_img): 

180 """Get unique labels from labels image.""" 

181 # remove singleton dimension if present 

182 for part in labels_img.data.parts: 

183 if ( 

184 labels_img.data.parts[part].ndim == 2 

185 and labels_img.data.parts[part].shape[-1] == 1 

186 ): 

187 labels_img.data.parts[part] = labels_img.data.parts[part].squeeze() 

188 labels_data = np.concatenate(list(labels_img.data.parts.values()), axis=0) 

189 return np.unique(labels_data) 

190 

191 

192@fill_doc 

193class Parcellations(_MultiPCA): 

194 """Learn :term:`parcellations<parcellation>` \ 

195 on :term:`fMRI` images. 

196 

197 Five different types of clustering methods can be used: 

198 kmeans, ward, complete, average and rena. 

199 kmeans will call MiniBatchKMeans whereas 

200 ward, complete, average are used within in Agglomerative Clustering and 

201 rena will call ReNA. 

202 kmeans, ward, complete, average are leveraged from scikit-learn. 

203 rena is built into nilearn. 

204 

205 .. versionadded:: 0.4.1 

206 

207 Parameters 

208 ---------- 

209 method : {'kmeans', 'ward', 'complete', 'average', 'rena', \ 

210 'hierarchical_kmeans'} 

211 A method to choose between for brain parcellations. 

212 For a small number of parcels, kmeans is usually advisable. 

213 For a large number of parcellations (several hundreds, or thousands), 

214 ward and rena are the best options. Ward will give higher quality 

215 parcels, but with increased computation time. ReNA is most useful as a 

216 fast data-reduction step, typically dividing the signal size by ten. 

217 

218 n_parcels : :obj:`int`, default=50 

219 Number of parcels to divide the data into. 

220 

221 %(random_state)s 

222 Default=0. 

223 

224 mask : Niimg-like object or :class:`~nilearn.surface.SurfaceImage`,\ 

225 or :class:`nilearn.maskers.NiftiMasker`,\ 

226 :class:`nilearn.maskers.MultiNiftiMasker` or \ 

227 :class:`nilearn.maskers.SurfaceMasker`, optional 

228 Mask/Masker used for masking the data. 

229 If mask image if provided, it will be used in the MultiNiftiMasker or 

230 SurfaceMasker (depending on the type of mask image). 

231 If an instance of either maskers is provided, then this instance 

232 parameters will be used in masking the data by overriding the default 

233 masker parameters. 

234 If None, mask will be automatically computed by a MultiNiftiMasker 

235 with default parameters for Nifti images and no mask will be used for 

236 SurfaceImage. 

237 %(smoothing_fwhm)s 

238 Default=4.0. 

239 %(standardize_false)s 

240 %(detrend)s 

241 

242 .. note:: 

243 This parameter is passed to :func:`nilearn.signal.clean`. 

244 Please see the related documentation for details. 

245 

246 Default=False. 

247 %(low_pass)s 

248 

249 .. note:: 

250 This parameter is passed to :func:`nilearn.signal.clean`. 

251 Please see the related documentation for details. 

252 

253 %(high_pass)s 

254 

255 .. note:: 

256 This parameter is passed to :func:`nilearn.signal.clean`. 

257 Please see the related documentation for details. 

258 

259 %(t_r)s 

260 

261 .. note:: 

262 This parameter is passed to :func:`nilearn.signal.clean`. 

263 Please see the related documentation for details. 

264 

265 %(target_affine)s 

266 

267 .. note:: 

268 This parameter is passed to :func:`nilearn.image.resample_img`. 

269 Please see the related documentation for details. 

270 

271 .. note:: 

272 The given affine will be considered as same for all 

273 given list of images. 

274 

275 %(target_shape)s 

276 

277 .. note:: 

278 This parameter is passed to :func:`nilearn.image.resample_img`. 

279 Please see the related documentation for details. 

280 

281 %(mask_strategy)s 

282 

283 .. note:: 

284 Depending on this value, the mask will be computed from 

285 :func:`nilearn.masking.compute_background_mask`, 

286 :func:`nilearn.masking.compute_epi_mask`, or 

287 :func:`nilearn.masking.compute_brain_mask`. 

288 

289 Default='epi'. 

290 

291 mask_args : :obj:`dict`, optional 

292 If mask is None, these are additional parameters passed to 

293 :func:`nilearn.masking.compute_background_mask`, 

294 or :func:`nilearn.masking.compute_epi_mask` 

295 to fine-tune mask computation. 

296 Please see the related documentation for details. 

297 

298 scaling : :obj:`bool`, default=False 

299 Used only when the method selected is 'rena'. If scaling is True, each 

300 cluster is scaled by the square root of its size, preserving the 

301 l2-norm of the image. 

302 

303 n_iter : :obj:`int`, default=10 

304 Used only when the method selected is 'rena'. Number of iterations of 

305 the recursive neighbor agglomeration. 

306 %(memory)s 

307 %(memory_level)s 

308 %(n_jobs)s 

309 %(verbose0)s 

310 

311 Attributes 

312 ---------- 

313 labels_img_ : :class:`nibabel.nifti1.Nifti1Image` 

314 Labels image to each parcellation learned on fmri images. 

315 

316 masker_ : :class:`nilearn.maskers.NiftiMasker` or \ 

317 :class:`nilearn.maskers.MultiNiftiMasker` 

318 The masker used to mask the data. 

319 

320 connectivity_ : :class:`numpy.ndarray` 

321 Voxel-to-voxel connectivity matrix computed from a mask. 

322 Note that this attribute is only seen if selected methods are 

323 Agglomerative Clustering type, 'ward', 'complete', 'average'. 

324 

325 Notes 

326 ----- 

327 * Transforming list of images to data matrix takes few steps. 

328 Reducing the data dimensionality using randomized SVD, build brain 

329 parcellations using KMeans or various Agglomerative methods. 

330 

331 * This object uses spatially-constrained AgglomerativeClustering for 

332 method='ward' or 'complete' or 'average' and spatially-constrained 

333 ReNA clustering for method='rena'. Spatial connectivity matrix 

334 (voxel-to-voxel) is built-in object which means no need of explicitly 

335 giving the matrix. 

336 

337 """ 

338 

339 VALID_METHODS: ClassVar[tuple[str, ...]] = ( 

340 "kmeans", 

341 "ward", 

342 "complete", 

343 "average", 

344 "rena", 

345 "hierarchical_kmeans", 

346 ) 

347 

348 def __init__( 

349 self, 

350 method, 

351 n_parcels=50, 

352 random_state=0, 

353 mask=None, 

354 smoothing_fwhm=4.0, 

355 standardize=False, 

356 detrend=False, 

357 low_pass=None, 

358 high_pass=None, 

359 t_r=None, 

360 target_affine=None, 

361 target_shape=None, 

362 mask_strategy="epi", 

363 mask_args=None, 

364 scaling=False, 

365 n_iter=10, 

366 memory=None, 

367 memory_level=0, 

368 n_jobs=1, 

369 verbose=1, 

370 ): 

371 if memory is None: 

372 memory = Memory(location=None) 

373 self.method = method 

374 self.n_parcels = n_parcels 

375 self.scaling = scaling 

376 self.n_iter = n_iter 

377 

378 _MultiPCA.__init__( 

379 self, 

380 n_components=200, 

381 random_state=random_state, 

382 mask=mask, 

383 memory=memory, 

384 smoothing_fwhm=smoothing_fwhm, 

385 standardize=standardize, 

386 detrend=detrend, 

387 low_pass=low_pass, 

388 high_pass=high_pass, 

389 t_r=t_r, 

390 target_affine=target_affine, 

391 target_shape=target_shape, 

392 mask_strategy=mask_strategy, 

393 mask_args=mask_args, 

394 memory_level=memory_level, 

395 n_jobs=n_jobs, 

396 verbose=verbose, 

397 ) 

398 

399 def _raw_fit(self, data): 

400 """Fits the parcellation method on this reduced data. 

401 

402 Data are coming from a base decomposition estimator which computes 

403 the mask and reduces the dimensionality of images using 

404 randomized_svd. 

405 

406 Parameters 

407 ---------- 

408 data : :class:`numpy.ndarray` 

409 Shape (n_samples, n_features) 

410 

411 Returns 

412 ------- 

413 labels : :class:`numpy.ndarray` 

414 Labels to each cluster in the brain. 

415 

416 connectivity : :class:`numpy.ndarray` 

417 Voxel-to-voxel connectivity matrix computed from a mask. 

418 Note that, this attribute is returned only for selected methods 

419 such as 'ward', 'complete', 'average'. 

420 

421 """ 

422 valid_methods = self.VALID_METHODS 

423 if self.method is None: 

424 raise ValueError( 

425 "Parcellation method is specified as None. " 

426 f"Please select one of the method in {valid_methods}" 

427 ) 

428 if self.method not in valid_methods: 

429 raise ValueError( 

430 f"The method you have selected is not implemented " 

431 f"'{self.method}'. Valid methods are in {valid_methods}" 

432 ) 

433 

434 # we delay importing Ward or AgglomerativeClustering and same 

435 # time import plotting module before that. 

436 

437 components = _MultiPCA._raw_fit(self, data) 

438 

439 mask_img_ = self.masker_.mask_img_ 

440 

441 logger.log( 

442 f"computing {self.method}", 

443 verbose=self.verbose, 

444 ) 

445 

446 if self.method == "kmeans": 

447 from sklearn.cluster import MiniBatchKMeans 

448 

449 kmeans = MiniBatchKMeans( 

450 n_clusters=self.n_parcels, 

451 init="k-means++", 

452 n_init=3, 

453 random_state=self.random_state, 

454 verbose=max(0, self.verbose - 1), 

455 ) 

456 labels = self._cache(_estimator_fit, func_memory_level=1)( 

457 components.T, kmeans 

458 ) 

459 elif self.method == "hierarchical_kmeans": 

460 hkmeans = HierarchicalKMeans( 

461 self.n_parcels, 

462 init="k-means++", 

463 batch_size=1000, 

464 n_init=10, 

465 max_no_improvement=10, 

466 random_state=self.random_state, 

467 verbose=max(0, self.verbose - 1), 

468 ) 

469 # data ou data.T 

470 labels = self._cache(_estimator_fit, func_memory_level=1)( 

471 components.T, hkmeans, self.method 

472 ) 

473 elif self.method == "rena": 

474 rena = ReNA( 

475 mask_img_, 

476 n_clusters=self.n_parcels, 

477 scaling=self.scaling, 

478 n_iter=self.n_iter, 

479 memory=self.memory, 

480 memory_level=self.memory_level, 

481 verbose=max(0, self.verbose - 1), 

482 ) 

483 method = "rena" 

484 labels = self._cache(_estimator_fit, func_memory_level=1)( 

485 components.T, rena, method 

486 ) 

487 

488 else: 

489 if isinstance(mask_img_, SurfaceImage): 

490 connectivity = _connectivity_surface(mask_img_) 

491 else: 

492 mask_ = safe_get_data(mask_img_).astype(bool) 

493 shape = mask_.shape 

494 connectivity = image.grid_to_graph( 

495 n_x=shape[0], n_y=shape[1], n_z=shape[2], mask=mask_ 

496 ) 

497 from sklearn.cluster import AgglomerativeClustering 

498 

499 agglomerative = AgglomerativeClustering( 

500 n_clusters=self.n_parcels, 

501 connectivity=connectivity, 

502 linkage=self.method, 

503 memory=self.memory, 

504 ) 

505 

506 labels = self._cache(_estimator_fit, func_memory_level=1)( 

507 components.T, agglomerative 

508 ) 

509 

510 self.connectivity_ = connectivity 

511 # Avoid 0 label 

512 labels = labels + 1 

513 unique_labels = np.unique(labels) 

514 

515 # Check that appropriate number of labels were created 

516 if len(unique_labels) != self.n_parcels: 

517 n_parcels_warning = ( 

518 "The number of generated labels does not " 

519 "match the requested number of parcels." 

520 ) 

521 warnings.warn( 

522 message=n_parcels_warning, 

523 category=UserWarning, 

524 stacklevel=find_stack_level(), 

525 ) 

526 self.labels_img_ = self.masker_.inverse_transform( 

527 labels.astype(np.int32) 

528 ) 

529 

530 return self 

531 

532 def __sklearn_is_fitted__(self): 

533 return hasattr(self, "labels_img_") 

534 

535 @fill_doc 

536 def transform(self, imgs, confounds=None): 

537 """Extract signals from :term:`parcellations<parcellation>` learned \ 

538 on :term:`fMRI` images. 

539 

540 Parameters 

541 ---------- 

542 %(imgs)s 

543 Images to process. 

544 

545 confounds : :obj:`list` of CSV files, arrays-like,\ 

546 or :class:`pandas.DataFrame`, default=None 

547 Each file or numpy array in a list should have shape 

548 (number of scans, number of confounds) 

549 Must be of same length as imgs. 

550 

551 .. note:: 

552 This parameter is passed to :func:`nilearn.signal.clean`. 

553 Please see the related documentation for details. 

554 

555 Returns 

556 ------- 

557 region_signals : :obj:`list` of or 2D :class:`numpy.ndarray` 

558 Signals extracted for each label for each image. 

559 Example, for single image shape will be 

560 (number of scans, number of labels) 

561 

562 """ 

563 check_is_fitted(self) 

564 imgs, confounds, single_subject = _check_parameters_transform( 

565 imgs, confounds 

566 ) 

567 # Required for special cases like extracting signals on list of 

568 # 3D images or SurfaceImages. 

569 if isinstance(self.masker_.mask_img_, SurfaceImage): 

570 imgs_list = imgs.copy() 

571 masker = SurfaceLabelsMasker( 

572 self.labels_img_, 

573 mask_img=self.masker_.mask_img_, 

574 smoothing_fwhm=self.smoothing_fwhm, 

575 standardize=self.standardize, 

576 detrend=self.detrend, 

577 low_pass=self.low_pass, 

578 high_pass=self.high_pass, 

579 t_r=self.t_r, 

580 memory=self.memory, 

581 memory_level=self.memory_level, 

582 verbose=self.verbose, 

583 ) 

584 else: 

585 imgs_list = iter_check_niimg(imgs, atleast_4d=True) 

586 masker = NiftiLabelsMasker( 

587 self.labels_img_, 

588 mask_img=self.masker_.mask_img_, 

589 smoothing_fwhm=self.smoothing_fwhm, 

590 standardize=self.standardize, 

591 detrend=self.detrend, 

592 low_pass=self.low_pass, 

593 high_pass=self.high_pass, 

594 t_r=self.t_r, 

595 resampling_target="data", 

596 memory=self.memory, 

597 memory_level=self.memory_level, 

598 verbose=self.verbose, 

599 ) 

600 

601 region_signals = Parallel(n_jobs=self.n_jobs)( 

602 delayed( 

603 self._cache(_labels_masker_extraction, func_memory_level=2) 

604 )(img, masker, confound) 

605 for img, confound in zip(imgs_list, confounds) 

606 ) 

607 

608 return region_signals[0] if single_subject else region_signals 

609 

610 @fill_doc 

611 def fit_transform(self, imgs, confounds=None): 

612 """Fit the images to :term:`parcellations<parcellation>` and \ 

613 then transform them. 

614 

615 Parameters 

616 ---------- 

617 %(imgs)s 

618 Images for process for fit as well for transform to signals. 

619 

620 confounds : :obj:`list` of CSV files, arrays-like or\ 

621 :class:`pandas.DataFrame`, default=None 

622 Each file or numpy array in a list should have shape 

623 (number of scans, number of confounds). 

624 Given confounds should have same length as images if 

625 given as a list. 

626 

627 .. note:: 

628 This parameter is passed to :func:`nilearn.signal.clean`. 

629 Please see the related documentation for details. 

630 

631 .. note:: 

632 Confounds will be used for cleaning signals before 

633 learning parcellations. 

634 

635 Returns 

636 ------- 

637 region_signals : :obj:`list` of or 2D :class:`numpy.ndarray` 

638 Signals extracted for each label for each image. 

639 Example, for single image shape will be 

640 (number of scans, number of labels) 

641 

642 """ 

643 return self.fit(imgs, confounds=confounds).transform(imgs, confounds) 

644 

645 @fill_doc 

646 def inverse_transform(self, signals): 

647 """Transform signals extracted \ 

648 from :term:`parcellations<parcellation>` back to brain images. 

649 

650 Uses `labels_img_` (parcellations) built at fit() level. 

651 

652 Parameters 

653 ---------- 

654 signals : :obj:`list` of 2D :class:`numpy.ndarray` 

655 Each 2D array with shape (number of scans, number of regions). 

656 

657 Returns 

658 ------- 

659 %(imgs)s 

660 Brain image(s). 

661 

662 """ 

663 from .signal_extraction import signals_to_img_labels 

664 

665 check_is_fitted(self) 

666 

667 if not isinstance(signals, (list, tuple)) or isinstance( 

668 signals, np.ndarray 

669 ): 

670 signals = [signals] 

671 single_subject = True 

672 elif len(signals) == 1: 

673 single_subject = True 

674 else: 

675 single_subject = False 

676 

677 if isinstance(self.mask_img_, SurfaceImage): 

678 labels = _get_unique_labels(self.labels_img_) 

679 imgs = Parallel(n_jobs=self.n_jobs)( 

680 delayed( 

681 self._cache( 

682 signals_to_surf_img_labels, func_memory_level=2 

683 ) 

684 )(each_signal, labels, self.labels_img_, self.mask_img_) 

685 for each_signal in signals 

686 ) 

687 else: 

688 imgs = Parallel(n_jobs=self.n_jobs)( 

689 delayed( 

690 self._cache(signals_to_img_labels, func_memory_level=2) 

691 )(each_signal, self.labels_img_, self.mask_img_) 

692 for each_signal in signals 

693 ) 

694 

695 return imgs[0] if single_subject else imgs