Coverage for arrakis_server/backends/mock/__init__.py: 61.7%
94 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-12 16:39 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-12 16:39 -0700
1# Copyright (c) 2022, California Institute of Technology and contributors
2#
3# You should have received a copy of the licensing terms for this
4# software included in the file "LICENSE" located in the top-level
5# directory of this package. If you did not, you can view a copy at
6# https://git.ligo.org/ngdd/arrakis-server/-/raw/main/LICENSE
8from __future__ import annotations
10import argparse
11import logging
12import time
13from collections.abc import Callable, Iterable, Iterator
14from importlib import resources
15from pathlib import Path
16from typing import TypeAlias
18import gpstime
19import numpy
20import pyarrow
21from arrakis import SeriesBlock, Time
22from pyarrow import flight
23from sympy import lambdify, parse_expr
24from sympy.abc import t
26from ...channel import Channel
27from ...metadata import ChannelConfigBackend
28from ...scope import Retention, ScopeInfo
29from ...traits import ServerBackend
30from . import channels as channel_lists
32logger = logging.getLogger("arrakis")
35ArrayTransform: TypeAlias = Callable[[numpy.ndarray], numpy.ndarray]
38def _func_random_normal(t):
39 return numpy.random.normal(size=len(t))
42def load_channel_funcs(metadata: ChannelConfigBackend) -> dict[str, ArrayTransform]:
43 """load channel description TOML files
45 Returns a dictionary of channel: channel_obj.
47 Channels should be defined in tables, with the channel name in the
48 header. The table should include:
50 `rate` in samples per second
51 `dtype` as a python dtype
52 `func` an optional function to generate the data, will be given
53 the block time array as it's single argument. any sympy
54 expression containing the t variable is valid
55 (numpy.random.uniform used by default)
57 example:
59 ["MY:CHANNEL-NAME"]
60 rate = 16384
61 dtype = "float32"
62 func = "3*t + cos(t)"
64 """
65 channel_func_map = {}
66 for channel_name, meta in metadata.extra.items():
67 if "func" in meta:
68 expr = parse_expr(meta["func"])
69 func = lambdify(t, expr, "numpy")
70 else:
71 func = _func_random_normal
72 channel_func_map[channel_name] = func
73 return channel_func_map
76class MockBackend(ServerBackend):
77 """Mock server backend that generates synthetic timeseries data.
79 If channel definition files are not provided then the Mock backend
80 will serve a pre-defined set of H1: and L1: channels.
82 """
84 def __init__(self, channel_files: list[Path] | None = None):
85 """Initialize mock server with list of channel definition files."""
86 if not channel_files:
87 with (
88 resources.as_file(
89 resources.files(channel_lists).joinpath("H1_channels.toml")
90 ) as H1_file,
91 resources.as_file(
92 resources.files(channel_lists).joinpath("L1_channels.toml")
93 ) as L1_file,
94 ):
95 channel_files = [H1_file, L1_file]
96 self.metadata = ChannelConfigBackend()
97 for channel_file in channel_files:
98 self.metadata.load(channel_file)
99 self._channel_func_map = load_channel_funcs(self.metadata)
101 self.scope_info = ScopeInfo(self.metadata.scopes, Retention())
103 @staticmethod
104 def add_arguments(parser: argparse.ArgumentParser):
105 """add arguments for this backend to an argparse subparser"""
106 parser.add_argument(
107 "channel_files",
108 metavar="CHANNELS.toml",
109 nargs="*",
110 type=Path,
111 help="Channel definition TOML file (may be specified multiple times).",
112 )
114 @classmethod
115 def from_args(cls, args: argparse.Namespace) -> MockBackend:
116 """initialize class from argparse namespace"""
117 return cls(args.channel_files)
119 def find(
120 self,
121 *,
122 pattern: str,
123 data_type: list[str],
124 min_rate: int,
125 max_rate: int,
126 publisher: list[str],
127 ) -> Iterable[Channel]:
128 assert isinstance(self.metadata, ChannelConfigBackend)
129 return self.metadata.find(
130 pattern=pattern,
131 data_type=data_type,
132 min_rate=min_rate,
133 max_rate=max_rate,
134 publisher=publisher,
135 )
137 def describe(self, *, channels: Iterable[str]) -> Iterable[Channel]:
138 self._check_channels(channels)
139 assert isinstance(self.metadata, ChannelConfigBackend)
140 return self.metadata.describe(channels=channels)
142 def stream(
143 self, *, channels: Iterable[str], start: int, end: int
144 ) -> Iterator[SeriesBlock]:
145 self._check_channels(channels)
146 return self._generate_series(channels, start, end)
148 def _check_channels(self, channels: Iterable[str]):
149 bad_channels = [
150 channel for channel in channels if channel not in self.metadata.metadata
151 ]
152 if bad_channels:
153 # FIXME: is this the correct error to return?
154 raise flight.FlightServerError(
155 f"the following channels are not available on this server: {bad_channels}" # noqa E501
156 )
158 def _generate_block(self, channels: Iterable[str], timestamp: int) -> SeriesBlock:
159 assert isinstance(self.metadata, ChannelConfigBackend)
160 channel_data = {}
161 channel_dict = {}
162 for channel in channels:
163 metadata = self.metadata.metadata[channel]
164 rate = metadata.sample_rate
165 assert rate is not None
166 size = rate // 16
167 dtype = metadata.data_type
168 time_array = (timestamp / Time.SECONDS) + numpy.arange(size) / rate
169 func = self._channel_func_map[channel]
170 data = numpy.array(
171 numpy.broadcast_to(func(time_array), time_array.shape),
172 dtype=dtype,
173 )
174 channel_data[channel] = data
175 channel_dict[channel] = metadata
177 return SeriesBlock(timestamp, channel_data, channel_dict)
179 def _generate_series(
180 self,
181 channels: Iterable[str],
182 start: int | None,
183 end: int | None,
184 ) -> Iterator[SeriesBlock]:
185 dt = Time.SECONDS // 16
187 if start:
188 current = start
189 else:
190 current = (int(gpstime.gpsnow() * Time.SECONDS) // dt) * dt
192 def _run(current: int):
193 if end:
194 return current < end
195 return True
197 while _run(current):
198 yield self._generate_block(channels, current)
199 current += dt
200 now = int(gpstime.gpsnow() * Time.SECONDS)
201 if current >= now:
202 # sleep for up to dt to simulate live stream
203 time.sleep(max((current - now) / Time.SECONDS, 0))