Coverage for nilearn/decoding/_proximal_operators.py: 7%

114 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 13:00 +0200

1"""Implementation of multiple proximal operators for TV-L1, Graph-Net, etc.""" 

2 

3from math import sqrt 

4 

5import numpy as np 

6 

7from nilearn._utils import logger 

8 

9from ._objective_functions import ( 

10 divergence_id, 

11 gradient_id, 

12 tv_l1_from_gradient, 

13) 

14 

15 

16def prox_l1(y, alpha, copy=True): 

17 """Compute proximity operator for L1 norm.""" 

18 shrink = np.zeros(y.shape) 

19 if copy: 

20 y = y.copy() 

21 y_nz = y.nonzero() 

22 shrink[y_nz] = np.maximum(1 - alpha / np.abs(y[y_nz]), 0) 

23 y *= shrink 

24 return y 

25 

26 

27def prox_l1_with_intercept(x, tau): 

28 """Return the same as prox_l1, but just for the n-1 components.""" 

29 x[:-1] = prox_l1(x[:-1], tau) 

30 return x 

31 

32 

33def _projector_on_tvl1_dual(grad, l1_ratio): 

34 """Compute TV-l1 duality gap. 

35 

36 Modifies IN PLACE the gradient + id to project it 

37 on the l21 unit ball in the gradient direction and the L1 ball in the 

38 identity direction. 

39 """ 

40 # The l21 ball for the gradient direction 

41 if l1_ratio < 1.0: 

42 # infer number of axes and include an additional axis if l1_ratio > 0 

43 end = len(grad) - int(l1_ratio > 0.0) 

44 norm = np.sqrt(np.sum(grad[:end] * grad[:end], 0)) 

45 norm.clip(1.0, out=norm) # set everythx < 1 to 1 

46 for grad_comp in grad[:end]: 

47 grad_comp /= norm 

48 

49 # The L1 ball for the identity direction 

50 if l1_ratio > 0.0: 

51 norm = np.abs(grad[-1]) 

52 norm.clip(1.0, out=norm) 

53 grad[-1] /= norm 

54 

55 return grad 

56 

57 

58def _dual_gap_prox_tvl1(input_img_norm, new, gap, weight, l1_ratio=1.0): 

59 """Compute dual gap of total variation denoising. 

60 

61 See "Total variation regularization for fMRI-based prediction of behavior", 

62 by Michel et al. (2011) for a derivation of the dual gap 

63 """ 

64 tv_new = tv_l1_from_gradient(gradient_id(new, l1_ratio=l1_ratio)) 

65 gap = gap.ravel() 

66 d_gap = ( 

67 np.dot(gap, gap) 

68 + 2 * weight * tv_new 

69 - input_img_norm 

70 + (new * new).sum() 

71 ) 

72 return 0.5 * d_gap 

73 

74 

75def _objective_function_prox_tvl1(input_img, output_img, gradient, weight): 

76 diff = (input_img - output_img).ravel() 

77 return 0.5 * (diff * diff).sum() + weight * tv_l1_from_gradient(gradient) 

78 

79 

80def prox_tvl1( 

81 input_img, 

82 l1_ratio=0.05, 

83 weight=50, 

84 dgap_tol=5.0e-5, 

85 x_tol=None, 

86 max_iter=200, 

87 check_gap_frequency=4, 

88 val_min=None, 

89 val_max=None, 

90 verbose=0, 

91 fista=True, 

92 init=None, 

93): 

