Coverage for src/dataknobs_fsm/functions/library/transformers.py: 0%

225 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Built-in transformer functions for FSM. 

2 

3This module provides commonly used transformation functions that can be 

4referenced in FSM configurations. 

5""" 

6 

7import copy 

8import json 

9import re 

10from datetime import datetime 

11from typing import Any, Callable, Dict, List, Union 

12 

13from dataknobs_fsm.functions.base import ITransformFunction, TransformError 

14 

15 

16class FieldMapper(ITransformFunction): 

17 """Map fields from source to target names.""" 

18 

19 def __init__( 

20 self, 

21 field_map: Dict[str, str], 

22 drop_unmapped: bool = False, 

23 copy_unmapped: bool = True, 

24 ): 

25 """Initialize the field mapper. 

26  

27 Args: 

28 field_map: Dictionary mapping source field names to target names. 

29 drop_unmapped: If True, drop fields not in the mapping. 

30 copy_unmapped: If True, copy unmapped fields as-is. 

31 """ 

32 self.field_map = field_map 

33 self.drop_unmapped = drop_unmapped 

34 self.copy_unmapped = copy_unmapped 

35 

36 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

37 """Transform data by mapping field names. 

38  

39 Args: 

40 data: Input data. 

41  

42 Returns: 

43 Transformed data with mapped field names. 

44 """ 

45 result = {} 

46 

47 # Map specified fields 

48 for source, target in self.field_map.items(): 

49 if source in data: 

50 # Handle nested field paths 

51 if "." in source: 

52 value = self._get_nested(data, source) 

53 else: 

54 value = data[source] 

55 

56 if "." in target: 

57 self._set_nested(result, target, value) 

58 else: 

59 result[target] = value 

60 

61 # Handle unmapped fields 

62 if not self.drop_unmapped and self.copy_unmapped: 

63 for key, value in data.items(): 

64 if key not in self.field_map and key not in result: 

65 result[key] = value 

66 

67 return result 

68 

69 def _get_nested(self, data: Dict, path: str) -> Any: 

70 """Get value from nested dictionary using dot notation.""" 

71 parts = path.split(".") 

72 value = data 

73 for part in parts: 

74 if isinstance(value, dict) and part in value: 

75 value = value[part] 

76 else: 

77 return None 

78 return value 

79 

80 def _set_nested(self, data: Dict, path: str, value: Any) -> None: 

81 """Set value in nested dictionary using dot notation.""" 

82 parts = path.split(".") 

83 current = data 

84 for part in parts[:-1]: 

85 if part not in current: 

86 current[part] = {} 

87 current = current[part] 

88 current[parts[-1]] = value 

89 

90 

91class ValueNormalizer(ITransformFunction): 

92 """Normalize values in data fields.""" 

93 

94 def __init__( 

95 self, 

96 normalizations: Dict[str, str], 

97 fields: List[str] | None = None, 

98 ): 

99 """Initialize the value normalizer. 

100  

101 Args: 

102 normalizations: Dictionary of normalization types: 

103 - "lowercase": Convert to lowercase 

104 - "uppercase": Convert to uppercase 

105 - "trim": Remove leading/trailing whitespace 

106 - "snake_case": Convert to snake_case 

107 - "camel_case": Convert to camelCase 

108 - "pascal_case": Convert to PascalCase 

109 - "remove_special": Remove special characters 

110 - "normalize_spaces": Replace multiple spaces with single space 

111 fields: List of fields to normalize. If None, apply to all string fields. 

112 """ 

113 self.normalizations = normalizations 

114 self.fields = fields 

115 

116 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

117 """Transform data by normalizing values. 

118  

119 Args: 

120 data: Input data. 

121  

122 Returns: 

123 Transformed data with normalized values. 

