Coverage for arrakis_server/arrow.py: 58.1%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-08-12 16:39 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-server/-/raw/main/LICENSE 

7 

8import itertools 

9import typing 

10from collections.abc import Generator, Iterable, Iterator 

11 

12import pyarrow 

13from arrakis import Channel 

14from pyarrow import flight 

15 

16from . import schemas 

17 

18T = typing.TypeVar("T") 

19 

20 

21def batched(iterable: Iterable[T], n: int) -> Generator[Iterable[T], None, None]: 

22 """An implementation of python 3.12's itertools.batches. 

23 Taken from the python documentation for itertools recipies 

24 Given an Iterable object iterable, generate a series of 

25 Iterators that return chunks of 'n' items from iterable. 

26 The last batch may be smaller than n entries. 

27 """ 

28 if n < 1: 

29 raise ValueError("n must be greater than or equal to 1") 

30 iterator = iter(iterable) 

31 while batch := tuple(itertools.islice(iterator, n)): 

32 yield batch 

33 

34 

35def create_partition_batches(channels: Iterable[Channel]) -> list[pyarrow.RecordBatch]: 

36 """Create record batches from channel metadata for partitioning.""" 

37 batches = [] 

38 schema = schemas.partition() 

39 for channel_batch in batched(channels, 1000): 

40 metadata = [ 

41 ( 

42 channel.name, 

43 channel.data_type.name, 

44 channel.sample_rate, 

45 channel.partition_id, 

46 ) 

47 for channel in channel_batch 

48 ] 

49 if metadata: 

50 names, dtypes, rates, partitions = map(list, zip(*metadata)) 

51 else: 

52 names, dtypes, rates, partitions = [], [], [], [] 

53 batch = pyarrow.RecordBatch.from_arrays( 

54 [ 

55 pyarrow.array(names, type=schema.field("channel").type), 

56 pyarrow.array(dtypes, type=schema.field("data_type").type), 

57 pyarrow.array(rates, type=schema.field("sample_rate").type), 

58 pyarrow.array(partitions, type=schema.field("partition_id").type), 

59 ], 

60 schema=schema, 

61 ) 

62 batches.append(batch) 

63 return batches 

64 

65 

66def create_metadata_stream(channels: Iterable[Channel]) -> flight.RecordBatchStream: 

67 """Create a record batch stream from channel metadata.""" 

68 batches = [] 

69 schema = schemas.find() 

70 for channel_batch in batched(channels, 1000): 

71 metadata = [ 

72 ( 

73 channel.name, 

74 channel.data_type.name, 

75 channel.sample_rate, 

76 channel.partition_id, 

77 channel.publisher, 

78 ) 

79 for channel in channel_batch 

80 ] 

81 if metadata: 

82 names, dtypes, rates, partitions, publishers = map(list, zip(*metadata)) 

83 else: 

84 names, dtypes, rates, partitions, publishers = [], [], [], [], [] 

85 batch = pyarrow.RecordBatch.from_arrays( 

86 [ 

87 pyarrow.array(names, type=schema.field("channel").type), 

88 pyarrow.array(dtypes, type=schema.field("data_type").type), 

89 pyarrow.array(rates, type=schema.field("sample_rate").type), 

90 pyarrow.array(partitions, type=schema.field("partition_id").type), 

91 pyarrow.array(publishers, type=schema.field("publisher").type), 

92 ], 

93 schema=schema, 

94 ) 

95 batches.append(batch) 

96 

97 return flight.RecordBatchStream( 

98 pyarrow.RecordBatchReader.from_batches(schema, batches) 

99 ) 

100 

101 

102def read_all_chunks( 

103 reader: flight.MetadataRecordBatchReader, 

104) -> Iterator[pyarrow.RecordBatch]: 

105 while True: 

106 try: 

107 batch, _ = reader.read_chunk() 

108 yield batch 

109 except StopIteration: 

110 return