Coverage for aipyapp/aipy/llm.py: 24%
177 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3from collections import defaultdict, namedtuple
5from loguru import logger
7from .. import T, __respath__
8from ..llm import CLIENTS, ModelRegistry, ModelCapability
9from .multimodal import LLMContext
11class LineReceiver(list):
12 def __init__(self):
13 super().__init__()
14 self.buffer = ""
16 @property
17 def content(self):
18 return '\n'.join(self)
20 def feed(self, data: str):
21 self.buffer += data
22 new_lines = []
24 while '\n' in self.buffer:
25 line, self.buffer = self.buffer.split('\n', 1)
26 if line:
27 self.append(line)
28 new_lines.append(line)
30 return new_lines
32 def empty(self):
33 return not self and not self.buffer
35 def done(self):
36 buffer = self.buffer
37 if buffer:
38 self.append(buffer)
39 self.buffer = ""
40 return buffer
42class StreamProcessor:
43 """流式数据处理器,负责处理 LLM 流式响应并发送事件"""
45 def __init__(self, task, name):
46 self.task = task
47 self.name = name
48 self.lr = LineReceiver()
49 self.lr_reason = LineReceiver()
51 @property
52 def content(self):
53 return self.lr.content
55 @property
56 def reason(self):
57 return self.lr_reason.content
59 def __enter__(self):
60 """支持上下文管理器协议"""
61 self.task.emit('stream_start', llm=self.name)
62 return self
64 def __exit__(self, exc_type, exc_val, exc_tb):
65 """支持上下文管理器协议"""
66 if self.lr.buffer:
67 self.process_chunk('\n')
68 self.task.emit('stream_end', llm=self.name)
70 def process_chunk(self, content, *, reason=False):
71 """处理流式数据块并发送事件"""
72 if not content:
73 return
75 # 处理思考内容的结束
76 if not reason and self.lr.empty() and not self.lr_reason.empty():
77 line = self.lr_reason.done()
78 if line:
79 self.task.emit('stream', llm=self.name, lines=[line, "\n\n----\n\n"], reason=True)
81 # 处理当前数据块
82 lr = self.lr_reason if reason else self.lr
83 lines = lr.feed(content)
84 if not lines:
85 return
87 # 过滤掉特殊注释行
88 lines2 = [line for line in lines if not line.startswith('<!-- Block-') and not line.startswith('<!-- Cmd-')]
89 if lines2:
90 self.task.emit('stream', llm=self.name, lines=lines2, reason=reason)
93class ClientManager(object):
94 MAX_TOKENS = 8192
96 def __init__(self, settings):
97 self.clients = {}
98 self.default = None
99 self.current = None
100 self.log = logger.bind(src='client_manager')
101 self.names = self._init_clients(settings)
102 self.model_registry = ModelRegistry(__respath__ / "models.yaml")
104 def _create_client(self, config):
105 kind = config.get("type", "openai")
106 client_class = CLIENTS.get(kind.lower())
107 if not client_class:
108 self.log.error('Unsupported LLM provider', kind=kind)
109 return None
110 return client_class(config)
112 def _init_clients(self, settings):
113 names = defaultdict(set)
114 max_tokens = settings.get('max_tokens', self.MAX_TOKENS)
115 for name, config in settings.llm.items():
116 if not config.get('enable', True):
117 names['disabled'].add(name)
118 continue
120 config['name'] = name
121 try:
122 client = self._create_client(config)
123 except Exception as e:
124 self.log.exception('Error creating LLM client', config=config)
125 names['error'].add(name)
126 continue
128 if not client or not client.usable():
129 names['disabled'].add(name)
130 self.log.error('LLM client not usable', name=name, config=config)
131 continue
133 names['enabled'].add(name)
134 if not client.max_tokens:
135 client.max_tokens = max_tokens
136 self.clients[name] = client
138 if config.get('default', False) and not self.default:
139 self.default = client
140 names['default'] = name
142 if not self.default:
143 name = list(self.clients.keys())[0]
144 self.default = self.clients[name]
145 names['default'] = name
147 self.current = self.default
148 return names
150 def __len__(self):
151 return len(self.clients)
153 def __repr__(self):
154 return f"Current: {'default' if self.current == self.default else self.current}, Default: {self.default}"
156 def __contains__(self, name):
157 return name in self.clients
159 def use(self, name):
160 client = self.clients.get(name)
161 if client and client.usable():
162 self.current = client
163 return True
164 return False
166 def get_client(self, name):
167 return self.clients.get(name)
169 def Client(self, task, context_manager):
170 return Client(self, task, context_manager)
172 def to_records(self):
173 LLMRecord = namedtuple('LLMRecord', ['Name', 'Model', 'Max_Tokens', 'Base_URL'])
174 rows = []
175 for name, client in self.clients.items():
176 rows.append(LLMRecord(name, client.model, client.max_tokens, client.base_url))
177 return rows
179 def get_model_info(self, model: str):
180 return self.model_registry.get_model_info(model)
182class Client:
183 def __init__(self, manager: ClientManager, task, context_manager):
184 self.manager = manager
185 self.current = manager.current
186 self.task = task
188 # 接收外部传入的上下文管理器
189 self.context_manager = context_manager
191 self.log = logger.bind(src='client', name=self.current.name)
193 def add_message(self, message):
194 """添加消息"""
195 self.context_manager.add_message(message)
197 @property
198 def name(self):
199 return self.current.name
201 def use(self, name):
202 client = self.manager.get_client(name)
203 if client and client.usable():
204 self.current = client
205 self.log = logger.bind(src='client', name=self.current.name)
206 return True
207 return False
209 def has_capability(self, content: LLMContext) -> bool:
210 # 判断 content 需要什么能力
211 if isinstance(content, str):
212 return True
214 #TODO: 不应该硬编码字符串
215 if self.current.kind == 'trust':
216 return True
218 model = self.current.model
219 model = model.rsplit('/', 1)[-1]
220 model_info = self.manager.get_model_info(model)
221 if not model_info:
222 self.log.error(f"Model info not found for {model}")
223 return False
225 capabilities = set()
226 for item in content:
227 if item['type'] == 'image_url':
228 capabilities.add(ModelCapability.IMAGE_INPUT)
229 if item['type'] == 'file':
230 capabilities.add(ModelCapability.FILE_INPUT)
231 if item['type'] == 'text':
232 capabilities.add(ModelCapability.TEXT)
234 return any(capability in model_info.capabilities for capability in capabilities)
236 def __call__(self, content: LLMContext, *, system_prompt=None):
237 client = self.current
238 stream_processor = StreamProcessor(self.task, client.name)
240 # 直接传递 ContextManager,它已经实现了所需的接口
241 msg = client(self.context_manager, content, system_prompt=system_prompt, stream_processor=stream_processor)
242 return msg