Coverage for src/otg_mcp/server.py: 25%

111 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-19 00:42 -0700

1"""OTG MCP Server implementation. 

2 

3This module implements a Model Context Protocol (MCP) server that provides access to 

4Open Traffic Generator (OTG) APIs via direct connections to traffic generators. 

5""" 

6 

7import argparse 

8import logging 

9import sys 

10import traceback 

11from typing import Annotated, Any, Dict, List, Literal, Optional, Union 

12 

13from fastmcp import FastMCP 

14from pydantic import Field 

15 

16from otg_mcp.client import OtgClient 

17from otg_mcp.config import Config 

18from otg_mcp.models import ( 

19 CaptureResponse, 

20 ConfigResponse, 

21 ControlResponse, 

22 HealthStatus, 

23 MetricsResponse, 

24) 

25 

26logger = logging.getLogger(__name__) 

27logger.setLevel(logging.INFO) 

28 

29 

30class OtgMcpServer: 

31 """OTG MCP Server that provides access to traffic generators. 

32 

33 This server provides a unified API that adheres to the 

34 Open Traffic Generator specification. 

35 

36 Attributes: 

37 mcp: FastMCP instance that handles tool registration and execution 

38 """ 

39 

40 def __init__(self, config_file: str): 

41 """Initialize the server and register all tools and endpoints. 

42 

43 Args: 

44 config_file: Path to the configuration file 

45 """ 

46 try: 

47 logger.info("Initializing config with the provided file") 

48 config = Config(config_file) 

49 

50 logger.info("Setting up logging configuration") 

51 config.setup_logging() 

52 

53 logger.info("Creating the FastMCP instance") 

54 self.mcp: FastMCP = FastMCP("otg-mcp-server", log_level="INFO") 

55 

56 logger.info("Initializing schema registry") 

57 custom_schema_path = None 

58 if hasattr(config, "schemas") and config.schemas.schema_path: 

59 custom_schema_path = config.schemas.schema_path 

60 logger.info( 

61 f"Using custom schema path from config: {custom_schema_path}" 

62 ) 

63 

64 from otg_mcp.schema_registry import SchemaRegistry 

65 

66 self.schema_registry = SchemaRegistry(custom_schema_path) 

67 

68 logger.info("Initializing OTG client with schema registry") 

69 self.client = OtgClient(config=config, schema_registry=self.schema_registry) 

70 

71 logger.info("Registering all endpoints") 

72 self._register_tools() 

73 

74 except Exception as e: 

75 logger.critical(f"Failed to initialize server: {str(e)}") 

76 logger.critical(f"Stack trace: {traceback.format_exc()}") 

77 raise 

78 

79 def _register_tools(self): 

80 """Automatically register all methods starting with 'tool_' as MCP tools.""" 

81 logger.info("Discovering and registering tools") 

82 

83 count = 0 

84 for attr_name in dir(self): 

85 logger.debug(f"Checking attribute: {attr_name}") 

86 if attr_name.startswith("_") or not callable(getattr(self, attr_name)): 

87 logger.debug(f"Skipping non-tool attribute: {attr_name}") 

88 continue 

89 

90 if attr_name.startswith("tool_"): 

91 method = getattr(self, attr_name) 

92 tool_name = attr_name[5:] 

93 logger.debug( 

94 f"Found tool method: {attr_name}, registering as: {tool_name}" 

95 ) 

96 logger.info(f"Registering tool: {tool_name}") 

97 self.mcp.add_tool(method, name=tool_name) 

98 count += 1 

99 

100 logger.info(f"Registered {count} tools successfully") 

101 

102 async def tool_set_config( 

103 self, 

104 config: Annotated[ 

105 Dict[str, Any], Field(description="The configuration to set") 

106 ], 

107 target: Annotated[ 

108 str, Field(description="Target traffic generator hostname or IP address") 

109 ], 

110 ) -> ConfigResponse: 

111 """Set the configuration of the traffic generator and retrieve the applied configuration.""" 

112 logger.info(f"Tool: set_config for target {target}") 

113 return await self.client.set_config(target=target, config=config) 

114 

115 async def tool_get_config( 

116 self, target: Annotated[str, Field(description="Target traffic generator")] 

117 ) -> ConfigResponse: 

118 """Get the current configuration of the traffic generator.""" 

119 logger.info(f"Tool: get_config for target {target}") 

120 return await self.client.get_config(target=target) 

121 

122 async def tool_get_metrics( 

123 self, 

124 flow_names: Annotated[ 

125 Optional[Union[str, List[str]]], 

126 Field(description="Optional flow name(s) to get metrics for"), 

127 ] = None, 

128 port_names: Annotated[ 

129 Optional[Union[str, List[str]]], 

130 Field(description="Optional port name(s) to get metrics for"), 

131 ] = None, 

132 target: Annotated[ 

133 Optional[str], Field(description="Optional target traffic generator") 

134 ] = None, 

135 ) -> MetricsResponse: 

136 """Get metrics from the traffic generator.""" 

137 logger.info( 

138 f"Tool: get_metrics for target {target}, flow_names={flow_names}, port_names={port_names}" 

139 ) 

140 return await self.client.get_metrics( 

141 flow_names=flow_names, port_names=port_names, target=target 

142 ) 

143 

144 async def tool_start_traffic( 

145 self, target: Annotated[str, Field(description="Target traffic generator")] 

146 ) -> ControlResponse: 

147 """Start traffic generation.""" 

148 logger.info(f"Tool: start_traffic for target {target}") 

