Coverage for aipyapp/aipy/task_state.py: 22%

124 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import json 

5import time 

6from pathlib import Path 

7from typing import Dict, Any, Optional, Union, TYPE_CHECKING 

8from collections import OrderedDict 

9 

10from loguru import logger 

11 

12if TYPE_CHECKING: 

13 from .task import Task 

14 

15# 任务版本常量 

16TASK_VERSION = 20250806 

17 

18class TaskStateError(Exception): 

19 """任务状态异常""" 

20 pass 

21 

22class TaskState: 

23 """任务状态管理器 - 封装任务状态的序列化、反序列化和文件操作""" 

24 

25 def __init__(self, task: Optional['Task'] = None): 

26 self.log = logger.bind(src='task_state') 

27 

28 # 任务基本信息 

29 self.version: int = TASK_VERSION 

30 self.task_id: Optional[str] = None 

31 self.instruction: Optional[str] = None 

32 self.start_time: Optional[float] = None 

33 self.done_time: Optional[float] = None 

34 

35 # 组件状态 

36 self._component_states: Dict[str, Any] = {} 

37 

38 if task: 

39 self.from_task(task) 

40 

41 @classmethod 

42 def from_file(cls, path: Union[str, Path]) -> 'TaskState': 

43 """从文件创建 TaskState 对象""" 

44 instance = cls() 

45 instance.load_from_file(path) 

46 return instance 

47 

48 @classmethod 

49 def from_dict(cls, data: Dict[str, Any]) -> 'TaskState': 

50 """从字典创建 TaskState 对象""" 

51 instance = cls() 

52 instance._load_from_dict(data) 

53 return instance 

54 

55 def from_task(self, task: 'Task') -> None: 

56 """从 Task 对象提取状态""" 

57 self.task_id = task.task_id 

58 self.instruction = task.instruction 

59 self.start_time = task.start_time 

60 self.done_time = task.done_time 

61 

62 # 提取各组件状态 

63 self._component_states['steps'] = task.step_manager.get_state() 

64 self._component_states['context_manager'] = task.context_manager.get_state() 

65 self._component_states['runner'] = task.runner.get_state() 

66 self._component_states['blocks'] = task.code_blocks.get_state() 

67 

68 # 提取事件记录(如果存在) 

69 if task.event_recorder: 

70 self._component_states['events'] = task.event_recorder.get_events() 

71 

72 self.log.debug('Extracted state from task', task_id=self.task_id) 

73 

74 def restore_to_task(self, task: 'Task') -> None: 

75 """恢复状态到 Task 对象""" 

76 # 验证版本兼容性 

77 if self.version != TASK_VERSION: 

78 raise TaskStateError(f'Task version mismatch: expected {TASK_VERSION}, got {self.version}') 

79 

80 # 恢复任务基本信息 

81 task.task_id = self.task_id 

82 task.instruction = self.instruction 

83 task.start_time = self.start_time 

84 task.done_time = self.done_time 

85 

86 # 恢复各组件状态 

87 if 'steps' in self._component_states: 

88 task.step_manager.restore_state(self._component_states['steps']) 

89 

90 if 'context_manager' in self._component_states: 

91 task.context_manager.restore_state(self._component_states['context_manager']) 

92 

93 if 'runner' in self._component_states: 

94 task.runner.restore_state(self._component_states['runner']) 

95 

96 if 'blocks' in self._component_states: 

97 task.code_blocks.restore_state(self._component_states['blocks']) 

98 

99 # 恢复事件记录器状态(如果存在) 

100 if 'events' in self._component_states and task.event_recorder: 

101 events_data = self._component_states['events'] 

102 task.event_recorder.restore_state({'events': events_data, 'enabled': True}) 

103 

104 self.log.info('Restored state to task', task_id=self.task_id) 

105 

106 def to_dict(self) -> Dict[str, Any]: 

107 """序列化为字典""" 

108 # 使用 OrderedDict 保持字段顺序 

109 data = OrderedDict() 

