Coverage for src/chat_limiter/batch.py: 77%
298 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-18 21:15 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-18 21:15 +0100
1"""
2Batch processing functionality for handling multiple requests efficiently.
3"""
5import asyncio
6import logging
7import traceback
8from abc import ABC, abstractmethod
9from collections.abc import Callable
10from concurrent.futures import ThreadPoolExecutor, as_completed
11from dataclasses import dataclass, field
12from typing import (
13 TYPE_CHECKING,
14 Any,
15 Generic,
16 TypeVar,
17)
19import httpx
20from tqdm import tqdm
22if TYPE_CHECKING:
23 pass
25from .limiter import ChatLimiter
26from .types import ChatCompletionRequest, ChatCompletionResponse
28logger = logging.getLogger(__name__)
30# Type variables for generic batch processing
31BatchItemT = TypeVar("BatchItemT")
32BatchResultT = TypeVar("BatchResultT")
35@dataclass
36class BatchConfig:
37 """Configuration for batch processing."""
39 # Concurrency settings
40 max_concurrent_requests: int = 10
41 max_workers: int = 4 # For sync processing
43 # Retry settings
44 max_retries_per_item: int = 3
45 retry_delay: float = 1.0
47 # Progress tracking
48 show_progress: bool = True
49 progress_desc: str = "Processing batch"
51 # Error handling
52 stop_on_first_error: bool = False
53 collect_errors: bool = True
55 # Fine-grained logging configuration
56 print_prompts: bool = False
57 print_responses: bool = False
58 verbose_timeouts: bool = False
59 verbose_exceptions: bool = False
60 print_rate_limits: bool = False
61 print_request_initiation: bool = False
63 # Response format configuration
64 json_mode: bool = False
66 # Reasoning configuration (for thinking models like o1, o3, o4)
67 reasoning_effort: str | None = None # None, "low", "medium", or "high"
69 # Batch size optimization
70 adaptive_batch_size: bool = True
71 min_batch_size: int = 1
72 max_batch_size: int = 100
74 # Request grouping
75 group_by_model: bool = True
76 group_by_provider: bool = True
78 def __post_init__(self):
79 """Validate configuration after initialization."""
80 if self.reasoning_effort is not None:
81 valid_efforts = {"low", "medium", "high"}
82 if self.reasoning_effort not in valid_efforts:
83 raise ValueError(
84 f"reasoning_effort must be one of {valid_efforts} or None, "
85 f"got: {self.reasoning_effort}"
86 )
89@dataclass
90class BatchItem(Generic[BatchItemT]):
91 """A single item in a batch request."""
93 # Item data
94 data: BatchItemT
96 # Request configuration
97 method: str = "POST"
98 url: str = "/chat/completions"
99 json_data: dict[str, Any] | None = None
101 # Metadata
102 id: str | None = None
103 metadata: dict[str, Any] = field(default_factory=dict)
105 # Processing state
106 attempt_count: int = 0
107 last_error: Exception | None = None
110@dataclass
111class BatchResult(Generic[BatchResultT]):
112 """Result of processing a batch item."""
114 # Original item
115 item: "BatchItem[Any]"
117 # Result data
118 result: BatchResultT | None = None
120 # Processing metadata
121 duration: float = 0.0
122 attempt_count: int = 0
124 # Error information
125 success: bool = True
126 error_message: str | None = None
128 # Response metadata
129 response_headers: dict[str, str] = field(default_factory=dict)
130 status_code: int | None = None
133class BatchProcessor(ABC, Generic[BatchItemT, BatchResultT]):
134 """Abstract base class for batch processing."""
136 def __init__(
137 self,
138 limiter: ChatLimiter,
139 config: BatchConfig | None = None,
140 ):
141 self.limiter = limiter
142 self.config = config or BatchConfig()
143 self._results: list[BatchResult[BatchResultT]] = []
144 self._errors: list[Exception] = []
146 # Configure limiter logging based on batch config
147 self.limiter.set_print_rate_limit_info(self.config.print_rate_limits)
148 self.limiter.set_print_request_initiation(self.config.print_request_initiation)
150 @abstractmethod
151 async def process_item(self, item: BatchItem[BatchItemT]) -> BatchResultT:
152 """Process a single batch item."""
153 pass
155 @abstractmethod
156 def process_item_sync(self, item: BatchItem[BatchItemT]) -> BatchResultT:
157 """Process a single batch item synchronously."""
158 pass
160 def create_batch_items(
161 self,
162 items: list[BatchItemT],
163 request_fn: Callable[[BatchItemT], tuple[str, str, dict[str, Any]]] | None = None,
164 ) -> list[BatchItem[BatchItemT]]:
165 """Create batch items from raw data."""
166 batch_items = []
168 for i, item in enumerate(items):
169 batch_item = BatchItem(
170 data=item,
171 id=f"item_{i}",
172 )
174 # Configure request if function provided
175 if request_fn:
176 method, url, json_data = request_fn(item)
177 batch_item.method = method
178 batch_item.url = url
179 batch_item.json_data = json_data
181 batch_items.append(batch_item)
183 return batch_items
185 async def process_batch(
186 self,
187 items: list[BatchItemT] | list[BatchItem[BatchItemT]],
188 request_fn: Callable[[BatchItemT], tuple[str, str, dict[str, Any]]] | None = None,
189 ) -> list[BatchResult[BatchResultT]]:
190 """Process a batch of items asynchronously."""
191 # Convert to batch items if needed
192 if items and not isinstance(items[0], BatchItem):
193 batch_items = self.create_batch_items(items, request_fn) # type: ignore
194 else:
195 batch_items = items # type: ignore
197 # Group items if configured
198 if self.config.group_by_model or self.config.group_by_provider:
199 grouped_items = self._group_items(batch_items)
200 else:
201 grouped_items = {"default": batch_items}
203 # Process groups
204 all_results = []
206 # Calculate total items for progress tracking
207 total_items = sum(len(group_items) for group_items in grouped_items.values())
209 # Initialize progress bar if enabled
210 progress_bar = None
211 if self.config.show_progress:
212 progress_bar = tqdm(
213 total=total_items,
214 desc=self.config.progress_desc,
215 unit="item"
216 )
218 for group_name, group_items in grouped_items.items():
219 logger.info(
220 f"Processing group '{group_name}' with {len(group_items)} items"
221 )
223 # Create semaphore for concurrency control
224 semaphore = asyncio.Semaphore(self.config.max_concurrent_requests)
226 # Process items with concurrency control and progress tracking
227 tasks = [
228 self._process_item_with_retry(item, semaphore, progress_bar) for item in group_items
229 ]
231 # Wait for all tasks to complete
232 group_results = await asyncio.gather(*tasks, return_exceptions=True)
234 # Handle exceptions from gather
235 for i, result in enumerate(group_results):
236 if isinstance(result, Exception):
237 # Create error result
238 error_result: BatchResult[BatchResultT] = BatchResult(
239 item=group_items[i],
240 success=False,
241 error_message=str(result),
242 attempt_count=group_items[i].attempt_count,
243 )
244 all_results.append(error_result)
245 else:
246 all_results.append(result) # type: ignore
248 # Close progress bar if it was created
249 if progress_bar:
250 progress_bar.close()
252 self._results = all_results
253 return all_results
255 def process_batch_sync(
256 self,
257 items: list[BatchItemT] | list[BatchItem[BatchItemT]],
258 request_fn: Callable[[BatchItemT], tuple[str, str, dict[str, Any]]] | None = None,
259 ) -> list[BatchResult[BatchResultT]]:
260 """Process a batch of items synchronously."""
261 # Convert to batch items if needed
262 if items and not isinstance(items[0], BatchItem):
263 batch_items = self.create_batch_items(items, request_fn) # type: ignore
264 else:
265 batch_items = items # type: ignore
267 # Group items if configured
268 if self.config.group_by_model or self.config.group_by_provider:
269 grouped_items = self._group_items(batch_items)
270 else:
271 grouped_items = {"default": batch_items}
273 # Calculate total items for progress tracking
274 total_items = sum(len(group_items) for group_items in grouped_items.values())
276 # Initialize progress bar if enabled
277 progress_bar = None
278 if self.config.show_progress:
279 progress_bar = tqdm(
280 total=total_items,
281 desc=self.config.progress_desc,
282 unit="item"
283 )
285 # Process groups
286 all_results = []
287 for group_name, group_items in grouped_items.items():
288 logger.info(
289 f"Processing group '{group_name}' with {len(group_items)} items"
290 )
292 # Use ThreadPoolExecutor for concurrent processing
293 with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
294 # Submit all tasks
295 future_to_item = {
296 executor.submit(self._process_item_sync_with_retry, item, progress_bar): item
297 for item in group_items
298 }
300 # Collect results
301 for future in as_completed(future_to_item):
302 item = future_to_item[future]
303 try:
304 result = future.result()
305 all_results.append(result)
306 except Exception as e:
307 error_result: BatchResult[BatchResultT] = BatchResult(
308 item=item,
309 success=False,
310 error_message=str(e),
311 attempt_count=item.attempt_count,
312 )
313 all_results.append(error_result)
315 # Close progress bar if it was created
316 if progress_bar:
317 progress_bar.close()
319 self._results = all_results
320 return all_results
322 def _group_items(
323 self, items: list[BatchItem[BatchItemT]]
324 ) -> dict[str, list[BatchItem[BatchItemT]]]:
325 """Group items by model or provider."""
326 groups: dict[str, list[BatchItem[BatchItemT]]] = {}
328 for item in items:
329 # Determine group key
330 group_key = "default"
332 if (
333 self.config.group_by_model
334 and item.json_data
335 and "model" in item.json_data
336 ):
337 group_key = item.json_data["model"]
338 elif self.config.group_by_provider:
339 group_key = self.limiter.provider.value
341 # Add to group
342 if group_key not in groups:
343 groups[group_key] = []
344 groups[group_key].append(item)
346 return groups
348 async def _process_item_with_retry(
349 self,
350 item: BatchItem[BatchItemT],
351 semaphore: asyncio.Semaphore,
352 progress_bar: tqdm | None = None,
353 ) -> BatchResult[BatchResultT]:
354 """Process a single item with retry logic."""
355 async with semaphore:
356 import time
358 start_time = time.time()
360 for attempt in range(self.config.max_retries_per_item + 1):
361 item.attempt_count = attempt + 1
363 try:
364 # Print request initiation if enabled
365 if self.config.print_request_initiation:
366 print(f"Sent request for batch item {item.id} (attempt {attempt + 1})")
368 # Process the item
369 result = await self.process_item(item)
371 # Update progress bar on success
372 if progress_bar:
373 progress_bar.update(1)
375 # Success
376 return BatchResult(
377 item=item,
378 result=result,
379 success=True,
380 duration=time.time() - start_time,
381 attempt_count=item.attempt_count,
382 )
384 except Exception as e:
385 item.last_error = e
387 # Check if this is a timeout error
388 is_timeout_error = (
389 isinstance(e, (httpx.ReadTimeout, httpx.ConnectTimeout)) or
390 (hasattr(e, '__cause__') and isinstance(e.__cause__, (httpx.ReadTimeout, httpx.ConnectTimeout))) or
391 'ReadTimeout' in str(type(e)) or 'timeout' in str(e).lower()
392 )
394 # Print user-friendly error messages based on config
395 if is_timeout_error and self.config.verbose_timeouts:
396 # Get current timeout from the limiter
397 current_timeout = getattr(self.limiter, '_user_timeout', 120.0)
398 print(f"⏱️ TIMEOUT ERROR in batch item {item.id} (attempt {attempt + 1}):")
399 print(f" Current timeout setting: {current_timeout} seconds")
400 print(f" The request took longer than {current_timeout}s to complete.")
401 print("")
402 print("💡 How to fix this:")
403 print(f" 1. Increase timeout: ChatLimiter.for_model('{getattr(self.limiter, 'provider', 'your-model')}', timeout={current_timeout + 60})")
404 print(f" 2. Reduce concurrency: BatchConfig(max_concurrent_requests={max(1, self.config.max_concurrent_requests // 2)})")
405 print(f" 3. Current concurrency: {self.config.max_concurrent_requests} requests")
406 print("")
407 elif not is_timeout_error and self.config.verbose_exceptions:
408 print(f"❌ Exception in batch item {item.id} (attempt {attempt + 1}):")
410 if self.config.verbose_exceptions:
411 traceback.print_exc()
413 # If this is the last attempt or we should stop on error
414 if (
415 attempt == self.config.max_retries_per_item
416 or self.config.stop_on_first_error
417 ):
418 # Update progress bar on final failure
419 if progress_bar:
420 progress_bar.update(1)
422 return BatchResult(
423 item=item,
424 success=False,
425 error_message=str(e),
426 duration=time.time() - start_time,
427 attempt_count=item.attempt_count,
428 )
430 # Wait before retry - longer for timeout errors
431 if is_timeout_error:
432 # For timeout errors, wait longer and suggest more aggressive backing off
433 retry_delay = self.config.retry_delay * (3**attempt) # More aggressive backoff
434 else:
435 retry_delay = self.config.retry_delay * (2**attempt)
437 await asyncio.sleep(retry_delay)
439 # This should never be reached, but added for type checking
440 return BatchResult(
441 item=item,
442 success=False,
443 error_message="Unexpected error in retry logic",
444 duration=time.time() - start_time,
445 attempt_count=item.attempt_count,
446 )
448 def _process_item_sync_with_retry(
449 self,
450 item: BatchItem[BatchItemT],
451 progress_bar: tqdm | None = None,
452 ) -> BatchResult[BatchResultT]:
453 """Process a single item with retry logic (sync)."""
454 import time
456 start_time = time.time()
458 for attempt in range(self.config.max_retries_per_item + 1):
459 item.attempt_count = attempt + 1
461 try:
462 # Print request initiation if enabled
463 if self.config.print_request_initiation:
464 print(f"Sent request for batch item {item.id} (attempt {attempt + 1})")
466 # Process the item
467 result = self.process_item_sync(item)
469 # Update progress bar on success
470 if progress_bar:
471 progress_bar.update(1)
473 # Success
474 return BatchResult(
475 item=item,
476 result=result,
477 success=True,
478 duration=time.time() - start_time,
479 attempt_count=item.attempt_count,
480 )
482 except Exception as e:
483 item.last_error = e
485 # Print traceback if verbose exceptions is enabled
486 if self.config.verbose_exceptions:
487 print(f"Exception in batch item {item.id} (attempt {attempt + 1}):")
488 traceback.print_exc()
490 # If this is the last attempt or we should stop on error
491 if (
492 attempt == self.config.max_retries_per_item
493 or self.config.stop_on_first_error
494 ):
495 # Update progress bar on final failure
496 if progress_bar:
497 progress_bar.update(1)
499 return BatchResult(
500 item=item,
501 success=False,
502 error_message=str(e),
503 duration=time.time() - start_time,
504 attempt_count=item.attempt_count,
505 )
507 # Wait before retry
508 time.sleep(self.config.retry_delay * (2**attempt))
510 # This should never be reached, but added for type checking
511 return BatchResult(
512 item=item,
513 success=False,
514 error_message="Unexpected error in retry logic",
515 duration=time.time() - start_time,
516 attempt_count=item.attempt_count,
517 )
519 def get_success_rate(self) -> float:
520 """Get the success rate of the last batch."""
521 if not self._results:
522 return 0.0
524 successful = sum(1 for r in self._results if r.success)
525 return successful / len(self._results)
527 def get_successful_results(self) -> list[BatchResult[BatchResultT]]:
528 """Get only successful results."""
529 return [r for r in self._results if r.success]
531 def get_failed_results(self) -> list[BatchResult[BatchResultT]]:
532 """Get only failed results."""
533 return [r for r in self._results if not r.success]
535 def get_stats(self) -> dict[str, Any]:
536 """Get comprehensive processing statistics."""
537 if not self._results:
538 return {"total": 0, "successful": 0, "failed": 0, "success_rate": 0.0}
540 successful = self.get_successful_results()
541 failed = self.get_failed_results()
543 # Calculate timing statistics
544 durations = [r.duration for r in self._results]
545 avg_duration = sum(durations) / len(durations) if durations else 0
547 return {
548 "total": len(self._results),
549 "successful": len(successful),
550 "failed": len(failed),
551 "success_rate": self.get_success_rate(),
552 "avg_duration": avg_duration,
553 "total_duration": sum(durations),
554 "avg_attempts": sum(r.attempt_count for r in self._results)
555 / len(self._results),
556 }
558# High-level chat completion batch processing
559class ChatCompletionBatchProcessor(BatchProcessor[ChatCompletionRequest, ChatCompletionResponse]):
560 """High-level batch processor for chat completion requests."""
562 async def process_item(self, item: BatchItem[ChatCompletionRequest]) -> ChatCompletionResponse:
563 """Process a single chat completion request using high-level interface."""
564 request = item.data
566 # Check json_mode compatibility
567 if self.config.json_mode:
568 assert self.limiter.provider.value == "openai", \
569 f"json_mode is only supported with OpenAI provider, but got '{self.limiter.provider.value}'"
571 # Log prompt if enabled
572 if self.config.print_prompts:
573 print(f"\n--- PROMPT (Item {item.id}) ---")
574 print(f"MODEL: {request.model}")
575 for msg in request.messages:
576 print(f"{msg.role.value.upper()}: {msg.content}")
577 print("--- END PROMPT ---\n")
579 # Use the high-level chat completion method
580 kwargs = {
581 "model": request.model,
582 "messages": request.messages,
583 "max_tokens": request.max_tokens,
584 "temperature": request.temperature,
585 "top_p": request.top_p,
586 "stop": request.stop,
587 "stream": request.stream,
588 # Provider-specific parameters
589 "frequency_penalty": request.frequency_penalty,
590 "presence_penalty": request.presence_penalty,
591 "top_k": request.top_k,
592 "reasoning_effort": self.config.reasoning_effort,
593 }
595 # Add json_mode if enabled
596 if self.config.json_mode:
597 kwargs["response_format"] = {"type": "json_object"}
599 response = await self.limiter.chat_completion(**kwargs)
601 # Check for errors in the response
602 if not response.success:
603 raise Exception(f"Chat completion failed: {response.error_message}")
605 # Log response if enabled
606 if self.config.print_responses:
607 print(f"\n--- RESPONSE (Item {item.id}) ---")
608 print(f"MODEL: {response.model}")
609 if response.choices:
610 for i, choice in enumerate(response.choices):
611 print(f"CHOICE {i}: {choice.message.content}")
612 print("--- END RESPONSE ---\n")
614 return response
616 def process_item_sync(self, item: BatchItem[ChatCompletionRequest]) -> ChatCompletionResponse:
617 """Process a single chat completion request synchronously using high-level interface."""
618 request = item.data
620 # Check json_mode compatibility
621 if self.config.json_mode:
622 assert self.limiter.provider.value == "openai", \
623 f"json_mode is only supported with OpenAI provider, but got '{self.limiter.provider.value}'"
625 # Log prompt if enabled
626 if self.config.print_prompts:
627 print(f"\n--- PROMPT (Item {item.id}) ---")
628 print(f"MODEL: {request.model}")
629 for msg in request.messages:
630 print(f"{msg.role.value.upper()}: {msg.content}")
631 print("--- END PROMPT ---\n")
633 # Use the high-level chat completion method (sync)
634 kwargs = {
635 "model": request.model,
636 "messages": request.messages,
637 "max_tokens": request.max_tokens,
638 "temperature": request.temperature,
639 "top_p": request.top_p,
640 "stop": request.stop,
641 "stream": request.stream,
642 # Provider-specific parameters
643 "frequency_penalty": request.frequency_penalty,
644 "presence_penalty": request.presence_penalty,
645 "top_k": request.top_k,
646 "reasoning_effort": self.config.reasoning_effort,
647 }
649 # Add json_mode if enabled
650 if self.config.json_mode:
651 kwargs["response_format"] = {"type": "json_object"}
653 response = self.limiter.chat_completion_sync(**kwargs)
655 # Check for errors in the response
656 if not response.success:
657 raise Exception(f"Chat completion failed: {response.error_message}")
659 # Log response if enabled
660 if self.config.print_responses:
661 print(f"\n--- RESPONSE (Item {item.id}) ---")
662 print(f"MODEL: {response.model}")
663 if response.choices:
664 for i, choice in enumerate(response.choices):
665 print(f"CHOICE {i}: {choice.message.content}")
666 print("--- END RESPONSE ---\n")
668 return response
671# Convenience functions for high-level chat completion batches
672async def process_chat_completion_batch(
673 limiter: ChatLimiter,
674 requests: list[ChatCompletionRequest],
675 config: BatchConfig | None = None,
676) -> list[BatchResult[ChatCompletionResponse]]:
677 """
678 Process a batch of high-level chat completion requests.
680 Args:
681 limiter: Configured ChatLimiter instance
682 requests: List of ChatCompletionRequest objects
683 config: Optional batch processing configuration
685 Returns:
686 List of batch results containing ChatCompletionResponse objects
688 Example:
689 from chat_limiter import ChatLimiter, Message, MessageRole, ChatCompletionRequest
691 requests = [
692 ChatCompletionRequest(
693 model="gpt-4o",
694 messages=[Message(role=MessageRole.USER, content="Hello!")],
695 max_tokens=50
696 ),
697 ChatCompletionRequest(
698 model="gpt-4o",
699 messages=[Message(role=MessageRole.USER, content="How are you?")],
700 max_tokens=50
701 )
702 ]
704 async with ChatLimiter.for_model("gpt-4o", api_key) as limiter:
705 results = await process_chat_completion_batch(limiter, requests)
706 """
707 processor = ChatCompletionBatchProcessor(limiter, config)
708 return await processor.process_batch(requests)
711def process_chat_completion_batch_sync(
712 limiter: ChatLimiter,
713 requests: list[ChatCompletionRequest],
714 config: BatchConfig | None = None,
715) -> list[BatchResult[ChatCompletionResponse]]:
716 """
717 Process a batch of high-level chat completion requests synchronously.
719 Args:
720 limiter: Configured ChatLimiter instance
721 requests: List of ChatCompletionRequest objects
722 config: Optional batch processing configuration
724 Returns:
725 List of batch results containing ChatCompletionResponse objects
726 """
727 processor = ChatCompletionBatchProcessor(limiter, config)
728 return processor.process_batch_sync(requests)
731# Helper function for creating chat completion requests from simple data
732def create_chat_completion_requests(
733 model: str,
734 prompts: list[str],
735 max_tokens: int | None = None,
736 temperature: float | None = None,
737 **kwargs: Any,
738) -> list[ChatCompletionRequest]:
739 """
740 Create a list of ChatCompletionRequest objects from simple prompts.
742 Args:
743 model: The model to use for all requests
744 prompts: List of user prompts
745 max_tokens: Maximum tokens per completion
746 temperature: Sampling temperature
747 **kwargs: Additional parameters for all requests
749 Returns:
750 List of ChatCompletionRequest objects
752 Example:
753 requests = create_chat_completion_requests(
754 model="gpt-4o",
755 prompts=["Hello!", "How are you?", "What is Python?"],
756 max_tokens=50,
757 temperature=0.7
758 )
759 """
760 from .types import Message, MessageRole
762 requests = []
763 for prompt in prompts:
764 request = ChatCompletionRequest(
765 model=model,
766 messages=[Message(role=MessageRole.USER, content=prompt)],
767 max_tokens=max_tokens,
768 temperature=temperature,
769 **kwargs
770 )
771 requests.append(request)
773 return requests