Coverage for src/dataknobs_fsm/config/loader.py: 30%
281 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:51 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:51 -0600
1"""Configuration loader for FSM configurations.
3This module provides functionality to load FSM configurations from various sources:
4- Files (JSON, YAML)
5- Dictionaries
6- Templates
7- Environment variables
8"""
10import json
11import os
12from pathlib import Path
13from typing import Any, Dict, List, Set, Union
15import yaml
16from dataknobs_config import Config as DataknobsConfig
18from dataknobs_fsm.config.schema import (
19 FSMConfig,
20 TemplateConfig,
21 UseCaseTemplate,
22 apply_template,
23 validate_config,
24)
27class ConfigLoader:
28 """Load and process FSM configurations from various sources."""
30 def __init__(self, use_dataknobs_config: bool = False):
31 """Initialize the ConfigLoader.
33 Args:
34 use_dataknobs_config: Whether to use dataknobs_config for advanced features.
35 """
36 self.use_dataknobs_config = use_dataknobs_config
37 self._env_prefix = "FSM_"
38 self._included_configs: Dict[str, Dict[str, Any]] = {}
39 self._registered_functions: Set[str] = set()
41 def add_registered_function(self, name: str) -> None:
42 """Add a function name to the set of registered functions.
44 Args:
45 name: Function name that has been registered.
46 """
47 self._registered_functions.add(name)
49 def _convert_to_function_reference(self, value: Any) -> Dict[str, Any]:
50 """Convert a value to a function reference dictionary.
52 Args:
53 value: The value to convert (string, dict, etc.)
55 Returns:
56 Function reference dictionary with 'type' and appropriate fields.
57 """
58 if isinstance(value, dict):
59 # Already a function reference
60 return value
61 elif isinstance(value, str):
62 # Check if it's a registered function
63 if value in self._registered_functions:
64 return {
65 'type': 'registered',
66 'name': value
67 }
68 else:
69 # Treat as inline code
70 return {
71 'type': 'inline',
72 'code': value
73 }
74 else:
75 # Convert to string and treat as inline code
76 return {
77 'type': 'inline',
78 'code': str(value)
79 }
81 def load_from_file(
82 self,
83 file_path: Union[str, Path],
84 resolve_env: bool = True,
85 resolve_references: bool = True,
86 ) -> FSMConfig:
87 """Load configuration from a file.
89 Args:
90 file_path: Path to configuration file (JSON or YAML).
91 resolve_env: Whether to resolve environment variables.
92 resolve_references: Whether to resolve file references.
94 Returns:
95 Validated FSMConfig instance.
97 Raises:
98 FileNotFoundError: If file doesn't exist.
99 ValueError: If file format is not supported.
100 """
101 file_path = Path(file_path)
103 if not file_path.exists():
104 raise FileNotFoundError(f"Configuration file not found: {file_path}")
106 # Load raw configuration
107 raw_config = self._load_file(file_path)
109 # Process with dataknobs_config if enabled
110 if self.use_dataknobs_config:
111 config_obj = DataknobsConfig(raw_config)
112 processed_config = config_obj.to_dict()
113 else:
114 processed_config = raw_config
116 # Resolve environment variables
117 if resolve_env:
118 processed_config = self._resolve_environment_vars(processed_config)
120 # Resolve file references (includes/imports)
121 if resolve_references:
122 processed_config = self._resolve_references(processed_config, file_path.parent)
124 # Apply common transformations and validate
125 return self._finalize_config(processed_config)
127 def load_from_dict(
128 self,
129 config_dict: Dict[str, Any],
130 resolve_env: bool = True,
131 ) -> FSMConfig:
132 """Load configuration from a dictionary.
134 Args:
135 config_dict: Configuration dictionary.
136 resolve_env: Whether to resolve environment variables.
138 Returns:
139 Validated FSMConfig instance.
140 """
141 processed_config = config_dict.copy()
143 # Process with dataknobs_config if enabled
144 if self.use_dataknobs_config:
145 config_obj = DataknobsConfig(processed_config)
146 processed_config = config_obj.to_dict()
148 # Resolve environment variables
149 if resolve_env:
150 processed_config = self._resolve_environment_vars(processed_config)
152 # Apply common transformations and validate
153 return self._finalize_config(processed_config)
155 def load_from_template(
156 self,
157 template: Union[UseCaseTemplate, str],
158 params: Dict[str, Any] | None = None,
159 overrides: Dict[str, Any] | None = None,
160 ) -> FSMConfig:
161 """Load configuration from a template.
163 Args:
164 template: Template name or enum value.
165 params: Template parameters.
166 overrides: Configuration overrides.
168 Returns:
169 Validated FSMConfig instance.
170 """
171 if isinstance(template, str):
172 template = UseCaseTemplate(template)
174 # Apply template
175 config_dict = apply_template(template, params, overrides)
177 # Load from dictionary
178 return self.load_from_dict(config_dict)
180 def load_template_config(self, template_config: TemplateConfig) -> FSMConfig:
181 """Load configuration from a template configuration object.
183 Args:
184 template_config: Template configuration.
186 Returns:
187 Validated FSMConfig instance.
188 """
189 return self.load_from_template(
190 template_config.template,
191 template_config.params,
192 template_config.overrides,
193 )
195 def _load_file(self, file_path: Path) -> Dict[str, Any]:
196 """Load raw configuration from a file.
198 Args:
199 file_path: Path to configuration file.
201 Returns:
202 Raw configuration dictionary.
204 Raises:
205 ValueError: If file format is not supported.
206 """
207 suffix = file_path.suffix.lower()
209 with open(file_path) as f:
210 if suffix == ".json":
211 return json.load(f)
212 elif suffix in [".yaml", ".yml"]:
213 return yaml.safe_load(f)
214 else:
215 raise ValueError(f"Unsupported file format: {suffix}")
217 def _resolve_environment_vars(self, config: Any) -> Any:
218 """Resolve environment variables in configuration.
220 Supports:
221 - ${VAR_NAME} - Required variable
222 - ${VAR_NAME:-default} - Variable with default value
223 - ${VAR_NAME:?error message} - Required with custom error
225 Args:
226 config: Configuration to process.
228 Returns:
229 Configuration with resolved environment variables.
230 """
231 if isinstance(config, str):
232 # Check for environment variable pattern
233 if config.startswith("${") and config.endswith("}"):
234 var_expr = config[2:-1]
236 # Handle default value
237 if ":-" in var_expr:
238 var_name, default_value = var_expr.split(":-", 1)
239 return os.environ.get(var_name, default_value)
241 # Handle error message
242 elif ":?" in var_expr:
243 var_name, error_msg = var_expr.split(":?", 1)
244 if var_name not in os.environ:
245 raise ValueError(f"Required environment variable: {error_msg}")
246 return os.environ[var_name]
248 # Simple variable
249 else:
250 if var_expr not in os.environ:
251 # Check with prefix
252 prefixed_var = f"{self._env_prefix}{var_expr}"
253 if prefixed_var in os.environ:
254 return os.environ[prefixed_var]
255 raise ValueError(f"Environment variable not found: {var_expr}")
256 return os.environ[var_expr]
258 # Also support $VAR_NAME format for compatibility
259 elif config.startswith("$") and not config.startswith("${"):
260 var_name = config[1:]
261 if var_name in os.environ:
262 return os.environ[var_name]
263 prefixed_var = f"{self._env_prefix}{var_name}"
264 if prefixed_var in os.environ:
265 return os.environ[prefixed_var]
267 return config
269 elif isinstance(config, dict):
270 return {key: self._resolve_environment_vars(value) for key, value in config.items()}
272 elif isinstance(config, list):
273 return [self._resolve_environment_vars(item) for item in config]
275 else:
276 return config
278 def _finalize_config(self, config: Dict[str, Any]) -> FSMConfig:
279 """Apply final transformations and validate configuration.
281 This method applies all common transformations that should happen
282 regardless of the source of the configuration.
284 Args:
285 config: Configuration dictionary.
287 Returns:
288 Validated FSMConfig instance.
289 """
290 # Transform simple format to network format if needed
291 config = self._transform_simple_to_network(config)
293 # Transform network-level arcs to state-level arcs if present
294 config = self._transform_network_arcs(config)
296 # Transform state functions field to transforms list
297 config = self._transform_state_functions(config)
299 # Validate and return
300 return validate_config(config)
302 def _transform_simple_to_network(self, config: Dict[str, Any]) -> Dict[str, Any]:
303 """Transform simple format to network format if needed.
305 Detects if the config is in simple format (has 'states' and 'arcs' at
306 top level without 'networks') and transforms it to network format.
308 Args:
309 config: Configuration dictionary.
311 Returns:
312 Transformed configuration.
313 """
314 # Check if this is a simple format config
315 if 'states' in config and 'networks' not in config:
316 # Convert states from dict to list format if needed
317 states = config['states']
318 arcs = list(config.get('arcs', [])) # Start with existing arcs if any
320 if isinstance(states, dict):
321 # Convert dict-style states to list-style
322 states_list = []
323 initial_state = config.get('initial_state')
325 for name, state_config in states.items():
326 state = state_config.copy() if isinstance(state_config, dict) else {}
327 state['name'] = name
329 # Mark initial state if specified
330 if initial_state and name == initial_state:
331 state['is_start'] = True
333 # Extract inline transitions and convert to arcs
334 for transition_type in ['on_complete', 'on_error', 'on_timeout']:
335 if transition_type in state:
336 transition = state.pop(transition_type)
337 if isinstance(transition, dict) and 'target' in transition:
338 arc = {
339 'from': name,
340 'to': transition['target'],
341 'type': transition_type.replace('on_', '') # on_complete -> complete
342 }
343 # Add any conditions or transforms
344 if 'condition' in transition:
345 arc['condition'] = transition['condition']
346 if 'transform' in transition:
347 arc['transform'] = transition['transform']
348 arcs.append(arc)
350 # Mark final states
351 if state.get('final'):
352 state['is_end'] = True
353 state.pop('final') # Remove the 'final' field as we use 'is_end'
355 states_list.append(state)
356 states = states_list
358 # If no initial state was specified, mark the first state as start
359 if not initial_state and states_list:
360 states_list[0]['is_start'] = True
362 # Transform to network format
363 network_config = {
364 'name': config.get('name', 'default_fsm'),
365 'networks': [{
366 'name': 'main',
367 'states': states,
368 'arcs': self._add_type_to_transforms(arcs)
369 }],
370 'main_network': 'main'
371 }
373 # Handle data_mode transformation
374 if 'data_mode' in config:
375 mode = config['data_mode']
376 if isinstance(mode, str):
377 # Convert string to proper data_mode config
378 network_config['data_mode'] = {
379 'default': mode.lower() if mode.lower() in ['copy', 'reference', 'direct'] else 'copy'
380 }
381 else:
382 network_config['data_mode'] = mode
383 else:
384 network_config['data_mode'] = {'default': 'copy'}
386 # Handle initial_state if present
387 if 'initial_state' in config:
388 network_config['networks'][0]['initial_state'] = config['initial_state']
390 # Copy over other top-level fields
391 for key in ['resources', 'templates', 'functions', 'execution', 'metadata']:
392 if key in config:
393 network_config[key] = config[key]
395 return network_config
397 return config
399 def _add_type_to_transforms(self, arcs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
400 """Add type field to arc transforms if missing.
402 Args:
403 arcs: List of arc configurations.
405 Returns:
406 Updated arc configurations.
407 """
408 updated_arcs = []
409 for arc in arcs:
410 arc_copy = arc.copy()
411 if 'transform' in arc_copy and isinstance(arc_copy['transform'], dict):
412 if 'type' not in arc_copy['transform']:
413 # Infer type from the content
414 if 'code' in arc_copy['transform']:
415 arc_copy['transform']['type'] = 'inline'
416 elif 'lambda' in arc_copy['transform']:
417 arc_copy['transform']['type'] = 'lambda'
418 elif 'module' in arc_copy['transform']:
419 arc_copy['transform']['type'] = 'module'
420 else:
421 arc_copy['transform']['type'] = 'inline'
422 updated_arcs.append(arc_copy)
423 return updated_arcs
425 def _transform_network_arcs(self, config: Dict[str, Any]) -> Dict[str, Any]:
426 """Transform network-level arcs format to state-level arcs format.
428 This allows for a more intuitive configuration format where arcs
429 are defined at the network level with 'from' and 'to' fields,
430 rather than attached to the source state.
432 Args:
433 config: Configuration dictionary.
435 Returns:
436 Transformed configuration.
437 """
438 config = config.copy()
440 # Process each network
441 if 'networks' in config:
442 for network in config['networks']:
443 if 'arcs' in network and isinstance(network['arcs'], list):
444 # Build a map of state name to state config
445 state_map = {}
446 for state in network.get('states', []):
447 state_map[state['name']] = state
448 # Ensure each state has an arcs list
449 if 'arcs' not in state:
450 state['arcs'] = []
452 # Transform network-level arcs to state-level arcs
453 for arc in network['arcs']:
454 if 'from' in arc and 'to' in arc:
455 from_state = arc['from']
456 to_state = arc['to']
458 # Create state-level arc config
459 state_arc = {
460 'target': to_state
461 }
463 # Handle legacy pre_test format
464 if 'pre_test' in arc and 'condition' not in arc:
465 pre_test = arc['pre_test']
466 if isinstance(pre_test, dict) and 'test' in pre_test:
467 # Convert pre_test.test to condition
468 state_arc['condition'] = self._convert_to_function_reference(pre_test['test'])
469 elif isinstance(pre_test, str):
470 # Direct function reference or inline code
471 state_arc['condition'] = self._convert_to_function_reference(pre_test)
473 # Copy optional fields
474 for field in ['name', 'condition', 'transform', 'priority', 'metadata']:
475 if field in arc:
476 if field == 'name':
477 # Store arc name in metadata
478 if 'metadata' not in state_arc:
479 state_arc['metadata'] = {}
480 state_arc['metadata']['name'] = arc[field]
481 elif field == 'condition':
482 # Handle condition field
483 condition = arc[field]
484 if isinstance(condition, dict):
485 # Check for simple condition types
486 if condition.get('type') == 'success':
487 # Transform to check validation success in data
488 state_arc['condition'] = {
489 'type': 'inline',
490 'code': 'data.get("valid", True)'
491 }
492 elif condition.get('type') == 'failure':
493 # Transform to check validation failure in data
494 state_arc['condition'] = {
495 'type': 'inline',
496 'code': 'not data.get("valid", True)'
497 }
498 else:
499 # Keep as is
500 state_arc[field] = condition
501 elif isinstance(condition, str):
502 # Simple string condition
503 if condition == 'success':
504 state_arc['condition'] = {
505 'type': 'inline',
506 'code': 'data.get("valid", True)'
507 }
508 elif condition == 'failure':
509 state_arc['condition'] = {
510 'type': 'inline',
511 'code': 'not data.get("valid", True)'
512 }
513 else:
514 # Check if registered function or inline code
515 state_arc['condition'] = self._convert_to_function_reference(condition)
516 else:
517 state_arc[field] = condition
518 else:
519 state_arc[field] = arc[field]
521 # Add arc to the source state
522 if from_state in state_map:
523 state_map[from_state]['arcs'].append(state_arc)
525 # Remove network-level arcs since they've been transformed
526 del network['arcs']
529 return config
531 def _transform_state_functions(self, config: Dict[str, Any]) -> Dict[str, Any]:
532 """Transform state 'functions' field to proper schema format.
534 This converts the legacy 'functions' format to the proper schema:
535 - functions.validate -> validators (state validation)
536 - functions.transform -> transforms (state transformation when entering state)
538 Args:
539 config: Configuration dictionary.
541 Returns:
542 Transformed configuration.
543 """
544 config = config.copy()
546 # Process each network
547 if 'networks' in config:
548 for network in config['networks']:
549 if 'states' in network:
550 for state in network['states']:
551 # Check if state has 'functions' field
552 if 'functions' in state and isinstance(state['functions'], dict):
553 functions = state['functions']
555 # Convert validate function to validators list
556 if 'validate' in functions:
557 validate_func = functions['validate']
558 state['validators'] = [self._convert_to_function_reference(validate_func)]
560 # Convert transform function to transforms list (StateTransform)
561 if 'transform' in functions:
562 transform_func = functions['transform']
563 state['transforms'] = [self._convert_to_function_reference(transform_func)]
565 # Remove the functions field as it's not in the schema
566 del state['functions']
568 # Also handle direct 'transform' field (singular) for convenience
569 if 'transform' in state and 'transforms' not in state:
570 transform = state['transform']
571 if isinstance(transform, list):
572 state['transforms'] = transform
573 else:
574 state['transforms'] = [self._convert_to_function_reference(transform)]
575 del state['transform']
577 # Similarly handle direct 'validator' field (singular)
578 if 'validator' in state and 'validators' not in state:
579 validator = state['validator']
580 if isinstance(validator, list):
581 state['validators'] = validator
582 else:
583 state['validators'] = [self._convert_to_function_reference(validator)]
584 del state['validator']
586 return config
588 def _resolve_references(self, config: Dict[str, Any], base_path: Path) -> Dict[str, Any]:
589 """Resolve file references (includes/imports) in configuration.
591 Supports:
592 - $include: path/to/file.yaml
593 - $import: { file: path/to/file.yaml, path: some.nested.path }
595 Args:
596 config: Configuration dictionary.
597 base_path: Base path for resolving relative paths.
599 Returns:
600 Configuration with resolved references.
601 """
602 processed = {}
604 for key, value in config.items():
605 if key == "$include" and isinstance(value, str):
606 # Load and merge included file
607 include_path = base_path / value
608 if include_path.as_posix() not in self._included_configs:
609 included = self._load_file(include_path)
610 self._included_configs[include_path.as_posix()] = included
611 else:
612 included = self._included_configs[include_path.as_posix()]
614 # Recursively resolve references in included content
615 included = self._resolve_references(included, include_path.parent)
617 # Merge with current config
618 for inc_key, inc_value in included.items():
619 if inc_key not in processed:
620 processed[inc_key] = inc_value
622 elif key == "$import" and isinstance(value, dict):
623 # Import specific path from file
624 file_path = base_path / value["file"]
625 path_expr = value.get("path", "")
627 if file_path.as_posix() not in self._included_configs:
628 imported = self._load_file(file_path)
629 self._included_configs[file_path.as_posix()] = imported
630 else:
631 imported = self._included_configs[file_path.as_posix()]
633 # Navigate to specified path
634 if path_expr:
635 for part in path_expr.split("."):
636 imported = imported.get(part, {})
638 # Recursively resolve references
639 if isinstance(imported, dict):
640 imported = self._resolve_references(imported, file_path.parent)
642 return imported
644 elif isinstance(value, dict):
645 processed[key] = self._resolve_references(value, base_path)
647 elif isinstance(value, list):
648 processed[key] = [
649 self._resolve_references(item, base_path) if isinstance(item, dict) else item
650 for item in value
651 ]
653 else:
654 processed[key] = value
656 return processed
658 def validate_file(self, file_path: Union[str, Path]) -> bool:
659 """Validate a configuration file without fully loading it.
661 Args:
662 file_path: Path to configuration file.
664 Returns:
665 True if valid, False otherwise.
666 """
667 try:
668 self.load_from_file(file_path)
669 return True
670 except Exception:
671 return False
673 def merge_configs(self, *configs: FSMConfig) -> FSMConfig:
674 """Merge multiple FSM configurations.
676 Later configurations override earlier ones.
678 Args:
679 *configs: FSMConfig instances to merge.
681 Returns:
682 Merged FSMConfig instance.
683 """
684 merged_dict = {}
686 for config in configs:
687 config_dict = config.model_dump()
688 self._deep_merge(merged_dict, config_dict)
690 return validate_config(merged_dict)
692 def _deep_merge(self, base: Dict, updates: Dict) -> Dict:
693 """Deep merge two dictionaries.
695 Args:
696 base: Base dictionary (modified in place).
697 updates: Updates to apply.
699 Returns:
700 Merged dictionary.
701 """
702 for key, value in updates.items():
703 if key in base and isinstance(base[key], dict) and isinstance(value, dict):
704 self._deep_merge(base[key], value)
705 elif key in base and isinstance(base[key], list) and isinstance(value, list):
706 # For lists, we extend rather than replace
707 base[key].extend(value)
708 else:
709 base[key] = value
711 return base