Coverage for aipyapp/aipy/task_state.py: 22%
124 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import json
5import time
6from pathlib import Path
7from typing import Dict, Any, Optional, Union, TYPE_CHECKING
8from collections import OrderedDict
10from loguru import logger
12if TYPE_CHECKING:
13 from .task import Task
15# 任务版本常量
16TASK_VERSION = 20250806
18class TaskStateError(Exception):
19 """任务状态异常"""
20 pass
22class TaskState:
23 """任务状态管理器 - 封装任务状态的序列化、反序列化和文件操作"""
25 def __init__(self, task: Optional['Task'] = None):
26 self.log = logger.bind(src='task_state')
28 # 任务基本信息
29 self.version: int = TASK_VERSION
30 self.task_id: Optional[str] = None
31 self.instruction: Optional[str] = None
32 self.start_time: Optional[float] = None
33 self.done_time: Optional[float] = None
35 # 组件状态
36 self._component_states: Dict[str, Any] = {}
38 if task:
39 self.from_task(task)
41 @classmethod
42 def from_file(cls, path: Union[str, Path]) -> 'TaskState':
43 """从文件创建 TaskState 对象"""
44 instance = cls()
45 instance.load_from_file(path)
46 return instance
48 @classmethod
49 def from_dict(cls, data: Dict[str, Any]) -> 'TaskState':
50 """从字典创建 TaskState 对象"""
51 instance = cls()
52 instance._load_from_dict(data)
53 return instance
55 def from_task(self, task: 'Task') -> None:
56 """从 Task 对象提取状态"""
57 self.task_id = task.task_id
58 self.instruction = task.instruction
59 self.start_time = task.start_time
60 self.done_time = task.done_time
62 # 提取各组件状态
63 self._component_states['steps'] = task.step_manager.get_state()
64 self._component_states['context_manager'] = task.context_manager.get_state()
65 self._component_states['runner'] = task.runner.get_state()
66 self._component_states['blocks'] = task.code_blocks.get_state()
68 # 提取事件记录(如果存在)
69 if task.event_recorder:
70 self._component_states['events'] = task.event_recorder.get_events()
72 self.log.debug('Extracted state from task', task_id=self.task_id)
74 def restore_to_task(self, task: 'Task') -> None:
75 """恢复状态到 Task 对象"""
76 # 验证版本兼容性
77 if self.version != TASK_VERSION:
78 raise TaskStateError(f'Task version mismatch: expected {TASK_VERSION}, got {self.version}')
80 # 恢复任务基本信息
81 task.task_id = self.task_id
82 task.instruction = self.instruction
83 task.start_time = self.start_time
84 task.done_time = self.done_time
86 # 恢复各组件状态
87 if 'steps' in self._component_states:
88 task.step_manager.restore_state(self._component_states['steps'])
90 if 'context_manager' in self._component_states:
91 task.context_manager.restore_state(self._component_states['context_manager'])
93 if 'runner' in self._component_states:
94 task.runner.restore_state(self._component_states['runner'])
96 if 'blocks' in self._component_states:
97 task.code_blocks.restore_state(self._component_states['blocks'])
99 # 恢复事件记录器状态(如果存在)
100 if 'events' in self._component_states and task.event_recorder:
101 events_data = self._component_states['events']
102 task.event_recorder.restore_state({'events': events_data, 'enabled': True})
104 self.log.info('Restored state to task', task_id=self.task_id)
106 def to_dict(self) -> Dict[str, Any]:
107 """序列化为字典"""
108 # 使用 OrderedDict 保持字段顺序
109 data = OrderedDict()
110 data['version'] = self.version
111 data['task_id'] = self.task_id
112 data['instruction'] = self.instruction
113 data['start_time'] = int(self.start_time) if self.start_time else None
114 data['done_time'] = int(self.done_time) if self.done_time else None
116 # 添加组件状态
117 for key, state in self._component_states.items():
118 data[key] = state
120 return data
122 def _load_from_dict(self, data: Dict[str, Any]) -> None:
123 """从字典加载状态"""
124 # 加载基本字段
125 basic_fields = {'version', 'task_id', 'instruction', 'start_time', 'done_time'}
126 self.version = data.get('version', TASK_VERSION)
127 self.task_id = data.get('task_id')
128 self.instruction = data.get('instruction')
129 self.start_time = data.get('start_time')
130 self.done_time = data.get('done_time')
132 # 加载所有非基本字段作为组件状态
133 for key, value in data.items():
134 if key not in basic_fields:
135 self._component_states[key] = value
137 def save_to_file(self, path: Union[str, Path]) -> None:
138 """保存到文件"""
139 path = Path(path)
141 # 确保目录存在
142 path.parent.mkdir(parents=True, exist_ok=True)
144 # 更新完成时间
145 if not self.done_time:
146 self.done_time = time.time()
148 try:
149 with open(path, 'w', encoding='utf-8') as f:
150 json.dump(self.to_dict(), f, ensure_ascii=False, indent=4, default=str)
151 self.log.info('Saved task state to file', path=str(path))
152 except Exception as e:
153 self.log.exception('Failed to save task state', path=str(path))
154 raise TaskStateError(f'Failed to save task state: {e}') from e
156 def load_from_file(self, path: Union[str, Path]) -> None:
157 """从文件加载状态"""
158 path = Path(path)
159 self.validate_file(path)
161 try:
162 with open(path, 'r', encoding='utf-8') as f:
163 data = json.load(f)
164 self._load_from_dict(data)
165 self.log.info('Loaded task state from file', path=str(path), task_id=self.task_id)
166 except json.JSONDecodeError as e:
167 raise TaskStateError(f'Invalid JSON file: {e}') from e
168 except Exception as e:
169 self.log.exception('Failed to load task state', path=str(path))
170 raise TaskStateError(f'Failed to load task state: {e}') from e
172 def validate_file(self, path: Union[str, Path]) -> None:
173 """验证文件格式和存在性"""
174 path = Path(path)
176 if not path.exists():
177 raise FileNotFoundError(f"Task file not found: {path}")
179 if not path.name.endswith('.json'):
180 raise ValueError("Task file must be a .json file")
182 if not path.is_file():
183 raise ValueError(f"Path is not a file: {path}")
185 def get_component_state(self, name: str) -> Any:
186 """获取组件状态"""
187 return self._component_states.get(name)
189 def set_component_state(self, name: str, state: Any) -> None:
190 """设置组件状态"""
191 self._component_states[name] = state
193 def has_component_state(self, name: str) -> bool:
194 """检查是否有指定组件的状态"""
195 return name in self._component_states
197 def get_summary(self) -> Dict[str, Any]:
198 """获取状态摘要信息"""
199 return {
200 'version': self.version,
201 'task_id': self.task_id,
202 'instruction': self.instruction[:50] + '...' if self.instruction and len(self.instruction) > 50 else self.instruction,
203 'start_time': self.start_time,
204 'done_time': self.done_time,
205 'components': list(self._component_states.keys())
206 }
208 def __repr__(self):
209 return f"<TaskState task_id={self.task_id}, version={self.version}, components={list(self._component_states.keys())}>"