Coverage for src/chat_limiter/batch.py: 77%

298 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-18 21:15 +0100

1""" 

2Batch processing functionality for handling multiple requests efficiently. 

3""" 

4 

5import asyncio 

6import logging 

7import traceback 

8from abc import ABC, abstractmethod 

9from collections.abc import Callable 

10from concurrent.futures import ThreadPoolExecutor, as_completed 

11from dataclasses import dataclass, field 

12from typing import ( 

13 TYPE_CHECKING, 

14 Any, 

15 Generic, 

16 TypeVar, 

17) 

18 

19import httpx 

20from tqdm import tqdm 

21 

22if TYPE_CHECKING: 

23 pass 

24 

25from .limiter import ChatLimiter 

26from .types import ChatCompletionRequest, ChatCompletionResponse 

27 

28logger = logging.getLogger(__name__) 

29 

30# Type variables for generic batch processing 

31BatchItemT = TypeVar("BatchItemT") 

32BatchResultT = TypeVar("BatchResultT") 

33 

34 

35@dataclass 

36class BatchConfig: 

37 """Configuration for batch processing.""" 

38 

39 # Concurrency settings 

40 max_concurrent_requests: int = 10 

41 max_workers: int = 4 # For sync processing 

42 

43 # Retry settings 

44 max_retries_per_item: int = 3 

45 retry_delay: float = 1.0 

46 

47 # Progress tracking 

48 show_progress: bool = True 

49 progress_desc: str = "Processing batch" 

50 

51 # Error handling 

52 stop_on_first_error: bool = False 

53 collect_errors: bool = True 

54 

55 # Fine-grained logging configuration 

56 print_prompts: bool = False 

57 print_responses: bool = False 

58 verbose_timeouts: bool = False 

59 verbose_exceptions: bool = False 

60 print_rate_limits: bool = False 

61 print_request_initiation: bool = False 

62 

63 # Response format configuration 

64 json_mode: bool = False 

65 

66 # Reasoning configuration (for thinking models like o1, o3, o4) 

67 reasoning_effort: str | None = None # None, "low", "medium", or "high" 

68 

69 # Batch size optimization 

70 adaptive_batch_size: bool = True 

71 min_batch_size: int = 1 

72 max_batch_size: int = 100 

73 

74 # Request grouping 

75 group_by_model: bool = True 

76 group_by_provider: bool = True 

77 

78 def __post_init__(self): 

79 """Validate configuration after initialization.""" 

80 if self.reasoning_effort is not None: 

81 valid_efforts = {"low", "medium", "high"} 

82 if self.reasoning_effort not in valid_efforts: 

83 raise ValueError( 

84 f"reasoning_effort must be one of {valid_efforts} or None, " 

85 f"got: {self.reasoning_effort}" 

86 ) 

87 

88 

89@dataclass 

90class BatchItem(Generic[BatchItemT]): 

91 """A single item in a batch request.""" 

92 

93 # Item data 

94 data: BatchItemT 

95 

96 # Request configuration 

97 method: str = "POST" 

98 url: str = "/chat/completions" 

99 json_data: dict[str, Any] | None = None 

100 

101 # Metadata 

102 id: str | None = None 

103 metadata: dict[str, Any] = field(default_factory=dict) 

104 

105 # Processing state 

106 attempt_count: int = 0 

107 last_error: Exception | None = None 

108 

109 

110@dataclass 

111class BatchResult(Generic[BatchResultT]): 

112 """Result of processing a batch item.""" 

113 

114 # Original item 

115 item: "BatchItem[Any]" 

116 

117 # Result data 

118 result: BatchResultT | None = None 

119 

120 # Processing metadata 

121 duration: float = 0.0 

122 attempt_count: int = 0 

123 

124 # Error information 

125 success: bool = True 

126 error_message: str | None = None 

127 

128 # Response metadata 

129 response_headers: dict[str, str] = field(default_factory=dict) 

130 status_code: int | None = None 

131 

132 

133class BatchProcessor(ABC, Generic[BatchItemT, BatchResultT]): 

134 """Abstract base class for batch processing.""" 

135 

136 def __init__( 

137 self, 

138 limiter: ChatLimiter, 

139 config: BatchConfig | None = None, 

140 ): 

141 self.limiter = limiter 

142 self.config = config or BatchConfig() 

143 self._results: list[BatchResult[BatchResultT]] = [] 

144 self._errors: list[Exception] = [] 

145 

146 # Configure limiter logging based on batch config 

