Coverage for arrakis_server/server.py: 68.3%

224 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-08-12 16:39 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-server/-/raw/main/LICENSE 

7 

8import logging 

9import threading 

10from collections.abc import Iterable, Iterator 

11from functools import wraps 

12from urllib.parse import urlparse 

13 

14import numpy 

15import pyarrow 

16from arrakis import SeriesBlock 

17from arrakis.flight import RequestType, RequestValidator, create_command, parse_command 

18from pyarrow import flight 

19 

20from . import arrow, constants, schemas, traits 

21from .channel import Channel 

22from .scope import ScopeMap 

23 

24logger = logging.getLogger("arrakis") 

25 

26 

27def exception_catcher(func): 

28 """decorator to catch uncaught exceptions in FlightServer 

29 

30 The exception is logged and a FlightInternalError is raised for 

31 the client. 

32 

33 """ 

34 

35 @wraps(func) 

36 def wrapper(*args, **kwargs): 

37 try: 

38 return func(*args, **kwargs) 

39 except flight.FlightError: 

40 raise 

41 except Exception: 

42 logger.exception( 

43 "internal server error: %s, %s, %s", 

44 func, 

45 args, 

46 kwargs, 

47 ) 

48 # FIXME: provide better admin contact into 

49 raise flight.FlightInternalError( 

50 "internal server error, please contact server admin" 

51 ) 

52 

53 return wrapper 

54 

55 

56def parse_url(url: str | tuple[str, int] | None): 

57 """Parse a URL into a valid location for the FlightServer 

58 

59 Returns a tuple of (hostname, port). 

60 

61 """ 

62 if url is None: 

63 return None 

64 if url in ["-", "0"]: 

65 return None 

66 if isinstance(url, tuple): 

67 url = "//%s:%s" % url 

68 parsed = urlparse(url, scheme="grpc") 

69 if parsed.hostname is None: 

70 parsed = urlparse("//" + url, scheme="grpc") 

71 return parsed.hostname, parsed.port 

72 

73 

74class ArrakisFlightServer(flight.FlightServerBase): 

75 """Arrow Flight server implementation to server timeseries. 

76 

77 Parameters 

78 ---------- 

79 url : str, optional 

80 The URL at which to serve Flight requests. Either URI 

81 (e.g. grpc://localhost:port) or (host, port) tuple. If None 

82 then server will be started on localport with a 

83 system-provided random port. Default is to bind to all 

84 available interfaces on port 31206. 

85 backend: ServerBackend, optional 

86 An instantiated backend providing data to serve. 

87 scope_map: ScopeMap, optional 

88 Scope map for available flight endpoints. 

89 

90 """ 

91 

92 def __init__( 

93 self, 

94 url: str | tuple[str, int] | None = constants.DEFAULT_LOCATION, 

95 backend: traits.ServerBackend | traits.PublishServerBackend | None = None, 

96 scope_map: ScopeMap | None = None, 

97 **kwargs, 

98 ): 

99 self._location = parse_url(url) 

100 self._is_stopped = threading.Event() 

101 self._validator = RequestValidator() 

102 

103 if not backend and not scope_map: 

104 raise ValueError("nothing to serve, must specify scope map and/or backend") 

105 

106 self._backend = backend 

107 logger.info("backend: %s", self._backend) 

108 

109 self._scope_map = scope_map or ScopeMap() 

110 self._scope_map.sync_local_map(self._backend, constants.FLIGHT_REUSE_URL) 

111 logger.info("scope map: %s", self._scope_map) 

112 

113 super().__init__(self._location, **kwargs) 

114 if self._location is None: 

115 self._location = ("127.0.0.1", self.port) 

116 else: 

117 self._location = (self._location[0], self.port) 

118 logger.info("URL: %s", self.url) 

119 

120 @property 

121 def url(self): 

122 return "grpc://%s:%s" % self._location 

123 

124 @exception_catcher 

