Coverage for C:\Users\t590r\Documents\GitHub\suppy\suppy\projections\_projections.py: 87%
47 statements
« prev ^ index » next coverage.py v7.6.4, created at 2025-02-05 10:12 +0100
« prev ^ index » next coverage.py v7.6.4, created at 2025-02-05 10:12 +0100
1"""Base classes for all projection objects."""
2from abc import ABC, abstractmethod
3from typing import List
4import numpy as np
5import numpy.typing as npt
7try:
8 import cupy as cp
10 NO_GPU = False
11except ImportError:
12 NO_GPU = True
13 cp = np
16class Projection(ABC):
17 """
18 Abstract base class for projections used in feasibility algorithms.
20 Parameters
21 ----------
22 relaxation : float, optional
23 Relaxation parameter for the projection, by default 1.
24 proximity_flag : bool
25 Flag to indicate whether to take this object into account when calculating proximity, by default True.
27 Attributes
28 ----------
29 relaxation : float
30 Relaxation parameter for the projection.
31 proximity_flag : bool
32 Flag to indicate whether to take this object into account when calculating proximity.
33 """
35 def __init__(self, relaxation=1, proximity_flag=True, _use_gpu=False):
36 self.relaxation = relaxation
37 self.proximity_flag = proximity_flag
38 self._use_gpu = _use_gpu
40 # @ensure_float_array
41 # removed decorator since it leads to unwanted behavior
43 def step(self, x: npt.NDArray) -> npt.NDArray:
44 """
45 Perform the (possibly relaxed) projection of input array 'x' onto
46 the constraint.
48 Parameters
49 ----------
50 x : npt.NDArray
51 The input array to be projected.
53 Returns
54 -------
55 npt.NDArray
56 The (possibly relaxed) projection of 'x' onto the constraint.
57 """
58 return self.project(x)
60 def project(self, x: npt.NDArray) -> npt.NDArray:
61 """
62 Perform the (possibly relaxed) projection of input array 'x' onto
63 the constraint.
65 Parameters
66 ----------
67 x : npt.NDArray
68 The input array to be projected.
70 Returns
71 -------
72 npt.NDArray
73 The (possibly relaxed) projection of 'x' onto the constraint.
74 """
75 if self.relaxation == 1:
76 return self._project(x)
78 return x.copy() * (1 - self.relaxation) + self.relaxation * (self._project(x))
80 @abstractmethod
81 def _project(self, x: npt.NDArray) -> npt.NDArray:
82 """Internal method to project the point x onto the set."""
84 def proximity(self, x: npt.NDArray, proximity_measures: List) -> float:
85 """
86 Calculate proximity measures of point `x` to the set.
88 Parameters
89 ----------
90 x : npt.NDArray
91 Input array for which the proximity measure is to be calculated.
93 Returns
94 -------
95 List[float]
96 The proximity measures of the input array `x`.
97 """
98 xp = cp if isinstance(x, cp.ndarray) else np
99 if self.proximity_flag:
100 return xp.array(self._proximity(x, proximity_measures))
102 return xp.zeros(len(proximity_measures))
104 @abstractmethod
105 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> float:
106 """
107 Calculate proximity measures of point `x` to set.
109 Parameters
110 ----------
111 x : npt.NDArray
112 Input array for which the proximity measures are to be calculated.
113 proximity_measures : List
114 List of proximity measures to calculate.
116 Returns
117 -------
118 List[float]
119 The proximity measures of the input array `x`.
120 """
123class BasicProjection(Projection, ABC):
124 """
125 BasicProjection is an abstract base class that extends the Projection
126 class.
127 It allows for projecting onto a subset of the input array based on provided
128 indices.
130 Parameters
131 ----------
132 idx : npt.NDArray or None, optional
133 Indices to apply the projection, by default None.
134 relaxation : float, optional
135 Relaxation parameter for the projection, by default 1.
136 proximity_flag : bool
137 Flag to indicate whether to take this object into account when calculating proximity, by default True.
139 Attributes
140 ----------
141 relaxation : float
142 Relaxation parameter for the projection.
143 proximity_flag : bool
144 Flag to indicate whether to take this object into account when calculating proximity.
145 idx : npt.NDArray
146 Subset of the input vector to apply the projection on.
147 """
149 def __init__(
150 self, relaxation=1, idx: npt.NDArray | None = None, proximity_flag=True, _use_gpu=False
151 ):
152 super().__init__(relaxation, proximity_flag, _use_gpu)
153 self.idx = idx if idx is not None else np.s_[:]
155 # NOTE: This method should not be required since the base class implementation is sufficient
156 # def project(self, x: npt.NDArray) -> npt.NDArray:
157 # """
158 # Perform the (possibly relaxed) projection of input array 'x' onto the constraint.
160 # Parameters
161 # ----------
162 # x : npt.NDArray
163 # The input array to be projected.
165 # Returns
166 # -------
167 # npt.NDArray
168 # The (possibly relaxed) projection of 'x' onto the constraint.
169 # """
171 # if self.relaxation == 1:
172 # return self._project(x)
173 # else:
174 # x[self.idx] = x[self.idx] * (1 - self.relaxation) + self.relaxation * (
175 # self._project(x)[self.idx]
176 # )
177 # return x
179 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]:
180 # probably should have some option to choose the distance
181 res = x[self.idx] - self._project(x.copy())[self.idx]
182 dist = (res**2).sum() ** (1 / 2)
183 measures = []
184 for measure in proximity_measures:
185 if isinstance(measure, tuple):
186 if measure[0] == "p_norm":
187 measures.append(dist ** measure[1])
188 else:
189 raise ValueError("Invalid proximity measure")
190 elif isinstance(measure, str) and measure == "max_norm":
191 measures.append(dist)
192 else:
193 raise ValueError("Invalid proximity measure")
194 return measures