Coverage for src/mcpadapt/auth/handlers.py: 97%

114 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-06 19:35 +0530

1"""OAuth handlers for managing authentication flows.""" 

2 

3import threading 

4import time 

5import webbrowser 

6from abc import ABC, abstractmethod 

7from http.server import BaseHTTPRequestHandler, HTTPServer 

8from typing import Any 

9from urllib.parse import parse_qs, urlparse 

10 

11from .exceptions import ( 

12 OAuthCallbackError, 

13 OAuthCancellationError, 

14 OAuthNetworkError, 

15 OAuthServerError, 

16 OAuthTimeoutError, 

17) 

18 

19 

20class CallbackHandler(BaseHTTPRequestHandler): 

21 """Simple HTTP handler to capture OAuth callback.""" 

22 

23 def __init__(self, request, client_address, server, callback_data): 

24 """Initialize with callback data storage.""" 

25 self.callback_data = callback_data 

26 super().__init__(request, client_address, server) 

27 

28 def do_GET(self): 

29 """Handle GET request from OAuth redirect.""" 

30 parsed = urlparse(self.path) 

31 query_params = parse_qs(parsed.query) 

32 

33 if "code" in query_params: 

34 self.callback_data["authorization_code"] = query_params["code"][0] 

35 self.callback_data["state"] = query_params.get("state", [None])[0] 

36 self.send_response(200) 

37 self.send_header("Content-type", "text/html") 

38 self.end_headers() 

39 self.wfile.write(b""" 

40 <html> 

41 <body> 

42 <h1>Authorization Successful!</h1> 

43 <p>You can close this window and return to the terminal.</p> 

44 <script>setTimeout(() => window.close(), 2000);</script> 

45 </body> 

46 </html> 

47 """) 

48 elif "error" in query_params: 

49 self.callback_data["error"] = query_params["error"][0] 

50 self.send_response(400) 

51 self.send_header("Content-type", "text/html") 

52 self.end_headers() 

53 self.wfile.write( 

54 f""" 

55 <html> 

56 <body> 

57 <h1>Authorization Failed</h1> 

58 <p>Error: {query_params["error"][0]}</p> 

59 <p>You can close this window and return to the terminal.</p> 

60 </body> 

61 </html> 

62 """.encode() 

63 ) 

64 else: 

65 self.send_response(404) 

66 self.end_headers() 

67 

68 def log_message(self, format, *args): 

69 """Suppress default logging.""" 

70 pass 

71 

72 

73class LocalCallbackServer: 

74 """Simple server to handle OAuth callbacks.""" 

75 

76 def __init__(self, port: int = 3030): 

77 """Initialize callback server. 

78  

79 Args: 

80 port: Port to listen on for OAuth callbacks 

81 """ 

82 self.port = port 

83 self.server = None 

84 self.thread = None 

85 self.callback_data = {"authorization_code": None, "state": None, "error": None} 

86 

87 def _create_handler_with_data(self): 

88 """Create a handler class with access to callback data.""" 

89 callback_data = self.callback_data 

90 

91 class DataCallbackHandler(CallbackHandler): 

92 def __init__(self, request, client_address, server): 

93 super().__init__(request, client_address, server, callback_data) 

94 

95 return DataCallbackHandler 

96 

97 def start(self) -> None: 

98 """Start the callback server in a background thread. 

99  

100 Raises: 

101 OAuthCallbackError: If server cannot be started 

102 """ 

103 try: 

104 handler_class = self._create_handler_with_data() 

105 self.server = HTTPServer(("localhost", self.port), handler_class) 

106 self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) 

107 self.thread.start() 

108 except OSError as e: 

109 if e.errno == 48: # Address already in use 

110 raise OAuthCallbackError( 

111 f"Port {self.port} is already in use. Try using a different port or check if another OAuth flow is running.", 

112 context={"port": self.port, "original_error": str(e)} 

113 ) 

114 else: 

115 raise OAuthCallbackError( 

116 f"Failed to start OAuth callback server on port {self.port}: {str(e)}", 

117 context={"port": self.port, "original_error": str(e)} 

118 ) 

119 except Exception as e: 

120 raise OAuthCallbackError( 

121 f"Unexpected error starting OAuth callback server: {str(e)}", 

122 context={"port": self.port, "original_error": str(e)} 

123 ) 

124 

125 def stop(self) -> None: 

126 """Stop the callback server.""" 

127 if self.server: 

128 self.server.shutdown() 

129 self.server.server_close() 

130 if self.thread: 

131 self.thread.join(timeout=1) 

132 

133 def wait_for_callback(self, timeout: int = 300) -> str: 

134 """Wait for OAuth callback with timeout. 

135  

136 Args: 

137 timeout: Maximum time to wait for callback in seconds 

138  

139 Returns: 

140 Authorization code from OAuth callback 

141  

142 Raises: 

143 OAuthTimeoutError: If timeout occurs 

144 OAuthCancellationError: If user cancels authorization 

145 OAuthServerError: If OAuth server returns an error 

146 """ 

147 start_time = time.time() 

148 while time.time() - start_time < timeout: 

