Coverage for src/dataknobs_fsm/patterns/etl.py: 0%

197 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""AsyncDatabase ETL (Extract, Transform, Load) pattern implementation. 

2 

3This module provides pre-configured FSM patterns for ETL operations, 

4including data extraction from source databases, transformation pipelines, 

5and loading into target systems. 

6""" 

7 

8from typing import Any, Dict, List, Union, Callable, AsyncIterator 

9from dataclasses import dataclass 

10from enum import Enum 

11from dataknobs_data import AsyncDatabase, Record, Query 

12 

13from ..api.simple import SimpleFSM 

14from dataknobs_fsm.core.data_modes import DataHandlingMode 

15from ..functions.library.database import DatabaseFetch, DatabaseUpsert 

16from ..functions.library.transformers import ( 

17 FieldMapper, DataEnricher 

18) 

19from ..functions.library.validators import SchemaValidator 

20 

21 

22class ETLMode(Enum): 

23 """ETL processing modes.""" 

24 FULL_REFRESH = "full" # Replace all data 

25 INCREMENTAL = "incremental" # Process only new/changed data 

26 UPSERT = "upsert" # Update existing, insert new 

27 APPEND = "append" # Always append, no updates 

28 

29 

30@dataclass 

31class ETLConfig: 

32 """Configuration for ETL pipeline.""" 

33 source_db: Dict[str, Any] # Source database config 

34 target_db: Dict[str, Any] # Target database config 

35 mode: ETLMode = ETLMode.FULL_REFRESH 

36 batch_size: int = 1000 

37 parallel_workers: int = 4 

38 error_threshold: float = 0.05 # Max 5% errors 

39 checkpoint_interval: int = 10000 # Checkpoint every N records 

40 

41 # Optional configurations 

42 source_query: str | None = "SELECT * FROM source_table" 

43 target_table: str = "target_table" 

44 key_columns: List[str] | None = None 

45 field_mappings: Dict[str, str] | None = None 

46 transformations: List[Callable] | None = None 

47 validation_schema: Dict[str, Any] | None = None 

48 enrichment_sources: List[Dict[str, Any]] | None = None 

49 

50 

51class DatabaseETL: 

52 """AsyncDatabase ETL pipeline using FSM pattern.""" 

53 

54 def __init__(self, config: ETLConfig): 

55 """Initialize ETL pipeline. 

56  

57 Args: 

58 config: ETL configuration 

