Coverage for aipyapp/aipy/cache.py: 37%

148 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1import sqlite3 

2import time 

3import json 

4import threading 

5import hashlib 

6import functools 

7from typing import Any, Optional, Callable, Union 

8from pathlib import Path 

9from .config import CONFIG_DIR 

10 

11CACHE_FILE = CONFIG_DIR / "cache.db" 

12 

13 

14class KVCache: 

15 """基于SQLite的KV缓存类""" 

16 

17 def __init__(self, db_path: str = "cache.db", default_ttl: int = 3600): 

18 """ 

19 初始化缓存 

20 

21 Args: 

22 db_path: SQLite数据库文件路径 

23 default_ttl: 默认过期时间(秒) 

24 """ 

25 self.db_path = db_path 

26 self.default_ttl = default_ttl 

27 self._lock = threading.RLock() 

28 self._init_db() 

29 

30 def _init_db(self): 

31 """初始化数据库表""" 

32 with sqlite3.connect(self.db_path) as conn: 

33 conn.execute(''' 

34 CREATE TABLE IF NOT EXISTS cache ( 

35 key TEXT PRIMARY KEY, 

36 value TEXT, 

37 expire_time REAL, 

38 created_time REAL 

39 ) 

40 ''') 

41 conn.execute( 

42 'CREATE INDEX IF NOT EXISTS idx_expire_time ON cache(expire_time)' 

43 ) 

44 conn.commit() 

45 

46 def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: 

47 """ 

48 设置缓存 

49 

50 Args: 

51 key: 缓存键 

52 value: 缓存值 

53 ttl: 过期时间(秒),None表示使用默认TTL 

54 """ 

55 with self._lock: 

56 current_time = time.time() 

57 expire_time = current_time + (ttl or self.default_ttl) 

58 

59 # 序列化值 

60 try: 

61 serialized_value = json.dumps(value, ensure_ascii=False) 

62 except (TypeError, ValueError) as e: 

63 raise ValueError(f"无法序列化值: {e}") 

64 

65 with sqlite3.connect(self.db_path) as conn: 

66 conn.execute( 

67 'INSERT OR REPLACE INTO cache (key, value, expire_time, created_time) VALUES (?, ?, ?, ?)', 

68 (key, serialized_value, expire_time, current_time), 

69 ) 

70 conn.commit() 

71 

72 def get(self, key: str, default: Any = None) -> Any: 

73 """ 

74 获取缓存 

75 

76 Args: 

77 key: 缓存键 

78 default: 默认值 

79 

80 Returns: 

81 缓存值或默认值 

82 """ 

83 with self._lock: 

84 current_time = time.time() 

85 

86 with sqlite3.connect(self.db_path) as conn: 

87 cursor = conn.execute( 

88 'SELECT value, expire_time FROM cache WHERE key = ?', (key,) 

89 ) 

90 row = cursor.fetchone() 

91 

92 if row is None: 

93 return default 

94 

95 value, expire_time = row 

96 if current_time > expire_time: 

97 # 删除过期缓存 

98 conn.execute('DELETE FROM cache WHERE key = ?', (key,)) 

99 conn.commit() 

100 return default 

101 

102 try: 

103 return json.loads(value) 

104 except (json.JSONDecodeError, TypeError): 

105 return default 

106 

107 def delete(self, key: str) -> bool: 

108 """ 

109 删除缓存 

110 

111 Args: 

112 key: 缓存键 

113 

114 Returns: 

115 是否删除成功 

116 """ 

117 with self._lock: 

118 with sqlite3.connect(self.db_path) as conn: 

119 cursor = conn.execute('DELETE FROM cache WHERE key = ?', (key,)) 

120 conn.commit() 

121 return cursor.rowcount > 0 

122 

123 def exists(self, key: str) -> bool: 

124 """ 

125 检查缓存是否存在且未过期 

126 

127 Args: 

128 key: 缓存键 

129 

130 Returns: 

131 是否存在 

132 """ 

133 return self.get(key) is not None 

134 

135 def expire(self, key: str, ttl: int) -> bool: 

136 """ 

137 设置缓存过期时间 

138 

139 Args: 

140 key: 缓存键 

141 ttl: 过期时间(秒) 

142 

143 Returns: 

144 是否设置成功 

145 """ 

146 with self._lock: 

147 current_time = time.time() 

148 expire_time = current_time + ttl 

149 

150 with sqlite3.connect(self.db_path) as conn: 

151 cursor = conn.execute( 

152 'UPDATE cache SET expire_time = ? WHERE key = ?', (expire_time, key) 

153 ) 

154 conn.commit() 

155 return cursor.rowcount > 0 

156 

157 def ttl(self, key: str) -> int: 

158 """ 

159 获取缓存剩余过期时间 

160 

161 Args: 

162 key: 缓存键 

163 

164 Returns: 

165 剩余时间(秒),-1表示不存在,-2表示永不过期 

166 """ 

167 with self._lock: 

168 current_time = time.time() 

169 

170 with sqlite3.connect(self.db_path) as conn: 

171 cursor = conn.execute( 

172 'SELECT expire_time FROM cache WHERE key = ?', (key,) 

173 ) 

174 row = cursor.fetchone() 

175 

176 if row is None: 

177 return -1 

178 

179 expire_time = row[0] 

180 if current_time > expire_time: 

181 return -1 

182 

183 return int(expire_time - current_time) 

184 

