Coverage for fastblocks/actions/gather/strategies.py: 59%
158 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-09 00:47 -0700
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-09 00:47 -0700
1"""Gather strategies for error handling, caching, and parallelization."""
3import asyncio
4import typing as t
5from enum import Enum
6from pathlib import Path
8from acb.debug import debug
11class ErrorStrategy(Enum):
12 FAIL_FAST = "fail_fast"
13 COLLECT_ERRORS = "collect_errors"
14 IGNORE_ERRORS = "ignore_errors"
15 PARTIAL_SUCCESS = "partial_success"
18class CacheStrategy(Enum):
19 NO_CACHE = "no_cache"
20 MEMORY_CACHE = "memory_cache"
21 PERSISTENT = "persistent"
24class GatherStrategy:
25 def __init__(
26 self,
27 *,
28 parallel: bool = True,
29 max_concurrent: int = 10,
30 timeout: float = 30.0,
31 error_strategy: ErrorStrategy = ErrorStrategy.PARTIAL_SUCCESS,
32 cache_strategy: CacheStrategy = CacheStrategy.MEMORY_CACHE,
33 retry_attempts: int = 2,
34 retry_delay: float = 0.1,
35 ) -> None:
36 self.parallel = parallel
37 self.max_concurrent = max_concurrent
38 self.timeout = timeout
39 self.error_strategy = error_strategy
40 self.cache_strategy = cache_strategy
41 self.retry_attempts = retry_attempts
42 self.retry_delay = retry_delay
45_memory_cache: dict[str, t.Any] = {}
48class GatherResult:
49 def __init__(
50 self,
51 *,
52 success: list[t.Any] | None = None,
53 errors: list[Exception] | None = None,
54 cache_key: str | None = None,
55 ) -> None:
56 self.success = success if success is not None else []
57 self.errors = errors if errors is not None else []
58 self.cache_key = cache_key
59 self.total_attempts = len(self.success) + len(self.errors)
61 @property
62 def is_success(self) -> bool:
63 return len(self.success) > 0
65 @property
66 def is_partial(self) -> bool:
67 return len(self.success) > 0 and len(self.errors) > 0
69 @property
70 def is_failure(self) -> bool:
71 return len(self.success) == 0 and len(self.errors) > 0
74async def gather_with_strategy(
75 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]],
76 strategy: GatherStrategy,
77 cache_key: str | None = None,
78) -> GatherResult:
79 if cached_result := _check_cache(cache_key, strategy):
80 return cached_result
82 success_results, error_results = await _execute_tasks_with_strategy(tasks, strategy)
84 _handle_gather_errors(error_results, success_results, strategy)
86 result = GatherResult(
87 success=success_results,
88 errors=error_results,
89 cache_key=cache_key,
90 )
92 _cache_result_if_needed(result, cache_key, strategy)
94 return result
97def _check_cache(
98 cache_key: str | None,
99 strategy: GatherStrategy,
100) -> GatherResult | None:
101 if cache_key and strategy.cache_strategy != CacheStrategy.NO_CACHE:
102 if cached_result := _memory_cache.get(cache_key):
103 debug(f"Cache hit for {cache_key}")
104 return t.cast(GatherResult, cached_result)
105 return None
108async def _execute_tasks_with_strategy(
109 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]],
110 strategy: GatherStrategy,
111) -> tuple[list[t.Any], list[Exception]]:
112 if strategy.parallel and len(tasks) > 1:
113 return await _execute_tasks_parallel(tasks, strategy)
114 return await _execute_tasks_sequential(tasks, strategy)
117async def _execute_tasks_parallel(
118 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]],
119 strategy: GatherStrategy,
120) -> tuple[list[t.Any], list[Exception]]:
121 success_results = []
122 error_results = []
123 semaphore = asyncio.Semaphore(strategy.max_concurrent)
125 async def execute_with_semaphore(task: t.Coroutine[t.Any, t.Any, t.Any]) -> t.Any:
126 async with semaphore:
127 return await _execute_with_retry(task, strategy)
129 try:
130 results = await asyncio.wait_for(
131 asyncio.gather(
132 *[execute_with_semaphore(task) for task in tasks],
133 return_exceptions=True,
134 ),
135 timeout=strategy.timeout,
136 )
138 for result in results:
139 if isinstance(result, Exception):
140 error_results.append(result)
141 else:
142 success_results.append(result)
144 except TimeoutError as e:
145 error_results.append(e)
147 return success_results, error_results
150async def _execute_tasks_sequential(
151 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]],
152 strategy: GatherStrategy,
153) -> tuple[list[t.Any], list[Exception]]:
154 success_results = []
155 error_results = []
157 for task in tasks:
158 try:
159 result = await asyncio.wait_for(
160 _execute_with_retry(task, strategy),
161 timeout=strategy.timeout,
162 )
163 success_results.append(result)
165 except Exception as e:
166 error_results.append(e)
167 if strategy.error_strategy == ErrorStrategy.FAIL_FAST:
168 break
170 return success_results, error_results
173def _handle_gather_errors(
174 error_results: list[Exception],
175 success_results: list[t.Any],
176 strategy: GatherStrategy,
177) -> None:
178 if error_results:
179 debug(f"Gathering completed with {len(error_results)} errors")
181 if strategy.error_strategy == ErrorStrategy.FAIL_FAST:
182 raise error_results[0]
183 if strategy.error_strategy == ErrorStrategy.COLLECT_ERRORS:
184 if not success_results:
185 msg = f"All gathering operations failed: {error_results}"
186 raise Exception(msg)
189def _cache_result_if_needed(
190 result: GatherResult,
191 cache_key: str | None,
192 strategy: GatherStrategy,
193) -> None:
194 if cache_key and strategy.cache_strategy == CacheStrategy.MEMORY_CACHE:
195 _memory_cache[cache_key] = result
196 debug(f"Cached result for {cache_key}")
199async def _execute_with_retry(
200 task: t.Coroutine[t.Any, t.Any, t.Any],
201 strategy: GatherStrategy,
202) -> t.Any:
203 for attempt in range(strategy.retry_attempts + 1):
204 try:
205 return await task
206 except Exception as e:
207 if attempt == strategy.retry_attempts:
208 raise
210 debug(f"Retry attempt {attempt + 1} after error: {e}")
211 await asyncio.sleep(strategy.retry_delay * (attempt + 1))
213 msg = "Should not reach here"
214 raise RuntimeError(msg)
217async def gather_modules(
218 module_patterns: list[str],
219 strategy: GatherStrategy | None = None,
220 cache_key: str | None = None,
221) -> GatherResult:
222 if strategy is None:
223 strategy = GatherStrategy()
225 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]] = [
226 _import_module_safe(pattern) for pattern in module_patterns
227 ]
229 return await gather_with_strategy(tasks, strategy, cache_key)
232async def _import_module_safe(module_path: str) -> t.Any:
233 from importlib import import_module
235 if module_path.startswith("."):
236 base_module = "fastblocks"
237 module_path = base_module + module_path
238 debug(f"Importing module: {module_path}")
239 try:
240 return import_module(module_path)
241 except ModuleNotFoundError as e:
242 debug(f"Module not found: {module_path} - {e}")
243 raise
244 except Exception as e:
245 debug(f"Error importing {module_path}: {e}")
246 raise
249async def gather_files(
250 file_patterns: list[str],
251 base_path: Path | None = None,
252 strategy: GatherStrategy | None = None,
253 cache_key: str | None = None,
254) -> GatherResult:
255 if strategy is None:
256 strategy = GatherStrategy()
258 if base_path is None:
259 from acb.adapters import root_path
261 base_path = Path(root_path)
263 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]] = [
264 _find_files_safe(pattern, base_path) for pattern in file_patterns
265 ]
267 return await gather_with_strategy(tasks, strategy, cache_key)
270async def _find_files_safe(pattern: str, base_path: Path) -> list[Path]:
271 from anyio import Path as AsyncPath
273 async_base = AsyncPath(base_path)
274 try:
275 if "*" in pattern:
276 async_files = [
277 f async for f in async_base.rglob(pattern) if await f.is_file()
278 ]
279 files = [Path(f) for f in async_files]
280 else:
281 file_path = async_base / pattern
282 files = [Path(file_path)] if await file_path.exists() else []
283 debug(f"Found {len(files)} files for pattern: {pattern}")
284 return files
285 except Exception as e:
286 debug(f"Error finding files for pattern {pattern}: {e}")
287 raise
290def clear_cache(cache_key: str | None = None) -> None:
291 if cache_key:
292 _memory_cache.pop(cache_key, None)
293 debug(f"Cleared cache for {cache_key}")
294 else:
295 _memory_cache.clear()
296 debug("Cleared all cache")
299def get_cache_info() -> dict[str, t.Any]:
300 return {
301 "total_entries": len(_memory_cache),
302 "cache_keys": list(_memory_cache.keys()),
303 "memory_usage": sum(len(str(v)) for v in _memory_cache.values()),
304 }