147 self.limiter.set_print_rate_limit_info(self.config.print_rate_limits) 

148 self.limiter.set_print_request_initiation(self.config.print_request_initiation) 

149 

150 @abstractmethod 

151 async def process_item(self, item: BatchItem[BatchItemT]) -> BatchResultT: 

152 """Process a single batch item.""" 

153 pass 

154 

155 @abstractmethod 

156 def process_item_sync(self, item: BatchItem[BatchItemT]) -> BatchResultT: 

157 """Process a single batch item synchronously.""" 

158 pass 

159 

160 def create_batch_items( 

161 self, 

162 items: list[BatchItemT], 

163 request_fn: Callable[[BatchItemT], tuple[str, str, dict[str, Any]]] | None = None, 

164 ) -> list[BatchItem[BatchItemT]]: 

165 """Create batch items from raw data.""" 

166 batch_items = [] 

167 

168 for i, item in enumerate(items): 

169 batch_item = BatchItem( 

170 data=item, 

171 id=f"item_{i}", 

172 ) 

173 

174 # Configure request if function provided 

175 if request_fn: 

176 method, url, json_data = request_fn(item) 

177 batch_item.method = method 

178 batch_item.url = url 

179 batch_item.json_data = json_data 

180 

181 batch_items.append(batch_item) 

182 

183 return batch_items 

184 

185 async def process_batch( 

186 self, 

187 items: list[BatchItemT] | list[BatchItem[BatchItemT]], 

188 request_fn: Callable[[BatchItemT], tuple[str, str, dict[str, Any]]] | None = None, 

189 ) -> list[BatchResult[BatchResultT]]: 

190 """Process a batch of items asynchronously.""" 

191 # Convert to batch items if needed 

192 if items and not isinstance(items[0], BatchItem): 

193 batch_items = self.create_batch_items(items, request_fn) # type: ignore 

194 else: 

195 batch_items = items # type: ignore 

196 

197 # Group items if configured 

198 if self.config.group_by_model or self.config.group_by_provider: 

199 grouped_items = self._group_items(batch_items) 

200 else: 

201 grouped_items = {"default": batch_items} 

202 

203 # Process groups 

204 all_results = [] 

205 

206 # Calculate total items for progress tracking 

207 total_items = sum(len(group_items) for group_items in grouped_items.values()) 

208 

209 # Initialize progress bar if enabled 

210 progress_bar = None 

211 if self.config.show_progress: 

212 progress_bar = tqdm( 

213 total=total_items, 

214 desc=self.config.progress_desc, 

215 unit="item" 

216 ) 

217 

218 for group_name, group_items in grouped_items.items(): 

219 logger.info( 

220 f"Processing group '{group_name}' with {len(group_items)} items" 

221 ) 

222 

223 # Create semaphore for concurrency control 

224 semaphore = asyncio.Semaphore(self.config.max_concurrent_requests) 

225 

226 # Process items with concurrency control and progress tracking 

227 tasks = [ 

228 self._process_item_with_retry(item, semaphore, progress_bar) for item in group_items 

229 ] 

230 

231 # Wait for all tasks to complete 

232 group_results = await asyncio.gather(*tasks, return_exceptions=True) 

233 

234 # Handle exceptions from gather 

235 for i, result in enumerate(group_results): 

236 if isinstance(result, Exception): 

237 # Create error result 

238 error_result: BatchResult[BatchResultT] = BatchResult( 

239 item=group_items[i], 

240 success=False, 

241 error_message=str(result), 

242 attempt_count=group_items[i].attempt_count, 

243 ) 

244 all_results.append(error_result) 

245 else: 

246 all_results.append(result) # type: ignore 

247 

248 # Close progress bar if it was created 

249 if progress_bar: 

250 progress_bar.close() 

251 

252 self._results = all_results 

253 return all_results 

254 

255 def process_batch_sync( 

256 self, 

257 items: list[BatchItemT] | list[BatchItem[BatchItemT]], 

258 request_fn: Callable[[BatchItemT], tuple[str, str, dict[str, Any]]] | None = None, 

259 ) -> list[BatchResult[BatchResultT]]: 

260 """Process a batch of items synchronously.""" 

261 # Convert to batch items if needed 

262 if items and not isinstance(items[0], BatchItem): 

263 batch_items = self.create_batch_items(items, request_fn) # type: ignore 

264 else: 

265 batch_items = items # type: ignore 

266 

267 # Group items if configured 

268 if self.config.group_by_model or self.config.group_by_provider: 

269 grouped_items = self._group_items(batch_items) 

