Coverage for C:\Users\t590r\Documents\GitHub\suppy\suppy\projections\_projections.py: 87%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2025-02-05 10:12 +0100

1"""Base classes for all projection objects.""" 

2from abc import ABC, abstractmethod 

3from typing import List 

4import numpy as np 

5import numpy.typing as npt 

6 

7try: 

8 import cupy as cp 

9 

10 NO_GPU = False 

11except ImportError: 

12 NO_GPU = True 

13 cp = np 

14 

15 

16class Projection(ABC): 

17 """ 

18 Abstract base class for projections used in feasibility algorithms. 

19 

20 Parameters 

21 ---------- 

22 relaxation : float, optional 

23 Relaxation parameter for the projection, by default 1. 

24 proximity_flag : bool 

25 Flag to indicate whether to take this object into account when calculating proximity, by default True. 

26 

27 Attributes 

28 ---------- 

29 relaxation : float 

30 Relaxation parameter for the projection. 

31 proximity_flag : bool 

32 Flag to indicate whether to take this object into account when calculating proximity. 

33 """ 

34 

35 def __init__(self, relaxation=1, proximity_flag=True, _use_gpu=False): 

36 self.relaxation = relaxation 

37 self.proximity_flag = proximity_flag 

38 self._use_gpu = _use_gpu 

39 

40 # @ensure_float_array 

41 # removed decorator since it leads to unwanted behavior 

42 

43 def step(self, x: npt.NDArray) -> npt.NDArray: 

44 """ 

45 Perform the (possibly relaxed) projection of input array 'x' onto 

46 the constraint. 

47 

48 Parameters 

49 ---------- 

50 x : npt.NDArray 

51 The input array to be projected. 

52 

53 Returns 

54 ------- 

55 npt.NDArray 

56 The (possibly relaxed) projection of 'x' onto the constraint. 

57 """ 

58 return self.project(x) 

59 

60 def project(self, x: npt.NDArray) -> npt.NDArray: 

61 """ 

62 Perform the (possibly relaxed) projection of input array 'x' onto 

63 the constraint. 

64 

65 Parameters 

66 ---------- 

67 x : npt.NDArray 

68 The input array to be projected. 

69 

70 Returns 

71 ------- 

72 npt.NDArray 

73 The (possibly relaxed) projection of 'x' onto the constraint. 

74 """ 

75 if self.relaxation == 1: 

76 return self._project(x) 

77 

78 return x.copy() * (1 - self.relaxation) + self.relaxation * (self._project(x)) 

79 

80 @abstractmethod 

81 def _project(self, x: npt.NDArray) -> npt.NDArray: 

82 """Internal method to project the point x onto the set.""" 

83 

84 def proximity(self, x: npt.NDArray, proximity_measures: List) -> float: 

85 """ 

86 Calculate proximity measures of point `x` to the set. 

87 

88 Parameters 

89 ---------- 

90 x : npt.NDArray 

91 Input array for which the proximity measure is to be calculated. 

92 

93 Returns 

94 ------- 

95 List[float] 

96 The proximity measures of the input array `x`. 

97 """ 

98 xp = cp if isinstance(x, cp.ndarray) else np 

99 if self.proximity_flag: 

100 return xp.array(self._proximity(x, proximity_measures)) 

101 

102 return xp.zeros(len(proximity_measures)) 

103 

104 @abstractmethod 

105 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> float: 

106 """ 

107 Calculate proximity measures of point `x` to set. 

108 

109 Parameters 

110 ---------- 

111 x : npt.NDArray 

112 Input array for which the proximity measures are to be calculated. 

113 proximity_measures : List 

114 List of proximity measures to calculate. 

115 

116 Returns 

117 ------- 

118 List[float] 

119 The proximity measures of the input array `x`. 

120 """ 

121 

122 

123class BasicProjection(Projection, ABC): 

124 """ 

125 BasicProjection is an abstract base class that extends the Projection 

126 class. 

127 It allows for projecting onto a subset of the input array based on provided 

128 indices. 

129 

130 Parameters 

131 ---------- 

132 idx : npt.NDArray or None, optional 

133 Indices to apply the projection, by default None. 

134 relaxation : float, optional 

135 Relaxation parameter for the projection, by default 1. 

136 proximity_flag : bool 

137 Flag to indicate whether to take this object into account when calculating proximity, by default True. 

138 

139 Attributes 

140 ---------- 

141 relaxation : float 

142 Relaxation parameter for the projection. 

143 proximity_flag : bool 

144 Flag to indicate whether to take this object into account when calculating proximity. 

145 idx : npt.NDArray 

146 Subset of the input vector to apply the projection on. 

147 """ 

148 

149 def __init__( 

150 self, relaxation=1, idx: npt.NDArray | None = None, proximity_flag=True, _use_gpu=False 

151 ): 

152 super().__init__(relaxation, proximity_flag, _use_gpu) 

153 self.idx = idx if idx is not None else np.s_[:] 

154 

155 # NOTE: This method should not be required since the base class implementation is sufficient 

156 # def project(self, x: npt.NDArray) -> npt.NDArray: 

157 # """ 

158 # Perform the (possibly relaxed) projection of input array 'x' onto the constraint. 

159 

160 # Parameters 

161 # ---------- 

162 # x : npt.NDArray 

163 # The input array to be projected. 

164 

165 # Returns 

166 # ------- 

167 # npt.NDArray 

168 # The (possibly relaxed) projection of 'x' onto the constraint. 

169 # """ 

170 

171 # if self.relaxation == 1: 

172 # return self._project(x) 

173 # else: 

174 # x[self.idx] = x[self.idx] * (1 - self.relaxation) + self.relaxation * ( 

175 # self._project(x)[self.idx] 

176 # ) 

177 # return x 

178 

179 def _proximity(self, x: npt.NDArray, proximity_measures: List) -> List[float]: 

180 # probably should have some option to choose the distance 

181 res = x[self.idx] - self._project(x.copy())[self.idx] 

182 dist = (res**2).sum() ** (1 / 2) 

183 measures = [] 

184 for measure in proximity_measures: 

185 if isinstance(measure, tuple): 

186 if measure[0] == "p_norm": 

187 measures.append(dist ** measure[1]) 

188 else: 

189 raise ValueError("Invalid proximity measure") 

190 elif isinstance(measure, str) and measure == "max_norm": 

191 measures.append(dist) 

192 else: 

193 raise ValueError("Invalid proximity measure") 

194 return measures