Coverage for nilearn/decoding/fista.py: 6%

77 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 13:00 +0200

1"""Generic FISTA for solving TV-L1, Graph-Net, etc., problems. 

2 

3For problems on which the prox of the nonsmooth term \ 

4cannot be computed closed-form (e.g TV-L1), \ 

5we approximate the prox using an inner FISTA loop. 

6""" 

7 

8from math import sqrt 

9 

10import numpy as np 

11from scipy import linalg 

12 

13from nilearn._utils import logger 

14 

15 

16def _check_lipschitz_continuous( 

17 f, ndim, lipschitz_constant, n_trials=10, random_state=42 

18): 

19 """Empirically check Lipschitz continuity of a function. 

20 

21 If this test is passed, then we are empirically confident in the 

22 Lipschitz continuity of the function with respect to the given 

23 constant `L`. This confidence increases with the `n_trials` parameter. 

24 

25 Parameters 

26 ---------- 

27 f : callable, 

28 The function to be checked for Lipschitz continuity. 

29 `f` takes a vector of float as unique argument. 

30 The size of the input vector is determined by `ndim`. 

31 

32 ndim : int, 

33 Dimension of the input of the function to be checked for Lipschitz 

34 continuity (i.e. it corresponds to the size of the vector that `f` 

35 takes as an argument). 

36 

37 lispchitz_constant : float, 

38 Constant associated to the Lipschitz continuity. 

39 

40 n_trials : int, 

41 Number of tests performed when assessing the Lipschitz continuity of 

42 function `f`. The more tests, the more confident we are in the 

43 Lipschitz continuity of `f` if the test passes. 

44 

45 %(random_state)s 

46 default 42 

47 

48 Raises 

49 ------ 

50 RuntimeError 

51 """ 

52 rng = np.random.default_rng(random_state) 

53 for x in rng.standard_normal((n_trials, ndim)): 

54 for y in rng.standard_normal((n_trials, ndim)): 

55 a = linalg.norm(f(x).ravel() - f(y).ravel(), 2) 

56 b = lipschitz_constant * linalg.norm(x - y, 2) 

57 if a > b: 

58 raise RuntimeError(f"Counter example: ({x}, {y})") 

59 

60 

61def mfista( 

62 f1_grad, 

63 f2_prox, 

64 total_energy, 

65 lipschitz_constant, 

66 w_size, 

67 dgap_tol=None, 

68 init=None, 

69 max_iter=1000, 

70 tol=1e-4, 

71 check_lipschitz=False, 

72 dgap_factor=None, 

73 callback=None, 

74 verbose=2, 

75): 

76 """Solve FISTA in a generic way. 

77 

78 Minimizes the a sum `f + g` of two convex functions f (smooth) 

79 and g (proximable nonsmooth). 

80 

81 Parameters 

82 ---------- 

83 f1_grad : callable(w) -> np.array 

84 Gradient of smooth part of energy 

85 

86 f2_prox : callable(w, stepsize, dgap_tol, init?) -> float, dict 

87 Proximal operator of non-smooth part of energy (f2). 

88 The returned dict should have a key "converged", whose value 

89 indicates whether the prox computation converged. 

90 

91 total_energy : callable(w) -> float 

92 total energy (i.e smooth (f1) + nonsmooth (f2) parts) 

93 

94 lipschitz_constant : :obj:`float` 

95 Lipschitz constant of gradient of f1_grad. 

96 

97 check_lipschitz : :obj:`bool`, default=False 

98 If True, check Lipschitz continuity of gradient of smooth part. 

99 

100 w_size : :obj:`int` 

101 Size of the solution. f1, f2, f1_grad, f2_prox (fixed l, tol) must 

102 accept a w such that w.shape = (w_size,). 

103 

104 tol : :obj:`float`, default=1e-4 

105 Tolerance on the (primal) cost function. 

106 

107 dgap_tol : :obj:`float`, default=None 

108 If None, the nonsmooth_prox argument returns a float, with the value, 

109 if not 0, the nonsmooth_prox accepts a third parameter tol, which is 

110 the tolerance on the computation of the proximal operator and returns a 

111 float, and a dict with the key "converged", that says if the method to 

112 compute f2_prox converged or not. 

113 

114 dgap_factor : :obj:`float`, default=None 

115 Dual gap factor. Used for debugging purpose (control the convergence). 

116 

117 init : dict-like, default=None 

118 Dictionary of initialization parameters. Possible keys are 'w', 

119 'stepsize', 'z', 't', 'dgap_factor', etc. 

120 

121 callback : callable(dict) -> bool 

122 Function called on every iteration. If it returns True, then the loop 

123 breaks. 

124 

125 max_iter : :obj:`int`, default=1000 

126 Maximum number of iterations for the solver. 

127 

128 verbose : :obj:`int`, default=2 

129 Indicate the level of verbosity. 

130 

131 Returns 

132 ------- 

133 w : ndarray, shape (w_size,) 

134 A minimizer for `f + g`. 

135 

136 solver_info : float 

137 Solver information, for warm starting. 

138 

139 cost : array of floats 

140 Cost function (fval) computed on every iteration. 

141 

142 Notes 

143 ----- 

144 A motivation for the choice of FISTA as a solver for the TV-L1 

145 penalized problems emerged in the paper: Elvis Dohmatob, 

146 Alexandre Gramfort, Bertrand Thirion, Gael Varoquaux, 

147 "Benchmarking solvers for TV-L1 least-squares and logistic regression 

148 in brain imaging". Pattern Recognition in Neuroimaging (PRNI), 

149 Jun 2014, Tubingen, Germany. IEEE 

150 

151 """ 