125 def list_flights( 

126 self, context: flight.ServerCallContext, criteria: bytes 

127 ) -> Iterator[flight.FlightInfo]: 

128 """List flights available on this service. 

129 

130 Parameters 

131 ---------- 

132 context : ServerCallContext 

133 Common contextual information. 

134 criteria : bytes 

135 Filter criteria provided by the client. 

136 

137 Yields 

138 ------ 

139 FlightInfo 

140 

141 """ 

142 logger.debug("serving list_flights for %s", context.peer()) 

143 for channel in self._channels: 

144 yield self.make_flight_info( 

145 create_command( 

146 RequestType.Stream, 

147 channels=[channel.name], 

148 validator=self._validator, 

149 ) 

150 ) 

151 

152 @exception_catcher 

153 def get_flight_info( 

154 self, context: flight.ServerCallContext, descriptor: flight.FlightDescriptor 

155 ) -> flight.FlightInfo: 

156 """Get information about a flight. 

157 

158 Parameters 

159 ---------- 

160 context : ServerCallContext 

161 Common contextual information. 

162 descriptor : FlightDescriptor 

163 The descriptor for the flight provided by the client. 

164 

165 Returns 

166 ------- 

167 FlightInfo 

168 

169 """ 

170 logger.debug("serving get_flight_info for %s", context.peer()) 

171 return self.make_flight_info(descriptor.command) 

172 

173 @exception_catcher 

174 def do_exchange( 

175 self, 

176 context: flight.ServerCallContext, 

177 descriptor: flight.FlightDescriptor, 

178 reader: flight.MetadataRecordBatchReader, 

179 writer: flight.MetadataRecordBatchWriter, 

180 ) -> flight.FlightDataStream: 

181 """Write data to a flight. 

182 

183 Parameters 

184 ---------- 

185 context : ServerCallContext 

186 Common contextual information. 

187 ticket : Ticket 

188 The ticket for the flight. 

189 

190 Returns 

191 ------- 

192 FlightDataStream 

193 A stream of data to send back to the client. 

194 

195 """ 

196 if not self._backend: 

197 raise flight.FlightServerError( 

198 "DoExchange requests unavailable from this server" 

199 ) 

200 logger.debug("serving DoExchange request for %s", context.peer()) 

201 return self.process_exchange_request(descriptor, reader, writer) 

202 

203 def process_exchange_request( 

204 self, 

205 descriptor: flight.FlightDescriptor, 

206 reader: flight.MetadataRecordBatchReader, 

207 writer: flight.MetadataRecordBatchWriter, 

208 ) -> flight.FlightDataStream: 

209 """Write data to a flight. 

210 

211 Parameters 

212 ---------- 

213 context : ServerCallContext 

214 Common contextual information. 

215 ticket : Ticket 

216 The ticket for the flight. 

217 

218 Returns 

219 ------- 

220 FlightDataStream 

221 A stream of data to send back to the client. 

222 

223 """ 

224 assert self._backend 

225 request, kwargs = parse_command(descriptor.command, validator=self._validator) 

226 logger.debug("serving DoExchange %s request", request.name) 

227 match request: 

228 case RequestType.Partition: 

229 if not traits.can_publish(self._backend): 

230 raise flight.FlightError( 

231 "partition not supported for server backend" 

232 ) 

233 return self._partition(reader, writer, **kwargs) 

234 case _: 

235 raise flight.FlightError("request type not valid") 

236 

237 @exception_catcher 

238 def do_get( 

239 self, context: flight.ServerCallContext, ticket: flight.Ticket 

240 ) -> flight.FlightDataStream: 

241 """Write data to a flight. 

242 

243 Parameters 

244 ---------- 

245 context : ServerCallContext 

246 Common contextual information. 

247 ticket : Ticket 

248 The ticket for the flight. 

249 

250 Returns 

251 ------- 

252 FlightDataStream 

253 A stream of data to send back to the client. 

254 

255 """ 

256 if not self._backend: 

