Coverage for arrakis_server/metadata.py: 83.0%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-08-12 16:39 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-server/-/raw/main/LICENSE 

7 

8from __future__ import annotations 

9 

10import logging 

11import re 

12from collections import defaultdict 

13from collections.abc import Iterable 

14from pathlib import Path 

15from typing import Any, Protocol 

16 

17import numpy 

18import toml 

19 

20from .channel import Channel 

21from .partition import partition_channels 

22 

23logger = logging.getLogger("arrakis") 

24 

25 

26class ChannelMetadataBackend(Protocol): 

27 def update(self, channels: Iterable[Channel]) -> None: 

28 """Update channel metadata. 

29 

30 Parameters 

31 ---------- 

32 channels : Iterable[Channel] 

33 Channels for which to update metadata with. 

34 

35 """ 

36 ... 

37 

38 def load(cls, *args, **kwargs) -> list[Channel]: 

39 """Load channel metadata""" 

40 ... 

41 

42 def find( 

43 self, 

44 *, 

45 pattern: str, 

46 data_type: list[str], 

47 min_rate: int, 

48 max_rate: int, 

49 publisher: list[str], 

50 ) -> Iterable[Channel]: 

51 """Find channels matching a set of conditions. 

52 

53 Parameters 

54 ---------- 

55 pattern : str 

56 Channel pattern to match channels with, using regular expressions. 

57 data_type : list[str] 

58 Data types to match. 

59 min_rate : int 

60 Minimum sampling rate for channels. 

61 max_rate : int 

62 Maximum sampling rate for channels. 

63 publisher : list[str] 

64 Publishers to match. 

65 

66 Returns 

67 ------- 

68 Iterable[Channel] 

69 Channel objects for all channels matching query. 

70 

71 """ 

72 ... 

73 

74 def describe(self, *, channels: Iterable[str]) -> Iterable[Channel]: 

75 """Get channel metadata for channels requested. 

76 

77 Parameters 

78 ---------- 

79 channels : Iterable[str] 

80 Channels to request. 

81 

82 Returns 

83 ------- 

84 Channel 

85 Channel objects, one per channel requested. 

86 

87 """ 

88 ... 

89 

90 

91class ChannelConfigBackend(ChannelMetadataBackend): 

92 """A channel metadata backend backed by configuration. 

93 

94 Channel metadata is stored as a configuration file in the TOML format. 

95 

96 """ 

97 

98 def __init__( 

99 self, 

100 cache_file: Path | None = None, 

101 enforce: Iterable[str] | None = None, 

102 ): 

103 """initialize backend 

104 

105 Parameters 

106 ---------- 

107 cache_file : Path, optional 

108 Path to file that will hold channel metadata for 

109 publishers that are publishing their own channel lists. 

110 enforce : list[str], optional 

111 A list of Channel properties to enforce the presence of 

112 (beyond the default of ["sample_rate", "data_type"]). 

113 

114 """ 

115 self.cache_file = cache_file 

116 # always enforced parameters 

117 self.enforce = {"sample_rate", "data_type"} 

118 if enforce: 

119 self.enforce |= set(enforce) 

120 

121 self.metadata: dict[str, Channel] = {} 

122 self.extra: dict[str, dict[str, Any]] = {} 

123 # used for tracking if channels have been updated 

124 self._updated: dict[str, bool] = {} 

125 

126 if self.cache_file is not None and self.cache_file.exists(): 

127 self.load(self.cache_file) 

128 # reset the load tracking after loading the cache 

129 self._updated = {} 

130 

131 def validate(self, channel: Channel) -> None: 

132 """return True if channel contains all enforced properties 

133 

134 Raises a ValueError if the enforced properties are not present 

135 or None. 

136 

137 """ 

138 for prop in self.enforce: 

139 if not hasattr(channel, prop) or not getattr(channel, prop): 

140 raise ValueError(f"channel '{channel.name}' missing property {prop}") 

141 

142 def update(self, channels: Iterable[Channel], overwrite: bool = False): 

143 """Update channel metadata and sync to cache 

144 

145 Channels will be validated according to the enforced property 

146 list. 

147 

148 Parameters 

149 ---------- 

150 channels : list[Channel] 

151 List of channels to upate. 

152 overwrite : bool 

153 Whether to allow overwriting existing channels or not 

154 (default: False) 

155 

156 """ 

157 # update in-memory channel map 