270 else: 

271 grouped_items = {"default": batch_items} 

272 

273 # Calculate total items for progress tracking 

274 total_items = sum(len(group_items) for group_items in grouped_items.values()) 

275 

276 # Initialize progress bar if enabled 

277 progress_bar = None 

278 if self.config.show_progress: 

279 progress_bar = tqdm( 

280 total=total_items, 

281 desc=self.config.progress_desc, 

282 unit="item" 

283 ) 

284 

285 # Process groups 

286 all_results = [] 

287 for group_name, group_items in grouped_items.items(): 

288 logger.info( 

289 f"Processing group '{group_name}' with {len(group_items)} items" 

290 ) 

291 

292 # Use ThreadPoolExecutor for concurrent processing 

293 with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor: 

294 # Submit all tasks 

295 future_to_item = { 

296 executor.submit(self._process_item_sync_with_retry, item, progress_bar): item 

297 for item in group_items 

298 } 

299 

300 # Collect results 

301 for future in as_completed(future_to_item): 

302 item = future_to_item[future] 

303 try: 

304 result = future.result() 

305 all_results.append(result) 

306 except Exception as e: 

307 error_result: BatchResult[BatchResultT] = BatchResult( 

308 item=item, 

309 success=False, 

310 error_message=str(e), 

311 attempt_count=item.attempt_count, 

312 ) 

313 all_results.append(error_result) 

314 

315 # Close progress bar if it was created 

316 if progress_bar: 

317 progress_bar.close() 

318 

319 self._results = all_results 

320 return all_results 

321 

322 def _group_items( 

323 self, items: list[BatchItem[BatchItemT]] 

324 ) -> dict[str, list[BatchItem[BatchItemT]]]: 

325 """Group items by model or provider.""" 

326 groups: dict[str, list[BatchItem[BatchItemT]]] = {} 

327 

328 for item in items: 

329 # Determine group key 

330 group_key = "default" 

331 

332 if ( 

333 self.config.group_by_model 

334 and item.json_data 

335 and "model" in item.json_data 

336 ): 

337 group_key = item.json_data["model"] 

338 elif self.config.group_by_provider: 

339 group_key = self.limiter.provider.value 

340 

341 # Add to group 

342 if group_key not in groups: 

343 groups[group_key] = [] 

344 groups[group_key].append(item) 

345 

346 return groups 

347 

348 async def _process_item_with_retry( 

349 self, 

350 item: BatchItem[BatchItemT], 

351 semaphore: asyncio.Semaphore, 

352 progress_bar: tqdm | None = None, 

353 ) -> BatchResult[BatchResultT]: 

354 """Process a single item with retry logic.""" 

355 async with semaphore: 

356 import time 

357 

358 start_time = time.time() 

359 

360 for attempt in range(self.config.max_retries_per_item + 1): 

361 item.attempt_count = attempt + 1 

362 

363 try: 

364 # Print request initiation if enabled 

365 if self.config.print_request_initiation: 

366 print(f"Sent request for batch item {item.id} (attempt {attempt + 1})") 

367 

368 # Process the item 

369 result = await self.process_item(item) 

370 

371 # Update progress bar on success 

372 if progress_bar: 

373 progress_bar.update(1) 

374 

375 # Success 

376 return BatchResult( 

377 item=item, 

378 result=result, 

379 success=True, 

380 duration=time.time() - start_time, 

381 attempt_count=item.attempt_count, 

382 ) 

383 

384 except Exception as e: 

385 item.last_error = e 

386 

387 # Check if this is a timeout error 

388 is_timeout_error = ( 

389 isinstance(e, (httpx.ReadTimeout, httpx.ConnectTimeout)) or 

390 (hasattr(e, '__cause__') and isinstance(e.__cause__, (httpx.ReadTimeout, httpx.ConnectTimeout))) or 

391 'ReadTimeout' in str(type(e)) or 'timeout' in str(e).lower() 

392 ) 

393 

394 # Print user-friendly error messages based on config 

395 if is_timeout_error and self.config.verbose_timeouts: 

396 # Get current timeout from the limiter 

397 current_timeout = getattr(self.limiter, '_user_timeout', 120.0) 

398 print(f"⏱️ TIMEOUT ERROR in batch item {item.id} (attempt {attempt + 1}):") 

399 print(f" Current timeout setting: {current_timeout} seconds") 

400 print(f" The request took longer than {current_timeout}s to complete.") 

401 print("") 

402 print("💡 How to fix this:") 