124 """ 

125 result = copy.deepcopy(data) 

126 

127 # Determine which fields to process 

128 fields_to_process = self.fields if self.fields else list(result.keys()) 

129 

130 for field in fields_to_process: 

131 if field not in result: 

132 continue 

133 

134 value = result[field] 

135 if not isinstance(value, str): 

136 continue 

137 

138 # Apply normalizations for this field 

139 field_normalizations = self.normalizations.get( 

140 field, self.normalizations.get("*", []) 

141 ) 

142 

143 if isinstance(field_normalizations, str): 

144 field_normalizations = [field_normalizations] 

145 

146 for normalization in field_normalizations: 

147 value = self._apply_normalization(value, normalization) 

148 

149 result[field] = value 

150 

151 return result 

152 

153 def _apply_normalization(self, value: str, normalization: str) -> str: 

154 """Apply a single normalization to a value.""" 

155 if normalization == "lowercase": 

156 return value.lower() 

157 elif normalization == "uppercase": 

158 return value.upper() 

159 elif normalization == "trim": 

160 return value.strip() 

161 elif normalization == "snake_case": 

162 # Convert to snake_case 

163 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', value) 

164 return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 

165 elif normalization == "camel_case": 

166 # Convert to camelCase 

167 parts = value.replace("-", "_").split("_") 

168 return parts[0].lower() + "".join(p.capitalize() for p in parts[1:]) 

169 elif normalization == "pascal_case": 

170 # Convert to PascalCase 

171 parts = value.replace("-", "_").split("_") 

172 return "".join(p.capitalize() for p in parts) 

173 elif normalization == "remove_special": 

174 return re.sub(r'[^a-zA-Z0-9\s]', '', value) 

175 elif normalization == "normalize_spaces": 

176 return re.sub(r'\s+', ' ', value).strip() 

177 else: 

178 return value 

179 

180 

181class TypeConverter(ITransformFunction): 

182 """Convert field types in data.""" 

183 

184 def __init__( 

185 self, 

186 conversions: Dict[str, Union[str, type, Callable]], 

187 strict: bool = False, 

188 ): 

189 """Initialize the type converter. 

190  

191 Args: 

192 conversions: Dictionary mapping field names to target types. 

193 Can be type names (str, int, float, bool, list, dict), 

194 type objects, or callable converters. 

195 strict: If True, raise error on conversion failure. 

196 """ 

197 self.conversions = conversions 

198 self.strict = strict 

199 

200 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

201 """Transform data by converting field types. 

202  

203 Args: 

204 data: Input data. 

205  

206 Returns: 

207 Transformed data with converted types. 

208 """ 

209 result = copy.deepcopy(data) 

210 

211 for field, target_type in self.conversions.items(): 

212 if field not in result: 

213 continue 

214 

215 value = result[field] 

216 

217 try: 

218 result[field] = self._convert_value(value, target_type) 

219 except Exception as e: 

220 if self.strict: 

221 raise TransformError( 

222 f"Failed to convert field '{field}': {e}" 

223 ) from e 

224 # Keep original value if conversion fails and not strict 

225 

226 return result 

227 

228 def _convert_value(self, value: Any, target_type: Union[str, type, Callable]) -> Any: 

229 """Convert a single value to target type.""" 

230 if value is None: 

231 return None 

232 

233 # Handle callable converters 

234 if callable(target_type) and not isinstance(target_type, type): 

235 return target_type(value) 

236 

237 # Handle type names 

238 if isinstance(target_type, str): 

239 target_type = { 

240 "str": str, 

241 "int": int, 

242 "float": float, 

243 "bool": bool, 

244 "list": list, 

245 "dict": dict, 

246 "datetime": datetime.fromisoformat, 

247 "json": json.loads, 

248 }.get(target_type, str) 

249 

250 # Special handling for bool conversion 

251 if target_type == bool and isinstance(value, str): 

252 return value.lower() in ["true", "yes", "1", "on"] 

253 

254 # Special handling for datetime 

255 if target_type == datetime.fromisoformat and isinstance(value, str): 

256 return datetime.fromisoformat(value) 

257 

258 # Standard type conversion 

259 return target_type(value) # type: ignore 

260 

261 

262class DataEnricher(ITransformFunction): 

263 """Enrich data with additional fields.""" 

264 

265 def __init__( 

266 self, 

267 enrichments: Dict[str, Any], 

268 overwrite: bool = False, 

269 ): 

270 """Initialize the data enricher. 

271  

272 Args: 

273 enrichments: Dictionary of fields to add/update. 

274 Values can be static or callables. 

275 overwrite: If True, overwrite existing fields. 

