Coverage for src/dataknobs_fsm/io/utils.py: 0%
139 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""Utility functions for I/O operations.
3This module provides utility functions for common I/O patterns.
4"""
6import asyncio
7from typing import (
8 Any, Dict, List, Union, AsyncIterator, Iterator,
9 Callable, TypeVar, Awaitable
10)
11from functools import reduce
13from .base import IOConfig, IOFormat, IOProvider
14from .adapters import (
15 FileIOAdapter, DatabaseIOAdapter, HTTPIOAdapter
16)
18T = TypeVar('T')
21def create_io_provider(
22 config: IOConfig,
23 is_async: bool = True
24) -> IOProvider:
25 """Create appropriate I/O provider based on configuration.
27 Args:
28 config: I/O configuration
29 is_async: Whether to create async provider
31 Returns:
32 Appropriate I/O provider instance
33 """
34 # Determine adapter based on format and source
35 if config.format == IOFormat.DATABASE:
36 adapter = DatabaseIOAdapter()
37 elif config.format == IOFormat.API or (isinstance(config.source, str) and config.source.startswith(('http://', 'https://'))):
38 adapter = HTTPIOAdapter()
39 elif isinstance(config.source, dict):
40 adapter = DatabaseIOAdapter()
41 else:
42 adapter = FileIOAdapter()
44 return adapter.create_provider(config, is_async)
47def batch_iterator(
48 iterable: Iterator[T],
49 batch_size: int
50) -> Iterator[List[T]]:
51 """Create batches from an iterator.
53 Args:
54 iterable: Source iterator
55 batch_size: Size of each batch
57 Yields:
58 Batches of items
59 """
60 batch = []
61 for item in iterable:
62 batch.append(item)
63 if len(batch) >= batch_size:
64 yield batch
65 batch = []
66 if batch:
67 yield batch
70async def async_batch_iterator(
71 iterable: AsyncIterator[T],
72 batch_size: int
73) -> AsyncIterator[List[T]]:
74 """Create batches from an async iterator.
76 Args:
77 iterable: Source async iterator
78 batch_size: Size of each batch
80 Yields:
81 Batches of items
82 """
83 batch = []
84 async for item in iterable:
85 batch.append(item)
86 if len(batch) >= batch_size:
87 yield batch
88 batch = []
89 if batch:
90 yield batch
93def transform_pipeline(
94 *transforms: Callable[[Any], Any]
95) -> Callable[[Any], Any]:
96 """Create a synchronous transformation pipeline.
98 Args:
99 *transforms: Transformation functions to apply in sequence
101 Returns:
102 Combined transformation function
103 """
104 def pipeline(data: Any) -> Any:
105 return reduce(lambda d, f: f(d), transforms, data)
106 return pipeline
109def async_transform_pipeline(
110 *transforms: Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]]]
111) -> Callable[[Any], Awaitable[Any]]:
112 """Create an asynchronous transformation pipeline.
114 Args:
115 *transforms: Transformation functions (sync or async) to apply in sequence
117 Returns:
118 Combined async transformation function
119 """
120 async def pipeline(data: Any) -> Any:
121 result = data
122 for transform in transforms:
123 if asyncio.iscoroutinefunction(transform):
124 result = await transform(result)
125 else:
126 result = transform(result)
127 return result
128 return pipeline
131class IORouter:
132 """Routes data between multiple I/O providers based on conditions."""
134 def __init__(self):
135 self.routes = []
137 def add_route(
138 self,
139 condition: Callable[[Any], bool],
140 provider: IOProvider,
141 transform: Callable[[Any], Any] | None = None
142 ):
143 """Add a routing rule.
145 Args:
146 condition: Function to determine if route should be used
147 provider: I/O provider for this route
148 transform: Optional transformation to apply
149 """
150 self.routes.append({
151 'condition': condition,
152 'provider': provider,
153 'transform': transform or (lambda x: x)
154 })
156 async def route(self, data: Any) -> List[Any]:
157 """Route data to appropriate providers.
159 Args:
160 data: Data to route
162 Returns:
163 Results from all matching routes
164 """
165 results = []
166 for route in self.routes:
167 if route['condition'](data):
168 transformed = route['transform'](data)
169 if hasattr(route['provider'], 'write'):
170 if asyncio.iscoroutinefunction(route['provider'].write):
171 await route['provider'].write(transformed)
172 else:
173 route['provider'].write(transformed)
174 results.append(transformed)
175 return results
178class IOBuffer:
179 """Buffer for I/O operations with overflow handling."""
181 def __init__(
182 self,
183 max_size: int = 10000,
184 overflow_handler: Callable[[List[Any]], None] | None = None
185 ):
186 """Initialize buffer.
188 Args:
189 max_size: Maximum buffer size
190 overflow_handler: Function to handle overflow
191 """
192 self.max_size = max_size
193 self.overflow_handler = overflow_handler
194 self.buffer = []
195 self._lock = asyncio.Lock()
197 async def add(self, item: Any) -> None:
198 """Add item to buffer.
200 Args:
201 item: Item to add
202 """
203 async with self._lock:
204 self.buffer.append(item)
205 if len(self.buffer) >= self.max_size:
206 await self._handle_overflow()
208 async def flush(self) -> List[Any]:
209 """Flush and return buffer contents.
211 Returns:
212 Buffer contents
213 """
214 async with self._lock:
215 items = self.buffer.copy()
216 self.buffer.clear()
217 return items
219 async def _handle_overflow(self) -> None:
220 """Handle buffer overflow."""
221 if self.overflow_handler:
222 overflow_items = self.buffer[:self.max_size // 2]
223 self.buffer = self.buffer[self.max_size // 2:]
224 if asyncio.iscoroutinefunction(self.overflow_handler):
225 await self.overflow_handler(overflow_items)
226 else:
227 self.overflow_handler(overflow_items)
230class IOMetrics:
231 """Track metrics for I/O operations."""
233 def __init__(self):
234 self.metrics = {
235 'read_count': 0,
236 'write_count': 0,
237 'bytes_read': 0,
238 'bytes_written': 0,
239 'errors': 0,
240 'retries': 0,
241 'duration_ms': 0
242 }
244 def record_read(self, bytes_read: int = 0):
245 """Record read operation."""
246 self.metrics['read_count'] += 1
247 self.metrics['bytes_read'] += bytes_read
249 def record_write(self, bytes_written: int = 0):
250 """Record write operation."""
251 self.metrics['write_count'] += 1
252 self.metrics['bytes_written'] += bytes_written
254 def record_error(self):
255 """Record error."""
256 self.metrics['errors'] += 1
258 def record_retry(self):
259 """Record retry."""
260 self.metrics['retries'] += 1
262 def get_metrics(self) -> Dict[str, Any]:
263 """Get current metrics."""
264 return self.metrics.copy()
266 def reset(self):
267 """Reset all metrics."""
268 for key in self.metrics:
269 self.metrics[key] = 0
272async def retry_io_operation(
273 operation: Callable[[], Awaitable[T]],
274 max_retries: int = 3,
275 delay: float = 1.0,
276 backoff: float = 2.0,
277 exceptions: tuple = (Exception,)
278) -> T:
279 """Retry an I/O operation with exponential backoff.
281 Args:
282 operation: Operation to retry
283 max_retries: Maximum number of retries
284 delay: Initial delay between retries
285 backoff: Backoff multiplier
286 exceptions: Exceptions to catch and retry
288 Returns:
289 Result of successful operation
291 Raises:
292 Last exception if all retries fail
293 """
294 last_exception = None
295 current_delay = delay
297 for attempt in range(max_retries + 1):
298 try:
299 return await operation()
300 except exceptions as e:
301 last_exception = e
302 if attempt < max_retries:
303 await asyncio.sleep(current_delay)
304 current_delay *= backoff
305 else:
306 raise
308 raise last_exception # type: ignore
311def parallel_io_executor(
312 providers: List[IOProvider],
313 max_workers: int = 4
314) -> 'ParallelIOExecutor':
315 """Create a parallel I/O executor.
317 Args:
318 providers: List of I/O providers
319 max_workers: Maximum concurrent workers
321 Returns:
322 Parallel I/O executor instance
323 """
324 return ParallelIOExecutor(providers, max_workers)
327class ParallelIOExecutor:
328 """Execute I/O operations in parallel."""
330 def __init__(self, providers: List[IOProvider], max_workers: int = 4):
331 self.providers = providers
332 self.max_workers = max_workers
334 async def read_all(self, **kwargs) -> List[Any]:
335 """Read from all providers in parallel.
337 Returns:
338 Results from all providers
339 """
340 tasks = []
341 for provider in self.providers:
342 if hasattr(provider, 'read'):
343 if asyncio.iscoroutinefunction(provider.read):
344 tasks.append(provider.read(**kwargs))
346 if tasks:
347 return await asyncio.gather(*tasks)
348 return []
350 async def write_all(self, data: Any, **kwargs) -> None:
351 """Write to all providers in parallel.
353 Args:
354 data: Data to write
355 """
356 tasks = []
357 for provider in self.providers:
358 if hasattr(provider, 'write'):
359 if asyncio.iscoroutinefunction(provider.write):
360 tasks.append(provider.write(data, **kwargs))
362 if tasks:
363 await asyncio.gather(*tasks)