Coverage for src/dataknobs_fsm/core/state.py: 60%
173 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Core state definitions and instances for FSM.
3This module provides:
4- StateDefinition: Blueprint for states with schema, functions, etc.
5- StateInstance: Runtime instance of a state with data
6"""
8from dataclasses import dataclass, field as dataclass_field
9from datetime import datetime
10from enum import Enum
11from typing import Any, Dict, List, Set, Tuple, TYPE_CHECKING
12from uuid import uuid4
14from dataknobs_data.fields import Field
15from dataknobs_fsm.core.data_modes import DataHandlingMode, DataModeManager
16from dataknobs_fsm.core.transactions import Transaction
17from dataknobs_fsm.functions.base import (
18 IValidationFunction,
19 ITransformFunction,
20 IEndStateTestFunction,
21 ResourceConfig,
22)
24if TYPE_CHECKING:
25 from dataknobs_fsm.core.arc import ArcDefinition
28class StateType(Enum):
29 """Type of state in the FSM."""
31 NORMAL = "normal" # Regular processing state
32 START = "start" # Entry point state
33 END = "end" # Terminal state
34 START_END = "start_end" # Both entry and terminal state (for simple FSMs)
35 ERROR = "error" # Error handling state
36 CHOICE = "choice" # Decision/branching state
37 WAIT = "wait" # Waiting/pause state
38 PARALLEL = "parallel" # Parallel execution state
41class StateStatus(Enum):
42 """Status of a state instance."""
44 PENDING = "pending" # Not yet entered
45 ACTIVE = "active" # Currently processing
46 COMPLETED = "completed" # Successfully completed
47 FAILED = "failed" # Failed with error
48 SKIPPED = "skipped" # Skipped in execution
49 PAUSED = "paused" # Paused execution
52@dataclass
53class StateSchema:
54 """Schema definition for state data."""
56 fields: List[Field]
57 required_fields: Set[str] = dataclass_field(default_factory=set)
58 constraints: Dict[str, Any] = dataclass_field(default_factory=dict)
59 allow_extra_fields: bool = True
61 def validate(self, data: Dict[str, Any]) -> Tuple[bool, List[str]]:
62 """Validate data against schema.
64 Args:
65 data: Data to validate.
67 Returns:
68 Tuple of (is_valid, error_messages).
69 """
70 errors = []
72 # Check required fields
73 for field_name in self.required_fields:
74 if field_name not in data:
75 errors.append(f"Required field '{field_name}' is missing")
77 # Check field types
78 field_map = {f.name: f for f in self.fields}
79 for field_name, value in data.items():
80 if field_name in field_map:
81 field_def = field_map[field_name]
82 test_field = Field(field_name, value, field_def.type)
83 if not test_field.validate():
84 errors.append(
85 f"Field '{field_name}' has invalid type. "
86 f"Expected {field_def.type}, got {type(value).__name__}"
87 )
88 elif not self.allow_extra_fields:
89 errors.append(f"Unexpected field '{field_name}'")
91 return len(errors) == 0, errors
94@dataclass
95class StateDefinition:
96 """Definition of a state in the FSM."""
98 name: str
99 type: StateType = StateType.NORMAL
100 description: str = ""
101 metadata: Dict[str, Any] = dataclass_field(default_factory=dict)
103 # Schema and data handling
104 schema: StateSchema | None = None
105 data_mode: DataHandlingMode | None = None # Preferred data mode
107 # Resource requirements
108 resource_requirements: List[ResourceConfig] = dataclass_field(default_factory=list)
110 # Functions
111 pre_validation_functions: List[IValidationFunction] = dataclass_field(default_factory=list)
112 validation_functions: List[IValidationFunction] = dataclass_field(default_factory=list)
113 transform_functions: List[ITransformFunction] = dataclass_field(default_factory=list)
114 end_test_function: IEndStateTestFunction | None = None
116 # Arc references (will be populated when building network)
117 outgoing_arcs: List["ArcDefinition"] = dataclass_field(default_factory=list)
119 # Execution settings
120 timeout: float | None = None # Timeout in seconds
121 retry_count: int = 0 # Number of retries on failure
122 retry_delay: float = 1.0 # Delay between retries in seconds
124 def is_start_state(self) -> bool:
125 """Check if this is a start state.
127 Returns:
128 True if this is a start state.
129 """
130 return self.type == StateType.START
132 def is_end_state(self) -> bool:
133 """Check if this is an end state.
135 Returns:
136 True if this is an end state.
137 """
138 return self.type == StateType.END
140 @property
141 def is_start(self) -> bool:
142 """Property alias for is_start_state()."""
143 return self.is_start_state()
145 @property
146 def is_end(self) -> bool:
147 """Property alias for is_end_state()."""
148 return self.is_end_state()
150 @property
151 def arcs(self) -> List["ArcDefinition"]:
152 """Get the outgoing arcs from this state."""
153 return self.outgoing_arcs
155 def validate_data(self, data: Dict[str, Any]) -> Tuple[bool, List[str]]:
156 """Validate data against state schema.
158 Args:
159 data: Data to validate.
161 Returns:
162 Tuple of (is_valid, error_messages).
163 """
164 if self.schema is None:
165 return True, []
166 return self.schema.validate(data)
168 def add_pre_validation_function(self, func: IValidationFunction) -> None:
169 """Add a pre-validation function.
171 Args:
172 func: Pre-validation function to add.
173 """
174 self.pre_validation_functions.append(func)
176 def add_validation_function(self, func: IValidationFunction) -> None:
177 """Add a validation function.
179 Args:
180 func: Validation function to add.
181 """
182 self.validation_functions.append(func)
184 def add_transform_function(self, func: ITransformFunction) -> None:
185 """Add a transform function.
187 Args:
188 func: Transform function to add.
189 """
190 self.transform_functions.append(func)
192 def add_outgoing_arc(self, arc: "ArcDefinition") -> None:
193 """Add an outgoing arc.
195 Args:
196 arc: Arc definition to add.
197 """
198 self.outgoing_arcs.append(arc)
201@dataclass
202class StateInstance:
203 """Runtime instance of a state."""
205 id: str = dataclass_field(default_factory=lambda: str(uuid4()))
206 definition: StateDefinition = None
207 status: StateStatus = StateStatus.PENDING
209 # Data handling
210 data: Dict[str, Any] = dataclass_field(default_factory=dict)
211 data_mode_manager: DataModeManager | None = None
212 data_handler: Any | None = None # Active data handler
214 # Transaction participation
215 transaction: Transaction | None = None
217 # Resource access
218 acquired_resources: Dict[str, Any] = dataclass_field(default_factory=dict)
220 # Execution tracking
221 entry_time: datetime | None = None
222 exit_time: datetime | None = None
223 execution_count: int = 0
224 error_count: int = 0
225 last_error: str | None = None
227 # Arc execution history
228 executed_arcs: List[str] = dataclass_field(default_factory=list)
229 next_state: str | None = None
231 def __post_init__(self):
232 """Initialize data mode manager if not provided."""
233 if self.data_mode_manager is None:
234 # Use definition's data_mode if available and not None, else default to COPY
235 default_mode = DataHandlingMode.COPY
236 if self.definition and self.definition.data_mode:
237 default_mode = self.definition.data_mode
238 self.data_mode_manager = DataModeManager(default_mode)
240 def enter(self, input_data: Dict[str, Any]) -> None:
241 """Enter the state with input data.
243 Args:
244 input_data: Input data for the state.
245 """
246 self.status = StateStatus.ACTIVE
247 self.entry_time = datetime.now()
248 self.execution_count += 1
250 # Apply data mode
251 if self.data_mode_manager:
252 mode = self.definition.data_mode if self.definition and self.definition.data_mode else self.data_mode_manager.default_mode
253 self.data_handler = self.data_mode_manager.get_handler(mode)
254 self.data = self.data_handler.on_entry(input_data)
255 else:
256 self.data = input_data
258 def exit(self, commit: bool = True) -> Dict[str, Any]:
259 """Exit the state.
261 Args:
262 commit: Whether to commit data changes.
264 Returns:
265 The final state data.
266 """
267 self.exit_time = datetime.now()
269 # Handle data mode exit
270 if self.data_handler:
271 self.data = self.data_handler.on_exit(self.data, commit)
273 if self.status == StateStatus.ACTIVE:
274 self.status = StateStatus.COMPLETED
276 return self.data
278 def fail(self, error: str) -> None:
279 """Mark the state as failed.
281 Args:
282 error: Error message.
283 """
284 self.status = StateStatus.FAILED
285 self.error_count += 1
286 self.last_error = error
287 self.exit_time = datetime.now()
289 def pause(self) -> None:
290 """Pause state execution."""
291 if self.status == StateStatus.ACTIVE:
292 self.status = StateStatus.PAUSED
294 def resume(self) -> None:
295 """Resume paused state execution."""
296 if self.status == StateStatus.PAUSED:
297 self.status = StateStatus.ACTIVE
299 def skip(self) -> None:
300 """Skip this state."""
301 self.status = StateStatus.SKIPPED
302 self.exit_time = datetime.now()
304 def modify_data(self, updates: Dict[str, Any]) -> None:
305 """Modify state data.
307 Args:
308 updates: Data updates to apply.
309 """
310 if self.data_handler:
311 # Let the data handler manage modifications
312 self.data.update(updates)
313 self.data = self.data_handler.on_modification(self.data)
314 else:
315 self.data.update(updates)
317 def add_resource(self, name: str, resource: Any) -> None:
318 """Add an acquired resource.
320 Args:
321 name: Resource name.
322 resource: The resource handle/connection.
323 """
324 self.acquired_resources[name] = resource
326 def get_resource(self, name: str) -> Any | None:
327 """Get an acquired resource.
329 Args:
330 name: Resource name.
332 Returns:
333 The resource if available.
334 """
335 return self.acquired_resources.get(name)
337 def release_resources(self) -> None:
338 """Release all acquired resources."""
339 self.acquired_resources.clear()
341 def record_arc_execution(self, arc_id: str) -> None:
342 """Record that an arc was executed.
344 Args:
345 arc_id: ID of the executed arc.
346 """
347 self.executed_arcs.append(arc_id)
349 def get_duration(self) -> float | None:
350 """Get execution duration in seconds.
352 Returns:
353 Duration in seconds if available.
354 """
355 if self.entry_time and self.exit_time:
356 return (self.exit_time - self.entry_time).total_seconds()
357 elif self.entry_time:
358 return (datetime.now() - self.entry_time).total_seconds()
359 return None
361 def to_dict(self) -> Dict[str, Any]:
362 """Convert to dictionary representation.
364 Returns:
365 Dictionary with state instance data.
366 """
367 return {
368 "id": self.id,
369 "name": self.definition.name if self.definition else None,
370 "status": self.status.value,
371 "data": self.data,
372 "entry_time": self.entry_time.isoformat() if self.entry_time else None,
373 "exit_time": self.exit_time.isoformat() if self.exit_time else None,
374 "duration": self.get_duration(),
375 "execution_count": self.execution_count,
376 "error_count": self.error_count,
377 "last_error": self.last_error,
378 "executed_arcs": self.executed_arcs,
379 "next_state": self.next_state,
380 }
383# Simplified State class for network usage
384class State:
385 """Simplified state class for use in state networks."""
387 def __init__(self, name: str, **kwargs):
388 """Initialize state.
390 Args:
391 name: State name.
392 **kwargs: Additional state properties.
393 """
394 self.name = name
395 self.metadata = kwargs
396 self.resource_requirements = kwargs.get("resource_requirements", {})
398 def to_dict(self) -> Dict[str, Any]:
399 """Convert to dictionary."""
400 return {
401 "name": self.name,
402 "metadata": self.metadata,
403 "resource_requirements": self.resource_requirements
404 }
407# StateMode for backwards compatibility
408class StateMode(Enum):
409 """Mode of state operation."""
410 NORMAL = "normal"
411 PARALLEL = "parallel"
412 SEQUENTIAL = "sequential"