Coverage for src/dataknobs_fsm/execution/batch.py: 22%

203 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Batch executor for parallel record processing.""" 

2 

3import asyncio 

4import threading 

5import time 

6from concurrent.futures import ThreadPoolExecutor, as_completed 

7from dataclasses import dataclass, field 

8from typing import Any, Callable, Dict, List, Union 

9 

10from dataknobs_fsm.core.fsm import FSM 

11from dataknobs_fsm.core.modes import ProcessingMode, TransactionMode 

12from dataknobs_fsm.execution.context import ExecutionContext 

13from dataknobs_fsm.execution.engine import ExecutionEngine 

14 

15 

16@dataclass 

17class BatchResult: 

18 """Result from batch processing.""" 

19 index: int 

20 success: bool 

21 result: Any 

22 error: Exception | None = None 

23 processing_time: float = 0.0 

24 metadata: Dict[str, Any] = field(default_factory=dict) 

25 

26 

27@dataclass 

28class BatchProgress: 

29 """Progress tracking for batch processing.""" 

30 total: int 

31 completed: int = 0 

32 succeeded: int = 0 

33 failed: int = 0 

34 start_time: float = field(default_factory=time.time) 

35 

36 @property 

37 def progress(self) -> float: 

38 """Get progress percentage.""" 

39 if self.total == 0: 

40 return 0.0 

41 return self.completed / self.total 

42 

43 @property 

44 def elapsed_time(self) -> float: 

45 """Get elapsed time.""" 

46 return time.time() - self.start_time 

47 

48 @property 

49 def items_per_second(self) -> float: 

50 """Get processing rate.""" 

51 elapsed = self.elapsed_time 

52 if elapsed == 0: 

53 return 0.0 

54 return self.completed / elapsed 

55 

56 @property 

57 def estimated_time_remaining(self) -> float: 

58 """Get estimated time remaining.""" 

59 rate = self.items_per_second 

60 if rate == 0: 

61 return float('inf') 

62 remaining = self.total - self.completed 

63 return remaining / rate 

64 

65 

66class BatchExecutor: 

67 """Executor for batch processing with parallelism. 

68  

69 This executor handles: 

70 - Parallel record processing 

71 - Resource pooling and management 

72 - Progress tracking and reporting 

73 - Error aggregation and handling 

74 - Performance optimization 

75 """ 

76 

77 def __init__( 

78 self, 

79 fsm: FSM, 

80 parallelism: int = 4, 

81 batch_size: int = 100, 

82 enable_resource_pooling: bool = True, 

83 progress_callback: Union[Callable, None] = None 

84 ): 

85 """Initialize batch executor. 

86  

87 Args: 

88 fsm: FSM to execute. 

89 parallelism: Number of parallel workers. 

90 batch_size: Size of each batch. 

91 enable_resource_pooling: Enable resource pooling. 

92 progress_callback: Callback for progress updates. 

93 """ 

94 self.fsm = fsm 

95 self.parallelism = parallelism 

96 self.batch_size = batch_size 

97 self.enable_resource_pooling = enable_resource_pooling 

98 self.progress_callback = progress_callback 

99 

100 # Create execution engine 

101 self.engine = ExecutionEngine(fsm) 

102 

103 # Resource pool 

104 self._resource_pool: Dict[str, List[Any]] = {} 

105 self._resource_locks: Dict[str, asyncio.Lock] = {} 

106 

107 def execute_batch( 

108 self, 

109 items: List[Any], 

110 context_template: ExecutionContext | None = None, 

111 max_transitions: int = 1000 

112 ) -> List[BatchResult]: 

113 """Execute batch of items. 

114  

115 Args: 

116 items: Items to process. 

117 context_template: Template context to clone. 

118 max_transitions: Maximum transitions per item. 

119  

120 Returns: 

121 List of batch results. 