403 print(f" 1. Increase timeout: ChatLimiter.for_model('{getattr(self.limiter, 'provider', 'your-model')}', timeout={current_timeout + 60})") 

404 print(f" 2. Reduce concurrency: BatchConfig(max_concurrent_requests={max(1, self.config.max_concurrent_requests // 2)})") 

405 print(f" 3. Current concurrency: {self.config.max_concurrent_requests} requests") 

406 print("") 

407 elif not is_timeout_error and self.config.verbose_exceptions: 

408 print(f"❌ Exception in batch item {item.id} (attempt {attempt + 1}):") 

409 

410 if self.config.verbose_exceptions: 

411 traceback.print_exc() 

412 

413 # If this is the last attempt or we should stop on error 

414 if ( 

415 attempt == self.config.max_retries_per_item 

416 or self.config.stop_on_first_error 

417 ): 

418 # Update progress bar on final failure 

419 if progress_bar: 

420 progress_bar.update(1) 

421 

422 return BatchResult( 

423 item=item, 

424 success=False, 

425 error_message=str(e), 

426 duration=time.time() - start_time, 

427 attempt_count=item.attempt_count, 

428 ) 

429 

430 # Wait before retry - longer for timeout errors 

431 if is_timeout_error: 

432 # For timeout errors, wait longer and suggest more aggressive backing off 

433 retry_delay = self.config.retry_delay * (3**attempt) # More aggressive backoff 

434 else: 

435 retry_delay = self.config.retry_delay * (2**attempt) 

436 

437 await asyncio.sleep(retry_delay) 

438 

439 # This should never be reached, but added for type checking 

440 return BatchResult( 

441 item=item, 

442 success=False, 

443 error_message="Unexpected error in retry logic", 

444 duration=time.time() - start_time, 

445 attempt_count=item.attempt_count, 

446 ) 

447 

448 def _process_item_sync_with_retry( 

449 self, 

450 item: BatchItem[BatchItemT], 

451 progress_bar: tqdm | None = None, 

452 ) -> BatchResult[BatchResultT]: 

453 """Process a single item with retry logic (sync).""" 

454 import time 

455 

456 start_time = time.time() 

457 

458 for attempt in range(self.config.max_retries_per_item + 1): 

459 item.attempt_count = attempt + 1 

460 

461 try: 

462 # Print request initiation if enabled 

463 if self.config.print_request_initiation: 

464 print(f"Sent request for batch item {item.id} (attempt {attempt + 1})") 

465 

466 # Process the item 

467 result = self.process_item_sync(item) 

468 

469 # Update progress bar on success 

470 if progress_bar: 

471 progress_bar.update(1) 

472 

473 # Success 

474 return BatchResult( 

475 item=item, 

476 result=result, 

477 success=True, 

478 duration=time.time() - start_time, 

479 attempt_count=item.attempt_count, 

480 ) 

481 

482 except Exception as e: 

483 item.last_error = e 

484 

485 # Print traceback if verbose exceptions is enabled 

486 if self.config.verbose_exceptions: 

487 print(f"Exception in batch item {item.id} (attempt {attempt + 1}):") 

488 traceback.print_exc() 

489 

490 # If this is the last attempt or we should stop on error 

491 if ( 

492 attempt == self.config.max_retries_per_item 

493 or self.config.stop_on_first_error 

494 ): 

495 # Update progress bar on final failure 

496 if progress_bar: 

497 progress_bar.update(1) 

498 

499 return BatchResult( 

500 item=item, 

501 success=False, 

502 error_message=str(e), 

503 duration=time.time() - start_time, 

504 attempt_count=item.attempt_count, 

505 ) 

506 

507 # Wait before retry 

508 time.sleep(self.config.retry_delay * (2**attempt)) 

509 

510 # This should never be reached, but added for type checking 

511 return BatchResult( 

512 item=item, 

513 success=False, 

514 error_message="Unexpected error in retry logic", 

515 duration=time.time() - start_time, 

516 attempt_count=item.attempt_count, 

517 ) 

518 

519 def get_success_rate(self) -> float: 

520 """Get the success rate of the last batch.""" 

521 if not self._results: 

522 return 0.0 

523 

524 successful = sum(1 for r in self._results if r.success) 

525 return successful / len(self._results) 

526 

527 def get_successful_results(self) -> list[BatchResult[BatchResultT]]: 

528 """Get only successful results.""" 

529 return [r for r in self._results if r.success] 

530 

531 def get_failed_results(self) -> list[BatchResult[BatchResultT]]: 