94 """ 

95 Compute the TV-L1 proximal (ie total-variation +l1 denoising) on 3d images. 

96 

97 Find the argmin `res` of 

98 1/2 * ||im - res||^2 + weight * TVl1(res), 

99 

100 Parameters 

101 ---------- 

102 input_img : ndarray of floats (2-d or 3-d) 

103 Input data to be denoised. `im` can be of any numeric type, 

104 but it is cast into an ndarray of floats for the computation 

105 of the denoised image. 

106 

107 weight : float, optional 

108 Denoising weight. The greater ``weight``, the more denoising (at 

109 the expense of fidelity to ``input``) 

110 

111 dgap_tol : float, optional 

112 Precision required. The distance to the exact solution is computed 

113 by the dual gap of the optimization problem and rescaled by the 

114 squared l2 norm of the image (for contrast invariance). 

115 

116 x_tol : float or None, optional 

117 The maximal relative difference between input and output. If 

118 specified, this specifies a stopping criterion on x, rather than 

119 the dual gap. 

120 

121 max_iter : int, optional 

122 Maximal number of iterations used for the optimization. 

123 

124 val_min : None or float, optional 

125 An optional lower bound constraint on the reconstructed image. 

126 

127 val_max : None or float, optional 

128 An optional upper bound constraint on the reconstructed image. 

129 

130 verbose : int or bool, optional 

131 If True or 1, print the dual gap of the optimization 

132 

133 fista : bool, optional 

134 If True, uses a FISTA loop to perform the optimization. 

135 if False, uses an ISTA loop. 

136 

137 callback : callable 

138 Callable that takes the local variables at each 

139 steps. Useful for tracking. 

140 

141 init : array of shape as im 

142 Starting point for the optimization. 

143 

144 check_gap_frequency : int, default=4 

145 Frequency at which duality gap is checked for convergence. 

146 

147 Returns 

148 ------- 

149 out : ndarray 

150 TV-l1-denoised image. 

151 

152 Notes 

153 ----- 

154 The principle of total variation denoising is explained in 

155 https://en.wikipedia.org/wiki/Total_variation_denoising 

156 

157 The principle of total variation denoising is to minimize the 

158 total variation of the image, which can be roughly described as 

159 the integral of the norm of the image gradient. Total variation 

160 denoising tends to produce "cartoon-like" images, that is, 

161 piecewise-constant images. 

162 

163 This function implements the FISTA (Fast Iterative Shrinkage 

164 Thresholding Algorithm) algorithm of Beck et Teboulle, adapted to 

165 total variation denoising in "Fast gradient-based algorithms for 

166 constrained total variation image denoising and deblurring problems" 

167 (2009). 

168 

169 For details on implementing the bound constraints, read the aforementioned 

170 Beck and Teboulle paper. 

171 """ 

172 if verbose is False: 

173 verbose = 0 

174 if verbose is True: 

175 verbose = 1 

176 

177 weight = float(weight) 

178 input_img_flat = input_img.view() 

179 input_img_flat.shape = input_img.size 

180 input_img_norm = np.dot(input_img_flat, input_img_flat) 

181 if input_img.dtype.kind != "f": 

182 input_img = input_img.astype(np.float64) 

183 shape = [len(input_img.shape) + 1, *input_img.shape] 

184 grad_im = np.zeros(shape) 

185 grad_aux = np.zeros(shape) 

186 t = 1.0 

187 i = 0 

188 lipschitz_constant = 1.1 * ( 

189 4 * input_img.ndim * (1 - l1_ratio) ** 2 + l1_ratio**2 

190 ) 

191 

192 # negated_output is the negated primal variable in the optimization 

193 # loop 

194 negated_output = -input_img if init is None else -init 

195 

196 # Clipping values for the inner loop 

197 negated_val_min = np.inf 

198 negated_val_max = -np.inf 

199 if val_min is not None: 

200 negated_val_min = -val_min 

201 if val_max is not None: 

202 negated_val_max = -val_max 

203 # With bound constraints, the stopping criterion is on the 

204 # evolution of the output 

205 negated_output_old = negated_output.copy() 

206 grad_tmp = None 

207 old_dgap = np.inf 

208 dgap = np.inf 

209 

210 # A boolean to control if we are going to do a fista step 

211 fista_step = fista 

212 

213 while i < max_iter: 

214 grad_tmp = gradient_id(negated_output, l1_ratio=l1_ratio) 

215 grad_tmp *= 1.0 / (lipschitz_constant * weight) 

216 grad_aux += grad_tmp 

217 grad_tmp = _projector_on_tvl1_dual(grad_aux, l1_ratio) 

218 

219 # Careful, in the next few lines, grad_tmp and grad_aux are a 

220 # view on the same array, as _projector_on_tvl1_dual returns a view 

221 # on the input array 

222 t_new = 0.5 * (1.0 + sqrt(1.0 + 4.0 * t * t)) 

223 t_factor = (t - 1.0) / t_new 

