Coverage for aipyapp/aipy/context_manager.py: 23%

317 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 13:03 +0200

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import time 

5from collections import Counter 

6from dataclasses import dataclass 

7from typing import List, Dict, Any, Optional, Union, Tuple 

8from enum import Enum 

9 

10from loguru import logger 

11 

12from ..llm import ChatMessage 

13from ..interface import Trackable 

14 

15 

16class ContextStrategy(Enum): 

17 """上下文管理策略""" 

18 SLIDING_WINDOW = "sliding_window" # 滑动窗口 

19 IMPORTANCE_FILTER = "importance_filter" # 重要性过滤 

20 SUMMARY_COMPRESSION = "summary_compression" # 摘要压缩 

21 HYBRID = "hybrid" # 混合策略 

22 

23 

24@dataclass 

25class ContextConfig: 

26 """上下文管理配置""" 

27 max_tokens: int = 100000 # 最大token数 

28 max_rounds: int = 10 # 最大对话轮数 

29 auto_compress: bool = True # 是否自动压缩 

30 strategy: ContextStrategy = ContextStrategy.HYBRID 

31 compression_ratio: float = 0.3 # 压缩比例 

32 importance_threshold: float = 0.5 # 重要性阈值 

33 summary_max_length: int = 200 # 摘要最大长度 

34 preserve_system: bool = True # 是否保留系统消息 

35 preserve_recent: int = 3 # 保留最近几轮对话 

36 

37 def set_strategy(self, strategy: str) -> Union[ContextStrategy, None]: 

38 try: 

39 strategy = ContextStrategy(strategy) 

40 self.strategy = strategy 

41 except ValueError: 

42 strategy = None 

43 return strategy 

44 

45 @classmethod 

46 def from_dict(cls, data: dict) -> 'ContextConfig': 

47 obj = cls(**data) 

48 obj.set_strategy(data.get('strategy', 'hybrid')) 

49 return obj 

50 

51 def to_dict(self) -> dict: 

52 return self.__dict__ 

53 

54class TokenCounter: 

55 """Token计数器""" 

56 

57 def __init__(self): 

58 self.total_tokens = 0 

59 self.input_tokens = 0 

60 self.output_tokens = 0 

61 

62 def add_message(self, message: ChatMessage): 

63 """添加消息的token统计""" 

64 if hasattr(message, 'usage') and message.usage: 

65 self.total_tokens += message.usage.get('total_tokens', 0) 

66 self.input_tokens += message.usage.get('input_tokens', 0) 

67 self.output_tokens += message.usage.get('output_tokens', 0) 

68 

69 def estimate_tokens(self, text: str) -> int: 

70 """估算文本的token数量(简单估算:1token ≈ 4字符)""" 

71 return len(text) // 4 

72 

73 def reset(self): 

74 """重置计数器""" 

75 self.total_tokens = 0 

76 self.input_tokens = 0 

77 self.output_tokens = 0 

78 

79class ChatHistory: 

80 def __init__(self): 

81 self.messages = [] 

82 self._total_tokens = Counter() 

83 

84 def __len__(self): 

85 return len(self.messages) 

86 

87 def get_state(self): 

88 """获取需要持久化的状态数据""" 

89 return [msg.__dict__ for msg in self.messages] 

90 

91 def restore_state(self, state_data): 

92 """从状态数据恢复历史状态""" 

93 self.messages.clear() 

94 self._total_tokens = Counter() 

95 

96 if not state_data: 

97 return 

98 

99 chat_data = state_data 

100 

101 # 恢复消息 

102 if chat_data: 

103 for chat_item in chat_data: 

104 usage = Counter(chat_item.get('usage', {})) 

105 message = ChatMessage( 

106 role=chat_item['role'], 

107 content=chat_item['content'], 

108 reason=chat_item.get('reason'), 

109 usage=usage 

110 ) 

111 self.add_message(message) 

112 

113 def clear(self): 

114 self.messages.clear() 

115 self._total_tokens.clear() 

116 

117 def delete_range(self, start_index, end_index): 

118 """删除指定范围的消息""" 

119 if start_index < 0 or end_index > len(self.messages) or start_index >= end_index: 

120 return 

121 

122 # 删除指定范围的消息 

123 self.messages = self.messages[:start_index] + self.messages[end_index:] 

124 

125 # 重新计算 token 统计 

126 self._total_tokens = Counter() 

127 for msg in self.messages: 

128 self._total_tokens += msg.usage 