276 """ 

277 self.enrichments = enrichments 

278 self.overwrite = overwrite 

279 

280 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

281 """Transform data by adding enrichment fields. 

282  

283 Args: 

284 data: Input data. 

285  

286 Returns: 

287 Transformed data with enrichments. 

288 """ 

289 result = copy.deepcopy(data) 

290 

291 for field, value in self.enrichments.items(): 

292 # Skip if field exists and not overwriting 

293 if field in result and not self.overwrite: 

294 continue 

295 

296 # Evaluate value if callable 

297 if callable(value): 

298 try: 

299 result[field] = value(data) 

300 except Exception as e: 

301 raise TransformError( 

302 f"Failed to compute enrichment for '{field}': {e}" 

303 ) from e 

304 else: 

305 result[field] = value 

306 

307 return result 

308 

309 def get_transform_description(self) -> str: 

310 """Get a description of the transformation.""" 

311 fields = list(self.enrichments.keys()) 

312 return f"Enrich data with fields: {', '.join(fields[:3])}{'...' if len(fields) > 3 else ''}" 

313 

314 

315class FieldFilter(ITransformFunction): 

316 """Filter fields from data.""" 

317 

318 def __init__( 

319 self, 

320 include: List[str] | None = None, 

321 exclude: List[str] | None = None, 

322 ): 

323 """Initialize the field filter. 

324  

325 Args: 

326 include: List of fields to include (whitelist). 

327 exclude: List of fields to exclude (blacklist). 

328 """ 

329 if include and exclude: 

330 raise ValueError("Cannot specify both include and exclude") 

331 

332 self.include = include 

333 self.exclude = exclude 

334 

335 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

336 """Transform data by filtering fields. 

337  

338 Args: 

339 data: Input data. 

340  

341 Returns: 

342 Transformed data with filtered fields. 

343 """ 

344 if self.include: 

345 # Include only specified fields 

346 return {k: v for k, v in data.items() if k in self.include} 

347 elif self.exclude: 

348 # Exclude specified fields 

349 return {k: v for k, v in data.items() if k not in self.exclude} 

350 else: 

351 # No filtering 

352 return data.copy() 

353 

354 

355class ValueReplacer(ITransformFunction): 

356 """Replace specific values in data fields.""" 

357 

358 def __init__( 

359 self, 

360 replacements: Dict[str, Dict[Any, Any]], 

361 default_replacements: Dict[Any, Any] | None = None, 

362 ): 

363 """Initialize the value replacer. 

364  

365 Args: 

366 replacements: Dictionary mapping field names to replacement mappings. 

367 default_replacements: Default replacements for all fields. 

368 """ 

369 self.replacements = replacements 

370 self.default_replacements = default_replacements or {} 

371 

372 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

373 """Transform data by replacing values. 

374  

375 Args: 

376 data: Input data. 

377  

378 Returns: 

379 Transformed data with replaced values. 

380 """ 

381 result = copy.deepcopy(data) 

382 

383 for field, value in result.items(): 

384 # Get replacements for this field 

385 field_replacements = self.replacements.get(field, self.default_replacements) 

386 

387 if value in field_replacements: 

388 result[field] = field_replacements[value] 

389 

390 return result 

391 

392 

393class ArrayFlattener(ITransformFunction): 

394 """Flatten nested arrays in data.""" 

395 

396 def __init__( 

397 self, 

398 fields: List[str], 

399 depth: int = 1, 

400 ): 

401 """Initialize the array flattener. 

402  

403 Args: 

404 fields: List of fields containing arrays to flatten. 

405 depth: Number of levels to flatten (0 = fully flatten). 

406 """ 

407 self.fields = fields 

408 self.depth = depth 

409 

410 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

411 """Transform data by flattening arrays. 

412  

413 Args: 

414 data: Input data. 

415  

416 Returns: 

417 Transformed data with flattened arrays. 