532 """Get only failed results.""" 

533 return [r for r in self._results if not r.success] 

534 

535 def get_stats(self) -> dict[str, Any]: 

536 """Get comprehensive processing statistics.""" 

537 if not self._results: 

538 return {"total": 0, "successful": 0, "failed": 0, "success_rate": 0.0} 

539 

540 successful = self.get_successful_results() 

541 failed = self.get_failed_results() 

542 

543 # Calculate timing statistics 

544 durations = [r.duration for r in self._results] 

545 avg_duration = sum(durations) / len(durations) if durations else 0 

546 

547 return { 

548 "total": len(self._results), 

549 "successful": len(successful), 

550 "failed": len(failed), 

551 "success_rate": self.get_success_rate(), 

552 "avg_duration": avg_duration, 

553 "total_duration": sum(durations), 

554 "avg_attempts": sum(r.attempt_count for r in self._results) 

555 / len(self._results), 

556 } 

557 

558# High-level chat completion batch processing 

559class ChatCompletionBatchProcessor(BatchProcessor[ChatCompletionRequest, ChatCompletionResponse]): 

560 """High-level batch processor for chat completion requests.""" 

561 

562 async def process_item(self, item: BatchItem[ChatCompletionRequest]) -> ChatCompletionResponse: 

563 """Process a single chat completion request using high-level interface.""" 

564 request = item.data 

565 

566 # Check json_mode compatibility 

567 if self.config.json_mode: 

568 assert self.limiter.provider.value == "openai", \ 

569 f"json_mode is only supported with OpenAI provider, but got '{self.limiter.provider.value}'" 

570 

571 # Log prompt if enabled 

572 if self.config.print_prompts: 

573 print(f"\n--- PROMPT (Item {item.id}) ---") 

574 print(f"MODEL: {request.model}") 

575 for msg in request.messages: 

576 print(f"{msg.role.value.upper()}: {msg.content}") 

577 print("--- END PROMPT ---\n") 

578 

579 # Use the high-level chat completion method 

580 kwargs = { 

581 "model": request.model, 

582 "messages": request.messages, 

583 "max_tokens": request.max_tokens, 

584 "temperature": request.temperature, 

585 "top_p": request.top_p, 

586 "stop": request.stop, 

587 "stream": request.stream, 

588 # Provider-specific parameters 

589 "frequency_penalty": request.frequency_penalty, 

590 "presence_penalty": request.presence_penalty, 

591 "top_k": request.top_k, 

592 "reasoning_effort": self.config.reasoning_effort, 

593 } 

594 

595 # Add json_mode if enabled 

596 if self.config.json_mode: 

597 kwargs["response_format"] = {"type": "json_object"} 

598 

599 response = await self.limiter.chat_completion(**kwargs) 

600 

601 # Check for errors in the response 

602 if not response.success: 

603 raise Exception(f"Chat completion failed: {response.error_message}") 

604 

605 # Log response if enabled 

606 if self.config.print_responses: 

607 print(f"\n--- RESPONSE (Item {item.id}) ---") 

608 print(f"MODEL: {response.model}") 

609 if response.choices: 

610 for i, choice in enumerate(response.choices): 

611 print(f"CHOICE {i}: {choice.message.content}") 

612 print("--- END RESPONSE ---\n") 

613 

614 return response 

615 

616 def process_item_sync(self, item: BatchItem[ChatCompletionRequest]) -> ChatCompletionResponse: 

617 """Process a single chat completion request synchronously using high-level interface.""" 

618 request = item.data 

619 

620 # Check json_mode compatibility 

621 if self.config.json_mode: 

622 assert self.limiter.provider.value == "openai", \ 

623 f"json_mode is only supported with OpenAI provider, but got '{self.limiter.provider.value}'" 

624 

625 # Log prompt if enabled 

626 if self.config.print_prompts: 

627 print(f"\n--- PROMPT (Item {item.id}) ---") 

628 print(f"MODEL: {request.model}") 

629 for msg in request.messages: 

630 print(f"{msg.role.value.upper()}: {msg.content}") 

631 print("--- END PROMPT ---\n") 

632 

633 # Use the high-level chat completion method (sync) 

634 kwargs = { 

635 "model": request.model, 

636 "messages": request.messages, 

637 "max_tokens": request.max_tokens, 

638 "temperature": request.temperature, 

639 "top_p": request.top_p, 

640 "stop": request.stop, 

641 "stream": request.stream, 

642 # Provider-specific parameters 

643 "frequency_penalty": request.frequency_penalty, 

644 "presence_penalty": request.presence_penalty, 

645 "top_k": request.top_k, 

646 "reasoning_effort": self.config.reasoning_effort, 

647 } 