185 def cleanup(self) -> int: 

186 """ 

187 清理过期缓存 

188 

189 Returns: 

190 清理的缓存数量 

191 """ 

192 with self._lock: 

193 current_time = time.time() 

194 with sqlite3.connect(self.db_path) as conn: 

195 cursor = conn.execute( 

196 'DELETE FROM cache WHERE expire_time < ?', (current_time,) 

197 ) 

198 conn.commit() 

199 return cursor.rowcount 

200 

201 def clear(self) -> None: 

202 """清空所有缓存""" 

203 with self._lock: 

204 with sqlite3.connect(self.db_path) as conn: 

205 conn.execute('DELETE FROM cache') 

206 conn.commit() 

207 

208 def size(self) -> int: 

209 """获取缓存数量""" 

210 with sqlite3.connect(self.db_path) as conn: 

211 cursor = conn.execute('SELECT COUNT(*) FROM cache') 

212 return cursor.fetchone()[0] 

213 

214 def keys(self) -> list: 

215 """获取所有有效缓存键""" 

216 with self._lock: 

217 current_time = time.time() 

218 with sqlite3.connect(self.db_path) as conn: 

219 cursor = conn.execute( 

220 'SELECT key FROM cache WHERE expire_time > ?', (current_time,) 

221 ) 

222 return [row[0] for row in cursor.fetchall()] 

223 

224 def stats(self) -> dict: 

225 """获取缓存统计信息""" 

226 with sqlite3.connect(self.db_path) as conn: 

227 cursor = conn.execute( 

228 ''' 

229 SELECT  

230 COUNT(*) as total, 

231 COUNT(CASE WHEN expire_time > ? THEN 1 END) as valid, 

232 COUNT(CASE WHEN expire_time <= ? THEN 1 END) as expired 

233 FROM cache 

234 ''', 

235 (time.time(), time.time()), 

236 ) 

237 

238 row = cursor.fetchone() 

239 return {'total': row[0], 'valid': row[1], 'expired': row[2]} 

240 

241 

242# 全局缓存实例 

243_default_cache = None 

244 

245 

246def get_default_cache() -> KVCache: 

247 """获取默认缓存实例""" 

248 global _default_cache 

249 if _default_cache is None: 

250 _default_cache = KVCache(str(CACHE_FILE)) 

251 return _default_cache 

252 

253 

254def cache_key(*args, **kwargs) -> str: 

255 """生成缓存键""" 

256 key_data = str(args) + str(sorted(kwargs.items())) 

257 return hashlib.md5(key_data.encode()).hexdigest() 

258 

259 

260def cached( 

261 ttl: int = 3600, 

262 key_func: Optional[Callable] = None, 

263 cache_instance: Optional[KVCache] = None, 

264): 

265 """ 

266 缓存装饰器 

267 

268 Args: 

269 ttl: 缓存过期时间(秒) 

270 key_func: 自定义键生成函数 

271 cache_instance: 缓存实例,None表示使用默认实例 

272 """ 

273 

274 def decorator(func: Callable) -> Callable: 

275 @functools.wraps(func) 

276 def wrapper(*args, **kwargs): 

277 # 获取缓存实例 

278 cache = cache_instance or get_default_cache() 

279 

280 # 生成缓存键 

281 if key_func: 

282 key = key_func(*args, **kwargs) 

283 else: 

284 func_name = f"{func.__module__}.{func.__name__}" 

285 key = f"{func_name}:{cache_key(*args, **kwargs)}" 

286 

287 # 尝试从缓存获取 

288 result = cache.get(key) 

289 if result is not None: 

290 return result 

291 

292 # 执行函数并缓存结果 

293 result = func(*args, **kwargs) 

294 cache.set(key, result, ttl) 

295 return result 

296 

297 # 添加缓存操作方法 

298 setattr( 

299 wrapper, 

300 'cache_clear', 

301 lambda: cache_instance.clear() 

302 if cache_instance 

303 else get_default_cache().clear(), 

304 ) 

305 setattr( 

306 wrapper, 

307 'cache_info', 

308 lambda: cache_instance.stats() 

309 if cache_instance 

310 else get_default_cache().stats(), 

311 ) 

312 

313 return wrapper 

314 

315 return decorator 

316 

317 

318# 便捷函数 

319def set_cache(key: str, value: Any, ttl: Optional[int] = None) -> None: 

320 """设置缓存""" 

321 get_default_cache().set(key, value, ttl) 

322 

323 

324def get_cache(key: str, default: Any = None) -> Any: 

325 """获取缓存""" 

326 return get_default_cache().get(key, default) 

327 

328 

329def delete_cache(key: str) -> bool: 

330 """删除缓存""" 

331 return get_default_cache().delete(key) 

332 

333 

334def clear_cache() -> None: 

335 """清空缓存""" 

336 get_default_cache().clear() 

337 

338 

339def cache_exists(key: str) -> bool: 

340 """检查缓存是否存在""" 

341 return get_default_cache().exists(key) 

342 

343 

344def cache_ttl(key: str) -> int: 

345 """获取缓存TTL""" 

346 return get_default_cache().ttl(key) 

347 

348 

349def cleanup_cache() -> int: 

350 """清理过期缓存""" 

351 return get_default_cache().cleanup() 

352 

353 

354def cache_stats() -> dict: 

355 """获取缓存统计""" 

356 return get_default_cache().stats() 

357 

358 

359cleanup_cache()