Coverage for arrakis_server/arrow.py: 58.1%
43 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-12 16:39 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-12 16:39 -0700
1# Copyright (c) 2022, California Institute of Technology and contributors
2#
3# You should have received a copy of the licensing terms for this
4# software included in the file "LICENSE" located in the top-level
5# directory of this package. If you did not, you can view a copy at
6# https://git.ligo.org/ngdd/arrakis-server/-/raw/main/LICENSE
8import itertools
9import typing
10from collections.abc import Generator, Iterable, Iterator
12import pyarrow
13from arrakis import Channel
14from pyarrow import flight
16from . import schemas
18T = typing.TypeVar("T")
21def batched(iterable: Iterable[T], n: int) -> Generator[Iterable[T], None, None]:
22 """An implementation of python 3.12's itertools.batches.
23 Taken from the python documentation for itertools recipies
24 Given an Iterable object iterable, generate a series of
25 Iterators that return chunks of 'n' items from iterable.
26 The last batch may be smaller than n entries.
27 """
28 if n < 1:
29 raise ValueError("n must be greater than or equal to 1")
30 iterator = iter(iterable)
31 while batch := tuple(itertools.islice(iterator, n)):
32 yield batch
35def create_partition_batches(channels: Iterable[Channel]) -> list[pyarrow.RecordBatch]:
36 """Create record batches from channel metadata for partitioning."""
37 batches = []
38 schema = schemas.partition()
39 for channel_batch in batched(channels, 1000):
40 metadata = [
41 (
42 channel.name,
43 channel.data_type.name,
44 channel.sample_rate,
45 channel.partition_id,
46 )
47 for channel in channel_batch
48 ]
49 if metadata:
50 names, dtypes, rates, partitions = map(list, zip(*metadata))
51 else:
52 names, dtypes, rates, partitions = [], [], [], []
53 batch = pyarrow.RecordBatch.from_arrays(
54 [
55 pyarrow.array(names, type=schema.field("channel").type),
56 pyarrow.array(dtypes, type=schema.field("data_type").type),
57 pyarrow.array(rates, type=schema.field("sample_rate").type),
58 pyarrow.array(partitions, type=schema.field("partition_id").type),
59 ],
60 schema=schema,
61 )
62 batches.append(batch)
63 return batches
66def create_metadata_stream(channels: Iterable[Channel]) -> flight.RecordBatchStream:
67 """Create a record batch stream from channel metadata."""
68 batches = []
69 schema = schemas.find()
70 for channel_batch in batched(channels, 1000):
71 metadata = [
72 (
73 channel.name,
74 channel.data_type.name,
75 channel.sample_rate,
76 channel.partition_id,
77 channel.publisher,
78 )
79 for channel in channel_batch
80 ]
81 if metadata:
82 names, dtypes, rates, partitions, publishers = map(list, zip(*metadata))
83 else:
84 names, dtypes, rates, partitions, publishers = [], [], [], [], []
85 batch = pyarrow.RecordBatch.from_arrays(
86 [
87 pyarrow.array(names, type=schema.field("channel").type),
88 pyarrow.array(dtypes, type=schema.field("data_type").type),
89 pyarrow.array(rates, type=schema.field("sample_rate").type),
90 pyarrow.array(partitions, type=schema.field("partition_id").type),
91 pyarrow.array(publishers, type=schema.field("publisher").type),
92 ],
93 schema=schema,
94 )
95 batches.append(batch)
97 return flight.RecordBatchStream(
98 pyarrow.RecordBatchReader.from_batches(schema, batches)
99 )
102def read_all_chunks(
103 reader: flight.MetadataRecordBatchReader,
104) -> Iterator[pyarrow.RecordBatch]:
105 while True:
106 try:
107 batch, _ = reader.read_chunk()
108 yield batch
109 except StopIteration:
110 return