59 """ 

60 self.config = config 

61 self._fsm = self._build_fsm() 

62 self._checkpoint_data = {} 

63 self._metrics = { 

64 'extracted': 0, 

65 'transformed': 0, 

66 'loaded': 0, 

67 'errors': 0, 

68 'skipped': 0 

69 } 

70 

71 def _build_fsm(self) -> SimpleFSM: 

72 """Build FSM for ETL pipeline.""" 

73 # Build resources list 

74 resources = [ 

75 {'name': 'source_db', 'type': 'database', 'config': self.config.source_db}, 

76 {'name': 'target_db', 'type': 'database', 'config': self.config.target_db} 

77 ] 

78 

79 # Add enrichment resources if configured 

80 if self.config.enrichment_sources: 

81 for i, source in enumerate(self.config.enrichment_sources): 

82 if 'database' in source: 

83 resources.append({ 

84 'name': f'enrichment_db_{i}', 

85 'type': 'database', 

86 'config': source['database'] 

87 }) 

88 elif 'api' in source: 

89 resources.append({ 

90 'name': f'enrichment_api_{i}', 

91 'type': 'http', 

92 'config': source['api'] 

93 }) 

94 

95 # Create FSM configuration 

96 fsm_config = { 

97 'name': 'ETL_Pipeline', 

98 'data_mode': DataHandlingMode.COPY.value, # Use COPY for data isolation 

99 'resources': resources, 

100 'states': [ 

101 { 

102 'name': 'extract', 

103 'is_start': True, 

104 'resources': ['source_db'] 

105 }, 

106 { 

107 'name': 'validate', 

108 'resources': [] 

109 }, 

110 { 

111 'name': 'transform', 

112 'resources': self._get_transform_resources() 

113 }, 

114 { 

115 'name': 'enrich', 

116 'resources': self._get_enrichment_resources() 

117 }, 

118 { 

119 'name': 'load', 

120 'resources': ['target_db'] 

121 }, 

122 { 

123 'name': 'complete', 

124 'is_end': True 

125 }, 

126 { 

127 'name': 'error', 

128 'is_end': True 

129 } 

130 ], 

131 'arcs': [ 

132 { 

133 'from': 'extract', 

134 'to': 'validate', 

135 'name': 'extracted' 

136 }, 

137 { 

138 'from': 'validate', 

139 'to': 'transform', 

140 'name': 'valid', 

141 'pre_test': self._create_validation_test_reference() 

142 }, 

143 { 

144 'from': 'validate', 

145 'to': 'error', 

146 'name': 'invalid' 

147 }, 

148 { 

149 'from': 'transform', 

150 'to': 'enrich' if self.config.enrichment_sources else 'load', 

151 'name': 'transformed', 

152 'transform': self._create_transformer_reference() 

153 }, 

154 { 

155 'from': 'enrich', 

156 'to': 'load', 

157 'name': 'enriched', 

158 'transform': self._create_enricher_reference() 

159 }, 

160 { 

161 'from': 'load', 

162 'to': 'complete', 

163 'name': 'loaded' 

164 } 

165 ] 

166 } 

167 

168 # Add functions 

169 self._register_functions(fsm_config) 

170 

171 return SimpleFSM(fsm_config, data_mode=DataHandlingMode.COPY) 

172 

173 def _get_transform_resources(self) -> List[str]: 

174 """Get resources needed for transformation.""" 

175 resources = [] 

176 if self.config.enrichment_sources: 

177 for i, source in enumerate(self.config.enrichment_sources): 

178 if 'database' in source: 

179 resources.append(f'enrichment_db_{i}') 

180 return resources 

181 

182 def _get_enrichment_resources(self) -> List[str]: 

183 """Get resources needed for enrichment.""" 

184 resources = [] 

185 if self.config.enrichment_sources: 

186 for i, source in enumerate(self.config.enrichment_sources): 

187 if 'api' in source: 

188 resources.append(f'enrichment_api_{i}') 

189 return resources 

190 

191 def _create_validation_test(self) -> Callable | None: 

192 """Create validation test function.""" 

193 if not self.config.validation_schema: 

194 return None 

195 

196 validator = SchemaValidator(self.config.validation_schema) 

197 return lambda state: validator.validate(Record(state.data)) # type: ignore 

198 

199 def _create_validation_test_reference(self) -> Dict[str, Any] | None: 

200 """Create validation test function reference for FSM config.""" 

201 if not self.config.validation_schema: 

202 return None 

203 

204 import json 

205 # Build validation code based on schema 

206 code_lines = [ 

207 "# Validate data against schema", 

208 f"schema = {json.dumps(self.config.validation_schema)}", 

209 "# Basic schema validation", 

210 "if schema.get('type') == 'object':", 

211 " required = schema.get('required', [])", 

212 " for field in required:", 

213 " if field not in data:", 

214 " False", 

215 "True" 

216 ] 

217 

218 return { 

219 'type': 'inline', 

220 'code': '\n'.join(code_lines) 

221 } 

222 

223 def _create_transformer(self) -> Callable: 

224 """Create transformation function.""" 

225 transformers = [] 

226 

227 # Add field mapping 

228 if self.config.field_mappings: 

229 transformers.append(FieldMapper(self.config.field_mappings)) 

230 

231 # Add custom transformations 

232 if self.config.transformations: 

233 transformers.extend(self.config.transformations) # type: ignore 

234 

235 # Compose transformers 

236 async def transform(data: Dict[str, Any]) -> Dict[str, Any]: 

237 result = data 

238 for transformer in transformers: 

239 if hasattr(transformer, 'transform'): 

240 result = await transformer.transform(result) # type: ignore 

241 elif callable(transformer): 

242 result = transformer(result) 

243 return result 

244 

245 return transform 

246 

247 def _create_enricher(self) -> Callable | None: 

248 """Create enrichment function.""" 

249 if not self.config.enrichment_sources: 

250 return None 

251 

252 enricher = DataEnricher(self.config.enrichment_sources) # type: ignore 

253 return enricher.transform 

254 

255 def _create_transformer_reference(self) -> Dict[str, Any] | None: 

256 """Create transformation function reference for FSM config.""" 

257 if not self.config.field_mappings and not self.config.transformations: 

258 return None 

259 

260 # Build transformation code 

261 code_lines = [ 

262 "# Apply transformations", 

263 "result = data" 

264 ] 

265 

266 if self.config.field_mappings: 

267 # Add field mapping code 

268 for old_name, new_name in self.config.field_mappings.items(): 

269 code_lines.append(f"if '{old_name}' in result:") 

270 code_lines.append(f" result['{new_name}'] = result.pop('{old_name}')") 

271 

272 # For custom transformations, we'll apply them as simple dict updates 

273 if self.config.transformations: 

274 code_lines.append("# Apply custom transformations") 

275 code_lines.append("result['transformed'] = True") 

276 

277 code_lines.append("result") 

278 

279 return { 

280 'type': 'inline', 

281 'code': '\n'.join(code_lines) 

282 } 

283 

284 def _create_enricher_reference(self) -> Dict[str, Any] | None: 

285 """Create enrichment function reference for FSM config.""" 

286 if not self.config.enrichment_sources: 

287 return None 

288 

289 # Build enrichment code 

290 code_lines = [ 

291 "# Enrich data", 

292 "result = data", 

293 "result['enriched'] = True", 

294 "result" 

295 ] 

296 

297 return { 

298 'type': 'inline', 

299 'code': '\n'.join(code_lines) 

300 } 

301 

302 def _register_functions(self, config: Dict[str, Any]) -> None: 

303 """Register ETL-specific functions.""" 

304 # Register database functions 

305 config['functions'] = { 

306 'extract': DatabaseFetch( 

307 resource_name='source_db', 

308 query=self.config.source_query # type: ignore 

309 ), 

310 'load': DatabaseUpsert( 

311 resource_name='target_db', 

312 table=self.config.target_table, 

313 key_columns=self.config.key_columns or ['id'] 

314 ) 

315 } 

316 

317 async def run( 

318 self, 

319 checkpoint_id: str | None = None 

320 ) -> Dict[str, Any]: 

321 """Run ETL pipeline. 

