Coverage for src/dataknobs_fsm/streaming/db_stream.py: 12%

203 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Database streaming implementation for FSM.""" 

2 

3import logging 

4import time 

5from typing import Any, Callable, Dict, Iterator, List, Union 

6 

7from dataknobs_data.database import AsyncDatabase, SyncDatabase 

8from dataknobs_data.query import Query 

9from dataknobs_data.records import Record 

10 

11from dataknobs_fsm.streaming.core import ( 

12 IStreamSink, 

13 IStreamSource, 

14 StreamChunk, 

15) 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20class DatabaseStreamSource(IStreamSource): 

21 """Database-based stream source with cursor iteration. 

22  

23 This source supports streaming records from a database using 

24 efficient cursor-based iteration with configurable batch fetching. 

25 """ 

26 

27 def __init__( 

28 self, 

29 database: Union[SyncDatabase, AsyncDatabase], 

30 query: Query | None = None, 

31 batch_size: int = 1000, 

32 cursor_field: str | None = None, 

33 start_cursor: Any | None = None 

34 ): 

35 """Initialize database stream source. 

36  

37 Args: 

38 database: Database instance to stream from. 

39 query: Query to filter records (None for all). 

40 batch_size: Number of records per batch. 

41 cursor_field: Field to use for cursor pagination. 

42 start_cursor: Starting cursor value. 

43 """ 

44 self.database = database 

45 self.query = query or Query() 

46 self.batch_size = batch_size 

47 self.cursor_field = cursor_field or 'id' 

48 self.current_cursor = start_cursor 

49 

50 self._chunk_count = 0 

51 self._record_count = 0 

52 self._exhausted = False 

53 

54 # Get total count if possible 

55 try: 

56 self._total_records = database.count(self.query) 

57 except Exception: 

58 self._total_records = None 

59 

60 def read_chunk(self) -> StreamChunk | None: 

61 """Read next chunk of records from database. 

62  

63 Returns: 

64 StreamChunk with records or None if exhausted. 

65 """ 

66 if self._exhausted: 

67 return None 

68 

69 try: 

70 # Build query with cursor 

71 batch_query = self._build_batch_query() 

72 

73 # Fetch batch of records 

74 records = self.database.search(batch_query) 

75 

76 if not records: 

77 self._exhausted = True 

78 return None 

79 

80 # Update cursor for next batch 

81 if records and self.cursor_field: 

82 last_record = records[-1] # type: ignore 

83 if isinstance(last_record, Record): 

84 # Use Record's API to get field value 

85 if self.cursor_field == 'id': 

86 self.current_cursor = last_record.id 

87 elif last_record.has_field(self.cursor_field): 

88 self.current_cursor = last_record.get_value(self.cursor_field) 

89 elif isinstance(last_record, dict) and self.cursor_field in last_record: 

90 self.current_cursor = last_record[self.cursor_field] 

91 elif hasattr(last_record, self.cursor_field): 

92 self.current_cursor = getattr(last_record, self.cursor_field) 

93 

94 # Calculate progress 

95 progress = 0.0 

96 if self._total_records and self._total_records > 0: # type: ignore 

97 progress = min(1.0, (self._record_count + len(records)) / self._total_records) # type: ignore 

98 

99 # Check if this is the last chunk 

100 is_last = len(records) < self.batch_size # type: ignore 

101 if is_last: 

102 self._exhausted = True 

103 

104 # Convert records to serializable format 

105 chunk_data = [] 

106 for record in records: 

107 if isinstance(record, Record): 

108 # Use Record's built-in serialization 

109 chunk_data.append(record.to_dict(include_metadata=True)) 

110 elif hasattr(record, 'to_dict'): 

111 chunk_data.append(record.to_dict()) 

112 elif hasattr(record, '__dict__'): 

113 chunk_data.append(record.__dict__) 

114 else: 

115 chunk_data.append(record) 

116 

117 # Create chunk 

118 chunk = StreamChunk( 

119 data=chunk_data, 

120 sequence_number=self._chunk_count, 

121 metadata={ 

122 'database_type': type(self.database).__name__, 

123 'query': str(self.query), 

124 'batch_size': len(chunk_data), 

125 'progress': progress, 

126 'cursor_field': self.cursor_field, 

127 'cursor_value': self.current_cursor 

128 }, 

129 is_last=is_last 

130 ) 

131 

132 self._chunk_count += 1 

133 self._record_count += len(records) # type: ignore 

134 

135 return chunk 

136 

137 except Exception as e: 

138 # Return error chunk 

139 self._exhausted = True 

140 return StreamChunk( 

141 data=[], 

142 sequence_number=self._chunk_count, 

143 metadata={'error': str(e)}, 

144 is_last=True 

145 ) 

146 

147 def _build_batch_query(self) -> Query: 

148 """Build query for next batch with cursor. 

149  

150 Returns: 

