Coverage for arrakis_server/backends/mock/__init__.py: 61.7%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-08-12 16:39 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-server/-/raw/main/LICENSE 

7 

8from __future__ import annotations 

9 

10import argparse 

11import logging 

12import time 

13from collections.abc import Callable, Iterable, Iterator 

14from importlib import resources 

15from pathlib import Path 

16from typing import TypeAlias 

17 

18import gpstime 

19import numpy 

20import pyarrow 

21from arrakis import SeriesBlock, Time 

22from pyarrow import flight 

23from sympy import lambdify, parse_expr 

24from sympy.abc import t 

25 

26from ...channel import Channel 

27from ...metadata import ChannelConfigBackend 

28from ...scope import Retention, ScopeInfo 

29from ...traits import ServerBackend 

30from . import channels as channel_lists 

31 

32logger = logging.getLogger("arrakis") 

33 

34 

35ArrayTransform: TypeAlias = Callable[[numpy.ndarray], numpy.ndarray] 

36 

37 

38def _func_random_normal(t): 

39 return numpy.random.normal(size=len(t)) 

40 

41 

42def load_channel_funcs(metadata: ChannelConfigBackend) -> dict[str, ArrayTransform]: 

43 """load channel description TOML files 

44 

45 Returns a dictionary of channel: channel_obj. 

46 

47 Channels should be defined in tables, with the channel name in the 

48 header. The table should include: 

49 

50 `rate` in samples per second 

51 `dtype` as a python dtype 

52 `func` an optional function to generate the data, will be given 

53 the block time array as it's single argument. any sympy 

54 expression containing the t variable is valid 

55 (numpy.random.uniform used by default) 

56 

57 example: 

58 

59 ["MY:CHANNEL-NAME"] 

60 rate = 16384 

61 dtype = "float32" 

62 func = "3*t + cos(t)" 

63 

64 """ 

65 channel_func_map = {} 

66 for channel_name, meta in metadata.extra.items(): 

67 if "func" in meta: 

68 expr = parse_expr(meta["func"]) 

69 func = lambdify(t, expr, "numpy") 

70 else: 

71 func = _func_random_normal 

72 channel_func_map[channel_name] = func 

73 return channel_func_map 

74 

75 

76class MockBackend(ServerBackend): 

77 """Mock server backend that generates synthetic timeseries data. 

78 

79 If channel definition files are not provided then the Mock backend 

80 will serve a pre-defined set of H1: and L1: channels. 

81 

82 """ 

83 

84 def __init__(self, channel_files: list[Path] | None = None): 

85 """Initialize mock server with list of channel definition files.""" 

86 if not channel_files: 

87 with ( 

88 resources.as_file( 

89 resources.files(channel_lists).joinpath("H1_channels.toml") 

90 ) as H1_file, 

91 resources.as_file( 

92 resources.files(channel_lists).joinpath("L1_channels.toml") 

93 ) as L1_file, 

94 ): 

95 channel_files = [H1_file, L1_file] 

96 self.metadata = ChannelConfigBackend() 

97 for channel_file in channel_files: 

98 self.metadata.load(channel_file) 

99 self._channel_func_map = load_channel_funcs(self.metadata) 

100 

101 self.scope_info = ScopeInfo(self.metadata.scopes, Retention()) 

102 

103 @staticmethod 

104 def add_arguments(parser: argparse.ArgumentParser): 

105 """add arguments for this backend to an argparse subparser""" 

106 parser.add_argument( 

107 "channel_files", 

108 metavar="CHANNELS.toml", 

109 nargs="*", 

110 type=Path, 

111 help="Channel definition TOML file (may be specified multiple times).", 

112 ) 

113 

114 @classmethod 

115 def from_args(cls, args: argparse.Namespace) -> MockBackend: 

116 """initialize class from argparse namespace""" 

117 return cls(args.channel_files) 

118 

119 def find( 

120 self, 

121 *, 

122 pattern: str, 

123 data_type: list[str], 

124 min_rate: int, 

125 max_rate: int, 

126 publisher: list[str], 

127 ) -> Iterable[Channel]: 

128 assert isinstance(self.metadata, ChannelConfigBackend) 

129 return self.metadata.find( 

130 pattern=pattern, 

131 data_type=data_type, 

132 min_rate=min_rate, 

133 max_rate=max_rate, 

134 publisher=publisher, 

135 ) 

136 

137 def describe(self, *, channels: Iterable[str]) -> Iterable[Channel]: 

138 self._check_channels(channels) 

139 assert isinstance(self.metadata, ChannelConfigBackend) 

140 return self.metadata.describe(channels=channels) 

141 

142 def stream( 

143 self, *, channels: Iterable[str], start: int, end: int 

144 ) -> Iterator[SeriesBlock]: 

145 self._check_channels(channels) 

146 return self._generate_series(channels, start, end) 

147 

148 def _check_channels(self, channels: Iterable[str]): 

149 bad_channels = [ 

150 channel for channel in channels if channel not in self.metadata.metadata 

151 ] 

152 if bad_channels: 

153 # FIXME: is this the correct error to return? 

154 raise flight.FlightServerError( 

155 f"the following channels are not available on this server: {bad_channels}" # noqa E501 

156 ) 

157 

158 def _generate_block(self, channels: Iterable[str], timestamp: int) -> SeriesBlock: 

159 assert isinstance(self.metadata, ChannelConfigBackend) 

160 channel_data = {} 

161 channel_dict = {} 

162 for channel in channels: 

163 metadata = self.metadata.metadata[channel] 

164 rate = metadata.sample_rate 

165 assert rate is not None 

166 size = rate // 16 

167 dtype = metadata.data_type 

168 time_array = (timestamp / Time.SECONDS) + numpy.arange(size) / rate 

169 func = self._channel_func_map[channel] 

170 data = numpy.array( 

171 numpy.broadcast_to(func(time_array), time_array.shape), 

172 dtype=dtype, 

173 ) 

174 channel_data[channel] = data 

175 channel_dict[channel] = metadata 

176 

177 return SeriesBlock(timestamp, channel_data, channel_dict) 

178 

179 def _generate_series( 

180 self, 

181 channels: Iterable[str], 

182 start: int | None, 

183 end: int | None, 

184 ) -> Iterator[SeriesBlock]: 

185 dt = Time.SECONDS // 16 

186 

187 if start: 

188 current = start 

189 else: 

190 current = (int(gpstime.gpsnow() * Time.SECONDS) // dt) * dt 

191 

192 def _run(current: int): 

193 if end: 

194 return current < end 

195 return True 

196 

197 while _run(current): 

198 yield self._generate_block(channels, current) 

199 current += dt 

200 now = int(gpstime.gpsnow() * Time.SECONDS) 

201 if current >= now: 

202 # sleep for up to dt to simulate live stream 

203 time.sleep(max((current - now) / Time.SECONDS, 0))