Coverage for src/dataknobs_fsm/config/schema.py: 82%

172 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:51 -0600

1"""Configuration schema definitions for FSM using Pydantic. 

2 

3This module defines the schema for FSM configuration files, including: 

4- Data mode configuration 

5- Transaction configuration 

6- Resource definitions 

7- Streaming configuration 

8- FSM definition 

9- Network definition 

10- State definition 

11- Arc definition 

12""" 

13 

14from enum import Enum 

15from typing import Any, Dict, List, Literal, Union 

16 

17from pydantic import BaseModel, Field, field_validator, model_validator 

18 

19from dataknobs_fsm.core.data_modes import DataHandlingMode 

20from dataknobs_fsm.core.transactions import TransactionStrategy 

21 

22 

23class ResourceType(str, Enum): 

24 """Available resource types.""" 

25 

26 DATABASE = "database" 

27 FILESYSTEM = "filesystem" 

28 HTTP = "http" 

29 LLM = "llm" 

30 VECTOR_STORE = "vector_store" 

31 CUSTOM = "custom" 

32 

33 

34class ExecutionStrategy(str, Enum): 

35 """Available execution strategies.""" 

36 

37 DEPTH_FIRST = "depth_first" 

38 BREADTH_FIRST = "breadth_first" 

39 RESOURCE_OPTIMIZED = "resource_optimized" 

40 STREAM_OPTIMIZED = "stream_optimized" 

41 

42 

43class FunctionReference(BaseModel): 

44 """Reference to a function.""" 

45 

46 type: Literal["builtin", "custom", "inline", "registered"] 

47 name: str | None = None 

48 module: str | None = None 

49 code: str | None = None 

50 params: Dict[str, Any] = Field(default_factory=dict) 

51 

52 @model_validator(mode="after") 

53 def validate_reference(self) -> "FunctionReference": 

54 """Validate that the reference has required fields based on type.""" 

55 if self.type == "builtin" and not self.name: 

56 raise ValueError("Builtin functions require a 'name'") 

57 if self.type == "custom" and not (self.module and self.name): 

58 raise ValueError("Custom functions require both 'module' and 'name'") 

59 if self.type == "inline" and not self.code: 

60 raise ValueError("Inline functions require 'code'") 

61 if self.type == "registered" and not self.name: 

62 raise ValueError("Registered functions require a 'name'") 

63 return self 

64 

65 

66class DataModeConfig(BaseModel): 

67 """Configuration for data handling modes.""" 

68 

69 default: DataHandlingMode = DataHandlingMode.COPY 

70 state_overrides: Dict[str, DataHandlingMode] = Field(default_factory=dict) 

71 copy_config: Dict[str, Any] = Field(default_factory=dict) 

72 reference_config: Dict[str, Any] = Field(default_factory=dict) 

73 direct_config: Dict[str, Any] = Field(default_factory=dict) 

74 

75 

76class TransactionConfig(BaseModel): 

77 """Configuration for transaction management.""" 

78 

79 strategy: TransactionStrategy = TransactionStrategy.SINGLE 

80 batch_size: int = Field(default=100, ge=1) 

81 commit_triggers: List[str] = Field(default_factory=list) 

82 rollback_on_error: bool = True 

83 timeout_seconds: int | None = Field(default=None, ge=1) 

84 

85 

86class StreamConfig(BaseModel): 

87 """Configuration for streaming support.""" 

88 

89 enabled: bool = False 

90 chunk_size: int = Field(default=1000, ge=1) 

91 parallelism: int = Field(default=1, ge=1) 

92 memory_limit_mb: int | None = Field(default=None, ge=1) 

93 backpressure_threshold: float = Field(default=0.8, ge=0, le=1) 

94 format: str | None = None 

95 

96 

97class ResourceConfig(BaseModel): 

98 """Configuration for a resource.""" 

99 

100 name: str 

101 type: ResourceType 

102 config: Dict[str, Any] = Field(default_factory=dict) 

103 connection_pool_size: int = Field(default=10, ge=1) 

104 timeout_seconds: int = Field(default=30, ge=1) 

105 retry_attempts: int = Field(default=3, ge=0) 

106 retry_delay_seconds: float = Field(default=1.0, ge=0) 

107 health_check_interval: int | None = Field(default=None, ge=1) 

108 

109 

110class ArcConfig(BaseModel): 

111 """Configuration for an arc.""" 

112 

113 target: str 

114 condition: FunctionReference | None = None 

115 transform: FunctionReference | None = None 

116 resources: List[str] = Field(default_factory=list) 

117 priority: int = Field(default=0) 

118 metadata: Dict[str, Any] = Field(default_factory=dict) 

119 

120 

