Coverage for src/dataknobs_fsm/patterns/api_orchestration.py: 0%

281 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-20 16:46 -0600

1"""API orchestration pattern implementation. 

2 

3This module provides pre-configured FSM patterns for orchestrating API calls, 

4including parallel requests, sequential workflows, rate limiting, and retries. 

5""" 

6 

7from typing import Any, Dict, List, Union, Callable 

8from dataclasses import dataclass 

9from enum import Enum 

10import asyncio 

11from datetime import datetime, timedelta 

12 

13from ..api.simple import SimpleFSM 

14from ..core.data_modes import DataHandlingMode 

15from ..io.base import IOConfig, IOMode, IOFormat 

16from ..io.utils import create_io_provider, retry_io_operation, IOMetrics 

17 

18 

19class OrchestrationMode(Enum): 

20 """API orchestration modes.""" 

21 SEQUENTIAL = "sequential" # Execute APIs one after another 

22 PARALLEL = "parallel" # Execute APIs concurrently 

23 FANOUT = "fanout" # One request triggers multiple APIs 

24 PIPELINE = "pipeline" # Output of one API feeds into next 

25 CONDITIONAL = "conditional" # Execute based on conditions 

26 HYBRID = "hybrid" # Mix of above patterns 

27 

28 

29@dataclass 

30class APIEndpoint: 

31 """Configuration for a single API endpoint.""" 

32 name: str 

33 url: str 

34 method: str = "GET" 

35 headers: Dict[str, str] | None = None 

36 params: Dict[str, Any] | None = None 

37 body: Union[Dict[str, Any], str] | None = None 

38 timeout: float = 30.0 

39 retry_count: int = 3 

40 retry_delay: float = 1.0 

41 

42 # Rate limiting 

43 rate_limit: int | None = None # Requests per minute 

44 burst_limit: int | None = None # Max burst size 

45 

46 # Response handling 

47 response_parser: Callable[[Any], Any] | None = None 

48 error_handler: Callable[[Exception], Any] | None = None 

49 

50 # Dependencies 

51 depends_on: List[str] | None = None 

52 transform_input: Callable[[Dict[str, Any]], Dict[str, Any]] | None = None 

53 

54 

55@dataclass 

56class APIOrchestrationConfig: 

57 """Configuration for API orchestration.""" 

58 endpoints: List[APIEndpoint] 

59 mode: OrchestrationMode = OrchestrationMode.SEQUENTIAL 

60 

61 # Global settings 

62 max_concurrent: int = 10 

63 total_timeout: float = 300.0 

64 fail_fast: bool = False # Stop on first error 

65 

66 # Rate limiting (global) 

67 global_rate_limit: int | None = None 

68 rate_limit_window: int = 60 # seconds 

69 

70 # Result handling 

71 result_merger: Callable[[List[Dict[str, Any]]], Any] | None = None 

72 result_transformer: Callable[[Any], Any] | None = None 

73 

74 # Error handling 

75 error_threshold: float = 0.1 # Max 10% errors 

76 circuit_breaker_threshold: int = 5 # Consecutive failures 

77 circuit_breaker_timeout: float = 60.0 # seconds 

78 

79 # Caching 

80 cache_ttl: int | None = None # seconds 

81 cache_key_generator: Callable[[APIEndpoint], str] | None = None 

82 

83 # Monitoring 

84 metrics_enabled: bool = True 

85 log_requests: bool = False 

86 log_responses: bool = False 

87 

88 

89class RateLimiter: 

90 """Rate limiter for API calls.""" 

91 

92 def __init__(self, rate_limit: int, window: int = 60): 

93 """Initialize rate limiter. 

94  

95 Args: 

96 rate_limit: Maximum requests per window 

97 window: Time window in seconds 

98 """ 

99 self.rate_limit = rate_limit 

100 self.window = window 

101 self.requests = [] 

102 self._lock = asyncio.Lock() 

103 

104 async def acquire(self) -> None: 

105 """Acquire permission to make a request.""" 

106 while True: 

107 async with self._lock: 

108 now = datetime.now() 

109 cutoff = now - timedelta(seconds=self.window) 

110 

111 # Remove old requests 

112 self.requests = [t for t in self.requests if t > cutoff] 

113 

114 # Check if we can make a request 

115 if len(self.requests) < self.rate_limit: 

116 # Record this request and return 

117 self.requests.append(now) 

118 return 

119 

120 # Calculate wait time 

121 oldest = self.requests[0] 

