Coverage for src/mcpadapt/auth/handlers.py: 97%
114 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-06 19:35 +0530
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-06 19:35 +0530
1"""OAuth handlers for managing authentication flows."""
3import threading
4import time
5import webbrowser
6from abc import ABC, abstractmethod
7from http.server import BaseHTTPRequestHandler, HTTPServer
8from typing import Any
9from urllib.parse import parse_qs, urlparse
11from .exceptions import (
12 OAuthCallbackError,
13 OAuthCancellationError,
14 OAuthNetworkError,
15 OAuthServerError,
16 OAuthTimeoutError,
17)
20class CallbackHandler(BaseHTTPRequestHandler):
21 """Simple HTTP handler to capture OAuth callback."""
23 def __init__(self, request, client_address, server, callback_data):
24 """Initialize with callback data storage."""
25 self.callback_data = callback_data
26 super().__init__(request, client_address, server)
28 def do_GET(self):
29 """Handle GET request from OAuth redirect."""
30 parsed = urlparse(self.path)
31 query_params = parse_qs(parsed.query)
33 if "code" in query_params:
34 self.callback_data["authorization_code"] = query_params["code"][0]
35 self.callback_data["state"] = query_params.get("state", [None])[0]
36 self.send_response(200)
37 self.send_header("Content-type", "text/html")
38 self.end_headers()
39 self.wfile.write(b"""
40 <html>
41 <body>
42 <h1>Authorization Successful!</h1>
43 <p>You can close this window and return to the terminal.</p>
44 <script>setTimeout(() => window.close(), 2000);</script>
45 </body>
46 </html>
47 """)
48 elif "error" in query_params:
49 self.callback_data["error"] = query_params["error"][0]
50 self.send_response(400)
51 self.send_header("Content-type", "text/html")
52 self.end_headers()
53 self.wfile.write(
54 f"""
55 <html>
56 <body>
57 <h1>Authorization Failed</h1>
58 <p>Error: {query_params["error"][0]}</p>
59 <p>You can close this window and return to the terminal.</p>
60 </body>
61 </html>
62 """.encode()
63 )
64 else:
65 self.send_response(404)
66 self.end_headers()
68 def log_message(self, format, *args):
69 """Suppress default logging."""
70 pass
73class LocalCallbackServer:
74 """Simple server to handle OAuth callbacks."""
76 def __init__(self, port: int = 3030):
77 """Initialize callback server.
79 Args:
80 port: Port to listen on for OAuth callbacks
81 """
82 self.port = port
83 self.server = None
84 self.thread = None
85 self.callback_data = {"authorization_code": None, "state": None, "error": None}
87 def _create_handler_with_data(self):
88 """Create a handler class with access to callback data."""
89 callback_data = self.callback_data
91 class DataCallbackHandler(CallbackHandler):
92 def __init__(self, request, client_address, server):
93 super().__init__(request, client_address, server, callback_data)
95 return DataCallbackHandler
97 def start(self) -> None:
98 """Start the callback server in a background thread.
100 Raises:
101 OAuthCallbackError: If server cannot be started
102 """
103 try:
104 handler_class = self._create_handler_with_data()
105 self.server = HTTPServer(("localhost", self.port), handler_class)
106 self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
107 self.thread.start()
108 except OSError as e:
109 if e.errno == 48: # Address already in use
110 raise OAuthCallbackError(
111 f"Port {self.port} is already in use. Try using a different port or check if another OAuth flow is running.",
112 context={"port": self.port, "original_error": str(e)}
113 )
114 else:
115 raise OAuthCallbackError(
116 f"Failed to start OAuth callback server on port {self.port}: {str(e)}",
117 context={"port": self.port, "original_error": str(e)}
118 )
119 except Exception as e:
120 raise OAuthCallbackError(
121 f"Unexpected error starting OAuth callback server: {str(e)}",
122 context={"port": self.port, "original_error": str(e)}
123 )
125 def stop(self) -> None:
126 """Stop the callback server."""
127 if self.server:
128 self.server.shutdown()
129 self.server.server_close()
130 if self.thread:
131 self.thread.join(timeout=1)
133 def wait_for_callback(self, timeout: int = 300) -> str:
134 """Wait for OAuth callback with timeout.
136 Args:
137 timeout: Maximum time to wait for callback in seconds
139 Returns:
140 Authorization code from OAuth callback
142 Raises:
143 OAuthTimeoutError: If timeout occurs
144 OAuthCancellationError: If user cancels authorization
145 OAuthServerError: If OAuth server returns an error
146 """
147 start_time = time.time()
148 while time.time() - start_time < timeout:
149 if self.callback_data["authorization_code"]:
150 return self.callback_data["authorization_code"]
151 elif self.callback_data["error"]:
152 error = self.callback_data["error"]
153 context = {"port": self.port, "timeout": timeout}
155 # Map common OAuth error codes to specific exceptions
156 if error in ["access_denied", "user_cancelled"]:
157 raise OAuthCancellationError(
158 error_details=error,
159 context=context
160 )
161 else:
162 # Generic server error for other OAuth errors
163 raise OAuthServerError(
164 server_error=error,
165 context=context
166 )
167 time.sleep(0.1)
169 raise OAuthTimeoutError(
170 timeout_seconds=timeout,
171 context={"port": self.port}
172 )
174 def get_state(self) -> str | None:
175 """Get the received state parameter.
177 Returns:
178 OAuth state parameter or None
179 """
180 return self.callback_data["state"]
183class BaseOAuthHandler(ABC):
184 """Base class for OAuth authentication handlers.
186 Combines redirect and callback handling into a single cohesive interface.
187 Subclasses should implement both the redirect flow (opening authorization URL)
188 and callback flow (receiving authorization code).
189 """
191 @abstractmethod
192 async def handle_redirect(self, authorization_url: str) -> None:
193 """Handle OAuth redirect to authorization URL.
195 Args:
196 authorization_url: The OAuth authorization URL to redirect the user to
197 """
198 pass
200 @abstractmethod
201 async def handle_callback(self) -> tuple[str, str | None]:
202 """Handle OAuth callback and return authorization code and state.
204 Returns:
205 Tuple of (authorization_code, state) received from OAuth provider
207 Raises:
208 Exception: If OAuth flow fails or times out
209 """
210 pass
213class LocalBrowserOAuthHandler(BaseOAuthHandler):
214 """Default OAuth handler using local browser and callback server.
216 Opens authorization URL in the user's default browser and starts a local
217 HTTP server to receive the OAuth callback. This is the most user-friendly
218 approach for desktop applications.
219 """
221 def __init__(self, callback_port: int = 3030, timeout: int = 300):
222 """Initialize the local browser OAuth handler.
224 Args:
225 callback_port: Port to run the local callback server on
226 timeout: Maximum time to wait for OAuth callback in seconds
227 """
228 self.callback_port = callback_port
229 self.timeout = timeout
230 self.callback_server: LocalCallbackServer | None = None
232 async def handle_redirect(self, authorization_url: str) -> None:
233 """Open authorization URL in the user's default browser.
235 Args:
236 authorization_url: OAuth authorization URL to open
238 Raises:
239 OAuthNetworkError: If browser cannot be opened
240 """
241 print(f"Opening OAuth authorization URL: {authorization_url}")
243 try:
244 success = webbrowser.open(authorization_url)
245 if not success:
246 print("Failed to automatically open browser. Please manually open the URL above.")
247 raise OAuthNetworkError(
248 Exception("Failed to open browser - no suitable browser found"),
249 context={"authorization_url": authorization_url}
250 )
251 except Exception as e:
252 if isinstance(e, OAuthNetworkError):
253 raise
254 print("Failed to automatically open browser. Please manually open the URL above.")
255 raise OAuthNetworkError(
256 e,
257 context={"authorization_url": authorization_url}
258 )
260 async def handle_callback(self) -> tuple[str, str | None]:
261 """Start local server and wait for OAuth callback.
263 Returns:
264 Tuple of (authorization_code, state) from OAuth callback
266 Raises:
267 OAuthCallbackError: If callback server cannot be started
268 OAuthTimeoutError: If callback doesn't arrive within timeout
269 OAuthCancellationError: If user cancels authorization
270 OAuthServerError: If OAuth server returns an error
271 """
272 try:
273 self.callback_server = LocalCallbackServer(port=self.callback_port)
274 self.callback_server.start()
276 auth_code = self.callback_server.wait_for_callback(timeout=self.timeout)
277 state = self.callback_server.get_state()
278 return auth_code, state
280 except (OAuthTimeoutError, OAuthCancellationError, OAuthServerError, OAuthCallbackError):
281 # Re-raise OAuth-specific exceptions as-is
282 raise
283 except Exception as e:
284 # Wrap unexpected errors
285 raise OAuthCallbackError(
286 f"Unexpected error during OAuth callback handling: {str(e)}",
287 context={
288 "port": self.callback_port,
289 "timeout": self.timeout,
290 "original_error": str(e)
291 }
292 )
293 finally:
294 if self.callback_server:
295 self.callback_server.stop()