129 

130 def add(self, role, content): 

131 self.add_message(ChatMessage(role=role, content=content)) 

132 

133 def add_message(self, message: ChatMessage): 

134 self.messages.append(message) 

135 self._total_tokens += message.usage 

136 

137 def get_usage(self): 

138 return iter(row.usage for row in self.messages if row.role == "assistant") 

139 

140 def get_summary(self): 

141 summary = {'time': 0, 'input_tokens': 0, 'output_tokens': 0, 'total_tokens': 0} 

142 summary.update(dict(self._total_tokens)) 

143 summary['rounds'] = sum(1 for row in self.messages if row.role == "assistant") 

144 return summary 

145 

146 def get_messages(self): 

147 return [{"role": msg.role, "content": msg.content} for msg in self.messages] 

148 

149class MessageCompressor: 

150 """消息压缩器""" 

151 

152 def __init__(self, config: ContextConfig): 

153 self.config = config 

154 self.log = logger.bind(src='message_compressor') 

155 

156 def compress_messages(self, messages: List[Dict[str, Any]], 

157 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]: 

158 """压缩消息列表""" 

159 if current_tokens <= self.config.max_tokens: 

160 return messages, current_tokens 

161 

162 strategy = self.config.strategy 

163 if strategy == ContextStrategy.SLIDING_WINDOW: 

164 return self._sliding_window_compress(messages, current_tokens) 

165 elif strategy == ContextStrategy.IMPORTANCE_FILTER: 

166 return self._importance_filter_compress(messages, current_tokens) 

167 elif strategy == ContextStrategy.SUMMARY_COMPRESSION: 

168 return self._summary_compression_compress(messages, current_tokens) 

169 elif strategy == ContextStrategy.HYBRID: 

170 return self._hybrid_compress(messages, current_tokens) 

171 else: 

172 return messages, current_tokens 

173 

174 def _sliding_window_compress(self, messages: List[Dict[str, Any]], 

175 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]: 

176 """滑动窗口压缩""" 

177 preserved_messages = [] 

178 preserved_tokens = 0 

179 

180 # 保留系统消息 

181 system_messages = [msg for msg in messages if msg['role'] == 'system'] 

182 for msg in system_messages: 

183 preserved_messages.append(msg) 

184 preserved_tokens += self._estimate_message_tokens(msg) 

185 

186 # 保留最近的对话 

187 recent_messages = [msg for msg in messages if msg['role'] != 'system'] 

188 max_recent = self.config.preserve_recent * 2 # 用户+助手消息 

189 

190 for msg in recent_messages[-max_recent:]: 

191 msg_tokens = self._estimate_message_tokens(msg) 

192 if preserved_tokens + msg_tokens <= self.config.max_tokens: 

193 preserved_messages.append(msg) 

194 preserved_tokens += msg_tokens 

195 else: 

196 break 

197 

198 self.log.info(f"Sliding window compression: {len(messages)} -> {len(preserved_messages)} messages") 

199 return preserved_messages, preserved_tokens 

200 

201 def _importance_filter_compress(self, messages: List[Dict[str, Any]], 

202 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]: 

203 """重要性过滤压缩""" 

204 # 计算消息重要性分数 

205 scored_messages = [] 

206 for i, msg in enumerate(messages): 

207 score = self._calculate_importance_score(msg, i, len(messages)) 

208 scored_messages.append((score, msg)) 

209 

210 # 按重要性排序 

211 scored_messages.sort(key=lambda x: x[0], reverse=True) 

212 

213 preserved_messages = [] 

214 preserved_tokens = 0 

215 

216 for score, msg in scored_messages: 

217 msg_tokens = self._estimate_message_tokens(msg) 

218 if preserved_tokens + msg_tokens <= self.config.max_tokens: 

219 preserved_messages.append(msg) 

220 preserved_tokens += msg_tokens 

221 else: 

222 break 

223 

224 # 按原始顺序重新排序 

225 preserved_messages.sort(key=lambda x: messages.index(x)) 

226 

227 self.log.info(f"Importance filter compression: {len(messages)} -> {len(preserved_messages)} messages") 

228 return preserved_messages, preserved_tokens 

229 

230 def _summary_compression_compress(self, messages: List[Dict[str, Any]], 

231 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]: 

232 """摘要压缩""" 

233 preserved_messages = [] 

234 preserved_tokens = 0 

235 

236 # 保留系统消息 

237 system_messages = [msg for msg in messages if msg['role'] == 'system'] 

