Coverage for src/dataknobs_fsm/core/state.py: 60%

173 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Core state definitions and instances for FSM. 

2 

3This module provides: 

4- StateDefinition: Blueprint for states with schema, functions, etc. 

5- StateInstance: Runtime instance of a state with data 

6""" 

7 

8from dataclasses import dataclass, field as dataclass_field 

9from datetime import datetime 

10from enum import Enum 

11from typing import Any, Dict, List, Set, Tuple, TYPE_CHECKING 

12from uuid import uuid4 

13 

14from dataknobs_data.fields import Field 

15from dataknobs_fsm.core.data_modes import DataHandlingMode, DataModeManager 

16from dataknobs_fsm.core.transactions import Transaction 

17from dataknobs_fsm.functions.base import ( 

18 IValidationFunction, 

19 ITransformFunction, 

20 IEndStateTestFunction, 

21 ResourceConfig, 

22) 

23 

24if TYPE_CHECKING: 

25 from dataknobs_fsm.core.arc import ArcDefinition 

26 

27 

28class StateType(Enum): 

29 """Type of state in the FSM.""" 

30 

31 NORMAL = "normal" # Regular processing state 

32 START = "start" # Entry point state 

33 END = "end" # Terminal state 

34 START_END = "start_end" # Both entry and terminal state (for simple FSMs) 

35 ERROR = "error" # Error handling state 

36 CHOICE = "choice" # Decision/branching state 

37 WAIT = "wait" # Waiting/pause state 

38 PARALLEL = "parallel" # Parallel execution state 

39 

40 

41class StateStatus(Enum): 

42 """Status of a state instance.""" 

43 

44 PENDING = "pending" # Not yet entered 

45 ACTIVE = "active" # Currently processing 

46 COMPLETED = "completed" # Successfully completed 

47 FAILED = "failed" # Failed with error 

48 SKIPPED = "skipped" # Skipped in execution 

49 PAUSED = "paused" # Paused execution 

50 

51 

52@dataclass 

53class StateSchema: 

54 """Schema definition for state data.""" 

55 

56 fields: List[Field] 

57 required_fields: Set[str] = dataclass_field(default_factory=set) 

58 constraints: Dict[str, Any] = dataclass_field(default_factory=dict) 

59 allow_extra_fields: bool = True 

60 

61 def validate(self, data: Dict[str, Any]) -> Tuple[bool, List[str]]: 

62 """Validate data against schema. 

63  

64 Args: 

65 data: Data to validate. 

66  

67 Returns: 

68 Tuple of (is_valid, error_messages). 

69 """ 

70 errors = [] 

71 

72 # Check required fields 

73 for field_name in self.required_fields: 

74 if field_name not in data: 

75 errors.append(f"Required field '{field_name}' is missing") 

76 

77 # Check field types 

78 field_map = {f.name: f for f in self.fields} 

79 for field_name, value in data.items(): 

80 if field_name in field_map: 

81 field_def = field_map[field_name] 

82 test_field = Field(field_name, value, field_def.type) 

83 if not test_field.validate(): 

84 errors.append( 

85 f"Field '{field_name}' has invalid type. " 

86 f"Expected {field_def.type}, got {type(value).__name__}" 

87 ) 

88 elif not self.allow_extra_fields: 

89 errors.append(f"Unexpected field '{field_name}'") 

90 

91 return len(errors) == 0, errors 

92 

93 

94@dataclass 

95class StateDefinition: 

96 """Definition of a state in the FSM.""" 

97 

98 name: str 

99 type: StateType = StateType.NORMAL 

100 description: str = "" 

101 metadata: Dict[str, Any] = dataclass_field(default_factory=dict) 

102 

103 # Schema and data handling 

104 schema: StateSchema | None = None 

105 data_mode: DataHandlingMode | None = None # Preferred data mode 

106 

107 # Resource requirements 

108 resource_requirements: List[ResourceConfig] = dataclass_field(default_factory=list) 

109 

110 # Functions 

111 pre_validation_functions: List[IValidationFunction] = dataclass_field(default_factory=list) 

112 validation_functions: List[IValidationFunction] = dataclass_field(default_factory=list) 

113 transform_functions: List[ITransformFunction] = dataclass_field(default_factory=list) 

114 end_test_function: IEndStateTestFunction | None = None 

115 

116 # Arc references (will be populated when building network) 

117 outgoing_arcs: List["ArcDefinition"] = dataclass_field(default_factory=list) 

118 

119 # Execution settings 

120 timeout: float | None = None # Timeout in seconds 

121 retry_count: int = 0 # Number of retries on failure 

122 retry_delay: float = 1.0 # Delay between retries in seconds 

123 

124 def is_start_state(self) -> bool: 

125 """Check if this is a start state. 

