Coverage for nilearn/decomposition/dict_learning.py: 23%
62 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 10:58 +0200
1"""Dictionary learning estimator.
3Perform a map learning algorithm by learning
4a temporal dense dictionary along with sparse spatial loadings, that
5constitutes output maps
6"""
8import warnings
10import numpy as np
11from sklearn.decomposition import dict_learning_online
12from sklearn.linear_model import Ridge
14from nilearn._utils import fill_doc, logger
15from nilearn._utils.helpers import transfer_deprecated_param_vals
17from ._base import _BaseDecomposition
18from .canica import CanICA
20# check_input=False is an optimization available in sklearn.
21sparse_encode_args = {"check_input": False}
24def _compute_loadings(components, data):
25 ridge = Ridge(fit_intercept=False, alpha=1e-8)
26 ridge.fit(components.T, np.asarray(data.T))
27 loadings = ridge.coef_.T
29 S = np.sqrt(np.sum(loadings**2, axis=0))
30 S[S == 0] = 1
31 loadings /= S[np.newaxis, :]
32 return loadings
35@fill_doc
36class DictLearning(_BaseDecomposition):
37 """Perform a map learning algorithm based on spatial component sparsity, \
38 over a :term:`CanICA` initialization.
40 This yields more stable maps than :term:`CanICA`.
42 See :footcite:t:`Mensch2016`.
44 .. versionadded:: 0.2
46 Parameters
47 ----------
48 mask : Niimg-like object, :obj:`~nilearn.maskers.MultiNiftiMasker` or \
49 :obj:`~nilearn.surface.SurfaceImage` or \
50 :obj:`~nilearn.maskers.SurfaceMasker` object, optional
51 Mask to be used on data. If an instance of masker is passed,
52 then its mask will be used. If no mask is given, for Nifti images,
53 it will be computed automatically by a MultiNiftiMasker with default
54 parameters; for surface images, all the vertices will be used.
56 n_components : :obj:`int`, default=20
57 Number of components to extract.
59 batch_size : :obj:`int`, default=20
60 The number of samples to take in each batch.
62 n_epochs : :obj:`float`, default=1
63 Number of epochs the algorithm should run on the data.
65 alpha : :obj:`float`, default=10
66 Sparsity controlling parameter.
68 dict_init : Niimg-like object or \
69 :obj:`~nilearn.surface.SurfaceImage`, optional
70 Initial estimation of dictionary maps. Would be computed from CanICA if
71 not provided.
73 reduction_ratio : 'auto' or :obj:`float` between 0. and 1., default='auto'
74 - Between 0. or 1. : controls data reduction in the temporal domain.
75 1. means no reduction, < 1. calls for an SVD based reduction.
76 - if set to 'auto', estimator will set the number of components per
77 reduced session to be n_components.
79 method : {'cd', 'lars'}, default='cd'
80 Coding method used by sklearn backend. Below are the possible values.
81 lars: uses the least angle regression method to solve the lasso problem
82 (linear_model.lars_path)
83 cd: uses the coordinate descent method to compute the
84 Lasso solution (linear_model.Lasso). Lars will be faster if
85 the estimated components are sparse.
87 %(random_state)s
89 %(smoothing_fwhm)s
90 Default=4mm.
92 standardize : :obj:`bool`, default=True
93 If standardize is True, the time-series are centered and normed:
94 their variance is put to 1 in the time dimension.
96 detrend : :obj:`bool`, default=True
97 If detrend is True, the time-series will be detrended before
98 components extraction.
100 %(target_affine)s
102 .. note::
103 This parameter is passed to :func:`nilearn.image.resample_img`.
105 %(target_shape)s
107 .. note::
108 This parameter is passed to :func:`nilearn.image.resample_img`.
110 %(low_pass)s
112 .. note::
113 This parameter is passed to :func:`nilearn.image.resample_img`.
115 %(high_pass)s
117 .. note::
118 This parameter is passed to :func:`nilearn.image.resample_img`.
120 %(t_r)s
122 .. note::
123 This parameter is passed to :func:`nilearn.image.resample_img`.
125 %(mask_strategy)s
127 Default='epi'.
129 .. note::
130 These strategies are only relevant for Nifti images and the
131 parameter is ignored for SurfaceImage objects.
133 mask_args : :obj:`dict`, optional
134 If mask is None, these are additional parameters passed to
135 :func:`nilearn.masking.compute_background_mask`,
136 or :func:`nilearn.masking.compute_epi_mask`
137 to fine-tune mask computation.
138 Please see the related documentation for details.
140 %(memory)s
142 %(memory_level)s
144 %(n_jobs)s
146 %(verbose0)s
148 %(base_decomposition_attributes)s
150 %(multi_pca_attributes)s
152 References
153 ----------
154 .. footbibliography::
156 """
158 def __init__(
159 self,
160 n_components=20,
161 n_epochs=1,
162 alpha=10,
163 reduction_ratio="auto",
164 dict_init=None,
165 random_state=None,
166 batch_size=20,
167 method="cd",
168 mask=None,
169 smoothing_fwhm=4,
170 standardize=True,
171 detrend=True,
172 low_pass=None,
173 high_pass=None,
174 t_r=None,
175 target_affine=None,
176 target_shape=None,
177 mask_strategy="epi",
178 mask_args=None,
179 n_jobs=1,
180 verbose=0,
181 memory=None,
182 memory_level=0,
183 ):
184 super().__init__(
185 n_components=n_components,
186 random_state=random_state,
187 mask=mask,
188 smoothing_fwhm=smoothing_fwhm,
189 standardize=standardize,
190 detrend=detrend,
191 low_pass=low_pass,
192 high_pass=high_pass,
193 t_r=t_r,
194 target_affine=target_affine,
195 target_shape=target_shape,
196 mask_strategy=mask_strategy,
197 mask_args=mask_args,
198 memory=memory,
199 memory_level=memory_level,
200 n_jobs=n_jobs,
201 verbose=verbose,
202 )
203 self.n_epochs = n_epochs
204 self.batch_size = batch_size
205 self.method = method
206 self.alpha = alpha
207 self.reduction_ratio = reduction_ratio
208 self.dict_init = dict_init
210 def _init_dict(self, data):
211 if self.dict_init is not None:
212 components = self.masker_.transform(self.dict_init)
213 else:
214 canica = CanICA(
215 n_components=self.n_components,
216 # CanICA specific parameters
217 do_cca=True,
218 threshold=float(self.n_components),
219 n_init=1,
220 # mask parameter is not useful as we bypass masking
221 mask=self.masker_,
222 random_state=self.random_state,
223 memory=self.memory,
224 memory_level=self.memory_level,
225 n_jobs=self.n_jobs,
226 verbose=self.verbose,
227 )
228 with warnings.catch_warnings():
229 warnings.simplefilter("ignore", UserWarning)
230 # We use protected function _raw_fit as data
231 # has already been unmasked
232 canica._raw_fit(data)
233 components = canica.components_
234 S = (components**2).sum(axis=1)
235 S[S == 0] = 1
236 components /= S[:, np.newaxis]
237 self.components_init_ = components
239 def _init_loadings(self, data):
240 self.loadings_init_ = self._cache(_compute_loadings)(
241 self.components_init_, data
242 )
244 def _raw_fit(self, data):
245 """Process unmasked data directly.
247 Parameters
248 ----------
249 data : ndarray,
250 Shape (n_samples, n_features)
252 """
253 logger.log("Learning initial components", self.verbose)
254 self._init_dict(data)
256 _, n_features = data.shape
258 logger.log(
259 "Computing initial loadings",
260 verbose=self.verbose,
261 )
262 self._init_loadings(data)
264 dict_init = self.loadings_init_
266 max_iter = ((n_features - 1) // self.batch_size + 1) * self.n_epochs
268 logger.log(
269 " Learning dictionary",
270 verbose=self.verbose,
271 )
273 kwargs = transfer_deprecated_param_vals(
274 {"n_iter": "max_iter"}, {"max_iter": max_iter}
275 )
276 self.components_, _ = self._cache(dict_learning_online)(
277 data.T,
278 self.n_components,
279 alpha=self.alpha,
280 batch_size=self.batch_size,
281 method=self.method,
282 dict_init=dict_init,
283 verbose=max(0, self.verbose - 1),
284 random_state=self.random_state,
285 return_code=True,
286 shuffle=True,
287 n_jobs=1,
288 **kwargs,
289 )
290 self.components_ = self.components_.T
291 # Unit-variance scaling
292 S = np.sqrt(np.sum(self.components_**2, axis=1))
293 S[S == 0] = 1
294 self.components_ /= S[:, np.newaxis]
296 # Flip signs in each component so that positive part is l1 larger
297 # than negative part. Empirically this yield more positive looking maps
298 # than with setting the max to be positive.
299 for component in self.components_:
300 if np.sum(component > 0) < np.sum(component < 0):
301 component *= -1
302 if hasattr(self, "masker_"):
303 self.components_img_ = self.masker_.inverse_transform(
304 self.components_
305 )
307 return self