Coverage for C:\Users\t590r\Documents\GitHub\suppy\suppy\perturbations\_base.py: 93%
56 statements
« prev ^ index » next coverage.py v7.6.4, created at 2025-02-05 10:12 +0100
« prev ^ index » next coverage.py v7.6.4, created at 2025-02-05 10:12 +0100
1"""Base class for perturbations applied to feasibility seeking algorithms."""
2from abc import ABC, abstractmethod
3from typing import Callable, List
4import numpy as np
5import numpy.typing as npt
6from suppy.utils import FuncWrapper
8try:
9 import cupy as cp
11 NO_GPU = False
12except ImportError:
13 NO_GPU = True
14 cp = np
17class Perturbation(ABC):
18 """
19 Abstract base class for perturbations applied to feasibility seeking
20 algorithms.
21 """
23 @abstractmethod
24 def perturbation_step(self, x: npt.NDArray) -> npt.NDArray:
25 """
26 Perform a perturbation step.
28 Parameters
29 ----------
30 x : npt.NDArray
31 The input array to be perturbed.
33 Returns
34 -------
35 npt.NDArray
36 The perturbed array.
37 """
40class ObjectivePerturbation(Perturbation, ABC):
41 """
42 Base class for perturbations performed by decreasing an objective
43 function.
45 Parameters
46 ----------
47 func : Callable
48 The objective function to be perturbed.
49 func_args : List
50 The arguments to be passed to the objective function.
51 n_red : int, optional
52 The number of reduction steps to perform in one perturbation iteration (default is 1).
54 Attributes
55 ----------
56 func : FuncWrapper
57 A wrapped version of the objective function with its arguments.
58 n_red : int
59 The number of reduction steps to perform.
60 _k : int
61 Keeps track of the number of performed perturbations.
62 """
64 def __init__(self, func: Callable, func_args: List, n_red=1):
65 self.func = FuncWrapper(func, func_args)
66 self.n_red = n_red
67 self._k = 0 # keeps track of the number of performed perturbations
69 def perturbation_step(self, x: npt.NDArray) -> npt.NDArray:
70 """
71 Perform n_red perturbation steps on the input array.
73 Parameters
74 ----------
75 x : npt.NDArray
76 The input array to be perturbed.
78 Returns
79 -------
80 npt.NDArray
81 The perturbed array after applying the reduction steps.
82 """
84 self._k += 1
85 n = 0
86 while n < self.n_red:
87 x = self._function_reduction_step(x)
88 n += 1
89 return x
91 @abstractmethod
92 def _function_reduction_step(self, x: npt.NDArray) -> npt.NDArray:
93 """
94 Abstract method to perform that should implement the individual
95 function reduction steps on the input array.
96 Needs to be implemented by subclasses.
98 Parameters
99 ----------
100 x : npt.NDArray
101 Input array on which the reduction step is to be performed.
103 Returns
104 -------
105 npt.NDArray
106 The array after the reduction step has been applied.
107 """
109 def pre_step(self):
110 """
111 If required perform any form of step previous to each
112 perturbation(?) in each iteration.
114 This method is intended to be overridden by subclasses to implement
115 specific pre-step logic. By default, it does nothing.
116 """
119class GradientPerturbation(ObjectivePerturbation, ABC):
120 """
121 A class for perturbations performed by decreasing an objective function
122 using the gradient.
124 Parameters
125 ----------
126 func : Callable
127 The objective function to be perturbed.
128 grad : Callable
129 The gradient function of the objective function.
130 func_args : List
131 The arguments to be passed to the objective function.
132 grad_args : List
133 The arguments to be passed to the gradient function.
134 n_red : int, optional
135 The reduction factor, by default 1.
137 Attributes
138 ----------
139 func : FuncWrapper
140 A wrapped version of the objective function with its arguments.
141 grad : FuncWrapper
142 A wrapped gradient function with its arguments.
143 n_red : int
144 The number of reduction steps to perform.
145 _k : int
146 Keeps track of the number of performed perturbations.
147 """
149 def __init__(self, func: Callable, grad: Callable, func_args: List, grad_args: List, n_red=1):
150 super().__init__(func, func_args, n_red)
151 self.grad = FuncWrapper(grad, grad_args)
154class PowerSeriesGradientPerturbation(GradientPerturbation):
155 """
156 Objective function perturbation using gradient descent with step size
157 reduction according to a power series.
158 Has the option to "restart" the power series after a certain number of
159 steps.
161 func : Callable
162 The function to be optimized.
163 grad : Callable
164 The gradient of the function to be optimized.
165 func_args : List, optional
166 Additional arguments to be passed to the function, by default [].
167 grad_args : List, optional
168 Additional arguments to be passed to the gradient function, by default [].
169 n_red : int, optional
170 The number of reductions, by default 1.
171 step_size : float, optional
172 The step size for the gradient descent, by default 0.5.
173 n_restart : int, optional
174 The number of steps after which to restart the power series, by default -1 (no restart).
175 """
177 def __init__(
178 self,
179 func: Callable,
180 grad: Callable,
181 func_args: List = [],
182 grad_args: List = [],
183 n_red=1,
184 step_size=0.5,
185 n_restart=-1,
186 ):
187 super().__init__(func, grad, func_args, grad_args, n_red)
188 self.step_size = step_size
189 self._l = -1
190 self.n_restart = np.inf if n_restart == -1 else n_restart
192 def _function_reduction_step(self, x: npt.NDArray) -> npt.NDArray:
193 """
194 Perform a function reduction step using gradient descent.
196 Parameters
197 ----------
198 x : npt.NDArray
199 The current point in the optimization process.
201 Returns
202 -------
203 npt.NDArray
204 The updated point after performing the reduction step.
205 """
206 xp = cp if isinstance(x, cp.ndarray) else np
207 grad_eval = self.grad(x)
208 func_eval = self.func(x)
209 loop = True
210 while loop:
211 self._l += 1
212 x_ln = x - self.step_size**self._l * grad_eval / (xp.linalg.norm(grad_eval))
213 y_ln = self.func(x_ln)
214 if y_ln <= func_eval:
215 return x_ln
216 return x_ln
218 def pre_step(self):
219 """
220 Resets the power series after n steps.
222 Returns
223 -------
224 None
225 """
226 if self._k <= 0:
227 return
228 # possibly restart the power series
229 if self._k % self.n_restart == 0:
230 self._l = self._k // self.n_restart