Coverage for src/dataknobs_fsm/execution/common.py: 24%

176 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Common utilities for sync and async execution engines. 

2 

3This module provides shared logic for both synchronous and asynchronous 

4execution engines, including network selection, arc scoring, and  

5transition selection strategies. 

6""" 

7 

8import math 

9from typing import Any, Dict, List, Optional, TYPE_CHECKING 

10from enum import Enum 

11 

12from dataknobs_fsm.core.arc import ArcDefinition 

13from dataknobs_fsm.core.fsm import FSM, StateNetwork 

14from dataknobs_fsm.core.modes import ProcessingMode 

15from dataknobs_fsm.execution.context import ExecutionContext 

16 

17if TYPE_CHECKING: 

18 from dataknobs_fsm.execution.engine import TraversalStrategy 

19 

20 

21class TransitionSelectionMode(Enum): 

22 """Mode for selecting transitions.""" 

23 STRATEGY_BASED = "strategy" # Use TraversalStrategy (depth-first, breadth-first, etc.) 

24 PRIORITY_SCORING = "scoring" # Use sophisticated multi-factor scoring 

25 HYBRID = "hybrid" # Combine both approaches 

26 

27 

28class NetworkSelector: 

29 """Common network selection logic for execution engines.""" 

30 

31 @staticmethod 

32 def get_current_network( 

33 fsm: FSM, 

34 context: ExecutionContext, 

35 enable_intelligent_selection: bool = True 

36 ) -> StateNetwork | None: 

37 """Get the current network from context with intelligent selection. 

38  

39 Network selection priority: 

40 1. If network stack is not empty, use the top network 

41 2. If a specific network is set in context metadata, use it 

42 3. Use the main network if defined 

43 4. If intelligent selection enabled, use context hints and data mode 

44 5. Fall back to the first available network 

45  

46 Args: 

47 fsm: The FSM instance. 

48 context: Execution context. 

49 enable_intelligent_selection: Whether to use intelligent selection. 

50  

51 Returns: 

52 Current network or None. 

53 """ 

54 # Priority 1: Check if we're in a pushed network (network stack) 

55 if context.network_stack: 

56 network_name = context.network_stack[-1][0] 

57 if network_name in fsm.networks: 

58 return fsm.networks[network_name] 

59 # Log warning if network not found but continue 

60 if hasattr(context, 'metadata'): 

61 context.metadata['network_selection_warning'] = f"Network '{network_name}' not found in stack" 

62 

63 # Priority 2: Check for explicitly set network in metadata 

64 if hasattr(context, 'metadata') and 'current_network' in context.metadata: 

65 network_name = context.metadata['current_network'] 

66 if isinstance(network_name, str) and network_name in fsm.networks: 

67 return fsm.networks[network_name] 

68 

69 # Priority 3: Use main network if defined 

70 main_network_ref = None 

71 

72 # Handle both wrapper and core FSM structures 

73 if hasattr(fsm, 'core_fsm'): 

74 # This is a wrapper FSM 

75 main_network_ref = fsm.core_fsm.main_network 

76 else: 

77 # This is a core FSM 

78 main_network_ref = getattr(fsm, 'main_network', None) 

79 

80 if main_network_ref: 

81 # Handle direct network object 

82 if hasattr(main_network_ref, 'states'): 

83 return main_network_ref 

84 # Handle network name reference 

85 elif isinstance(main_network_ref, str) and main_network_ref in fsm.networks: 

86 return fsm.networks[main_network_ref] 

87 

88 # If intelligent selection is disabled, just return first network 

89 if not enable_intelligent_selection: 

90 if fsm.networks: 

91 return next(iter(fsm.networks.values())) 

92 return None 

93 

94 # Priority 4: Select based on context hints 

95 # Check if there's a preferred network type in metadata 

96 if hasattr(context, 'metadata'): 

97 network_type = context.metadata.get('preferred_network_type') 

98 if network_type: 

99 # Find networks matching the type 

100 for network in fsm.networks.values(): 

101 if hasattr(network, 'type') and network.type == network_type: 

102 return network 

103 

104 # Priority 5: Use data mode to select appropriate network 

105 if hasattr(context, 'data_mode'): 

106 # Map processing modes to network name patterns 

107 mode_to_pattern = { 

108 ProcessingMode.BATCH: ['batch', 'parallel', 'bulk'], 

109 ProcessingMode.STREAM: ['stream', 'flow', 'pipeline'], 

110 ProcessingMode.SINGLE: ['main', 'default', 'single'] 

111 } 

112 

113 patterns = mode_to_pattern.get(context.data_mode, []) 

114 for pattern in patterns: 

115 for name, network in fsm.networks.items(): 

116 if pattern in name.lower(): 

117 return network 

118 

119 # Priority 6: Fallback to first network with initial state 

120 for network in fsm.networks.values(): 

121 if hasattr(network, 'initial_state') and network.initial_state: 

122 return network 

123 

124 # Final fallback: Return first available network 

125 if fsm.networks: 

126 return next(iter(fsm.networks.values())) 

127 

128 return None 

129 

130 

131class ArcScorer: 

132 """Common arc scoring logic for transition selection.""" 

133 

134 @staticmethod 

135 def score_arc( 

136 arc: ArcDefinition, 

137 context: ExecutionContext, 

138 source_state: str, 

139 include_resource_check: bool = True, 

140 include_history: bool = True, 

141 include_load_balancing: bool = True 

142 ) -> float: 

143 """Score an arc based on multiple factors. 