151 Query for next batch. 

152 """ 

153 batch_query = Query() 

154 

155 # Copy original query conditions if provided 

156 if self.query and hasattr(self.query, 'filters') and self.query.filters: 

157 batch_query.filters = self.query.filters.copy() 

158 

159 # Add cursor condition if we have a cursor value 

160 if self.current_cursor is not None and self.cursor_field: 

161 # Use Query API to add filter for pagination 

162 from dataknobs_data.query import Operator 

163 batch_query = batch_query.filter(self.cursor_field, Operator.GT, self.current_cursor) 

164 

165 # Set limit - this is critical for batching 

166 batch_query = batch_query.limit(self.batch_size) 

167 

168 # Add ordering by cursor field for consistent pagination 

169 if self.cursor_field: 

170 batch_query = batch_query.sort_by(self.cursor_field, "asc") 

171 

172 return batch_query 

173 

174 def __iter__(self) -> Iterator[StreamChunk]: 

175 """Iterate over all chunks.""" 

176 while True: 

177 chunk = self.read_chunk() 

178 if chunk is None: 

179 break 

180 yield chunk 

181 

182 def close(self) -> None: 

183 """Close the stream source.""" 

184 # Database connections are managed separately 

185 pass 

186 

187 

188class DatabaseStreamSink(IStreamSink): 

189 """Database-based stream sink with batch operations. 

190  

191 This sink supports writing data chunks to a database using 

192 efficient batch inserts with transaction support. 

193 """ 

194 

195 def __init__( 

196 self, 

197 database: Union[SyncDatabase, AsyncDatabase], 

198 table_name: str | None = None, 

199 batch_size: int = 1000, 

200 upsert: bool = False, 

201 transaction_batch: int = 10000, 

202 on_conflict_update: List[str] | None = None 

203 ): 

204 """Initialize database stream sink. 

205  

206 Args: 

207 database: Database instance to write to. 

208 table_name: Target table name (optional). 

209 batch_size: Records per batch insert. 

210 upsert: Use upsert instead of insert. 

211 transaction_batch: Records per transaction. 

212 on_conflict_update: Fields to update on conflict. 

213 """ 

214 self.database = database 

215 self.table_name = table_name 

216 self.batch_size = batch_size 

217 self.upsert = upsert 

218 self.transaction_batch = transaction_batch 

219 self.on_conflict_update = on_conflict_update or [] 

220 

221 self._buffer: List[Dict[str, Any]] = [] 

222 self._chunk_count = 0 

223 self._record_count = 0 

224 self._transaction_count = 0 

225 self._current_transaction_size = 0 

226 

227 def write_chunk(self, chunk: StreamChunk) -> bool: 

228 """Write chunk to database. 

229  

230 Args: 

231 chunk: Chunk containing records to write. 

232  

233 Returns: 

234 True if successful. 

235 """ 

236 if not chunk.data: 

237 return True 

238 

239 try: 

240 # Add to buffer 

241 if isinstance(chunk.data, list): 

242 self._buffer.extend(chunk.data) 

243 else: 

244 self._buffer.append(chunk.data) 

245 

246 # Process buffer in batches 

247 while len(self._buffer) >= self.batch_size: 

248 batch = self._buffer[:self.batch_size] 

249 self._buffer = self._buffer[self.batch_size:] 

250 

251 success = self._write_batch(batch) 

252 if not success: 

253 return False 

254 

255 self._current_transaction_size += len(batch) 

256 

257 # Commit transaction if batch is large enough 

258 if self._current_transaction_size >= self.transaction_batch: 

259 self._commit_transaction() 

260 

261 # If this is the last chunk, flush buffer 

262 if chunk.is_last: 

263 self.flush() 

264 

265 self._chunk_count += 1 

266 return True 

267 

268 except Exception as e: 

269 logger.error(f"Error writing chunk to database: {e}") 

270 return False 

271 

272 def _write_batch(self, batch: List[Dict[str, Any]]) -> bool: 

273 """Write a batch of records to database. 

274  

275 Args: 

276 batch: Records to write. 

277  

278 Returns: 

279 True if successful. 