126  

127 Returns: 

128 True if this is a start state. 

129 """ 

130 return self.type == StateType.START 

131 

132 def is_end_state(self) -> bool: 

133 """Check if this is an end state. 

134  

135 Returns: 

136 True if this is an end state. 

137 """ 

138 return self.type == StateType.END 

139 

140 @property 

141 def is_start(self) -> bool: 

142 """Property alias for is_start_state().""" 

143 return self.is_start_state() 

144 

145 @property 

146 def is_end(self) -> bool: 

147 """Property alias for is_end_state().""" 

148 return self.is_end_state() 

149 

150 @property 

151 def arcs(self) -> List["ArcDefinition"]: 

152 """Get the outgoing arcs from this state.""" 

153 return self.outgoing_arcs 

154 

155 def validate_data(self, data: Dict[str, Any]) -> Tuple[bool, List[str]]: 

156 """Validate data against state schema. 

157  

158 Args: 

159 data: Data to validate. 

160  

161 Returns: 

162 Tuple of (is_valid, error_messages). 

163 """ 

164 if self.schema is None: 

165 return True, [] 

166 return self.schema.validate(data) 

167 

168 def add_pre_validation_function(self, func: IValidationFunction) -> None: 

169 """Add a pre-validation function. 

170 

171 Args: 

172 func: Pre-validation function to add. 

173 """ 

174 self.pre_validation_functions.append(func) 

175 

176 def add_validation_function(self, func: IValidationFunction) -> None: 

177 """Add a validation function. 

178 

179 Args: 

180 func: Validation function to add. 

181 """ 

182 self.validation_functions.append(func) 

183 

184 def add_transform_function(self, func: ITransformFunction) -> None: 

185 """Add a transform function. 

186  

187 Args: 

188 func: Transform function to add. 

189 """ 

190 self.transform_functions.append(func) 

191 

192 def add_outgoing_arc(self, arc: "ArcDefinition") -> None: 

193 """Add an outgoing arc. 

194  

195 Args: 

196 arc: Arc definition to add. 

197 """ 

198 self.outgoing_arcs.append(arc) 

199 

200 

201@dataclass 

202class StateInstance: 

203 """Runtime instance of a state.""" 

204 

205 id: str = dataclass_field(default_factory=lambda: str(uuid4())) 

206 definition: StateDefinition = None 

207 status: StateStatus = StateStatus.PENDING 

208 

209 # Data handling 

210 data: Dict[str, Any] = dataclass_field(default_factory=dict) 

211 data_mode_manager: DataModeManager | None = None 

212 data_handler: Any | None = None # Active data handler 

213 

214 # Transaction participation 

215 transaction: Transaction | None = None 

216 

217 # Resource access 

218 acquired_resources: Dict[str, Any] = dataclass_field(default_factory=dict) 

219 

220 # Execution tracking 

221 entry_time: datetime | None = None 

222 exit_time: datetime | None = None 

223 execution_count: int = 0 

224 error_count: int = 0 

225 last_error: str | None = None 

226 

227 # Arc execution history 

228 executed_arcs: List[str] = dataclass_field(default_factory=list) 

229 next_state: str | None = None 

230 

231 def __post_init__(self): 

232 """Initialize data mode manager if not provided.""" 

233 if self.data_mode_manager is None: 

234 # Use definition's data_mode if available and not None, else default to COPY 

235 default_mode = DataHandlingMode.COPY 

236 if self.definition and self.definition.data_mode: 

237 default_mode = self.definition.data_mode 

238 self.data_mode_manager = DataModeManager(default_mode) 

239 

240 def enter(self, input_data: Dict[str, Any]) -> None: 

241 """Enter the state with input data. 

242  

243 Args: 

244 input_data: Input data for the state. 

245 """ 

246 self.status = StateStatus.ACTIVE 

247 self.entry_time = datetime.now() 

248 self.execution_count += 1 

249 

250 # Apply data mode 

251 if self.data_mode_manager: 

252 mode = self.definition.data_mode if self.definition and self.definition.data_mode else self.data_mode_manager.default_mode 

253 self.data_handler = self.data_mode_manager.get_handler(mode) 

254 self.data = self.data_handler.on_entry(input_data) 

255 else: 

256 self.data = input_data 

257 

258 def exit(self, commit: bool = True) -> Dict[str, Any]: 

259 """Exit the state. 

260  

261 Args: 

262 commit: Whether to commit data changes. 

263  

264 Returns: 

