Coverage for fastblocks/actions/sync/strategies.py: 33%

177 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-21 04:50 -0700

1"""Sync strategies for handling synchronization operations with conflict resolution.""" 

2 

3import asyncio 

4import typing as t 

5from enum import Enum 

6from pathlib import Path 

7 

8from acb.debug import debug 

9 

10 

11class SyncDirection(Enum): 

12 PULL = "pull" 

13 PUSH = "push" 

14 BIDIRECTIONAL = "bidirectional" 

15 

16 

17class ConflictStrategy(Enum): 

18 REMOTE_WINS = "remote_wins" 

19 LOCAL_WINS = "local_wins" 

20 NEWEST_WINS = "newest_wins" 

21 MANUAL = "manual" 

22 BACKUP_BOTH = "backup_both" 

23 

24 

25class SyncStrategy: 

26 def __init__( 

27 self, 

28 *, 

29 direction: SyncDirection = SyncDirection.BIDIRECTIONAL, 

30 conflict_strategy: ConflictStrategy = ConflictStrategy.NEWEST_WINS, 

31 dry_run: bool = False, 

32 backup_on_conflict: bool = True, 

33 parallel: bool = True, 

34 max_concurrent: int = 5, 

35 timeout: float = 30.0, 

36 retry_attempts: int = 2, 

37 retry_delay: float = 0.5, 

38 ) -> None: 

39 self.direction = direction 

40 self.conflict_strategy = conflict_strategy 

41 self.dry_run = dry_run 

42 self.backup_on_conflict = backup_on_conflict 

43 self.parallel = parallel 

44 self.max_concurrent = max_concurrent 

45 self.timeout = timeout 

46 self.retry_attempts = retry_attempts 

47 self.retry_delay = retry_delay 

48 

49 

50class SyncResult: 

51 def __init__( 

52 self, 

53 *, 

54 synced_items: list[str] | None = None, 

55 conflicts: list[dict[str, t.Any]] | None = None, 

56 errors: list[Exception] | None = None, 

57 skipped: list[str] | None = None, 

58 backed_up: list[str] | None = None, 

59 ) -> None: 

60 self.synced_items = synced_items if synced_items is not None else [] 

61 self.conflicts = conflicts if conflicts is not None else [] 

62 self.errors = errors if errors is not None else [] 

63 self.skipped = skipped if skipped is not None else [] 

64 self.backed_up = backed_up if backed_up is not None else [] 

65 

66 @property 

67 def total_processed(self) -> int: 

68 return ( 

69 len(self.synced_items) 

70 + len(self.conflicts) 

71 + len(self.errors) 

72 + len(self.skipped) 

73 ) 

74 

75 @property 

76 def success_count(self) -> int: 

77 return len(self.synced_items) 

78 

79 @property 

80 def has_conflicts(self) -> bool: 

81 return len(self.conflicts) > 0 

82 

83 @property 

84 def has_errors(self) -> bool: 

85 return len(self.errors) > 0 

86 

87 @property 

88 def is_success(self) -> bool: 

89 return not self.has_errors and not self.has_conflicts 

90 

91 

92async def sync_with_strategy( 

93 sync_tasks: list[t.Coroutine[t.Any, t.Any, t.Any]], 

94 strategy: SyncStrategy, 

95) -> SyncResult: 

96 result = SyncResult() 

97 

98 if strategy.parallel and len(sync_tasks) > 1: 

99 await _execute_parallel_sync(sync_tasks, strategy, result) 

100 else: 

101 await _execute_sequential_sync(sync_tasks, strategy, result) 

102 

103 debug( 

104 f"Sync completed: {result.success_count} synced, {len(result.conflicts)} conflicts, {len(result.errors)} errors", 

105 ) 

106 

107 return result 

108 

109 

110async def _execute_parallel_sync( 

111 sync_tasks: list[t.Coroutine[t.Any, t.Any, t.Any]], 

112 strategy: SyncStrategy, 

113 result: SyncResult, 

114) -> None: 

115 semaphore = asyncio.Semaphore(strategy.max_concurrent) 

116 

117 async def execute_with_semaphore( 

118 task: t.Coroutine[t.Any, t.Any, t.Any], 

119 ) -> t.Any: 