149 return await self.client.start_traffic(target=target) 

150 

151 async def tool_stop_traffic( 

152 self, target: Annotated[str, Field(description="Target traffic generator")] 

153 ) -> ControlResponse: 

154 """Stop traffic generation.""" 

155 logger.info(f"Tool: stop_traffic for target {target}") 

156 return await self.client.stop_traffic(target=target) 

157 

158 async def tool_start_capture( 

159 self, 

160 port_name: Annotated[ 

161 str, Field(description="Name of the port to capture packets on") 

162 ], 

163 target: Annotated[str, Field(description="Target traffic generator")], 

164 ) -> CaptureResponse: 

165 """Start packet capture on a port.""" 

166 logger.info(f"Tool: start_capture for port {port_name} on target {target}") 

167 return await self.client.start_capture(target=target, port_name=port_name) 

168 

169 async def tool_stop_capture( 

170 self, 

171 port_name: Annotated[ 

172 str, Field(description="Name of the port to stop capturing packets on") 

173 ], 

174 target: Annotated[str, Field(description="Target traffic generator")], 

175 ) -> CaptureResponse: 

176 """Stop packet capture on a port.""" 

177 logger.info(f"Tool: stop_capture for port {port_name} on target {target}") 

178 return await self.client.stop_capture(target=target, port_name=port_name) 

179 

180 async def tool_get_capture( 

181 self, 

182 port_name: Annotated[ 

183 str, Field(description="Name of the port to get capture from") 

184 ], 

185 target: Annotated[str, Field(description="Target traffic generator")], 

186 output_dir: Annotated[ 

187 Optional[str], 

188 Field(description="Directory to save the capture file (default: /tmp)"), 

189 ] = None, 

190 ) -> CaptureResponse: 

191 """ 

192 Get packet capture from a port and save it to a file. 

193 

194 The capture data is saved as a .pcap file that can be opened with tools like Wireshark. 

195 """ 

196 logger.info(f"Tool: get_capture for port {port_name} on target {target}") 

197 return await self.client.get_capture( 

198 target=target, port_name=port_name, output_dir=output_dir 

199 ) 

200 

201 async def tool_get_available_targets(self) -> Dict[str, Dict[str, Any]]: 

202 """Get all available traffic generator targets with comprehensive information.""" 

203 logger.info("Tool: get_available_targets") 

204 return await self.client.get_available_targets() 

205 

206 async def tool_health( 

207 self, 

208 target: Annotated[ 

209 Optional[str], 

210 Field( 

211 description="Optional target to check. If None, checks all available targets" 

212 ), 

213 ] = None, 

214 ) -> HealthStatus: 

215 """Health check tool.""" 

216 logger.info(f"Tool: health for {target or 'all targets'}") 

217 return await self.client.health(target) 

218 

219 async def tool_get_schemas_for_target( 

220 self, 

221 target_name: Annotated[str, Field(description="Name of the target")], 

222 schema_names: Annotated[ 

223 List[str], 

224 Field( 

225 description='List of schema names to retrieve (e.g., ["Flow", "Port"] or ["components.schemas.Flow"])' 

226 ), 

227 ], 

228 ) -> Dict[str, Any]: 

229 """Get schemas for a specific target's API version.""" 

230 logger.info( 

231 f"Tool: get_schemas_for_target for {target_name}, schemas {schema_names}" 

232 ) 

233 return await self.client.get_schemas_for_target(target_name, schema_names) 

234 

235 async def tool_list_schemas_for_target( 

236 self, target_name: Annotated[str, Field(description="Name of the target")] 

237 ) -> List[str]: 

238 """List available schemas for a specific target's API version.""" 

239 logger.info(f"Tool: list_schemas_for_target for {target_name}") 

240 return await self.client.list_schemas_for_target(target_name) 

241 

242 def run(self, transport: Literal["stdio", "sse"] = "stdio"): 

243 """Run the server with the specified transport mechanism. 

244 

245 Args: 

246 transport: Transport to use (stdio or sse) 

247 """ 

248 try: 

249 self.mcp.run(transport=transport) 

250 except Exception as e: 

251 logger.critical(f"Error running server: {str(e)}") 

252 logger.critical(f"Stack trace: {traceback.format_exc()}") 

253 raise 

254 

255 

256def run_server() -> None: 

257 """Run the OTG MCP Server.""" 

258 try: 

259 logger.info("Parsing command-line arguments") 

260 parser = argparse.ArgumentParser(description="OTG MCP Server") 

261 parser.add_argument( 

262 "--config-file", 

263 type=str, 

264 required=True, 

265 help="Path to the traffic generator configuration file", 

266 ) 

267 parser.add_argument( 

268 "--transport", 

269 type=str, 

270 choices=["stdio", "sse"], 

271 default="stdio", 

272 help="Transport mechanism to use (stdio or sse)", 

273 ) 

274 

275 args = parser.parse_args() 

276 

277 logger.info("Initializing and running the server with the config file") 

278 server = OtgMcpServer(config_file=args.config_file) 

279 server.run(transport=args.transport) # type: ignore 

280 except Exception as e: 

281 logger.critical(f"Server failed with error: {str(e)}") 

282 logger.critical(f"Stack trace: {traceback.format_exc()}") 

283 sys.exit(1) 

284 

285 

286def main() -> None: 

287 """Legacy entry point for backward compatibility.""" 

288 run_server() 

289 

290 

291if __name__ == "__main__": 

292 run_server()