Coverage for aipyapp/llm/base.py: 47%

86 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-11 12:02 +0200

1#! /usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3 

4import time 

5from collections import Counter 

6from abc import ABC, abstractmethod 

7from dataclasses import dataclass, field 

8from typing import Union, List, Dict, Any 

9 

10from loguru import logger 

11 

12from .. import T 

13 

14@dataclass 

15class ChatMessage: 

16 role: str 

17 content: Union[str, List[Dict[str, Any]]] 

18 reason: str = None 

19 usage: Counter = field(default_factory=Counter) 

20 

21 def to_dict(self) -> Dict[str, Any]: 

22 """转换为字典用于序列化""" 

23 return { 

24 '__type__': 'ChatMessage', 

25 'role': self.role, 

26 'content': self.content, 

27 'reason': self.reason, 

28 'usage': dict(self.usage) if self.usage else {} 

29 } 

30 

31 @classmethod 

32 def from_dict(cls, data: Dict[str, Any]) -> 'ChatMessage': 

33 """从字典恢复对象""" 

34 return cls( 

35 role=data.get('role', 'assistant'), 

36 content=data.get('content', ''), 

37 reason=data.get('reason', ''), 

38 usage=Counter(data.get('usage', {})) 

39 ) 

40 

41class BaseClient(ABC): 

42 MODEL = None 

43 BASE_URL = None 

44 TEMPERATURE = 0.5 

45 

46 def __init__(self, config): 

47 self.name = config['name'] 

48 self.log = logger.bind(src='llm', name=self.name) 

49 self.console = None 

50 self.config = config 

51 self.kind = config.get("type", "openai") 

52 self.max_tokens = config.get("max_tokens") 

53 self._model = config.get("model") or self.MODEL 

54 self._timeout = config.get("timeout") 

55 self._api_key = config.get("api_key") 

56 self._base_url = self.get_base_url() 

57 self._stream = config.get("stream", True) 

58 self._tls_verify = bool(config.get("tls_verify", True)) 

59 self._client = None 

60 params = self.get_params() 

61 params.update(config.get("params", {})) 

62 self._params = params 

63 self._temperature = params.get("temperature") or self.TEMPERATURE 

64 

65 @property 

66 def model(self): 

67 return self._model 

68 

69 @property 

70 def base_url(self): 

71 return self._base_url 

72 

73 def __repr__(self): 

74 return f"{self.name}/{self.kind}:{self._model}" 

75 

76 def get_params(self): 

77 return {} 

78 

79 def get_base_url(self): 

80 return self.config.get("base_url") or self.BASE_URL 

81 

82 def usable(self): 

83 return self._model 

84 

85 def _get_client(self): 

86 return self._client 

87 

88 @abstractmethod 

89 def get_completion(self, messages): 

90 pass 

91 

92 def add_system_prompt(self, history, system_prompt): 

93 history.add("system", system_prompt) 

94 

95 @abstractmethod 

96 def _parse_usage(self, response): 

97 pass 

98 

99 @abstractmethod 

100 def _parse_stream_response(self, response, stream_processor): 

101 pass 

102 

103 @abstractmethod 

104 def _parse_response(self, response): 

105 pass 

106 

107 def __call__(self, history, prompt, system_prompt=None, stream_processor=None): 

108 # We shall only send system prompt once 

109 if not history and system_prompt: 

110 self.add_system_prompt(history, system_prompt) 

111 history.add("user", prompt) 

112 

113 start = time.time() 

114 try: 

115 response = self.get_completion(history.get_messages()) 

116 except Exception as e: 

117 self.log.error(f"❌ [bold red]{self.name} API {T('Call failed')}: [yellow]{str(e)}") 

118 return ChatMessage(role='error', content=str(e)) 

119 

120 if self._stream: 

121 msg = self._parse_stream_response(response, stream_processor) 

122 else: 

123 msg = self._parse_response(response) 

124 

125 msg.usage['time'] = round(time.time() - start, 3) 

126 history.add_message(msg) 

127 return msg 

128