Coverage for src/dataknobs_fsm/execution/async_batch.py: 15%

98 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Asynchronous batch executor for parallel processing.""" 

2 

3import asyncio 

4import time 

5from typing import Any, Callable, Dict, List, Union 

6 

7from dataknobs_fsm.core.fsm import FSM 

8from dataknobs_fsm.core.modes import ProcessingMode, TransactionMode 

9from dataknobs_fsm.execution.batch import BatchResult, BatchProgress 

10from dataknobs_fsm.execution.context import ExecutionContext 

11from dataknobs_fsm.execution.engine import ExecutionEngine 

12 

13 

14class AsyncBatchExecutor: 

15 """Asynchronous executor for batch processing. 

16  

17 This executor handles: 

18 - True async parallel execution 

19 - Resource pooling 

20 - Progress reporting 

21 - Error recovery 

22 - Transaction management 

23 """ 

24 

25 def __init__( 

26 self, 

27 fsm: FSM, 

28 parallelism: int = 10, 

29 batch_size: int = 100, 

30 enable_transactions: bool = False, 

31 progress_callback: Union[Callable, None] = None 

32 ): 

33 """Initialize async batch executor. 

34  

35 Args: 

36 fsm: FSM to execute. 

37 parallelism: Maximum parallel executions. 

38 batch_size: Size of each batch. 

39 enable_transactions: Enable transaction support. 

40 progress_callback: Callback for progress updates. 

41 """ 

42 self.fsm = fsm 

43 self.parallelism = parallelism 

44 self.batch_size = batch_size 

45 self.enable_transactions = enable_transactions 

46 self.progress_callback = progress_callback 

47 

48 # Create execution engine 

49 self.engine = ExecutionEngine(fsm) 

50 

51 # Semaphore for parallelism control 

52 self._semaphore = asyncio.Semaphore(parallelism) 

53 

54 async def execute_batch( 

55 self, 

56 items: List[Any], 

57 context_template: ExecutionContext | None = None, 

58 max_transitions: int = 1000 

59 ) -> List[BatchResult]: 

60 """Execute batch of items asynchronously. 

61  

62 Args: 

63 items: Items to process. 

64 context_template: Template context to clone. 

65 max_transitions: Maximum transitions per item. 

66  

67 Returns: 

68 List of batch results. 

69 """ 

70 if not items: 

71 return [] 

72 

73 # Create progress tracker 

74 progress = BatchProgress(total=len(items)) 

75 

76 # Create base context if not provided 

77 if context_template is None: 

78 context_template = ExecutionContext( 

79 data_mode=ProcessingMode.SINGLE, 

80 transaction_mode=TransactionMode.PER_RECORD if self.enable_transactions else TransactionMode.NONE 

81 ) 

82 

83 # Process items in parallel 

84 tasks = [] 

85 for i, item in enumerate(items): 

86 task = asyncio.create_task( 

87 self._process_item(i, item, context_template, max_transitions, progress) 

88 ) 

89 tasks.append(task) 

90 

91 # Wait for all tasks to complete 

92 results = await asyncio.gather(*tasks, return_exceptions=False) 

93 

94 # Fire final progress callback 

95 if self.progress_callback: 

96 await self._fire_progress_callback(progress) 

97 

98 return results 

99 

100 async def _process_item( 

101 self, 

102 index: int, 

103 item: Any, 

104 context_template: ExecutionContext, 

105 max_transitions: int, 

106 progress: BatchProgress 

107 ) -> BatchResult: 

108 """Process a single item asynchronously. 

109  

110 Args: 

111 index: Item index. 

112 item: Item to process. 

113 context_template: Template context. 

114 max_transitions: Maximum transitions. 

115 progress: Progress tracker. 

116  

117 Returns: 

118 Batch result. 

