Coverage for fastblocks/actions/gather/strategies.py: 59%

158 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-09 00:47 -0700

1"""Gather strategies for error handling, caching, and parallelization.""" 

2 

3import asyncio 

4import typing as t 

5from enum import Enum 

6from pathlib import Path 

7 

8from acb.debug import debug 

9 

10 

11class ErrorStrategy(Enum): 

12 FAIL_FAST = "fail_fast" 

13 COLLECT_ERRORS = "collect_errors" 

14 IGNORE_ERRORS = "ignore_errors" 

15 PARTIAL_SUCCESS = "partial_success" 

16 

17 

18class CacheStrategy(Enum): 

19 NO_CACHE = "no_cache" 

20 MEMORY_CACHE = "memory_cache" 

21 PERSISTENT = "persistent" 

22 

23 

24class GatherStrategy: 

25 def __init__( 

26 self, 

27 *, 

28 parallel: bool = True, 

29 max_concurrent: int = 10, 

30 timeout: float = 30.0, 

31 error_strategy: ErrorStrategy = ErrorStrategy.PARTIAL_SUCCESS, 

32 cache_strategy: CacheStrategy = CacheStrategy.MEMORY_CACHE, 

33 retry_attempts: int = 2, 

34 retry_delay: float = 0.1, 

35 ) -> None: 

36 self.parallel = parallel 

37 self.max_concurrent = max_concurrent 

38 self.timeout = timeout 

39 self.error_strategy = error_strategy 

40 self.cache_strategy = cache_strategy 

41 self.retry_attempts = retry_attempts 

42 self.retry_delay = retry_delay 

43 

44 

45_memory_cache: dict[str, t.Any] = {} 

46 

47 

48class GatherResult: 

49 def __init__( 

50 self, 

51 *, 

52 success: list[t.Any] | None = None, 

53 errors: list[Exception] | None = None, 

54 cache_key: str | None = None, 

55 ) -> None: 

56 self.success = success if success is not None else [] 

57 self.errors = errors if errors is not None else [] 

58 self.cache_key = cache_key 

59 self.total_attempts = len(self.success) + len(self.errors) 

60 

61 @property 

62 def is_success(self) -> bool: 

63 return len(self.success) > 0 

64 

65 @property 

66 def is_partial(self) -> bool: 

67 return len(self.success) > 0 and len(self.errors) > 0 

68 

69 @property 

70 def is_failure(self) -> bool: 

71 return len(self.success) == 0 and len(self.errors) > 0 

72 

73 

74async def gather_with_strategy( 

75 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]], 

76 strategy: GatherStrategy, 

77 cache_key: str | None = None, 

78) -> GatherResult: 

79 if cached_result := _check_cache(cache_key, strategy): 

80 return cached_result 

81 

82 success_results, error_results = await _execute_tasks_with_strategy(tasks, strategy) 

83 

84 _handle_gather_errors(error_results, success_results, strategy) 

85 

86 result = GatherResult( 

87 success=success_results, 

88 errors=error_results, 

89 cache_key=cache_key, 

90 ) 

91 

92 _cache_result_if_needed(result, cache_key, strategy) 

93 

94 return result 

95 

96 

97def _check_cache( 

98 cache_key: str | None, 

99 strategy: GatherStrategy, 

100) -> GatherResult | None: 

101 if cache_key and strategy.cache_strategy != CacheStrategy.NO_CACHE: 

102 if cached_result := _memory_cache.get(cache_key): 

103 debug(f"Cache hit for {cache_key}") 

104 return t.cast(GatherResult, cached_result) 

105 return None 

106 

107 

108async def _execute_tasks_with_strategy( 

109 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]], 

110 strategy: GatherStrategy, 

111) -> tuple[list[t.Any], list[Exception]]: 

112 if strategy.parallel and len(tasks) > 1: 

113 return await _execute_tasks_parallel(tasks, strategy) 

114 return await _execute_tasks_sequential(tasks, strategy) 

115 

116 

117async def _execute_tasks_parallel( 

118 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]], 

119 strategy: GatherStrategy, 

120) -> tuple[list[t.Any], list[Exception]]: 

121 success_results = [] 

122 error_results = [] 

123 semaphore = asyncio.Semaphore(strategy.max_concurrent) 

124 

125 async def execute_with_semaphore(task: t.Coroutine[t.Any, t.Any, t.Any]) -> t.Any: 

126 async with semaphore: 

127 return await _execute_with_retry(task, strategy) 

128 

129 try: 

130 results = await asyncio.wait_for( 

131 asyncio.gather( 

132 *[execute_with_semaphore(task) for task in tasks], 

133 return_exceptions=True, 

134 ), 

135 timeout=strategy.timeout, 

136 ) 

137 

138 for result in results: 

139 if isinstance(result, Exception): 

140 error_results.append(result) 

141 else: 

142 success_results.append(result) 

143 

144 except TimeoutError as e: 

145 error_results.append(e) 

146 

147 return success_results, error_results 

148 

149 

150async def _execute_tasks_sequential( 

151 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]], 

152 strategy: GatherStrategy, 

153) -> tuple[list[t.Any], list[Exception]]: 