121class PushArcConfig(ArcConfig): 

122 """Configuration for a push arc to another network.""" 

123 

124 target_network: str 

125 return_state: str | None = None 

126 data_isolation: DataHandlingMode = DataHandlingMode.COPY 

127 

128 

129class StateConfig(BaseModel): 

130 """Configuration for a state.""" 

131 

132 name: str 

133 data_schema: Dict[str, Any] | None = Field(default=None, alias="schema") 

134 pre_validators: List[FunctionReference] = Field(default_factory=list) 

135 validators: List[FunctionReference] = Field(default_factory=list) 

136 transforms: List[FunctionReference] = Field(default_factory=list) 

137 arcs: List[Union[ArcConfig, PushArcConfig]] = Field(default_factory=list) 

138 resources: List[str] = Field(default_factory=list) 

139 data_mode: DataHandlingMode | None = None 

140 is_start: bool = False 

141 is_end: bool = False 

142 metadata: Dict[str, Any] = Field(default_factory=dict) 

143 

144 model_config = {"populate_by_name": True} # Allow both 'schema' and 'data_schema' 

145 

146 @classmethod 

147 @field_validator("arcs", mode="before") 

148 def validate_arcs(cls, v: List[Any]) -> List[Union[ArcConfig, PushArcConfig]]: 

149 """Validate and convert arc configurations.""" 

150 result = [] 

151 for arc in v: 

152 if isinstance(arc, dict): 

153 if "target_network" in arc: 

154 result.append(PushArcConfig(**arc)) 

155 else: 

156 result.append(ArcConfig(**arc)) # type: ignore 

157 else: 

158 result.append(arc) 

159 return result # type: ignore 

160 

161 

162class NetworkConfig(BaseModel): 

163 """Configuration for a state network.""" 

164 

165 name: str 

166 states: List[StateConfig] 

167 resources: List[str] = Field(default_factory=list) 

168 streaming: StreamConfig | None = None 

169 metadata: Dict[str, Any] = Field(default_factory=dict) 

170 

171 @model_validator(mode="after") 

172 def validate_network(self) -> "NetworkConfig": 

173 """Validate network consistency.""" 

174 state_names = {state.name for state in self.states} 

175 

176 # Validate arc targets exist 

177 for state in self.states: 

178 for arc in state.arcs: 

179 if isinstance(arc, ArcConfig) and arc.target not in state_names: 

180 raise ValueError(f"Arc target '{arc.target}' not found in network") 

181 

182 # Validate at least one start state 

183 start_states = [s for s in self.states if s.is_start] 

184 if not start_states: 

185 raise ValueError("Network must have at least one start state") 

186 

187 return self 

188 

189 

190class FSMConfig(BaseModel): 

191 """Complete FSM configuration.""" 

192 

193 name: str 

194 version: str = "1.0.0" 

195 description: str | None = None 

196 

197 # Data handling 

198 data_mode: DataModeConfig = Field(default_factory=DataModeConfig) 

199 transaction: TransactionConfig = Field(default_factory=TransactionConfig) 

200 

201 # Resources 

202 resources: List[ResourceConfig] = Field(default_factory=list) 

203 

204 # Networks 

205 networks: List[NetworkConfig] 

206 main_network: str 

207 

208 # Execution 

209 execution_strategy: ExecutionStrategy = ExecutionStrategy.DEPTH_FIRST 

210 max_transitions: int = Field(default=1000, ge=1) 

211 timeout_seconds: int | None = Field(default=None, ge=1) 

212 

213 # Metadata 

214 metadata: Dict[str, Any] = Field(default_factory=dict) 

215 

216 @model_validator(mode="after") 

217 def validate_fsm(self) -> "FSMConfig": 

218 """Validate FSM configuration consistency.""" 

219 # Validate main network exists 

220 network_names = {net.name for net in self.networks} 

221 if self.main_network not in network_names: 

222 raise ValueError(f"Main network '{self.main_network}' not found") 

223 

224 # Validate resource references 

225 resource_names = {res.name for res in self.resources} 

226 for network in self.networks: 

227 for res_name in network.resources: 

228 if res_name not in resource_names: 

229 raise ValueError(f"Resource '{res_name}' not found in FSM resources") 

230 

231 for state in network.states: 

232 for res_name in state.resources: 

233 if res_name not in resource_names: 

234 raise ValueError(f"Resource '{res_name}' not found in FSM resources") 

235 

236 return self 

237 

238 

239class UseCaseTemplate(str, Enum): 

240 """Pre-defined use case templates.""" 

241 

242 DATABASE_ETL = "database_etl" 

243 FILE_PROCESSING = "file_processing" 