120 async with semaphore: 

121 return await _execute_with_retry(task, strategy) 

122 

123 try: 

124 results = await asyncio.wait_for( 

125 asyncio.gather( 

126 *[execute_with_semaphore(task) for task in sync_tasks], 

127 return_exceptions=True, 

128 ), 

129 timeout=strategy.timeout, 

130 ) 

131 

132 _process_parallel_results(results, result) 

133 

134 except TimeoutError as e: 

135 result.errors.append(e) 

136 debug(f"Sync timeout after {strategy.timeout}s") 

137 

138 

139def _process_parallel_results(results: list[t.Any], result: SyncResult) -> None: 

140 for task_result in results: 

141 if isinstance(task_result, Exception): 

142 result.errors.append(task_result) 

143 elif isinstance(task_result, dict): 

144 _merge_sync_result(result, task_result) 

145 

146 

147async def _execute_sequential_sync( 

148 sync_tasks: list[t.Coroutine[t.Any, t.Any, t.Any]], 

149 strategy: SyncStrategy, 

150 result: SyncResult, 

151) -> None: 

152 for task in sync_tasks: 

153 try: 

154 task_result = await asyncio.wait_for( 

155 _execute_with_retry(task, strategy), 

156 timeout=strategy.timeout, 

157 ) 

158 

159 if isinstance(task_result, dict): 

160 _merge_sync_result(result, task_result) 

161 

162 except Exception as e: 

163 result.errors.append(e) 

164 debug(f"Sync task failed: {e}") 

165 

166 

167async def _execute_with_retry( 

168 task: t.Coroutine[t.Any, t.Any, t.Any], 

169 strategy: SyncStrategy, 

170) -> t.Any: 

171 for attempt in range(strategy.retry_attempts + 1): 

172 try: 

173 return await task 

174 except Exception as e: 

175 if attempt == strategy.retry_attempts: 

176 raise 

177 

178 debug(f"Sync retry attempt {attempt + 1} after error: {e}") 

179 await asyncio.sleep(strategy.retry_delay * (attempt + 1)) 

180 

181 msg = "Should not reach here" 

182 raise RuntimeError(msg) 

183 

184 

185def _merge_sync_result(main_result: SyncResult, task_result: dict[str, t.Any]) -> None: 

186 main_result.synced_items.extend(task_result.get("synced", [])) 

187 main_result.conflicts.extend(task_result.get("conflicts", [])) 

188 main_result.errors.extend(task_result.get("errors", [])) 

189 main_result.skipped.extend(task_result.get("skipped", [])) 

190 main_result.backed_up.extend(task_result.get("backed_up", [])) 

191 

192 

193async def resolve_conflict( 

194 local_path: Path, 

195 remote_content: bytes, 

196 local_content: bytes, 

197 strategy: ConflictStrategy, 

198 local_mtime: float | None = None, 

199 remote_mtime: float | None = None, 

200) -> tuple[bytes, str]: 

201 if strategy == ConflictStrategy.REMOTE_WINS: 

202 return remote_content, "remote_wins" 

203 

204 if strategy == ConflictStrategy.LOCAL_WINS: 

205 return local_content, "local_wins" 

206 

207 if strategy == ConflictStrategy.NEWEST_WINS: 

208 if local_mtime and remote_mtime: 

209 if remote_mtime > local_mtime: 

210 return remote_content, f"remote_newer({remote_mtime} > {local_mtime})" 

211 return local_content, f"local_newer({local_mtime} >= {remote_mtime})" 

212 return remote_content, "newest_wins_fallback_remote" 

213 

214 if strategy == ConflictStrategy.BACKUP_BOTH: 

215 return remote_content, "backup_both" 

216 

217 if strategy == ConflictStrategy.MANUAL: 

218 msg = f"Manual conflict resolution required for {local_path}" 

219 raise ValueError(msg) 

220 

221 msg = f"Unknown conflict strategy: {strategy}" 

222 raise ValueError(msg) 

223 

224 

225async def create_backup(file_path: Path, suffix: str | None = None) -> Path: 

226 if suffix is None: 

227 import time 

228 

229 timestamp = int(time.time()) 

230 suffix = f"backup_{timestamp}" 

231 backup_path = file_path.with_suffix(f"{file_path.suffix}.{suffix}") 

