Coverage for aipyapp/aipy/context_manager.py: 23%
317 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 13:03 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 13:03 +0200
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import time
5from collections import Counter
6from dataclasses import dataclass
7from typing import List, Dict, Any, Optional, Union, Tuple
8from enum import Enum
10from loguru import logger
12from ..llm import ChatMessage
13from ..interface import Trackable
16class ContextStrategy(Enum):
17 """上下文管理策略"""
18 SLIDING_WINDOW = "sliding_window" # 滑动窗口
19 IMPORTANCE_FILTER = "importance_filter" # 重要性过滤
20 SUMMARY_COMPRESSION = "summary_compression" # 摘要压缩
21 HYBRID = "hybrid" # 混合策略
24@dataclass
25class ContextConfig:
26 """上下文管理配置"""
27 max_tokens: int = 100000 # 最大token数
28 max_rounds: int = 10 # 最大对话轮数
29 auto_compress: bool = True # 是否自动压缩
30 strategy: ContextStrategy = ContextStrategy.HYBRID
31 compression_ratio: float = 0.3 # 压缩比例
32 importance_threshold: float = 0.5 # 重要性阈值
33 summary_max_length: int = 200 # 摘要最大长度
34 preserve_system: bool = True # 是否保留系统消息
35 preserve_recent: int = 3 # 保留最近几轮对话
37 def set_strategy(self, strategy: str) -> Union[ContextStrategy, None]:
38 try:
39 strategy = ContextStrategy(strategy)
40 self.strategy = strategy
41 except ValueError:
42 strategy = None
43 return strategy
45 @classmethod
46 def from_dict(cls, data: dict) -> 'ContextConfig':
47 obj = cls(**data)
48 obj.set_strategy(data.get('strategy', 'hybrid'))
49 return obj
51 def to_dict(self) -> dict:
52 return self.__dict__
54class TokenCounter:
55 """Token计数器"""
57 def __init__(self):
58 self.total_tokens = 0
59 self.input_tokens = 0
60 self.output_tokens = 0
62 def add_message(self, message: ChatMessage):
63 """添加消息的token统计"""
64 if hasattr(message, 'usage') and message.usage:
65 self.total_tokens += message.usage.get('total_tokens', 0)
66 self.input_tokens += message.usage.get('input_tokens', 0)
67 self.output_tokens += message.usage.get('output_tokens', 0)
69 def estimate_tokens(self, text: str) -> int:
70 """估算文本的token数量(简单估算:1token ≈ 4字符)"""
71 return len(text) // 4
73 def reset(self):
74 """重置计数器"""
75 self.total_tokens = 0
76 self.input_tokens = 0
77 self.output_tokens = 0
79class ChatHistory:
80 def __init__(self):
81 self.messages = []
82 self._total_tokens = Counter()
84 def __len__(self):
85 return len(self.messages)
87 def get_state(self):
88 """获取需要持久化的状态数据"""
89 return [msg.__dict__ for msg in self.messages]
91 def restore_state(self, state_data):
92 """从状态数据恢复历史状态"""
93 self.messages.clear()
94 self._total_tokens = Counter()
96 if not state_data:
97 return
99 chat_data = state_data
101 # 恢复消息
102 if chat_data:
103 for chat_item in chat_data:
104 usage = Counter(chat_item.get('usage', {}))
105 message = ChatMessage(
106 role=chat_item['role'],
107 content=chat_item['content'],
108 reason=chat_item.get('reason'),
109 usage=usage
110 )
111 self.add_message(message)
113 def clear(self):
114 self.messages.clear()
115 self._total_tokens.clear()
117 def delete_range(self, start_index, end_index):
118 """删除指定范围的消息"""
119 if start_index < 0 or end_index > len(self.messages) or start_index >= end_index:
120 return
122 # 删除指定范围的消息
123 self.messages = self.messages[:start_index] + self.messages[end_index:]
125 # 重新计算 token 统计
126 self._total_tokens = Counter()
127 for msg in self.messages:
128 self._total_tokens += msg.usage
130 def add(self, role, content):
131 self.add_message(ChatMessage(role=role, content=content))
133 def add_message(self, message: ChatMessage):
134 self.messages.append(message)
135 self._total_tokens += message.usage
137 def get_usage(self):
138 return iter(row.usage for row in self.messages if row.role == "assistant")
140 def get_summary(self):
141 summary = {'time': 0, 'input_tokens': 0, 'output_tokens': 0, 'total_tokens': 0}
142 summary.update(dict(self._total_tokens))
143 summary['rounds'] = sum(1 for row in self.messages if row.role == "assistant")
144 return summary
146 def get_messages(self):
147 return [{"role": msg.role, "content": msg.content} for msg in self.messages]
149class MessageCompressor:
150 """消息压缩器"""
152 def __init__(self, config: ContextConfig):
153 self.config = config
154 self.log = logger.bind(src='message_compressor')
156 def compress_messages(self, messages: List[Dict[str, Any]],
157 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]:
158 """压缩消息列表"""
159 if current_tokens <= self.config.max_tokens:
160 return messages, current_tokens
162 strategy = self.config.strategy
163 if strategy == ContextStrategy.SLIDING_WINDOW:
164 return self._sliding_window_compress(messages, current_tokens)
165 elif strategy == ContextStrategy.IMPORTANCE_FILTER:
166 return self._importance_filter_compress(messages, current_tokens)
167 elif strategy == ContextStrategy.SUMMARY_COMPRESSION:
168 return self._summary_compression_compress(messages, current_tokens)
169 elif strategy == ContextStrategy.HYBRID:
170 return self._hybrid_compress(messages, current_tokens)
171 else:
172 return messages, current_tokens
174 def _sliding_window_compress(self, messages: List[Dict[str, Any]],
175 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]:
176 """滑动窗口压缩"""
177 preserved_messages = []
178 preserved_tokens = 0
180 # 保留系统消息
181 system_messages = [msg for msg in messages if msg['role'] == 'system']
182 for msg in system_messages:
183 preserved_messages.append(msg)
184 preserved_tokens += self._estimate_message_tokens(msg)
186 # 保留最近的对话
187 recent_messages = [msg for msg in messages if msg['role'] != 'system']
188 max_recent = self.config.preserve_recent * 2 # 用户+助手消息
190 for msg in recent_messages[-max_recent:]:
191 msg_tokens = self._estimate_message_tokens(msg)
192 if preserved_tokens + msg_tokens <= self.config.max_tokens:
193 preserved_messages.append(msg)
194 preserved_tokens += msg_tokens
195 else:
196 break
198 self.log.info(f"Sliding window compression: {len(messages)} -> {len(preserved_messages)} messages")
199 return preserved_messages, preserved_tokens
201 def _importance_filter_compress(self, messages: List[Dict[str, Any]],
202 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]:
203 """重要性过滤压缩"""
204 # 计算消息重要性分数
205 scored_messages = []
206 for i, msg in enumerate(messages):
207 score = self._calculate_importance_score(msg, i, len(messages))
208 scored_messages.append((score, msg))
210 # 按重要性排序
211 scored_messages.sort(key=lambda x: x[0], reverse=True)
213 preserved_messages = []
214 preserved_tokens = 0
216 for score, msg in scored_messages:
217 msg_tokens = self._estimate_message_tokens(msg)
218 if preserved_tokens + msg_tokens <= self.config.max_tokens:
219 preserved_messages.append(msg)
220 preserved_tokens += msg_tokens
221 else:
222 break
224 # 按原始顺序重新排序
225 preserved_messages.sort(key=lambda x: messages.index(x))
227 self.log.info(f"Importance filter compression: {len(messages)} -> {len(preserved_messages)} messages")
228 return preserved_messages, preserved_tokens
230 def _summary_compression_compress(self, messages: List[Dict[str, Any]],
231 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]:
232 """摘要压缩"""
233 preserved_messages = []
234 preserved_tokens = 0
236 # 保留系统消息
237 system_messages = [msg for msg in messages if msg['role'] == 'system']
238 for msg in system_messages:
239 preserved_messages.append(msg)
240 preserved_tokens += self._estimate_message_tokens(msg)
242 # 保留最近的对话
243 recent_messages = [msg for msg in messages if msg['role'] != 'system']
244 max_recent = self.config.preserve_recent * 2
246 # 分割消息
247 if len(recent_messages) > max_recent:
248 old_messages = recent_messages[:-max_recent]
249 new_messages = recent_messages[-max_recent:]
251 # 创建摘要消息
252 summary_content = self._create_summary(old_messages)
253 summary_msg = {
254 'role': 'system',
255 'content': f"对话历史摘要:{summary_content}"
256 }
258 preserved_messages.append(summary_msg)
259 preserved_tokens += self._estimate_message_tokens(summary_msg)
261 # 添加新消息
262 for msg in recent_messages[-max_recent:]:
263 msg_tokens = self._estimate_message_tokens(msg)
264 if preserved_tokens + msg_tokens <= self.config.max_tokens:
265 preserved_messages.append(msg)
266 preserved_tokens += msg_tokens
267 else:
268 break
270 self.log.info(f"Summary compression: {len(messages)} -> {len(preserved_messages)} messages")
271 return preserved_messages, preserved_tokens
273 def _hybrid_compress(self, messages: List[Dict[str, Any]],
274 current_tokens: int) -> Tuple[List[Dict[str, Any]], int]:
275 """混合压缩策略"""
276 # 首先尝试滑动窗口
277 compressed_messages, compressed_tokens = self._sliding_window_compress(messages, current_tokens)
279 # 如果仍然超出限制,使用摘要压缩
280 if compressed_tokens > self.config.max_tokens:
281 compressed_messages, compressed_tokens = self._summary_compression_compress(messages, current_tokens)
283 return compressed_messages, compressed_tokens
285 def _calculate_importance_score(self, message: Dict[str, Any], index: int, total: int) -> float:
286 """计算消息重要性分数"""
287 score = 0.0
289 # 基于角色评分
290 role = message.get('role', '')
291 if role == 'system':
292 score += 1.0 # 系统消息最重要
293 elif role == 'user':
294 score += 0.8 # 用户消息次重要
295 elif role == 'assistant':
296 score += 0.6 # 助手消息一般重要
298 # 基于位置评分(越新越重要)
299 recency = (index + 1) / total
300 score += recency * 0.4
302 # 基于内容长度评分
303 content = message.get('content', '')
304 if isinstance(content, str):
305 length_score = min(len(content) / 1000, 1.0) # 标准化到0-1
306 score += length_score * 0.2
308 return score
310 def _create_summary(self, messages: List[Dict[str, Any]]) -> str:
311 """创建消息摘要"""
312 summary_parts = []
314 for msg in messages:
315 role = msg.get('role', '')
316 content = msg.get('content', '')
318 if isinstance(content, str):
319 # 截取前100个字符作为摘要
320 content_preview = content[:100] + ('...' if len(content) > 100 else '')
321 summary_parts.append(f"{role}: {content_preview}")
323 summary = " | ".join(summary_parts)
324 return summary[:self.config.summary_max_length]
326 def _estimate_message_tokens(self, message: Dict[str, Any]) -> int:
327 """估算消息的token数量"""
328 content = message.get('content', '')
329 if isinstance(content, str):
330 return len(content) // 4 # 简单估算
331 elif isinstance(content, list):
332 # 多模态内容
333 total_length = 0
334 for item in content:
335 if item.get('type') == 'text':
336 total_length += len(item.get('text', ''))
337 return total_length // 4
338 return 0
341class ContextManager(Trackable):
342 """上下文管理器"""
344 def __init__(self, config: Optional[ContextConfig] = None):
345 self.config = config or ContextConfig()
346 self.token_counter = TokenCounter()
347 self.compressor = MessageCompressor(self.config)
348 self.log = logger.bind(src='context_manager')
350 # 原始消息历史
351 self.chat_history = ChatHistory()
353 # 消息缓存
354 self._messages_cache: List[Dict[str, Any]] = []
355 self._cached_tokens = 0
356 self._last_compression_time = 0
358 def __len__(self):
359 """返回消息数量,用于兼容旧接口"""
360 return len(self.chat_history)
362 def add(self, role, content):
363 """添加简单消息,用于兼容旧接口"""
364 message = ChatMessage(role=role, content=content)
365 self.add_message(message)
367 def add_message(self, message: ChatMessage):
368 """添加消息到上下文"""
369 # 添加到原始历史
370 self.chat_history.add_message(message)
372 # 转换为字典格式添加到缓存
373 msg_dict = {
374 'role': message.role,
375 'content': message.content
376 }
378 # 更新token计数
379 self.token_counter.add_message(message)
381 # 添加到缓存
382 self._messages_cache.append(msg_dict)
383 self._cached_tokens += self.compressor._estimate_message_tokens(msg_dict)
385 self.log.debug(f"Added message: {message.role}, tokens: {self._cached_tokens}")
387 def get_messages(self, force_compress: bool = False) -> List[Dict[str, Any]]:
388 """获取压缩后的消息列表"""
389 current_time = time.time()
391 # 检查是否需要压缩
392 if force_compress or self.config.auto_compress:
393 should_compress = (
394 force_compress or
395 self._cached_tokens > self.config.max_tokens or
396 len(self._messages_cache) > self.config.max_rounds * 2 or
397 (current_time - self._last_compression_time) > 300 # 5分钟强制压缩
398 )
400 if should_compress:
401 self._compress_messages()
403 return self._messages_cache.copy()
405 def get_usage(self):
406 """获取使用统计"""
407 return self.chat_history.get_usage()
409 def get_summary(self):
410 """获取摘要统计"""
411 return self.chat_history.get_summary()
413 def json(self):
414 """获取原始消息的JSON格式,用于序列化"""
415 return self.chat_history.get_state()
417 def _compress_messages(self):
418 """压缩消息"""
419 if not self._messages_cache:
420 return
422 original_count = len(self._messages_cache)
423 original_tokens = self._cached_tokens
425 # 执行压缩
426 compressed_messages, compressed_tokens = self.compressor.compress_messages(
427 self._messages_cache, self._cached_tokens
428 )
430 # 更新缓存
431 self._messages_cache = compressed_messages
432 self._cached_tokens = compressed_tokens
433 self._last_compression_time = time.time()
435 self.log.info(
436 f"Context compressed: {original_count}->{len(compressed_messages)} messages, "
437 f"{original_tokens}->{compressed_tokens} tokens"
438 )
440 def get_stats(self) -> Dict[str, Any]:
441 """获取上下文统计信息"""
442 return {
443 'message_count': len(self._messages_cache),
444 'total_tokens': self._cached_tokens,
445 'max_tokens': self.config.max_tokens,
446 'compression_ratio': len(self._messages_cache) / max(len(self._messages_cache), 1),
447 'last_compression': self._last_compression_time
448 }
450 def _clear_messages_cache(self):
451 """清理消息缓存,只保留最初的两条消息和最后一条消息"""
452 if not self._messages_cache:
453 return
455 # 如果消息数量小于等于3,不需要清理
456 if len(self._messages_cache) <= 3:
457 return
459 # 保留最初的两条消息和最后一条消息
460 first_two = self._messages_cache[:2]
461 last_one = self._messages_cache[-1:]
463 # 合并消息
464 self._messages_cache = first_two + last_one
466 # 重新计算 token 数量
467 self._cached_tokens = sum(self.compressor._estimate_message_tokens(msg) for msg in self._messages_cache)
468 self.log.info(f"Context cleaned: {len(self._messages_cache)} messages, {self._cached_tokens} tokens")
470 def clear(self):
471 """清空上下文"""
472 #self.chat_history.clear()
473 self._clear_messages_cache()
474 #self._cached_tokens = 0
475 self.token_counter.reset()
476 self._last_compression_time = 0
477 self.log.info("Context cleared")
480 def _rebuild_cache(self):
481 """重建消息缓存"""
482 self._messages_cache.clear()
483 self._cached_tokens = 0
484 self.token_counter.reset()
486 for message in self.chat_history.messages:
487 msg_dict = {
488 'role': message.role,
489 'content': message.content
490 }
491 self._messages_cache.append(msg_dict)
492 self._cached_tokens += self.compressor._estimate_message_tokens(msg_dict)
493 self.token_counter.add_message(message)
495 def update_config(self, config: ContextConfig):
496 """更新配置"""
497 self.config = config
498 self.compressor = MessageCompressor(config)
499 self.log.info(f"Context config updated: {config.strategy.value}")
501 def get_state(self):
502 """获取需要持久化的状态数据"""
503 return {
504 'config': {
505 'max_tokens': self.config.max_tokens,
506 'max_rounds': self.config.max_rounds,
507 'auto_compress': self.config.auto_compress,
508 'strategy': self.config.strategy.value,
509 'compression_ratio': self.config.compression_ratio,
510 'importance_threshold': self.config.importance_threshold,
511 'summary_max_length': self.config.summary_max_length,
512 'preserve_system': self.config.preserve_system,
513 'preserve_recent': self.config.preserve_recent,
514 },
515 'chat_history': self.chat_history.get_state(),
516 'messages_cache': self._messages_cache.copy(),
517 'cached_tokens': self._cached_tokens,
518 'token_counter': {
519 'total_tokens': self.token_counter.total_tokens,
520 'input_tokens': self.token_counter.input_tokens,
521 'output_tokens': self.token_counter.output_tokens,
522 },
523 'last_compression_time': self._last_compression_time,
524 }
526 def restore_state(self, state_data):
527 """从状态数据恢复上下文管理器"""
528 if not state_data:
529 return
531 # 恢复配置
532 if 'config' in state_data:
533 config_data = state_data['config']
534 self.config = ContextConfig(
535 max_tokens=config_data.get('max_tokens', 100000),
536 max_rounds=config_data.get('max_rounds', 10),
537 auto_compress=config_data.get('auto_compress', False),
538 strategy=ContextStrategy(config_data.get('strategy', 'hybrid')),
539 compression_ratio=config_data.get('compression_ratio', 0.3),
540 importance_threshold=config_data.get('importance_threshold', 0.5),
541 summary_max_length=config_data.get('summary_max_length', 200),
542 preserve_system=config_data.get('preserve_system', True),
543 preserve_recent=config_data.get('preserve_recent', 3),
544 )
545 self.compressor = MessageCompressor(self.config)
547 # 恢复原始聊天历史
548 if 'chat_history' in state_data:
549 self.chat_history.restore_state(state_data['chat_history'])
551 # 恢复消息缓存
552 self._messages_cache = state_data.get('messages_cache', []).copy()
553 self._cached_tokens = state_data.get('cached_tokens', 0)
555 # 恢复token计数器
556 if 'token_counter' in state_data:
557 token_data = state_data['token_counter']
558 self.token_counter.total_tokens = token_data.get('total_tokens', 0)
559 self.token_counter.input_tokens = token_data.get('input_tokens', 0)
560 self.token_counter.output_tokens = token_data.get('output_tokens', 0)
562 # 恢复其他状态
563 self._last_compression_time = state_data.get('last_compression_time', 0)
565 self.log.info(f"Context manager state restored: {len(self._messages_cache)} messages, {self._cached_tokens} tokens")
567 # Trackable接口实现
568 def get_checkpoint(self) -> int:
569 """获取当前检查点状态 - 返回消息数量"""
570 return len(self.chat_history)
572 def restore_to_checkpoint(self, checkpoint: Optional[int]):
573 """恢复到指定检查点"""
574 if checkpoint is None:
575 # 恢复到初始状态
576 self.clear()
577 else:
578 # 恢复到指定消息数量
579 if checkpoint < len(self.chat_history):
580 # 删除checkpoint之后的消息
581 self.chat_history.delete_range(checkpoint, len(self.chat_history))
582 # 重建缓存
583 self._rebuild_cache()