Coverage for src/dataknobs_fsm/streaming/db_stream.py: 12%
203 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Database streaming implementation for FSM."""
3import logging
4import time
5from typing import Any, Callable, Dict, Iterator, List, Union
7from dataknobs_data.database import AsyncDatabase, SyncDatabase
8from dataknobs_data.query import Query
9from dataknobs_data.records import Record
11from dataknobs_fsm.streaming.core import (
12 IStreamSink,
13 IStreamSource,
14 StreamChunk,
15)
17logger = logging.getLogger(__name__)
20class DatabaseStreamSource(IStreamSource):
21 """Database-based stream source with cursor iteration.
23 This source supports streaming records from a database using
24 efficient cursor-based iteration with configurable batch fetching.
25 """
27 def __init__(
28 self,
29 database: Union[SyncDatabase, AsyncDatabase],
30 query: Query | None = None,
31 batch_size: int = 1000,
32 cursor_field: str | None = None,
33 start_cursor: Any | None = None
34 ):
35 """Initialize database stream source.
37 Args:
38 database: Database instance to stream from.
39 query: Query to filter records (None for all).
40 batch_size: Number of records per batch.
41 cursor_field: Field to use for cursor pagination.
42 start_cursor: Starting cursor value.
43 """
44 self.database = database
45 self.query = query or Query()
46 self.batch_size = batch_size
47 self.cursor_field = cursor_field or 'id'
48 self.current_cursor = start_cursor
50 self._chunk_count = 0
51 self._record_count = 0
52 self._exhausted = False
54 # Get total count if possible
55 try:
56 self._total_records = database.count(self.query)
57 except Exception:
58 self._total_records = None
60 def read_chunk(self) -> StreamChunk | None:
61 """Read next chunk of records from database.
63 Returns:
64 StreamChunk with records or None if exhausted.
65 """
66 if self._exhausted:
67 return None
69 try:
70 # Build query with cursor
71 batch_query = self._build_batch_query()
73 # Fetch batch of records
74 records = self.database.search(batch_query)
76 if not records:
77 self._exhausted = True
78 return None
80 # Update cursor for next batch
81 if records and self.cursor_field:
82 last_record = records[-1] # type: ignore
83 if isinstance(last_record, Record):
84 # Use Record's API to get field value
85 if self.cursor_field == 'id':
86 self.current_cursor = last_record.id
87 elif last_record.has_field(self.cursor_field):
88 self.current_cursor = last_record.get_value(self.cursor_field)
89 elif isinstance(last_record, dict) and self.cursor_field in last_record:
90 self.current_cursor = last_record[self.cursor_field]
91 elif hasattr(last_record, self.cursor_field):
92 self.current_cursor = getattr(last_record, self.cursor_field)
94 # Calculate progress
95 progress = 0.0
96 if self._total_records and self._total_records > 0: # type: ignore
97 progress = min(1.0, (self._record_count + len(records)) / self._total_records) # type: ignore
99 # Check if this is the last chunk
100 is_last = len(records) < self.batch_size # type: ignore
101 if is_last:
102 self._exhausted = True
104 # Convert records to serializable format
105 chunk_data = []
106 for record in records:
107 if isinstance(record, Record):
108 # Use Record's built-in serialization
109 chunk_data.append(record.to_dict(include_metadata=True))
110 elif hasattr(record, 'to_dict'):
111 chunk_data.append(record.to_dict())
112 elif hasattr(record, '__dict__'):
113 chunk_data.append(record.__dict__)
114 else:
115 chunk_data.append(record)
117 # Create chunk
118 chunk = StreamChunk(
119 data=chunk_data,
120 sequence_number=self._chunk_count,
121 metadata={
122 'database_type': type(self.database).__name__,
123 'query': str(self.query),
124 'batch_size': len(chunk_data),
125 'progress': progress,
126 'cursor_field': self.cursor_field,
127 'cursor_value': self.current_cursor
128 },
129 is_last=is_last
130 )
132 self._chunk_count += 1
133 self._record_count += len(records) # type: ignore
135 return chunk
137 except Exception as e:
138 # Return error chunk
139 self._exhausted = True
140 return StreamChunk(
141 data=[],
142 sequence_number=self._chunk_count,
143 metadata={'error': str(e)},
144 is_last=True
145 )
147 def _build_batch_query(self) -> Query:
148 """Build query for next batch with cursor.
150 Returns:
151 Query for next batch.
152 """
153 batch_query = Query()
155 # Copy original query conditions if provided
156 if self.query and hasattr(self.query, 'filters') and self.query.filters:
157 batch_query.filters = self.query.filters.copy()
159 # Add cursor condition if we have a cursor value
160 if self.current_cursor is not None and self.cursor_field:
161 # Use Query API to add filter for pagination
162 from dataknobs_data.query import Operator
163 batch_query = batch_query.filter(self.cursor_field, Operator.GT, self.current_cursor)
165 # Set limit - this is critical for batching
166 batch_query = batch_query.limit(self.batch_size)
168 # Add ordering by cursor field for consistent pagination
169 if self.cursor_field:
170 batch_query = batch_query.sort_by(self.cursor_field, "asc")
172 return batch_query
174 def __iter__(self) -> Iterator[StreamChunk]:
175 """Iterate over all chunks."""
176 while True:
177 chunk = self.read_chunk()
178 if chunk is None:
179 break
180 yield chunk
182 def close(self) -> None:
183 """Close the stream source."""
184 # Database connections are managed separately
185 pass
188class DatabaseStreamSink(IStreamSink):
189 """Database-based stream sink with batch operations.
191 This sink supports writing data chunks to a database using
192 efficient batch inserts with transaction support.
193 """
195 def __init__(
196 self,
197 database: Union[SyncDatabase, AsyncDatabase],
198 table_name: str | None = None,
199 batch_size: int = 1000,
200 upsert: bool = False,
201 transaction_batch: int = 10000,
202 on_conflict_update: List[str] | None = None
203 ):
204 """Initialize database stream sink.
206 Args:
207 database: Database instance to write to.
208 table_name: Target table name (optional).
209 batch_size: Records per batch insert.
210 upsert: Use upsert instead of insert.
211 transaction_batch: Records per transaction.
212 on_conflict_update: Fields to update on conflict.
213 """
214 self.database = database
215 self.table_name = table_name
216 self.batch_size = batch_size
217 self.upsert = upsert
218 self.transaction_batch = transaction_batch
219 self.on_conflict_update = on_conflict_update or []
221 self._buffer: List[Dict[str, Any]] = []
222 self._chunk_count = 0
223 self._record_count = 0
224 self._transaction_count = 0
225 self._current_transaction_size = 0
227 def write_chunk(self, chunk: StreamChunk) -> bool:
228 """Write chunk to database.
230 Args:
231 chunk: Chunk containing records to write.
233 Returns:
234 True if successful.
235 """
236 if not chunk.data:
237 return True
239 try:
240 # Add to buffer
241 if isinstance(chunk.data, list):
242 self._buffer.extend(chunk.data)
243 else:
244 self._buffer.append(chunk.data)
246 # Process buffer in batches
247 while len(self._buffer) >= self.batch_size:
248 batch = self._buffer[:self.batch_size]
249 self._buffer = self._buffer[self.batch_size:]
251 success = self._write_batch(batch)
252 if not success:
253 return False
255 self._current_transaction_size += len(batch)
257 # Commit transaction if batch is large enough
258 if self._current_transaction_size >= self.transaction_batch:
259 self._commit_transaction()
261 # If this is the last chunk, flush buffer
262 if chunk.is_last:
263 self.flush()
265 self._chunk_count += 1
266 return True
268 except Exception as e:
269 logger.error(f"Error writing chunk to database: {e}")
270 return False
272 def _write_batch(self, batch: List[Dict[str, Any]]) -> bool:
273 """Write a batch of records to database.
275 Args:
276 batch: Records to write.
278 Returns:
279 True if successful.
280 """
281 try:
282 for record_data in batch:
283 # Convert dict to Record if needed
284 if isinstance(record_data, dict):
285 # Extract id and create proper Record
286 record_id = record_data.pop('id', None) or record_data.pop('_id', None)
287 if record_id:
288 record = Record(id=record_id, data=record_data)
289 else:
290 record = Record(data=record_data)
291 else:
292 record = record_data # type: ignore[unreachable]
294 # Perform database operation
295 if self.upsert:
296 # Use update if record exists, create otherwise
297 record_id = record.id if hasattr(record, 'id') else None
298 if record_id and self.database.read(record_id):
299 self.database.update(record_id, record)
300 else:
301 self.database.create(record)
302 else:
303 # Simple insert
304 self.database.create(record)
306 self._record_count += 1
308 return True
310 except Exception as e:
311 logger.error(f"Error in batch write: {e}")
312 return False
314 def _commit_transaction(self) -> None:
315 """Commit current transaction if supported."""
316 try:
317 # Check if database supports transactions
318 if hasattr(self.database, 'commit'):
319 self.database.commit()
321 self._transaction_count += 1
322 self._current_transaction_size = 0
324 except Exception:
325 # Not all backends support transactions
326 pass
328 def flush(self) -> None:
329 """Flush any buffered records."""
330 if self._buffer:
331 # Write remaining records in buffer
332 success = self._write_batch(self._buffer)
333 if success:
334 self._buffer = []
336 # Commit any pending transaction
337 if self._current_transaction_size > 0:
338 self._commit_transaction()
340 def close(self) -> None:
341 """Close the sink and ensure all data is written."""
342 self.flush()
343 # Database connection is managed separately
346class DatabaseBulkLoader:
347 """Utility for efficient bulk loading into databases.
349 This class provides optimized bulk loading strategies
350 for different database backends.
351 """
353 def __init__(
354 self,
355 database: Union[SyncDatabase, AsyncDatabase],
356 table_name: str | None = None
357 ):
358 """Initialize bulk loader.
360 Args:
361 database: Target database.
362 table_name: Target table name.
363 """
364 self.database = database
365 self.table_name = table_name
366 self._stats = {
367 'records_loaded': 0,
368 'batches_processed': 0,
369 'errors': 0,
370 'start_time': None,
371 'end_time': None
372 }
374 def load_from_source(
375 self,
376 source: IStreamSource,
377 batch_size: int = 1000,
378 progress_callback: Union[Callable, None] = None
379 ) -> Dict[str, Any]:
380 """Load data from stream source into database.
382 Args:
383 source: Stream source to read from.
384 batch_size: Batch size for inserts.
385 progress_callback: Optional callback for progress updates.
387 Returns:
388 Loading statistics.
389 """
390 self._stats['start_time'] = time.time()
392 sink = DatabaseStreamSink(
393 self.database,
394 table_name=self.table_name,
395 batch_size=batch_size
396 )
398 try:
399 for chunk in source:
400 success = sink.write_chunk(chunk)
402 if not success:
403 self._stats['errors'] += 1 # type: ignore
405 self._stats['batches_processed'] += 1 # type: ignore
407 if chunk.data:
408 self._stats['records_loaded'] += len(chunk.data) # type: ignore
410 # Call progress callback if provided
411 if progress_callback:
412 progress = chunk.metadata.get('progress', 0.0)
413 progress_callback(progress, self._stats)
415 if chunk.is_last:
416 break
418 sink.flush()
420 finally:
421 sink.close()
422 source.close()
423 self._stats['end_time'] = time.time()
425 return self._stats
427 def export_to_sink(
428 self,
429 sink: IStreamSink,
430 query: Query | None = None,
431 batch_size: int = 1000,
432 progress_callback: Union[Callable, None] = None
433 ) -> Dict[str, Any]:
434 """Export data from database to stream sink.
436 Args:
437 sink: Stream sink to write to.
438 query: Query to filter records.
439 batch_size: Batch size for reading.
440 progress_callback: Optional callback for progress updates.
442 Returns:
443 Export statistics.
444 """
445 self._stats['start_time'] = time.time()
447 source = DatabaseStreamSource(
448 self.database,
449 query=query,
450 batch_size=batch_size
451 )
453 try:
454 for chunk in source:
455 success = sink.write_chunk(chunk)
457 if not success:
458 self._stats['errors'] += 1 # type: ignore
460 self._stats['batches_processed'] += 1 # type: ignore
462 if chunk.data:
463 self._stats['records_loaded'] += len(chunk.data) # type: ignore
465 # Call progress callback if provided
466 if progress_callback:
467 progress = chunk.metadata.get('progress', 0.0)
468 progress_callback(progress, self._stats)
470 if chunk.is_last:
471 break
473 sink.flush()
475 finally:
476 sink.close()
477 source.close()
478 self._stats['end_time'] = time.time()
480 return self._stats