648 

649 # Add json_mode if enabled 

650 if self.config.json_mode: 

651 kwargs["response_format"] = {"type": "json_object"} 

652 

653 response = self.limiter.chat_completion_sync(**kwargs) 

654 

655 # Check for errors in the response 

656 if not response.success: 

657 raise Exception(f"Chat completion failed: {response.error_message}") 

658 

659 # Log response if enabled 

660 if self.config.print_responses: 

661 print(f"\n--- RESPONSE (Item {item.id}) ---") 

662 print(f"MODEL: {response.model}") 

663 if response.choices: 

664 for i, choice in enumerate(response.choices): 

665 print(f"CHOICE {i}: {choice.message.content}") 

666 print("--- END RESPONSE ---\n") 

667 

668 return response 

669 

670 

671# Convenience functions for high-level chat completion batches 

672async def process_chat_completion_batch( 

673 limiter: ChatLimiter, 

674 requests: list[ChatCompletionRequest], 

675 config: BatchConfig | None = None, 

676) -> list[BatchResult[ChatCompletionResponse]]: 

677 """ 

678 Process a batch of high-level chat completion requests. 

679 

680 Args: 

681 limiter: Configured ChatLimiter instance 

682 requests: List of ChatCompletionRequest objects 

683 config: Optional batch processing configuration 

684 

685 Returns: 

686 List of batch results containing ChatCompletionResponse objects 

687 

688 Example: 

689 from chat_limiter import ChatLimiter, Message, MessageRole, ChatCompletionRequest 

690 

691 requests = [ 

692 ChatCompletionRequest( 

693 model="gpt-4o", 

694 messages=[Message(role=MessageRole.USER, content="Hello!")], 

695 max_tokens=50 

696 ), 

697 ChatCompletionRequest( 

698 model="gpt-4o", 

699 messages=[Message(role=MessageRole.USER, content="How are you?")], 

700 max_tokens=50 

701 ) 

702 ] 

703 

704 async with ChatLimiter.for_model("gpt-4o", api_key) as limiter: 

705 results = await process_chat_completion_batch(limiter, requests) 

706 """ 

707 processor = ChatCompletionBatchProcessor(limiter, config) 

708 return await processor.process_batch(requests) 

709 

710 

711def process_chat_completion_batch_sync( 

712 limiter: ChatLimiter, 

713 requests: list[ChatCompletionRequest], 

714 config: BatchConfig | None = None, 

715) -> list[BatchResult[ChatCompletionResponse]]: 

716 """ 

717 Process a batch of high-level chat completion requests synchronously. 

718 

719 Args: 

720 limiter: Configured ChatLimiter instance 

721 requests: List of ChatCompletionRequest objects 

722 config: Optional batch processing configuration 

723 

724 Returns: 

725 List of batch results containing ChatCompletionResponse objects 

726 """ 

727 processor = ChatCompletionBatchProcessor(limiter, config) 

728 return processor.process_batch_sync(requests) 

729 

730 

731# Helper function for creating chat completion requests from simple data 

732def create_chat_completion_requests( 

733 model: str, 

734 prompts: list[str], 

735 max_tokens: int | None = None, 

736 temperature: float | None = None, 

737 **kwargs: Any, 

738) -> list[ChatCompletionRequest]: 

739 """ 

740 Create a list of ChatCompletionRequest objects from simple prompts. 

741 

742 Args: 

743 model: The model to use for all requests 

744 prompts: List of user prompts 

745 max_tokens: Maximum tokens per completion 

746 temperature: Sampling temperature 

747 **kwargs: Additional parameters for all requests 

748 

749 Returns: 

750 List of ChatCompletionRequest objects 

751 

752 Example: 

753 requests = create_chat_completion_requests( 

754 model="gpt-4o", 

755 prompts=["Hello!", "How are you?", "What is Python?"], 

756 max_tokens=50, 

757 temperature=0.7 

758 ) 

759 """ 

760 from .types import Message, MessageRole 

761 

762 requests = [] 

763 for prompt in prompts: 

764 request = ChatCompletionRequest( 

765 model=model, 

766 messages=[Message(role=MessageRole.USER, content=prompt)], 

767 max_tokens=max_tokens, 

768 temperature=temperature, 

769 **kwargs 

770 ) 

771 requests.append(request) 

772 

773 return requests