154 success_results = [] 

155 error_results = [] 

156 

157 for task in tasks: 

158 try: 

159 result = await asyncio.wait_for( 

160 _execute_with_retry(task, strategy), 

161 timeout=strategy.timeout, 

162 ) 

163 success_results.append(result) 

164 

165 except Exception as e: 

166 error_results.append(e) 

167 if strategy.error_strategy == ErrorStrategy.FAIL_FAST: 

168 break 

169 

170 return success_results, error_results 

171 

172 

173def _handle_gather_errors( 

174 error_results: list[Exception], 

175 success_results: list[t.Any], 

176 strategy: GatherStrategy, 

177) -> None: 

178 if error_results: 

179 debug(f"Gathering completed with {len(error_results)} errors") 

180 

181 if strategy.error_strategy == ErrorStrategy.FAIL_FAST: 

182 raise error_results[0] 

183 if strategy.error_strategy == ErrorStrategy.COLLECT_ERRORS: 

184 if not success_results: 

185 msg = f"All gathering operations failed: {error_results}" 

186 raise Exception(msg) 

187 

188 

189def _cache_result_if_needed( 

190 result: GatherResult, 

191 cache_key: str | None, 

192 strategy: GatherStrategy, 

193) -> None: 

194 if cache_key and strategy.cache_strategy == CacheStrategy.MEMORY_CACHE: 

195 _memory_cache[cache_key] = result 

196 debug(f"Cached result for {cache_key}") 

197 

198 

199async def _execute_with_retry( 

200 task: t.Coroutine[t.Any, t.Any, t.Any], 

201 strategy: GatherStrategy, 

202) -> t.Any: 

203 for attempt in range(strategy.retry_attempts + 1): 

204 try: 

205 return await task 

206 except Exception as e: 

207 if attempt == strategy.retry_attempts: 

208 raise 

209 

210 debug(f"Retry attempt {attempt + 1} after error: {e}") 

211 await asyncio.sleep(strategy.retry_delay * (attempt + 1)) 

212 

213 msg = "Should not reach here" 

214 raise RuntimeError(msg) 

215 

216 

217async def gather_modules( 

218 module_patterns: list[str], 

219 strategy: GatherStrategy | None = None, 

220 cache_key: str | None = None, 

221) -> GatherResult: 

222 if strategy is None: 

223 strategy = GatherStrategy() 

224 

225 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]] = [ 

226 _import_module_safe(pattern) for pattern in module_patterns 

227 ] 

228 

229 return await gather_with_strategy(tasks, strategy, cache_key) 

230 

231 

232async def _import_module_safe(module_path: str) -> t.Any: 

233 from importlib import import_module 

234 

235 if module_path.startswith("."): 

236 base_module = "fastblocks" 

237 module_path = base_module + module_path 

238 debug(f"Importing module: {module_path}") 

239 try: 

240 return import_module(module_path) 

241 except ModuleNotFoundError as e: 

242 debug(f"Module not found: {module_path} - {e}") 

243 raise 

244 except Exception as e: 

245 debug(f"Error importing {module_path}: {e}") 

246 raise 

247 

248 

249async def gather_files( 

250 file_patterns: list[str], 

251 base_path: Path | None = None, 

252 strategy: GatherStrategy | None = None, 

253 cache_key: str | None = None, 

254) -> GatherResult: 

255 if strategy is None: 

256 strategy = GatherStrategy() 

257 

258 if base_path is None: 

259 from acb.adapters import root_path 

260 

261 base_path = Path(root_path) 

262 

263 tasks: list[t.Coroutine[t.Any, t.Any, t.Any]] = [ 

264 _find_files_safe(pattern, base_path) for pattern in file_patterns 

265 ] 

266 

267 return await gather_with_strategy(tasks, strategy, cache_key) 

268 

269 

270async def _find_files_safe(pattern: str, base_path: Path) -> list[Path]: 

271 from anyio import Path as AsyncPath 

272 

273 async_base = AsyncPath(base_path) 

274 try: 

275 if "*" in pattern: 

276 async_files = [ 

277 f async for f in async_base.rglob(pattern) if await f.is_file() 

278 ] 

279 files = [Path(f) for f in async_files] 

280 else: 

281 file_path = async_base / pattern 

282 files = [Path(file_path)] if await file_path.exists() else [] 

283 debug(f"Found {len(files)} files for pattern: {pattern}") 

284 return files 

285 except Exception as e: 

286 debug(f"Error finding files for pattern {pattern}: {e}") 

287 raise 

288 

289 

290def clear_cache(cache_key: str | None = None) -> None: 

291 if cache_key: 

292 _memory_cache.pop(cache_key, None) 

293 debug(f"Cleared cache for {cache_key}") 

294 else: 

295 _memory_cache.clear() 

296 debug("Cleared all cache") 

297 

298 

299def get_cache_info() -> dict[str, t.Any]: 

300 return { 

301 "total_entries": len(_memory_cache), 

302 "cache_keys": list(_memory_cache.keys()), 

303 "memory_usage": sum(len(str(v)) for v in _memory_cache.values()), 

304 }