238 for msg in system_messages: 

239 preserved_messages.append(msg) 

240 preserved_tokens += self._estimate_message_tokens(msg) 

241 

242 # 保留最近的对话 

243 recent_messages = [msg for msg in messages if msg['role'] != 'system'] 

244 max_recent = self.config.preserve_recent * 2 

245 

246 # 分割消息 

247 if len(recent_messages) > max_recent: 

248 old_messages = recent_messages[:-max_recent] 

249 new_messages = recent_messages[-max_recent:] 

250 

251 # 创建摘要消息 

252 summary_content = self._create_summary(old_messages) 

253 summary_msg = { 

254 'role': 'system', 

255 'content': f"对话历史摘要:{summary_content}" 

256 } 

257 

258 preserved_messages.append(summary_msg) 

259 preserved_tokens += self._estimate_message_tokens(summary_msg) 

260 

261 # 添加新消息 

262 for msg in recent_messages[-max_recent:]: 

263 msg_tokens = self._estimate_message_tokens(msg) 

264 if preserved_tokens + msg_tokens <= self.config.max_tokens: 

265 preserved_messages.append(msg) 

266 preserved_tokens += msg_tokens 

267 else: 

268 break 

269 

270 self.log.info(f"Summary compression: {len(messages)} -> {len(preserved_messages)} messages") 

271 return preserved_messages, preserved_tokens 

272 

273 def _hybrid_compress(self, messages: List[Dict[str, Any]], 

274 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]: 

275 """混合压缩策略""" 

276 # 首先尝试滑动窗口 

277 compressed_messages, compressed_tokens = self._sliding_window_compress(messages, current_tokens) 

278 

279 # 如果仍然超出限制,使用摘要压缩 

280 if compressed_tokens > self.config.max_tokens: 

281 compressed_messages, compressed_tokens = self._summary_compression_compress(messages, current_tokens) 

282 

283 return compressed_messages, compressed_tokens 

284 

285 def _calculate_importance_score(self, message: Dict[str, Any], index: int, total: int) -> float: 

286 """计算消息重要性分数""" 

287 score = 0.0 

288 

289 # 基于角色评分 

290 role = message.get('role', '') 

291 if role == 'system': 

292 score += 1.0 # 系统消息最重要 

293 elif role == 'user': 

294 score += 0.8 # 用户消息次重要 

295 elif role == 'assistant': 

296 score += 0.6 # 助手消息一般重要 

297 

298 # 基于位置评分(越新越重要) 

299 recency = (index + 1) / total 

300 score += recency * 0.4 

301 

302 # 基于内容长度评分 

303 content = message.get('content', '') 

304 if isinstance(content, str): 

305 length_score = min(len(content) / 1000, 1.0) # 标准化到0-1 

306 score += length_score * 0.2 

307 

308 return score 

309 

310 def _create_summary(self, messages: List[Dict[str, Any]]) -> str: 

311 """创建消息摘要""" 

312 summary_parts = [] 

313 

314 for msg in messages: 

315 role = msg.get('role', '') 

316 content = msg.get('content', '') 

317 

318 if isinstance(content, str): 

319 # 截取前100个字符作为摘要 

320 content_preview = content[:100] + ('...' if len(content) > 100 else '') 

321 summary_parts.append(f"{role}: {content_preview}") 

322 

323 summary = " | ".join(summary_parts) 

324 return summary[:self.config.summary_max_length] 

325 

326 def _estimate_message_tokens(self, message: Dict[str, Any]) -> int: 

327 """估算消息的token数量""" 

328 content = message.get('content', '') 

329 if isinstance(content, str): 

330 return len(content) // 4 # 简单估算 

331 elif isinstance(content, list): 

332 # 多模态内容 

333 total_length = 0 

334 for item in content: 

335 if item.get('type') == 'text': 

336 total_length += len(item.get('text', '')) 

337 return total_length // 4 

338 return 0 

339 

340 

341class ContextManager(Trackable): 

342 """上下文管理器""" 

343 

344 def __init__(self, config: Optional[ContextConfig] = None): 

345 self.config = config or ContextConfig() 

346 self.token_counter = TokenCounter() 

347 self.compressor = MessageCompressor(self.config) 

348 self.log = logger.bind(src='context_manager') 

349 

350 # 原始消息历史 

351 self.chat_history = ChatHistory() 

352 

353 # 消息缓存 

354 self._messages_cache: List[Dict[str, Any]] = [] 

