Coverage for arrakis_server/server.py: 68.3%
224 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-12 16:39 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-12 16:39 -0700
1# Copyright (c) 2022, California Institute of Technology and contributors
2#
3# You should have received a copy of the licensing terms for this
4# software included in the file "LICENSE" located in the top-level
5# directory of this package. If you did not, you can view a copy at
6# https://git.ligo.org/ngdd/arrakis-server/-/raw/main/LICENSE
8import logging
9import threading
10from collections.abc import Iterable, Iterator
11from functools import wraps
12from urllib.parse import urlparse
14import numpy
15import pyarrow
16from arrakis import SeriesBlock
17from arrakis.flight import RequestType, RequestValidator, create_command, parse_command
18from pyarrow import flight
20from . import arrow, constants, schemas, traits
21from .channel import Channel
22from .scope import ScopeMap
24logger = logging.getLogger("arrakis")
27def exception_catcher(func):
28 """decorator to catch uncaught exceptions in FlightServer
30 The exception is logged and a FlightInternalError is raised for
31 the client.
33 """
35 @wraps(func)
36 def wrapper(*args, **kwargs):
37 try:
38 return func(*args, **kwargs)
39 except flight.FlightError:
40 raise
41 except Exception:
42 logger.exception(
43 "internal server error: %s, %s, %s",
44 func,
45 args,
46 kwargs,
47 )
48 # FIXME: provide better admin contact into
49 raise flight.FlightInternalError(
50 "internal server error, please contact server admin"
51 )
53 return wrapper
56def parse_url(url: str | tuple[str, int] | None):
57 """Parse a URL into a valid location for the FlightServer
59 Returns a tuple of (hostname, port).
61 """
62 if url is None:
63 return None
64 if url in ["-", "0"]:
65 return None
66 if isinstance(url, tuple):
67 url = "//%s:%s" % url
68 parsed = urlparse(url, scheme="grpc")
69 if parsed.hostname is None:
70 parsed = urlparse("//" + url, scheme="grpc")
71 return parsed.hostname, parsed.port
74class ArrakisFlightServer(flight.FlightServerBase):
75 """Arrow Flight server implementation to server timeseries.
77 Parameters
78 ----------
79 url : str, optional
80 The URL at which to serve Flight requests. Either URI
81 (e.g. grpc://localhost:port) or (host, port) tuple. If None
82 then server will be started on localport with a
83 system-provided random port. Default is to bind to all
84 available interfaces on port 31206.
85 backend: ServerBackend, optional
86 An instantiated backend providing data to serve.
87 scope_map: ScopeMap, optional
88 Scope map for available flight endpoints.
90 """
92 def __init__(
93 self,
94 url: str | tuple[str, int] | None = constants.DEFAULT_LOCATION,
95 backend: traits.ServerBackend | traits.PublishServerBackend | None = None,
96 scope_map: ScopeMap | None = None,
97 **kwargs,
98 ):
99 self._location = parse_url(url)
100 self._is_stopped = threading.Event()
101 self._validator = RequestValidator()
103 if not backend and not scope_map:
104 raise ValueError("nothing to serve, must specify scope map and/or backend")
106 self._backend = backend
107 logger.info("backend: %s", self._backend)
109 self._scope_map = scope_map or ScopeMap()
110 self._scope_map.sync_local_map(self._backend, constants.FLIGHT_REUSE_URL)
111 logger.info("scope map: %s", self._scope_map)
113 super().__init__(self._location, **kwargs)
114 if self._location is None:
115 self._location = ("127.0.0.1", self.port)
116 else:
117 self._location = (self._location[0], self.port)
118 logger.info("URL: %s", self.url)
120 @property
121 def url(self):
122 return "grpc://%s:%s" % self._location
124 @exception_catcher
125 def list_flights(
126 self, context: flight.ServerCallContext, criteria: bytes
127 ) -> Iterator[flight.FlightInfo]:
128 """List flights available on this service.
130 Parameters
131 ----------
132 context : ServerCallContext
133 Common contextual information.
134 criteria : bytes
135 Filter criteria provided by the client.
137 Yields
138 ------
139 FlightInfo
141 """
142 logger.debug("serving list_flights for %s", context.peer())
143 for channel in self._channels:
144 yield self.make_flight_info(
145 create_command(
146 RequestType.Stream,
147 channels=[channel.name],
148 validator=self._validator,
149 )
150 )
152 @exception_catcher
153 def get_flight_info(
154 self, context: flight.ServerCallContext, descriptor: flight.FlightDescriptor
155 ) -> flight.FlightInfo:
156 """Get information about a flight.
158 Parameters
159 ----------
160 context : ServerCallContext
161 Common contextual information.
162 descriptor : FlightDescriptor
163 The descriptor for the flight provided by the client.
165 Returns
166 -------
167 FlightInfo
169 """
170 logger.debug("serving get_flight_info for %s", context.peer())
171 return self.make_flight_info(descriptor.command)
173 @exception_catcher
174 def do_exchange(
175 self,
176 context: flight.ServerCallContext,
177 descriptor: flight.FlightDescriptor,
178 reader: flight.MetadataRecordBatchReader,
179 writer: flight.MetadataRecordBatchWriter,
180 ) -> flight.FlightDataStream:
181 """Write data to a flight.
183 Parameters
184 ----------
185 context : ServerCallContext
186 Common contextual information.
187 ticket : Ticket
188 The ticket for the flight.
190 Returns
191 -------
192 FlightDataStream
193 A stream of data to send back to the client.
195 """
196 if not self._backend:
197 raise flight.FlightServerError(
198 "DoExchange requests unavailable from this server"
199 )
200 logger.debug("serving DoExchange request for %s", context.peer())
201 return self.process_exchange_request(descriptor, reader, writer)
203 def process_exchange_request(
204 self,
205 descriptor: flight.FlightDescriptor,
206 reader: flight.MetadataRecordBatchReader,
207 writer: flight.MetadataRecordBatchWriter,
208 ) -> flight.FlightDataStream:
209 """Write data to a flight.
211 Parameters
212 ----------
213 context : ServerCallContext
214 Common contextual information.
215 ticket : Ticket
216 The ticket for the flight.
218 Returns
219 -------
220 FlightDataStream
221 A stream of data to send back to the client.
223 """
224 assert self._backend
225 request, kwargs = parse_command(descriptor.command, validator=self._validator)
226 logger.debug("serving DoExchange %s request", request.name)
227 match request:
228 case RequestType.Partition:
229 if not traits.can_publish(self._backend):
230 raise flight.FlightError(
231 "partition not supported for server backend"
232 )
233 return self._partition(reader, writer, **kwargs)
234 case _:
235 raise flight.FlightError("request type not valid")
237 @exception_catcher
238 def do_get(
239 self, context: flight.ServerCallContext, ticket: flight.Ticket
240 ) -> flight.FlightDataStream:
241 """Write data to a flight.
243 Parameters
244 ----------
245 context : ServerCallContext
246 Common contextual information.
247 ticket : Ticket
248 The ticket for the flight.
250 Returns
251 -------
252 FlightDataStream
253 A stream of data to send back to the client.
255 """
256 if not self._backend:
257 raise flight.FlightServerError(
258 "DoGet requests unavailable from this server"
259 )
260 logger.debug("serving DoGet request for %s", context.peer())
261 return self.process_get_request(context, ticket)
263 def process_get_request(
264 self, context: flight.ServerCallContext | None, ticket: flight.Ticket
265 ) -> flight.FlightDataStream:
266 """Write data to a flight.
268 Parameters
269 ----------
270 context : ServerCallContext
271 Common contextual information.
272 ticket : Ticket
273 The ticket for the flight.
275 Returns
276 -------
277 FlightDataStream
278 A stream of data to send back to the client.
280 """
281 assert self._backend
282 request, kwargs = parse_command(ticket.ticket, validator=self._validator)
283 logger.debug("serving DoGet %s request", request.name)
284 match request:
285 case RequestType.Stream:
286 return self._stream(context, **kwargs)
287 case RequestType.Describe:
288 return self._describe(**kwargs)
289 case RequestType.Find:
290 return self._find(**kwargs)
291 case RequestType.Count:
292 return self._count(**kwargs)
293 case RequestType.Publish:
294 if not traits.can_publish(self._backend):
295 raise flight.FlightError("publish not supported for this backend")
296 return self._publish(**kwargs)
297 case _:
298 raise flight.FlightServerError("request type not valid")
300 @exception_catcher
301 def list_actions(
302 self, context: flight.ServerCallContext
303 ) -> Iterable[tuple[str, str]]:
304 """List custom actions available on this server.
306 Parameters
307 ----------
308 context : ServerCallContext
309 Common contextual information.
311 Returns
312 -------
313 Iterable of 2-tuples in the form (command, description).
315 """
316 if traits.can_publish(self._backend):
317 logger.debug("serving list_actions for %s", context.peer())
318 return [("publish", "Request to publish data.")]
319 else:
320 return []
322 @exception_catcher
323 def do_action(
324 self, context: flight.ServerCallContext, action: flight.Action
325 ) -> Iterator[bytes]:
326 """Execute a custom action.
328 Parameters
329 ----------
330 context : ServerCallContext
331 Common contextual information.
332 action : Action
333 The action to execute.
335 Yields
336 ------
337 bytes
339 """
340 logger.debug("serving %s action for %s", action.type, context.peer())
341 match action.type:
342 case _:
343 raise flight.FlightError("action not valid")
345 def _construct_endpoints(self, cmd: bytes) -> list[flight.FlightEndpoint]:
346 endpoints = []
347 request, kwargs = parse_command(cmd, validator=self._validator)
349 # publish requests can only be served directly to local endpoints,
350 # and we have validated that this server can accept publish requests
351 if request in {RequestType.Publish, RequestType.Partition}:
352 return [flight.FlightEndpoint(cmd, [constants.FLIGHT_REUSE_URL])]
354 # filter source map by time retention
355 if request is RequestType.Stream:
356 scope_map = self._scope_map.filter_by_range(
357 start=kwargs["start"],
358 end=kwargs["end"],
359 )
360 else:
361 scope_map = self._scope_map
363 # map channels to endpoints
364 if request in {RequestType.Stream, RequestType.Describe}:
365 endpoints_for_channels = scope_map.endpoints_for_channels(
366 kwargs["channels"]
367 )
368 for channels, locations in endpoints_for_channels:
369 kwargs["channels"] = channels
370 ticket = create_command(request, validator=self._validator, **kwargs)
371 endpoints.append(flight.FlightEndpoint(ticket, locations))
373 elif request in {RequestType.Find, RequestType.Count}:
374 ticket = create_command(request, validator=self._validator, **kwargs)
375 # ultimately we want to look at the request and use that to pair
376 # down the endpoints we return.
377 # for now just make sure to only send unique locations back
378 location_set = set()
379 for domain in scope_map.domains:
380 locations = scope_map.endpoints_for_domain(domain)
381 for entry in locations:
382 location_set.add(entry)
383 for entry in location_set:
384 endpoints.append(flight.FlightEndpoint(ticket, [entry]))
386 else:
387 raise flight.FlightError(f"Unknown request: {request}")
389 if not endpoints:
390 raise flight.FlightError("Could not find channels on any known endpoints.")
392 return endpoints
394 def make_flight_info(self, cmd: bytes) -> flight.FlightInfo:
395 """Create Arrow Flight stream descriptions from commands.
397 Parameters
398 ----------
399 cmd : bytes
400 The opaque command to parse.
402 Returns
403 -------
404 flight.FlightInfo
405 The Arrow Flight stream description describing the command.
407 """
408 endpoints = self._construct_endpoints(cmd)
409 descriptor = flight.FlightDescriptor.for_path(cmd)
411 request, args = parse_command(cmd, validator=self._validator)
413 match request:
414 case RequestType.Stream:
415 if self._backend:
416 channels = self._backend.describe(channels=args["channels"])
417 else:
418 # create dummy channel metadata for each channel as the client
419 # does not use this when making a request. by doing this,
420 # information servers do not need to store all metadata for the
421 # channels for a domain it could potentially serve, but instead
422 # can delegate to the endpoints it does know about
423 channels = [
424 Channel(name, data_type=numpy.dtype("int32"), sample_rate=32)
425 for name in args["channels"]
426 ]
427 schema = schemas.stream(channels)
428 case RequestType.Describe:
429 schema = schemas.describe()
430 case RequestType.Find:
431 schema = schemas.find()
432 case RequestType.Count:
433 schema = schemas.count()
434 case RequestType.Publish:
435 if not traits.can_publish(self._backend):
436 raise flight.FlightError("publish not supported for server backend")
437 schema = schemas.publish()
438 case RequestType.Partition:
439 if not traits.can_publish(self._backend):
440 raise flight.FlightError(
441 "partition not supported for server backend"
442 )
443 schema = schemas.partition()
444 case _:
445 raise flight.FlightError("command not understood")
447 return flight.FlightInfo(schema, descriptor, endpoints, -1, -1)
449 def shutdown(self) -> None:
450 """Shut down the server."""
451 self._is_stopped.set()
452 return super().shutdown()
454 def wait_until_shutdown(self) -> None:
455 """Wait until the server receives a shutdown request."""
456 self._is_stopped.wait()
458 def _find(
459 self,
460 *,
461 pattern: str,
462 data_type: list[str],
463 min_rate: int,
464 max_rate: int,
465 publisher: list[str],
466 ) -> flight.FlightDataStream:
467 """Serve Flight data for the 'find' route."""
468 assert isinstance(self._backend, traits.ServerBackend)
469 metadata = self._backend.find(
470 pattern=pattern,
471 data_type=data_type,
472 min_rate=min_rate,
473 max_rate=max_rate,
474 publisher=publisher,
475 )
476 return arrow.create_metadata_stream(metadata)
478 def _count(
479 self,
480 *,
481 pattern: str,
482 data_type: list[str],
483 min_rate: int,
484 max_rate: int,
485 publisher: list[str],
486 ) -> flight.FlightDataStream:
487 """Serve Flight data for the 'count' route."""
488 assert isinstance(self._backend, traits.ServerBackend)
489 count = self._backend.count(
490 pattern=pattern,
491 data_type=data_type,
492 min_rate=min_rate,
493 max_rate=max_rate,
494 publisher=publisher,
495 )
496 schema = schemas.count()
497 batch = pyarrow.RecordBatch.from_arrays(
498 [
499 pyarrow.array(
500 [count],
501 type=schema.field("count").type,
502 ),
503 ],
504 schema=schema,
505 )
506 return flight.RecordBatchStream(
507 pyarrow.RecordBatchReader.from_batches(schema, [batch])
508 )
510 def _describe(self, *, channels: Iterable[str]) -> flight.FlightDataStream:
511 """Serve Flight data for the 'describe' route."""
512 assert isinstance(self._backend, traits.ServerBackend)
513 metadata = self._backend.describe(channels=channels)
514 return arrow.create_metadata_stream(metadata)
516 def _stream(
517 self,
518 context: flight.ServerCallContext | None,
519 *,
520 channels: Iterable[str],
521 start: int,
522 end: int,
523 ) -> flight.FlightDataStream:
524 """Serve Flight data for the 'stream' route."""
525 assert isinstance(self._backend, traits.ServerBackend)
527 metadata = self._backend.describe(channels=channels)
528 schema = schemas.stream(metadata)
529 blocks = self._backend.stream(channels=channels, start=start, end=end)
530 return flight.GeneratorStream(
531 schema,
532 self._generate_stream(context, schema, blocks),
533 )
535 def _publish(self, *, publisher_id: str) -> flight.FlightDataStream:
536 """Serve Flight data for the 'publish' route."""
537 assert traits.can_publish(self._backend)
538 schema = schemas.publish()
539 info = self._backend.publish(publisher_id=publisher_id)
540 batch = pyarrow.RecordBatch.from_arrays(
541 [
542 pyarrow.array(
543 [info],
544 type=schema.field("properties").type,
545 ),
546 ],
547 schema=schema,
548 )
549 return flight.RecordBatchStream(
550 pyarrow.RecordBatchReader.from_batches(schema, [batch])
551 )
553 def _partition(self, reader, writer, *, publisher_id: str) -> None:
554 """Exchange Flight data for the 'partition' route."""
555 assert traits.can_publish(self._backend)
556 schema = schemas.partition()
558 # read metadata from client
559 channels = []
560 for batch in arrow.read_all_chunks(reader):
561 for meta in batch.to_pylist():
562 data_type = numpy.dtype(meta["data_type"])
563 channel = Channel(
564 meta["channel"],
565 sample_rate=meta["sample_rate"],
566 data_type=data_type,
567 publisher=publisher_id,
568 )
569 channels.append(channel)
571 # partition channels
572 channels = list(
573 self._backend.partition(channels=channels, publisher_id=publisher_id)
574 )
576 # prepare the batch with mappings
577 batches = arrow.create_partition_batches(channels)
579 # send partitions back to the client
580 writer.begin(schema)
581 for batch in batches:
582 writer.write_batch(batch)
583 writer.close()
585 def _generate_stream(
586 self,
587 context: flight.ServerCallContext | None,
588 schema: pyarrow.Schema,
589 blocks: Iterator[SeriesBlock],
590 ) -> Iterator[pyarrow.RecordBatch]:
591 """Generate a record batch stream which can be stopped."""
592 for block in blocks:
593 if self._is_stopped.is_set() or (context and context.is_cancelled()):
594 return
595 yield block.to_column_batch(schema)