122 wait_time = (oldest + timedelta(seconds=self.window) - now).total_seconds() 

123 

124 # Wait outside the lock 

125 if wait_time > 0: 

126 await asyncio.sleep(min(wait_time, 0.1)) # Sleep in small increments 

127 

128 

129class CircuitBreaker: 

130 """Circuit breaker for API calls.""" 

131 

132 def __init__(self, threshold: int, timeout: float): 

133 """Initialize circuit breaker. 

134  

135 Args: 

136 threshold: Number of consecutive failures to trip 

137 timeout: Time to wait before attempting reset 

138 """ 

139 self.threshold = threshold 

140 self.timeout = timeout 

141 self.failure_count = 0 

142 self.last_failure = None 

143 self.is_open = False 

144 self._lock = asyncio.Lock() 

145 

146 async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any: 

147 """Call function with circuit breaker protection. 

148  

149 Args: 

150 func: Function to call 

151 *args: Any, **kwargs: Any: Function arguments 

152  

153 Returns: 

154 Function result 

155  

156 Raises: 

157 Exception: If circuit is open or function fails 

158 """ 

159 async with self._lock: 

160 # Check if circuit is open 

161 if self.is_open: 

162 if self.last_failure: 

163 elapsed = (datetime.now() - self.last_failure).total_seconds() # type: ignore[unreachable] 

164 if elapsed < self.timeout: 

165 from ..core.exceptions import CircuitBreakerError 

166 raise CircuitBreakerError(wait_time=self.timeout - elapsed) 

167 # Try to reset 

168 self.is_open = False 

169 self.failure_count = 0 

170 

171 try: 

172 # Execute function 

173 if asyncio.iscoroutinefunction(func): 

174 result = await func(*args, **kwargs) 

175 else: 

176 result = func(*args, **kwargs) 

177 

178 # Success - reset failure count 

179 async with self._lock: 

180 self.failure_count = 0 

181 

182 return result 

183 

184 except Exception: 

185 # Record failure 

186 async with self._lock: 

187 self.failure_count += 1 

188 self.last_failure = datetime.now() 

189 

190 if self.failure_count >= self.threshold: 

191 self.is_open = True 

192 

193 raise 

194 

195 

196class APIOrchestrator: 

197 """API orchestrator using FSM pattern.""" 

198 

199 def __init__(self, config: APIOrchestrationConfig): 

200 """Initialize API orchestrator. 

201  

202 Args: 

203 config: Orchestration configuration 

204 """ 

205 self.config = config 

206 self._fsm = self._build_fsm() 

207 self._providers = {} 

208 self._rate_limiters = {} 

209 self._circuit_breakers = {} 

210 self._cache = {} 

211 self._metrics = IOMetrics() if config.metrics_enabled else None 

212 

213 # Initialize rate limiters 

214 if config.global_rate_limit: 

215 self._global_rate_limiter = RateLimiter( 

216 config.global_rate_limit, 

217 config.rate_limit_window 

218 ) 

219 else: 

220 self._global_rate_limiter = None 

221 

222 for endpoint in config.endpoints: 

223 if endpoint.rate_limit: 

224 self._rate_limiters[endpoint.name] = RateLimiter( 

225 endpoint.rate_limit, 

226 config.rate_limit_window 

227 ) 

228 

229 # Initialize circuit breakers 

230 for endpoint in config.endpoints: 

231 self._circuit_breakers[endpoint.name] = CircuitBreaker( 

232 config.circuit_breaker_threshold, 

233 config.circuit_breaker_timeout 

234 ) 

235 

236 def _build_fsm(self) -> SimpleFSM: 

237 """Build FSM for API orchestration.""" 

238 # Create FSM configuration based on orchestration mode 

239 # Add start state 

240 states = [{'name': 'start', 'type': 'initial', 'is_start': True}] 

241 arcs = [] 

242 

243 if self.config.mode == OrchestrationMode.SEQUENTIAL: 

244 # Create sequential states 

245 for i, endpoint in enumerate(self.config.endpoints): 

246 state_name = f"call_{endpoint.name}" 

247 states.append({ 

248 'name': state_name, 

249 'type': 'task' 

250 }) 

251 

252 if i == 0: 

253 arcs.append({ 

254 'from': 'start', 

255 'to': state_name, 

256 'name': f'init_{endpoint.name}' 

257 }) 

258 

259 if i < len(self.config.endpoints) - 1: 

260 next_state = f"call_{self.config.endpoints[i + 1].name}" 