152 # initialization 

153 if init is None: 

154 init = {} 

155 w = init.get("w", np.zeros(w_size)) 

156 z = init.get("z", w.copy()) 

157 t = init.get("t", 1.0) 

158 stepsize = init.get("stepsize", 1.0 / lipschitz_constant) 

159 if dgap_tol is None: 

160 dgap_tol = init.get("dgap_tol", np.inf) 

161 if dgap_factor is None: 

162 dgap_factor = init.get("dgap_factor", 1.0) 

163 

164 # check Lipschitz continuity of gradient of smooth part 

165 if check_lipschitz: 

166 _check_lipschitz_continuous(f1_grad, w_size, lipschitz_constant) 

167 

168 # aux variables 

169 old_energy = total_energy(w) 

170 energy_delta = np.inf 

171 best_w = w.copy() 

172 best_energy = old_energy 

173 best_dgap_tol = dgap_tol 

174 ista_step = False 

175 best_z = z.copy() 

176 best_t = t 

177 prox_info = {"converged": True} 

178 stepsize = 1.0 / lipschitz_constant 

179 history = [] 

180 w_old = w.copy() 

181 

182 # FISTA loop 

183 for i in range(max_iter): 

184 history.append(old_energy) 

185 w_old[:] = w 

186 

187 # invoke callback 

188 logger.log( 

189 f"mFISTA: Iteration {i + 1: 2}/{max_iter:2}: " 

190 f"E = {old_energy:7.4e}, dE {energy_delta: 4.4e}", 

191 verbose, 

192 ) 

193 if callback and callback(locals()): 

194 break 

195 if np.abs(energy_delta) < tol: 

196 logger.log(f"\tConverged (|dE| < {tol:g})", verbose) 

197 break 

198 

199 # forward (gradient) step 

200 gradient_buffer = f1_grad(z) 

201 

202 # backward (prox) step 

203 for _ in range(10): 

204 w, prox_info = f2_prox( 

205 z - stepsize * gradient_buffer, 

206 stepsize, 

207 dgap_factor * dgap_tol, 

208 init=w, 

209 ) 

210 energy = total_energy(w) 

211 if ( 

212 not ista_step 

213 or not prox_info["converged"] 

214 or old_energy > energy 

215 ): 

216 break 

217 

218 # Even when doing ISTA steps we are not decreasing. 

219 # Thus we need a tighter dual_gap on the prox_tv 

220 # This corresponds to a line search on the dual_gap 

221 # tolerance. 

222 dgap_factor *= 0.2 

223 

224 logger.log("decreased dgap_tol", verbose) 

225 # energy house-keeping 

226 energy_delta = old_energy - energy 

227 old_energy = energy 

228 

229 # z update 

230 if energy_delta < 0.0: 

231 # M-FISTA strategy: rewind and switch temporarily to an ISTA step 

232 z[:] = w_old 

233 w[:] = w_old 

234 ista_step = True 

235 logger.log("Monotonous FISTA: Switching to ISTA", verbose) 

236 else: 

237 if ista_step: 

238 z = w 

239 else: 

240 t0 = t 

241 t = 0.5 * (1.0 + sqrt(1.0 + 4.0 * t * t)) 

242 z = w + ((t0 - 1.0) / t) * (w - w_old) 

243 ista_step = False 

244 

245 # misc 

246 if energy_delta != 0.0: 

247 # We need to decrease the tolerance on the dual_gap as 1/i**4 

248 # (see Mark Schmidt, Nicolas le Roux and Francis Bach, NIPS 

249 # 2011), thus we need to count how many times we are called, 

250 # hence the callable class. In practice, empirically I (Gael) 

251 # have found that such a sharp decrease was counter 

252 # productive in terms of computation time, as it leads to too 

253 # much time spent in the prox_tvl1 calls. 

254 # 

255 # For this reason, we rely more on the linesearch-like 

256 # strategy to set the dgap_tol 

257 dgap_tol = abs(energy_delta) / (i + 1.0) 

258 

259 # dgap_tol house-keeping 

260 if energy < best_energy: 

261 best_energy = energy 

262 best_w[:] = w 

263 best_z[:] = z 

264 best_t = t 

265 best_dgap_tol = dgap_tol 

266 

267 init = { 

268 "w": best_w.copy(), 

269 "z": best_z, 

270 "t": best_t, 

271 "dgap_tol": best_dgap_tol, 

272 "stepsize": stepsize, 

273 } 

274 return best_w, history, init