144  

145 Factors considered: 

146 1. Arc priority (higher priority = higher score) 

147 2. Resource availability (available resources = bonus) 

148 3. Historical success rate (higher success = higher score) 

149 4. Load balancing (frequently used = penalty) 

150 5. Deterministic preference (deterministic = bonus) 

151  

152 Args: 

153 arc: The arc to score. 

154 context: Execution context. 

155 include_resource_check: Whether to include resource availability. 

156 include_history: Whether to include historical success rate. 

157 include_load_balancing: Whether to include load balancing. 

158  

159 Returns: 

160 Numeric score for the arc. 

161 """ 

162 score = 0.0 

163 

164 # Factor 1: Base priority (weighted heavily) 

165 priority = getattr(arc, 'priority', 0) 

166 score += priority * 1000 # High weight for explicit priority 

167 

168 # Factor 2: Resource availability 

169 if include_resource_check and hasattr(arc, 'resources') and arc.resources: 

170 # Check if required resources are available 

171 resources_available = all( 

172 res in context.resources and 

173 context.resources[res].status == 'available' 

174 for res in arc.resources 

175 ) 

176 if resources_available: 

177 score += 100 # Bonus for available resources 

178 

179 # Factor 3: Historical success rate from metadata 

180 if include_history: 

181 arc_key = f"arc_{source_state}_{arc.target_state}_stats" 

182 if arc_key in context.metadata: 

183 stats = context.metadata[arc_key] 

184 success_rate = stats.get('success_rate', 0.5) 

185 score += success_rate * 50 # Weight by success rate 

186 

187 # Factor 4: Load balancing - penalize frequently used arcs 

188 if include_load_balancing: 

189 usage_key = f"arc_{source_state}_{arc.target_state}_usage" 

190 if usage_key in context.metadata: 

191 usage_count = context.metadata[usage_key] 

192 # Logarithmic penalty for overuse 

193 score -= math.log(usage_count + 1) * 10 

194 

195 # Factor 5: Prefer deterministic arcs over non-deterministic 

196 if hasattr(arc, 'is_deterministic') and arc.is_deterministic: 

197 score += 25 

198 

199 return score 

200 

201 @staticmethod 

202 def update_arc_usage(arc: ArcDefinition, context: ExecutionContext, source_state: str) -> None: 

203 """Update usage statistics for an arc. 

204  

205 Args: 

206 arc: The arc that was used. 

207 context: Execution context. 

208 source_state: The source state of the arc. 

209 """ 

210 usage_key = f"arc_{source_state}_{arc.target_state}_usage" 

211 context.metadata[usage_key] = context.metadata.get(usage_key, 0) + 1 

212 

213 

214class TransitionSelector: 

215 """Common transition selection logic for execution engines.""" 

216 

217 def __init__( 

218 self, 

219 mode: TransitionSelectionMode = TransitionSelectionMode.HYBRID, 

220 default_strategy: Optional['TraversalStrategy'] = None 

221 ): 

222 """Initialize transition selector. 

223  

224 Args: 

225 mode: Selection mode to use. 

226 default_strategy: Default traversal strategy for strategy-based mode. 

227 """ 

228 self.mode = mode 

229 self.default_strategy = default_strategy 

230 

231 def select_transition( 

232 self, 

233 available: List[ArcDefinition], 

234 context: ExecutionContext, 

235 strategy: Optional['TraversalStrategy'] = None 

236 ) -> ArcDefinition | None: 

237 """Select which transition to take from available options. 

238  

239 Args: 

240 available: Available transitions. 

241 context: Execution context. 

242 strategy: Traversal strategy to use (for strategy-based mode). 

243  

244 Returns: 

245 Selected arc or None. 

246 """ 

247 if not available: 

248 return None 

249 

250 # If only one option, return it 

251 if len(available) == 1: 

252 selected = available[0] 

253 state_name = context.current_state or "" 

254 ArcScorer.update_arc_usage(selected, context, state_name) 

255 return selected 

256 

257 # Get the effective strategy 

258 effective_strategy = strategy or self.default_strategy 

259 

260 # Select based on mode 

261 if self.mode == TransitionSelectionMode.STRATEGY_BASED: 

262 return self._select_by_strategy(available, context, effective_strategy) 

263 elif self.mode == TransitionSelectionMode.PRIORITY_SCORING: 

264 return self._select_by_scoring(available, context) 

265 else: # HYBRID mode 

266 # Use strategy if specified, otherwise use scoring 

267 if effective_strategy: 

268 return self._select_by_strategy(available, context, effective_strategy) 

269 else: 

270 return self._select_by_scoring(available, context) 

271 

272 def _select_by_strategy( 

273 self, 

274 available: List[ArcDefinition], 

275 context: ExecutionContext, 

276 strategy: Optional['TraversalStrategy'] 

277 ) -> ArcDefinition | None: 

278 """Select transition based on traversal strategy. 

