Coverage for arrakis_server/traits.py: 89.5%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-08-12 16:39 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-server/-/raw/main/LICENSE 

7 

8from __future__ import annotations 

9 

10import argparse 

11import threading 

12from collections.abc import Iterable, Iterator 

13from typing import Protocol, runtime_checkable 

14 

15from arrakis import SeriesBlock 

16from typing_extensions import TypeGuard 

17 

18from .channel import Channel 

19from .scope import Retention, ScopeInfo 

20 

21 

22@runtime_checkable 

23class ServerBackend(Protocol): 

24 scope_info: ScopeInfo 

25 

26 def __str__(self): 

27 return f"<{self.__class__.__name__} domains: {self.domains}>" 

28 

29 @property 

30 def domains(self) -> set[str]: 

31 return self.scope_info.domains 

32 

33 @property 

34 def retention(self) -> Retention: 

35 return self.scope_info.retention 

36 

37 @staticmethod 

38 def add_arguments(parser: argparse.ArgumentParser): 

39 """Add custom arguments to an argparse ArgumentParser""" 

40 ... 

41 

42 @classmethod 

43 def from_args(cls, args: argparse.Namespace) -> ServerBackend: 

44 """Instantiate a backend from an argparse Namespace""" 

45 ... 

46 

47 def find( 

48 self, 

49 *, 

50 pattern: str, 

51 data_type: list[str], 

52 min_rate: int, 

53 max_rate: int, 

54 publisher: list[str], 

55 ) -> Iterable[Channel]: 

56 """Find channels matching a set of conditions. 

57 

58 Parameters 

59 ---------- 

60 pattern : str 

61 Channel pattern to match channels with, using regular expressions. 

62 data_type : list[str] 

63 Data types to match. 

64 min_rate : int 

65 Minimum sampling rate for channels. 

66 max_rate : int 

67 Maximum sampling rate for channels. 

68 publisher : list[str] 

69 Sources to match. 

70 

71 Returns 

72 ------- 

73 Iterable[Channel] 

74 Channel objects for all channels matching query. 

75 

76 """ 

77 ... 

78 

79 def count( 

80 self, 

81 *, 

82 pattern: str, 

83 data_type: list[str], 

84 min_rate: int, 

85 max_rate: int, 

86 publisher: list[str], 

87 ) -> int: 

88 """Count channels matching a set of conditions. 

89 

90 Parameters 

91 ---------- 

92 pattern : str 

93 Channel pattern to match channels with, using regular expressions. 

94 data_type : list[str] 

95 Data types to match. 

96 min_rate : int 

97 Minimum sampling rate for channels. 

98 max_rate : int 

99 Maximum sampling rate for channels. 

100 publisher : list[str] 

101 Sources to match. 

102 

103 Returns 

104 ------- 

105 int 

106 The number of channels matching query. 

107 

108 """ 

109 metadata = self.find( 

110 pattern=pattern, 

111 data_type=data_type, 

112 min_rate=min_rate, 

113 max_rate=max_rate, 

114 publisher=publisher, 

115 ) 

116 return sum(1 for _ in metadata) 

117 

118 def describe(self, *, channels: Iterable[str]) -> Iterable[Channel]: 

119 """Get channel metadata for channels requested. 

120 

121 Parameters 

122 ---------- 

123 channels : Iterable[str] 

124 Channels to request. 

125 

126 Returns 

127 ------- 

128 Channel 

129 Channel objects, one per channel requested. 

130 

131 """ 

132 ... 

133 

134 def stream( 

135 self, *, channels: Iterable[str], start: int, end: int 

136 ) -> Iterator[SeriesBlock]: 

137 """Stream timeseries data. 

138 

139 Parameters 

140 ---------- 

141 channels : Iterable[str] 

142 Channels to request. 

143 start : int 

144 GPS start time. 

145 end : int 

146 GPS end time. 

147 

148 Yields 

149 ------ 

150 SeriesBlock 

151 Dictionary-like object containing all requested channel data. 

152 

153 Setting neither start nor end begins a live stream starting 

154 from now. 

155 

156 """ 

157 ... 

158 

159 

160@runtime_checkable 

161class PublishServerBackend(ServerBackend, Protocol): 

162 _lock: threading.Lock = threading.Lock() 

163 

164 def publish(self, *, publisher_id: str) -> dict[str, str]: 

165 """Return producer-based connection info needed to publish data. 

166 

167 Parameters 

168 ---------- 

169 publisher_id : str 

170 The ID assigned to the publisher. 

171 

172 Returns 

173 ------- 

174 dict[str, str] 

175 A dictionary containing producer-based connection info. 

176 

177 """ 

178 ... 

179 

180 def partition( 

181 self, *, publisher_id: str, channels: Iterable[Channel] 

182 ) -> Iterable[Channel]: 

183 """Return producer-based connection info needed to publish data. 

184 

185 Parameters 

186 ---------- 

187 publisher_id : str 

188 The ID assigned to the producer. 

189 channels : Iterable[Channel] 

190 Channel objects, one for each channel needing to have their 

191 partitions assigned. 

192 

193 Returns 

194 ------- 

195 Iterable[Channel] 

196 Channel objects with their partition IDs set. 

197 

198 """ 

199 ... 

200 

201 

202MaybeBackend = ServerBackend | None 

203 

204 

205def can_publish( 

206 backend: ServerBackend | PublishServerBackend | None, 

207) -> TypeGuard[PublishServerBackend]: 

208 """Determine if a server backend supports publish-like functionality.""" 

209 # Note this is actually a "protocol" check, essentially a duck 

210 # type check, rather than an check that the class is explicitly an 

211 # instance of the specified protocol being compared against. See: 

212 # https://typing.readthedocs.io/en/latest/spec/protocol.html 

213 return isinstance(backend, PublishServerBackend)