257 raise flight.FlightServerError( 

258 "DoGet requests unavailable from this server" 

259 ) 

260 logger.debug("serving DoGet request for %s", context.peer()) 

261 return self.process_get_request(context, ticket) 

262 

263 def process_get_request( 

264 self, context: flight.ServerCallContext | None, ticket: flight.Ticket 

265 ) -> flight.FlightDataStream: 

266 """Write data to a flight. 

267 

268 Parameters 

269 ---------- 

270 context : ServerCallContext 

271 Common contextual information. 

272 ticket : Ticket 

273 The ticket for the flight. 

274 

275 Returns 

276 ------- 

277 FlightDataStream 

278 A stream of data to send back to the client. 

279 

280 """ 

281 assert self._backend 

282 request, kwargs = parse_command(ticket.ticket, validator=self._validator) 

283 logger.debug("serving DoGet %s request", request.name) 

284 match request: 

285 case RequestType.Stream: 

286 return self._stream(context, **kwargs) 

287 case RequestType.Describe: 

288 return self._describe(**kwargs) 

289 case RequestType.Find: 

290 return self._find(**kwargs) 

291 case RequestType.Count: 

292 return self._count(**kwargs) 

293 case RequestType.Publish: 

294 if not traits.can_publish(self._backend): 

295 raise flight.FlightError("publish not supported for this backend") 

296 return self._publish(**kwargs) 

297 case _: 

298 raise flight.FlightServerError("request type not valid") 

299 

300 @exception_catcher 

301 def list_actions( 

302 self, context: flight.ServerCallContext 

303 ) -> Iterable[tuple[str, str]]: 

304 """List custom actions available on this server. 

305 

306 Parameters 

307 ---------- 

308 context : ServerCallContext 

309 Common contextual information. 

310 

311 Returns 

312 ------- 

313 Iterable of 2-tuples in the form (command, description). 

314 

315 """ 

316 if traits.can_publish(self._backend): 

317 logger.debug("serving list_actions for %s", context.peer()) 

318 return [("publish", "Request to publish data.")] 

319 else: 

320 return [] 

321 

322 @exception_catcher 

323 def do_action( 

324 self, context: flight.ServerCallContext, action: flight.Action 

325 ) -> Iterator[bytes]: 

326 """Execute a custom action. 

327 

328 Parameters 

329 ---------- 

330 context : ServerCallContext 

331 Common contextual information. 

332 action : Action 

333 The action to execute. 

334 

335 Yields 

336 ------ 

337 bytes 

338 

339 """ 

340 logger.debug("serving %s action for %s", action.type, context.peer()) 

341 match action.type: 

342 case _: 

343 raise flight.FlightError("action not valid") 

344 

345 def _construct_endpoints(self, cmd: bytes) -> list[flight.FlightEndpoint]: 

346 endpoints = [] 

347 request, kwargs = parse_command(cmd, validator=self._validator) 

348 

349 # publish requests can only be served directly to local endpoints, 

350 # and we have validated that this server can accept publish requests 

351 if request in {RequestType.Publish, RequestType.Partition}: 

352 return [flight.FlightEndpoint(cmd, [constants.FLIGHT_REUSE_URL])] 

353 

354 # filter source map by time retention 

355 if request is RequestType.Stream: 

356 scope_map = self._scope_map.filter_by_range( 

357 start=kwargs["start"], 

358 end=kwargs["end"], 

359 ) 

360 else: 

361 scope_map = self._scope_map 

362 

363 # map channels to endpoints 

364 if request in {RequestType.Stream, RequestType.Describe}: 

365 endpoints_for_channels = scope_map.endpoints_for_channels( 

366 kwargs["channels"] 

367 ) 

368 for channels, locations in endpoints_for_channels: 

369 kwargs["channels"] = channels 

370 ticket = create_command(request, validator=self._validator, **kwargs) 

371 endpoints.append(flight.FlightEndpoint(ticket, locations)) 

372 

373 elif request in {RequestType.Find, RequestType.Count}: 