119 """ 

120 async with self._semaphore: # Control parallelism 

121 start_time = time.time() 

122 

123 # Create context for this item 

124 context = context_template.clone() 

125 # Convert Record to dict if needed 

126 if hasattr(item, 'to_dict'): 

127 context.data = item.to_dict() 

128 elif hasattr(item, '__dict__'): 

129 context.data = dict(item.__dict__) 

130 else: 

131 context.data = item 

132 

133 # Reset to initial state 

134 initial_state = self._find_initial_state() 

135 if initial_state: 

136 context.set_state(initial_state) 

137 else: 

138 # If no initial state found, try to get from first network 

139 if self.fsm.networks: 

140 first_network = next(iter(self.fsm.networks.values())) 

141 if hasattr(first_network, 'states') and first_network.states: 

142 # Find a state marked as is_start 

143 for state in first_network.states.values(): 

144 if getattr(state, 'is_start', False): 

145 context.set_state(state.name) 

146 break 

147 

148 try: 

149 # Execute in thread pool to avoid blocking 

150 loop = asyncio.get_event_loop() 

151 success, result = await loop.run_in_executor( 

152 None, 

153 self.engine.execute, 

154 context, 

155 None, # Data is already in context 

156 max_transitions 

157 ) 

158 

159 # Store final state and path in metadata 

160 metadata = context.metadata.copy() if context.metadata else {} 

161 metadata['final_state'] = context.current_state 

162 metadata['path'] = context.history if hasattr(context, 'history') else [] 

163 

164 batch_result = BatchResult( 

165 index=index, 

166 success=success, 

167 result=result, 

168 processing_time=time.time() - start_time, 

169 metadata=metadata 

170 ) 

171 

172 # Update progress 

173 progress.completed += 1 

174 if success: 

175 progress.succeeded += 1 

176 else: 

177 progress.failed += 1 

178 

179 # Fire progress callback 

180 if self.progress_callback and progress.completed % 10 == 0: 

181 await self._fire_progress_callback(progress) 

182 

183 return batch_result 

184 

185 except Exception as e: 

186 progress.completed += 1 

187 progress.failed += 1 

188 

189 return BatchResult( 

190 index=index, 

191 success=False, 

192 result=None, 

193 error=e, 

194 processing_time=time.time() - start_time 

195 ) 

196 

197 async def execute_batches( 

198 self, 

199 items: List[Any], 

200 context_template: ExecutionContext | None = None, 

201 max_transitions: int = 1000 

202 ) -> Dict[str, Any]: 

203 """Execute items in multiple batches. 

204  

205 Args: 

206 items: All items to process. 

207 context_template: Template context. 

208 max_transitions: Maximum transitions. 

209  

210 Returns: 

211 Execution statistics. 

212 """ 

213 all_results = [] 

214 total_start = time.time() 

215 

216 # Process in chunks 

217 for i in range(0, len(items), self.batch_size): 

218 batch = items[i:i + self.batch_size] 

219 batch_results = await self.execute_batch( 

220 batch, 

221 context_template, 

222 max_transitions 

223 ) 

224 all_results.extend(batch_results) 

225 

226 # Calculate statistics 

227 total_time = time.time() - total_start 

228 successful = sum(1 for r in all_results if r.success) 

229 failed = sum(1 for r in all_results if not r.success) 

230 

231 return { 

232 'total': len(all_results), 

233 'successful': successful, 

234 'failed': failed, 

235 'duration': total_time, 

236 'throughput': len(all_results) / total_time if total_time > 0 else 0, 

237 'results': all_results 

238 } 

239 

240 def _find_initial_state(self) -> str | None: 

241 """Find initial state in FSM. 

242  

243 Returns: 

244 Initial state name or None. 

245 """ 

246 # Get main network 

247 main_network_name = getattr(self.fsm, 'main_network', None) 

248 if main_network_name and main_network_name in self.fsm.networks: 

249 network = self.fsm.networks[main_network_name] 

250 # Check for initial_states (set) or get_initial_states() method 

251 if hasattr(network, 'initial_states') and network.initial_states: 

252 return next(iter(network.initial_states)) 

253 elif hasattr(network, 'get_initial_states'): 

254 initial_states = network.get_initial_states() 

255 if initial_states: 

256 return next(iter(initial_states)) 

257 

258 # Fallback: check all networks 

259 for network in self.fsm.networks.values(): 

260 if hasattr(network, 'initial_states') and network.initial_states: 

261 return next(iter(network.initial_states)) 

262 

263 return None 

264 

265 async def _fire_progress_callback(self, progress: BatchProgress): 

266 """Fire progress callback asynchronously. 

267  

268 Args: 

269 progress: Progress information. 

270 """ 

271 if asyncio.iscoroutinefunction(self.progress_callback): 

272 await self.progress_callback(progress) 

273 else: 

274 # Run sync callback in executor 

275 loop = asyncio.get_event_loop() 

276 await loop.run_in_executor(None, self.progress_callback, progress) # type: ignore