Coverage for src/dataknobs_fsm/functions/library/streaming.py: 0%

263 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Built-in streaming functions for FSM. 

2 

3This module provides streaming-related functions that can be referenced 

4in FSM configurations for processing large data sets efficiently. 

5""" 

6 

7import csv 

8import json 

9from pathlib import Path 

10from typing import Any, Dict, List, Union 

11 

12from dataknobs_fsm.functions.base import ITransformFunction, TransformFunctionError 

13from dataknobs_fsm.streaming.core import IStreamSource 

14 

15 

16class ChunkReader(ITransformFunction): 

17 """Read data in chunks from a source.""" 

18 

19 def __init__( 

20 self, 

21 source: Union[str, IStreamSource], 

22 chunk_size: int = 1000, 

23 format: str = "auto", # "auto", "json", "csv", "lines" 

24 ): 

25 """Initialize the chunk reader. 

26  

27 Args: 

28 source: Data source (file path or stream source). 

29 chunk_size: Number of records per chunk. 

30 format: Data format to expect. 

31 """ 

32 self.source = source 

33 self.chunk_size = chunk_size 

34 self.format = format 

35 

36 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

37 """Transform data by reading next chunk from source. 

38  

39 Args: 

40 data: Input data (may contain chunk state). 

41  

42 Returns: 

43 Data with next chunk of records. 

44 """ 

45 # Get or initialize chunk state 

46 chunk_state = data.get("_chunk_state", {}) 

47 

48 if isinstance(self.source, str): 

49 # File source 

50 file_path = Path(self.source) 

51 if not file_path.exists(): 

52 raise TransformFunctionError(f"File not found: {self.source}") 

53 

54 # Determine format 

55 format = self.format 

56 if format == "auto": 

57 format = self._detect_format(file_path) 

58 

59 # Read chunk based on format 

60 if format == "json": 

61 chunk = await self._read_json_chunk(file_path, chunk_state) 

62 elif format == "csv": 

63 chunk = await self._read_csv_chunk(file_path, chunk_state) 

64 elif format == "lines": 

65 chunk = await self._read_lines_chunk(file_path, chunk_state) 

66 else: 

67 raise TransformFunctionError(f"Unsupported format: {format}") 

68 

69 else: 

70 # Stream source 

71 chunk = await self._read_stream_chunk(self.source, chunk_state) 

72 

73 return { 

74 **data, 

75 "chunk": chunk["records"], 

76 "has_more": chunk["has_more"], 

77 "_chunk_state": chunk["state"], 

78 } 

79 

80 def _detect_format(self, file_path: Path) -> str: 

81 """Detect file format from extension.""" 

82 suffix = file_path.suffix.lower() 

83 if suffix == ".json": 

84 return "json" 

85 elif suffix == ".csv": 

86 return "csv" 

87 else: 

88 return "lines" 

89 

90 async def _read_json_chunk( 

91 self, file_path: Path, state: Dict[str, Any] 

92 ) -> Dict[str, Any]: 

93 """Read chunk from JSON file.""" 

94 offset = state.get("offset", 0) 

95 

96 # For JSON, we need to load the entire file (or use streaming JSON parser) 

97 with open(file_path) as f: 

98 data = json.load(f) 

99 

100 if isinstance(data, list): 

101 chunk = data[offset:offset + self.chunk_size] 

102 has_more = offset + self.chunk_size < len(data) 

103 new_offset = offset + len(chunk) 

104 else: 

105 # Single object 

106 if offset == 0: 

107 chunk = [data] 

108 has_more = False 

109 new_offset = 1 

110 else: 

111 chunk = [] 

112 has_more = False 

113 new_offset = offset 

114 

115 return { 

116 "records": chunk, 

117 "has_more": has_more, 

118 "state": {"offset": new_offset}, 

119 } 

120 

121 async def _read_csv_chunk( 

122 self, file_path: Path, state: Dict[str, Any] 

123 ) -> Dict[str, Any]: 

124 """Read chunk from CSV file.""" 

125 offset = state.get("offset", 0) 

126 records = [] 

127 

128 with open(file_path) as f: 

129 reader = csv.DictReader(f) 

130 

131 # Skip to offset 

132 for _ in range(offset): 

133 try: 

134 next(reader) 

135 except StopIteration: 

136 break 

137 

138 # Read chunk 

139 for _ in range(self.chunk_size): 

140 try: 

141 records.append(next(reader)) 

142 except StopIteration: 

143 break 

144 

145 has_more = len(records) == self.chunk_size 

146 new_offset = offset + len(records) 

147 

148 return { 

149 "records": records, 

150 "has_more": has_more, 

151 "state": {"offset": new_offset}, 

152 } 

153 

154 async def _read_lines_chunk( 

155 self, file_path: Path, state: Dict[str, Any] 

156 ) -> Dict[str, Any]: 

157 """Read chunk of lines from file.""" 

158 offset = state.get("offset", 0) 

159 records = [] 

160 

161 with open(file_path) as f: 

162 # Skip to offset 

163 for _ in range(offset): 

164 if not f.readline(): 

165 break 

166 

167 # Read chunk 

168 for _ in range(self.chunk_size): 

169 line = f.readline() 

170 if not line: 

171 break 

172 records.append({"line": line.strip()}) 

173 

174 has_more = len(records) == self.chunk_size 

175 new_offset = offset + len(records) 

176 

177 return { 

178 "records": records, 

179 "has_more": has_more, 

180 "state": {"offset": new_offset}, 

181 } 

182 

183 async def _read_stream_chunk( 

184 self, source: IStreamSource, state: Dict[str, Any] 

185 ) -> Dict[str, Any]: 

186 """Read chunk from stream source.""" 

187 records = [] 

188 

189 async for record in source.read(self.chunk_size): 

190 records.append(record) 

191 

192 has_more = len(records) == self.chunk_size 

193 

194 return { 

195 "records": records, 

196 "has_more": has_more, 

197 "state": {"stream_position": source.position if hasattr(source, "position") else None}, 

198 } 

199 

200 

201class RecordParser(ITransformFunction): 

202 """Parse records from various formats.""" 

203 

204 def __init__( 

205 self, 

206 format: str, 

207 field: str = "raw", 

208 output_field: str = "parsed", 

209 options: Dict[str, Any] | None = None, 

210 ): 

211 """Initialize the record parser. 