374 ticket = create_command(request, validator=self._validator, **kwargs) 

375 # ultimately we want to look at the request and use that to pair 

376 # down the endpoints we return. 

377 # for now just make sure to only send unique locations back 

378 location_set = set() 

379 for domain in scope_map.domains: 

380 locations = scope_map.endpoints_for_domain(domain) 

381 for entry in locations: 

382 location_set.add(entry) 

383 for entry in location_set: 

384 endpoints.append(flight.FlightEndpoint(ticket, [entry])) 

385 

386 else: 

387 raise flight.FlightError(f"Unknown request: {request}") 

388 

389 if not endpoints: 

390 raise flight.FlightError("Could not find channels on any known endpoints.") 

391 

392 return endpoints 

393 

394 def make_flight_info(self, cmd: bytes) -> flight.FlightInfo: 

395 """Create Arrow Flight stream descriptions from commands. 

396 

397 Parameters 

398 ---------- 

399 cmd : bytes 

400 The opaque command to parse. 

401 

402 Returns 

403 ------- 

404 flight.FlightInfo 

405 The Arrow Flight stream description describing the command. 

406 

407 """ 

408 endpoints = self._construct_endpoints(cmd) 

409 descriptor = flight.FlightDescriptor.for_path(cmd) 

410 

411 request, args = parse_command(cmd, validator=self._validator) 

412 

413 match request: 

414 case RequestType.Stream: 

415 if self._backend: 

416 channels = self._backend.describe(channels=args["channels"]) 

417 else: 

418 # create dummy channel metadata for each channel as the client 

419 # does not use this when making a request. by doing this, 

420 # information servers do not need to store all metadata for the 

421 # channels for a domain it could potentially serve, but instead 

422 # can delegate to the endpoints it does know about 

423 channels = [ 

424 Channel(name, data_type=numpy.dtype("int32"), sample_rate=32) 

425 for name in args["channels"] 

426 ] 

427 schema = schemas.stream(channels) 

428 case RequestType.Describe: 

429 schema = schemas.describe() 

430 case RequestType.Find: 

431 schema = schemas.find() 

432 case RequestType.Count: 

433 schema = schemas.count() 

434 case RequestType.Publish: 

435 if not traits.can_publish(self._backend): 

436 raise flight.FlightError("publish not supported for server backend") 

437 schema = schemas.publish() 

438 case RequestType.Partition: 

439 if not traits.can_publish(self._backend): 

440 raise flight.FlightError( 

441 "partition not supported for server backend" 

442 ) 

443 schema = schemas.partition() 

444 case _: 

445 raise flight.FlightError("command not understood") 

446 

447 return flight.FlightInfo(schema, descriptor, endpoints, -1, -1) 

448 

449 def shutdown(self) -> None: 

450 """Shut down the server.""" 

451 self._is_stopped.set() 

452 return super().shutdown() 

453 

454 def wait_until_shutdown(self) -> None: 

455 """Wait until the server receives a shutdown request.""" 

456 self._is_stopped.wait() 

457 

458 def _find( 

459 self, 

460 *, 

461 pattern: str, 

462 data_type: list[str], 

463 min_rate: int, 

464 max_rate: int, 

465 publisher: list[str], 

466 ) -> flight.FlightDataStream: 

467 """Serve Flight data for the 'find' route.""" 

468 assert isinstance(self._backend, traits.ServerBackend) 

469 metadata = self._backend.find( 

470 pattern=pattern, 

471 data_type=data_type, 

472 min_rate=min_rate, 

473 max_rate=max_rate, 

474 publisher=publisher, 

475 ) 

476 return arrow.create_metadata_stream(metadata) 

477 

478 def _count( 

479 self, 

480 *, 

481 pattern: str, 

482 data_type: list[str], 

483 min_rate: int, 

484 max_rate: int, 

485 publisher: list[str], 

486 ) -> flight.FlightDataStream: 

