Coverage for aipyapp/aipy/step_manager.py: 40%
62 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import time
5from datetime import datetime
6from dataclasses import dataclass
7from collections import namedtuple
8from typing import List, Dict, Any, Optional
10from ..interface import Trackable
12@dataclass
13class Step:
14 """步骤数据结构"""
15 instruction: str
16 round: int
17 response: str
18 timestamp: float
19 checkpoints: Dict[str, Any]
21 def get_summary(self) -> str:
22 """获取步骤摘要"""
23 return f"Step {self.round}: {self.instruction[:32]}..."
25class StepManager:
26 """步骤管理器 - 统一管理所有步骤和相关对象"""
28 def __init__(self):
29 self.steps: List[Step] = []
30 self.trackables: Dict[str, Trackable] = {}
32 def register_trackable(self, name: str, obj: Trackable):
33 """注册可追踪对象"""
34 self.trackables[name] = obj
36 def create_checkpoint(self, instruction: str, round: int, response: str) -> Step:
37 """创建新的检查点"""
38 checkpoints = {}
39 for name, trackable in self.trackables.items():
40 checkpoints[name] = trackable.get_checkpoint()
42 step = Step(
43 instruction=instruction,
44 round=round,
45 response=response,
46 timestamp=time.time(),
47 checkpoints=checkpoints
48 )
49 self.steps.append(step)
50 return step
52 def delete_step(self, index: int):
53 """删除指定步骤及之后的所有步骤"""
54 if index < 0 or index >= len(self.steps):
55 raise ValueError("Invalid step index")
57 # 获取恢复目标检查点
58 if index == 0:
59 # 删除第一个步骤,恢复到初始状态
60 for trackable in self.trackables.values():
61 trackable.restore_to_checkpoint(None)
62 else:
63 # 恢复到前一个步骤的状态
64 prev_step = self.steps[index - 1]
65 for name, trackable in self.trackables.items():
66 checkpoint = prev_step.checkpoints.get(name)
67 trackable.restore_to_checkpoint(checkpoint)
69 # 删除步骤及之后的所有步骤
70 self.steps = self.steps[:index]
71 return True
73 def clear_all(self):
74 """清空所有步骤"""
75 for trackable in self.trackables.values():
76 trackable.restore_to_checkpoint(None)
77 self.steps.clear()
79 def list_steps(self):
80 """列出所有步骤"""
81 StepRecord = namedtuple('StepRecord', ['Index', 'Instruction', 'Round', 'Ended'])
82 return [StepRecord(index, step.instruction, step.round, datetime.fromtimestamp(step.timestamp).strftime('%Y-%m-%d %H:%M:%S'))
83 for index, step in enumerate(self.steps)]
85 def get_step(self, index: int) -> Optional[Step]:
86 """获取指定索引的步骤"""
87 if index < 0 or index >= len(self.steps):
88 return None
89 return self.steps[index]
91 def __len__(self):
92 """返回步骤数量"""
93 return len(self.steps)
95 def get_state(self):
96 """获取需要持久化的状态数据"""
97 return [
98 {
99 'instruction': step.instruction,
100 'round': step.round,
101 'response': step.response,
102 'timestamp': step.timestamp,
103 'checkpoints': step.checkpoints
104 }
105 for step in self.steps
106 ]
108 def restore_state(self, state_data):
109 """从状态数据恢复步骤管理器"""
110 self.steps.clear()
111 if not state_data:
112 return
114 for step_data in state_data:
115 step = Step(
116 instruction=step_data['instruction'],
117 round=step_data['round'],
118 response=step_data.get('response', ''),
119 timestamp=step_data.get('timestamp', time.time()),
120 checkpoints=step_data.get('checkpoints', {})
121 )
122 self.steps.append(step)