Coverage for src/dataknobs_fsm/execution/base_engine.py: 56%

87 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Base execution engine with shared logic for sync and async engines. 

2 

3This module provides a base class that contains common logic shared between 

4the synchronous (ExecutionEngine) and asynchronous (AsyncExecutionEngine) 

5implementations, reducing code duplication and ensuring feature parity. 

6""" 

7 

8from abc import ABC, abstractmethod 

9from typing import Any, Dict, List, Tuple, TYPE_CHECKING 

10from types import SimpleNamespace 

11 

12from dataknobs_fsm.core.data_wrapper import ( 

13 ensure_dict, 

14 wrap_for_lambda 

15) 

16from dataknobs_fsm.core.arc import ArcDefinition 

17from dataknobs_fsm.core.fsm import FSM 

18from dataknobs_fsm.core.network import StateNetwork 

19from dataknobs_fsm.core.state import StateType 

20from dataknobs_fsm.execution.context import ExecutionContext 

21from dataknobs_fsm.functions.base import FunctionContext 

22from dataknobs_fsm.execution.common import ( 

23 NetworkSelector, 

24 TransitionSelector, 

25 TransitionSelectionMode 

26) 

27 

28if TYPE_CHECKING: 

29 from dataknobs_fsm.execution.engine import TraversalStrategy 

30 

31 

32class BaseExecutionEngine(ABC): 

33 """Base class for execution engines with shared logic. 

34 

35 This class provides common functionality for both sync and async engines: 

36 - Initial state finding 

37 - Network selection 

38 - State transform preparation 

39 - Arc evaluation logic 

40 - Error handling patterns 

41 - Statistics tracking 

42 """ 

43 

44 def __init__( 

45 self, 

46 fsm: FSM, 

47 strategy: 'TraversalStrategy', 

48 selection_mode: TransitionSelectionMode = TransitionSelectionMode.HYBRID, 

49 max_retries: int = 3, 

50 retry_delay: float = 1.0 

51 ): 

52 """Initialize base execution engine. 

53 

54 Args: 

55 fsm: FSM instance to execute. 

56 strategy: Traversal strategy to use. 

57 selection_mode: Transition selection mode. 

58 max_retries: Maximum retry attempts for failures. 

59 retry_delay: Delay between retries in seconds. 

60 """ 

61 self.fsm = fsm 

62 self.strategy = strategy 

63 self.selection_mode = selection_mode 

64 self.max_retries = max_retries 

65 self.retry_delay = retry_delay 

66 

67 # Initialize transition selector 

68 self.transition_selector = TransitionSelector( 

69 mode=selection_mode, 

70 default_strategy=strategy 

71 ) 

72 

73 # Execution statistics 

74 self._execution_count = 0 

75 self._transition_count = 0 

76 self._error_count = 0 

77 self._total_execution_time = 0.0 

78 

79 def find_initial_state_common(self) -> str | None: 

80 """Find the initial state in the FSM (common logic). 

81 

82 This method contains the shared logic for finding an initial state, 

83 used by both sync and async engines. 

84 

85 Returns: 

86 Name of initial state or None. 

87 """ 

88 # Try to get main_network attribute 

89 main_network = getattr(self.fsm, 'main_network', None) 

90 

91 # Handle string reference to network 

92 if isinstance(main_network, str): 

93 if main_network in self.fsm.networks: 

94 network = self.fsm.networks[main_network] 

95 if hasattr(network, 'initial_states') and network.initial_states: 

96 return next(iter(network.initial_states)) 

97 # Handle direct network object 

98 elif main_network and hasattr(main_network, 'initial_states'): 

99 if main_network.initial_states: 

100 return next(iter(main_network.initial_states)) 

101 

102 # Fallback to fsm.name for compatibility 

103 if hasattr(self.fsm, 'name') and self.fsm.name in self.fsm.networks: 

104 network = self.fsm.networks[self.fsm.name] 

105 if hasattr(network, 'initial_states') and network.initial_states: 

106 return next(iter(network.initial_states)) 

107 

108 # Last resort: check all networks for any initial state 

109 for network in self.fsm.networks.values(): 

110 if hasattr(network, 'initial_states') and network.initial_states: 

111 return next(iter(network.initial_states)) 

112 

113 return None 

114 

115 def is_final_state_common(self, state_name: str | None) -> bool: 

116 """Check if state is a final state (common logic). 

117 

118 Args: 

119 state_name: Name of state to check. 

120 

121 Returns: 

122 True if state is final. 

123 """ 

124 if not state_name: 

125 return False 

126 

127 # Check all networks for this state 

128 for network in self.fsm.networks.values(): 

129 if hasattr(network, 'final_states') and state_name in network.final_states: 

130 return True 

131 # Also check states directly 

132 if hasattr(network, 'states') and state_name in network.states: 

133 state = network.states[state_name] 

134 if hasattr(state, 'type') and state.type == StateType.END: 

135 return True 

136 

137 return False 

138 

139 def get_current_network_common(self, context: ExecutionContext) -> StateNetwork | None: 

140 """Get current network using common selection logic. 

141 

142 Args: 

143 context: Execution context. 

144 

145 Returns: 

146 Current network or None. 

147 """ 

148 return NetworkSelector.get_current_network( 

149 self.fsm, 

150 context, 

151 enable_intelligent_selection=True 

152 ) 

153 

154 def prepare_state_transform( 

155 self, 

156 state_def: Any, 

157 context: ExecutionContext 

158 ) -> Tuple[List[Any], SimpleNamespace]: 

159 """Prepare state transform execution (common logic). 

