Coverage for aipyapp/aipy/llm.py: 24%

177 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3from collections import defaultdict, namedtuple 

4 

5from loguru import logger 

6 

7from .. import T, __respath__ 

8from ..llm import CLIENTS, ModelRegistry, ModelCapability 

9from .multimodal import LLMContext 

10 

11class LineReceiver(list): 

12 def __init__(self): 

13 super().__init__() 

14 self.buffer = "" 

15 

16 @property 

17 def content(self): 

18 return '\n'.join(self) 

19 

20 def feed(self, data: str): 

21 self.buffer += data 

22 new_lines = [] 

23 

24 while '\n' in self.buffer: 

25 line, self.buffer = self.buffer.split('\n', 1) 

26 if line: 

27 self.append(line) 

28 new_lines.append(line) 

29 

30 return new_lines 

31 

32 def empty(self): 

33 return not self and not self.buffer 

34 

35 def done(self): 

36 buffer = self.buffer 

37 if buffer: 

38 self.append(buffer) 

39 self.buffer = "" 

40 return buffer 

41 

42class StreamProcessor: 

43 """流式数据处理器,负责处理 LLM 流式响应并发送事件""" 

44 

45 def __init__(self, task, name): 

46 self.task = task 

47 self.name = name 

48 self.lr = LineReceiver() 

49 self.lr_reason = LineReceiver() 

50 

51 @property 

52 def content(self): 

53 return self.lr.content 

54 

55 @property 

56 def reason(self): 

57 return self.lr_reason.content 

58 

59 def __enter__(self): 

60 """支持上下文管理器协议""" 

61 self.task.emit('stream_start', llm=self.name) 

62 return self 

63 

64 def __exit__(self, exc_type, exc_val, exc_tb): 

65 """支持上下文管理器协议""" 

66 if self.lr.buffer: 

67 self.process_chunk('\n') 

68 self.task.emit('stream_end', llm=self.name) 

69 

70 def process_chunk(self, content, *, reason=False): 

71 """处理流式数据块并发送事件""" 

72 if not content: 

73 return 

74 

75 # 处理思考内容的结束 

76 if not reason and self.lr.empty() and not self.lr_reason.empty(): 

77 line = self.lr_reason.done() 

78 if line: 

79 self.task.emit('stream', llm=self.name, lines=[line, "\n\n----\n\n"], reason=True) 

80 

81 # 处理当前数据块 

82 lr = self.lr_reason if reason else self.lr 

83 lines = lr.feed(content) 

84 if not lines: 

85 return 

86 

87 # 过滤掉特殊注释行 

88 lines2 = [line for line in lines if not line.startswith('<!-- Block-') and not line.startswith('<!-- Cmd-')] 

89 if lines2: 

90 self.task.emit('stream', llm=self.name, lines=lines2, reason=reason) 

91 

92 

93class ClientManager(object): 

94 MAX_TOKENS = 8192 

95 

96 def __init__(self, settings): 

97 self.clients = {} 

98 self.default = None 

99 self.current = None 

100 self.log = logger.bind(src='client_manager') 

101 self.names = self._init_clients(settings) 

102 self.model_registry = ModelRegistry(__respath__ / "models.yaml") 

103 

104 def _create_client(self, config): 

105 kind = config.get("type", "openai") 

106 client_class = CLIENTS.get(kind.lower()) 

107 if not client_class: 

108 self.log.error('Unsupported LLM provider', kind=kind) 

109 return None 

110 return client_class(config) 

111 

112 def _init_clients(self, settings): 

113 names = defaultdict(set) 

114 max_tokens = settings.get('max_tokens', self.MAX_TOKENS) 

115 for name, config in settings.llm.items(): 

116 if not config.get('enable', True): 

117 names['disabled'].add(name) 

118 continue 

119 

120 config['name'] = name 

121 try: 

122 client = self._create_client(config) 