322  

323 Args: 

324 checkpoint_id: Optional checkpoint to resume from 

325  

326 Returns: 

327 ETL execution metrics 

328 """ 

329 # Resume from checkpoint if provided 

330 if checkpoint_id: 

331 await self._load_checkpoint(checkpoint_id) 

332 

333 # Extract data 

334 source_db = await AsyncDatabase.create( 

335 self.config.source_db['type'], 

336 self.config.source_db # type: ignore 

337 ) 

338 

339 try: 

340 # Determine extraction strategy 

341 if self.config.mode == ETLMode.INCREMENTAL: 

342 query = self._get_incremental_query() 

343 else: 

344 query = self.config.source_query or Query() 

345 

346 # Process in batches 

347 async for batch in self._extract_batches(source_db, query): # type: ignore 

348 # Process batch through FSM 

349 results = self._fsm.process_batch( 

350 data=batch, # type: ignore 

351 batch_size=self.config.batch_size, 

352 max_workers=self.config.parallel_workers 

353 ) 

354 

355 # Update metrics 

356 self._update_metrics(results) 

357 

358 # Check error threshold 

359 if self._check_error_threshold(): 

360 from ..core.exceptions import ETLError 

361 raise ETLError(f"Error threshold exceeded: {self._metrics['errors']} errors") 

362 

363 # Checkpoint if needed 

364 if self._should_checkpoint(): 

365 await self._save_checkpoint() 

366 

367 finally: 

368 await source_db.close() 

369 

370 return self._metrics 

371 

372 async def _extract_batches( 

373 self, 

374 db: AsyncDatabase, 

375 query: Query 

376 ) -> AsyncIterator[List[Dict[str, Any]]]: 

377 """Extract data in batches. 

