Coverage for src/otg_mcp/client.py: 34%
703 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-19 00:42 -0700
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-19 00:42 -0700
1"""
2OTG Client module providing a direct interface to traffic generator APIs.
4This simplified client provides a single entry point for traffic generator operations
5using snappi API directly, with proper target management and version detection.
6"""
8import logging
9import os
10import time
11import traceback
12import uuid
13from dataclasses import dataclass, field
14from typing import Any, Dict, List, Optional, Union
16import aiohttp
17import snappi # type: ignore
19from otg_mcp.client_capture import get_capture, start_capture, stop_capture
20from otg_mcp.config import Config
21from otg_mcp.models import (
22 CapabilitiesVersionResponse,
23 CaptureResponse,
24 ConfigResponse,
25 ControlResponse,
26 HealthStatus,
27 MetricsResponse,
28 PortInfo,
29 TargetHealthInfo,
30 TrafficGeneratorInfo,
31 TrafficGeneratorStatus,
32)
33from otg_mcp.schema_registry import SchemaRegistry
35logger = logging.getLogger(__name__)
36logger.setLevel(logging.INFO)
39@dataclass
40class OtgClient:
41 """
42 Client for OTG traffic generator operations using snappi.
44 This client provides a unified interface for all traffic generator operations,
45 handling target resolution, API version differences, and client caching.
46 """
48 config: Config
49 api_clients: Dict[str, Any] = field(default_factory=dict)
50 schema_registry: Optional[SchemaRegistry] = field(default=None)
52 def __post_init__(self):
53 """Initialize after dataclass initialization."""
54 logger.info("Initializing OTG client")
56 logger.debug("Checking if we need to create a SchemaRegistry")
57 if self.schema_registry is None:
58 logger.info("No SchemaRegistry provided, creating one")
59 custom_schema_path = None
60 if self.config.schemas.schema_path:
61 custom_schema_path = self.config.schemas.schema_path
62 logger.info(
63 f"Using custom schema path from config: {custom_schema_path}"
64 )
66 self.schema_registry = SchemaRegistry(custom_schema_path)
67 logger.info("Created new SchemaRegistry instance")
68 else:
69 logger.info("Using provided SchemaRegistry instance")
71 logger.info("OTG client initialized")
73 def _get_api_client(self, target: str):
74 """
75 Get or create API client for target.
77 Args:
78 target: Target ID (required)
80 Returns:
81 Tuple of (snappi API client, API capabilities dict)
82 """
83 logger.info(f"Getting API client for target {target}")
85 logger.info(f"Checking if client is cached for target {target}")
86 if target in self.api_clients: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 logger.info(f"Using cached API client for target {target}")
88 return self.api_clients[target]
90 logger.info(f"Resolving location for target {target}")
91 location = self._get_location_for_target(target)
92 logger.info(f"Target {target} resolved to location {location}")
94 logger.info(f"Creating new snappi API client for {location}")
95 api = snappi.api(location=location, verify=False)
97 logger.info("Detecting API capabilities through schema introspection")
98 api_schema = self._discover_api_schema(api)
99 logger.info(f"API schema detected: version={api_schema['version']}")
101 logger.info(f"Caching client for target {target}")
102 self.api_clients[target] = api
104 return api
106 def _get_location_for_target(self, target: str) -> str:
107 """
108 Get location string for target.
110 Args:
111 target: Target ID
113 Returns:
114 Location string for snappi client
115 """
116 logger.info(f"Creating URL for direct connection to target {target}")
117 return f"https://{target}" if ":" not in target else f"https://{target}"
119 def _discover_api_schema(self, api) -> Dict[str, Any]:
120 """
121 Discover API capabilities through introspection.
123 Args:
124 api: Snappi API client
126 Returns:
127 Dictionary of API capabilities
128 """
129 logger.info("Getting all available API methods through introspection")
130 methods = [
131 m for m in dir(api) if not m.startswith("_") and callable(getattr(api, m))
132 ]
134 logger.info("Detecting API version and capabilities")
135 schema = {
136 "methods": methods,
137 "version": self._get_api_version(api),
138 "has_control_state": hasattr(api, "control_state"),
139 "has_transmit_state": hasattr(api, "transmit_state"),
140 "has_start_transmit": hasattr(api, "start_transmit"),
141 "has_stop_transmit": hasattr(api, "stop_transmit"),
142 "has_set_flow_transmit": hasattr(api, "set_flow_transmit"),
143 }
145 return schema
147 def _get_api_version(self, api) -> str:
148 """
149 Get the API version.
151 Args:
152 api: Snappi API client
154 Returns:
155 API version string
156 """
157 if hasattr(api, "__version__"): 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true
158 return str(api.__version__)
159 elif hasattr(snappi, "__version__"): 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true
160 return str(snappi.__version__)
161 return "unknown"
163 def _start_traffic(self, api) -> None:
164 """
165 Start traffic using the appropriate method based on API version.
167 Args:
168 api: Snappi API client
169 """
170 logger.info("Starting traffic generation")
172 if hasattr(api, "start_transmit"):
173 logger.info("Using start_transmit() method")
174 api.start_transmit()
175 elif hasattr(api, "set_flow_transmit"):
176 logger.info("Using set_flow_transmit() method")
177 api.set_flow_transmit(state="start")
178 elif hasattr(api, "control_state"):
179 logger.info("Using control_state() method")
180 self._start_traffic_control_state(api)
181 else:
182 raise NotImplementedError("No method available to start traffic")
184 def _start_traffic_control_state(self, api) -> None:
185 """
186 Start traffic using control_state API.
188 Args:
189 api: Snappi API client
190 """
191 cs = api.control_state()
192 cs.choice = cs.TRAFFIC
194 if not hasattr(cs, "traffic"):
195 raise AttributeError("control_state object does not have traffic attribute")
197 cs.traffic.choice = cs.traffic.FLOW_TRANSMIT
199 if not hasattr(cs.traffic, "flow_transmit"):
200 raise AttributeError("traffic object does not have flow_transmit attribute")
202 logger.debug("Handling different versions of control_state API")
203 if hasattr(cs.traffic.flow_transmit, "START"):
204 cs.traffic.flow_transmit.state = cs.traffic.flow_transmit.START
205 else:
206 cs.traffic.flow_transmit.state = "start"
208 api.set_control_state(cs)
210 def _stop_traffic(self, api) -> bool:
211 """
212 Stop traffic using the appropriate method based on API version.
214 Args:
215 api: Snappi API client
217 Returns:
218 True if traffic was successfully stopped, False otherwise
219 """
220 logger.info("Stopping traffic generation")
222 methods = [
223 self._stop_traffic_direct,
224 self._stop_traffic_transmit,
225 self._stop_traffic_control_state,
226 self._stop_traffic_flow_transmit,
227 ]
229 for method in methods:
230 try:
231 method(api)
232 logger.info(f"Successfully stopped traffic using {method.__name__}")
233 return self._verify_traffic_stopped(api)
234 except Exception as e:
235 logger.info(f"Failed to stop traffic using {method.__name__}: {e}")
236 continue
238 logger.warning("All methods to stop traffic failed")
239 return False
241 def _stop_traffic_direct(self, api) -> None:
242 """
243 Stop traffic using stop_transmit method.
245 Args:
246 api: Snappi API client
247 """
248 if not hasattr(api, "stop_transmit"):
249 raise AttributeError("stop_transmit method not available")
250 api.stop_transmit()
252 def _stop_traffic_transmit(self, api) -> None:
253 """
254 Stop traffic using transmit_state method.
256 Args:
257 api: Snappi API client
258 """
259 if not hasattr(api, "transmit_state"):
260 raise AttributeError("transmit_state method not available")
261 ts = api.transmit_state()
262 ts.state = ts.STOP
263 api.set_transmit_state(ts)
265 def _stop_traffic_control_state(self, api) -> None:
266 """
267 Stop traffic using control_state method.
269 Args:
270 api: Snappi API client
271 """
272 if not hasattr(api, "control_state"):
273 raise AttributeError("control_state method not available")
275 cs = api.control_state()
276 cs.choice = cs.TRAFFIC
278 if not hasattr(cs, "traffic"):
279 raise AttributeError("control_state object does not have traffic attribute")
281 cs.traffic.choice = cs.traffic.FLOW_TRANSMIT
283 if not hasattr(cs.traffic, "flow_transmit"):
284 raise AttributeError("traffic object does not have flow_transmit attribute")
286 logger.debug(
287 "Handling different versions of control_state API for stopping traffic"
288 )
289 if hasattr(cs.traffic.flow_transmit, "STOP"):
290 cs.traffic.flow_transmit.state = cs.traffic.flow_transmit.STOP
291 else:
292 cs.traffic.flow_transmit.state = "stop"
294 api.set_control_state(cs)
296 def _stop_traffic_flow_transmit(self, api) -> None:
297 """
298 Stop traffic using set_flow_transmit method.
300 Args:
301 api: Snappi API client
302 """
303 if not hasattr(api, "set_flow_transmit"):
304 raise AttributeError("set_flow_transmit method not available")
305 api.set_flow_transmit(state="stop")
307 def _verify_traffic_stopped(self, api, timeout=5, threshold=0.1) -> bool:
308 """
309 Verify that traffic has actually stopped by checking metrics.
311 Args:
312 api: Snappi API client
313 timeout: Maximum time in seconds to wait
314 threshold: Threshold below which traffic is considered stopped
316 Returns:
317 True if traffic is stopped, False otherwise
318 """
319 logger.info(f"Verifying traffic has stopped (timeout={timeout}s)")
321 start_time = time.time()
322 while time.time() - start_time < timeout:
323 try:
324 logger.debug("Getting flow metrics")
325 request = api.metrics_request()
326 metrics = api.get_metrics(request)
328 logger.debug("Checking if there are any flow metrics")
329 if (
330 not hasattr(metrics, "flow_metrics")
331 or len(metrics.flow_metrics) == 0
332 ):
333 logger.info(
334 "No flow metrics available, assuming traffic is stopped"
335 )
336 return True
338 logger.debug("Checking if all flows have stopped")
339 all_stopped = True
340 for flow in metrics.flow_metrics:
341 if (
342 hasattr(flow, "frames_tx_rate")
343 and flow.frames_tx_rate >= threshold
344 ):
345 logger.info(
346 f"Flow {getattr(flow, 'name', 'unknown')} still running with rate {flow.frames_tx_rate}"
347 )
348 all_stopped = False
349 break
351 if all_stopped:
352 logger.info("All flows verified stopped")
353 return True
354 except Exception as e:
355 logger.warning(f"Error checking traffic status: {str(e)}")
357 time.sleep(0.5)
359 logger.warning(f"Timed out waiting for traffic to stop after {timeout}s")
360 return False
362 def _get_metrics(self, api, flow_names=None, port_names=None):
363 """
364 Get metrics from the API.
366 Args:
367 api: Snappi API client
368 flow_names: Optional list of flow names
369 port_names: Optional list of port names
371 Returns:
372 Metrics object
373 """
374 request = api.metrics_request()
376 if flow_names:
377 request.flow.flow_names = flow_names
379 if port_names:
380 request.port.port_names = port_names
382 return api.get_metrics(request)
384 def _start_capture(self, api: Any, port_names: Union[str, List[str]]) -> None:
385 """
386 Start packet capture on one or more ports.
388 Args:
389 api: Snappi API client
390 port_names: List or single name of port(s) to capture on
391 """
392 logger.info(f"Starting capture for ports: {port_names}")
394 logger.debug("Converting port names to list for consistent handling")
395 port_list = [port_names] if isinstance(port_names, str) else list(port_names)
397 logger.debug("Detecting available API methods for capture")
398 api_methods = [method for method in dir(api) if not method.startswith("_")]
399 logger.debug(f"Available API methods: {api_methods}")
401 logger.info("Trying multiple methods to start capture based on available API")
402 try:
403 if "capture_state" in api_methods:
404 logger.info("Using capture_state() method")
405 cs = api.capture_state()
406 cs.state = "start"
407 cs.port_names = port_list
408 api.set_capture_state(cs)
409 elif "start_capture" in api_methods:
410 logger.info("Using start_capture() method")
411 for port in port_list:
412 api.start_capture(port_name=port)
413 elif "control_state" in api_methods:
414 logger.info("Using control_state() method for capture")
415 cs = api.control_state()
417 logger.debug("Checking if there's a CAPTURE choice available")
418 if hasattr(cs, "CAPTURE") and hasattr(cs, "choice"):
419 logger.debug("Setting control_state choice to CAPTURE")
420 cs.choice = cs.CAPTURE
422 logger.debug("Checking for capture attribute in control_state")
423 if hasattr(cs, "capture"):
424 logger.debug("Found capture attribute in control_state")
426 if hasattr(cs.capture, "port_names"):
427 logger.debug("Setting port_names in capture")
428 cs.capture.port_names = port_list
430 if hasattr(cs.capture, "state"):
431 logger.debug("Setting capture state to start")
432 if hasattr(cs.capture, "START"):
433 cs.capture.state = cs.capture.START
434 else:
435 cs.capture.state = "start"
437 logger.info(
438 f"Setting control state to start capture on ports: {port_list}"
439 )
440 api.set_control_state(cs)
441 else:
442 logger.error("No compatible capture method found in API")
443 raise NotImplementedError("No method available to start capture")
444 except Exception as e:
445 logger.error(f"Error starting capture: {e}")
446 raise
448 def _stop_capture(self, api: Any, port_names: Union[str, List[str]]) -> None:
449 """
450 Stop packet capture on one or more ports.
452 Args:
453 api: Snappi API client
454 port_names: List or single name of port(s) to stop capture on
455 """
456 logger.info(f"Stopping capture for ports: {port_names}")
458 logger.debug("Converting port names to list for consistent handling")
459 port_list = [port_names] if isinstance(port_names, str) else list(port_names)
461 logger.debug("Detecting available API methods for capture")
462 api_methods = [method for method in dir(api) if not method.startswith("_")]
463 logger.debug(f"Available API methods: {api_methods}")
465 try:
466 if "capture_state" in api_methods:
467 logger.info("Using capture_state() method")
468 cs = api.capture_state()
469 cs.state = "stop"
470 cs.port_names = port_list
471 api.set_capture_state(cs)
472 elif "stop_capture" in api_methods:
473 logger.info("Using stop_capture() method")
474 for port in port_list:
475 api.stop_capture(port_name=port)
476 elif "control_state" in api_methods:
477 logger.info("Using control_state() method for capture")
478 cs = api.control_state()
480 if hasattr(cs, "CAPTURE") and hasattr(cs, "choice"):
481 logger.debug("Setting control_state choice to CAPTURE")
482 cs.choice = cs.CAPTURE
484 if hasattr(cs, "capture"):
485 logger.debug("Found capture attribute in control_state")
487 if hasattr(cs.capture, "port_names"):
488 logger.debug("Setting port_names in capture")
489 cs.capture.port_names = port_list
491 if hasattr(cs.capture, "state"):
492 logger.debug("Setting capture state to stop")
493 if hasattr(cs.capture, "STOP"):
494 cs.capture.state = cs.capture.STOP
495 else:
496 cs.capture.state = "stop"
498 logger.info(
499 f"Setting control state to stop capture on ports: {port_list}"
500 )
501 api.set_control_state(cs)
502 else:
503 logger.error("No compatible capture method found in API")
504 raise NotImplementedError("No method available to stop capture")
505 except Exception as e:
506 logger.error(f"Error stopping capture: {e}")
507 raise
509 def _get_capture(
510 self, api: Any, port_name: str, output_dir: Optional[str] = None
511 ) -> str:
512 """
513 Get capture data and save to a file.
515 Args:
516 api: Snappi API client
517 port_name: Name of port to get capture from
518 output_dir: Directory to save the capture file (default: /tmp)
520 Returns:
521 File path where the capture was saved
522 """
523 if output_dir is None:
524 logger.debug("Using default output directory: /tmp")
525 output_dir = "/tmp"
527 logger.debug(f"Creating output directory if it doesn't exist: {output_dir}")
528 os.makedirs(output_dir, exist_ok=True)
530 logger.debug("Generating unique file name for capture data")
531 file_name = f"capture_{port_name}_{uuid.uuid4().hex[:8]}.pcap"
532 file_path = os.path.join(output_dir, file_name)
534 logger.info(f"Getting capture data for port {port_name}")
536 logger.debug("Detecting available API methods for capture")
537 api_methods = [method for method in dir(api) if not method.startswith("_")]
538 logger.debug(f"Available API methods: {api_methods}")
540 try:
541 if "capture_request" in api_methods and "get_capture" in api_methods:
542 logger.info("Using capture_request() and get_capture() methods")
543 req = api.capture_request()
544 req.port_name = port_name
545 capture_data = api.get_capture(req)
547 logger.info(f"Saving capture data to {file_path}")
548 with open(file_path, "wb") as pcap:
549 pcap.write(capture_data.read())
551 elif "control_state" in api_methods:
552 logger.info("Using control_state() method for capture retrieval")
553 cs = api.control_state()
555 if hasattr(cs, "CAPTURE") and hasattr(cs, "choice"):
556 logger.debug("Setting control_state choice to CAPTURE")
557 cs.choice = cs.CAPTURE
559 if hasattr(cs, "capture"):
560 logger.debug("Found capture attribute in control_state")
562 if hasattr(cs.capture, "port_name"):
563 logger.debug(f"Setting port_name to {port_name}")
564 cs.capture.port_name = port_name
566 if hasattr(cs.capture, "state"):
567 logger.debug("Setting capture state to RETRIEVE")
568 if hasattr(cs.capture, "RETRIEVE"):
569 cs.capture.state = cs.capture.RETRIEVE
570 else:
571 cs.capture.state = "retrieve"
573 logger.info(
574 f"Setting control state to retrieve capture on port {port_name}"
575 )
576 result = api.set_control_state(cs)
578 if hasattr(result, "capture") and hasattr(result.capture, "data"):
579 logger.info(f"Saving capture data to {file_path}")
580 with open(file_path, "wb") as pcap:
581 pcap.write(result.capture.data)
582 else:
583 raise ValueError(
584 f"No capture data found in control_state result: {result}"
585 )
586 else:
587 logger.error("No compatible capture retrieval method found in API")
588 raise NotImplementedError("No method available to get capture data")
590 except Exception as e:
591 logger.error(f"Error getting capture data: {e}")
592 raise
594 logger.info(f"Capture data saved to {file_path}")
595 return file_path
597 async def get_traffic_generators_status(self):
598 """Legacy method that maps to list_traffic_generators."""
599 logger.info("Legacy call to get_traffic_generators_status")
600 return await self.list_traffic_generators()
602 async def set_config(
603 self, config: Dict[str, Any], target: Optional[str] = None
604 ) -> ConfigResponse:
605 """
606 Set configuration on traffic generator and retrieve the applied configuration.
608 Args:
609 config: Configuration to set
610 target: Optional target ID
612 Returns:
613 Configuration response containing the applied configuration
614 """
615 logger.info(f"Setting configuration on target {target or 'default'}")
617 try:
618 logger.info(f"Getting API client for {target or 'localhost'}")
619 api = self._get_api_client(target or "localhost")
621 logger.info("Processing config based on type")
622 if isinstance(config, dict): 622 ↛ 628line 622 didn't jump to line 628 because the condition on line 622 was always true
623 logger.info("Deserializing config dictionary")
624 cfg = api.config()
625 cfg.deserialize(config)
626 api.set_config(cfg)
627 else:
628 logger.info("Using config object directly")
629 api.set_config(config)
631 logger.info("Retrieving the applied configuration")
632 config = api.get_config()
633 logger.info("Serializing retrieved config to dictionary")
634 logger.debug("Using config directly for serialization")
635 config_dict = config.serialize(encoding=config.DICT) # type: ignore
637 return ConfigResponse(status="success", config=config_dict)
638 except Exception as e:
639 logger.error(f"Error setting configuration: {e}")
640 logger.error(traceback.format_exc())
641 return ConfigResponse(status="error", config={"error": str(e)})
643 async def get_config(self, target: Optional[str] = None) -> ConfigResponse:
644 """
645 Get configuration from traffic generator.
647 Args:
648 target: Optional target ID
650 Returns:
651 Configuration response
652 """
653 logger.info(f"Getting configuration from target {target or 'default'}")
655 try:
656 logger.info(f"Getting API client for {target or 'localhost'}")
657 api = self._get_api_client(target or "localhost")
659 logger.info("Getting configuration from device")
660 config = api.get_config()
662 logger.info("Serializing config to dictionary")
663 logger.debug("Using config directly for serialization")
664 config_dict = config.serialize(encoding=config.DICT) # type: ignore
666 return ConfigResponse(status="success", config=config_dict)
667 except Exception as e:
668 logger.error(f"Error getting configuration: {e}")
669 logger.error(traceback.format_exc())
670 return ConfigResponse(status="error", config={"error": str(e)})
672 async def start_traffic(self, target: Optional[str] = None) -> ControlResponse:
673 """
674 Start traffic generation.
676 Args:
677 target: Optional target ID
679 Returns:
680 Control response
681 """
682 logger.info(f"Starting traffic on target {target or 'default'}")
684 try:
685 logger.info(f"Getting API client for {target or 'localhost'}")
686 api = self._get_api_client(target or "localhost")
688 logger.info("Starting traffic on device")
689 self._start_traffic(api)
691 return ControlResponse(status="success", action="traffic_generation")
692 except Exception as e:
693 logger.error(f"Error starting traffic: {e}")
694 logger.error(traceback.format_exc())
695 return ControlResponse(
696 status="error", action="traffic_generation", result={"error": str(e)}
697 )
699 async def stop_traffic(self, target: Optional[str] = None) -> ControlResponse:
700 """
701 Stop traffic generation.
703 Args:
704 target: Target ID
706 Returns:
707 Control response
708 """
709 logger.info(f"Stopping traffic on target {target or 'default'}")
711 try:
712 logger.info(f"Getting API client for {target or 'localhost'}")
713 api = self._get_api_client(target or "localhost")
715 logger.info("Stopping traffic on device")
716 success = self._stop_traffic(api)
718 return ControlResponse(
719 status="success",
720 action="traffic_generation",
721 result={"verified": success},
722 )
723 except Exception as e:
724 logger.error(f"Error stopping traffic: {e}")
725 logger.error(traceback.format_exc())
726 return ControlResponse(
727 status="error", action="traffic_generation", result={"error": str(e)}
728 )
730 async def start_capture(
731 self, port_name: Union[str, List[str]], target: Optional[str] = None
732 ) -> CaptureResponse:
733 """
734 Start packet capture on one or more ports.
736 Args:
737 port_name: Name or list of names of port(s) to capture on
738 target: Optional target ID
740 Returns:
741 Capture response
742 """
743 logger.debug("Determining response port name for multi-port capture")
744 response_port = (
745 port_name[0] if isinstance(port_name, list) and port_name else port_name
746 )
747 if isinstance(response_port, list):
748 logger.debug("Response port is still a list, extracting first element")
749 response_port = response_port[0] if response_port else ""
751 logger.info(
752 f"Starting capture on port(s) {port_name} on target {target or 'default'}"
753 )
755 try:
756 logger.info(f"Getting API client for {target or 'localhost'}")
757 api = self._get_api_client(target or "localhost")
759 logger.info(
760 f"Starting capture on port(s) {port_name} with improved implementation"
761 )
762 result = start_capture(api, port_name)
764 if result["status"] == "success":
765 return CaptureResponse(status="success", port=response_port)
766 else:
767 return CaptureResponse(
768 status="error",
769 port=response_port,
770 data={"error": result.get("error", "Unknown error")},
771 )
772 except Exception as e:
773 logger.error(f"Error starting capture: {e}")
774 logger.error(traceback.format_exc())
775 return CaptureResponse(
776 status="error", port=response_port, data={"error": str(e)}
777 )
779 async def stop_capture(
780 self, port_name: Union[str, List[str]], target: Optional[str] = None
781 ) -> CaptureResponse:
782 """
783 Stop packet capture on one or more ports.
785 Args:
786 port_name: Name or list of names of port(s) to stop capture on
787 target: Optional target ID
789 Returns:
790 Capture response
791 """
792 logger.debug("Determining response port name for multi-port capture")
793 response_port = (
794 port_name[0] if isinstance(port_name, list) and port_name else port_name
795 )
796 if isinstance(response_port, list):
797 logger.debug("Response port is still a list, extracting first element")
798 response_port = response_port[0] if response_port else ""
800 logger.info(
801 f"Stopping capture on port(s) {port_name} on target {target or 'default'}"
802 )
804 try:
805 logger.info(f"Getting API client for {target or 'localhost'}")
806 api = self._get_api_client(target or "localhost")
808 logger.info(
809 f"Stopping capture on port(s) {port_name} with improved implementation"
810 )
811 result = stop_capture(api, port_name)
813 if result["status"] == "success":
814 data = {"status": "stopped"}
815 if "warnings" in result:
816 data["warnings"] = result["warnings"]
817 return CaptureResponse(status="success", port=response_port, data=data)
818 else:
819 return CaptureResponse(
820 status="error",
821 port=response_port,
822 data={"error": result.get("error", "Unknown error")},
823 )
824 except Exception as e:
825 logger.error(f"Error stopping capture: {e}")
826 logger.error(traceback.format_exc())
827 return CaptureResponse(
828 status="error", port=response_port, data={"error": str(e)}
829 )
831 async def get_capture(
832 self,
833 port_name: str,
834 output_dir: Optional[str] = None,
835 target: Optional[str] = None,
836 filename: Optional[str] = None,
837 ) -> CaptureResponse:
838 """
839 Get packet capture from a port and save it to a file.
841 Args:
842 port_name: Name of port to get capture from
843 output_dir: Directory to save the capture file (default: /tmp)
844 target: Optional target ID
845 filename: Optional custom filename for the capture file
847 Returns:
848 Capture response with file path where the capture was saved
849 """
850 logger.info(
851 f"Getting capture from port {port_name} on target {target or 'default'}"
852 )
854 try:
855 logger.info(f"Getting API client for {target or 'localhost'}")
856 api = self._get_api_client(target or "localhost")
858 logger.info(
859 f"Getting capture for port {port_name} with improved implementation"
860 )
861 result = get_capture(
862 api, port_name, output_dir=output_dir, filename=filename
863 )
865 if result["status"] == "success":
866 return CaptureResponse(
867 status="success",
868 port=port_name,
869 data={"status": "captured", "file_path": result["file_path"]},
870 file_path=result["file_path"],
871 capture_id=result.get("capture_id"),
872 )
873 else:
874 return CaptureResponse(
875 status="error",
876 port=port_name,
877 data={"error": result.get("error", "Unknown error")},
878 )
879 except Exception as e:
880 logger.error(f"Error getting capture: {e}")
881 logger.error(traceback.format_exc())
882 return CaptureResponse(
883 status="error", port=port_name, data={"error": str(e)}
884 )
886 async def list_traffic_generators(self) -> TrafficGeneratorStatus:
887 """
888 List all available traffic generators.
890 Returns:
891 TrafficGeneratorStatus containing all traffic generators
892 """
893 logger.info("Listing all traffic generators")
895 try:
896 result = TrafficGeneratorStatus()
898 logger.info("Getting targets from config")
899 for hostname, target_config in self.config.targets.targets.items():
900 logger.info(f"Adding target {hostname} to list")
902 logger.info(f"Creating generator info for {hostname}")
903 gen_info = TrafficGeneratorInfo(hostname=hostname)
905 logger.info(f"Adding port configurations for {hostname}")
906 for port_name, port_config in target_config.ports.items():
907 logger.debug(f"Ensuring location is not None for port {port_name}")
908 location = port_config.location or ""
909 gen_info.ports[port_name] = PortInfo(
910 name=port_name, location=location, interface=None
911 )
913 logger.info(f"Testing connection to {hostname}")
914 try:
915 logger.info("Simple availability check")
916 gen_info.available = True
917 except Exception as e:
918 logger.warning(f"Error connecting to {hostname}: {e}")
919 gen_info.available = False
921 logger.info(f"Adding {hostname} to result")
922 result.generators[hostname] = gen_info
924 return result
925 except Exception as e:
926 logger.error(f"Error listing traffic generators: {e}")
927 logger.error(traceback.format_exc())
928 return TrafficGeneratorStatus()
930 async def get_available_targets(self) -> Dict[str, Dict[str, Any]]:
931 """
932 Get all available traffic generator targets with comprehensive information.
934 Provides a combined view of target configurations and availability status:
935 - Technical configuration (ports)
936 - Availability information (whether target is reachable)
937 - API version information (when available)
938 - Additional metadata
940 This method always clears the client cache to ensure fresh connections.
942 Returns:
943 Dictionary mapping target names to their configurations, including:
944 - ports: Port configurations for the target
945 - available: Whether the target is currently reachable
946 - apiVersion: API version detected from the target (if available)
947 """
948 logger.info("Getting available traffic generator targets")
950 logger.info("Clearing client cache to force reconnection")
951 self.api_clients.clear()
953 result = {}
954 try:
955 logger.info("Reading targets from config")
956 for target_name, target_config in self.config.targets.targets.items():
957 logger.info(f"Processing target: {target_name}")
959 target_dict = {
960 "ports": {},
961 "available": False,
962 }
964 for port_name, port_config in target_config.ports.items():
965 target_dict["ports"][port_name] = { # type: ignore
966 "location": port_config.location,
967 "name": port_config.name,
968 }
970 logger.info(f"Testing connection to {target_name}")
971 try:
972 self._get_api_client(target_name)
973 logger.debug("Testing availability of the target")
974 target_dict["available"] = True
975 logger.info(f"Target {target_name} is available")
977 logger.info(
978 f"Attempting to retrieve API version from target {target_name}"
979 )
980 try:
981 version_info = await self.get_target_version(target_name)
982 target_dict["apiVersion"] = version_info.sdk_version
983 logger.info(
984 f"Detected API version {version_info.sdk_version} for target {target_name}"
985 )
986 except Exception as version_error:
987 logger.warning(
988 f"Could not detect API version for {target_name}: {version_error}"
989 )
990 target_dict["apiVersionError"] = str(version_error)
991 except Exception as e:
992 logger.warning(f"Error connecting to {target_name}: {e}")
993 target_dict["available"] = False
994 target_dict["error"] = str(e)
996 result[target_name] = target_dict # type: ignore
998 logger.info(
999 f"Found {len(result)} targets, {sum(1 for t in result.values() if t['available'])} available"
1000 )
1001 return result
1002 except Exception as e:
1003 logger.error(f"Error getting available targets: {e}")
1004 logger.error(traceback.format_exc())
1005 return {}
1007 async def _get_target_config(self, target_name: str) -> Optional[Dict[str, Any]]:
1008 """
1009 Get configuration for a specific target (internal method).
1011 Args:
1012 target_name: Name of the target to look up
1014 Returns:
1015 Target configuration including ports and dynamically determined apiVersion,
1016 or None if the target doesn't exist
1017 """
1018 logger.info(f"Looking up configuration for target: {target_name}")
1020 try:
1021 targets = await self.get_available_targets()
1023 if target_name in targets: 1023 ↛ 1099line 1023 didn't jump to line 1099 because the condition on line 1023 was always true
1024 logger.info(f"Found configuration for target: {target_name}")
1025 target_config = targets[target_name]
1026 schema_registry = self.schema_registry
1028 logger.info(
1029 "Initializing API version with default value to be overridden by device-reported version"
1030 )
1031 target_config["apiVersion"] = "unknown"
1032 logger.debug(
1033 "API version is always set dynamically from device, never from config"
1034 )
1036 logger.info("Getting API version directly from the target device")
1037 try:
1038 logger.info(
1039 f"Attempting to get API version from target {target_name}"
1040 )
1041 version_info = await self.get_target_version(target_name)
1042 actual_api_version = version_info.sdk_version
1043 normalized_version = actual_api_version.replace(".", "_")
1045 logger.info(
1046 f"Target {target_name} reports API version: {actual_api_version}"
1047 )
1049 logger.debug(
1050 "Verifying schema registry was properly initialized in __post_init__"
1051 )
1052 if schema_registry is None: 1052 ↛ 1053line 1052 didn't jump to line 1053 because the condition on line 1052 was never true
1053 logger.error("Schema registry is not initialized")
1054 raise ValueError("Schema registry is not initialized")
1056 logger.info(
1057 "Checking for schema match for the reported API version"
1058 )
1059 if schema_registry.schema_exists(normalized_version):
1060 logger.info(
1061 f"Found exact schema for actual version: {actual_api_version}"
1062 )
1063 target_config["apiVersion"] = actual_api_version
1064 else:
1065 logger.info(
1066 "No exact schema match found, finding closest available schema"
1067 )
1068 if schema_registry is None: 1068 ↛ 1069line 1068 didn't jump to line 1069 because the condition on line 1068 was never true
1069 logger.error("Schema registry is not initialized")
1070 raise ValueError("Schema registry is not initialized")
1072 closest_version = schema_registry.find_closest_schema_version(
1073 normalized_version
1074 )
1075 closest_version_dotted = closest_version.replace("_", ".")
1076 logger.info(
1077 f"No exact schema for version {actual_api_version}. "
1078 f"Using closest matching version: {closest_version_dotted}"
1079 )
1080 target_config["apiVersion"] = closest_version_dotted
1081 except Exception as e:
1082 logger.info(
1083 "Exception during API version detection, falling back to latest schema"
1084 )
1085 if schema_registry is None: 1085 ↛ 1086line 1085 didn't jump to line 1086 because the condition on line 1085 was never true
1086 logger.error("Schema registry is not initialized")
1087 raise ValueError("Schema registry is not initialized")
1089 latest_version = schema_registry.get_latest_schema_version()
1090 latest_version_dotted = latest_version.replace("_", ".")
1091 logger.warning(
1092 f"Failed to get API version from target {target_name}: {str(e)}. "
1093 f"Using latest available schema version: {latest_version_dotted}"
1094 )
1095 target_config["apiVersion"] = latest_version_dotted
1097 return target_config
1099 logger.warning(f"Target not found: {target_name}")
1100 return None
1101 except Exception as e:
1102 logger.error(f"Error looking up target {target_name}: {e}")
1103 logger.error(traceback.format_exc())
1104 return None
1106 async def get_schemas_for_target(
1107 self, target_name: str, schema_names: List[str]
1108 ) -> Dict[str, Any]:
1109 """
1110 Get multiple schemas for a specific target.
1112 Args:
1113 target_name: Name of the target
1114 schema_names: List of schema names/components to retrieve (e.g., ["Flow", "Port"] or
1115 ["components.schemas.Flow", "components.schemas.Port"])
1117 Returns:
1118 Dictionary mapping schema names to their content
1120 Raises:
1121 ValueError: If the target doesn't exist or schemas couldn't be loaded
1122 """
1123 logger.info(
1124 f"Getting schemas for target {target_name}, schemas: {schema_names}"
1125 )
1127 target_config = await self._get_target_config(target_name)
1128 if not target_config: 1128 ↛ 1129line 1128 didn't jump to line 1129 because the condition on line 1128 was never true
1129 error_msg = f"Target {target_name} not found"
1130 logger.error(error_msg)
1131 raise ValueError(error_msg)
1133 api_version = target_config["apiVersion"]
1134 logger.info(f"Using API version {api_version} for target {target_name}")
1136 result = {}
1137 try:
1138 registry = self.schema_registry
1139 for schema_name in schema_names:
1140 logger.info(f"Retrieving schema {schema_name} for target {target_name}")
1141 try:
1142 logger.debug("Verifying schema registry is properly initialized")
1143 if registry is None: 1143 ↛ 1144line 1143 didn't jump to line 1144 because the condition on line 1143 was never true
1144 logger.error("Schema registry is not initialized")
1145 raise ValueError("Schema registry is not initialized")
1147 if "." not in schema_name or not schema_name.startswith(
1148 "components.schemas."
1149 ):
1150 qualified_name = f"components.schemas.{schema_name}"
1151 logger.info(f"Interpreting {schema_name} as {qualified_name}")
1152 result[schema_name] = registry.get_schema(
1153 api_version, qualified_name
1154 )
1155 else:
1157 result[schema_name] = registry.get_schema(
1158 api_version, schema_name
1159 )
1160 except Exception as e:
1161 logger.warning(f"Error retrieving schema {schema_name}: {str(e)}")
1162 logger.debug("Creating error dictionary for exception response")
1163 error_msg = str(e)
1164 result[schema_name] = {"error": error_msg}
1166 return result
1167 except Exception as e:
1168 error_msg = f"Error getting schemas for target {target_name}: {str(e)}"
1169 logger.error(error_msg)
1170 raise ValueError(error_msg)
1172 async def list_schemas_for_target(self, target_name: str) -> List[str]:
1173 """
1174 List available schemas for a specific target's API version.
1176 This method returns only the components.schemas entries from the schema,
1177 focusing solely on the available schemas without including other structural
1178 information like top-level keys, paths, or other components.
1180 Args:
1181 target_name: Name of the target
1183 Returns:
1184 List of available schema names under components.schemas
1186 Raises:
1187 ValueError: If the target doesn't exist
1188 """
1189 logger.info(f"Listing schemas for target {target_name}")
1191 target_config = await self._get_target_config(target_name)
1192 if not target_config:
1193 error_msg = f"Target {target_name} not found"
1194 logger.error(error_msg)
1195 raise ValueError(error_msg)
1197 api_version = target_config["apiVersion"]
1198 logger.info(f"Using API version {api_version} for target {target_name}")
1200 try:
1201 registry = self.schema_registry
1202 logger.debug("Verifying schema registry is properly initialized")
1203 if registry is None: 1203 ↛ 1204line 1203 didn't jump to line 1204 because the condition on line 1203 was never true
1204 logger.error("Schema registry is not initialized")
1205 raise ValueError("Schema registry is not initialized")
1206 schema = registry.get_schema(api_version)
1208 if (
1209 "components" in schema
1210 and isinstance(schema["components"], dict)
1211 and "schemas" in schema["components"]
1212 and isinstance(schema["components"]["schemas"], dict)
1213 ):
1215 result = list(schema["components"]["schemas"].keys())
1216 logger.info(f"Extracted {len(result)} schema keys")
1217 return result
1218 else:
1219 logger.warning("No components.schemas found in the schema")
1220 return []
1221 except Exception as e:
1222 error_msg = (
1223 f"Error listing schema structure for target {target_name}: {str(e)}"
1224 )
1225 logger.error(error_msg)
1226 raise ValueError(error_msg)
1228 async def get_target_version(self, target: str) -> CapabilitiesVersionResponse:
1229 """
1230 Get version information from a target's capabilities/version endpoint.
1232 Args:
1233 target: Target hostname or IP
1235 Returns:
1236 CapabilitiesVersionResponse containing version information
1238 Raises:
1239 ValueError: If the request fails
1240 """
1241 logger.info(f"Getting version information from target {target}")
1243 url = f"https://{target}/capabilities/version"
1244 logger.info(f"Making request to {url}")
1246 async with aiohttp.ClientSession() as session:
1247 async with session.get(url, ssl=False) as response:
1248 if response.status == 200:
1249 data = await response.json()
1250 logger.info(f"Received version data: {data}")
1251 return CapabilitiesVersionResponse(**data)
1252 else:
1253 error_msg = (
1254 f"Failed to get version from {target}: {response.status}"
1255 )
1256 logger.error(error_msg)
1257 raise ValueError(error_msg)
1259 async def get_schema_components_for_target(
1260 self, target_name: str, path_prefix: str = "components.schemas"
1261 ) -> List[str]:
1262 """
1263 Get a list of schema component names for a specific target's API version.
1265 Args:
1266 target_name: Name of the target
1267 path_prefix: The path prefix to look in (default: "components.schemas")
1269 Returns:
1270 List of component names
1272 Raises:
1273 ValueError: If the target doesn't exist or path doesn't exist
1274 """
1275 logger.info(
1276 f"Getting schema components for target {target_name} with prefix {path_prefix}"
1277 )
1279 target_config = await self._get_target_config(target_name)
1280 if not target_config:
1281 error_msg = f"Target {target_name} not found"
1282 logger.error(error_msg)
1283 raise ValueError(error_msg)
1285 api_version = target_config["apiVersion"]
1286 logger.info(f"Using API version {api_version} for target {target_name}")
1288 try:
1289 registry = self.schema_registry
1290 logger.debug("Verifying schema registry is properly initialized")
1291 if registry is None:
1292 logger.error("Schema registry is not initialized")
1293 raise ValueError("Schema registry is not initialized")
1294 return registry.get_schema_components(api_version, path_prefix)
1295 except Exception as e:
1296 error_msg = (
1297 f"Error getting schema components for target {target_name}: {str(e)}"
1298 )
1299 logger.error(error_msg)
1300 raise ValueError(error_msg)
1302 async def health(self, target: Optional[str] = None) -> HealthStatus:
1303 """
1304 Check health of traffic generator system by verifying version endpoints.
1306 Args:
1307 target: Optional target to check. If None, checks all targets.
1309 Returns:
1310 HealthStatus: Collection of target health information
1311 """
1312 logger.info(f"Checking health of {target or 'all targets'}")
1314 logger.debug("Initializing HealthStatus with 'error' status by default")
1315 health_status = HealthStatus(status="error")
1317 try:
1318 target_names = []
1319 if target:
1320 logger.info(f"Checking specific target: {target}")
1321 target_names = [target]
1322 else:
1323 logger.info("No specific target - checking all available targets")
1324 targets = await self.get_available_targets()
1325 target_names = list(targets.keys())
1326 logger.info(f"Found {len(target_names)} targets to check")
1328 logger.info("Beginning health checks for all targets")
1329 all_targets_healthy = True
1330 for target_name in target_names:
1331 logger.info(f"Checking health for target: {target_name}")
1332 try:
1333 logger.info(f"Requesting version info from {target_name}")
1334 version_info = await self.get_target_version(target_name)
1336 logger.info(f"Target {target_name} is healthy")
1337 health_status.targets[target_name] = TargetHealthInfo(
1338 name=target_name,
1339 healthy=True,
1340 version_info=version_info,
1341 error=None,
1342 )
1344 except Exception as e:
1345 logger.warning(f"Target {target_name} is unhealthy: {str(e)}")
1346 health_status.targets[target_name] = TargetHealthInfo(
1347 name=target_name, healthy=False, error=str(e), version_info=None
1348 )
1349 all_targets_healthy = False
1351 if all_targets_healthy and target_names:
1352 logger.info("All targets are healthy, setting status to 'success'")
1353 health_status.status = "success"
1354 else:
1355 logger.info("One or more targets are unhealthy, status remains 'error'")
1357 logger.info(f"Health check complete for {len(target_names)} targets")
1358 return health_status
1360 except Exception as e:
1361 logger.error(f"Health check failed with error: {str(e)}")
1362 logger.error(traceback.format_exc())
1363 return HealthStatus(status="error", targets={})
1365 async def get_metrics(
1366 self,
1367 flow_names: Optional[Union[str, List[str]]] = None,
1368 port_names: Optional[Union[str, List[str]]] = None,
1369 target: Optional[str] = None,
1370 ) -> MetricsResponse:
1371 """
1372 Get metrics from traffic generator.
1374 This unified method handles all metrics retrieval cases:
1375 - All metrics (flows and ports): call with no parameters
1376 - All flow metrics: call with flow_names=[]
1377 - Specific flow(s): call with flow_names="flow1" or flow_names=["flow1", "flow2"]
1378 - All port metrics: call with port_names=[]
1379 - Specific port(s): call with port_names="port1" or port_names=["port1", "port2"]
1380 - Combination: call with both flow_names and port_names
1382 Args:
1383 flow_names: Optional flow name(s) to get metrics for:
1384 - None: Flow metrics not specifically requested
1385 - Empty list: All flow metrics
1386 - str: Single flow metrics
1387 - List[str]: Multiple flow metrics
1388 port_names: Optional port name(s) to get metrics for:
1389 - None: Port metrics not specifically requested
1390 - Empty list: All port metrics
1391 - str: Single port metrics
1392 - List[str]: Multiple port metrics
1393 target: Optional target ID
1395 Returns:
1396 Metrics response containing requested metrics
1397 """
1398 logger.info(f"Getting metrics from target {target or 'default'}")
1400 try:
1401 logger.info(f"Getting API client for {target or 'localhost'}")
1402 api = self._get_api_client(target or "localhost")
1404 flow_name_list = None
1405 if flow_names is not None:
1406 if isinstance(flow_names, str):
1407 logger.info(f"Getting metrics for flow: {flow_names}")
1408 flow_name_list = [flow_names]
1409 else:
1410 if flow_names:
1411 logger.info(f"Getting metrics for flows: {flow_names}")
1412 else:
1413 logger.info("Getting metrics for all flows")
1414 flow_name_list = flow_names
1416 port_name_list = None
1417 if port_names is not None:
1418 if isinstance(port_names, str):
1419 logger.info(f"Getting metrics for port: {port_names}")
1420 port_name_list = [port_names]
1421 else:
1422 if port_names:
1423 logger.info(f"Getting metrics for ports: {port_names}")
1424 else:
1425 logger.info("Getting metrics for all ports")
1426 port_name_list = port_names
1428 if flow_name_list is None and port_name_list is None:
1429 logger.info("Getting all metrics (no specific filters)")
1430 metrics = self._get_metrics(api)
1431 else:
1432 logger.info(
1433 f"Calling _get_metrics with flow_names={flow_name_list}, port_names={port_name_list}"
1434 )
1435 metrics = self._get_metrics(
1436 api, flow_names=flow_name_list, port_names=port_name_list
1437 )
1439 logger.info("Serializing metrics to dictionary")
1440 logger.debug("Using metrics directly for serialization")
1441 metrics_dict = metrics.serialize(encoding=metrics.DICT) # type: ignore
1443 return MetricsResponse(status="success", metrics=metrics_dict)
1444 except Exception as e:
1445 logger.error(f"Error getting metrics: {e}")
1446 logger.error(traceback.format_exc())
1447 return MetricsResponse(status="error", metrics={"error": str(e)})