Coverage for src/dataknobs_fsm/config/loader.py: 30%

281 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:51 -0600

1"""Configuration loader for FSM configurations. 

2 

3This module provides functionality to load FSM configurations from various sources: 

4- Files (JSON, YAML) 

5- Dictionaries 

6- Templates 

7- Environment variables 

8""" 

9 

10import json 

11import os 

12from pathlib import Path 

13from typing import Any, Dict, List, Set, Union 

14 

15import yaml 

16from dataknobs_config import Config as DataknobsConfig 

17 

18from dataknobs_fsm.config.schema import ( 

19 FSMConfig, 

20 TemplateConfig, 

21 UseCaseTemplate, 

22 apply_template, 

23 validate_config, 

24) 

25 

26 

27class ConfigLoader: 

28 """Load and process FSM configurations from various sources.""" 

29 

30 def __init__(self, use_dataknobs_config: bool = False): 

31 """Initialize the ConfigLoader. 

32  

33 Args: 

34 use_dataknobs_config: Whether to use dataknobs_config for advanced features. 

35 """ 

36 self.use_dataknobs_config = use_dataknobs_config 

37 self._env_prefix = "FSM_" 

38 self._included_configs: Dict[str, Dict[str, Any]] = {} 

39 self._registered_functions: Set[str] = set() 

40 

41 def add_registered_function(self, name: str) -> None: 

42 """Add a function name to the set of registered functions. 

43 

44 Args: 

45 name: Function name that has been registered. 

46 """ 

47 self._registered_functions.add(name) 

48 

49 def _convert_to_function_reference(self, value: Any) -> Dict[str, Any]: 

50 """Convert a value to a function reference dictionary. 

51 

52 Args: 

53 value: The value to convert (string, dict, etc.) 

54 

55 Returns: 

56 Function reference dictionary with 'type' and appropriate fields. 

57 """ 

58 if isinstance(value, dict): 

59 # Already a function reference 

60 return value 

61 elif isinstance(value, str): 

62 # Check if it's a registered function 

63 if value in self._registered_functions: 

64 return { 

65 'type': 'registered', 

66 'name': value 

67 } 

68 else: 

69 # Treat as inline code 

70 return { 

71 'type': 'inline', 

72 'code': value 

73 } 

74 else: 

75 # Convert to string and treat as inline code 

76 return { 

77 'type': 'inline', 

78 'code': str(value) 

79 } 

80 

81 def load_from_file( 

82 self, 

83 file_path: Union[str, Path], 

84 resolve_env: bool = True, 

85 resolve_references: bool = True, 

86 ) -> FSMConfig: 

87 """Load configuration from a file. 

88  

89 Args: 

90 file_path: Path to configuration file (JSON or YAML). 

91 resolve_env: Whether to resolve environment variables. 

92 resolve_references: Whether to resolve file references. 

93  

94 Returns: 

95 Validated FSMConfig instance. 

96  

97 Raises: 

98 FileNotFoundError: If file doesn't exist. 

99 ValueError: If file format is not supported. 

100 """ 

101 file_path = Path(file_path) 

102 

103 if not file_path.exists(): 

104 raise FileNotFoundError(f"Configuration file not found: {file_path}") 

105 

106 # Load raw configuration 

107 raw_config = self._load_file(file_path) 

108 

109 # Process with dataknobs_config if enabled 

110 if self.use_dataknobs_config: 

111 config_obj = DataknobsConfig(raw_config) 

112 processed_config = config_obj.to_dict() 

113 else: 

114 processed_config = raw_config 

115 

116 # Resolve environment variables 

117 if resolve_env: 

118 processed_config = self._resolve_environment_vars(processed_config) 

119 

120 # Resolve file references (includes/imports) 

121 if resolve_references: 

122 processed_config = self._resolve_references(processed_config, file_path.parent) 

123 

124 # Apply common transformations and validate 

125 return self._finalize_config(processed_config) 

126 

