Coverage for src/dataknobs_fsm/execution/async_batch.py: 15%
98 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Asynchronous batch executor for parallel processing."""
3import asyncio
4import time
5from typing import Any, Callable, Dict, List, Union
7from dataknobs_fsm.core.fsm import FSM
8from dataknobs_fsm.core.modes import ProcessingMode, TransactionMode
9from dataknobs_fsm.execution.batch import BatchResult, BatchProgress
10from dataknobs_fsm.execution.context import ExecutionContext
11from dataknobs_fsm.execution.engine import ExecutionEngine
14class AsyncBatchExecutor:
15 """Asynchronous executor for batch processing.
17 This executor handles:
18 - True async parallel execution
19 - Resource pooling
20 - Progress reporting
21 - Error recovery
22 - Transaction management
23 """
25 def __init__(
26 self,
27 fsm: FSM,
28 parallelism: int = 10,
29 batch_size: int = 100,
30 enable_transactions: bool = False,
31 progress_callback: Union[Callable, None] = None
32 ):
33 """Initialize async batch executor.
35 Args:
36 fsm: FSM to execute.
37 parallelism: Maximum parallel executions.
38 batch_size: Size of each batch.
39 enable_transactions: Enable transaction support.
40 progress_callback: Callback for progress updates.
41 """
42 self.fsm = fsm
43 self.parallelism = parallelism
44 self.batch_size = batch_size
45 self.enable_transactions = enable_transactions
46 self.progress_callback = progress_callback
48 # Create execution engine
49 self.engine = ExecutionEngine(fsm)
51 # Semaphore for parallelism control
52 self._semaphore = asyncio.Semaphore(parallelism)
54 async def execute_batch(
55 self,
56 items: List[Any],
57 context_template: ExecutionContext | None = None,
58 max_transitions: int = 1000
59 ) -> List[BatchResult]:
60 """Execute batch of items asynchronously.
62 Args:
63 items: Items to process.
64 context_template: Template context to clone.
65 max_transitions: Maximum transitions per item.
67 Returns:
68 List of batch results.
69 """
70 if not items:
71 return []
73 # Create progress tracker
74 progress = BatchProgress(total=len(items))
76 # Create base context if not provided
77 if context_template is None:
78 context_template = ExecutionContext(
79 data_mode=ProcessingMode.SINGLE,
80 transaction_mode=TransactionMode.PER_RECORD if self.enable_transactions else TransactionMode.NONE
81 )
83 # Process items in parallel
84 tasks = []
85 for i, item in enumerate(items):
86 task = asyncio.create_task(
87 self._process_item(i, item, context_template, max_transitions, progress)
88 )
89 tasks.append(task)
91 # Wait for all tasks to complete
92 results = await asyncio.gather(*tasks, return_exceptions=False)
94 # Fire final progress callback
95 if self.progress_callback:
96 await self._fire_progress_callback(progress)
98 return results
100 async def _process_item(
101 self,
102 index: int,
103 item: Any,
104 context_template: ExecutionContext,
105 max_transitions: int,
106 progress: BatchProgress
107 ) -> BatchResult:
108 """Process a single item asynchronously.
110 Args:
111 index: Item index.
112 item: Item to process.
113 context_template: Template context.
114 max_transitions: Maximum transitions.
115 progress: Progress tracker.
117 Returns:
118 Batch result.
119 """
120 async with self._semaphore: # Control parallelism
121 start_time = time.time()
123 # Create context for this item
124 context = context_template.clone()
125 # Convert Record to dict if needed
126 if hasattr(item, 'to_dict'):
127 context.data = item.to_dict()
128 elif hasattr(item, '__dict__'):
129 context.data = dict(item.__dict__)
130 else:
131 context.data = item
133 # Reset to initial state
134 initial_state = self._find_initial_state()
135 if initial_state:
136 context.set_state(initial_state)
137 else:
138 # If no initial state found, try to get from first network
139 if self.fsm.networks:
140 first_network = next(iter(self.fsm.networks.values()))
141 if hasattr(first_network, 'states') and first_network.states:
142 # Find a state marked as is_start
143 for state in first_network.states.values():
144 if getattr(state, 'is_start', False):
145 context.set_state(state.name)
146 break
148 try:
149 # Execute in thread pool to avoid blocking
150 loop = asyncio.get_event_loop()
151 success, result = await loop.run_in_executor(
152 None,
153 self.engine.execute,
154 context,
155 None, # Data is already in context
156 max_transitions
157 )
159 # Store final state and path in metadata
160 metadata = context.metadata.copy() if context.metadata else {}
161 metadata['final_state'] = context.current_state
162 metadata['path'] = context.history if hasattr(context, 'history') else []
164 batch_result = BatchResult(
165 index=index,
166 success=success,
167 result=result,
168 processing_time=time.time() - start_time,
169 metadata=metadata
170 )
172 # Update progress
173 progress.completed += 1
174 if success:
175 progress.succeeded += 1
176 else:
177 progress.failed += 1
179 # Fire progress callback
180 if self.progress_callback and progress.completed % 10 == 0:
181 await self._fire_progress_callback(progress)
183 return batch_result
185 except Exception as e:
186 progress.completed += 1
187 progress.failed += 1
189 return BatchResult(
190 index=index,
191 success=False,
192 result=None,
193 error=e,
194 processing_time=time.time() - start_time
195 )
197 async def execute_batches(
198 self,
199 items: List[Any],
200 context_template: ExecutionContext | None = None,
201 max_transitions: int = 1000
202 ) -> Dict[str, Any]:
203 """Execute items in multiple batches.
205 Args:
206 items: All items to process.
207 context_template: Template context.
208 max_transitions: Maximum transitions.
210 Returns:
211 Execution statistics.
212 """
213 all_results = []
214 total_start = time.time()
216 # Process in chunks
217 for i in range(0, len(items), self.batch_size):
218 batch = items[i:i + self.batch_size]
219 batch_results = await self.execute_batch(
220 batch,
221 context_template,
222 max_transitions
223 )
224 all_results.extend(batch_results)
226 # Calculate statistics
227 total_time = time.time() - total_start
228 successful = sum(1 for r in all_results if r.success)
229 failed = sum(1 for r in all_results if not r.success)
231 return {
232 'total': len(all_results),
233 'successful': successful,
234 'failed': failed,
235 'duration': total_time,
236 'throughput': len(all_results) / total_time if total_time > 0 else 0,
237 'results': all_results
238 }
240 def _find_initial_state(self) -> str | None:
241 """Find initial state in FSM.
243 Returns:
244 Initial state name or None.
245 """
246 # Get main network
247 main_network_name = getattr(self.fsm, 'main_network', None)
248 if main_network_name and main_network_name in self.fsm.networks:
249 network = self.fsm.networks[main_network_name]
250 # Check for initial_states (set) or get_initial_states() method
251 if hasattr(network, 'initial_states') and network.initial_states:
252 return next(iter(network.initial_states))
253 elif hasattr(network, 'get_initial_states'):
254 initial_states = network.get_initial_states()
255 if initial_states:
256 return next(iter(initial_states))
258 # Fallback: check all networks
259 for network in self.fsm.networks.values():
260 if hasattr(network, 'initial_states') and network.initial_states:
261 return next(iter(network.initial_states))
263 return None
265 async def _fire_progress_callback(self, progress: BatchProgress):
266 """Fire progress callback asynchronously.
268 Args:
269 progress: Progress information.
270 """
271 if asyncio.iscoroutinefunction(self.progress_callback):
272 await self.progress_callback(progress)
273 else:
274 # Run sync callback in executor
275 loop = asyncio.get_event_loop()
276 await loop.run_in_executor(None, self.progress_callback, progress) # type: ignore