Coverage for src/dataknobs_fsm/execution/batch.py: 22%
203 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Batch executor for parallel record processing."""
3import asyncio
4import threading
5import time
6from concurrent.futures import ThreadPoolExecutor, as_completed
7from dataclasses import dataclass, field
8from typing import Any, Callable, Dict, List, Union
10from dataknobs_fsm.core.fsm import FSM
11from dataknobs_fsm.core.modes import ProcessingMode, TransactionMode
12from dataknobs_fsm.execution.context import ExecutionContext
13from dataknobs_fsm.execution.engine import ExecutionEngine
16@dataclass
17class BatchResult:
18 """Result from batch processing."""
19 index: int
20 success: bool
21 result: Any
22 error: Exception | None = None
23 processing_time: float = 0.0
24 metadata: Dict[str, Any] = field(default_factory=dict)
27@dataclass
28class BatchProgress:
29 """Progress tracking for batch processing."""
30 total: int
31 completed: int = 0
32 succeeded: int = 0
33 failed: int = 0
34 start_time: float = field(default_factory=time.time)
36 @property
37 def progress(self) -> float:
38 """Get progress percentage."""
39 if self.total == 0:
40 return 0.0
41 return self.completed / self.total
43 @property
44 def elapsed_time(self) -> float:
45 """Get elapsed time."""
46 return time.time() - self.start_time
48 @property
49 def items_per_second(self) -> float:
50 """Get processing rate."""
51 elapsed = self.elapsed_time
52 if elapsed == 0:
53 return 0.0
54 return self.completed / elapsed
56 @property
57 def estimated_time_remaining(self) -> float:
58 """Get estimated time remaining."""
59 rate = self.items_per_second
60 if rate == 0:
61 return float('inf')
62 remaining = self.total - self.completed
63 return remaining / rate
66class BatchExecutor:
67 """Executor for batch processing with parallelism.
69 This executor handles:
70 - Parallel record processing
71 - Resource pooling and management
72 - Progress tracking and reporting
73 - Error aggregation and handling
74 - Performance optimization
75 """
77 def __init__(
78 self,
79 fsm: FSM,
80 parallelism: int = 4,
81 batch_size: int = 100,
82 enable_resource_pooling: bool = True,
83 progress_callback: Union[Callable, None] = None
84 ):
85 """Initialize batch executor.
87 Args:
88 fsm: FSM to execute.
89 parallelism: Number of parallel workers.
90 batch_size: Size of each batch.
91 enable_resource_pooling: Enable resource pooling.
92 progress_callback: Callback for progress updates.
93 """
94 self.fsm = fsm
95 self.parallelism = parallelism
96 self.batch_size = batch_size
97 self.enable_resource_pooling = enable_resource_pooling
98 self.progress_callback = progress_callback
100 # Create execution engine
101 self.engine = ExecutionEngine(fsm)
103 # Resource pool
104 self._resource_pool: Dict[str, List[Any]] = {}
105 self._resource_locks: Dict[str, asyncio.Lock] = {}
107 def execute_batch(
108 self,
109 items: List[Any],
110 context_template: ExecutionContext | None = None,
111 max_transitions: int = 1000
112 ) -> List[BatchResult]:
113 """Execute batch of items.
115 Args:
116 items: Items to process.
117 context_template: Template context to clone.
118 max_transitions: Maximum transitions per item.
120 Returns:
121 List of batch results.
122 """
123 if not items:
124 return []
126 # Create progress tracker
127 progress = BatchProgress(total=len(items))
129 # Create base context if not provided
130 if context_template is None:
131 context_template = ExecutionContext(
132 data_mode=ProcessingMode.SINGLE,
133 transaction_mode=TransactionMode.PER_RECORD
134 )
136 # Process based on parallelism setting
137 if self.parallelism <= 1:
138 return self._execute_sequential(
139 items,
140 context_template,
141 max_transitions,
142 progress
143 )
144 else:
145 return self._execute_parallel(
146 items,
147 context_template,
148 max_transitions,
149 progress
150 )
152 def _execute_sequential(
153 self,
154 items: List[Any],
155 context_template: ExecutionContext,
156 max_transitions: int,
157 progress: BatchProgress
158 ) -> List[BatchResult]:
159 """Execute items sequentially.
161 Args:
162 items: Items to process.
163 context_template: Template context.
164 max_transitions: Maximum transitions.
165 progress: Progress tracker.
167 Returns:
168 List of results.
169 """
170 results = []
172 for i, item in enumerate(items):
173 start_time = time.time()
175 # Create context for this item
176 context = context_template.clone()
177 # Convert Record to dict if needed
178 if hasattr(item, 'to_dict'):
179 context.data = item.to_dict()
180 elif hasattr(item, '__dict__'):
181 context.data = dict(item.__dict__)
182 else:
183 context.data = item
185 # Add batch tracking metadata
186 context.batch_id = i
187 context.metadata['batch_info'] = {
188 'batch_id': i,
189 'total_items': len(items),
190 'item_index': i,
191 'processing_mode': 'sequential'
192 }
194 # Reset to initial state
195 initial_state = self._find_initial_state()
196 if initial_state:
197 context.set_state(initial_state)
199 # Execute
200 try:
201 success, result = self.engine.execute(
202 context,
203 None, # Data is already in context
204 max_transitions
205 )
207 # Store final state and path in metadata
208 metadata = context.metadata.copy() if context.metadata else {}
209 metadata['final_state'] = context.current_state
210 metadata['path'] = context.history if hasattr(context, 'history') else []
212 batch_result = BatchResult(
213 index=i,
214 success=success,
215 result=result,
216 processing_time=time.time() - start_time,
217 metadata=metadata
218 )
220 if success:
221 progress.succeeded += 1
222 else:
223 progress.failed += 1
225 except Exception as e:
226 batch_result = BatchResult(
227 index=i,
228 success=False,
229 result=None,
230 error=e,
231 processing_time=time.time() - start_time
232 )
233 progress.failed += 1
235 results.append(batch_result)
236 progress.completed += 1
238 # Fire progress callback
239 if self.progress_callback:
240 self.progress_callback(progress)
242 return results
244 def _execute_parallel(
245 self,
246 items: List[Any],
247 context_template: ExecutionContext,
248 max_transitions: int,
249 progress: BatchProgress
250 ) -> List[BatchResult]:
251 """Execute items in parallel.
253 Args:
254 items: Items to process.
255 context_template: Template context.
256 max_transitions: Maximum transitions.
257 progress: Progress tracker.
259 Returns:
260 List of results.
261 """
262 results = [None] * len(items)
264 with ThreadPoolExecutor(max_workers=self.parallelism) as executor:
265 # Submit all items
266 futures = {}
267 for i, item in enumerate(items):
268 future = executor.submit(
269 self._process_single_item,
270 i,
271 item,
272 context_template,
273 max_transitions
274 )
275 futures[future] = i
277 # Process completed items
278 for future in as_completed(futures):
279 index = futures[future]
281 try:
282 batch_result = future.result()
283 results[index] = batch_result # type: ignore
285 if batch_result.success:
286 progress.succeeded += 1
287 else:
288 progress.failed += 1
290 except Exception as e:
291 results[index] = BatchResult( # type: ignore
292 index=index,
293 success=False,
294 result=None,
295 error=e
296 )
297 progress.failed += 1
299 progress.completed += 1
301 # Fire progress callback
302 if self.progress_callback:
303 self.progress_callback(progress)
305 return results # type: ignore
307 def _process_single_item(
308 self,
309 index: int,
310 item: Any,
311 context_template: ExecutionContext,
312 max_transitions: int
313 ) -> BatchResult:
314 """Process a single item.
316 Args:
317 index: Item index.
318 item: Item to process.
319 context_template: Template context.
320 max_transitions: Maximum transitions.
322 Returns:
323 Batch result.
324 """
325 start_time = time.time()
327 # Create context for this item
328 context = context_template.clone()
329 # Convert Record to dict if needed
330 if hasattr(item, 'to_dict'):
331 context.data = item.to_dict()
332 elif hasattr(item, '__dict__'):
333 context.data = dict(item.__dict__)
334 else:
335 context.data = item
337 # Add batch tracking metadata
338 context.batch_id = index
339 context.metadata['batch_info'] = {
340 'batch_id': index,
341 'item_index': index,
342 'processing_mode': 'parallel',
343 'worker_thread': threading.current_thread().name
344 }
346 # Get resource from pool if available
347 if self.enable_resource_pooling:
348 self._acquire_resources(context)
350 try:
351 # Reset to initial state
352 initial_state = self._find_initial_state()
353 if initial_state:
354 context.set_state(initial_state)
356 # Execute
357 success, result = self.engine.execute(
358 context,
359 None, # Data is already in context
360 max_transitions
361 )
363 # Store final state and path in metadata
364 metadata = context.metadata.copy() if context.metadata else {}
365 metadata['final_state'] = context.current_state
366 metadata['path'] = context.history if hasattr(context, 'history') else []
368 return BatchResult(
369 index=index,
370 success=success,
371 result=result,
372 processing_time=time.time() - start_time,
373 metadata=metadata
374 )
376 except Exception as e:
377 return BatchResult(
378 index=index,
379 success=False,
380 result=None,
381 error=e,
382 processing_time=time.time() - start_time
383 )
385 finally:
386 # Release resources back to pool
387 if self.enable_resource_pooling:
388 self._release_resources(context)
390 def _acquire_resources(self, context: ExecutionContext) -> None:
391 """Acquire resources from pool for context.
393 Args:
394 context: Execution context.
395 """
396 # Initialize resource pools if needed
397 for resource_type, limit in context.resource_limits.items():
398 if resource_type not in self._resource_pool:
399 self._resource_pool[resource_type] = []
400 self._resource_locks[resource_type] = asyncio.Lock()
402 # Track batch-specific resource allocation
403 if hasattr(context, 'batch_id'):
404 context.metadata[f'batch_{context.batch_id}_resources'] = {
405 'resource_type': resource_type,
406 'limit': limit,
407 'acquired_at': context.metadata.get('start_time'),
408 'pool_size': len(self._resource_pool[resource_type])
409 }
411 def _release_resources(self, context: ExecutionContext) -> None:
412 """Release resources back to pool.
414 Args:
415 context: Execution context.
416 """
417 # Release allocated resources back to pool
418 for allocation in context.resources.values():
419 if allocation.status == 'allocated':
420 resource_type = allocation.resource_type
421 if resource_type in self._resource_pool:
422 self._resource_pool[resource_type].append(
423 allocation.resource_id
424 )
426 # Track batch-specific resource release
427 if hasattr(context, 'batch_id'):
428 batch_key = f'batch_{context.batch_id}_resources'
429 if batch_key in context.metadata:
430 context.metadata[batch_key]['released_at'] = context.metadata.get('end_time')
431 context.metadata[batch_key]['final_pool_size'] = len(self._resource_pool[resource_type])
433 # Mark as released
434 allocation.status = 'released'
436 def _find_initial_state(self) -> str | None:
437 """Find initial state in FSM.
439 Returns:
440 Initial state name or None.
441 """
442 # Get main network
443 if self.fsm.name in self.fsm.networks:
444 network = self.fsm.networks[self.fsm.name]
445 if network.initial_states:
446 return next(iter(network.initial_states))
447 return None
449 def execute_batches(
450 self,
451 items: List[Any],
452 context_template: ExecutionContext | None = None,
453 max_transitions: int = 1000
454 ) -> Dict[str, Any]:
455 """Execute items in batches.
457 Args:
458 items: All items to process.
459 context_template: Template context.
460 max_transitions: Maximum transitions.
462 Returns:
463 Aggregated results.
464 """
465 all_results = []
466 total_batches = (len(items) + self.batch_size - 1) // self.batch_size
468 for batch_num in range(total_batches):
469 start_idx = batch_num * self.batch_size
470 end_idx = min(start_idx + self.batch_size, len(items))
471 batch = items[start_idx:end_idx]
473 # Process batch
474 batch_results = self.execute_batch(
475 batch,
476 context_template,
477 max_transitions
478 )
480 all_results.extend(batch_results)
482 # Aggregate results
483 total = len(all_results)
484 succeeded = sum(1 for r in all_results if r.success)
485 failed = total - succeeded
487 total_time = sum(r.processing_time for r in all_results)
488 avg_time = total_time / total if total > 0 else 0
490 errors_by_type = {}
491 for result in all_results:
492 if result.error:
493 error_type = type(result.error).__name__
494 errors_by_type[error_type] = errors_by_type.get(error_type, 0) + 1
496 return {
497 'total': total,
498 'succeeded': succeeded,
499 'failed': failed,
500 'success_rate': succeeded / total if total > 0 else 0,
501 'total_processing_time': total_time,
502 'average_processing_time': avg_time,
503 'errors_by_type': errors_by_type,
504 'results': all_results
505 }
507 def create_benchmark(
508 self,
509 items: List[Any],
510 configurations: List[Dict[str, Any]]
511 ) -> Dict[str, Any]:
512 """Run performance benchmark with different configurations.
514 Args:
515 items: Items to process.
516 configurations: List of configuration dicts with:
517 - 'name': Configuration name
518 - 'parallelism': Parallelism level
519 - 'batch_size': Batch size
520 - 'strategy': Traversal strategy
522 Returns:
523 Benchmark results.
524 """
525 benchmark_results = {}
527 for config in configurations:
528 name = config.get('name', 'unnamed')
530 # Update executor settings
531 self.parallelism = config.get('parallelism', self.parallelism)
532 self.batch_size = config.get('batch_size', self.batch_size)
534 if 'strategy' in config:
535 self.engine.strategy = config['strategy']
537 # Run benchmark
538 start_time = time.time()
539 results = self.execute_batches(items)
540 elapsed_time = time.time() - start_time
542 # Calculate metrics
543 throughput = len(items) / elapsed_time if elapsed_time > 0 else 0
545 benchmark_results[name] = {
546 'configuration': config,
547 'elapsed_time': elapsed_time,
548 'throughput': throughput,
549 'success_rate': results['success_rate'],
550 'average_processing_time': results['average_processing_time']
551 }
553 return benchmark_results