110 data['version'] = self.version 

111 data['task_id'] = self.task_id 

112 data['instruction'] = self.instruction 

113 data['start_time'] = int(self.start_time) if self.start_time else None 

114 data['done_time'] = int(self.done_time) if self.done_time else None 

115 

116 # 添加组件状态 

117 for key, state in self._component_states.items(): 

118 data[key] = state 

119 

120 return data 

121 

122 def _load_from_dict(self, data: Dict[str, Any]) -> None: 

123 """从字典加载状态""" 

124 # 加载基本字段 

125 basic_fields = {'version', 'task_id', 'instruction', 'start_time', 'done_time'} 

126 self.version = data.get('version', TASK_VERSION) 

127 self.task_id = data.get('task_id') 

128 self.instruction = data.get('instruction') 

129 self.start_time = data.get('start_time') 

130 self.done_time = data.get('done_time') 

131 

132 # 加载所有非基本字段作为组件状态 

133 for key, value in data.items(): 

134 if key not in basic_fields: 

135 self._component_states[key] = value 

136 

137 def save_to_file(self, path: Union[str, Path]) -> None: 

138 """保存到文件""" 

139 path = Path(path) 

140 

141 # 确保目录存在 

142 path.parent.mkdir(parents=True, exist_ok=True) 

143 

144 # 更新完成时间 

145 if not self.done_time: 

146 self.done_time = time.time() 

147 

148 try: 

149 with open(path, 'w', encoding='utf-8') as f: 

150 json.dump(self.to_dict(), f, ensure_ascii=False, indent=4, default=str) 

151 self.log.info('Saved task state to file', path=str(path)) 

152 except Exception as e: 

153 self.log.exception('Failed to save task state', path=str(path)) 

154 raise TaskStateError(f'Failed to save task state: {e}') from e 

155 

156 def load_from_file(self, path: Union[str, Path]) -> None: 

157 """从文件加载状态""" 

158 path = Path(path) 

159 self.validate_file(path) 

160 

161 try: 

162 with open(path, 'r', encoding='utf-8') as f: 

163 data = json.load(f) 

164 self._load_from_dict(data) 

165 self.log.info('Loaded task state from file', path=str(path), task_id=self.task_id) 

166 except json.JSONDecodeError as e: 

167 raise TaskStateError(f'Invalid JSON file: {e}') from e 

168 except Exception as e: 

169 self.log.exception('Failed to load task state', path=str(path)) 

170 raise TaskStateError(f'Failed to load task state: {e}') from e 

171 

172 def validate_file(self, path: Union[str, Path]) -> None: 

173 """验证文件格式和存在性""" 

174 path = Path(path) 

175 

176 if not path.exists(): 

177 raise FileNotFoundError(f"Task file not found: {path}") 

178 

179 if not path.name.endswith('.json'): 

180 raise ValueError("Task file must be a .json file") 

181 

182 if not path.is_file(): 

183 raise ValueError(f"Path is not a file: {path}") 

184 

185 def get_component_state(self, name: str) -> Any: 

186 """获取组件状态""" 

187 return self._component_states.get(name) 

188 

189 def set_component_state(self, name: str, state: Any) -> None: 

190 """设置组件状态""" 

191 self._component_states[name] = state 

192 

193 def has_component_state(self, name: str) -> bool: 

194 """检查是否有指定组件的状态""" 

195 return name in self._component_states 

196 

197 def get_summary(self) -> Dict[str, Any]: 

198 """获取状态摘要信息""" 

199 return { 

200 'version': self.version, 

201 'task_id': self.task_id, 

202 'instruction': self.instruction[:50] + '...' if self.instruction and len(self.instruction) > 50 else self.instruction, 

203 'start_time': self.start_time, 

204 'done_time': self.done_time, 

205 'components': list(self._component_states.keys()) 

206 } 

207 

208 def __repr__(self): 

209 return f"<TaskState task_id={self.task_id}, version={self.version}, components={list(self._component_states.keys())}>"