418 """ 

419 result = copy.deepcopy(data) 

420 

421 for field in self.fields: 

422 if field not in result: 

423 continue 

424 

425 value = result[field] 

426 if isinstance(value, list): 

427 result[field] = self._flatten(value, self.depth) 

428 

429 return result 

430 

431 def _flatten(self, arr: List, depth: int) -> List: 

432 """Recursively flatten an array.""" 

433 if depth == 0: 

434 # Fully flatten 

435 result = [] 

436 for item in arr: 

437 if isinstance(item, list): 

438 result.extend(self._flatten(item, 0)) 

439 else: 

440 result.append(item) 

441 return result 

442 else: 

443 # Flatten to specified depth 

444 result = [] 

445 for item in arr: 

446 if isinstance(item, list) and depth > 1: 

447 result.extend(self._flatten(item, depth - 1)) 

448 elif isinstance(item, list): 

449 result.extend(item) 

450 else: 

451 result.append(item) 

452 return result 

453 

454 

455class DataSplitter(ITransformFunction): 

456 """Split data into multiple records based on a field.""" 

457 

458 def __init__( 

459 self, 

460 split_field: str, 

461 output_field: str = "records", 

462 ): 

463 """Initialize the data splitter. 

464  

465 Args: 

466 split_field: Field containing array to split on. 

467 output_field: Name of output field containing split records. 

468 """ 

469 self.split_field = split_field 

470 self.output_field = output_field 

471 

472 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

473 """Transform data by splitting into multiple records. 

474  

475 Args: 

476 data: Input data. 

477  

478 Returns: 

479 Transformed data with split records. 

480 """ 

481 if self.split_field not in data: 

482 raise TransformError(f"Split field '{self.split_field}' not found") 

483 

484 split_values = data[self.split_field] 

485 if not isinstance(split_values, list): 

486 raise TransformError("Split field must be a list") 

487 

488 # Create a record for each value 

489 records = [] 

490 base_data = {k: v for k, v in data.items() if k != self.split_field} 

491 

492 for value in split_values: 

493 record = copy.deepcopy(base_data) 

494 record[self.split_field] = value 

495 records.append(record) 

496 

497 return {self.output_field: records} 

498 

499 

500class ChainTransformer(ITransformFunction): 

501 """Chain multiple transformers together.""" 

502 

503 def __init__(self, transformers: List[ITransformFunction]): 

504 """Initialize the chain transformer. 

505  

506 Args: 

507 transformers: List of transformers to apply in sequence. 

508 """ 

509 self.transformers = transformers 

510 

511 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

512 """Apply all transformers in sequence. 

513  

514 Args: 

515 data: Input data. 

516  

517 Returns: 

518 Transformed data after all transformers. 

519 """ 

520 result = data 

521 for transformer in self.transformers: 

522 result = transformer.transform(result) 

523 return result 

524 

525 

526# Convenience functions for creating transformers 

527def map_fields(mapping: Dict[str, str], **kwargs) -> FieldMapper: 

528 """Create a FieldMapper.""" 

529 return FieldMapper(mapping, **kwargs) 

530 

531 

532def normalize(**normalizations: str) -> ValueNormalizer: 

533 """Create a ValueNormalizer.""" 

534 return ValueNormalizer(normalizations) 

535 

536 

537def convert_types(**conversions: Union[str, type, Callable]) -> TypeConverter: 

538 """Create a TypeConverter.""" 

539 return TypeConverter(conversions) 

540 

541 

542def enrich(**enrichments: Any) -> DataEnricher: 

543 """Create a DataEnricher.""" 

544 return DataEnricher(enrichments) 

545 

546 

547def filter_fields(include: List[str] | None = None, exclude: List[str] | None = None) -> FieldFilter: 

548 """Create a FieldFilter.""" 

549 return FieldFilter(include, exclude) 

550 

551 

552def replace_values(**replacements: Dict[Any, Any]) -> ValueReplacer: 

553 """Create a ValueReplacer.""" 

554 return ValueReplacer(replacements) 

555 

556 

557def flatten(*fields: str, depth: int = 1) -> ArrayFlattener: 

558 """Create an ArrayFlattener.""" 

559 return ArrayFlattener(list(fields), depth) 

560 

561 

562def split_on(field: str, output: str = "records") -> DataSplitter: 

563 """Create a DataSplitter.""" 

564 return DataSplitter(field, output) 

565 

566 

567def chain(*transformers: ITransformFunction) -> ChainTransformer: 

568 """Create a ChainTransformer.""" 

569 return ChainTransformer(list(transformers))