212  

213 Args: 

214 format: Format to parse ("json", "csv", "xml", "yaml"). 

215 field: Field containing raw data to parse. 

216 output_field: Field to store parsed data. 

217 options: Format-specific parsing options. 

218 """ 

219 self.format = format 

220 self.field = field 

221 self.output_field = output_field 

222 self.options = options or {} 

223 

224 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

225 """Transform data by parsing records. 

226  

227 Args: 

228 data: Input data containing raw records. 

229  

230 Returns: 

231 Data with parsed records. 

232 """ 

233 raw_data = data.get(self.field) 

234 if raw_data is None: 

235 return data 

236 

237 try: 

238 if self.format == "json": 

239 parsed = self._parse_json(raw_data) 

240 elif self.format == "csv": 

241 parsed = self._parse_csv(raw_data) 

242 elif self.format == "yaml": 

243 parsed = self._parse_yaml(raw_data) 

244 elif self.format == "xml": 

245 parsed = self._parse_xml(raw_data) 

246 else: 

247 raise TransformFunctionError(f"Unsupported format: {self.format}") 

248 

249 return { 

250 **data, 

251 self.output_field: parsed, 

252 } 

253 

254 except Exception as e: 

255 raise TransformFunctionError(f"Failed to parse {self.format}: {e}") from e 

256 

257 def _parse_json(self, raw: Union[str, bytes]) -> Any: 

258 """Parse JSON data.""" 

259 if isinstance(raw, bytes): 

260 raw = raw.decode("utf-8") 

261 return json.loads(raw) 

262 

263 def _parse_csv(self, raw: Union[str, bytes]) -> List[Dict[str, Any]]: 

264 """Parse CSV data.""" 

265 if isinstance(raw, bytes): 

266 raw = raw.decode("utf-8") 

267 

268 import io 

269 reader = csv.DictReader(io.StringIO(raw), **self.options) 

270 return list(reader) 

271 

272 def _parse_yaml(self, raw: Union[str, bytes]) -> Any: 

273 """Parse YAML data.""" 

274 import yaml 

275 if isinstance(raw, bytes): 

276 raw = raw.decode("utf-8") 

277 return yaml.safe_load(raw) 

278 

279 def _parse_xml(self, raw: Union[str, bytes]) -> Dict[str, Any]: 

280 """Parse XML data.""" 

281 import xml.etree.ElementTree as ET 

282 if isinstance(raw, str): 

283 raw = raw.encode("utf-8") 

284 

285 root = ET.fromstring(raw) 

286 return self._xml_to_dict(root) 

287 

288 def _xml_to_dict(self, element) -> Dict[str, Any]: 

289 """Convert XML element to dictionary.""" 

290 result = {} 

291 

292 # Add attributes 

293 if element.attrib: 

294 result["@attributes"] = element.attrib 

295 

296 # Add text content 

297 if element.text and element.text.strip(): 

298 result["text"] = element.text.strip() 

299 

300 # Add children 

301 for child in element: 

302 child_data = self._xml_to_dict(child) 

303 if child.tag in result: 

304 # Convert to list if multiple children with same tag 

305 if not isinstance(result[child.tag], list): 

306 result[child.tag] = [result[child.tag]] 

307 result[child.tag].append(child_data) 

308 else: 

309 result[child.tag] = child_data 

310 

311 return result 

312 

313 

314class FileAppender(ITransformFunction): 

315 """Append data to a file.""" 

316 

317 def __init__( 

318 self, 

319 file_path: str, 

320 format: str = "json", # "json", "csv", "lines" 

321 field: str = "data", 

322 buffer_size: int = 100, 

323 create_if_missing: bool = True, 

324 ): 

325 """Initialize the file appender. 

