Coverage for arrakis_server/metadata.py: 83.0%
94 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-12 16:39 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-12 16:39 -0700
1# Copyright (c) 2022, California Institute of Technology and contributors
2#
3# You should have received a copy of the licensing terms for this
4# software included in the file "LICENSE" located in the top-level
5# directory of this package. If you did not, you can view a copy at
6# https://git.ligo.org/ngdd/arrakis-server/-/raw/main/LICENSE
8from __future__ import annotations
10import logging
11import re
12from collections import defaultdict
13from collections.abc import Iterable
14from pathlib import Path
15from typing import Any, Protocol
17import numpy
18import toml
20from .channel import Channel
21from .partition import partition_channels
23logger = logging.getLogger("arrakis")
26class ChannelMetadataBackend(Protocol):
27 def update(self, channels: Iterable[Channel]) -> None:
28 """Update channel metadata.
30 Parameters
31 ----------
32 channels : Iterable[Channel]
33 Channels for which to update metadata with.
35 """
36 ...
38 def load(cls, *args, **kwargs) -> list[Channel]:
39 """Load channel metadata"""
40 ...
42 def find(
43 self,
44 *,
45 pattern: str,
46 data_type: list[str],
47 min_rate: int,
48 max_rate: int,
49 publisher: list[str],
50 ) -> Iterable[Channel]:
51 """Find channels matching a set of conditions.
53 Parameters
54 ----------
55 pattern : str
56 Channel pattern to match channels with, using regular expressions.
57 data_type : list[str]
58 Data types to match.
59 min_rate : int
60 Minimum sampling rate for channels.
61 max_rate : int
62 Maximum sampling rate for channels.
63 publisher : list[str]
64 Publishers to match.
66 Returns
67 -------
68 Iterable[Channel]
69 Channel objects for all channels matching query.
71 """
72 ...
74 def describe(self, *, channels: Iterable[str]) -> Iterable[Channel]:
75 """Get channel metadata for channels requested.
77 Parameters
78 ----------
79 channels : Iterable[str]
80 Channels to request.
82 Returns
83 -------
84 Channel
85 Channel objects, one per channel requested.
87 """
88 ...
91class ChannelConfigBackend(ChannelMetadataBackend):
92 """A channel metadata backend backed by configuration.
94 Channel metadata is stored as a configuration file in the TOML format.
96 """
98 def __init__(
99 self,
100 cache_file: Path | None = None,
101 enforce: Iterable[str] | None = None,
102 ):
103 """initialize backend
105 Parameters
106 ----------
107 cache_file : Path, optional
108 Path to file that will hold channel metadata for
109 publishers that are publishing their own channel lists.
110 enforce : list[str], optional
111 A list of Channel properties to enforce the presence of
112 (beyond the default of ["sample_rate", "data_type"]).
114 """
115 self.cache_file = cache_file
116 # always enforced parameters
117 self.enforce = {"sample_rate", "data_type"}
118 if enforce:
119 self.enforce |= set(enforce)
121 self.metadata: dict[str, Channel] = {}
122 self.extra: dict[str, dict[str, Any]] = {}
123 # used for tracking if channels have been updated
124 self._updated: dict[str, bool] = {}
126 if self.cache_file is not None and self.cache_file.exists():
127 self.load(self.cache_file)
128 # reset the load tracking after loading the cache
129 self._updated = {}
131 def validate(self, channel: Channel) -> None:
132 """return True if channel contains all enforced properties
134 Raises a ValueError if the enforced properties are not present
135 or None.
137 """
138 for prop in self.enforce:
139 if not hasattr(channel, prop) or not getattr(channel, prop):
140 raise ValueError(f"channel '{channel.name}' missing property {prop}")
142 def update(self, channels: Iterable[Channel], overwrite: bool = False):
143 """Update channel metadata and sync to cache
145 Channels will be validated according to the enforced property
146 list.
148 Parameters
149 ----------
150 channels : list[Channel]
151 List of channels to upate.
152 overwrite : bool
153 Whether to allow overwriting existing channels or not
154 (default: False)
156 """
157 # update in-memory channel map
158 for channel in channels:
159 if not overwrite and self._updated.get(channel.name, False):
160 raise ValueError(
161 f"attempt to overwrite existing channel: {channel.name}"
162 )
163 self.validate(channel)
164 self.metadata[channel.name] = channel
165 self._updated[channel.name] = True
167 if not self.cache_file:
168 return
170 # write updated channel map to disk
171 metadata = {}
172 for name, meta in self.metadata.items():
173 metadata[name] = {
174 "sample_rate": meta.sample_rate,
175 "data_type": numpy.dtype(meta.data_type).name,
176 "partition_id": meta.partition_id,
177 "publisher": meta.publisher,
178 }
179 if self.extra and name in self.extra:
180 metadata[name].update(self.extra[name])
182 with self.cache_file.open("w") as f:
183 toml.dump(metadata, f)
185 def load(
186 self,
187 path: Path,
188 publisher: str | None = None,
189 overwrite: bool = False,
190 ) -> list[Channel]:
191 """Load channel metadata from TOML file.
193 If a publisher is specified this will also handle assigning
194 Kafka partition IDs to the loaded channels.
195 Returns the list of Channel objects loaded from the file.
197 Parameters
198 ----------
199 path : Path
200 Path to channel metadata toml file.
201 publisher : str
202 Publisher ID to apply to all channels being loaded from
203 this file.
204 overwrite : bool
205 Whether to allow overwriting existing channels or not
206 (default: False)
208 Returns
209 -------
210 List[Channel]
211 List of channels loaded from file.
213 """
214 logger.info("loading channel description file: %s", path)
215 with path.open("r") as f:
216 metadata = toml.load(f)
218 # common metadata for all channels defined in a "common" block
219 common = metadata.pop("common", {})
221 channels = []
222 extra = {}
223 for name, meta in metadata.items():
224 # FIXME: deprecated attributes, should throw deprecation warning
225 if "rate" in meta and "sample_rate" not in meta:
226 meta["sample_rate"] = meta.pop("rate")
227 if "dtype" in meta and "data_type" not in meta:
228 meta["data_type"] = meta.pop("dtype")
229 cmeta = {
230 key: meta.pop(key, common.get(key, None))
231 for key in [
232 "sample_rate",
233 "data_type",
234 "publisher",
235 "partition_id",
236 "expected_latency",
237 ]
238 }
239 channel = Channel(name, **cmeta)
240 channels.append(channel)
241 extra[name] = meta
243 if publisher is not None:
244 channels = partition_channels(
245 channels,
246 metadata=self.metadata,
247 publisher=publisher,
248 )
250 self.update(channels, overwrite=overwrite)
251 self.extra.update(extra)
253 return channels
255 @property
256 def scopes(self) -> dict[str, list[dict[str, Any]]]:
257 """The scopes that the set of channels span."""
258 scopes: dict[str, list[dict[str, Any]]] = defaultdict(list)
259 for channel in self.metadata.values():
260 scopes[channel.domain].append({"subsystem": channel.subsystem})
261 return scopes
263 def find(
264 self,
265 *,
266 pattern: str,
267 data_type: list[str],
268 min_rate: int,
269 max_rate: int,
270 publisher: list[str],
271 ) -> Iterable[Channel]:
272 expr = re.compile(pattern)
273 channels = []
274 dtypes = {numpy.dtype(dtype) for dtype in data_type}
275 publishers = set(publisher)
276 for channel in self.metadata.values():
277 if expr.match(channel.name):
278 rate = channel.sample_rate
279 if not (rate >= min_rate and rate <= max_rate):
280 continue
281 if dtypes and channel.data_type not in dtypes:
282 continue
283 if publishers and channel.publisher not in publishers:
284 continue
285 channels.append(channel)
287 return channels
289 def describe(self, *, channels: Iterable[str]) -> Iterable[Channel]:
290 return [self.metadata[channel] for channel in channels]