Coverage for src/dataknobs_fsm/llm/base.py: 0%

104 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""Base LLM abstraction components. 

2 

3This module provides the base abstractions for unified LLM operations. 

4""" 

5 

6from abc import ABC, abstractmethod 

7from dataclasses import dataclass, field 

8from enum import Enum 

9from typing import ( 

10 Any, Dict, List, Union, AsyncIterator, Iterator, 

11 Callable, Protocol 

12) 

13from datetime import datetime 

14 

15 

16class CompletionMode(Enum): 

17 """LLM completion modes.""" 

18 CHAT = "chat" # Chat completion with message history 

19 TEXT = "text" # Text completion 

20 INSTRUCT = "instruct" # Instruction following 

21 EMBEDDING = "embedding" # Generate embeddings 

22 FUNCTION = "function" # Function calling 

23 

24 

25class ModelCapability(Enum): 

26 """Model capabilities.""" 

27 TEXT_GENERATION = "text_generation" 

28 CHAT = "chat" 

29 EMBEDDINGS = "embeddings" 

30 FUNCTION_CALLING = "function_calling" 

31 VISION = "vision" 

32 CODE = "code" 

33 JSON_MODE = "json_mode" 

34 STREAMING = "streaming" 

35 

36 

37@dataclass 

38class LLMMessage: 

39 """Represents a message in LLM conversation.""" 

40 role: str # 'system', 'user', 'assistant', 'function' 

41 content: str 

42 name: str | None = None # For function messages 

43 function_call: Dict[str, Any] | None = None # For function calling 

44 metadata: Dict[str, Any] = field(default_factory=dict) 

45 

46 

47@dataclass 

48class LLMResponse: 

49 """Response from LLM.""" 

50 content: str 

51 model: str 

52 finish_reason: str | None = None # 'stop', 'length', 'function_call' 

53 usage: Dict[str, int] | None = None # tokens used 

54 function_call: Dict[str, Any] | None = None 

55 metadata: Dict[str, Any] = field(default_factory=dict) 

56 created_at: datetime = field(default_factory=datetime.now) 

57 

58 

59@dataclass 

60class LLMStreamResponse: 

61 """Streaming response from LLM.""" 

62 delta: str # Incremental content 

63 is_final: bool = False 

64 finish_reason: str | None = None 

65 usage: Dict[str, int] | None = None 

66 metadata: Dict[str, Any] = field(default_factory=dict) 

67 

68 

69@dataclass 

70class LLMConfig: 

71 """Configuration for LLM operations.""" 

72 provider: str # 'openai', 'anthropic', 'ollama', etc. 

73 model: str # Model name/identifier 

74 api_key: str | None = None 

75 api_base: str | None = None # Custom API endpoint 

76 

77 # Generation parameters 

78 temperature: float = 0.7 

79 max_tokens: int | None = None 

80 top_p: float = 1.0 

81 frequency_penalty: float = 0.0 

82 presence_penalty: float = 0.0 

83 stop_sequences: List[str] | None = None 

84 

85 # Mode settings 

86 mode: CompletionMode = CompletionMode.CHAT 

87 system_prompt: str | None = None 

88 response_format: str | None = None # 'text' or 'json' 

89 

90 # Function calling 

91 functions: List[Dict[str, Any]] | None = None 

92 function_call: Union[str, Dict[str, str]] | None = None # 'auto', 'none', or specific function 

93 

94 # Streaming 

95 stream: bool = False 

96 stream_callback: Callable[[LLMStreamResponse], None] | None = None 

97 

98 # Rate limiting 

99 rate_limit: int | None = None # Requests per minute 

100 retry_count: int = 3 

101 retry_delay: float = 1.0 

102 timeout: float = 60.0 

103 

104 # Advanced settings 

105 seed: int | None = None # For reproducibility 

106 logit_bias: Dict[str, float] | None = None 

107 user_id: str | None = None 

108 

109 # Provider-specific options 

110 options: Dict[str, Any] = field(default_factory=dict) 

111 

112 

113class LLMProvider(ABC): 

114 """Base LLM provider interface.""" 

115 

116 def __init__(self, config: LLMConfig): 

117 """Initialize provider with configuration.""" 

118 self.config = config 

119 self._client = None 

120 self._is_initialized = False 

121 

122 @abstractmethod 

123 def initialize(self) -> None: 

124 """Initialize the LLM client.""" 

125 pass 

126 

127 @abstractmethod 

128 def close(self) -> None: 

129 """Close the LLM client.""" 

130 pass 

131 

132 @abstractmethod 

133 def validate_model(self) -> bool: 

134 """Validate that the model is available.""" 

135 pass 

136 

137 @abstractmethod 

138 def get_capabilities(self) -> List[ModelCapability]: 

139 """Get model capabilities.""" 

140 pass 

141 

142 @property 

143 def is_initialized(self) -> bool: 

144 """Check if provider is initialized.""" 

145 return self._is_initialized 

146 

147 def __enter__(self): 

148 """Context manager entry.""" 

149 self.initialize() 

150 return self 

151 

152 def __exit__(self, exc_type, exc_val, exc_tb): 

153 """Context manager exit.""" 

154 self.close() 

155 

156 

157class AsyncLLMProvider(LLMProvider): 

158 """Async LLM provider interface.""" 

159 

160 @abstractmethod 

161 async def complete( 

162 self, 

163 messages: Union[str, List[LLMMessage]], 

164 **kwargs 

165 ) -> LLMResponse: 

166 """Generate completion asynchronously. 

