Coverage for aipyapp/llm/base.py: 47%
86 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-11 12:02 +0200
1#! /usr/bin/env python3
2# -*- coding: utf-8 -*-
4import time
5from collections import Counter
6from abc import ABC, abstractmethod
7from dataclasses import dataclass, field
8from typing import Union, List, Dict, Any
10from loguru import logger
12from .. import T
14@dataclass
15class ChatMessage:
16 role: str
17 content: Union[str, List[Dict[str, Any]]]
18 reason: str = None
19 usage: Counter = field(default_factory=Counter)
21 def to_dict(self) -> Dict[str, Any]:
22 """转换为字典用于序列化"""
23 return {
24 '__type__': 'ChatMessage',
25 'role': self.role,
26 'content': self.content,
27 'reason': self.reason,
28 'usage': dict(self.usage) if self.usage else {}
29 }
31 @classmethod
32 def from_dict(cls, data: Dict[str, Any]) -> 'ChatMessage':
33 """从字典恢复对象"""
34 return cls(
35 role=data.get('role', 'assistant'),
36 content=data.get('content', ''),
37 reason=data.get('reason', ''),
38 usage=Counter(data.get('usage', {}))
39 )
41class BaseClient(ABC):
42 MODEL = None
43 BASE_URL = None
44 TEMPERATURE = 0.5
46 def __init__(self, config):
47 self.name = config['name']
48 self.log = logger.bind(src='llm', name=self.name)
49 self.console = None
50 self.config = config
51 self.kind = config.get("type", "openai")
52 self.max_tokens = config.get("max_tokens")
53 self._model = config.get("model") or self.MODEL
54 self._timeout = config.get("timeout")
55 self._api_key = config.get("api_key")
56 self._base_url = self.get_base_url()
57 self._stream = config.get("stream", True)
58 self._tls_verify = bool(config.get("tls_verify", True))
59 self._client = None
60 params = self.get_params()
61 params.update(config.get("params", {}))
62 self._params = params
63 self._temperature = params.get("temperature") or self.TEMPERATURE
65 @property
66 def model(self):
67 return self._model
69 @property
70 def base_url(self):
71 return self._base_url
73 def __repr__(self):
74 return f"{self.name}/{self.kind}:{self._model}"
76 def get_params(self):
77 return {}
79 def get_base_url(self):
80 return self.config.get("base_url") or self.BASE_URL
82 def usable(self):
83 return self._model
85 def _get_client(self):
86 return self._client
88 @abstractmethod
89 def get_completion(self, messages):
90 pass
92 def add_system_prompt(self, history, system_prompt):
93 history.add("system", system_prompt)
95 @abstractmethod
96 def _parse_usage(self, response):
97 pass
99 @abstractmethod
100 def _parse_stream_response(self, response, stream_processor):
101 pass
103 @abstractmethod
104 def _parse_response(self, response):
105 pass
107 def __call__(self, history, prompt, system_prompt=None, stream_processor=None):
108 # We shall only send system prompt once
109 if not history and system_prompt:
110 self.add_system_prompt(history, system_prompt)
111 history.add("user", prompt)
113 start = time.time()
114 try:
115 response = self.get_completion(history.get_messages())
116 except Exception as e:
117 self.log.error(f"❌ [bold red]{self.name} API {T('Call failed')}: [yellow]{str(e)}")
118 return ChatMessage(role='error', content=str(e))
120 if self._stream:
121 msg = self._parse_stream_response(response, stream_processor)
122 else:
123 msg = self._parse_response(response)
125 msg.usage['time'] = round(time.time() - start, 3)
126 history.add_message(msg)
127 return msg