Coverage for aipyapp/aipy/step_manager.py: 40%

62 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import time 

5from datetime import datetime 

6from dataclasses import dataclass 

7from collections import namedtuple 

8from typing import List, Dict, Any, Optional 

9 

10from ..interface import Trackable 

11 

12@dataclass 

13class Step: 

14 """步骤数据结构""" 

15 instruction: str 

16 round: int 

17 response: str 

18 timestamp: float 

19 checkpoints: Dict[str, Any] 

20 

21 def get_summary(self) -> str: 

22 """获取步骤摘要""" 

23 return f"Step {self.round}: {self.instruction[:32]}..." 

24 

25class StepManager: 

26 """步骤管理器 - 统一管理所有步骤和相关对象""" 

27 

28 def __init__(self): 

29 self.steps: List[Step] = [] 

30 self.trackables: Dict[str, Trackable] = {} 

31 

32 def register_trackable(self, name: str, obj: Trackable): 

33 """注册可追踪对象""" 

34 self.trackables[name] = obj 

35 

36 def create_checkpoint(self, instruction: str, round: int, response: str) -> Step: 

37 """创建新的检查点""" 

38 checkpoints = {} 

39 for name, trackable in self.trackables.items(): 

40 checkpoints[name] = trackable.get_checkpoint() 

41 

42 step = Step( 

43 instruction=instruction, 

44 round=round, 

45 response=response, 

46 timestamp=time.time(), 

47 checkpoints=checkpoints 

48 ) 

49 self.steps.append(step) 

50 return step 

51 

52 def delete_step(self, index: int): 

53 """删除指定步骤及之后的所有步骤""" 

54 if index < 0 or index >= len(self.steps): 

55 raise ValueError("Invalid step index") 

56 

57 # 获取恢复目标检查点 

58 if index == 0: 

59 # 删除第一个步骤,恢复到初始状态 

60 for trackable in self.trackables.values(): 

61 trackable.restore_to_checkpoint(None) 

62 else: 

63 # 恢复到前一个步骤的状态 

64 prev_step = self.steps[index - 1] 

65 for name, trackable in self.trackables.items(): 

66 checkpoint = prev_step.checkpoints.get(name) 

67 trackable.restore_to_checkpoint(checkpoint) 

68 

69 # 删除步骤及之后的所有步骤 

70 self.steps = self.steps[:index] 

71 return True 

72 

73 def clear_all(self): 

74 """清空所有步骤""" 

75 for trackable in self.trackables.values(): 

76 trackable.restore_to_checkpoint(None) 

77 self.steps.clear() 

78 

79 def list_steps(self): 

80 """列出所有步骤""" 

81 StepRecord = namedtuple('StepRecord', ['Index', 'Instruction', 'Round', 'Ended']) 

82 return [StepRecord(index, step.instruction, step.round, datetime.fromtimestamp(step.timestamp).strftime('%Y-%m-%d %H:%M:%S')) 

83 for index, step in enumerate(self.steps)] 

84 

85 def get_step(self, index: int) -> Optional[Step]: 

86 """获取指定索引的步骤""" 

87 if index < 0 or index >= len(self.steps): 

88 return None 

89 return self.steps[index] 

90 

91 def __len__(self): 

92 """返回步骤数量""" 

93 return len(self.steps) 

94 

95 def get_state(self): 

96 """获取需要持久化的状态数据""" 

97 return [ 

98 { 

99 'instruction': step.instruction, 

100 'round': step.round, 

101 'response': step.response, 

102 'timestamp': step.timestamp, 

103 'checkpoints': step.checkpoints 

104 } 

105 for step in self.steps 

106 ] 

107 

108 def restore_state(self, state_data): 

109 """从状态数据恢复步骤管理器""" 

110 self.steps.clear() 

111 if not state_data: 

112 return 

113 

114 for step_data in state_data: 

115 step = Step( 

116 instruction=step_data['instruction'], 

117 round=step_data['round'], 

118 response=step_data.get('response', ''), 

119 timestamp=step_data.get('timestamp', time.time()), 

120 checkpoints=step_data.get('checkpoints', {}) 

121 ) 

122 self.steps.append(step)