Source code for kelvin.sdk.services.auth.callback_server

"""Local HTTP server for handling OAuth callbacks."""

from __future__ import annotations

import contextlib
import threading
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Final, Optional
from urllib.parse import parse_qs, urlparse

from typing_extensions import override

from kelvin.sdk.services.auth.errors import OAuthFlowError


[docs] @dataclass class CallbackResult: """Result from OAuth callback.""" code: Optional[str] = None state: Optional[str] = None
[docs] class CallbackHandler(BaseHTTPRequestHandler): """HTTP request handler for OAuth callbacks.""" authorization_code: Optional[str] = None state: Optional[str] = None error: Optional[str] = None error_description: Optional[str] = None callback_received: Optional[threading.Event] = None
[docs] def do_GET(self) -> None: """Handle GET requests to the callback endpoint.""" parsed_url = urlparse(self.path) query_params = parse_qs(parsed_url.query) if parsed_url.path == "/callback": if "code" in query_params: code_list = query_params.get("code") state_list = query_params.get("state") CallbackHandler.authorization_code = code_list[0] if code_list else None CallbackHandler.state = state_list[0] if state_list else None self._send_success_response() else: error_list = query_params.get("error") error_desc_list = query_params.get("error_description") CallbackHandler.error = error_list[0] if error_list else "unknown" CallbackHandler.error_description = error_desc_list[0] if error_desc_list else "No description" self._send_error_response() if CallbackHandler.callback_received: CallbackHandler.callback_received.set() else: self._send_404_response()
[docs] @override def finish(self) -> None: """Override finish to ensure connection is properly closed.""" with contextlib.suppress(Exception): BaseHTTPRequestHandler.finish(self)
def _send_success_response(self) -> None: """Send a success response to the browser.""" html_path = Path(__file__).parent / "html" / "success.html" try: html = html_path.read_text(encoding="utf-8") except Exception: html = "<html><body><h1>Authentication Successful!</h1><p>You can close this window.</p></body></html>" self.send_response(200) self.send_header("Content-type", "text/html") self.end_headers() _ = self.wfile.write(html.encode("utf-8")) def _send_error_response(self) -> None: """Send an error response to the browser.""" html_path = Path(__file__).parent / "html" / "error.html" try: html = html_path.read_text(encoding="utf-8") except Exception: html = "<html><body><h1>Authentication Error!</h1></body></html>" self.send_response(400) self.send_header("Content-type", "text/html") self.end_headers() _ = self.wfile.write(html.encode("utf-8")) def _send_404_response(self) -> None: """Send a 404 response.""" self.send_response(404) self.send_header("Content-type", "text/plain") self.end_headers() _ = self.wfile.write(b"Not Found")
[docs] @override def log_message(self, format: str, *args: object) -> None: """Suppress log messages.""" pass
[docs] class CallbackServer: """Local HTTP server for handling OAuth callbacks.""" ADDRESS: Final[str] = "localhost"
[docs] def __init__(self, port: int = 8080) -> None: """Initialize the callback server. Args: port: The port to listen on. """ self.port: int = port self.server: Optional[HTTPServer] = None self.thread: Optional[threading.Thread] = None
[docs] def get_callback_url(self) -> str: """Get the callback URL for the OAuth flow. Returns: The callback URL. """ return f"http://{self.ADDRESS}:{self.port}/callback"
[docs] def start(self) -> None: """Start the callback server. Raises: OAuthFlowError: If the server fails to start. """ # Reset class variables CallbackHandler.authorization_code = None CallbackHandler.state = None CallbackHandler.error = None CallbackHandler.error_description = None CallbackHandler.callback_received = threading.Event() try: self.server = HTTPServer((self.ADDRESS, self.port), CallbackHandler) self.server.timeout = 1.0 self.thread = threading.Thread(target=self._run_server, daemon=True) self.thread.start() except Exception as e: raise OAuthFlowError(f"Failed to start callback server: {e}") from e
def _run_server(self) -> None: """Run the server (called in a separate thread).""" if self.server: with contextlib.suppress(Exception): self.server.serve_forever()
[docs] def wait_for_callback(self, timeout: float = 120.0) -> CallbackResult: """Wait for the OAuth callback. Args: timeout: Maximum time to wait in seconds. Returns: CallbackResult with the authorization code and state. Raises: OAuthFlowError: If timeout or error occurs. """ callback_received = False if CallbackHandler.callback_received: callback_received = CallbackHandler.callback_received.wait(timeout=timeout) # Shutdown the server if self.server: with contextlib.suppress(Exception): self.server.shutdown() if not callback_received: raise OAuthFlowError("Timed out waiting for OAuth callback") if CallbackHandler.error: raise OAuthFlowError(f"OAuth callback error: {CallbackHandler.error_description}") if CallbackHandler.authorization_code: return CallbackResult( code=CallbackHandler.authorization_code, state=CallbackHandler.state, ) raise OAuthFlowError("OAuth callback did not contain an authorization code")
[docs] def stop(self) -> None: """Stop the callback server.""" if self.server: self.server.shutdown() self.server = None self.thread = None