261 arcs.append({ 

262 'from': state_name, 

263 'to': next_state, 

264 'name': f'{endpoint.name}_to_{self.config.endpoints[i + 1].name}' 

265 }) 

266 else: 

267 arcs.append({ 

268 'from': state_name, 

269 'to': 'end', 

270 'name': f'{endpoint.name}_complete' 

271 }) 

272 

273 elif self.config.mode == OrchestrationMode.PARALLEL: 

274 # Create parallel states with fork/join 

275 states.append({'name': 'fork', 'type': 'fork'}) 

276 states.append({'name': 'join', 'type': 'join'}) 

277 

278 arcs.append({'from': 'start', 'to': 'fork', 'name': 'init_parallel'}) 

279 

280 for endpoint in self.config.endpoints: 

281 state_name = f"call_{endpoint.name}" 

282 states.append({ 

283 'name': state_name, 

284 'type': 'task' 

285 }) 

286 arcs.append({ 

287 'from': 'fork', 

288 'to': state_name, 

289 'name': f'fork_to_{endpoint.name}' 

290 }) 

291 arcs.append({ 

292 'from': state_name, 

293 'to': 'join', 

294 'name': f'{endpoint.name}_to_join' 

295 }) 

296 

297 arcs.append({'from': 'join', 'to': 'end', 'name': 'parallel_complete'}) 

298 

299 elif self.config.mode == OrchestrationMode.PIPELINE: 

300 # Create pipeline with data transformation 

301 for i, endpoint in enumerate(self.config.endpoints): 

302 state_name = f"call_{endpoint.name}" 

303 transform_name = f"transform_{endpoint.name}" 

304 

305 states.append({ 

306 'name': state_name, 

307 'type': 'task' 

308 }) 

309 

310 if endpoint.transform_input: 

311 states.append({ 

312 'name': transform_name, 

313 'type': 'task' 

314 }) 

315 

316 if i == 0: 

317 if endpoint.transform_input: 

318 arcs.append({ 

319 'from': 'start', 

320 'to': transform_name, 

321 'name': f'init_transform_{endpoint.name}' 

322 }) 

323 arcs.append({ 

324 'from': transform_name, 

325 'to': state_name, 

326 'name': f'transform_to_{endpoint.name}' 

327 }) 

328 else: 

329 arcs.append({ 

330 'from': 'start', 

331 'to': state_name, 

332 'name': f'init_{endpoint.name}' 

333 }) 

334 

335 if i < len(self.config.endpoints) - 1: 

336 next_endpoint = self.config.endpoints[i + 1] 

337 if next_endpoint.transform_input: 

338 next_transform = f"transform_{next_endpoint.name}" 

339 arcs.append({ 

340 'from': state_name, 

341 'to': next_transform, 

342 'name': f'{endpoint.name}_to_transform' 

343 }) 

344 else: 

345 next_state = f"call_{next_endpoint.name}" 

346 arcs.append({ 

347 'from': state_name, 

348 'to': next_state, 

349 'name': f'{endpoint.name}_to_{next_endpoint.name}' 

350 }) 

351 else: 

352 arcs.append({ 

353 'from': state_name, 

354 'to': 'end', 

355 'name': f'{endpoint.name}_complete' 

356 }) 

357 

358 # Add end state 

359 states.append({ 

360 'name': 'end', 

361 'type': 'terminal' 

362 }) 

363 

364 # Build FSM configuration 

365 fsm_config = { 

366 'name': 'API_Orchestration', 

367 'data_mode': DataHandlingMode.REFERENCE.value, 

368 'states': states, 

369 'arcs': arcs, 

370 'resources': [] # HTTP providers created dynamically 

371 } 

372 

373 return SimpleFSM(fsm_config) 

374 

375 def _create_provider(self, endpoint: APIEndpoint): 

376 """Create I/O provider for endpoint. 

377  

378 Args: 

379 endpoint: API endpoint configuration 

380  

381 Returns: 

382 I/O provider instance 

383 """ 

384 io_config = IOConfig( 

385 mode=IOMode.READ if endpoint.method == "GET" else IOMode.WRITE, 

386 format=IOFormat.API, 

387 source=endpoint.url, 

388 headers=endpoint.headers, 

389 timeout=endpoint.timeout, 

390 retry_count=endpoint.retry_count, 

391 retry_delay=endpoint.retry_delay 

392 ) 

393 

394 return create_io_provider(io_config, is_async=True) 

395 