127 def load_from_dict( 

128 self, 

129 config_dict: Dict[str, Any], 

130 resolve_env: bool = True, 

131 ) -> FSMConfig: 

132 """Load configuration from a dictionary. 

133  

134 Args: 

135 config_dict: Configuration dictionary. 

136 resolve_env: Whether to resolve environment variables. 

137  

138 Returns: 

139 Validated FSMConfig instance. 

140 """ 

141 processed_config = config_dict.copy() 

142 

143 # Process with dataknobs_config if enabled 

144 if self.use_dataknobs_config: 

145 config_obj = DataknobsConfig(processed_config) 

146 processed_config = config_obj.to_dict() 

147 

148 # Resolve environment variables 

149 if resolve_env: 

150 processed_config = self._resolve_environment_vars(processed_config) 

151 

152 # Apply common transformations and validate 

153 return self._finalize_config(processed_config) 

154 

155 def load_from_template( 

156 self, 

157 template: Union[UseCaseTemplate, str], 

158 params: Dict[str, Any] | None = None, 

159 overrides: Dict[str, Any] | None = None, 

160 ) -> FSMConfig: 

161 """Load configuration from a template. 

162  

163 Args: 

164 template: Template name or enum value. 

165 params: Template parameters. 

166 overrides: Configuration overrides. 

167  

168 Returns: 

169 Validated FSMConfig instance. 

170 """ 

171 if isinstance(template, str): 

172 template = UseCaseTemplate(template) 

173 

174 # Apply template 

175 config_dict = apply_template(template, params, overrides) 

176 

177 # Load from dictionary 

178 return self.load_from_dict(config_dict) 

179 

180 def load_template_config(self, template_config: TemplateConfig) -> FSMConfig: 

181 """Load configuration from a template configuration object. 

182  

183 Args: 

184 template_config: Template configuration. 

185  

186 Returns: 

187 Validated FSMConfig instance. 

188 """ 

189 return self.load_from_template( 

190 template_config.template, 

191 template_config.params, 

192 template_config.overrides, 

193 ) 

194 

195 def _load_file(self, file_path: Path) -> Dict[str, Any]: 

196 """Load raw configuration from a file. 

197  

198 Args: 

199 file_path: Path to configuration file. 

200  

201 Returns: 

202 Raw configuration dictionary. 

203  

204 Raises: 

205 ValueError: If file format is not supported. 

206 """ 

207 suffix = file_path.suffix.lower() 

208 

209 with open(file_path) as f: 

210 if suffix == ".json": 

211 return json.load(f) 

212 elif suffix in [".yaml", ".yml"]: 

213 return yaml.safe_load(f) 

214 else: 

215 raise ValueError(f"Unsupported file format: {suffix}") 

216 

217 def _resolve_environment_vars(self, config: Any) -> Any: 

218 """Resolve environment variables in configuration. 

219  

220 Supports: 

221 - ${VAR_NAME} - Required variable 

222 - ${VAR_NAME:-default} - Variable with default value 

223 - ${VAR_NAME:?error message} - Required with custom error 

224  

225 Args: 

226 config: Configuration to process. 

227  

228 Returns: 

229 Configuration with resolved environment variables. 

230 """ 

231 if isinstance(config, str): 

232 # Check for environment variable pattern 

233 if config.startswith("${") and config.endswith("}"): 

234 var_expr = config[2:-1] 

235 

236 # Handle default value 

237 if ":-" in var_expr: 

238 var_name, default_value = var_expr.split(":-", 1) 

239 return os.environ.get(var_name, default_value) 

240 

241 # Handle error message 

242 elif ":?" in var_expr: 

243 var_name, error_msg = var_expr.split(":?", 1) 

244 if var_name not in os.environ: 

245 raise ValueError(f"Required environment variable: {error_msg}") 

246 return os.environ[var_name] 

247 

248 # Simple variable 

249 else: 

250 if var_expr not in os.environ: 