265 The final state data. 

266 """ 

267 self.exit_time = datetime.now() 

268 

269 # Handle data mode exit 

270 if self.data_handler: 

271 self.data = self.data_handler.on_exit(self.data, commit) 

272 

273 if self.status == StateStatus.ACTIVE: 

274 self.status = StateStatus.COMPLETED 

275 

276 return self.data 

277 

278 def fail(self, error: str) -> None: 

279 """Mark the state as failed. 

280  

281 Args: 

282 error: Error message. 

283 """ 

284 self.status = StateStatus.FAILED 

285 self.error_count += 1 

286 self.last_error = error 

287 self.exit_time = datetime.now() 

288 

289 def pause(self) -> None: 

290 """Pause state execution.""" 

291 if self.status == StateStatus.ACTIVE: 

292 self.status = StateStatus.PAUSED 

293 

294 def resume(self) -> None: 

295 """Resume paused state execution.""" 

296 if self.status == StateStatus.PAUSED: 

297 self.status = StateStatus.ACTIVE 

298 

299 def skip(self) -> None: 

300 """Skip this state.""" 

301 self.status = StateStatus.SKIPPED 

302 self.exit_time = datetime.now() 

303 

304 def modify_data(self, updates: Dict[str, Any]) -> None: 

305 """Modify state data. 

306  

307 Args: 

308 updates: Data updates to apply. 

309 """ 

310 if self.data_handler: 

311 # Let the data handler manage modifications 

312 self.data.update(updates) 

313 self.data = self.data_handler.on_modification(self.data) 

314 else: 

315 self.data.update(updates) 

316 

317 def add_resource(self, name: str, resource: Any) -> None: 

318 """Add an acquired resource. 

319  

320 Args: 

321 name: Resource name. 

322 resource: The resource handle/connection. 

323 """ 

324 self.acquired_resources[name] = resource 

325 

326 def get_resource(self, name: str) -> Any | None: 

327 """Get an acquired resource. 

328  

329 Args: 

330 name: Resource name. 

331  

332 Returns: 

333 The resource if available. 

334 """ 

335 return self.acquired_resources.get(name) 

336 

337 def release_resources(self) -> None: 

338 """Release all acquired resources.""" 

339 self.acquired_resources.clear() 

340 

341 def record_arc_execution(self, arc_id: str) -> None: 

342 """Record that an arc was executed. 

343  

344 Args: 

345 arc_id: ID of the executed arc. 

346 """ 

347 self.executed_arcs.append(arc_id) 

348 

349 def get_duration(self) -> float | None: 

350 """Get execution duration in seconds. 

351  

352 Returns: 

353 Duration in seconds if available. 

354 """ 

355 if self.entry_time and self.exit_time: 

356 return (self.exit_time - self.entry_time).total_seconds() 

357 elif self.entry_time: 

358 return (datetime.now() - self.entry_time).total_seconds() 

359 return None 

360 

361 def to_dict(self) -> Dict[str, Any]: 

362 """Convert to dictionary representation. 

363  

364 Returns: 

365 Dictionary with state instance data. 

366 """ 

367 return { 

368 "id": self.id, 

369 "name": self.definition.name if self.definition else None, 

370 "status": self.status.value, 

371 "data": self.data, 

372 "entry_time": self.entry_time.isoformat() if self.entry_time else None, 

373 "exit_time": self.exit_time.isoformat() if self.exit_time else None, 

374 "duration": self.get_duration(), 

375 "execution_count": self.execution_count, 

376 "error_count": self.error_count, 

377 "last_error": self.last_error, 

378 "executed_arcs": self.executed_arcs, 

379 "next_state": self.next_state, 

380 } 

381 

382 

383# Simplified State class for network usage 

384class State: 

385 """Simplified state class for use in state networks.""" 

386 

387 def __init__(self, name: str, **kwargs): 

388 """Initialize state. 

389  

390 Args: 

391 name: State name. 

392 **kwargs: Additional state properties. 

393 """ 

394 self.name = name 

395 self.metadata = kwargs 

396 self.resource_requirements = kwargs.get("resource_requirements", {}) 

397 

398 def to_dict(self) -> Dict[str, Any]: 

399 """Convert to dictionary.""" 

400 return { 

401 "name": self.name, 

402 "metadata": self.metadata, 

403 "resource_requirements": self.resource_requirements 

404 } 

405 

406 

407# StateMode for backwards compatibility 

408class StateMode(Enum): 

409 """Mode of state operation.""" 

410 NORMAL = "normal" 

411 PARALLEL = "parallel" 

412 SEQUENTIAL = "sequential"