326  

327 Args: 

328 file_path: Path to file to append to. 

329 format: Format to write data in. 

330 field: Field containing data to append. 

331 buffer_size: Number of records to buffer before writing. 

332 create_if_missing: Create file if it doesn't exist. 

333 """ 

334 self.file_path = Path(file_path) 

335 self.format = format 

336 self.field = field 

337 self.buffer_size = buffer_size 

338 self.create_if_missing = create_if_missing 

339 self._buffer: List[Any] = [] 

340 

341 async def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

342 """Transform data by appending to file. 

343  

344 Args: 

345 data: Input data containing records to append. 

346  

347 Returns: 

348 Data with append status. 

349 """ 

350 records = data.get(self.field) 

351 if records is None: 

352 return data 

353 

354 # Add to buffer 

355 if isinstance(records, list): 

356 self._buffer.extend(records) 

357 else: 

358 self._buffer.append(records) 

359 

360 # Write if buffer is full 

361 written = 0 

362 if len(self._buffer) >= self.buffer_size: 

363 written = await self._write_buffer() 

364 

365 return { 

366 **data, 

367 "appended_count": written, 

368 "buffer_size": len(self._buffer), 

369 } 

370 

371 async def _write_buffer(self) -> int: 

372 """Write buffer to file.""" 

373 if not self._buffer: 

374 return 0 

375 

376 # Create file if needed 

377 if self.create_if_missing and not self.file_path.exists(): 

378 self.file_path.parent.mkdir(parents=True, exist_ok=True) 

379 self.file_path.touch() 

380 

381 count = len(self._buffer) 

382 

383 if self.format == "json": 

384 # Append to JSON array 

385 existing = [] 

386 if self.file_path.exists() and self.file_path.stat().st_size > 0: 

387 with open(self.file_path) as f: 

388 existing = json.load(f) 

389 

390 existing.extend(self._buffer) 

391 

392 with open(self.file_path, "w") as f: 

393 json.dump(existing, f, indent=2) 

394 

395 elif self.format == "csv": 

396 # Append to CSV 

397 import csv 

398 

399 file_exists = self.file_path.exists() and self.file_path.stat().st_size > 0 

400 

401 with open(self.file_path, "a", newline="") as f: 

402 if self._buffer and isinstance(self._buffer[0], dict): 

403 writer = csv.DictWriter(f, fieldnames=self._buffer[0].keys()) 

404 if not file_exists: 

405 writer.writeheader() 

406 writer.writerows(self._buffer) 

407 else: 

408 writer = csv.writer(f) 

409 writer.writerows(self._buffer) 

410 

411 elif self.format == "lines": 

412 # Append lines 

413 with open(self.file_path, "a") as f: 

414 for record in self._buffer: 

415 if isinstance(record, dict): 

416 f.write(json.dumps(record) + "\n") 

417 else: 

418 f.write(str(record) + "\n") 

419 

420 else: 

421 raise TransformFunctionError(f"Unsupported format: {self.format}") 

422 

423 self._buffer.clear() 

424 return count 

425 

426 async def flush(self) -> int: 

427 """Flush any remaining buffered data.""" 

428 return await self._write_buffer() 

429 

430 

431class StreamAggregator(ITransformFunction): 

432 """Aggregate streaming data using various functions.""" 

433 

434 def __init__( 

435 self, 

436 aggregations: Dict[str, Dict[str, Any]], 

437 group_by: List[str] | None = None, 

438 window_size: int | None = None, 

439 ): 

440 """Initialize the stream aggregator. 