244 API_ORCHESTRATION = "api_orchestration" 

245 LLM_WORKFLOW = "llm_workflow" 

246 DATA_VALIDATION = "data_validation" 

247 STREAM_PROCESSING = "stream_processing" 

248 

249 

250class TemplateConfig(BaseModel): 

251 """Configuration for using a template.""" 

252 

253 template: UseCaseTemplate 

254 params: Dict[str, Any] = Field(default_factory=dict) 

255 overrides: Dict[str, Any] = Field(default_factory=dict) 

256 

257 

258def generate_json_schema() -> Dict[str, Any]: 

259 """Generate JSON schema for FSM configuration. 

260  

261 Returns: 

262 JSON schema as a dictionary. 

263 """ 

264 return FSMConfig.model_json_schema() 

265 

266 

267def validate_config(config: Dict[str, Any]) -> FSMConfig: 

268 """Validate a configuration dictionary. 

269  

270 Args: 

271 config: Configuration dictionary. 

272  

273 Returns: 

274 Validated FSMConfig instance. 

275  

276 Raises: 

277 ValidationError: If configuration is invalid. 

278 """ 

279 return FSMConfig(**config) 

280 

281 

282# Template definitions 

283TEMPLATES: Dict[UseCaseTemplate, Dict[str, Any]] = { 

284 UseCaseTemplate.DATABASE_ETL: { 

285 "data_mode": { 

286 "default": DataHandlingMode.COPY, 

287 }, 

288 "transaction": { 

289 "strategy": TransactionStrategy.BATCH, 

290 "batch_size": 1000, 

291 }, 

292 "execution_strategy": ExecutionStrategy.RESOURCE_OPTIMIZED, 

293 }, 

294 UseCaseTemplate.FILE_PROCESSING: { 

295 "data_mode": { 

296 "default": DataHandlingMode.REFERENCE, 

297 }, 

298 "transaction": { 

299 "strategy": TransactionStrategy.SINGLE, 

300 }, 

301 "execution_strategy": ExecutionStrategy.STREAM_OPTIMIZED, 

302 }, 

303 UseCaseTemplate.API_ORCHESTRATION: { 

304 "data_mode": { 

305 "default": DataHandlingMode.COPY, 

306 }, 

307 "transaction": { 

308 "strategy": TransactionStrategy.MANUAL, 

309 }, 

310 "execution_strategy": ExecutionStrategy.DEPTH_FIRST, 

311 }, 

312 UseCaseTemplate.LLM_WORKFLOW: { 

313 "data_mode": { 

314 "default": DataHandlingMode.COPY, 

315 }, 

316 "transaction": { 

317 "strategy": TransactionStrategy.SINGLE, 

318 }, 

319 "execution_strategy": ExecutionStrategy.RESOURCE_OPTIMIZED, 

320 }, 

321 UseCaseTemplate.DATA_VALIDATION: { 

322 "data_mode": { 

323 "default": DataHandlingMode.DIRECT, 

324 }, 

325 "transaction": { 

326 "strategy": TransactionStrategy.SINGLE, 

327 }, 

328 "execution_strategy": ExecutionStrategy.DEPTH_FIRST, 

329 }, 

330 UseCaseTemplate.STREAM_PROCESSING: { 

331 "data_mode": { 

332 "default": DataHandlingMode.REFERENCE, 

333 }, 

334 "transaction": { 

335 "strategy": TransactionStrategy.BATCH, 

336 "batch_size": 5000, 

337 }, 

338 "execution_strategy": ExecutionStrategy.STREAM_OPTIMIZED, 

339 }, 

340} 

341 

342 

343def apply_template( 

344 template: UseCaseTemplate, 

345 params: Dict[str, Any] | None = None, 

346 overrides: Dict[str, Any] | None = None, 

347) -> Dict[str, Any]: 

348 """Apply a use case template to generate configuration. 

349  

350 Args: 

351 template: Template to apply. 

352 params: Template parameters. 

353 overrides: Configuration overrides. 

354  

355 Returns: 

356 Configuration dictionary. 

357 """ 

358 import copy 

359 

360 config = copy.deepcopy(TEMPLATES[template]) 

361 

362 # Apply parameters (template-specific logic can go here) 

363 if params: 

364 # This would contain template-specific parameter application 

365 pass 

366 

367 # Apply overrides 

368 if overrides: 

369 def deep_merge(base: Dict, updates: Dict) -> Dict: 

370 for key, value in updates.items(): 

371 if key in base and isinstance(base[key], dict) and isinstance(value, dict): 

372 deep_merge(base[key], value) 

373 else: 

374 base[key] = value 

375 return base 

376 

377 deep_merge(config, overrides) 

378 

379 return config