122 """ 

123 if not items: 

124 return [] 

125 

126 # Create progress tracker 

127 progress = BatchProgress(total=len(items)) 

128 

129 # Create base context if not provided 

130 if context_template is None: 

131 context_template = ExecutionContext( 

132 data_mode=ProcessingMode.SINGLE, 

133 transaction_mode=TransactionMode.PER_RECORD 

134 ) 

135 

136 # Process based on parallelism setting 

137 if self.parallelism <= 1: 

138 return self._execute_sequential( 

139 items, 

140 context_template, 

141 max_transitions, 

142 progress 

143 ) 

144 else: 

145 return self._execute_parallel( 

146 items, 

147 context_template, 

148 max_transitions, 

149 progress 

150 ) 

151 

152 def _execute_sequential( 

153 self, 

154 items: List[Any], 

155 context_template: ExecutionContext, 

156 max_transitions: int, 

157 progress: BatchProgress 

158 ) -> List[BatchResult]: 

159 """Execute items sequentially. 

160  

161 Args: 

162 items: Items to process. 

163 context_template: Template context. 

164 max_transitions: Maximum transitions. 

165 progress: Progress tracker. 

166  

167 Returns: 

168 List of results. 

169 """ 

170 results = [] 

171 

172 for i, item in enumerate(items): 

173 start_time = time.time() 

174 

175 # Create context for this item 

176 context = context_template.clone() 

177 # Convert Record to dict if needed 

178 if hasattr(item, 'to_dict'): 

179 context.data = item.to_dict() 

180 elif hasattr(item, '__dict__'): 

181 context.data = dict(item.__dict__) 

182 else: 

183 context.data = item 

184 

185 # Add batch tracking metadata 

186 context.batch_id = i 

187 context.metadata['batch_info'] = { 

188 'batch_id': i, 

189 'total_items': len(items), 

190 'item_index': i, 

191 'processing_mode': 'sequential' 

192 } 

193 

194 # Reset to initial state 

195 initial_state = self._find_initial_state() 

196 if initial_state: 

197 context.set_state(initial_state) 

198 

199 # Execute 

200 try: 

201 success, result = self.engine.execute( 

202 context, 

203 None, # Data is already in context 

204 max_transitions 

205 ) 

206 

207 # Store final state and path in metadata 

208 metadata = context.metadata.copy() if context.metadata else {} 

209 metadata['final_state'] = context.current_state 

210 metadata['path'] = context.history if hasattr(context, 'history') else [] 

211 

212 batch_result = BatchResult( 

213 index=i, 

214 success=success, 

215 result=result, 

216 processing_time=time.time() - start_time, 

217 metadata=metadata 

218 ) 

219 

220 if success: 

221 progress.succeeded += 1 

222 else: 

223 progress.failed += 1 

224 

225 except Exception as e: 

226 batch_result = BatchResult( 

227 index=i, 

228 success=False, 

229 result=None, 

230 error=e, 

231 processing_time=time.time() - start_time 

232 ) 

233 progress.failed += 1 

234 

235 results.append(batch_result) 

236 progress.completed += 1 

237 

238 # Fire progress callback 

239 if self.progress_callback: 

240 self.progress_callback(progress) 

241 

242 return results 

243 

244 def _execute_parallel( 

245 self, 

246 items: List[Any], 

247 context_template: ExecutionContext, 

248 max_transitions: int, 

249 progress: BatchProgress 

250 ) -> List[BatchResult]: 

251 """Execute items in parallel. 

252  

253 Args: 

254 items: Items to process. 

255 context_template: Template context. 

256 max_transitions: Maximum transitions. 

257 progress: Progress tracker. 

258  

259 Returns: 

260 List of results. 

261 """ 

262 results = [None] * len(items) 

263 

264 with ThreadPoolExecutor(max_workers=self.parallelism) as executor: 

265 # Submit all items 

266 futures = {} 

267 for i, item in enumerate(items): 

268 future = executor.submit( 

269 self._process_single_item, 

270 i, 

271 item, 

272 context_template, 

273 max_transitions 

274 ) 

275 futures[future] = i 

276 

277 # Process completed items 

278 for future in as_completed(futures): 

279 index = futures[future] 

280 

281 try: 

282 batch_result = future.result() 

283 results[index] = batch_result # type: ignore 

284 

285 if batch_result.success: 

286 progress.succeeded += 1 

287 else: 

288 progress.failed += 1 

289 

290 except Exception as e: 

291 results[index] = BatchResult( # type: ignore 

292 index=index, 

293 success=False, 

294 result=None, 

295 error=e 

296 ) 

297 progress.failed += 1 

298 

299 progress.completed += 1 

300 

301 # Fire progress callback 

302 if self.progress_callback: 

303 self.progress_callback(progress) 

304 

305 return results # type: ignore 

306 

307 def _process_single_item( 

308 self, 

309 index: int, 

310 item: Any, 

311 context_template: ExecutionContext, 

312 max_transitions: int 

313 ) -> BatchResult: 

314 """Process a single item. 