396 async def _call_endpoint( 

397 self, 

398 endpoint: APIEndpoint, 

399 input_data: Dict[str, Any] | None = None 

400 ) -> Any: 

401 """Call a single API endpoint. 

402  

403 Args: 

404 endpoint: Endpoint configuration 

405 input_data: Input data for the endpoint 

406  

407 Returns: 

408 API response 

409 """ 

410 # Apply rate limiting 

411 if self._global_rate_limiter: 

412 await self._global_rate_limiter.acquire() 

413 

414 if endpoint.name in self._rate_limiters: 

415 await self._rate_limiters[endpoint.name].acquire() 

416 

417 # Check cache 

418 if self.config.cache_ttl and self.config.cache_key_generator: 

419 cache_key = self.config.cache_key_generator(endpoint) 

420 if cache_key in self._cache: 

421 cached_data, cached_time = self._cache[cache_key] 

422 if (datetime.now() - cached_time).total_seconds() < self.config.cache_ttl: 

423 return cached_data 

424 

425 # Transform input if needed 

426 if endpoint.transform_input and input_data: 

427 request_data = endpoint.transform_input(input_data) 

428 else: 

429 request_data = endpoint.body or {} 

430 

431 # Create provider if not exists 

432 if endpoint.name not in self._providers: 

433 self._providers[endpoint.name] = self._create_provider(endpoint) 

434 

435 provider = self._providers[endpoint.name] 

436 

437 # Make API call with circuit breaker 

438 circuit_breaker = self._circuit_breakers[endpoint.name] 

439 

440 async def make_request(): 

441 if not provider.is_open: 

442 await provider.open() 

443 

444 if endpoint.method == "GET": 

445 response = await provider.read(params=endpoint.params) 

446 elif endpoint.method == "POST": 

447 response = await provider.write(request_data, params=endpoint.params) 

448 else: 

449 # Handle other methods 

450 response = await provider.read(params=endpoint.params) 

451 

452 return response 

453 

454 try: 

455 # Execute with retry 

456 response = await retry_io_operation( 

457 lambda: circuit_breaker.call(make_request), 

458 max_retries=endpoint.retry_count, 

459 delay=endpoint.retry_delay 

460 ) 

461 

462 # Parse response if parser provided 

463 if endpoint.response_parser: 

464 response = endpoint.response_parser(response) 

465 

466 # Cache response 

467 if self.config.cache_ttl and self.config.cache_key_generator: 

468 cache_key = self.config.cache_key_generator(endpoint) 

469 self._cache[cache_key] = (response, datetime.now()) 

470 

471 # Record metrics 

472 if self._metrics: 

473 self._metrics.record_read() 

474 

475 return response 

476 

477 except Exception as e: 

478 # Handle error 

479 if endpoint.error_handler: 

480 return endpoint.error_handler(e) 

481 

482 if self._metrics: 

483 self._metrics.record_error() 

484 

485 if self.config.fail_fast: 

486 raise 

487 

488 return None 

489 

490 async def orchestrate( 

491 self, 

492 input_data: Dict[str, Any] | None = None 

493 ) -> Dict[str, Any]: 

494 """Execute API orchestration. 

495  

496 Args: 

497 input_data: Initial input data 

498  

499 Returns: 

500 Orchestration results 

501 """ 

502 results = {} 

503 

504 if self.config.mode == OrchestrationMode.SEQUENTIAL: 

505 # Execute sequentially 

506 current_data = input_data 

507 for endpoint in self.config.endpoints: 

508 result = await self._call_endpoint(endpoint, current_data) 

509 results[endpoint.name] = result 

510 current_data = result # Pass result to next 

511 

512 elif self.config.mode == OrchestrationMode.PARALLEL: 

513 # Execute in parallel 

514 tasks = [] 

515 for endpoint in self.config.endpoints: 

516 task = self._call_endpoint(endpoint, input_data) 

517 tasks.append((endpoint.name, task)) 

518 

519 # Wait for all tasks 

520 for name, task in tasks: 

521 results[name] = await task 

522 

523 elif self.config.mode == OrchestrationMode.PIPELINE: 

524 # Execute as pipeline 

525 current_data = input_data 

526 for endpoint in self.config.endpoints: 

527 result = await self._call_endpoint(endpoint, current_data) 

528 results[endpoint.name] = result 

529 current_data = result # Pass result to next 

530 

531 elif self.config.mode == OrchestrationMode.FANOUT: 

532 # Fan out to multiple endpoints 

533 tasks = [] 

534 for endpoint in self.config.endpoints: 