355 self._cached_tokens = 0 

356 self._last_compression_time = 0 

357 

358 def __len__(self): 

359 """返回消息数量,用于兼容旧接口""" 

360 return len(self.chat_history) 

361 

362 def add(self, role, content): 

363 """添加简单消息,用于兼容旧接口""" 

364 message = ChatMessage(role=role, content=content) 

365 self.add_message(message) 

366 

367 def add_message(self, message: ChatMessage): 

368 """添加消息到上下文""" 

369 # 添加到原始历史 

370 self.chat_history.add_message(message) 

371 

372 # 转换为字典格式添加到缓存 

373 msg_dict = { 

374 'role': message.role, 

375 'content': message.content 

376 } 

377 

378 # 更新token计数 

379 self.token_counter.add_message(message) 

380 

381 # 添加到缓存 

382 self._messages_cache.append(msg_dict) 

383 self._cached_tokens += self.compressor._estimate_message_tokens(msg_dict) 

384 

385 self.log.debug(f"Added message: {message.role}, tokens: {self._cached_tokens}") 

386 

387 def get_messages(self, force_compress: bool = False) -> List[Dict[str, Any]]: 

388 """获取压缩后的消息列表""" 

389 current_time = time.time() 

390 

391 # 检查是否需要压缩 

392 if force_compress or self.config.auto_compress: 

393 should_compress = ( 

394 force_compress or 

395 self._cached_tokens > self.config.max_tokens or 

396 len(self._messages_cache) > self.config.max_rounds * 2 or 

397 (current_time - self._last_compression_time) > 300 # 5分钟强制压缩 

398 ) 

399 

400 if should_compress: 

401 self._compress_messages() 

402 

403 return self._messages_cache.copy() 

404 

405 def get_usage(self): 

406 """获取使用统计""" 

407 return self.chat_history.get_usage() 

408 

409 def get_summary(self): 

410 """获取摘要统计""" 

411 return self.chat_history.get_summary() 

412 

413 def json(self): 

414 """获取原始消息的JSON格式,用于序列化""" 

415 return self.chat_history.get_state() 

416 

417 def _compress_messages(self): 

418 """压缩消息""" 

419 if not self._messages_cache: 

420 return 

421 

422 original_count = len(self._messages_cache) 

423 original_tokens = self._cached_tokens 

424 

425 # 执行压缩 

426 compressed_messages, compressed_tokens = self.compressor.compress_messages( 

427 self._messages_cache, self._cached_tokens 

428 ) 

429 

430 # 更新缓存 

431 self._messages_cache = compressed_messages 

432 self._cached_tokens = compressed_tokens 

433 self._last_compression_time = time.time() 

434 

435 self.log.info( 

436 f"Context compressed: {original_count}->{len(compressed_messages)} messages, " 

437 f"{original_tokens}->{compressed_tokens} tokens" 

438 ) 

439 

440 def get_stats(self) -> Dict[str, Any]: 

441 """获取上下文统计信息""" 

442 return { 

443 'message_count': len(self._messages_cache), 

444 'total_tokens': self._cached_tokens, 

445 'max_tokens': self.config.max_tokens, 

446 'compression_ratio': len(self._messages_cache) / max(len(self._messages_cache), 1), 

447 'last_compression': self._last_compression_time 

448 } 

449 

450 def _clear_messages_cache(self): 

451 """清理消息缓存,只保留最初的两条消息和最后一条消息""" 

452 if not self._messages_cache: 

453 return 

454 

455 # 如果消息数量小于等于3,不需要清理 

456 if len(self._messages_cache) <= 3: 

457 return 

458 

459 # 保留最初的两条消息和最后一条消息 

460 first_two = self._messages_cache[:2] 

461 last_one = self._messages_cache[-1:] 

462 

463 # 合并消息 

464 self._messages_cache = first_two + last_one 

465 

466 # 重新计算 token 数量 

467 self._cached_tokens = sum(self.compressor._estimate_message_tokens(msg) for msg in self._messages_cache) 

468 self.log.info(f"Context cleaned: {len(self._messages_cache)} messages, {self._cached_tokens} tokens") 

469 

470 def clear(self): 

471 """清空上下文""" 

472 #self.chat_history.clear() 

473 self._clear_messages_cache() 

474 #self._cached_tokens = 0 

475 self.token_counter.reset() 

476 self._last_compression_time = 0 

477 self.log.info("Context cleared") 

478 