280 """ 

281 try: 

282 for record_data in batch: 

283 # Convert dict to Record if needed 

284 if isinstance(record_data, dict): 

285 # Extract id and create proper Record 

286 record_id = record_data.pop('id', None) or record_data.pop('_id', None) 

287 if record_id: 

288 record = Record(id=record_id, data=record_data) 

289 else: 

290 record = Record(data=record_data) 

291 else: 

292 record = record_data # type: ignore[unreachable] 

293 

294 # Perform database operation 

295 if self.upsert: 

296 # Use update if record exists, create otherwise 

297 record_id = record.id if hasattr(record, 'id') else None 

298 if record_id and self.database.read(record_id): 

299 self.database.update(record_id, record) 

300 else: 

301 self.database.create(record) 

302 else: 

303 # Simple insert 

304 self.database.create(record) 

305 

306 self._record_count += 1 

307 

308 return True 

309 

310 except Exception as e: 

311 logger.error(f"Error in batch write: {e}") 

312 return False 

313 

314 def _commit_transaction(self) -> None: 

315 """Commit current transaction if supported.""" 

316 try: 

317 # Check if database supports transactions 

318 if hasattr(self.database, 'commit'): 

319 self.database.commit() 

320 

321 self._transaction_count += 1 

322 self._current_transaction_size = 0 

323 

324 except Exception: 

325 # Not all backends support transactions 

326 pass 

327 

328 def flush(self) -> None: 

329 """Flush any buffered records.""" 

330 if self._buffer: 

331 # Write remaining records in buffer 

332 success = self._write_batch(self._buffer) 

333 if success: 

334 self._buffer = [] 

335 

336 # Commit any pending transaction 

337 if self._current_transaction_size > 0: 

338 self._commit_transaction() 

339 

340 def close(self) -> None: 

341 """Close the sink and ensure all data is written.""" 

342 self.flush() 

343 # Database connection is managed separately 

344 

345 

346class DatabaseBulkLoader: 

347 """Utility for efficient bulk loading into databases. 

348  

349 This class provides optimized bulk loading strategies 

350 for different database backends. 

351 """ 

352 

353 def __init__( 

354 self, 

355 database: Union[SyncDatabase, AsyncDatabase], 

356 table_name: str | None = None 

357 ): 

358 """Initialize bulk loader. 

359  

360 Args: 

361 database: Target database. 

362 table_name: Target table name. 

363 """ 

364 self.database = database 

365 self.table_name = table_name 

366 self._stats = { 

367 'records_loaded': 0, 

368 'batches_processed': 0, 

369 'errors': 0, 

370 'start_time': None, 

371 'end_time': None 

372 } 

373 

374 def load_from_source( 

375 self, 

376 source: IStreamSource, 

377 batch_size: int = 1000, 

378 progress_callback: Union[Callable, None] = None 

379 ) -> Dict[str, Any]: 

380 """Load data from stream source into database. 

381  

382 Args: 

383 source: Stream source to read from. 

384 batch_size: Batch size for inserts. 

385 progress_callback: Optional callback for progress updates. 

386  

387 Returns: 

388 Loading statistics. 

389 """ 

390 self._stats['start_time'] = time.time() 

391 

392 sink = DatabaseStreamSink( 

393 self.database, 

394 table_name=self.table_name, 

395 batch_size=batch_size 

396 ) 

397 

398 try: 

399 for chunk in source: 

400 success = sink.write_chunk(chunk) 

401 

402 if not success: 

403 self._stats['errors'] += 1 # type: ignore 

404 

405 self._stats['batches_processed'] += 1 # type: ignore 

406 

407 if chunk.data: 

408 self._stats['records_loaded'] += len(chunk.data) # type: ignore 

409 

410 # Call progress callback if provided 

411 if progress_callback: 

412 progress = chunk.metadata.get('progress', 0.0) 

413 progress_callback(progress, self._stats) 

414 

415 if chunk.is_last: 

416 break 

417 

418 sink.flush() 

419 

420 finally: 

421 sink.close() 

422 source.close() 

423 self._stats['end_time'] = time.time() 

424 

425 return self._stats 

426 

427 def export_to_sink( 

428 self, 

429 sink: IStreamSink, 

430 query: Query | None = None, 

431 batch_size: int = 1000, 

432 progress_callback: Union[Callable, None] = None 

433 ) -> Dict[str, Any]: 

434 """Export data from database to stream sink. 

435  

436 Args: 

437 sink: Stream sink to write to. 

438 query: Query to filter records. 

439 batch_size: Batch size for reading. 

440 progress_callback: Optional callback for progress updates. 

441  

442 Returns: 

443 Export statistics. 

444 """ 

445 self._stats['start_time'] = time.time() 

446 

447 source = DatabaseStreamSource( 

448 self.database, 

449 query=query, 

450 batch_size=batch_size 

451 ) 

452 

453 try: 

454 for chunk in source: 

455 success = sink.write_chunk(chunk) 

456 

457 if not success: 

458 self._stats['errors'] += 1 # type: ignore 

459 

460 self._stats['batches_processed'] += 1 # type: ignore 

461 

462 if chunk.data: 

463 self._stats['records_loaded'] += len(chunk.data) # type: ignore 

464 

465 # Call progress callback if provided 

466 if progress_callback: 

467 progress = chunk.metadata.get('progress', 0.0) 

468 progress_callback(progress, self._stats) 

469 

470 if chunk.is_last: 

471 break 

472 

473 sink.flush() 

474 

475 finally: 

476 sink.close() 

477 source.close() 

478 self._stats['end_time'] = time.time() 

479 

480 return self._stats