167  

168 Args: 

169 messages: Input messages or prompt 

170 **kwargs: Additional parameters 

171  

172 Returns: 

173 LLM response 

174 """ 

175 pass 

176 

177 @abstractmethod 

178 async def stream_complete( 

179 self, 

180 messages: Union[str, List[LLMMessage]], 

181 **kwargs 

182 ) -> AsyncIterator[LLMStreamResponse]: 

183 """Generate streaming completion asynchronously. 

184  

185 Args: 

186 messages: Input messages or prompt 

187 **kwargs: Additional parameters 

188  

189 Yields: 

190 Streaming response chunks 

191 """ 

192 pass 

193 

194 @abstractmethod 

195 async def embed( 

196 self, 

197 texts: Union[str, List[str]], 

198 **kwargs 

199 ) -> Union[List[float], List[List[float]]]: 

200 """Generate embeddings asynchronously. 

201  

202 Args: 

203 texts: Input text(s) 

204 **kwargs: Additional parameters 

205  

206 Returns: 

207 Embedding vector(s) 

208 """ 

209 pass 

210 

211 @abstractmethod 

212 async def function_call( 

213 self, 

214 messages: List[LLMMessage], 

215 functions: List[Dict[str, Any]], 

216 **kwargs 

217 ) -> LLMResponse: 

218 """Execute function calling asynchronously. 

219  

220 Args: 

221 messages: Conversation messages 

222 functions: Available functions 

223 **kwargs: Additional parameters 

224  

225 Returns: 

226 Response with function call 

227 """ 

228 pass 

229 

230 async def initialize(self) -> None: 

231 """Initialize the async LLM client.""" 

232 self._is_initialized = True 

233 

234 async def close(self) -> None: 

235 """Close the async LLM client.""" 

236 self._is_initialized = False 

237 

238 async def __aenter__(self): 

239 """Async context manager entry.""" 

240 await self.initialize() 

241 return self 

242 

243 async def __aexit__(self, exc_type, exc_val, exc_tb): 

244 """Async context manager exit.""" 

245 await self.close() 

246 

247 

248class SyncLLMProvider(LLMProvider): 

249 """Synchronous LLM provider interface.""" 

250 

251 @abstractmethod 

252 def complete( 

253 self, 

254 messages: Union[str, List[LLMMessage]], 

255 **kwargs 

256 ) -> LLMResponse: 

257 """Generate completion synchronously. 

258  

259 Args: 

260 messages: Input messages or prompt 

261 **kwargs: Additional parameters 

262  

263 Returns: 

264 LLM response 

265 """ 

266 pass 

267 

268 @abstractmethod 

269 def stream_complete( 

270 self, 

271 messages: Union[str, List[LLMMessage]], 

272 **kwargs 

273 ) -> Iterator[LLMStreamResponse]: 

274 """Generate streaming completion synchronously. 

275  

276 Args: 

277 messages: Input messages or prompt 

278 **kwargs: Additional parameters 

279  

280 Yields: 

281 Streaming response chunks 

282 """ 

283 pass 

284 

285 @abstractmethod 

286 def embed( 

287 self, 

288 texts: Union[str, List[str]], 

289 **kwargs 

290 ) -> Union[List[float], List[List[float]]]: 

291 """Generate embeddings synchronously. 

292  

293 Args: 

294 texts: Input text(s) 

295 **kwargs: Additional parameters 

296  

297 Returns: 

298 Embedding vector(s) 

299 """ 

300 pass 

301 

302 @abstractmethod 

303 def function_call( 

304 self, 

305 messages: List[LLMMessage], 

306 functions: List[Dict[str, Any]], 

307 **kwargs 

308 ) -> LLMResponse: 

309 """Execute function calling synchronously. 

310  

311 Args: 

312 messages: Conversation messages 

313 functions: Available functions 

314 **kwargs: Additional parameters 

315  

316 Returns: 

317 Response with function call 

318 """ 

319 pass 

320 

321 def initialize(self) -> None: 

322 """Initialize the sync LLM client.""" 

323 self._is_initialized = True 

324 

325 def close(self) -> None: 

326 """Close the sync LLM client.""" 

327 self._is_initialized = False 

328 

329 

330class LLMAdapter(ABC): 

331 """Base adapter for converting between different LLM formats.""" 

332 

333 @abstractmethod 

334 def adapt_messages( 

335 self, 

336 messages: List[LLMMessage] 

337 ) -> Any: 

338 """Adapt messages to provider format.""" 

339 pass 

340 

341 @abstractmethod 

342 def adapt_response( 

343 self, 

344 response: Any 

345 ) -> LLMResponse: 

346 """Adapt provider response to standard format.""" 

347 pass 

348 

349 @abstractmethod 

350 def adapt_config( 

351 self, 

352 config: LLMConfig 

353 ) -> Dict[str, Any]: 

354 """Adapt configuration to provider format.""" 

355 pass 

356 

357 

358class LLMMiddleware(Protocol): 

359 """Protocol for LLM middleware.""" 

360 

361 async def process_request( 

362 self, 

363 messages: List[LLMMessage], 

364 config: LLMConfig 

365 ) -> List[LLMMessage]: 

366 """Process request before sending to LLM.""" 

367 ... 

368 

369 async def process_response( 

370 self, 

371 response: LLMResponse, 

372 config: LLMConfig 

373 ) -> LLMResponse: 

374 """Process response from LLM.""" 

375 ...