Coverage for nilearn/decoding/space_net_solvers.py: 11%
152 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 13:00 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 13:00 +0200
1"""Regression with spatial priors like TV-L1 and Graph-Net."""
3from math import sqrt
5import numpy as np
7from nilearn.masking import unmask_from_to_3d_array
9from ._objective_functions import (
10 divergence,
11 gradient,
12 gradient_id,
13 logistic_loss,
14 logistic_loss_grad,
15 logistic_loss_lipschitz_constant,
16 spectral_norm_squared,
17 squared_loss,
18 squared_loss_grad,
19)
20from ._proximal_operators import (
21 prox_l1,
22 prox_l1_with_intercept,
23 prox_tvl1,
24 prox_tvl1_with_intercept,
25)
26from .fista import mfista
29def _squared_loss_and_spatial_grad(X, y, w, mask, grad_weight):
30 """Compute the squared loss (data fidelity term) + squared l2 norm \
31 of gradient (penalty term).
33 Parameters
34 ----------
35 X : ndarray, shape (n_samples, n_features)
36 Design matrix.
38 y : ndarray, shape (n_samples,)
39 Target / response vector.
41 w : ndarray shape (n_features,)
42 Unmasked, ravelized weights map.
44 grad_weight : float
45 l1_ratio * alpha.
47 Returns
48 -------
49 float
50 Value of Graph-Net objective.
51 """
52 data_section = np.dot(X, w) - y
53 grad_buffer = np.zeros(mask.shape)
54 grad_buffer[mask] = w
55 grad_mask = np.tile(mask, [mask.ndim] + [1] * mask.ndim)
56 grad_section = gradient(grad_buffer)[grad_mask]
57 return 0.5 * (
58 np.dot(data_section, data_section)
59 + grad_weight * np.dot(grad_section, grad_section)
60 )
63def _squared_loss_and_spatial_grad_derivative(X, y, w, mask, grad_weight):
64 """Compute the derivative of _squared_loss_and_spatial_grad.
66 Parameters
67 ----------
68 X : ndarray, shape (n_samples, n_features)
69 Design matrix.
71 y : ndarray, shape (n_samples,)
72 Target / response vector.
74 w : ndarray shape (n_features,)
75 Unmasked, ravelized weights map.
77 grad_weight : float
78 l1_ratio * alpha
80 Returns
81 -------
82 ndarray, shape (n_features,)
83 Derivative of _squared_loss_and_spatial_grad function.
84 """
85 data_section = np.dot(X, w) - y
86 image_buffer = np.zeros(mask.shape)
87 image_buffer[mask] = w
88 return (
89 np.dot(X.T, data_section)
90 - grad_weight * divergence(gradient(image_buffer))[mask]
91 )
94def _graph_net_data_function(X, w, mask, grad_weight):
95 """Compute dot([X; grad_weight * grad], w).
97 This function is made for the Lasso-like interpretation of the
98 Graph-Net.
100 Parameters
101 ----------
102 X : ndarray, shape (n_samples, n_features)
103 Design matrix.
105 y : ndarray, shape (n_samples,)
106 Target / response vector.
108 w : ndarray shape (n_features,)
109 Unmasked, ravelized weights map.
111 grad_weight : float
112 l1_ratio * alpha.
114 Returns
115 -------
116 ndarray, shape (n_features + mask.ndim * n_samples,)
117 Data-fit term augmented with design matrix augmented with
118 nabla operator (for spatial gradient).
119 """
120 data_buffer = np.zeros(mask.shape)
121 data_buffer[mask] = w
122 w_g = grad_weight * gradient(data_buffer)
123 out = np.ndarray(X.shape[0] + mask.ndim * X.shape[1])
124 out[: X.shape[0]] = X.dot(w)
125 out[X.shape[0] :] = np.concatenate(
126 tuple(w_g[i][mask] for i in range(mask.ndim))
127 )
128 return out
131def _graph_net_adjoint_data_function(X, w, adjoint_mask, grad_weight):
132 """Compute the adjoint of the _graph_net_data_function.
134 That is:
136 np.dot([X.T; grad_weight * div], w).
138 This function is made for the Lasso-like interpretation of the Graph-Net.
140 Parameters
141 ----------
142 X : ndarray, shape (n_samples, n_features)
143 Design matrix.
145 y : ndarray, shape (n_samples,)
146 Target / response vector.
148 w : ndarray shape (n_features,)
149 Unmasked, ravelized weights map.
151 grad_weight : float
152 l1_ratio * alpha.
154 Returns
155 -------
156 ndarray, shape (n_samples,)
157 Value of adjoint.
158 """
159 n_samples, _ = X.shape
160 out = X.T.dot(w[:n_samples])
161 div_buffer = np.zeros(adjoint_mask.shape)
162 div_buffer[adjoint_mask] = w[n_samples:]
163 out -= grad_weight * divergence(div_buffer)[adjoint_mask[0]]
164 return out
167def _squared_loss_derivative_lipschitz_constant(
168 X, mask, grad_weight, n_iterations=100
169):
170 """Compute the lipschitz constant of the gradient of the smooth part \
171 of the Graph-Net regression problem (squared_loss + grad_weight*grad) \
172 via power method.
173 """
174 rng = np.random.RandomState(42)
175 a = rng.randn(X.shape[1])
176 a /= sqrt(np.dot(a, a))
177 adjoint_mask = np.tile(mask, [mask.ndim] + [1] * mask.ndim)
179 # Since we are putting the coefficient into the matrix, which
180 # is squared in the data loss function, it must be the
181 # square root of the desired weight
182 actual_grad_weight = sqrt(grad_weight)
183 for _ in range(n_iterations):
184 a = _graph_net_adjoint_data_function(
185 X,
186 _graph_net_data_function(X, a, mask, actual_grad_weight),
187 adjoint_mask,
188 actual_grad_weight,
189 )
190 a /= sqrt(np.dot(a, a))
192 lipschitz_constant = np.dot(
193 _graph_net_adjoint_data_function(
194 X,
195 _graph_net_data_function(X, a, mask, actual_grad_weight),
196 adjoint_mask,
197 actual_grad_weight,
198 ),
199 a,
200 ) / np.dot(a, a)
202 return lipschitz_constant
205def _logistic_derivative_lipschitz_constant(
206 X, mask, grad_weight, n_iterations=100
207):
208 """Compute the lipschitz constant of the gradient of the smooth part \
209 of the Graph-Net classification problem (logistic_loss + \
210 grad_weight*grad) via analytical formula on the logistic loss + \
211 power method on the smooth part.
212 """
213 # L. constant for the data term (logistic)
214 # data_constant = sp.linalg.norm(X, 2) ** 2
215 data_constant = logistic_loss_lipschitz_constant(X)
217 rng = np.random.RandomState(42)
218 a = rng.randn(X.shape[1])
219 a /= sqrt(np.dot(a, a))
220 grad_buffer = np.zeros(mask.shape)
221 for _ in range(n_iterations):
222 grad_buffer[mask] = a
223 a = -divergence(gradient(grad_buffer))[mask] / sqrt(np.dot(a, a))
225 grad_buffer[mask] = a
226 grad_constant = -np.dot(
227 divergence(gradient(grad_buffer))[mask], a
228 ) / np.dot(a, a)
230 return data_constant + grad_weight * grad_constant
233def _logistic_data_loss_and_spatial_grad(X, y, w, mask, grad_weight):
234 """Compute the smooth part of the Graph-Net objective, \
235 with logistic loss.
236 """
237 grad_buffer = np.zeros(mask.shape)
238 grad_buffer[mask] = w[:-1]
239 grad_mask = np.array([mask for _ in range(mask.ndim)])
240 grad_section = gradient(grad_buffer)[grad_mask]
241 return logistic_loss(X, y, w) + 0.5 * grad_weight * np.dot(
242 grad_section, grad_section
243 )
246def _logistic_data_loss_and_spatial_grad_derivative(
247 X, y, w, mask, grad_weight
248):
249 """Compute the derivative of _logistic_loss_and_spatial_grad."""
250 image_buffer = np.zeros(mask.shape)
251 image_buffer[mask] = w[:-1]
252 data_section = logistic_loss_grad(X, y, w)
253 data_section[:-1] = (
254 data_section[:-1]
255 - grad_weight * divergence(gradient(image_buffer))[mask]
256 )
257 return data_section
260def graph_net_squared_loss(
261 X,
262 y,
263 alpha,
264 l1_ratio,
265 mask,
266 init=None,
267 max_iter=1000,
268 tol=1e-4,
269 callback=None,
270 lipschitz_constant=None,
271 verbose=0,
272):
273 """Compute a solution for the Graph-Net regression problem.
275 This function invokes the mfista backend (from fista.py) to solve the
276 underlying optimization problem.
278 Returns
279 -------
280 w : ndarray, shape (n_features,)
281 Solution vector.
283 solver_info : float
284 Solver information, for warm start.
286 objective : array of floats
287 Objective function (fval) computed on every iteration.
289 """
290 _, n_features = X.shape
292 # misc
293 model_size = n_features
294 l1_weight = alpha * l1_ratio
295 grad_weight = alpha * (1.0 - l1_ratio)
297 if lipschitz_constant is None:
298 lipschitz_constant = _squared_loss_derivative_lipschitz_constant(
299 X, mask, grad_weight
300 )
302 # it's always a good idea to use something a bit bigger
303 lipschitz_constant *= 1.05
305 # smooth part of energy, and gradient thereof
306 def f1(w):
307 return _squared_loss_and_spatial_grad(X, y, w, mask, grad_weight)
309 def f1_grad(w):
310 return _squared_loss_and_spatial_grad_derivative(
311 X, y, w, mask, grad_weight
312 )
314 # prox of nonsmooth path of energy (account for the intercept)
315 def f2(w):
316 return np.sum(np.abs(w)) * l1_weight
318 def f2_prox(w, step_size, *args, **kwargs): # noqa: ARG001
319 return prox_l1(w, step_size * l1_weight), {"converged": True}
321 # total energy (smooth + nonsmooth)
322 def total_energy(w):
323 return f1(w) + f2(w)
325 return mfista(
326 f1_grad,
327 f2_prox,
328 total_energy,
329 lipschitz_constant,
330 model_size,
331 dgap_factor=(0.1 + l1_ratio) ** 2,
332 callback=callback,
333 tol=tol,
334 max_iter=max_iter,
335 verbose=verbose,
336 init=init,
337 )
340def graph_net_logistic(
341 X,
342 y,
343 alpha,
344 l1_ratio,
345 mask,
346 init=None,
347 max_iter=1000,
348 tol=1e-4,
349 callback=None,
350 verbose=0,
351 lipschitz_constant=None,
352):
353 """Compute a solution for the Graph-Net classification problem, \
354 with response vector in {-1, 1}^n_samples.
356 This function invokes the mfista backend (from fista.py) to solve the
357 underlying optimization problem.
359 Returns
360 -------
361 w : ndarray of shape (n_features,)
362 The solution vector (Where `n_features` is the size of the support
363 of the mask.)
365 solver_info : dict
366 Solver information for warm starting. See fista.py.mfista(...)
367 function for detailed documentation.
369 objective : array of floats
370 Cost function (fval) computed on every iteration.
372 """
373 _, n_features = X.shape
375 # misc
376 model_size = n_features + 1
377 l1_weight = alpha * l1_ratio
378 grad_weight = alpha * (1 - l1_ratio)
380 if lipschitz_constant is None:
381 lipschitz_constant = _logistic_derivative_lipschitz_constant(
382 X, mask, grad_weight
383 )
385 # it's always a good idea to use somethx a bit bigger
386 lipschitz_constant *= 1.1
388 # smooth part of energy, and gradient of
389 def f1(w):
390 return _logistic_data_loss_and_spatial_grad(X, y, w, mask, grad_weight)
392 def f1_grad(w):
393 return _logistic_data_loss_and_spatial_grad_derivative(
394 X, y, w, mask, grad_weight
395 )
397 # prox of nonsmooth path of energy (account for the intercept)
398 def f2(w):
399 return np.sum(np.abs(w[:-1])) * l1_weight
401 def f2_prox(w, step_size, *args, **kwargs): # noqa: ARG001
402 return prox_l1_with_intercept(w, step_size * l1_weight), {
403 "converged": True
404 }
406 # total energy (smooth + nonsmooth)
407 def total_energy(w):
408 return f1(w) + f2(w)
410 # finally, run the solver proper
411 return mfista(
412 f1_grad,
413 f2_prox,
414 total_energy,
415 lipschitz_constant,
416 model_size,
417 dgap_factor=(0.1 + l1_ratio) ** 2,
418 callback=callback,
419 tol=tol,
420 max_iter=max_iter,
421 verbose=verbose,
422 init=init,
423 )
426def _tvl1_objective_from_gradient(gradient):
427 """Compute TV-l1 objective function from gradient.
429 Parameters
430 ----------
431 gradient : ndarray, shape (4, nx, ny, nz)
432 precomputed "gradient + id" array
434 Returns
435 -------
436 float
437 Value of TV-L1 penalty.
438 """
439 tv_term = np.sum(np.sqrt(np.sum(gradient[:-1] * gradient[:-1], axis=0)))
440 l1_term = np.abs(gradient[-1]).sum()
441 return l1_term + tv_term
444def _tvl1_objective(X, y, w, alpha, l1_ratio, mask, loss="mse"):
445 """Compute the TV-L1 squared loss regression objective functions.
447 Returns
448 -------
449 float
450 Value of TV-L1 penalty.
451 """
452 loss = loss.lower()
453 if loss not in ["mse", "logistic"]:
454 raise ValueError(
455 f"loss must be one of 'mse' or 'logistic'; got '{loss}'"
456 )
458 if loss == "mse":
459 out = squared_loss(X, y, w)
460 else:
461 out = logistic_loss(X, y, w)
462 w = w[:-1]
464 grad_id = gradient_id(unmask_from_to_3d_array(w, mask), l1_ratio=l1_ratio)
465 out += alpha * _tvl1_objective_from_gradient(grad_id)
467 return out
470def tvl1_solver(
471 X,
472 y,
473 alpha,
474 l1_ratio,
475 mask,
476 loss=None,
477 max_iter=100,
478 lipschitz_constant=None,
479 init=None,
480 prox_max_iter=5000,
481 tol=1e-4,
482 callback=None,
483 verbose=1,
484):
485 """Minimizes empirical risk for TV-L1 penalized models.
487 Can handle least squares (mean squared error --a.k.a mse) or logistic
488 regression. The same solver works for both of these losses.
490 This function invokes the mfista backend (from fista.py) to solver the
491 underlying optimization problem.
493 Parameters
494 ----------
495 X : ndarray, shape (n_samples, n_features)
496 Design matrix.
498 y : ndarray, shape (n_samples,)
499 Target / response vector.
501 alpha : :obj:`float`, default=1.0
502 Constant that scales the overall regularization term.
504 l1_ratio : :obj:`float` in the interval [0, 1]; default=0.5
505 Constant that mixes L1 and TV penalization.
506 l1_ratio == 0 : just smooth. l1_ratio == 1 : just lasso.
508 mask : ndarray, shape (nx, ny, nz)
509 The support of this mask defines the ROIs being considered in
510 the problem.
512 max_iter : :obj:`int`, default=100
513 Defines the iterations for the solver.
515 prox_max_iter : :obj:`int`, default=5000
516 Maximum number of iterations for inner FISTA loop in which
517 the prox of TV is approximated.
519 tol : :obj:`float`, default=1e-4
520 Defines the tolerance for convergence.
522 loss : :obj:`str` or None
523 Loss model for regression. Can be "mse" (for squared loss) or
524 "logistic" (for logistic loss).
526 lipschitz_constant : :obj:`float`, default=None
527 Lipschitz constant (i.e an upper bound of) of gradient of smooth part
528 of the energy being minimized. If no value is specified (None),
529 then it will be calculated.
531 callback : callable(dict) -> bool, default=None
532 Function called at the end of every energy descendent iteration of the
533 solver. If it returns True, the loop breaks.
535 Returns
536 -------
537 w : ndarray, shape (n_features,)
538 The solution vector (Where `w_size` is the size of the support of the
539 mask.)
541 objective : array of floats
542 Objective function (fval) computed on every iteration.
544 solver_info : float
545 Solver information, for warm start.
547 """
548 # sanitize loss
549 if loss not in ["mse", "logistic"]:
550 raise ValueError(
551 f"'{loss}' loss not implemented. Should be 'mse' or 'logistic"
552 )
554 # shape of image box
555 flat_mask = mask.ravel()
556 volume_shape = mask.shape
558 # in logistic regression, we fit the intercept explicitly
559 w_size = X.shape[1] + int(loss == "logistic")
561 def unmaskvec(w):
562 if loss == "mse":
563 return unmask_from_to_3d_array(w, mask)
564 else:
565 return np.append(unmask_from_to_3d_array(w[:-1], mask), w[-1])
567 def maskvec(w):
568 return (
569 w[flat_mask]
570 if loss == "mse"
571 else np.append(w[:-1][flat_mask], w[-1])
572 )
574 # function to compute derivative of f1
575 def f1_grad(w):
576 if loss == "logistic":
577 return logistic_loss_grad(X, y, w)
578 else:
579 return squared_loss_grad(X, y, w)
581 # function to compute total energy (i.e smooth (f1) + nonsmooth (f2) parts)
582 def total_energy(w):
583 return _tvl1_objective(X, y, w, alpha, l1_ratio, mask, loss=loss)
585 # Lipschitz constant of f1_grad
586 if lipschitz_constant is None:
587 if loss == "mse":
588 lipschitz_constant = 1.05 * spectral_norm_squared(X)
589 else:
590 lipschitz_constant = 1.1 * logistic_loss_lipschitz_constant(X)
592 # proximal operator of nonsmooth proximable part of energy (f2)
593 if loss == "mse":
595 def f2_prox(w, stepsize, dgap_tol, init=None):
596 out, info = prox_tvl1(
597 unmaskvec(w),
598 weight=alpha * stepsize,
599 l1_ratio=l1_ratio,
600 dgap_tol=dgap_tol,
601 init=unmaskvec(init),
602 max_iter=prox_max_iter,
603 verbose=verbose,
604 )
605 return maskvec(out.ravel()), info
607 else:
609 def f2_prox(w, stepsize, dgap_tol, init=None):
610 out, info = prox_tvl1_with_intercept(
611 unmaskvec(w),
612 volume_shape,
613 l1_ratio,
614 alpha * stepsize,
615 dgap_tol,
616 prox_max_iter,
617 init=(
618 unmask_from_to_3d_array(init[:-1], mask)
619 if init is not None
620 else None
621 ),
622 verbose=verbose,
623 )
624 return maskvec(out.ravel()), info
626 # invoke m-FISTA solver
627 w, obj, init = mfista(
628 f1_grad,
629 f2_prox,
630 total_energy,
631 lipschitz_constant,
632 w_size,
633 dgap_factor=(0.1 + l1_ratio) ** 2,
634 tol=tol,
635 init=init,
636 verbose=verbose,
637 max_iter=max_iter,
638 callback=callback,
639 )
641 return w, obj, init