Coverage for src/dataknobs_fsm/io/utils.py: 0%

139 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Utility functions for I/O operations. 

2 

3This module provides utility functions for common I/O patterns. 

4""" 

5 

6import asyncio 

7from typing import ( 

8 Any, Dict, List, Union, AsyncIterator, Iterator, 

9 Callable, TypeVar, Awaitable 

10) 

11from functools import reduce 

12 

13from .base import IOConfig, IOFormat, IOProvider 

14from .adapters import ( 

15 FileIOAdapter, DatabaseIOAdapter, HTTPIOAdapter 

16) 

17 

18T = TypeVar('T') 

19 

20 

21def create_io_provider( 

22 config: IOConfig, 

23 is_async: bool = True 

24) -> IOProvider: 

25 """Create appropriate I/O provider based on configuration. 

26  

27 Args: 

28 config: I/O configuration 

29 is_async: Whether to create async provider 

30  

31 Returns: 

32 Appropriate I/O provider instance 

33 """ 

34 # Determine adapter based on format and source 

35 if config.format == IOFormat.DATABASE: 

36 adapter = DatabaseIOAdapter() 

37 elif config.format == IOFormat.API or (isinstance(config.source, str) and config.source.startswith(('http://', 'https://'))): 

38 adapter = HTTPIOAdapter() 

39 elif isinstance(config.source, dict): 

40 adapter = DatabaseIOAdapter() 

41 else: 

42 adapter = FileIOAdapter() 

43 

44 return adapter.create_provider(config, is_async) 

45 

46 

47def batch_iterator( 

48 iterable: Iterator[T], 

49 batch_size: int 

50) -> Iterator[List[T]]: 

51 """Create batches from an iterator. 

52  

53 Args: 

54 iterable: Source iterator 

55 batch_size: Size of each batch 

56  

57 Yields: 

58 Batches of items 

59 """ 

60 batch = [] 

61 for item in iterable: 

62 batch.append(item) 

63 if len(batch) >= batch_size: 

64 yield batch 

65 batch = [] 

66 if batch: 

67 yield batch 

68 

69 

70async def async_batch_iterator( 

71 iterable: AsyncIterator[T], 

72 batch_size: int 

73) -> AsyncIterator[List[T]]: 

74 """Create batches from an async iterator. 

75  

76 Args: 

77 iterable: Source async iterator 

78 batch_size: Size of each batch 

79  

80 Yields: 

81 Batches of items 

82 """ 

83 batch = [] 

84 async for item in iterable: 

85 batch.append(item) 

86 if len(batch) >= batch_size: 

87 yield batch 

88 batch = [] 

89 if batch: 

90 yield batch 

91 

92 

93def transform_pipeline( 

94 *transforms: Callable[[Any], Any] 

95) -> Callable[[Any], Any]: 

96 """Create a synchronous transformation pipeline. 

97  

98 Args: 

99 *transforms: Transformation functions to apply in sequence 

100  

101 Returns: 

102 Combined transformation function 

103 """ 

104 def pipeline(data: Any) -> Any: 

105 return reduce(lambda d, f: f(d), transforms, data) 

106 return pipeline 

107 

108 

109def async_transform_pipeline( 

110 *transforms: Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]]] 

111) -> Callable[[Any], Awaitable[Any]]: 

112 """Create an asynchronous transformation pipeline. 

113  

114 Args: 

115 *transforms: Transformation functions (sync or async) to apply in sequence 

116  

117 Returns: 

118 Combined async transformation function 

119 """ 

120 async def pipeline(data: Any) -> Any: 

121 result = data 

122 for transform in transforms: 

123 if asyncio.iscoroutinefunction(transform): 

124 result = await transform(result) 

125 else: 

126 result = transform(result) 

127 return result 

128 return pipeline 

129 

130 

131class IORouter: 

132 """Routes data between multiple I/O providers based on conditions.""" 

133 

134 def __init__(self): 

135 self.routes = [] 

136 

137 def add_route( 

138 self, 

139 condition: Callable[[Any], bool], 

140 provider: IOProvider, 

141 transform: Callable[[Any], Any] | None = None 

142 ): 

143 """Add a routing rule. 

144  

145 Args: 

146 condition: Function to determine if route should be used 

147 provider: I/O provider for this route 

148 transform: Optional transformation to apply 

149 """ 

150 self.routes.append({ 

151 'condition': condition, 

152 'provider': provider, 

153 'transform': transform or (lambda x: x) 

154 }) 

155 

156 async def route(self, data: Any) -> List[Any]: 

157 """Route data to appropriate providers. 

158  

159 Args: 

160 data: Data to route 

161  

162 Returns: 

163 Results from all matching routes 

164 """ 

165 results = [] 

166 for route in self.routes: 

167 if route['condition'](data): 

168 transformed = route['transform'](data) 

169 if hasattr(route['provider'], 'write'): 

170 if asyncio.iscoroutinefunction(route['provider'].write): 

171 await route['provider'].write(transformed) 

172 else: 

173 route['provider'].write(transformed) 

174 results.append(transformed) 

175 return results 

176 

177 

178class IOBuffer: 

179 """Buffer for I/O operations with overflow handling.""" 

180 

181 def __init__( 

182 self, 

183 max_size: int = 10000, 

184 overflow_handler: Callable[[List[Any]], None] | None = None 

185 ): 

186 """Initialize buffer. 