251 # Check with prefix 

252 prefixed_var = f"{self._env_prefix}{var_expr}" 

253 if prefixed_var in os.environ: 

254 return os.environ[prefixed_var] 

255 raise ValueError(f"Environment variable not found: {var_expr}") 

256 return os.environ[var_expr] 

257 

258 # Also support $VAR_NAME format for compatibility 

259 elif config.startswith("$") and not config.startswith("${"): 

260 var_name = config[1:] 

261 if var_name in os.environ: 

262 return os.environ[var_name] 

263 prefixed_var = f"{self._env_prefix}{var_name}" 

264 if prefixed_var in os.environ: 

265 return os.environ[prefixed_var] 

266 

267 return config 

268 

269 elif isinstance(config, dict): 

270 return {key: self._resolve_environment_vars(value) for key, value in config.items()} 

271 

272 elif isinstance(config, list): 

273 return [self._resolve_environment_vars(item) for item in config] 

274 

275 else: 

276 return config 

277 

278 def _finalize_config(self, config: Dict[str, Any]) -> FSMConfig: 

279 """Apply final transformations and validate configuration. 

280  

281 This method applies all common transformations that should happen 

282 regardless of the source of the configuration. 

283  

284 Args: 

285 config: Configuration dictionary. 

286  

287 Returns: 

288 Validated FSMConfig instance. 

289 """ 

290 # Transform simple format to network format if needed 

291 config = self._transform_simple_to_network(config) 

292 

293 # Transform network-level arcs to state-level arcs if present 

294 config = self._transform_network_arcs(config) 

295 

296 # Transform state functions field to transforms list 

297 config = self._transform_state_functions(config) 

298 

299 # Validate and return 

300 return validate_config(config) 

301 

302 def _transform_simple_to_network(self, config: Dict[str, Any]) -> Dict[str, Any]: 

303 """Transform simple format to network format if needed. 

304  

305 Detects if the config is in simple format (has 'states' and 'arcs' at 

306 top level without 'networks') and transforms it to network format. 

307  

308 Args: 

309 config: Configuration dictionary. 

310  

311 Returns: 

312 Transformed configuration. 

313 """ 

314 # Check if this is a simple format config 

315 if 'states' in config and 'networks' not in config: 

316 # Convert states from dict to list format if needed 

317 states = config['states'] 

318 arcs = list(config.get('arcs', [])) # Start with existing arcs if any 

319 

320 if isinstance(states, dict): 

321 # Convert dict-style states to list-style 

322 states_list = [] 

323 initial_state = config.get('initial_state') 

324 

325 for name, state_config in states.items(): 

326 state = state_config.copy() if isinstance(state_config, dict) else {} 

327 state['name'] = name 

328 

329 # Mark initial state if specified 

330 if initial_state and name == initial_state: 

331 state['is_start'] = True 

332 

333 # Extract inline transitions and convert to arcs 

334 for transition_type in ['on_complete', 'on_error', 'on_timeout']: 

335 if transition_type in state: 

336 transition = state.pop(transition_type) 

337 if isinstance(transition, dict) and 'target' in transition: 

338 arc = { 

339 'from': name, 

340 'to': transition['target'], 

341 'type': transition_type.replace('on_', '') # on_complete -> complete 

342 } 

343 # Add any conditions or transforms 

344 if 'condition' in transition: 

345 arc['condition'] = transition['condition'] 

346 if 'transform' in transition: 

347 arc['transform'] = transition['transform'] 

348 arcs.append(arc) 

349 

350 # Mark final states 

351 if state.get('final'): 

352 state['is_end'] = True 

353 state.pop('final') # Remove the 'final' field as we use 'is_end' 

354 

355 states_list.append(state) 

356 states = states_list 

357 

358 # If no initial state was specified, mark the first state as start 

359 if not initial_state and states_list: 

360 states_list[0]['is_start'] = True 

361 

362 # Transform to network format 