160 

161 Args: 

162 state_def: State definition. 

163 context: Execution context. 

164 

165 Returns: 

166 Tuple of (transform functions, state object for inline lambdas). 

167 """ 

168 transform_functions = [] 

169 

170 # Check for transform functions on the state 

171 if hasattr(state_def, 'transform_functions') and state_def.transform_functions: 

172 transform_functions = state_def.transform_functions 

173 # Also check for single transform function 

174 elif hasattr(state_def, 'transform_function') and state_def.transform_function: 

175 transform_functions = [state_def.transform_function] 

176 

177 # Create a wrapper for transforms that expect state.data access pattern 

178 # This wrapper provides both dict and attribute access 

179 state_obj = wrap_for_lambda(context.data) 

180 

181 return transform_functions, state_obj 

182 

183 def process_transform_result( 

184 self, 

185 result: Any, 

186 context: ExecutionContext, 

187 state_name: str 

188 ) -> None: 

189 """Process transform result (common logic). 

190 

191 Args: 

192 result: Result from transform function. 

193 context: Execution context. 

194 state_name: Name of current state. 

195 """ 

196 if result is not None: 

197 # Handle ExecutionResult objects from unified function manager 

198 from dataknobs_fsm.functions.base import ExecutionResult 

199 if isinstance(result, ExecutionResult): 

200 if result.success: 

201 # Ensure we store plain dict data 

202 context.data = ensure_dict(result.data) 

203 else: 

204 # Transform failed - handle the error 

205 self.handle_transform_error( 

206 Exception(result.error or "Transform failed"), 

207 context, 

208 state_name 

209 ) 

210 else: 

211 # Ensure we always store plain dict data, not wrappers 

212 context.data = ensure_dict(result) 

213 

214 def handle_transform_error( 

215 self, 

216 error: Exception, 

217 context: ExecutionContext, 

218 state_name: str 

219 ) -> None: 

220 """Handle transform error (common logic). 

221 

222 State transforms failing doesn't stop the FSM, but marks the state as failed. 

223 

224 Args: 

225 error: Exception that occurred. 

226 context: Execution context. 

227 state_name: Name of current state. 

228 """ 

229 if not hasattr(context, 'failed_states'): 

230 context.failed_states = set() 

231 context.failed_states.add(state_name) 

232 

233 def evaluate_arc_condition_common( 

234 self, 

235 arc: ArcDefinition, 

236 context: ExecutionContext 

237 ) -> bool: 

238 """Evaluate arc condition (common logic). 

239 

240 Args: 

241 arc: Arc definition. 

242 context: Execution context. 

243 

244 Returns: 

245 True if arc condition is met. 

246 """ 

247 # If arc has no condition, it's always valid 

248 if not hasattr(arc, 'condition') or not arc.condition: 

249 return True 

250 

251 # Evaluate the condition function 

252 try: 

253 # Create function context 

254 func_context = FunctionContext( 

255 state_name=context.current_state or "", 

256 function_name="arc_condition", 

257 metadata={'arc': arc.name if hasattr(arc, 'name') else None}, 

258 resources={} 

259 ) 

260 

261 # Try different function signatures 

262 try: 

263 # Try with data and context 

264 return bool(arc.condition(context.data, func_context)) 

265 except TypeError: 

266 # Try with just data 

267 return bool(arc.condition(context.data)) 

268 except Exception: 

269 # Condition evaluation failed - arc is not valid 

270 return False 

271 

272 def get_execution_statistics(self) -> Dict[str, Any]: 

273 """Get execution statistics (common implementation). 

274 

275 Returns: 

276 Dictionary of execution statistics. 

277 """ 

278 return { 

279 'execution_count': self._execution_count, 

280 'transition_count': self._transition_count, 

281 'error_count': self._error_count, 

282 'total_execution_time': self._total_execution_time, 

283 'average_execution_time': ( 

284 self._total_execution_time / self._execution_count 

285 if self._execution_count > 0 else 0 

286 ) 

287 } 

288 

289 @abstractmethod 

290 def execute(self, context: ExecutionContext, data: Any = None, 

291 max_transitions: int = 1000, arc_name: str | None = None) -> Tuple[bool, Any]: 

292 """Execute the FSM with given context. 

293 

294 This method must be implemented by sync and async engines. 

295 

296 Args: 

297 context: Execution context. 

298 data: Input data to process. 

299 max_transitions: Maximum transitions before stopping. 

300 arc_name: Optional specific arc name to follow. 

301 

302 Returns: 

303 Tuple of (success, result). 

304 """ 

305 pass 

306 

307 @abstractmethod 

308 def _execute_single(self, context: ExecutionContext, 

309 max_transitions: int, arc_name: str | None = None) -> Any: 

310 """Execute single mode processing. 

311 

312 Must be implemented by subclasses. 

313 """ 

314 pass 

315 

316 @abstractmethod 

317 def _execute_batch(self, context: ExecutionContext, max_transitions: int) -> Any: 

318 """Execute batch mode processing. 

319 

320 Must be implemented by subclasses. 

321 """ 

322 pass 

323 

324 @abstractmethod 

325 def _execute_stream(self, context: ExecutionContext, max_transitions: int) -> Any: 

326 """Execute stream mode processing. 

327 

328 Must be implemented by subclasses. 

329 """ 

330 pass