441  

442 Args: 

443 aggregations: Dictionary of aggregation specifications. 

444 Keys are output field names, values are: 

445 {"function": "sum|avg|min|max|count", "field": "source_field"} 

446 group_by: Fields to group by before aggregating. 

447 window_size: Number of records in sliding window. 

448 """ 

449 self.aggregations = aggregations 

450 self.group_by = group_by 

451 self.window_size = window_size 

452 self._window: List[Dict[str, Any]] = [] 

453 self._groups: Dict[tuple, List[Dict[str, Any]]] = {} 

454 

455 def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: 

456 """Transform data by aggregating stream. 

457  

458 Args: 

459 data: Input data (single record or batch). 

460  

461 Returns: 

462 Data with aggregation results. 

463 """ 

464 # Add to window/groups 

465 records = data.get("records", [data]) 

466 

467 if self.group_by: 

468 # Group-based aggregation 

469 for record in records: 

470 key = tuple(record.get(field) for field in self.group_by) 

471 if key not in self._groups: 

472 self._groups[key] = [] 

473 self._groups[key].append(record) 

474 

475 # Apply window size per group 

476 if self.window_size and len(self._groups[key]) > self.window_size: 

477 self._groups[key] = self._groups[key][-self.window_size:] 

478 

479 # Compute aggregations per group 

480 results = [] 

481 for key, group_records in self._groups.items(): 

482 result = dict(zip(self.group_by, key, strict=False)) 

483 for output_field, agg_spec in self.aggregations.items(): 

484 result[output_field] = self._compute_aggregation(group_records, agg_spec) 

485 results.append(result) 

486 

487 return {**data, "aggregations": results} 

488 

489 else: 

490 # Global aggregation 

491 self._window.extend(records) 

492 

493 # Apply window size 

494 if self.window_size and len(self._window) > self.window_size: 

495 self._window = self._window[-self.window_size:] 

496 

497 # Compute aggregations 

498 result = {} 

499 for output_field, agg_spec in self.aggregations.items(): 

500 result[output_field] = self._compute_aggregation(self._window, agg_spec) 

501 

502 return {**data, "aggregation": result} 

503 

504 def _compute_aggregation( 

505 self, records: List[Dict[str, Any]], spec: Dict[str, Any] 

506 ) -> Any: 

507 """Compute a single aggregation.""" 

508 func = spec["function"] 

509 field = spec.get("field") 

510 

511 if func == "count": 

512 return len(records) 

513 

514 if not field: 

515 raise TransformFunctionError(f"Field required for {func} aggregation") 

516 

517 values: List[Any] = [r.get(field) for r in records if r.get(field) is not None] 

518 

519 if not values: 

520 return None 

521 

522 if func == "sum": 

523 return sum(values) # type: ignore 

524 elif func == "avg": 

525 return sum(values) / len(values) # type: ignore 

526 elif func == "min": 

527 return min(values) # type: ignore 

528 elif func == "max": 

529 return max(values) # type: ignore 

530 else: 

531 raise TransformFunctionError(f"Unknown aggregation function: {func}") 

532 

533 

534# Convenience functions for creating streaming functions 

535def read_chunks(source: str, size: int = 1000, **kwargs) -> ChunkReader: 

536 """Create a ChunkReader.""" 

537 return ChunkReader(source, size, **kwargs) 

538 

539 

540def parse(format: str, **kwargs) -> RecordParser: 

541 """Create a RecordParser.""" 

542 return RecordParser(format, **kwargs) 

543 

544 

545def append_to_file(path: str, **kwargs) -> FileAppender: 

546 """Create a FileAppender.""" 

547 return FileAppender(path, **kwargs) 

548 

549 

550def aggregate(**aggregations: Dict[str, Any]) -> StreamAggregator: 

551 """Create a StreamAggregator.""" 

552 return StreamAggregator(aggregations)