315  

316 Args: 

317 index: Item index. 

318 item: Item to process. 

319 context_template: Template context. 

320 max_transitions: Maximum transitions. 

321  

322 Returns: 

323 Batch result. 

324 """ 

325 start_time = time.time() 

326 

327 # Create context for this item 

328 context = context_template.clone() 

329 # Convert Record to dict if needed 

330 if hasattr(item, 'to_dict'): 

331 context.data = item.to_dict() 

332 elif hasattr(item, '__dict__'): 

333 context.data = dict(item.__dict__) 

334 else: 

335 context.data = item 

336 

337 # Add batch tracking metadata 

338 context.batch_id = index 

339 context.metadata['batch_info'] = { 

340 'batch_id': index, 

341 'item_index': index, 

342 'processing_mode': 'parallel', 

343 'worker_thread': threading.current_thread().name 

344 } 

345 

346 # Get resource from pool if available 

347 if self.enable_resource_pooling: 

348 self._acquire_resources(context) 

349 

350 try: 

351 # Reset to initial state 

352 initial_state = self._find_initial_state() 

353 if initial_state: 

354 context.set_state(initial_state) 

355 

356 # Execute 

357 success, result = self.engine.execute( 

358 context, 

359 None, # Data is already in context 

360 max_transitions 

361 ) 

362 

363 # Store final state and path in metadata 

364 metadata = context.metadata.copy() if context.metadata else {} 

365 metadata['final_state'] = context.current_state 

366 metadata['path'] = context.history if hasattr(context, 'history') else [] 

367 

368 return BatchResult( 

369 index=index, 

370 success=success, 

371 result=result, 

372 processing_time=time.time() - start_time, 

373 metadata=metadata 

374 ) 

375 

376 except Exception as e: 

377 return BatchResult( 

378 index=index, 

379 success=False, 

380 result=None, 

381 error=e, 

382 processing_time=time.time() - start_time 

383 ) 

384 

385 finally: 

386 # Release resources back to pool 

387 if self.enable_resource_pooling: 

388 self._release_resources(context) 

389 

390 def _acquire_resources(self, context: ExecutionContext) -> None: 

391 """Acquire resources from pool for context. 

392  

393 Args: 

394 context: Execution context. 

395 """ 

396 # Initialize resource pools if needed 

397 for resource_type, limit in context.resource_limits.items(): 

398 if resource_type not in self._resource_pool: 

399 self._resource_pool[resource_type] = [] 

400 self._resource_locks[resource_type] = asyncio.Lock() 

401 

402 # Track batch-specific resource allocation 

403 if hasattr(context, 'batch_id'): 

404 context.metadata[f'batch_{context.batch_id}_resources'] = { 

405 'resource_type': resource_type, 

406 'limit': limit, 

407 'acquired_at': context.metadata.get('start_time'), 

408 'pool_size': len(self._resource_pool[resource_type]) 

409 } 

410 

411 def _release_resources(self, context: ExecutionContext) -> None: 

412 """Release resources back to pool. 

413  

414 Args: 

415 context: Execution context. 

