Coverage for src/dataknobs_fsm/patterns/etl.py: 0%
197 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""AsyncDatabase ETL (Extract, Transform, Load) pattern implementation.
3This module provides pre-configured FSM patterns for ETL operations,
4including data extraction from source databases, transformation pipelines,
5and loading into target systems.
6"""
8from typing import Any, Dict, List, Union, Callable, AsyncIterator
9from dataclasses import dataclass
10from enum import Enum
11from dataknobs_data import AsyncDatabase, Record, Query
13from ..api.simple import SimpleFSM
14from dataknobs_fsm.core.data_modes import DataHandlingMode
15from ..functions.library.database import DatabaseFetch, DatabaseUpsert
16from ..functions.library.transformers import (
17 FieldMapper, DataEnricher
18)
19from ..functions.library.validators import SchemaValidator
22class ETLMode(Enum):
23 """ETL processing modes."""
24 FULL_REFRESH = "full" # Replace all data
25 INCREMENTAL = "incremental" # Process only new/changed data
26 UPSERT = "upsert" # Update existing, insert new
27 APPEND = "append" # Always append, no updates
30@dataclass
31class ETLConfig:
32 """Configuration for ETL pipeline."""
33 source_db: Dict[str, Any] # Source database config
34 target_db: Dict[str, Any] # Target database config
35 mode: ETLMode = ETLMode.FULL_REFRESH
36 batch_size: int = 1000
37 parallel_workers: int = 4
38 error_threshold: float = 0.05 # Max 5% errors
39 checkpoint_interval: int = 10000 # Checkpoint every N records
41 # Optional configurations
42 source_query: str | None = "SELECT * FROM source_table"
43 target_table: str = "target_table"
44 key_columns: List[str] | None = None
45 field_mappings: Dict[str, str] | None = None
46 transformations: List[Callable] | None = None
47 validation_schema: Dict[str, Any] | None = None
48 enrichment_sources: List[Dict[str, Any]] | None = None
51class DatabaseETL:
52 """AsyncDatabase ETL pipeline using FSM pattern."""
54 def __init__(self, config: ETLConfig):
55 """Initialize ETL pipeline.
57 Args:
58 config: ETL configuration
59 """
60 self.config = config
61 self._fsm = self._build_fsm()
62 self._checkpoint_data = {}
63 self._metrics = {
64 'extracted': 0,
65 'transformed': 0,
66 'loaded': 0,
67 'errors': 0,
68 'skipped': 0
69 }
71 def _build_fsm(self) -> SimpleFSM:
72 """Build FSM for ETL pipeline."""
73 # Build resources list
74 resources = [
75 {'name': 'source_db', 'type': 'database', 'config': self.config.source_db},
76 {'name': 'target_db', 'type': 'database', 'config': self.config.target_db}
77 ]
79 # Add enrichment resources if configured
80 if self.config.enrichment_sources:
81 for i, source in enumerate(self.config.enrichment_sources):
82 if 'database' in source:
83 resources.append({
84 'name': f'enrichment_db_{i}',
85 'type': 'database',
86 'config': source['database']
87 })
88 elif 'api' in source:
89 resources.append({
90 'name': f'enrichment_api_{i}',
91 'type': 'http',
92 'config': source['api']
93 })
95 # Create FSM configuration
96 fsm_config = {
97 'name': 'ETL_Pipeline',
98 'data_mode': DataHandlingMode.COPY.value, # Use COPY for data isolation
99 'resources': resources,
100 'states': [
101 {
102 'name': 'extract',
103 'is_start': True,
104 'resources': ['source_db']
105 },
106 {
107 'name': 'validate',
108 'resources': []
109 },
110 {
111 'name': 'transform',
112 'resources': self._get_transform_resources()
113 },
114 {
115 'name': 'enrich',
116 'resources': self._get_enrichment_resources()
117 },
118 {
119 'name': 'load',
120 'resources': ['target_db']
121 },
122 {
123 'name': 'complete',
124 'is_end': True
125 },
126 {
127 'name': 'error',
128 'is_end': True
129 }
130 ],
131 'arcs': [
132 {
133 'from': 'extract',
134 'to': 'validate',
135 'name': 'extracted'
136 },
137 {
138 'from': 'validate',
139 'to': 'transform',
140 'name': 'valid',
141 'pre_test': self._create_validation_test_reference()
142 },
143 {
144 'from': 'validate',
145 'to': 'error',
146 'name': 'invalid'
147 },
148 {
149 'from': 'transform',
150 'to': 'enrich' if self.config.enrichment_sources else 'load',
151 'name': 'transformed',
152 'transform': self._create_transformer_reference()
153 },
154 {
155 'from': 'enrich',
156 'to': 'load',
157 'name': 'enriched',
158 'transform': self._create_enricher_reference()
159 },
160 {
161 'from': 'load',
162 'to': 'complete',
163 'name': 'loaded'
164 }
165 ]
166 }
168 # Add functions
169 self._register_functions(fsm_config)
171 return SimpleFSM(fsm_config, data_mode=DataHandlingMode.COPY)
173 def _get_transform_resources(self) -> List[str]:
174 """Get resources needed for transformation."""
175 resources = []
176 if self.config.enrichment_sources:
177 for i, source in enumerate(self.config.enrichment_sources):
178 if 'database' in source:
179 resources.append(f'enrichment_db_{i}')
180 return resources
182 def _get_enrichment_resources(self) -> List[str]:
183 """Get resources needed for enrichment."""
184 resources = []
185 if self.config.enrichment_sources:
186 for i, source in enumerate(self.config.enrichment_sources):
187 if 'api' in source:
188 resources.append(f'enrichment_api_{i}')
189 return resources
191 def _create_validation_test(self) -> Callable | None:
192 """Create validation test function."""
193 if not self.config.validation_schema:
194 return None
196 validator = SchemaValidator(self.config.validation_schema)
197 return lambda state: validator.validate(Record(state.data)) # type: ignore
199 def _create_validation_test_reference(self) -> Dict[str, Any] | None:
200 """Create validation test function reference for FSM config."""
201 if not self.config.validation_schema:
202 return None
204 import json
205 # Build validation code based on schema
206 code_lines = [
207 "# Validate data against schema",
208 f"schema = {json.dumps(self.config.validation_schema)}",
209 "# Basic schema validation",
210 "if schema.get('type') == 'object':",
211 " required = schema.get('required', [])",
212 " for field in required:",
213 " if field not in data:",
214 " False",
215 "True"
216 ]
218 return {
219 'type': 'inline',
220 'code': '\n'.join(code_lines)
221 }
223 def _create_transformer(self) -> Callable:
224 """Create transformation function."""
225 transformers = []
227 # Add field mapping
228 if self.config.field_mappings:
229 transformers.append(FieldMapper(self.config.field_mappings))
231 # Add custom transformations
232 if self.config.transformations:
233 transformers.extend(self.config.transformations) # type: ignore
235 # Compose transformers
236 async def transform(data: Dict[str, Any]) -> Dict[str, Any]:
237 result = data
238 for transformer in transformers:
239 if hasattr(transformer, 'transform'):
240 result = await transformer.transform(result) # type: ignore
241 elif callable(transformer):
242 result = transformer(result)
243 return result
245 return transform
247 def _create_enricher(self) -> Callable | None:
248 """Create enrichment function."""
249 if not self.config.enrichment_sources:
250 return None
252 enricher = DataEnricher(self.config.enrichment_sources) # type: ignore
253 return enricher.transform
255 def _create_transformer_reference(self) -> Dict[str, Any] | None:
256 """Create transformation function reference for FSM config."""
257 if not self.config.field_mappings and not self.config.transformations:
258 return None
260 # Build transformation code
261 code_lines = [
262 "# Apply transformations",
263 "result = data"
264 ]
266 if self.config.field_mappings:
267 # Add field mapping code
268 for old_name, new_name in self.config.field_mappings.items():
269 code_lines.append(f"if '{old_name}' in result:")
270 code_lines.append(f" result['{new_name}'] = result.pop('{old_name}')")
272 # For custom transformations, we'll apply them as simple dict updates
273 if self.config.transformations:
274 code_lines.append("# Apply custom transformations")
275 code_lines.append("result['transformed'] = True")
277 code_lines.append("result")
279 return {
280 'type': 'inline',
281 'code': '\n'.join(code_lines)
282 }
284 def _create_enricher_reference(self) -> Dict[str, Any] | None:
285 """Create enrichment function reference for FSM config."""
286 if not self.config.enrichment_sources:
287 return None
289 # Build enrichment code
290 code_lines = [
291 "# Enrich data",
292 "result = data",
293 "result['enriched'] = True",
294 "result"
295 ]
297 return {
298 'type': 'inline',
299 'code': '\n'.join(code_lines)
300 }
302 def _register_functions(self, config: Dict[str, Any]) -> None:
303 """Register ETL-specific functions."""
304 # Register database functions
305 config['functions'] = {
306 'extract': DatabaseFetch(
307 resource_name='source_db',
308 query=self.config.source_query # type: ignore
309 ),
310 'load': DatabaseUpsert(
311 resource_name='target_db',
312 table=self.config.target_table,
313 key_columns=self.config.key_columns or ['id']
314 )
315 }
317 async def run(
318 self,
319 checkpoint_id: str | None = None
320 ) -> Dict[str, Any]:
321 """Run ETL pipeline.
323 Args:
324 checkpoint_id: Optional checkpoint to resume from
326 Returns:
327 ETL execution metrics
328 """
329 # Resume from checkpoint if provided
330 if checkpoint_id:
331 await self._load_checkpoint(checkpoint_id)
333 # Extract data
334 source_db = await AsyncDatabase.create(
335 self.config.source_db['type'],
336 self.config.source_db # type: ignore
337 )
339 try:
340 # Determine extraction strategy
341 if self.config.mode == ETLMode.INCREMENTAL:
342 query = self._get_incremental_query()
343 else:
344 query = self.config.source_query or Query()
346 # Process in batches
347 async for batch in self._extract_batches(source_db, query): # type: ignore
348 # Process batch through FSM
349 results = self._fsm.process_batch(
350 data=batch, # type: ignore
351 batch_size=self.config.batch_size,
352 max_workers=self.config.parallel_workers
353 )
355 # Update metrics
356 self._update_metrics(results)
358 # Check error threshold
359 if self._check_error_threshold():
360 from ..core.exceptions import ETLError
361 raise ETLError(f"Error threshold exceeded: {self._metrics['errors']} errors")
363 # Checkpoint if needed
364 if self._should_checkpoint():
365 await self._save_checkpoint()
367 finally:
368 await source_db.close()
370 return self._metrics
372 async def _extract_batches(
373 self,
374 db: AsyncDatabase,
375 query: Query
376 ) -> AsyncIterator[List[Dict[str, Any]]]:
377 """Extract data in batches.
379 Args:
380 db: Source database
381 query: Extraction query
383 Yields:
384 Batches of records
385 """
386 batch = []
387 async for record in db.stream_read(query):
388 batch.append(record.to_dict())
389 if len(batch) >= self.config.batch_size:
390 yield batch
391 batch = []
393 if batch:
394 yield batch
396 def _get_incremental_query(self) -> Query:
397 """Get query for incremental extraction."""
398 # Get last processed timestamp from checkpoint
399 last_timestamp = self._checkpoint_data.get('last_timestamp')
401 if last_timestamp:
402 return Query().filter('updated_at', '>', last_timestamp)
403 else:
404 return Query()
406 def _update_metrics(self, results: List[Dict[str, Any]]) -> None:
407 """Update execution metrics."""
408 for result in results:
409 if result['success']:
410 if result['final_state'] == 'complete':
411 self._metrics['loaded'] += 1
412 elif result['final_state'] == 'error':
413 self._metrics['errors'] += 1
414 else:
415 self._metrics['errors'] += 1
417 self._metrics['extracted'] = self._metrics['loaded'] + self._metrics['errors']
419 def _check_error_threshold(self) -> bool:
420 """Check if error threshold is exceeded."""
421 if self._metrics['extracted'] == 0:
422 return False
424 error_rate = self._metrics['errors'] / self._metrics['extracted']
425 return error_rate > self.config.error_threshold
427 def _should_checkpoint(self) -> bool:
428 """Check if checkpoint should be saved."""
429 return self._metrics['extracted'] % self.config.checkpoint_interval == 0
431 async def _save_checkpoint(self) -> str:
432 """Save checkpoint for resume capability."""
433 import json
434 import hashlib
435 from datetime import datetime
437 checkpoint = {
438 'timestamp': datetime.now().isoformat(),
439 'metrics': self._metrics,
440 'config': {
441 'mode': self.config.mode.value,
442 'batch_size': self.config.batch_size
443 },
444 'position': self._metrics['extracted']
445 }
447 # Generate checkpoint ID
448 checkpoint_id = hashlib.md5(
449 json.dumps(checkpoint).encode()
450 ).hexdigest()[:8]
452 # Save to storage (simplified - would use persistent storage)
453 self._checkpoint_data[checkpoint_id] = checkpoint
455 return checkpoint_id
457 async def _load_checkpoint(self, checkpoint_id: str) -> None:
458 """Load checkpoint data."""
459 if checkpoint_id in self._checkpoint_data:
460 checkpoint = self._checkpoint_data[checkpoint_id]
461 self._metrics = checkpoint['metrics']
464def create_etl_pipeline(
465 source: Union[str, Dict[str, Any]],
466 target: Union[str, Dict[str, Any]],
467 mode: ETLMode = ETLMode.FULL_REFRESH,
468 **kwargs
469) -> DatabaseETL:
470 """Factory function to create ETL pipeline.
472 Args:
473 source: Source database configuration or connection string
474 target: Target database configuration or connection string
475 mode: ETL mode
476 **kwargs: Additional configuration options
478 Returns:
479 Configured DatabaseETL instance
480 """
481 # Parse connection strings if needed
482 if isinstance(source, str):
483 source = _parse_connection_string(source)
484 if isinstance(target, str):
485 target = _parse_connection_string(target)
487 config = ETLConfig(
488 source_db=source,
489 target_db=target,
490 mode=mode,
491 **kwargs
492 )
494 return DatabaseETL(config)
497def _parse_connection_string(conn_str: str) -> Dict[str, Any]:
498 """Parse database connection string.
500 Args:
501 conn_str: Connection string
503 Returns:
504 AsyncDatabase configuration dictionary
505 """
506 # Simplified parsing - real implementation would be more robust
507 if conn_str.startswith('postgresql://'):
508 return {
509 'type': 'postgres',
510 'connection_string': conn_str
511 }
512 elif conn_str.startswith('mongodb://'):
513 return {
514 'type': 'mongodb',
515 'connection_string': conn_str
516 }
517 elif conn_str.startswith('sqlite://'):
518 return {
519 'type': 'sqlite',
520 'path': conn_str.replace('sqlite://', '')
521 }
522 else:
523 raise ValueError(f"Unsupported connection string: {conn_str}")
526# Pre-configured ETL patterns
528def create_database_sync(
529 source: Dict[str, Any],
530 target: Dict[str, Any],
531 sync_interval: int = 300 # 5 minutes
532) -> DatabaseETL:
533 """Create database synchronization pipeline.
535 Args:
536 source: Source database config
537 target: Target database config
538 sync_interval: Sync interval in seconds
540 Returns:
541 AsyncDatabase sync ETL pipeline
542 """
543 return create_etl_pipeline(
544 source=source,
545 target=target,
546 mode=ETLMode.INCREMENTAL,
547 checkpoint_interval=1000
548 )
551def create_data_migration(
552 source: Dict[str, Any],
553 target: Dict[str, Any],
554 field_mappings: Dict[str, str] | None = None,
555 transformations: List[Callable] | None = None
556) -> DatabaseETL:
557 """Create data migration pipeline.
559 Args:
560 source: Source database config
561 target: Target database config
562 field_mappings: Field name mappings
563 transformations: Data transformation functions
565 Returns:
566 Data migration ETL pipeline
567 """
568 return create_etl_pipeline(
569 source=source,
570 target=target,
571 mode=ETLMode.FULL_REFRESH,
572 field_mappings=field_mappings,
573 transformations=transformations,
574 batch_size=5000,
575 parallel_workers=8
576 )
579def create_data_warehouse_load(
580 sources: List[Dict[str, Any]],
581 warehouse: Dict[str, Any],
582 aggregations: List[Callable] | None = None
583) -> List[DatabaseETL]:
584 """Create data warehouse loading pipelines.
586 Args:
587 sources: List of source database configs
588 warehouse: Data warehouse config
589 aggregations: Aggregation functions
591 Returns:
592 List of ETL pipelines for each source
593 """
594 pipelines = []
596 for source in sources:
597 pipeline = create_etl_pipeline(
598 source=source,
599 target=warehouse,
600 mode=ETLMode.APPEND,
601 transformations=aggregations,
602 batch_size=10000
603 )
604 pipelines.append(pipeline)
606 return pipelines