187  

188 Args: 

189 max_size: Maximum buffer size 

190 overflow_handler: Function to handle overflow 

191 """ 

192 self.max_size = max_size 

193 self.overflow_handler = overflow_handler 

194 self.buffer = [] 

195 self._lock = asyncio.Lock() 

196 

197 async def add(self, item: Any) -> None: 

198 """Add item to buffer. 

199  

200 Args: 

201 item: Item to add 

202 """ 

203 async with self._lock: 

204 self.buffer.append(item) 

205 if len(self.buffer) >= self.max_size: 

206 await self._handle_overflow() 

207 

208 async def flush(self) -> List[Any]: 

209 """Flush and return buffer contents. 

210  

211 Returns: 

212 Buffer contents 

213 """ 

214 async with self._lock: 

215 items = self.buffer.copy() 

216 self.buffer.clear() 

217 return items 

218 

219 async def _handle_overflow(self) -> None: 

220 """Handle buffer overflow.""" 

221 if self.overflow_handler: 

222 overflow_items = self.buffer[:self.max_size // 2] 

223 self.buffer = self.buffer[self.max_size // 2:] 

224 if asyncio.iscoroutinefunction(self.overflow_handler): 

225 await self.overflow_handler(overflow_items) 

226 else: 

227 self.overflow_handler(overflow_items) 

228 

229 

230class IOMetrics: 

231 """Track metrics for I/O operations.""" 

232 

233 def __init__(self): 

234 self.metrics = { 

235 'read_count': 0, 

236 'write_count': 0, 

237 'bytes_read': 0, 

238 'bytes_written': 0, 

239 'errors': 0, 

240 'retries': 0, 

241 'duration_ms': 0 

242 } 

243 

244 def record_read(self, bytes_read: int = 0): 

245 """Record read operation.""" 

246 self.metrics['read_count'] += 1 

247 self.metrics['bytes_read'] += bytes_read 

248 

249 def record_write(self, bytes_written: int = 0): 

250 """Record write operation.""" 

251 self.metrics['write_count'] += 1 

252 self.metrics['bytes_written'] += bytes_written 

253 

254 def record_error(self): 

255 """Record error.""" 

256 self.metrics['errors'] += 1 

257 

258 def record_retry(self): 

259 """Record retry.""" 

260 self.metrics['retries'] += 1 

261 

262 def get_metrics(self) -> Dict[str, Any]: 

263 """Get current metrics.""" 

264 return self.metrics.copy() 

265 

266 def reset(self): 

267 """Reset all metrics.""" 

268 for key in self.metrics: 

269 self.metrics[key] = 0 

270 

271 

272async def retry_io_operation( 

273 operation: Callable[[], Awaitable[T]], 

274 max_retries: int = 3, 

275 delay: float = 1.0, 

276 backoff: float = 2.0, 

277 exceptions: tuple = (Exception,) 

278) -> T: 

279 """Retry an I/O operation with exponential backoff. 

280  

281 Args: 

282 operation: Operation to retry 

283 max_retries: Maximum number of retries 

284 delay: Initial delay between retries 

285 backoff: Backoff multiplier 

286 exceptions: Exceptions to catch and retry 

287  

288 Returns: 

289 Result of successful operation 

290  

291 Raises: 

292 Last exception if all retries fail 

293 """ 

294 last_exception = None 

295 current_delay = delay 

296 

297 for attempt in range(max_retries + 1): 

298 try: 

299 return await operation() 

300 except exceptions as e: 

301 last_exception = e 

302 if attempt < max_retries: 

303 await asyncio.sleep(current_delay) 

304 current_delay *= backoff 

305 else: 

306 raise 

307 

308 raise last_exception # type: ignore 

309 

310 

311def parallel_io_executor( 

312 providers: List[IOProvider], 

313 max_workers: int = 4 

314) -> 'ParallelIOExecutor': 

315 """Create a parallel I/O executor. 

316  

317 Args: 

318 providers: List of I/O providers 

319 max_workers: Maximum concurrent workers 

320  

321 Returns: 

322 Parallel I/O executor instance 

323 """ 

324 return ParallelIOExecutor(providers, max_workers) 

325 

326 

327class ParallelIOExecutor: 

328 """Execute I/O operations in parallel.""" 

329 

330 def __init__(self, providers: List[IOProvider], max_workers: int = 4): 

331 self.providers = providers 

332 self.max_workers = max_workers 

333 

334 async def read_all(self, **kwargs) -> List[Any]: 

335 """Read from all providers in parallel. 

336  

337 Returns: 

338 Results from all providers 

339 """ 

340 tasks = [] 

341 for provider in self.providers: 

342 if hasattr(provider, 'read'): 

343 if asyncio.iscoroutinefunction(provider.read): 

344 tasks.append(provider.read(**kwargs)) 

345 

346 if tasks: 

347 return await asyncio.gather(*tasks) 

348 return [] 

349 

350 async def write_all(self, data: Any, **kwargs) -> None: 

351 """Write to all providers in parallel. 

352  

353 Args: 

354 data: Data to write 

355 """ 

356 tasks = [] 

357 for provider in self.providers: 

358 if hasattr(provider, 'write'): 

359 if asyncio.iscoroutinefunction(provider.write): 

360 tasks.append(provider.write(data, **kwargs)) 

361 

362 if tasks: 

363 await asyncio.gather(*tasks)