378  

379 Args: 

380 db: Source database 

381 query: Extraction query 

382  

383 Yields: 

384 Batches of records 

385 """ 

386 batch = [] 

387 async for record in db.stream_read(query): 

388 batch.append(record.to_dict()) 

389 if len(batch) >= self.config.batch_size: 

390 yield batch 

391 batch = [] 

392 

393 if batch: 

394 yield batch 

395 

396 def _get_incremental_query(self) -> Query: 

397 """Get query for incremental extraction.""" 

398 # Get last processed timestamp from checkpoint 

399 last_timestamp = self._checkpoint_data.get('last_timestamp') 

400 

401 if last_timestamp: 

402 return Query().filter('updated_at', '>', last_timestamp) 

403 else: 

404 return Query() 

405 

406 def _update_metrics(self, results: List[Dict[str, Any]]) -> None: 

407 """Update execution metrics.""" 

408 for result in results: 

409 if result['success']: 

410 if result['final_state'] == 'complete': 

411 self._metrics['loaded'] += 1 

412 elif result['final_state'] == 'error': 

413 self._metrics['errors'] += 1 

414 else: 

415 self._metrics['errors'] += 1 

416 

417 self._metrics['extracted'] = self._metrics['loaded'] + self._metrics['errors'] 

418 

419 def _check_error_threshold(self) -> bool: 

420 """Check if error threshold is exceeded.""" 

421 if self._metrics['extracted'] == 0: 

422 return False 

423 

424 error_rate = self._metrics['errors'] / self._metrics['extracted'] 

425 return error_rate > self.config.error_threshold 

426 

427 def _should_checkpoint(self) -> bool: 

428 """Check if checkpoint should be saved.""" 

429 return self._metrics['extracted'] % self.config.checkpoint_interval == 0 

430 

431 async def _save_checkpoint(self) -> str: 

432 """Save checkpoint for resume capability.""" 

433 import json 

434 import hashlib 

435 from datetime import datetime 

436 

437 checkpoint = { 

438 'timestamp': datetime.now().isoformat(), 

439 'metrics': self._metrics, 

440 'config': { 

441 'mode': self.config.mode.value, 

442 'batch_size': self.config.batch_size 

443 }, 

444 'position': self._metrics['extracted'] 

445 } 

446 

447 # Generate checkpoint ID 

448 checkpoint_id = hashlib.md5( 

449 json.dumps(checkpoint).encode() 

450 ).hexdigest()[:8] 

451 

452 # Save to storage (simplified - would use persistent storage) 

453 self._checkpoint_data[checkpoint_id] = checkpoint 

454 

455 return checkpoint_id 

456 

457 async def _load_checkpoint(self, checkpoint_id: str) -> None: 

458 """Load checkpoint data.""" 

459 if checkpoint_id in self._checkpoint_data: 

460 checkpoint = self._checkpoint_data[checkpoint_id] 

461 self._metrics = checkpoint['metrics'] 

462 

463 

464def create_etl_pipeline( 

465 source: Union[str, Dict[str, Any]], 

466 target: Union[str, Dict[str, Any]], 

467 mode: ETLMode = ETLMode.FULL_REFRESH, 

468 **kwargs 

469) -> DatabaseETL: 

470 """Factory function to create ETL pipeline. 