123 except Exception as e: 

124 self.log.exception('Error creating LLM client', config=config) 

125 names['error'].add(name) 

126 continue 

127 

128 if not client or not client.usable(): 

129 names['disabled'].add(name) 

130 self.log.error('LLM client not usable', name=name, config=config) 

131 continue 

132 

133 names['enabled'].add(name) 

134 if not client.max_tokens: 

135 client.max_tokens = max_tokens 

136 self.clients[name] = client 

137 

138 if config.get('default', False) and not self.default: 

139 self.default = client 

140 names['default'] = name 

141 

142 if not self.default: 

143 name = list(self.clients.keys())[0] 

144 self.default = self.clients[name] 

145 names['default'] = name 

146 

147 self.current = self.default 

148 return names 

149 

150 def __len__(self): 

151 return len(self.clients) 

152 

153 def __repr__(self): 

154 return f"Current: {'default' if self.current == self.default else self.current}, Default: {self.default}" 

155 

156 def __contains__(self, name): 

157 return name in self.clients 

158 

159 def use(self, name): 

160 client = self.clients.get(name) 

161 if client and client.usable(): 

162 self.current = client 

163 return True 

164 return False 

165 

166 def get_client(self, name): 

167 return self.clients.get(name) 

168 

169 def Client(self, task, context_manager): 

170 return Client(self, task, context_manager) 

171 

172 def to_records(self): 

173 LLMRecord = namedtuple('LLMRecord', ['Name', 'Model', 'Max_Tokens', 'Base_URL']) 

174 rows = [] 

175 for name, client in self.clients.items(): 

176 rows.append(LLMRecord(name, client.model, client.max_tokens, client.base_url)) 

177 return rows 

178 

179 def get_model_info(self, model: str): 

180 return self.model_registry.get_model_info(model) 

181 

182class Client: 

183 def __init__(self, manager: ClientManager, task, context_manager): 

184 self.manager = manager 

185 self.current = manager.current 

186 self.task = task 

187 

188 # 接收外部传入的上下文管理器 

189 self.context_manager = context_manager 

190 

191 self.log = logger.bind(src='client', name=self.current.name) 

192 

193 def add_message(self, message): 

194 """添加消息""" 

195 self.context_manager.add_message(message) 

196 

197 @property 

198 def name(self): 

199 return self.current.name 

200 

201 def use(self, name): 

202 client = self.manager.get_client(name) 

203 if client and client.usable(): 

204 self.current = client 

205 self.log = logger.bind(src='client', name=self.current.name) 

206 return True 

207 return False 

208 

209 def has_capability(self, content: LLMContext) -> bool: 

210 # 判断 content 需要什么能力 

211 if isinstance(content, str): 

212 return True 

213 

214 #TODO: 不应该硬编码字符串 

215 if self.current.kind == 'trust': 

216 return True 

217 

218 model = self.current.model 

219 model = model.rsplit('/', 1)[-1] 

220 model_info = self.manager.get_model_info(model) 

221 if not model_info: 

222 self.log.error(f"Model info not found for {model}") 

223 return False 

224 

225 capabilities = set() 

226 for item in content: 

227 if item['type'] == 'image_url': 

228 capabilities.add(ModelCapability.IMAGE_INPUT) 

229 if item['type'] == 'file': 

230 capabilities.add(ModelCapability.FILE_INPUT) 

231 if item['type'] == 'text': 

232 capabilities.add(ModelCapability.TEXT) 

233 

234 return any(capability in model_info.capabilities for capability in capabilities) 

235 

236 def __call__(self, content: LLMContext, *, system_prompt=None): 

237 client = self.current 

238 stream_processor = StreamProcessor(self.task, client.name) 

239 

240 # 直接传递 ContextManager,它已经实现了所需的接口 

241 msg = client(self.context_manager, content, system_prompt=system_prompt, stream_processor=stream_processor) 

242 return msg 

243