Coverage for nilearn/decoding/fista.py: 6%
77 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 13:00 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 13:00 +0200
1"""Generic FISTA for solving TV-L1, Graph-Net, etc., problems.
3For problems on which the prox of the nonsmooth term \
4cannot be computed closed-form (e.g TV-L1), \
5we approximate the prox using an inner FISTA loop.
6"""
8from math import sqrt
10import numpy as np
11from scipy import linalg
13from nilearn._utils import logger
16def _check_lipschitz_continuous(
17 f, ndim, lipschitz_constant, n_trials=10, random_state=42
18):
19 """Empirically check Lipschitz continuity of a function.
21 If this test is passed, then we are empirically confident in the
22 Lipschitz continuity of the function with respect to the given
23 constant `L`. This confidence increases with the `n_trials` parameter.
25 Parameters
26 ----------
27 f : callable,
28 The function to be checked for Lipschitz continuity.
29 `f` takes a vector of float as unique argument.
30 The size of the input vector is determined by `ndim`.
32 ndim : int,
33 Dimension of the input of the function to be checked for Lipschitz
34 continuity (i.e. it corresponds to the size of the vector that `f`
35 takes as an argument).
37 lispchitz_constant : float,
38 Constant associated to the Lipschitz continuity.
40 n_trials : int,
41 Number of tests performed when assessing the Lipschitz continuity of
42 function `f`. The more tests, the more confident we are in the
43 Lipschitz continuity of `f` if the test passes.
45 %(random_state)s
46 default 42
48 Raises
49 ------
50 RuntimeError
51 """
52 rng = np.random.default_rng(random_state)
53 for x in rng.standard_normal((n_trials, ndim)):
54 for y in rng.standard_normal((n_trials, ndim)):
55 a = linalg.norm(f(x).ravel() - f(y).ravel(), 2)
56 b = lipschitz_constant * linalg.norm(x - y, 2)
57 if a > b:
58 raise RuntimeError(f"Counter example: ({x}, {y})")
61def mfista(
62 f1_grad,
63 f2_prox,
64 total_energy,
65 lipschitz_constant,
66 w_size,
67 dgap_tol=None,
68 init=None,
69 max_iter=1000,
70 tol=1e-4,
71 check_lipschitz=False,
72 dgap_factor=None,
73 callback=None,
74 verbose=2,
75):
76 """Solve FISTA in a generic way.
78 Minimizes the a sum `f + g` of two convex functions f (smooth)
79 and g (proximable nonsmooth).
81 Parameters
82 ----------
83 f1_grad : callable(w) -> np.array
84 Gradient of smooth part of energy
86 f2_prox : callable(w, stepsize, dgap_tol, init?) -> float, dict
87 Proximal operator of non-smooth part of energy (f2).
88 The returned dict should have a key "converged", whose value
89 indicates whether the prox computation converged.
91 total_energy : callable(w) -> float
92 total energy (i.e smooth (f1) + nonsmooth (f2) parts)
94 lipschitz_constant : :obj:`float`
95 Lipschitz constant of gradient of f1_grad.
97 check_lipschitz : :obj:`bool`, default=False
98 If True, check Lipschitz continuity of gradient of smooth part.
100 w_size : :obj:`int`
101 Size of the solution. f1, f2, f1_grad, f2_prox (fixed l, tol) must
102 accept a w such that w.shape = (w_size,).
104 tol : :obj:`float`, default=1e-4
105 Tolerance on the (primal) cost function.
107 dgap_tol : :obj:`float`, default=None
108 If None, the nonsmooth_prox argument returns a float, with the value,
109 if not 0, the nonsmooth_prox accepts a third parameter tol, which is
110 the tolerance on the computation of the proximal operator and returns a
111 float, and a dict with the key "converged", that says if the method to
112 compute f2_prox converged or not.
114 dgap_factor : :obj:`float`, default=None
115 Dual gap factor. Used for debugging purpose (control the convergence).
117 init : dict-like, default=None
118 Dictionary of initialization parameters. Possible keys are 'w',
119 'stepsize', 'z', 't', 'dgap_factor', etc.
121 callback : callable(dict) -> bool
122 Function called on every iteration. If it returns True, then the loop
123 breaks.
125 max_iter : :obj:`int`, default=1000
126 Maximum number of iterations for the solver.
128 verbose : :obj:`int`, default=2
129 Indicate the level of verbosity.
131 Returns
132 -------
133 w : ndarray, shape (w_size,)
134 A minimizer for `f + g`.
136 solver_info : float
137 Solver information, for warm starting.
139 cost : array of floats
140 Cost function (fval) computed on every iteration.
142 Notes
143 -----
144 A motivation for the choice of FISTA as a solver for the TV-L1
145 penalized problems emerged in the paper: Elvis Dohmatob,
146 Alexandre Gramfort, Bertrand Thirion, Gael Varoquaux,
147 "Benchmarking solvers for TV-L1 least-squares and logistic regression
148 in brain imaging". Pattern Recognition in Neuroimaging (PRNI),
149 Jun 2014, Tubingen, Germany. IEEE
151 """
152 # initialization
153 if init is None:
154 init = {}
155 w = init.get("w", np.zeros(w_size))
156 z = init.get("z", w.copy())
157 t = init.get("t", 1.0)
158 stepsize = init.get("stepsize", 1.0 / lipschitz_constant)
159 if dgap_tol is None:
160 dgap_tol = init.get("dgap_tol", np.inf)
161 if dgap_factor is None:
162 dgap_factor = init.get("dgap_factor", 1.0)
164 # check Lipschitz continuity of gradient of smooth part
165 if check_lipschitz:
166 _check_lipschitz_continuous(f1_grad, w_size, lipschitz_constant)
168 # aux variables
169 old_energy = total_energy(w)
170 energy_delta = np.inf
171 best_w = w.copy()
172 best_energy = old_energy
173 best_dgap_tol = dgap_tol
174 ista_step = False
175 best_z = z.copy()
176 best_t = t
177 prox_info = {"converged": True}
178 stepsize = 1.0 / lipschitz_constant
179 history = []
180 w_old = w.copy()
182 # FISTA loop
183 for i in range(max_iter):
184 history.append(old_energy)
185 w_old[:] = w
187 # invoke callback
188 logger.log(
189 f"mFISTA: Iteration {i + 1: 2}/{max_iter:2}: "
190 f"E = {old_energy:7.4e}, dE {energy_delta: 4.4e}",
191 verbose,
192 )
193 if callback and callback(locals()):
194 break
195 if np.abs(energy_delta) < tol:
196 logger.log(f"\tConverged (|dE| < {tol:g})", verbose)
197 break
199 # forward (gradient) step
200 gradient_buffer = f1_grad(z)
202 # backward (prox) step
203 for _ in range(10):
204 w, prox_info = f2_prox(
205 z - stepsize * gradient_buffer,
206 stepsize,
207 dgap_factor * dgap_tol,
208 init=w,
209 )
210 energy = total_energy(w)
211 if (
212 not ista_step
213 or not prox_info["converged"]
214 or old_energy > energy
215 ):
216 break
218 # Even when doing ISTA steps we are not decreasing.
219 # Thus we need a tighter dual_gap on the prox_tv
220 # This corresponds to a line search on the dual_gap
221 # tolerance.
222 dgap_factor *= 0.2
224 logger.log("decreased dgap_tol", verbose)
225 # energy house-keeping
226 energy_delta = old_energy - energy
227 old_energy = energy
229 # z update
230 if energy_delta < 0.0:
231 # M-FISTA strategy: rewind and switch temporarily to an ISTA step
232 z[:] = w_old
233 w[:] = w_old
234 ista_step = True
235 logger.log("Monotonous FISTA: Switching to ISTA", verbose)
236 else:
237 if ista_step:
238 z = w
239 else:
240 t0 = t
241 t = 0.5 * (1.0 + sqrt(1.0 + 4.0 * t * t))
242 z = w + ((t0 - 1.0) / t) * (w - w_old)
243 ista_step = False
245 # misc
246 if energy_delta != 0.0:
247 # We need to decrease the tolerance on the dual_gap as 1/i**4
248 # (see Mark Schmidt, Nicolas le Roux and Francis Bach, NIPS
249 # 2011), thus we need to count how many times we are called,
250 # hence the callable class. In practice, empirically I (Gael)
251 # have found that such a sharp decrease was counter
252 # productive in terms of computation time, as it leads to too
253 # much time spent in the prox_tvl1 calls.
254 #
255 # For this reason, we rely more on the linesearch-like
256 # strategy to set the dgap_tol
257 dgap_tol = abs(energy_delta) / (i + 1.0)
259 # dgap_tol house-keeping
260 if energy < best_energy:
261 best_energy = energy
262 best_w[:] = w
263 best_z[:] = z
264 best_t = t
265 best_dgap_tol = dgap_tol
267 init = {
268 "w": best_w.copy(),
269 "z": best_z,
270 "t": best_t,
271 "dgap_tol": best_dgap_tol,
272 "stepsize": stepsize,
273 }
274 return best_w, history, init