232 try: 

233 if file_path.exists(): 

234 import shutil 

235 

236 shutil.copy2(file_path, backup_path) 

237 debug(f"Created backup: {backup_path}") 

238 

239 return backup_path 

240 except Exception as e: 

241 debug(f"Error creating backup for {file_path}: {e}") 

242 raise 

243 

244 

245def compare_content( 

246 content1: bytes, 

247 content2: bytes, 

248 use_hash: bool = True, 

249) -> bool: 

250 if len(content1) != len(content2): 

251 return False 

252 

253 if use_hash and len(content1) > 1024: 

254 import hashlib 

255 

256 hash1 = hashlib.blake2b(content1).hexdigest() 

257 hash2 = hashlib.blake2b(content2).hexdigest() 

258 return hash1 == hash2 

259 

260 return content1 == content2 

261 

262 

263async def get_file_info(file_path: Path) -> dict[str, t.Any]: 

264 try: 

265 if not file_path.exists(): 

266 return { 

267 "exists": False, 

268 "size": 0, 

269 "mtime": 0, 

270 "content_hash": None, 

271 } 

272 stat = file_path.stat() 

273 content = file_path.read_bytes() 

274 import hashlib 

275 

276 content_hash = hashlib.blake2b(content).hexdigest() 

277 

278 return { 

279 "exists": True, 

280 "size": stat.st_size, 

281 "mtime": stat.st_mtime, 

282 "content_hash": content_hash, 

283 "content": content, 

284 } 

285 except Exception as e: 

286 debug(f"Error getting file info for {file_path}: {e}") 

287 return { 

288 "exists": False, 

289 "size": 0, 

290 "mtime": 0, 

291 "content_hash": None, 

292 "error": str(e), 

293 } 

294 

295 

296def should_sync( 

297 local_info: dict[str, t.Any], 

298 remote_info: dict[str, t.Any], 

299 direction: SyncDirection, 

300) -> tuple[bool, str]: 

301 local_exists = local_info["exists"] 

302 remote_exists = remote_info["exists"] 

303 

304 if missing_result := _check_missing_files(local_exists, remote_exists, direction): 

305 return missing_result 

306 

307 if local_exists and remote_exists: 

308 return _check_content_differences(local_info, remote_info, direction) 

309 

310 return False, "content_identical" 

311 

312 

313def _check_missing_files( 

314 local_exists: bool, 

315 remote_exists: bool, 

316 direction: SyncDirection, 

317) -> tuple[bool, str] | None: 

318 if not local_exists and remote_exists: 

319 if direction in (SyncDirection.PULL, SyncDirection.BIDIRECTIONAL): 

320 return True, "local_missing" 

321 

322 if local_exists and not remote_exists: 

323 if direction in (SyncDirection.PUSH, SyncDirection.BIDIRECTIONAL): 

324 return True, "remote_missing" 

325 

326 if not local_exists and not remote_exists: 

327 return False, "both_missing" 

328 

329 return None 

330 

331 

332def _check_content_differences( 

333 local_info: dict[str, t.Any], 

334 remote_info: dict[str, t.Any], 

335 direction: SyncDirection, 

336) -> tuple[bool, str]: 

337 if local_info["content_hash"] == remote_info["content_hash"]: 

338 return False, "content_identical" 

339 

340 direction_reasons = { 

341 SyncDirection.PULL: "content_differs_pull", 

342 SyncDirection.PUSH: "content_differs_push", 

343 SyncDirection.BIDIRECTIONAL: "content_differs_bidirectional", 

344 } 

345 

346 return True, direction_reasons.get(direction, "content_differs") 

347 

348 

349def get_sync_summary(result: SyncResult) -> dict[str, t.Any]: 

350 return { 

351 "total_processed": result.total_processed, 

352 "synced": result.success_count, 

353 "conflicts": len(result.conflicts), 

354 "errors": len(result.errors), 

355 "skipped": len(result.skipped), 

356 "backed_up": len(result.backed_up), 

357 "success_rate": result.success_count / max(result.total_processed, 1), 

358 "has_conflicts": result.has_conflicts, 

359 "has_errors": result.has_errors, 

360 "is_success": result.is_success, 

361 }