Coverage for nilearn/decoding/_proximal_operators.py: 7%
114 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 13:00 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 13:00 +0200
1"""Implementation of multiple proximal operators for TV-L1, Graph-Net, etc."""
3from math import sqrt
5import numpy as np
7from nilearn._utils import logger
9from ._objective_functions import (
10 divergence_id,
11 gradient_id,
12 tv_l1_from_gradient,
13)
16def prox_l1(y, alpha, copy=True):
17 """Compute proximity operator for L1 norm."""
18 shrink = np.zeros(y.shape)
19 if copy:
20 y = y.copy()
21 y_nz = y.nonzero()
22 shrink[y_nz] = np.maximum(1 - alpha / np.abs(y[y_nz]), 0)
23 y *= shrink
24 return y
27def prox_l1_with_intercept(x, tau):
28 """Return the same as prox_l1, but just for the n-1 components."""
29 x[:-1] = prox_l1(x[:-1], tau)
30 return x
33def _projector_on_tvl1_dual(grad, l1_ratio):
34 """Compute TV-l1 duality gap.
36 Modifies IN PLACE the gradient + id to project it
37 on the l21 unit ball in the gradient direction and the L1 ball in the
38 identity direction.
39 """
40 # The l21 ball for the gradient direction
41 if l1_ratio < 1.0:
42 # infer number of axes and include an additional axis if l1_ratio > 0
43 end = len(grad) - int(l1_ratio > 0.0)
44 norm = np.sqrt(np.sum(grad[:end] * grad[:end], 0))
45 norm.clip(1.0, out=norm) # set everythx < 1 to 1
46 for grad_comp in grad[:end]:
47 grad_comp /= norm
49 # The L1 ball for the identity direction
50 if l1_ratio > 0.0:
51 norm = np.abs(grad[-1])
52 norm.clip(1.0, out=norm)
53 grad[-1] /= norm
55 return grad
58def _dual_gap_prox_tvl1(input_img_norm, new, gap, weight, l1_ratio=1.0):
59 """Compute dual gap of total variation denoising.
61 See "Total variation regularization for fMRI-based prediction of behavior",
62 by Michel et al. (2011) for a derivation of the dual gap
63 """
64 tv_new = tv_l1_from_gradient(gradient_id(new, l1_ratio=l1_ratio))
65 gap = gap.ravel()
66 d_gap = (
67 np.dot(gap, gap)
68 + 2 * weight * tv_new
69 - input_img_norm
70 + (new * new).sum()
71 )
72 return 0.5 * d_gap
75def _objective_function_prox_tvl1(input_img, output_img, gradient, weight):
76 diff = (input_img - output_img).ravel()
77 return 0.5 * (diff * diff).sum() + weight * tv_l1_from_gradient(gradient)
80def prox_tvl1(
81 input_img,
82 l1_ratio=0.05,
83 weight=50,
84 dgap_tol=5.0e-5,
85 x_tol=None,
86 max_iter=200,
87 check_gap_frequency=4,
88 val_min=None,
89 val_max=None,
90 verbose=0,
91 fista=True,
92 init=None,
93):
94 """
95 Compute the TV-L1 proximal (ie total-variation +l1 denoising) on 3d images.
97 Find the argmin `res` of
98 1/2 * ||im - res||^2 + weight * TVl1(res),
100 Parameters
101 ----------
102 input_img : ndarray of floats (2-d or 3-d)
103 Input data to be denoised. `im` can be of any numeric type,
104 but it is cast into an ndarray of floats for the computation
105 of the denoised image.
107 weight : float, optional
108 Denoising weight. The greater ``weight``, the more denoising (at
109 the expense of fidelity to ``input``)
111 dgap_tol : float, optional
112 Precision required. The distance to the exact solution is computed
113 by the dual gap of the optimization problem and rescaled by the
114 squared l2 norm of the image (for contrast invariance).
116 x_tol : float or None, optional
117 The maximal relative difference between input and output. If
118 specified, this specifies a stopping criterion on x, rather than
119 the dual gap.
121 max_iter : int, optional
122 Maximal number of iterations used for the optimization.
124 val_min : None or float, optional
125 An optional lower bound constraint on the reconstructed image.
127 val_max : None or float, optional
128 An optional upper bound constraint on the reconstructed image.
130 verbose : int or bool, optional
131 If True or 1, print the dual gap of the optimization
133 fista : bool, optional
134 If True, uses a FISTA loop to perform the optimization.
135 if False, uses an ISTA loop.
137 callback : callable
138 Callable that takes the local variables at each
139 steps. Useful for tracking.
141 init : array of shape as im
142 Starting point for the optimization.
144 check_gap_frequency : int, default=4
145 Frequency at which duality gap is checked for convergence.
147 Returns
148 -------
149 out : ndarray
150 TV-l1-denoised image.
152 Notes
153 -----
154 The principle of total variation denoising is explained in
155 https://en.wikipedia.org/wiki/Total_variation_denoising
157 The principle of total variation denoising is to minimize the
158 total variation of the image, which can be roughly described as
159 the integral of the norm of the image gradient. Total variation
160 denoising tends to produce "cartoon-like" images, that is,
161 piecewise-constant images.
163 This function implements the FISTA (Fast Iterative Shrinkage
164 Thresholding Algorithm) algorithm of Beck et Teboulle, adapted to
165 total variation denoising in "Fast gradient-based algorithms for
166 constrained total variation image denoising and deblurring problems"
167 (2009).
169 For details on implementing the bound constraints, read the aforementioned
170 Beck and Teboulle paper.
171 """
172 if verbose is False:
173 verbose = 0
174 if verbose is True:
175 verbose = 1
177 weight = float(weight)
178 input_img_flat = input_img.view()
179 input_img_flat.shape = input_img.size
180 input_img_norm = np.dot(input_img_flat, input_img_flat)
181 if input_img.dtype.kind != "f":
182 input_img = input_img.astype(np.float64)
183 shape = [len(input_img.shape) + 1, *input_img.shape]
184 grad_im = np.zeros(shape)
185 grad_aux = np.zeros(shape)
186 t = 1.0
187 i = 0
188 lipschitz_constant = 1.1 * (
189 4 * input_img.ndim * (1 - l1_ratio) ** 2 + l1_ratio**2
190 )
192 # negated_output is the negated primal variable in the optimization
193 # loop
194 negated_output = -input_img if init is None else -init
196 # Clipping values for the inner loop
197 negated_val_min = np.inf
198 negated_val_max = -np.inf
199 if val_min is not None:
200 negated_val_min = -val_min
201 if val_max is not None:
202 negated_val_max = -val_max
203 # With bound constraints, the stopping criterion is on the
204 # evolution of the output
205 negated_output_old = negated_output.copy()
206 grad_tmp = None
207 old_dgap = np.inf
208 dgap = np.inf
210 # A boolean to control if we are going to do a fista step
211 fista_step = fista
213 while i < max_iter:
214 grad_tmp = gradient_id(negated_output, l1_ratio=l1_ratio)
215 grad_tmp *= 1.0 / (lipschitz_constant * weight)
216 grad_aux += grad_tmp
217 grad_tmp = _projector_on_tvl1_dual(grad_aux, l1_ratio)
219 # Careful, in the next few lines, grad_tmp and grad_aux are a
220 # view on the same array, as _projector_on_tvl1_dual returns a view
221 # on the input array
222 t_new = 0.5 * (1.0 + sqrt(1.0 + 4.0 * t * t))
223 t_factor = (t - 1.0) / t_new
225 grad_aux = grad_tmp
226 if fista_step:
227 grad_aux = (1 + t_factor) * grad_tmp - t_factor * grad_im
229 grad_im = grad_tmp
230 t = t_new
231 gap = weight * divergence_id(grad_aux, l1_ratio=l1_ratio)
233 # Compute the primal variable
234 negated_output = gap - input_img
235 if val_min is not None or val_max is not None:
236 negated_output = negated_output.clip(
237 negated_val_max, negated_val_min, out=negated_output
238 )
239 if (i % check_gap_frequency) == 0:
240 if x_tol is None:
241 # Stopping criterion based on the dual gap
242 if val_min is not None or val_max is not None:
243 # We need to recompute the dual variable
244 gap = negated_output + input_img
245 old_dgap = dgap
246 dgap = _dual_gap_prox_tvl1(
247 input_img_norm,
248 -negated_output,
249 gap,
250 weight,
251 l1_ratio=l1_ratio,
252 )
254 logger.log(
255 f"\tProxTVl1: Iteration {i: 2}, dual gap: {dgap: 6.3e}",
256 verbose,
257 )
259 if dgap < dgap_tol:
260 break
262 if old_dgap < dgap:
263 # M-FISTA strategy: switch to an ISTA to have
264 # monotone convergence
265 fista_step = False
266 elif fista:
267 fista_step = True
268 else:
269 # Stopping criterion based on x_tol
270 diff = np.max(np.abs(negated_output_old - negated_output))
271 diff /= np.max(np.abs(negated_output))
273 gid = gradient_id(negated_output, l1_ratio=l1_ratio)
274 energy = _objective_function_prox_tvl1(
275 input_img, -negated_output, gid, weight
276 )
277 logger.log(
278 f"\tProxTVl1 iteration {i: 2}, "
279 f"relative difference: {diff: 6.3e}, "
280 f"energy: {energy: 6.3e}",
281 verbose,
282 )
284 if diff < x_tol:
285 break
286 negated_output_old = negated_output
287 i += 1
289 # Compute the primal variable, however, here we must use the ista
290 # value, not the fista one
291 output = input_img - weight * divergence_id(grad_im, l1_ratio=l1_ratio)
292 if val_min is not None or val_max is not None:
293 output = output.clip(val_min, val_max, out=output)
294 return output, {"converged": (i < max_iter)}
297def prox_tvl1_with_intercept(
298 w,
299 shape,
300 l1_ratio,
301 weight,
302 dgap_tol,
303 max_iter=5000,
304 init=None,
305 verbose=0,
306):
307 """Compute TV-L1 prox taking into account the intercept.
309 Parameters
310 ----------
311 weight : float
312 Weight in prox. This would be something like `alpha_ * stepsize`,
313 where `alpha_` is the effective (i.e. re-scaled) alpha.
315 w : ndarray, shape (w_size,)
316 The point at which the prox is being computed
318 init : ndarray, shape (w_size - 1,), default=None
319 Initialization vector for the prox.
321 max_iter : int
322 Maximum number of iterations for the solver.
324 verbose : int or bool, optional
325 If True or 1, print the dual gap of the optimization
327 dgap_tol : float
328 Dual-gap tolerance for TV-L1 prox operator approximation loop.
330 """
331 if verbose is False:
332 verbose = 0
333 if verbose is True:
334 verbose = 1
336 init = init.reshape(shape) if init is not None else init
337 out, prox_info = prox_tvl1(
338 w[:-1].reshape(shape),
339 weight=weight,
340 l1_ratio=l1_ratio,
341 dgap_tol=dgap_tol,
342 init=init,
343 max_iter=max_iter,
344 verbose=verbose,
345 )
347 return np.append(out, w[-1]), prox_info