416 """ 

417 # Release allocated resources back to pool 

418 for allocation in context.resources.values(): 

419 if allocation.status == 'allocated': 

420 resource_type = allocation.resource_type 

421 if resource_type in self._resource_pool: 

422 self._resource_pool[resource_type].append( 

423 allocation.resource_id 

424 ) 

425 

426 # Track batch-specific resource release 

427 if hasattr(context, 'batch_id'): 

428 batch_key = f'batch_{context.batch_id}_resources' 

429 if batch_key in context.metadata: 

430 context.metadata[batch_key]['released_at'] = context.metadata.get('end_time') 

431 context.metadata[batch_key]['final_pool_size'] = len(self._resource_pool[resource_type]) 

432 

433 # Mark as released 

434 allocation.status = 'released' 

435 

436 def _find_initial_state(self) -> str | None: 

437 """Find initial state in FSM. 

438  

439 Returns: 

440 Initial state name or None. 

441 """ 

442 # Get main network 

443 if self.fsm.name in self.fsm.networks: 

444 network = self.fsm.networks[self.fsm.name] 

445 if network.initial_states: 

446 return next(iter(network.initial_states)) 

447 return None 

448 

449 def execute_batches( 

450 self, 

451 items: List[Any], 

452 context_template: ExecutionContext | None = None, 

453 max_transitions: int = 1000 

454 ) -> Dict[str, Any]: 

455 """Execute items in batches. 

456  

457 Args: 

458 items: All items to process. 

459 context_template: Template context. 

460 max_transitions: Maximum transitions. 

461  

462 Returns: 

463 Aggregated results. 

464 """ 

465 all_results = [] 

466 total_batches = (len(items) + self.batch_size - 1) // self.batch_size 

467 

468 for batch_num in range(total_batches): 

469 start_idx = batch_num * self.batch_size 

470 end_idx = min(start_idx + self.batch_size, len(items)) 

471 batch = items[start_idx:end_idx] 

472 

473 # Process batch 

474 batch_results = self.execute_batch( 

475 batch, 

476 context_template, 

477 max_transitions 

478 ) 

479 

480 all_results.extend(batch_results) 

481 

482 # Aggregate results 

483 total = len(all_results) 

484 succeeded = sum(1 for r in all_results if r.success) 

485 failed = total - succeeded 

486 

487 total_time = sum(r.processing_time for r in all_results) 

488 avg_time = total_time / total if total > 0 else 0 

489 

490 errors_by_type = {} 

491 for result in all_results: 

492 if result.error: 

493 error_type = type(result.error).__name__ 

494 errors_by_type[error_type] = errors_by_type.get(error_type, 0) + 1 

495 

496 return { 

497 'total': total, 

498 'succeeded': succeeded, 

499 'failed': failed, 

500 'success_rate': succeeded / total if total > 0 else 0, 

501 'total_processing_time': total_time, 

502 'average_processing_time': avg_time, 

503 'errors_by_type': errors_by_type, 

504 'results': all_results 

505 } 

506 

507 def create_benchmark( 

508 self, 

509 items: List[Any], 

510 configurations: List[Dict[str, Any]] 

511 ) -> Dict[str, Any]: 

512 """Run performance benchmark with different configurations. 

513  

514 Args: 

515 items: Items to process. 

516 configurations: List of configuration dicts with: 

517 - 'name': Configuration name 

518 - 'parallelism': Parallelism level 

519 - 'batch_size': Batch size 

520 - 'strategy': Traversal strategy 

521  

522 Returns: 

523 Benchmark results. 

524 """ 

525 benchmark_results = {} 

526 

527 for config in configurations: 

528 name = config.get('name', 'unnamed') 

529 

530 # Update executor settings 

531 self.parallelism = config.get('parallelism', self.parallelism) 

532 self.batch_size = config.get('batch_size', self.batch_size) 

533 

534 if 'strategy' in config: 

535 self.engine.strategy = config['strategy'] 

536 

537 # Run benchmark 

538 start_time = time.time() 

539 results = self.execute_batches(items) 

540 elapsed_time = time.time() - start_time 

541 

542 # Calculate metrics 

543 throughput = len(items) / elapsed_time if elapsed_time > 0 else 0 

544 

545 benchmark_results[name] = { 

546 'configuration': config, 

547 'elapsed_time': elapsed_time, 

548 'throughput': throughput, 

549 'success_rate': results['success_rate'], 

550 'average_processing_time': results['average_processing_time'] 

551 } 

552 

553 return benchmark_results