363 network_config = { 

364 'name': config.get('name', 'default_fsm'), 

365 'networks': [{ 

366 'name': 'main', 

367 'states': states, 

368 'arcs': self._add_type_to_transforms(arcs) 

369 }], 

370 'main_network': 'main' 

371 } 

372 

373 # Handle data_mode transformation 

374 if 'data_mode' in config: 

375 mode = config['data_mode'] 

376 if isinstance(mode, str): 

377 # Convert string to proper data_mode config 

378 network_config['data_mode'] = { 

379 'default': mode.lower() if mode.lower() in ['copy', 'reference', 'direct'] else 'copy' 

380 } 

381 else: 

382 network_config['data_mode'] = mode 

383 else: 

384 network_config['data_mode'] = {'default': 'copy'} 

385 

386 # Handle initial_state if present 

387 if 'initial_state' in config: 

388 network_config['networks'][0]['initial_state'] = config['initial_state'] 

389 

390 # Copy over other top-level fields 

391 for key in ['resources', 'templates', 'functions', 'execution', 'metadata']: 

392 if key in config: 

393 network_config[key] = config[key] 

394 

395 return network_config 

396 

397 return config 

398 

399 def _add_type_to_transforms(self, arcs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 

400 """Add type field to arc transforms if missing. 

401  

402 Args: 

403 arcs: List of arc configurations. 

404  

405 Returns: 

406 Updated arc configurations. 

407 """ 

408 updated_arcs = [] 

409 for arc in arcs: 

410 arc_copy = arc.copy() 

411 if 'transform' in arc_copy and isinstance(arc_copy['transform'], dict): 

412 if 'type' not in arc_copy['transform']: 

413 # Infer type from the content 

414 if 'code' in arc_copy['transform']: 

415 arc_copy['transform']['type'] = 'inline' 

416 elif 'lambda' in arc_copy['transform']: 

417 arc_copy['transform']['type'] = 'lambda' 

418 elif 'module' in arc_copy['transform']: 

419 arc_copy['transform']['type'] = 'module' 

420 else: 

421 arc_copy['transform']['type'] = 'inline' 

422 updated_arcs.append(arc_copy) 

423 return updated_arcs 

424 

425 def _transform_network_arcs(self, config: Dict[str, Any]) -> Dict[str, Any]: 

426 """Transform network-level arcs format to state-level arcs format. 

427  

428 This allows for a more intuitive configuration format where arcs 

429 are defined at the network level with 'from' and 'to' fields, 

430 rather than attached to the source state. 

431  

432 Args: 

433 config: Configuration dictionary. 

434  

435 Returns: 

436 Transformed configuration. 

437 """ 

438 config = config.copy() 

439 

440 # Process each network 

441 if 'networks' in config: 

442 for network in config['networks']: 

443 if 'arcs' in network and isinstance(network['arcs'], list): 

444 # Build a map of state name to state config 

445 state_map = {} 

446 for state in network.get('states', []): 

447 state_map[state['name']] = state 

448 # Ensure each state has an arcs list 

449 if 'arcs' not in state: 

450 state['arcs'] = [] 

451 

452 # Transform network-level arcs to state-level arcs 

453 for arc in network['arcs']: 

454 if 'from' in arc and 'to' in arc: 

455 from_state = arc['from'] 

456 to_state = arc['to'] 

457 

458 # Create state-level arc config 

459 state_arc = { 

460 'target': to_state 

461 } 

462 

463 # Handle legacy pre_test format 

464 if 'pre_test' in arc and 'condition' not in arc: 

465 pre_test = arc['pre_test'] 

466 if isinstance(pre_test, dict) and 'test' in pre_test: 

467 # Convert pre_test.test to condition 

468 state_arc['condition'] = self._convert_to_function_reference(pre_test['test']) 

469 elif isinstance(pre_test, str): 

470 # Direct function reference or inline code 

471 state_arc['condition'] = self._convert_to_function_reference(pre_test) 

472 

473 # Copy optional fields 

474 for field in ['name', 'condition', 'transform', 'priority', 'metadata']: 

475 if field in arc: 

476 if field == 'name': 

477 # Store arc name in metadata 

478 if 'metadata' not in state_arc: 

479 state_arc['metadata'] = {} 

480 state_arc['metadata']['name'] = arc[field] 

481 elif field == 'condition': 

482 # Handle condition field 

483 condition = arc[field] 

484 if isinstance(condition, dict): 

485 # Check for simple condition types 

486 if condition.get('type') == 'success': 

487 # Transform to check validation success in data 

488 state_arc['condition'] = { 

489 'type': 'inline', 

490 'code': 'data.get("valid", True)' 

491 } 

492 elif condition.get('type') == 'failure': 

493 # Transform to check validation failure in data 

494 state_arc['condition'] = { 

495 'type': 'inline', 

496 'code': 'not data.get("valid", True)' 

497 } 

498 else: 

499 # Keep as is 

500 state_arc[field] = condition 

501 elif isinstance(condition, str): 

502 # Simple string condition 

503 if condition == 'success': 

504 state_arc['condition'] = { 

505 'type': 'inline', 

506 'code': 'data.get("valid", True)' 

507 } 

508 elif condition == 'failure': 

509 state_arc['condition'] = { 

510 'type': 'inline', 

511 'code': 'not data.get("valid", True)' 

512 } 

513 else: 

514 # Check if registered function or inline code 

515 state_arc['condition'] = self._convert_to_function_reference(condition) 

516 else: 

517 state_arc[field] = condition 

518 else: 

519 state_arc[field] = arc[field] 

520 

521 # Add arc to the source state 

522 if from_state in state_map: 

523 state_map[from_state]['arcs'].append(state_arc) 

524 

525 # Remove network-level arcs since they've been transformed 

526 del network['arcs'] 

527 

528 

529 return config 

530 

531 def _transform_state_functions(self, config: Dict[str, Any]) -> Dict[str, Any]: 

532 """Transform state 'functions' field to proper schema format. 

533  

534 This converts the legacy 'functions' format to the proper schema: 

535 - functions.validate -> validators (state validation) 

536 - functions.transform -> transforms (state transformation when entering state) 

537  

538 Args: 

539 config: Configuration dictionary. 

540  

541 Returns: 

542 Transformed configuration. 

543 """ 

544 config = config.copy() 

545 

546 # Process each network 

547 if 'networks' in config: 

548 for network in config['networks']: 

549 if 'states' in network: 

550 for state in network['states']: 

551 # Check if state has 'functions' field 

552 if 'functions' in state and isinstance(state['functions'], dict): 

553 functions = state['functions'] 

554 

555 # Convert validate function to validators list 

556 if 'validate' in functions: 

557 validate_func = functions['validate'] 

558 state['validators'] = [self._convert_to_function_reference(validate_func)] 

559 

560 # Convert transform function to transforms list (StateTransform) 

561 if 'transform' in functions: 

562 transform_func = functions['transform'] 

563 state['transforms'] = [self._convert_to_function_reference(transform_func)] 

564 

565 # Remove the functions field as it's not in the schema 

566 del state['functions'] 

567 

568 # Also handle direct 'transform' field (singular) for convenience 

569 if 'transform' in state and 'transforms' not in state: 

570 transform = state['transform'] 

571 if isinstance(transform, list): 

572 state['transforms'] = transform 

573 else: 

574 state['transforms'] = [self._convert_to_function_reference(transform)] 

575 del state['transform'] 

576 

577 # Similarly handle direct 'validator' field (singular) 

578 if 'validator' in state and 'validators' not in state: 

579 validator = state['validator'] 

580 if isinstance(validator, list): 

581 state['validators'] = validator 

582 else: 

583 state['validators'] = [self._convert_to_function_reference(validator)] 

584 del state['validator'] 

585 

586 return config 

587 

588 def _resolve_references(self, config: Dict[str, Any], base_path: Path) -> Dict[str, Any]: 

589 """Resolve file references (includes/imports) in configuration. 

590  

591 Supports: 

592 - $include: path/to/file.yaml 

593 - $import: { file: path/to/file.yaml, path: some.nested.path } 

594  

595 Args: 

596 config: Configuration dictionary. 

597 base_path: Base path for resolving relative paths. 

598  

599 Returns: 

600 Configuration with resolved references. 

601 """ 

602 processed = {} 

603 

604 for key, value in config.items(): 

605 if key == "$include" and isinstance(value, str): 

606 # Load and merge included file 

607 include_path = base_path / value 

608 if include_path.as_posix() not in self._included_configs: 

609 included = self._load_file(include_path) 

610 self._included_configs[include_path.as_posix()] = included 

611 else: 

612 included = self._included_configs[include_path.as_posix()] 

613 

614 # Recursively resolve references in included content 

615 included = self._resolve_references(included, include_path.parent) 

616 

617 # Merge with current config 

618 for inc_key, inc_value in included.items(): 

619 if inc_key not in processed: 

620 processed[inc_key] = inc_value 

621 

622 elif key == "$import" and isinstance(value, dict): 

623 # Import specific path from file 

624 file_path = base_path / value["file"] 

625 path_expr = value.get("path", "") 

626 

627 if file_path.as_posix() not in self._included_configs: 

628 imported = self._load_file(file_path) 

629 self._included_configs[file_path.as_posix()] = imported 

630 else: 

631 imported = self._included_configs[file_path.as_posix()] 

632 

633 # Navigate to specified path 

634 if path_expr: 

635 for part in path_expr.split("."): 

636 imported = imported.get(part, {}) 

637 

638 # Recursively resolve references 

639 if isinstance(imported, dict): 

640 imported = self._resolve_references(imported, file_path.parent) 

641 

642 return imported 

643 

644 elif isinstance(value, dict): 

645 processed[key] = self._resolve_references(value, base_path) 

646 

647 elif isinstance(value, list): 

648 processed[key] = [ 

649 self._resolve_references(item, base_path) if isinstance(item, dict) else item 

650 for item in value 

651 ] 

652 

653 else: 

654 processed[key] = value 

655 

656 return processed 

657 

658 def validate_file(self, file_path: Union[str, Path]) -> bool: 

659 """Validate a configuration file without fully loading it. 

660  

661 Args: 

662 file_path: Path to configuration file. 

663  

664 Returns: 

665 True if valid, False otherwise. 

666 """ 

667 try: 

668 self.load_from_file(file_path) 

669 return True 

670 except Exception: 

671 return False 

672 

673 def merge_configs(self, *configs: FSMConfig) -> FSMConfig: 

674 """Merge multiple FSM configurations. 

675  

676 Later configurations override earlier ones. 

677  

678 Args: 

679 *configs: FSMConfig instances to merge. 

680  

681 Returns: 

682 Merged FSMConfig instance. 

683 """ 

684 merged_dict = {} 

685 

686 for config in configs: 

687 config_dict = config.model_dump() 

688 self._deep_merge(merged_dict, config_dict) 

689 

690 return validate_config(merged_dict) 

691 

692 def _deep_merge(self, base: Dict, updates: Dict) -> Dict: 

693 """Deep merge two dictionaries. 

694  

695 Args: 

696 base: Base dictionary (modified in place). 

697 updates: Updates to apply. 

698  

699 Returns: 

700 Merged dictionary. 

701 """ 

702 for key, value in updates.items(): 

703 if key in base and isinstance(base[key], dict) and isinstance(value, dict): 

704 self._deep_merge(base[key], value) 

705 elif key in base and isinstance(base[key], list) and isinstance(value, list): 

706 # For lists, we extend rather than replace 

707 base[key].extend(value) 

708 else: 

709 base[key] = value 

710 

711 return base