471  

472 Args: 

473 source: Source database configuration or connection string 

474 target: Target database configuration or connection string 

475 mode: ETL mode 

476 **kwargs: Additional configuration options 

477  

478 Returns: 

479 Configured DatabaseETL instance 

480 """ 

481 # Parse connection strings if needed 

482 if isinstance(source, str): 

483 source = _parse_connection_string(source) 

484 if isinstance(target, str): 

485 target = _parse_connection_string(target) 

486 

487 config = ETLConfig( 

488 source_db=source, 

489 target_db=target, 

490 mode=mode, 

491 **kwargs 

492 ) 

493 

494 return DatabaseETL(config) 

495 

496 

497def _parse_connection_string(conn_str: str) -> Dict[str, Any]: 

498 """Parse database connection string. 

499  

500 Args: 

501 conn_str: Connection string 

502  

503 Returns: 

504 AsyncDatabase configuration dictionary 

505 """ 

506 # Simplified parsing - real implementation would be more robust 

507 if conn_str.startswith('postgresql://'): 

508 return { 

509 'type': 'postgres', 

510 'connection_string': conn_str 

511 } 

512 elif conn_str.startswith('mongodb://'): 

513 return { 

514 'type': 'mongodb', 

515 'connection_string': conn_str 

516 } 

517 elif conn_str.startswith('sqlite://'): 

518 return { 

519 'type': 'sqlite', 

520 'path': conn_str.replace('sqlite://', '') 

521 } 

522 else: 

523 raise ValueError(f"Unsupported connection string: {conn_str}") 

524 

525 

526# Pre-configured ETL patterns 

527 

528def create_database_sync( 

529 source: Dict[str, Any], 

530 target: Dict[str, Any], 

531 sync_interval: int = 300 # 5 minutes 

532) -> DatabaseETL: 

533 """Create database synchronization pipeline. 

534  

535 Args: 

536 source: Source database config 

537 target: Target database config 

538 sync_interval: Sync interval in seconds 

539  

540 Returns: 

541 AsyncDatabase sync ETL pipeline 

542 """ 

543 return create_etl_pipeline( 

544 source=source, 

545 target=target, 

546 mode=ETLMode.INCREMENTAL, 

547 checkpoint_interval=1000 

548 ) 

549 

550 

551def create_data_migration( 

552 source: Dict[str, Any], 

553 target: Dict[str, Any], 

554 field_mappings: Dict[str, str] | None = None, 

555 transformations: List[Callable] | None = None 

556) -> DatabaseETL: 

557 """Create data migration pipeline. 

558  

559 Args: 

560 source: Source database config 

561 target: Target database config 

562 field_mappings: Field name mappings 

563 transformations: Data transformation functions 

564  

565 Returns: 

566 Data migration ETL pipeline 

567 """ 

568 return create_etl_pipeline( 

569 source=source, 

570 target=target, 

571 mode=ETLMode.FULL_REFRESH, 

572 field_mappings=field_mappings, 

573 transformations=transformations, 

574 batch_size=5000, 

575 parallel_workers=8 

576 ) 

577 

578 

579def create_data_warehouse_load( 

580 sources: List[Dict[str, Any]], 

581 warehouse: Dict[str, Any], 

582 aggregations: List[Callable] | None = None 

583) -> List[DatabaseETL]: 

584 """Create data warehouse loading pipelines. 

585  

586 Args: 

587 sources: List of source database configs 

588 warehouse: Data warehouse config 

589 aggregations: Aggregation functions 

590  

591 Returns: 

592 List of ETL pipelines for each source 

593 """ 

594 pipelines = [] 

595 

596 for source in sources: 

597 pipeline = create_etl_pipeline( 

598 source=source, 

599 target=warehouse, 

600 mode=ETLMode.APPEND, 

601 transformations=aggregations, 

602 batch_size=10000 

603 ) 

604 pipelines.append(pipeline) 

605 

606 return pipelines