279  

280 Args: 

281 available: Available transitions. 

282 context: Execution context. 

283 strategy: Traversal strategy. 

284  

285 Returns: 

286 Selected arc or None. 

287 """ 

288 from dataknobs_fsm.execution.engine import TraversalStrategy 

289 

290 selected = None 

291 

292 if strategy == TraversalStrategy.DEPTH_FIRST: 

293 # Take first available (highest priority) 

294 selected = available[0] 

295 

296 elif strategy == TraversalStrategy.BREADTH_FIRST: 

297 # Prefer transitions to unvisited states 

298 for arc in available: 

299 if arc.target_state not in context.state_history: 

300 selected = arc 

301 break 

302 if not selected: 

303 selected = available[0] 

304 

305 elif strategy == TraversalStrategy.RESOURCE_OPTIMIZED: 

306 # Choose transition with least resource requirements 

307 best_arc = None 

308 min_resources = float('inf') 

309 

310 for arc in available: 

311 resource_count = len(arc.resource_requirements) if hasattr(arc, 'resource_requirements') else 0 

312 if resource_count < min_resources: 

313 min_resources = resource_count 

314 best_arc = arc 

315 

316 selected = best_arc or available[0] 

317 

318 elif strategy == TraversalStrategy.STREAM_OPTIMIZED: 

319 # Prefer transitions that support streaming 

320 for arc in available: 

321 if hasattr(arc, 'supports_streaming') and arc.supports_streaming: 

322 selected = arc 

323 break 

324 if not selected: 

325 selected = available[0] 

326 else: 

327 # Default to first available 

328 selected = available[0] 

329 

330 state_name = context.current_state or "" 

331 if selected: 

332 ArcScorer.update_arc_usage(selected, context, state_name) 

333 

334 return selected 

335 

336 def _select_by_scoring( 

337 self, 

338 available: List[ArcDefinition], 

339 context: ExecutionContext 

340 ) -> ArcDefinition | None: 

341 """Select transition based on multi-factor scoring. 

342  

343 Args: 

344 available: Available transitions. 

345 context: Execution context. 

346  

347 Returns: 

348 Selected arc or None. 

349 """ 

350 # Score each arc 

351 state_name = context.current_state or "" 

352 arc_scores = [] 

353 for arc in available: 

354 score = ArcScorer.score_arc(arc, context, state_name) 

355 arc_scores.append((arc, score)) 

356 

357 # Sort by score (highest first) 

358 arc_scores.sort(key=lambda x: x[1], reverse=True) 

359 

360 # If top scores are tied, use round-robin for load balancing 

361 top_score = arc_scores[0][1] 

362 tied_arcs = [arc for arc, score in arc_scores if abs(score - top_score) < 0.01] 

363 

364 if len(tied_arcs) > 1: 

365 # Use round-robin selection for tied arcs 

366 state_name = context.current_state or "" 

367 round_robin_key = f"state_{state_name}_round_robin" 

368 current_index = context.metadata.get(round_robin_key, 0) 

369 selected_arc = tied_arcs[current_index % len(tied_arcs)] 

370 context.metadata[round_robin_key] = current_index + 1 

371 else: 

372 # Return highest scoring arc 

373 selected_arc = arc_scores[0][0] 

374 

375 # Update usage count 

376 state_name = context.current_state or "" 

377 ArcScorer.update_arc_usage(selected_arc, context, state_name) 

378 

379 return selected_arc 

380 

381 

382def extract_metrics_from_context(context: ExecutionContext) -> Dict[str, Any]: 

383 """Extract performance metrics from execution context. 

384  

385 Args: 

386 context: Execution context. 

387  

388 Returns: 

389 Dictionary of metrics. 

390 """ 

391 metrics = {} 

392 

393 # Extract arc usage statistics 

394 arc_metrics = {} 

395 for key, value in context.metadata.items(): 

396 if key.startswith('arc_') and key.endswith('_usage'): 

397 arc_name = key.replace('arc_', '').replace('_usage', '') 

398 arc_metrics[arc_name] = value 

399 

400 if arc_metrics: 

401 metrics['arc_usage'] = arc_metrics 

402 

403 # Extract network selection warnings 

404 if 'network_selection_warning' in context.metadata: 

405 metrics['warnings'] = [context.metadata['network_selection_warning']] 

406 

407 # Extract batch information 

408 if 'batch_info' in context.metadata: 

409 metrics['batch'] = context.metadata['batch_info'] 

410 

411 # Extract execution times 

412 exec_times = {} 

413 for key, value in context.metadata.items(): 

414 if 'execution_time' in key or 'acquisition_time' in key: 

415 exec_times[key] = value 

416 

417 if exec_times: 

418 metrics['timing'] = exec_times 

419 

420 return metrics