149 if self.callback_data["authorization_code"]: 

150 return self.callback_data["authorization_code"] 

151 elif self.callback_data["error"]: 

152 error = self.callback_data["error"] 

153 context = {"port": self.port, "timeout": timeout} 

154 

155 # Map common OAuth error codes to specific exceptions 

156 if error in ["access_denied", "user_cancelled"]: 

157 raise OAuthCancellationError( 

158 error_details=error, 

159 context=context 

160 ) 

161 else: 

162 # Generic server error for other OAuth errors 

163 raise OAuthServerError( 

164 server_error=error, 

165 context=context 

166 ) 

167 time.sleep(0.1) 

168 

169 raise OAuthTimeoutError( 

170 timeout_seconds=timeout, 

171 context={"port": self.port} 

172 ) 

173 

174 def get_state(self) -> str | None: 

175 """Get the received state parameter. 

176  

177 Returns: 

178 OAuth state parameter or None 

179 """ 

180 return self.callback_data["state"] 

181 

182 

183class BaseOAuthHandler(ABC): 

184 """Base class for OAuth authentication handlers. 

185  

186 Combines redirect and callback handling into a single cohesive interface. 

187 Subclasses should implement both the redirect flow (opening authorization URL) 

188 and callback flow (receiving authorization code). 

189 """ 

190 

191 @abstractmethod 

192 async def handle_redirect(self, authorization_url: str) -> None: 

193 """Handle OAuth redirect to authorization URL. 

194  

195 Args: 

196 authorization_url: The OAuth authorization URL to redirect the user to 

197 """ 

198 pass 

199 

200 @abstractmethod 

201 async def handle_callback(self) -> tuple[str, str | None]: 

202 """Handle OAuth callback and return authorization code and state. 

203  

204 Returns: 

205 Tuple of (authorization_code, state) received from OAuth provider 

206  

207 Raises: 

208 Exception: If OAuth flow fails or times out 

209 """ 

210 pass 

211 

212 

213class LocalBrowserOAuthHandler(BaseOAuthHandler): 

214 """Default OAuth handler using local browser and callback server. 

215  

216 Opens authorization URL in the user's default browser and starts a local 

217 HTTP server to receive the OAuth callback. This is the most user-friendly 

218 approach for desktop applications. 

219 """ 

220 

221 def __init__(self, callback_port: int = 3030, timeout: int = 300): 

222 """Initialize the local browser OAuth handler. 

223  

224 Args: 

225 callback_port: Port to run the local callback server on 

226 timeout: Maximum time to wait for OAuth callback in seconds 

227 """ 

228 self.callback_port = callback_port 

229 self.timeout = timeout 

230 self.callback_server: LocalCallbackServer | None = None 

231 

232 async def handle_redirect(self, authorization_url: str) -> None: 

233 """Open authorization URL in the user's default browser. 

234  

235 Args: 

236 authorization_url: OAuth authorization URL to open 

237  

238 Raises: 

239 OAuthNetworkError: If browser cannot be opened 

240 """ 

241 print(f"Opening OAuth authorization URL: {authorization_url}") 

242 

243 try: 

244 success = webbrowser.open(authorization_url) 

245 if not success: 

246 print("Failed to automatically open browser. Please manually open the URL above.") 

247 raise OAuthNetworkError( 

248 Exception("Failed to open browser - no suitable browser found"), 

249 context={"authorization_url": authorization_url} 

250 ) 

251 except Exception as e: 

252 if isinstance(e, OAuthNetworkError): 

253 raise 

254 print("Failed to automatically open browser. Please manually open the URL above.") 

255 raise OAuthNetworkError( 

256 e, 

257 context={"authorization_url": authorization_url} 

258 ) 

259 

260 async def handle_callback(self) -> tuple[str, str | None]: 

261 """Start local server and wait for OAuth callback. 

262  

263 Returns: 

264 Tuple of (authorization_code, state) from OAuth callback 

265  

266 Raises: 

267 OAuthCallbackError: If callback server cannot be started 

268 OAuthTimeoutError: If callback doesn't arrive within timeout 

269 OAuthCancellationError: If user cancels authorization 

270 OAuthServerError: If OAuth server returns an error 

271 """ 

272 try: 

273 self.callback_server = LocalCallbackServer(port=self.callback_port) 

274 self.callback_server.start() 

275 

276 auth_code = self.callback_server.wait_for_callback(timeout=self.timeout) 

277 state = self.callback_server.get_state() 

278 return auth_code, state 

279 

280 except (OAuthTimeoutError, OAuthCancellationError, OAuthServerError, OAuthCallbackError): 

281 # Re-raise OAuth-specific exceptions as-is 

282 raise 

283 except Exception as e: 

284 # Wrap unexpected errors 

285 raise OAuthCallbackError( 

286 f"Unexpected error during OAuth callback handling: {str(e)}", 

287 context={ 

288 "port": self.callback_port, 

289 "timeout": self.timeout, 

290 "original_error": str(e) 

291 } 

292 ) 

293 finally: 

294 if self.callback_server: 

295 self.callback_server.stop()