479 

480 def _rebuild_cache(self): 

481 """重建消息缓存""" 

482 self._messages_cache.clear() 

483 self._cached_tokens = 0 

484 self.token_counter.reset() 

485 

486 for message in self.chat_history.messages: 

487 msg_dict = { 

488 'role': message.role, 

489 'content': message.content 

490 } 

491 self._messages_cache.append(msg_dict) 

492 self._cached_tokens += self.compressor._estimate_message_tokens(msg_dict) 

493 self.token_counter.add_message(message) 

494 

495 def update_config(self, config: ContextConfig): 

496 """更新配置""" 

497 self.config = config 

498 self.compressor = MessageCompressor(config) 

499 self.log.info(f"Context config updated: {config.strategy.value}") 

500 

501 def get_state(self): 

502 """获取需要持久化的状态数据""" 

503 return { 

504 'config': { 

505 'max_tokens': self.config.max_tokens, 

506 'max_rounds': self.config.max_rounds, 

507 'auto_compress': self.config.auto_compress, 

508 'strategy': self.config.strategy.value, 

509 'compression_ratio': self.config.compression_ratio, 

510 'importance_threshold': self.config.importance_threshold, 

511 'summary_max_length': self.config.summary_max_length, 

512 'preserve_system': self.config.preserve_system, 

513 'preserve_recent': self.config.preserve_recent, 

514 }, 

515 'chat_history': self.chat_history.get_state(), 

516 'messages_cache': self._messages_cache.copy(), 

517 'cached_tokens': self._cached_tokens, 

518 'token_counter': { 

519 'total_tokens': self.token_counter.total_tokens, 

520 'input_tokens': self.token_counter.input_tokens, 

521 'output_tokens': self.token_counter.output_tokens, 

522 }, 

523 'last_compression_time': self._last_compression_time, 

524 } 

525 

526 def restore_state(self, state_data): 

527 """从状态数据恢复上下文管理器""" 

528 if not state_data: 

529 return 

530 

531 # 恢复配置 

532 if 'config' in state_data: 

533 config_data = state_data['config'] 

534 self.config = ContextConfig( 

535 max_tokens=config_data.get('max_tokens', 100000), 

536 max_rounds=config_data.get('max_rounds', 10), 

537 auto_compress=config_data.get('auto_compress', False), 

538 strategy=ContextStrategy(config_data.get('strategy', 'hybrid')), 

539 compression_ratio=config_data.get('compression_ratio', 0.3), 

540 importance_threshold=config_data.get('importance_threshold', 0.5), 

541 summary_max_length=config_data.get('summary_max_length', 200), 

542 preserve_system=config_data.get('preserve_system', True), 

543 preserve_recent=config_data.get('preserve_recent', 3), 

544 ) 

545 self.compressor = MessageCompressor(self.config) 

546 

547 # 恢复原始聊天历史 

548 if 'chat_history' in state_data: 

549 self.chat_history.restore_state(state_data['chat_history']) 

550 

551 # 恢复消息缓存 

552 self._messages_cache = state_data.get('messages_cache', []).copy() 

553 self._cached_tokens = state_data.get('cached_tokens', 0) 

554 

555 # 恢复token计数器 

556 if 'token_counter' in state_data: 

557 token_data = state_data['token_counter'] 

558 self.token_counter.total_tokens = token_data.get('total_tokens', 0) 

559 self.token_counter.input_tokens = token_data.get('input_tokens', 0) 

560 self.token_counter.output_tokens = token_data.get('output_tokens', 0) 

561 

562 # 恢复其他状态 

563 self._last_compression_time = state_data.get('last_compression_time', 0) 

564 

565 self.log.info(f"Context manager state restored: {len(self._messages_cache)} messages, {self._cached_tokens} tokens") 

566 

567 # Trackable接口实现 

568 def get_checkpoint(self) -> int: 

569 """获取当前检查点状态 - 返回消息数量""" 

570 return len(self.chat_history) 

571 

572 def restore_to_checkpoint(self, checkpoint: Optional[int]): 

573 """恢复到指定检查点""" 

574 if checkpoint is None: 

575 # 恢复到初始状态 

576 self.clear() 

577 else: 

578 # 恢复到指定消息数量 

579 if checkpoint < len(self.chat_history): 

580 # 删除checkpoint之后的消息 

581 self.chat_history.delete_range(checkpoint, len(self.chat_history)) 

582 # 重建缓存 

583 self._rebuild_cache()