Coverage for src/dataknobs_fsm/config/schema.py: 82%
172 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:51 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:51 -0600
1"""Configuration schema definitions for FSM using Pydantic.
3This module defines the schema for FSM configuration files, including:
4- Data mode configuration
5- Transaction configuration
6- Resource definitions
7- Streaming configuration
8- FSM definition
9- Network definition
10- State definition
11- Arc definition
12"""
14from enum import Enum
15from typing import Any, Dict, List, Literal, Union
17from pydantic import BaseModel, Field, field_validator, model_validator
19from dataknobs_fsm.core.data_modes import DataHandlingMode
20from dataknobs_fsm.core.transactions import TransactionStrategy
23class ResourceType(str, Enum):
24 """Available resource types."""
26 DATABASE = "database"
27 FILESYSTEM = "filesystem"
28 HTTP = "http"
29 LLM = "llm"
30 VECTOR_STORE = "vector_store"
31 CUSTOM = "custom"
34class ExecutionStrategy(str, Enum):
35 """Available execution strategies."""
37 DEPTH_FIRST = "depth_first"
38 BREADTH_FIRST = "breadth_first"
39 RESOURCE_OPTIMIZED = "resource_optimized"
40 STREAM_OPTIMIZED = "stream_optimized"
43class FunctionReference(BaseModel):
44 """Reference to a function."""
46 type: Literal["builtin", "custom", "inline", "registered"]
47 name: str | None = None
48 module: str | None = None
49 code: str | None = None
50 params: Dict[str, Any] = Field(default_factory=dict)
52 @model_validator(mode="after")
53 def validate_reference(self) -> "FunctionReference":
54 """Validate that the reference has required fields based on type."""
55 if self.type == "builtin" and not self.name:
56 raise ValueError("Builtin functions require a 'name'")
57 if self.type == "custom" and not (self.module and self.name):
58 raise ValueError("Custom functions require both 'module' and 'name'")
59 if self.type == "inline" and not self.code:
60 raise ValueError("Inline functions require 'code'")
61 if self.type == "registered" and not self.name:
62 raise ValueError("Registered functions require a 'name'")
63 return self
66class DataModeConfig(BaseModel):
67 """Configuration for data handling modes."""
69 default: DataHandlingMode = DataHandlingMode.COPY
70 state_overrides: Dict[str, DataHandlingMode] = Field(default_factory=dict)
71 copy_config: Dict[str, Any] = Field(default_factory=dict)
72 reference_config: Dict[str, Any] = Field(default_factory=dict)
73 direct_config: Dict[str, Any] = Field(default_factory=dict)
76class TransactionConfig(BaseModel):
77 """Configuration for transaction management."""
79 strategy: TransactionStrategy = TransactionStrategy.SINGLE
80 batch_size: int = Field(default=100, ge=1)
81 commit_triggers: List[str] = Field(default_factory=list)
82 rollback_on_error: bool = True
83 timeout_seconds: int | None = Field(default=None, ge=1)
86class StreamConfig(BaseModel):
87 """Configuration for streaming support."""
89 enabled: bool = False
90 chunk_size: int = Field(default=1000, ge=1)
91 parallelism: int = Field(default=1, ge=1)
92 memory_limit_mb: int | None = Field(default=None, ge=1)
93 backpressure_threshold: float = Field(default=0.8, ge=0, le=1)
94 format: str | None = None
97class ResourceConfig(BaseModel):
98 """Configuration for a resource."""
100 name: str
101 type: ResourceType
102 config: Dict[str, Any] = Field(default_factory=dict)
103 connection_pool_size: int = Field(default=10, ge=1)
104 timeout_seconds: int = Field(default=30, ge=1)
105 retry_attempts: int = Field(default=3, ge=0)
106 retry_delay_seconds: float = Field(default=1.0, ge=0)
107 health_check_interval: int | None = Field(default=None, ge=1)
110class ArcConfig(BaseModel):
111 """Configuration for an arc."""
113 target: str
114 condition: FunctionReference | None = None
115 transform: FunctionReference | None = None
116 resources: List[str] = Field(default_factory=list)
117 priority: int = Field(default=0)
118 metadata: Dict[str, Any] = Field(default_factory=dict)
121class PushArcConfig(ArcConfig):
122 """Configuration for a push arc to another network."""
124 target_network: str
125 return_state: str | None = None
126 data_isolation: DataHandlingMode = DataHandlingMode.COPY
129class StateConfig(BaseModel):
130 """Configuration for a state."""
132 name: str
133 data_schema: Dict[str, Any] | None = Field(default=None, alias="schema")
134 pre_validators: List[FunctionReference] = Field(default_factory=list)
135 validators: List[FunctionReference] = Field(default_factory=list)
136 transforms: List[FunctionReference] = Field(default_factory=list)
137 arcs: List[Union[ArcConfig, PushArcConfig]] = Field(default_factory=list)
138 resources: List[str] = Field(default_factory=list)
139 data_mode: DataHandlingMode | None = None
140 is_start: bool = False
141 is_end: bool = False
142 metadata: Dict[str, Any] = Field(default_factory=dict)
144 model_config = {"populate_by_name": True} # Allow both 'schema' and 'data_schema'
146 @classmethod
147 @field_validator("arcs", mode="before")
148 def validate_arcs(cls, v: List[Any]) -> List[Union[ArcConfig, PushArcConfig]]:
149 """Validate and convert arc configurations."""
150 result = []
151 for arc in v:
152 if isinstance(arc, dict):
153 if "target_network" in arc:
154 result.append(PushArcConfig(**arc))
155 else:
156 result.append(ArcConfig(**arc)) # type: ignore
157 else:
158 result.append(arc)
159 return result # type: ignore
162class NetworkConfig(BaseModel):
163 """Configuration for a state network."""
165 name: str
166 states: List[StateConfig]
167 resources: List[str] = Field(default_factory=list)
168 streaming: StreamConfig | None = None
169 metadata: Dict[str, Any] = Field(default_factory=dict)
171 @model_validator(mode="after")
172 def validate_network(self) -> "NetworkConfig":
173 """Validate network consistency."""
174 state_names = {state.name for state in self.states}
176 # Validate arc targets exist
177 for state in self.states:
178 for arc in state.arcs:
179 if isinstance(arc, ArcConfig) and arc.target not in state_names:
180 raise ValueError(f"Arc target '{arc.target}' not found in network")
182 # Validate at least one start state
183 start_states = [s for s in self.states if s.is_start]
184 if not start_states:
185 raise ValueError("Network must have at least one start state")
187 return self
190class FSMConfig(BaseModel):
191 """Complete FSM configuration."""
193 name: str
194 version: str = "1.0.0"
195 description: str | None = None
197 # Data handling
198 data_mode: DataModeConfig = Field(default_factory=DataModeConfig)
199 transaction: TransactionConfig = Field(default_factory=TransactionConfig)
201 # Resources
202 resources: List[ResourceConfig] = Field(default_factory=list)
204 # Networks
205 networks: List[NetworkConfig]
206 main_network: str
208 # Execution
209 execution_strategy: ExecutionStrategy = ExecutionStrategy.DEPTH_FIRST
210 max_transitions: int = Field(default=1000, ge=1)
211 timeout_seconds: int | None = Field(default=None, ge=1)
213 # Metadata
214 metadata: Dict[str, Any] = Field(default_factory=dict)
216 @model_validator(mode="after")
217 def validate_fsm(self) -> "FSMConfig":
218 """Validate FSM configuration consistency."""
219 # Validate main network exists
220 network_names = {net.name for net in self.networks}
221 if self.main_network not in network_names:
222 raise ValueError(f"Main network '{self.main_network}' not found")
224 # Validate resource references
225 resource_names = {res.name for res in self.resources}
226 for network in self.networks:
227 for res_name in network.resources:
228 if res_name not in resource_names:
229 raise ValueError(f"Resource '{res_name}' not found in FSM resources")
231 for state in network.states:
232 for res_name in state.resources:
233 if res_name not in resource_names:
234 raise ValueError(f"Resource '{res_name}' not found in FSM resources")
236 return self
239class UseCaseTemplate(str, Enum):
240 """Pre-defined use case templates."""
242 DATABASE_ETL = "database_etl"
243 FILE_PROCESSING = "file_processing"
244 API_ORCHESTRATION = "api_orchestration"
245 LLM_WORKFLOW = "llm_workflow"
246 DATA_VALIDATION = "data_validation"
247 STREAM_PROCESSING = "stream_processing"
250class TemplateConfig(BaseModel):
251 """Configuration for using a template."""
253 template: UseCaseTemplate
254 params: Dict[str, Any] = Field(default_factory=dict)
255 overrides: Dict[str, Any] = Field(default_factory=dict)
258def generate_json_schema() -> Dict[str, Any]:
259 """Generate JSON schema for FSM configuration.
261 Returns:
262 JSON schema as a dictionary.
263 """
264 return FSMConfig.model_json_schema()
267def validate_config(config: Dict[str, Any]) -> FSMConfig:
268 """Validate a configuration dictionary.
270 Args:
271 config: Configuration dictionary.
273 Returns:
274 Validated FSMConfig instance.
276 Raises:
277 ValidationError: If configuration is invalid.
278 """
279 return FSMConfig(**config)
282# Template definitions
283TEMPLATES: Dict[UseCaseTemplate, Dict[str, Any]] = {
284 UseCaseTemplate.DATABASE_ETL: {
285 "data_mode": {
286 "default": DataHandlingMode.COPY,
287 },
288 "transaction": {
289 "strategy": TransactionStrategy.BATCH,
290 "batch_size": 1000,
291 },
292 "execution_strategy": ExecutionStrategy.RESOURCE_OPTIMIZED,
293 },
294 UseCaseTemplate.FILE_PROCESSING: {
295 "data_mode": {
296 "default": DataHandlingMode.REFERENCE,
297 },
298 "transaction": {
299 "strategy": TransactionStrategy.SINGLE,
300 },
301 "execution_strategy": ExecutionStrategy.STREAM_OPTIMIZED,
302 },
303 UseCaseTemplate.API_ORCHESTRATION: {
304 "data_mode": {
305 "default": DataHandlingMode.COPY,
306 },
307 "transaction": {
308 "strategy": TransactionStrategy.MANUAL,
309 },
310 "execution_strategy": ExecutionStrategy.DEPTH_FIRST,
311 },
312 UseCaseTemplate.LLM_WORKFLOW: {
313 "data_mode": {
314 "default": DataHandlingMode.COPY,
315 },
316 "transaction": {
317 "strategy": TransactionStrategy.SINGLE,
318 },
319 "execution_strategy": ExecutionStrategy.RESOURCE_OPTIMIZED,
320 },
321 UseCaseTemplate.DATA_VALIDATION: {
322 "data_mode": {
323 "default": DataHandlingMode.DIRECT,
324 },
325 "transaction": {
326 "strategy": TransactionStrategy.SINGLE,
327 },
328 "execution_strategy": ExecutionStrategy.DEPTH_FIRST,
329 },
330 UseCaseTemplate.STREAM_PROCESSING: {
331 "data_mode": {
332 "default": DataHandlingMode.REFERENCE,
333 },
334 "transaction": {
335 "strategy": TransactionStrategy.BATCH,
336 "batch_size": 5000,
337 },
338 "execution_strategy": ExecutionStrategy.STREAM_OPTIMIZED,
339 },
340}
343def apply_template(
344 template: UseCaseTemplate,
345 params: Dict[str, Any] | None = None,
346 overrides: Dict[str, Any] | None = None,
347) -> Dict[str, Any]:
348 """Apply a use case template to generate configuration.
350 Args:
351 template: Template to apply.
352 params: Template parameters.
353 overrides: Configuration overrides.
355 Returns:
356 Configuration dictionary.
357 """
358 import copy
360 config = copy.deepcopy(TEMPLATES[template])
362 # Apply parameters (template-specific logic can go here)
363 if params:
364 # This would contain template-specific parameter application
365 pass
367 # Apply overrides
368 if overrides:
369 def deep_merge(base: Dict, updates: Dict) -> Dict:
370 for key, value in updates.items():
371 if key in base and isinstance(base[key], dict) and isinstance(value, dict):
372 deep_merge(base[key], value)
373 else:
374 base[key] = value
375 return base
377 deep_merge(config, overrides)
379 return config