535 task = self._call_endpoint(endpoint, input_data) 

536 tasks.append((endpoint.name, task)) 

537 

538 # Gather results 

539 for name, task in tasks: 

540 results[name] = await task 

541 

542 # Merge results if merger provided 

543 if self.config.result_merger: 

544 merged = self.config.result_merger(list(results.values())) 

545 results['merged'] = merged 

546 

547 # Transform final result if transformer provided 

548 if self.config.result_transformer: 

549 results = self.config.result_transformer(results) 

550 

551 # Get metrics 

552 if self._metrics: 

553 results['_metrics'] = self._metrics.get_metrics() 

554 

555 return results 

556 

557 async def close(self) -> None: 

558 """Close all providers.""" 

559 for provider in self._providers.values(): 

560 if provider.is_open: 

561 await provider.close() 

562 

563 

564def create_rest_api_orchestrator( 

565 base_url: str, 

566 endpoints: List[Dict[str, Any]], 

567 auth_token: str | None = None, 

568 rate_limit: int = 60, 

569 mode: OrchestrationMode = OrchestrationMode.SEQUENTIAL 

570) -> APIOrchestrator: 

571 """Create REST API orchestrator. 

572  

573 Args: 

574 base_url: Base URL for all endpoints 

575 endpoints: List of endpoint configurations 

576 auth_token: Optional authentication token 

577 rate_limit: Requests per minute 

578 mode: Orchestration mode 

579  

580 Returns: 

581 Configured API orchestrator 

582 """ 

583 headers = {} 

584 if auth_token: 

585 headers['Authorization'] = f'Bearer {auth_token}' 

586 

587 api_endpoints = [] 

588 for ep in endpoints: 

589 endpoint = APIEndpoint( 

590 name=ep['name'], 

591 url=f"{base_url}{ep['path']}", 

592 method=ep.get('method', 'GET'), 

593 headers={**headers, **ep.get('headers', {})}, 

594 params=ep.get('params'), 

595 body=ep.get('body'), 

596 depends_on=ep.get('depends_on'), 

597 transform_input=ep.get('transform_input') 

598 ) 

599 api_endpoints.append(endpoint) 

600 

601 config = APIOrchestrationConfig( 

602 endpoints=api_endpoints, 

603 mode=mode, 

604 global_rate_limit=rate_limit, 

605 metrics_enabled=True 

606 ) 

607 

608 return APIOrchestrator(config) 

609 

610 

611def create_graphql_orchestrator( 

612 endpoint: str, 

613 queries: List[Dict[str, Any]], 

614 auth_token: str | None = None, 

615 batch_queries: bool = True 

616) -> APIOrchestrator: 

617 """Create GraphQL API orchestrator. 

618  

619 Args: 

620 endpoint: GraphQL endpoint URL 

621 queries: List of GraphQL queries 

622 auth_token: Optional authentication token 

623 batch_queries: Whether to batch queries 

624  

625 Returns: 

626 Configured API orchestrator 

627 """ 

628 headers = {'Content-Type': 'application/json'} 

629 if auth_token: 

630 headers['Authorization'] = f'Bearer {auth_token}' 

631 

632 if batch_queries: 

633 # Create single batched endpoint 

634 batched_query = { 

635 'query': '\\n'.join(q['query'] for q in queries), 

636 'variables': {} 

637 } 

638 for q in queries: 

639 if 'variables' in q: 

640 batched_query['variables'].update(q['variables']) 

641 

642 endpoint_config = APIEndpoint( 

643 name='graphql_batch', 

644 url=endpoint, 

645 method='POST', 

646 headers=headers, 

647 body=batched_query 

648 ) 

649 

650 config = APIOrchestrationConfig( 

651 endpoints=[endpoint_config], 

652 mode=OrchestrationMode.SEQUENTIAL 

653 ) 

654 else: 

655 # Create separate endpoints for each query 

656 api_endpoints = [] 

657 for q in queries: 

658 endpoint_config = APIEndpoint( 

659 name=q.get('name', f"query_{len(api_endpoints)}"), 

660 url=endpoint, 

661 method='POST', 

662 headers=headers, 

663 body={ 

664 'query': q['query'], 

665 'variables': q.get('variables', {}) 

666 } 

667 ) 

668 api_endpoints.append(endpoint_config) 

669 

670 config = APIOrchestrationConfig( 

671 endpoints=api_endpoints, 

672 mode=OrchestrationMode.PARALLEL 

673 ) 

674 

675 return APIOrchestrator(config)