Coverage for src/dataknobs_fsm/patterns/api_orchestration.py: 0%
281 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-20 16:46 -0600
1"""API orchestration pattern implementation.
3This module provides pre-configured FSM patterns for orchestrating API calls,
4including parallel requests, sequential workflows, rate limiting, and retries.
5"""
7from typing import Any, Dict, List, Union, Callable
8from dataclasses import dataclass
9from enum import Enum
10import asyncio
11from datetime import datetime, timedelta
13from ..api.simple import SimpleFSM
14from ..core.data_modes import DataHandlingMode
15from ..io.base import IOConfig, IOMode, IOFormat
16from ..io.utils import create_io_provider, retry_io_operation, IOMetrics
19class OrchestrationMode(Enum):
20 """API orchestration modes."""
21 SEQUENTIAL = "sequential" # Execute APIs one after another
22 PARALLEL = "parallel" # Execute APIs concurrently
23 FANOUT = "fanout" # One request triggers multiple APIs
24 PIPELINE = "pipeline" # Output of one API feeds into next
25 CONDITIONAL = "conditional" # Execute based on conditions
26 HYBRID = "hybrid" # Mix of above patterns
29@dataclass
30class APIEndpoint:
31 """Configuration for a single API endpoint."""
32 name: str
33 url: str
34 method: str = "GET"
35 headers: Dict[str, str] | None = None
36 params: Dict[str, Any] | None = None
37 body: Union[Dict[str, Any], str] | None = None
38 timeout: float = 30.0
39 retry_count: int = 3
40 retry_delay: float = 1.0
42 # Rate limiting
43 rate_limit: int | None = None # Requests per minute
44 burst_limit: int | None = None # Max burst size
46 # Response handling
47 response_parser: Callable[[Any], Any] | None = None
48 error_handler: Callable[[Exception], Any] | None = None
50 # Dependencies
51 depends_on: List[str] | None = None
52 transform_input: Callable[[Dict[str, Any]], Dict[str, Any]] | None = None
55@dataclass
56class APIOrchestrationConfig:
57 """Configuration for API orchestration."""
58 endpoints: List[APIEndpoint]
59 mode: OrchestrationMode = OrchestrationMode.SEQUENTIAL
61 # Global settings
62 max_concurrent: int = 10
63 total_timeout: float = 300.0
64 fail_fast: bool = False # Stop on first error
66 # Rate limiting (global)
67 global_rate_limit: int | None = None
68 rate_limit_window: int = 60 # seconds
70 # Result handling
71 result_merger: Callable[[List[Dict[str, Any]]], Any] | None = None
72 result_transformer: Callable[[Any], Any] | None = None
74 # Error handling
75 error_threshold: float = 0.1 # Max 10% errors
76 circuit_breaker_threshold: int = 5 # Consecutive failures
77 circuit_breaker_timeout: float = 60.0 # seconds
79 # Caching
80 cache_ttl: int | None = None # seconds
81 cache_key_generator: Callable[[APIEndpoint], str] | None = None
83 # Monitoring
84 metrics_enabled: bool = True
85 log_requests: bool = False
86 log_responses: bool = False
89class RateLimiter:
90 """Rate limiter for API calls."""
92 def __init__(self, rate_limit: int, window: int = 60):
93 """Initialize rate limiter.
95 Args:
96 rate_limit: Maximum requests per window
97 window: Time window in seconds
98 """
99 self.rate_limit = rate_limit
100 self.window = window
101 self.requests = []
102 self._lock = asyncio.Lock()
104 async def acquire(self) -> None:
105 """Acquire permission to make a request."""
106 while True:
107 async with self._lock:
108 now = datetime.now()
109 cutoff = now - timedelta(seconds=self.window)
111 # Remove old requests
112 self.requests = [t for t in self.requests if t > cutoff]
114 # Check if we can make a request
115 if len(self.requests) < self.rate_limit:
116 # Record this request and return
117 self.requests.append(now)
118 return
120 # Calculate wait time
121 oldest = self.requests[0]
122 wait_time = (oldest + timedelta(seconds=self.window) - now).total_seconds()
124 # Wait outside the lock
125 if wait_time > 0:
126 await asyncio.sleep(min(wait_time, 0.1)) # Sleep in small increments
129class CircuitBreaker:
130 """Circuit breaker for API calls."""
132 def __init__(self, threshold: int, timeout: float):
133 """Initialize circuit breaker.
135 Args:
136 threshold: Number of consecutive failures to trip
137 timeout: Time to wait before attempting reset
138 """
139 self.threshold = threshold
140 self.timeout = timeout
141 self.failure_count = 0
142 self.last_failure = None
143 self.is_open = False
144 self._lock = asyncio.Lock()
146 async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
147 """Call function with circuit breaker protection.
149 Args:
150 func: Function to call
151 *args: Any, **kwargs: Any: Function arguments
153 Returns:
154 Function result
156 Raises:
157 Exception: If circuit is open or function fails
158 """
159 async with self._lock:
160 # Check if circuit is open
161 if self.is_open:
162 if self.last_failure:
163 elapsed = (datetime.now() - self.last_failure).total_seconds() # type: ignore[unreachable]
164 if elapsed < self.timeout:
165 from ..core.exceptions import CircuitBreakerError
166 raise CircuitBreakerError(wait_time=self.timeout - elapsed)
167 # Try to reset
168 self.is_open = False
169 self.failure_count = 0
171 try:
172 # Execute function
173 if asyncio.iscoroutinefunction(func):
174 result = await func(*args, **kwargs)
175 else:
176 result = func(*args, **kwargs)
178 # Success - reset failure count
179 async with self._lock:
180 self.failure_count = 0
182 return result
184 except Exception:
185 # Record failure
186 async with self._lock:
187 self.failure_count += 1
188 self.last_failure = datetime.now()
190 if self.failure_count >= self.threshold:
191 self.is_open = True
193 raise
196class APIOrchestrator:
197 """API orchestrator using FSM pattern."""
199 def __init__(self, config: APIOrchestrationConfig):
200 """Initialize API orchestrator.
202 Args:
203 config: Orchestration configuration
204 """
205 self.config = config
206 self._fsm = self._build_fsm()
207 self._providers = {}
208 self._rate_limiters = {}
209 self._circuit_breakers = {}
210 self._cache = {}
211 self._metrics = IOMetrics() if config.metrics_enabled else None
213 # Initialize rate limiters
214 if config.global_rate_limit:
215 self._global_rate_limiter = RateLimiter(
216 config.global_rate_limit,
217 config.rate_limit_window
218 )
219 else:
220 self._global_rate_limiter = None
222 for endpoint in config.endpoints:
223 if endpoint.rate_limit:
224 self._rate_limiters[endpoint.name] = RateLimiter(
225 endpoint.rate_limit,
226 config.rate_limit_window
227 )
229 # Initialize circuit breakers
230 for endpoint in config.endpoints:
231 self._circuit_breakers[endpoint.name] = CircuitBreaker(
232 config.circuit_breaker_threshold,
233 config.circuit_breaker_timeout
234 )
236 def _build_fsm(self) -> SimpleFSM:
237 """Build FSM for API orchestration."""
238 # Create FSM configuration based on orchestration mode
239 # Add start state
240 states = [{'name': 'start', 'type': 'initial', 'is_start': True}]
241 arcs = []
243 if self.config.mode == OrchestrationMode.SEQUENTIAL:
244 # Create sequential states
245 for i, endpoint in enumerate(self.config.endpoints):
246 state_name = f"call_{endpoint.name}"
247 states.append({
248 'name': state_name,
249 'type': 'task'
250 })
252 if i == 0:
253 arcs.append({
254 'from': 'start',
255 'to': state_name,
256 'name': f'init_{endpoint.name}'
257 })
259 if i < len(self.config.endpoints) - 1:
260 next_state = f"call_{self.config.endpoints[i + 1].name}"
261 arcs.append({
262 'from': state_name,
263 'to': next_state,
264 'name': f'{endpoint.name}_to_{self.config.endpoints[i + 1].name}'
265 })
266 else:
267 arcs.append({
268 'from': state_name,
269 'to': 'end',
270 'name': f'{endpoint.name}_complete'
271 })
273 elif self.config.mode == OrchestrationMode.PARALLEL:
274 # Create parallel states with fork/join
275 states.append({'name': 'fork', 'type': 'fork'})
276 states.append({'name': 'join', 'type': 'join'})
278 arcs.append({'from': 'start', 'to': 'fork', 'name': 'init_parallel'})
280 for endpoint in self.config.endpoints:
281 state_name = f"call_{endpoint.name}"
282 states.append({
283 'name': state_name,
284 'type': 'task'
285 })
286 arcs.append({
287 'from': 'fork',
288 'to': state_name,
289 'name': f'fork_to_{endpoint.name}'
290 })
291 arcs.append({
292 'from': state_name,
293 'to': 'join',
294 'name': f'{endpoint.name}_to_join'
295 })
297 arcs.append({'from': 'join', 'to': 'end', 'name': 'parallel_complete'})
299 elif self.config.mode == OrchestrationMode.PIPELINE:
300 # Create pipeline with data transformation
301 for i, endpoint in enumerate(self.config.endpoints):
302 state_name = f"call_{endpoint.name}"
303 transform_name = f"transform_{endpoint.name}"
305 states.append({
306 'name': state_name,
307 'type': 'task'
308 })
310 if endpoint.transform_input:
311 states.append({
312 'name': transform_name,
313 'type': 'task'
314 })
316 if i == 0:
317 if endpoint.transform_input:
318 arcs.append({
319 'from': 'start',
320 'to': transform_name,
321 'name': f'init_transform_{endpoint.name}'
322 })
323 arcs.append({
324 'from': transform_name,
325 'to': state_name,
326 'name': f'transform_to_{endpoint.name}'
327 })
328 else:
329 arcs.append({
330 'from': 'start',
331 'to': state_name,
332 'name': f'init_{endpoint.name}'
333 })
335 if i < len(self.config.endpoints) - 1:
336 next_endpoint = self.config.endpoints[i + 1]
337 if next_endpoint.transform_input:
338 next_transform = f"transform_{next_endpoint.name}"
339 arcs.append({
340 'from': state_name,
341 'to': next_transform,
342 'name': f'{endpoint.name}_to_transform'
343 })
344 else:
345 next_state = f"call_{next_endpoint.name}"
346 arcs.append({
347 'from': state_name,
348 'to': next_state,
349 'name': f'{endpoint.name}_to_{next_endpoint.name}'
350 })
351 else:
352 arcs.append({
353 'from': state_name,
354 'to': 'end',
355 'name': f'{endpoint.name}_complete'
356 })
358 # Add end state
359 states.append({
360 'name': 'end',
361 'type': 'terminal'
362 })
364 # Build FSM configuration
365 fsm_config = {
366 'name': 'API_Orchestration',
367 'data_mode': DataHandlingMode.REFERENCE.value,
368 'states': states,
369 'arcs': arcs,
370 'resources': [] # HTTP providers created dynamically
371 }
373 return SimpleFSM(fsm_config)
375 def _create_provider(self, endpoint: APIEndpoint):
376 """Create I/O provider for endpoint.
378 Args:
379 endpoint: API endpoint configuration
381 Returns:
382 I/O provider instance
383 """
384 io_config = IOConfig(
385 mode=IOMode.READ if endpoint.method == "GET" else IOMode.WRITE,
386 format=IOFormat.API,
387 source=endpoint.url,
388 headers=endpoint.headers,
389 timeout=endpoint.timeout,
390 retry_count=endpoint.retry_count,
391 retry_delay=endpoint.retry_delay
392 )
394 return create_io_provider(io_config, is_async=True)
396 async def _call_endpoint(
397 self,
398 endpoint: APIEndpoint,
399 input_data: Dict[str, Any] | None = None
400 ) -> Any:
401 """Call a single API endpoint.
403 Args:
404 endpoint: Endpoint configuration
405 input_data: Input data for the endpoint
407 Returns:
408 API response
409 """
410 # Apply rate limiting
411 if self._global_rate_limiter:
412 await self._global_rate_limiter.acquire()
414 if endpoint.name in self._rate_limiters:
415 await self._rate_limiters[endpoint.name].acquire()
417 # Check cache
418 if self.config.cache_ttl and self.config.cache_key_generator:
419 cache_key = self.config.cache_key_generator(endpoint)
420 if cache_key in self._cache:
421 cached_data, cached_time = self._cache[cache_key]
422 if (datetime.now() - cached_time).total_seconds() < self.config.cache_ttl:
423 return cached_data
425 # Transform input if needed
426 if endpoint.transform_input and input_data:
427 request_data = endpoint.transform_input(input_data)
428 else:
429 request_data = endpoint.body or {}
431 # Create provider if not exists
432 if endpoint.name not in self._providers:
433 self._providers[endpoint.name] = self._create_provider(endpoint)
435 provider = self._providers[endpoint.name]
437 # Make API call with circuit breaker
438 circuit_breaker = self._circuit_breakers[endpoint.name]
440 async def make_request():
441 if not provider.is_open:
442 await provider.open()
444 if endpoint.method == "GET":
445 response = await provider.read(params=endpoint.params)
446 elif endpoint.method == "POST":
447 response = await provider.write(request_data, params=endpoint.params)
448 else:
449 # Handle other methods
450 response = await provider.read(params=endpoint.params)
452 return response
454 try:
455 # Execute with retry
456 response = await retry_io_operation(
457 lambda: circuit_breaker.call(make_request),
458 max_retries=endpoint.retry_count,
459 delay=endpoint.retry_delay
460 )
462 # Parse response if parser provided
463 if endpoint.response_parser:
464 response = endpoint.response_parser(response)
466 # Cache response
467 if self.config.cache_ttl and self.config.cache_key_generator:
468 cache_key = self.config.cache_key_generator(endpoint)
469 self._cache[cache_key] = (response, datetime.now())
471 # Record metrics
472 if self._metrics:
473 self._metrics.record_read()
475 return response
477 except Exception as e:
478 # Handle error
479 if endpoint.error_handler:
480 return endpoint.error_handler(e)
482 if self._metrics:
483 self._metrics.record_error()
485 if self.config.fail_fast:
486 raise
488 return None
490 async def orchestrate(
491 self,
492 input_data: Dict[str, Any] | None = None
493 ) -> Dict[str, Any]:
494 """Execute API orchestration.
496 Args:
497 input_data: Initial input data
499 Returns:
500 Orchestration results
501 """
502 results = {}
504 if self.config.mode == OrchestrationMode.SEQUENTIAL:
505 # Execute sequentially
506 current_data = input_data
507 for endpoint in self.config.endpoints:
508 result = await self._call_endpoint(endpoint, current_data)
509 results[endpoint.name] = result
510 current_data = result # Pass result to next
512 elif self.config.mode == OrchestrationMode.PARALLEL:
513 # Execute in parallel
514 tasks = []
515 for endpoint in self.config.endpoints:
516 task = self._call_endpoint(endpoint, input_data)
517 tasks.append((endpoint.name, task))
519 # Wait for all tasks
520 for name, task in tasks:
521 results[name] = await task
523 elif self.config.mode == OrchestrationMode.PIPELINE:
524 # Execute as pipeline
525 current_data = input_data
526 for endpoint in self.config.endpoints:
527 result = await self._call_endpoint(endpoint, current_data)
528 results[endpoint.name] = result
529 current_data = result # Pass result to next
531 elif self.config.mode == OrchestrationMode.FANOUT:
532 # Fan out to multiple endpoints
533 tasks = []
534 for endpoint in self.config.endpoints:
535 task = self._call_endpoint(endpoint, input_data)
536 tasks.append((endpoint.name, task))
538 # Gather results
539 for name, task in tasks:
540 results[name] = await task
542 # Merge results if merger provided
543 if self.config.result_merger:
544 merged = self.config.result_merger(list(results.values()))
545 results['merged'] = merged
547 # Transform final result if transformer provided
548 if self.config.result_transformer:
549 results = self.config.result_transformer(results)
551 # Get metrics
552 if self._metrics:
553 results['_metrics'] = self._metrics.get_metrics()
555 return results
557 async def close(self) -> None:
558 """Close all providers."""
559 for provider in self._providers.values():
560 if provider.is_open:
561 await provider.close()
564def create_rest_api_orchestrator(
565 base_url: str,
566 endpoints: List[Dict[str, Any]],
567 auth_token: str | None = None,
568 rate_limit: int = 60,
569 mode: OrchestrationMode = OrchestrationMode.SEQUENTIAL
570) -> APIOrchestrator:
571 """Create REST API orchestrator.
573 Args:
574 base_url: Base URL for all endpoints
575 endpoints: List of endpoint configurations
576 auth_token: Optional authentication token
577 rate_limit: Requests per minute
578 mode: Orchestration mode
580 Returns:
581 Configured API orchestrator
582 """
583 headers = {}
584 if auth_token:
585 headers['Authorization'] = f'Bearer {auth_token}'
587 api_endpoints = []
588 for ep in endpoints:
589 endpoint = APIEndpoint(
590 name=ep['name'],
591 url=f"{base_url}{ep['path']}",
592 method=ep.get('method', 'GET'),
593 headers={**headers, **ep.get('headers', {})},
594 params=ep.get('params'),
595 body=ep.get('body'),
596 depends_on=ep.get('depends_on'),
597 transform_input=ep.get('transform_input')
598 )
599 api_endpoints.append(endpoint)
601 config = APIOrchestrationConfig(
602 endpoints=api_endpoints,
603 mode=mode,
604 global_rate_limit=rate_limit,
605 metrics_enabled=True
606 )
608 return APIOrchestrator(config)
611def create_graphql_orchestrator(
612 endpoint: str,
613 queries: List[Dict[str, Any]],
614 auth_token: str | None = None,
615 batch_queries: bool = True
616) -> APIOrchestrator:
617 """Create GraphQL API orchestrator.
619 Args:
620 endpoint: GraphQL endpoint URL
621 queries: List of GraphQL queries
622 auth_token: Optional authentication token
623 batch_queries: Whether to batch queries
625 Returns:
626 Configured API orchestrator
627 """
628 headers = {'Content-Type': 'application/json'}
629 if auth_token:
630 headers['Authorization'] = f'Bearer {auth_token}'
632 if batch_queries:
633 # Create single batched endpoint
634 batched_query = {
635 'query': '\\n'.join(q['query'] for q in queries),
636 'variables': {}
637 }
638 for q in queries:
639 if 'variables' in q:
640 batched_query['variables'].update(q['variables'])
642 endpoint_config = APIEndpoint(
643 name='graphql_batch',
644 url=endpoint,
645 method='POST',
646 headers=headers,
647 body=batched_query
648 )
650 config = APIOrchestrationConfig(
651 endpoints=[endpoint_config],
652 mode=OrchestrationMode.SEQUENTIAL
653 )
654 else:
655 # Create separate endpoints for each query
656 api_endpoints = []
657 for q in queries:
658 endpoint_config = APIEndpoint(
659 name=q.get('name', f"query_{len(api_endpoints)}"),
660 url=endpoint,
661 method='POST',
662 headers=headers,
663 body={
664 'query': q['query'],
665 'variables': q.get('variables', {})
666 }
667 )
668 api_endpoints.append(endpoint_config)
670 config = APIOrchestrationConfig(
671 endpoints=api_endpoints,
672 mode=OrchestrationMode.PARALLEL
673 )
675 return APIOrchestrator(config)