487 """Serve Flight data for the 'count' route.""" 

488 assert isinstance(self._backend, traits.ServerBackend) 

489 count = self._backend.count( 

490 pattern=pattern, 

491 data_type=data_type, 

492 min_rate=min_rate, 

493 max_rate=max_rate, 

494 publisher=publisher, 

495 ) 

496 schema = schemas.count() 

497 batch = pyarrow.RecordBatch.from_arrays( 

498 [ 

499 pyarrow.array( 

500 [count], 

501 type=schema.field("count").type, 

502 ), 

503 ], 

504 schema=schema, 

505 ) 

506 return flight.RecordBatchStream( 

507 pyarrow.RecordBatchReader.from_batches(schema, [batch]) 

508 ) 

509 

510 def _describe(self, *, channels: Iterable[str]) -> flight.FlightDataStream: 

511 """Serve Flight data for the 'describe' route.""" 

512 assert isinstance(self._backend, traits.ServerBackend) 

513 metadata = self._backend.describe(channels=channels) 

514 return arrow.create_metadata_stream(metadata) 

515 

516 def _stream( 

517 self, 

518 context: flight.ServerCallContext | None, 

519 *, 

520 channels: Iterable[str], 

521 start: int, 

522 end: int, 

523 ) -> flight.FlightDataStream: 

524 """Serve Flight data for the 'stream' route.""" 

525 assert isinstance(self._backend, traits.ServerBackend) 

526 

527 metadata = self._backend.describe(channels=channels) 

528 schema = schemas.stream(metadata) 

529 blocks = self._backend.stream(channels=channels, start=start, end=end) 

530 return flight.GeneratorStream( 

531 schema, 

532 self._generate_stream(context, schema, blocks), 

533 ) 

534 

535 def _publish(self, *, publisher_id: str) -> flight.FlightDataStream: 

536 """Serve Flight data for the 'publish' route.""" 

537 assert traits.can_publish(self._backend) 

538 schema = schemas.publish() 

539 info = self._backend.publish(publisher_id=publisher_id) 

540 batch = pyarrow.RecordBatch.from_arrays( 

541 [ 

542 pyarrow.array( 

543 [info], 

544 type=schema.field("properties").type, 

545 ), 

546 ], 

547 schema=schema, 

548 ) 

549 return flight.RecordBatchStream( 

550 pyarrow.RecordBatchReader.from_batches(schema, [batch]) 

551 ) 

552 

553 def _partition(self, reader, writer, *, publisher_id: str) -> None: 

554 """Exchange Flight data for the 'partition' route.""" 

555 assert traits.can_publish(self._backend) 

556 schema = schemas.partition() 

557 

558 # read metadata from client 

559 channels = [] 

560 for batch in arrow.read_all_chunks(reader): 

561 for meta in batch.to_pylist(): 

562 data_type = numpy.dtype(meta["data_type"]) 

563 channel = Channel( 

564 meta["channel"], 

565 sample_rate=meta["sample_rate"], 

566 data_type=data_type, 

567 publisher=publisher_id, 

568 ) 

569 channels.append(channel) 

570 

571 # partition channels 

572 channels = list( 

573 self._backend.partition(channels=channels, publisher_id=publisher_id) 

574 ) 

575 

576 # prepare the batch with mappings 

577 batches = arrow.create_partition_batches(channels) 

578 

579 # send partitions back to the client 

580 writer.begin(schema) 

581 for batch in batches: 

582 writer.write_batch(batch) 

583 writer.close() 

584 

585 def _generate_stream( 

586 self, 

587 context: flight.ServerCallContext | None, 

588 schema: pyarrow.Schema, 

589 blocks: Iterator[SeriesBlock], 

590 ) -> Iterator[pyarrow.RecordBatch]: 

591 """Generate a record batch stream which can be stopped.""" 

592 for block in blocks: 

593 if self._is_stopped.is_set() or (context and context.is_cancelled()): 

594 return 

595 yield block.to_column_batch(schema)