Coverage for aipyapp/aipy/cache.py: 37%
148 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1import sqlite3
2import time
3import json
4import threading
5import hashlib
6import functools
7from typing import Any, Optional, Callable, Union
8from pathlib import Path
9from .config import CONFIG_DIR
11CACHE_FILE = CONFIG_DIR / "cache.db"
14class KVCache:
15 """基于SQLite的KV缓存类"""
17 def __init__(self, db_path: str = "cache.db", default_ttl: int = 3600):
18 """
19 初始化缓存
21 Args:
22 db_path: SQLite数据库文件路径
23 default_ttl: 默认过期时间(秒)
24 """
25 self.db_path = db_path
26 self.default_ttl = default_ttl
27 self._lock = threading.RLock()
28 self._init_db()
30 def _init_db(self):
31 """初始化数据库表"""
32 with sqlite3.connect(self.db_path) as conn:
33 conn.execute('''
34 CREATE TABLE IF NOT EXISTS cache (
35 key TEXT PRIMARY KEY,
36 value TEXT,
37 expire_time REAL,
38 created_time REAL
39 )
40 ''')
41 conn.execute(
42 'CREATE INDEX IF NOT EXISTS idx_expire_time ON cache(expire_time)'
43 )
44 conn.commit()
46 def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
47 """
48 设置缓存
50 Args:
51 key: 缓存键
52 value: 缓存值
53 ttl: 过期时间(秒),None表示使用默认TTL
54 """
55 with self._lock:
56 current_time = time.time()
57 expire_time = current_time + (ttl or self.default_ttl)
59 # 序列化值
60 try:
61 serialized_value = json.dumps(value, ensure_ascii=False)
62 except (TypeError, ValueError) as e:
63 raise ValueError(f"无法序列化值: {e}")
65 with sqlite3.connect(self.db_path) as conn:
66 conn.execute(
67 'INSERT OR REPLACE INTO cache (key, value, expire_time, created_time) VALUES (?, ?, ?, ?)',
68 (key, serialized_value, expire_time, current_time),
69 )
70 conn.commit()
72 def get(self, key: str, default: Any = None) -> Any:
73 """
74 获取缓存
76 Args:
77 key: 缓存键
78 default: 默认值
80 Returns:
81 缓存值或默认值
82 """
83 with self._lock:
84 current_time = time.time()
86 with sqlite3.connect(self.db_path) as conn:
87 cursor = conn.execute(
88 'SELECT value, expire_time FROM cache WHERE key = ?', (key,)
89 )
90 row = cursor.fetchone()
92 if row is None:
93 return default
95 value, expire_time = row
96 if current_time > expire_time:
97 # 删除过期缓存
98 conn.execute('DELETE FROM cache WHERE key = ?', (key,))
99 conn.commit()
100 return default
102 try:
103 return json.loads(value)
104 except (json.JSONDecodeError, TypeError):
105 return default
107 def delete(self, key: str) -> bool:
108 """
109 删除缓存
111 Args:
112 key: 缓存键
114 Returns:
115 是否删除成功
116 """
117 with self._lock:
118 with sqlite3.connect(self.db_path) as conn:
119 cursor = conn.execute('DELETE FROM cache WHERE key = ?', (key,))
120 conn.commit()
121 return cursor.rowcount > 0
123 def exists(self, key: str) -> bool:
124 """
125 检查缓存是否存在且未过期
127 Args:
128 key: 缓存键
130 Returns:
131 是否存在
132 """
133 return self.get(key) is not None
135 def expire(self, key: str, ttl: int) -> bool:
136 """
137 设置缓存过期时间
139 Args:
140 key: 缓存键
141 ttl: 过期时间(秒)
143 Returns:
144 是否设置成功
145 """
146 with self._lock:
147 current_time = time.time()
148 expire_time = current_time + ttl
150 with sqlite3.connect(self.db_path) as conn:
151 cursor = conn.execute(
152 'UPDATE cache SET expire_time = ? WHERE key = ?', (expire_time, key)
153 )
154 conn.commit()
155 return cursor.rowcount > 0
157 def ttl(self, key: str) -> int:
158 """
159 获取缓存剩余过期时间
161 Args:
162 key: 缓存键
164 Returns:
165 剩余时间(秒),-1表示不存在,-2表示永不过期
166 """
167 with self._lock:
168 current_time = time.time()
170 with sqlite3.connect(self.db_path) as conn:
171 cursor = conn.execute(
172 'SELECT expire_time FROM cache WHERE key = ?', (key,)
173 )
174 row = cursor.fetchone()
176 if row is None:
177 return -1
179 expire_time = row[0]
180 if current_time > expire_time:
181 return -1
183 return int(expire_time - current_time)
185 def cleanup(self) -> int:
186 """
187 清理过期缓存
189 Returns:
190 清理的缓存数量
191 """
192 with self._lock:
193 current_time = time.time()
194 with sqlite3.connect(self.db_path) as conn:
195 cursor = conn.execute(
196 'DELETE FROM cache WHERE expire_time < ?', (current_time,)
197 )
198 conn.commit()
199 return cursor.rowcount
201 def clear(self) -> None:
202 """清空所有缓存"""
203 with self._lock:
204 with sqlite3.connect(self.db_path) as conn:
205 conn.execute('DELETE FROM cache')
206 conn.commit()
208 def size(self) -> int:
209 """获取缓存数量"""
210 with sqlite3.connect(self.db_path) as conn:
211 cursor = conn.execute('SELECT COUNT(*) FROM cache')
212 return cursor.fetchone()[0]
214 def keys(self) -> list:
215 """获取所有有效缓存键"""
216 with self._lock:
217 current_time = time.time()
218 with sqlite3.connect(self.db_path) as conn:
219 cursor = conn.execute(
220 'SELECT key FROM cache WHERE expire_time > ?', (current_time,)
221 )
222 return [row[0] for row in cursor.fetchall()]
224 def stats(self) -> dict:
225 """获取缓存统计信息"""
226 with sqlite3.connect(self.db_path) as conn:
227 cursor = conn.execute(
228 '''
229 SELECT
230 COUNT(*) as total,
231 COUNT(CASE WHEN expire_time > ? THEN 1 END) as valid,
232 COUNT(CASE WHEN expire_time <= ? THEN 1 END) as expired
233 FROM cache
234 ''',
235 (time.time(), time.time()),
236 )
238 row = cursor.fetchone()
239 return {'total': row[0], 'valid': row[1], 'expired': row[2]}
242# 全局缓存实例
243_default_cache = None
246def get_default_cache() -> KVCache:
247 """获取默认缓存实例"""
248 global _default_cache
249 if _default_cache is None:
250 _default_cache = KVCache(str(CACHE_FILE))
251 return _default_cache
254def cache_key(*args, **kwargs) -> str:
255 """生成缓存键"""
256 key_data = str(args) + str(sorted(kwargs.items()))
257 return hashlib.md5(key_data.encode()).hexdigest()
260def cached(
261 ttl: int = 3600,
262 key_func: Optional[Callable] = None,
263 cache_instance: Optional[KVCache] = None,
264):
265 """
266 缓存装饰器
268 Args:
269 ttl: 缓存过期时间(秒)
270 key_func: 自定义键生成函数
271 cache_instance: 缓存实例,None表示使用默认实例
272 """
274 def decorator(func: Callable) -> Callable:
275 @functools.wraps(func)
276 def wrapper(*args, **kwargs):
277 # 获取缓存实例
278 cache = cache_instance or get_default_cache()
280 # 生成缓存键
281 if key_func:
282 key = key_func(*args, **kwargs)
283 else:
284 func_name = f"{func.__module__}.{func.__name__}"
285 key = f"{func_name}:{cache_key(*args, **kwargs)}"
287 # 尝试从缓存获取
288 result = cache.get(key)
289 if result is not None:
290 return result
292 # 执行函数并缓存结果
293 result = func(*args, **kwargs)
294 cache.set(key, result, ttl)
295 return result
297 # 添加缓存操作方法
298 setattr(
299 wrapper,
300 'cache_clear',
301 lambda: cache_instance.clear()
302 if cache_instance
303 else get_default_cache().clear(),
304 )
305 setattr(
306 wrapper,
307 'cache_info',
308 lambda: cache_instance.stats()
309 if cache_instance
310 else get_default_cache().stats(),
311 )
313 return wrapper
315 return decorator
318# 便捷函数
319def set_cache(key: str, value: Any, ttl: Optional[int] = None) -> None:
320 """设置缓存"""
321 get_default_cache().set(key, value, ttl)
324def get_cache(key: str, default: Any = None) -> Any:
325 """获取缓存"""
326 return get_default_cache().get(key, default)
329def delete_cache(key: str) -> bool:
330 """删除缓存"""
331 return get_default_cache().delete(key)
334def clear_cache() -> None:
335 """清空缓存"""
336 get_default_cache().clear()
339def cache_exists(key: str) -> bool:
340 """检查缓存是否存在"""
341 return get_default_cache().exists(key)
344def cache_ttl(key: str) -> int:
345 """获取缓存TTL"""
346 return get_default_cache().ttl(key)
349def cleanup_cache() -> int:
350 """清理过期缓存"""
351 return get_default_cache().cleanup()
354def cache_stats() -> dict:
355 """获取缓存统计"""
356 return get_default_cache().stats()
359cleanup_cache()