158 for channel in channels: 

159 if not overwrite and self._updated.get(channel.name, False): 

160 raise ValueError( 

161 f"attempt to overwrite existing channel: {channel.name}" 

162 ) 

163 self.validate(channel) 

164 self.metadata[channel.name] = channel 

165 self._updated[channel.name] = True 

166 

167 if not self.cache_file: 

168 return 

169 

170 # write updated channel map to disk 

171 metadata = {} 

172 for name, meta in self.metadata.items(): 

173 metadata[name] = { 

174 "sample_rate": meta.sample_rate, 

175 "data_type": numpy.dtype(meta.data_type).name, 

176 "partition_id": meta.partition_id, 

177 "publisher": meta.publisher, 

178 } 

179 if self.extra and name in self.extra: 

180 metadata[name].update(self.extra[name]) 

181 

182 with self.cache_file.open("w") as f: 

183 toml.dump(metadata, f) 

184 

185 def load( 

186 self, 

187 path: Path, 

188 publisher: str | None = None, 

189 overwrite: bool = False, 

190 ) -> list[Channel]: 

191 """Load channel metadata from TOML file. 

192 

193 If a publisher is specified this will also handle assigning 

194 Kafka partition IDs to the loaded channels. 

195 Returns the list of Channel objects loaded from the file. 

196 

197 Parameters 

198 ---------- 

199 path : Path 

200 Path to channel metadata toml file. 

201 publisher : str 

202 Publisher ID to apply to all channels being loaded from 

203 this file. 

204 overwrite : bool 

205 Whether to allow overwriting existing channels or not 

206 (default: False) 

207 

208 Returns 

209 ------- 

210 List[Channel] 

211 List of channels loaded from file. 

212 

213 """ 

214 logger.info("loading channel description file: %s", path) 

215 with path.open("r") as f: 

216 metadata = toml.load(f) 

217 

218 # common metadata for all channels defined in a "common" block 

219 common = metadata.pop("common", {}) 

220 

221 channels = [] 

222 extra = {} 

223 for name, meta in metadata.items(): 

224 # FIXME: deprecated attributes, should throw deprecation warning 

225 if "rate" in meta and "sample_rate" not in meta: 

226 meta["sample_rate"] = meta.pop("rate") 

227 if "dtype" in meta and "data_type" not in meta: 

228 meta["data_type"] = meta.pop("dtype") 

229 cmeta = { 

230 key: meta.pop(key, common.get(key, None)) 

231 for key in [ 

232 "sample_rate", 

233 "data_type", 

234 "publisher", 

235 "partition_id", 

236 "expected_latency", 

237 ] 

238 } 

239 channel = Channel(name, **cmeta) 

240 channels.append(channel) 

241 extra[name] = meta 

242 

243 if publisher is not None: 

244 channels = partition_channels( 

245 channels, 

246 metadata=self.metadata, 

247 publisher=publisher, 

248 ) 

249 

250 self.update(channels, overwrite=overwrite) 

251 self.extra.update(extra) 

252 

253 return channels 

254 

255 @property 

256 def scopes(self) -> dict[str, list[dict[str, Any]]]: 

257 """The scopes that the set of channels span.""" 

258 scopes: dict[str, list[dict[str, Any]]] = defaultdict(list) 

259 for channel in self.metadata.values(): 

260 scopes[channel.domain].append({"subsystem": channel.subsystem}) 

261 return scopes 

262 

263 def find( 

264 self, 

265 *, 

266 pattern: str, 

267 data_type: list[str], 

268 min_rate: int, 

269 max_rate: int, 

270 publisher: list[str], 

271 ) -> Iterable[Channel]: 

272 expr = re.compile(pattern) 

273 channels = [] 

274 dtypes = {numpy.dtype(dtype) for dtype in data_type} 

275 publishers = set(publisher) 

276 for channel in self.metadata.values(): 

277 if expr.match(channel.name): 

278 rate = channel.sample_rate 

279 if not (rate >= min_rate and rate <= max_rate): 

280 continue 

281 if dtypes and channel.data_type not in dtypes: 

282 continue 

283 if publishers and channel.publisher not in publishers: 

284 continue 

285 channels.append(channel) 

286 

287 return channels 

288 

289 def describe(self, *, channels: Iterable[str]) -> Iterable[Channel]: 

290 return [self.metadata[channel] for channel in channels]