224 

225 grad_aux = grad_tmp 

226 if fista_step: 

227 grad_aux = (1 + t_factor) * grad_tmp - t_factor * grad_im 

228 

229 grad_im = grad_tmp 

230 t = t_new 

231 gap = weight * divergence_id(grad_aux, l1_ratio=l1_ratio) 

232 

233 # Compute the primal variable 

234 negated_output = gap - input_img 

235 if val_min is not None or val_max is not None: 

236 negated_output = negated_output.clip( 

237 negated_val_max, negated_val_min, out=negated_output 

238 ) 

239 if (i % check_gap_frequency) == 0: 

240 if x_tol is None: 

241 # Stopping criterion based on the dual gap 

242 if val_min is not None or val_max is not None: 

243 # We need to recompute the dual variable 

244 gap = negated_output + input_img 

245 old_dgap = dgap 

246 dgap = _dual_gap_prox_tvl1( 

247 input_img_norm, 

248 -negated_output, 

249 gap, 

250 weight, 

251 l1_ratio=l1_ratio, 

252 ) 

253 

254 logger.log( 

255 f"\tProxTVl1: Iteration {i: 2}, dual gap: {dgap: 6.3e}", 

256 verbose, 

257 ) 

258 

259 if dgap < dgap_tol: 

260 break 

261 

262 if old_dgap < dgap: 

263 # M-FISTA strategy: switch to an ISTA to have 

264 # monotone convergence 

265 fista_step = False 

266 elif fista: 

267 fista_step = True 

268 else: 

269 # Stopping criterion based on x_tol 

270 diff = np.max(np.abs(negated_output_old - negated_output)) 

271 diff /= np.max(np.abs(negated_output)) 

272 

273 gid = gradient_id(negated_output, l1_ratio=l1_ratio) 

274 energy = _objective_function_prox_tvl1( 

275 input_img, -negated_output, gid, weight 

276 ) 

277 logger.log( 

278 f"\tProxTVl1 iteration {i: 2}, " 

279 f"relative difference: {diff: 6.3e}, " 

280 f"energy: {energy: 6.3e}", 

281 verbose, 

282 ) 

283 

284 if diff < x_tol: 

285 break 

286 negated_output_old = negated_output 

287 i += 1 

288 

289 # Compute the primal variable, however, here we must use the ista 

290 # value, not the fista one 

291 output = input_img - weight * divergence_id(grad_im, l1_ratio=l1_ratio) 

292 if val_min is not None or val_max is not None: 

293 output = output.clip(val_min, val_max, out=output) 

294 return output, {"converged": (i < max_iter)} 

295 

296 

297def prox_tvl1_with_intercept( 

298 w, 

299 shape, 

300 l1_ratio, 

301 weight, 

302 dgap_tol, 

303 max_iter=5000, 

304 init=None, 

305 verbose=0, 

306): 

307 """Compute TV-L1 prox taking into account the intercept. 

308 

309 Parameters 

310 ---------- 

311 weight : float 

312 Weight in prox. This would be something like `alpha_ * stepsize`, 

313 where `alpha_` is the effective (i.e. re-scaled) alpha. 

314 

315 w : ndarray, shape (w_size,) 

316 The point at which the prox is being computed 

317 

318 init : ndarray, shape (w_size - 1,), default=None 

319 Initialization vector for the prox. 

320 

321 max_iter : int 

322 Maximum number of iterations for the solver. 

323 

324 verbose : int or bool, optional 

325 If True or 1, print the dual gap of the optimization 

326 

327 dgap_tol : float 

328 Dual-gap tolerance for TV-L1 prox operator approximation loop. 

329 

330 """ 

331 if verbose is False: 

332 verbose = 0 

333 if verbose is True: 

334 verbose = 1 

335 

336 init = init.reshape(shape) if init is not None else init 

337 out, prox_info = prox_tvl1( 

338 w[:-1].reshape(shape), 

339 weight=weight, 

340 l1_ratio=l1_ratio, 

341 dgap_tol=dgap_tol, 

342 init=init, 

343 max_iter=max_iter, 

344 verbose=verbose, 

345 ) 

346 

347 return np.append(out, w[-1]), prox_info