Source code for kelvin.sdk.services.auth.callback_server
"""Local HTTP server for handling OAuth callbacks."""
from __future__ import annotations
import contextlib
import threading
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Final, Optional
from urllib.parse import parse_qs, urlparse
from typing_extensions import override
from kelvin.sdk.services.auth.errors import OAuthFlowError
[docs]
@dataclass
class CallbackResult:
"""Result from OAuth callback."""
code: Optional[str] = None
state: Optional[str] = None
[docs]
class CallbackHandler(BaseHTTPRequestHandler):
"""HTTP request handler for OAuth callbacks."""
authorization_code: Optional[str] = None
state: Optional[str] = None
error: Optional[str] = None
error_description: Optional[str] = None
callback_received: Optional[threading.Event] = None
[docs]
def do_GET(self) -> None:
"""Handle GET requests to the callback endpoint."""
parsed_url = urlparse(self.path)
query_params = parse_qs(parsed_url.query)
if parsed_url.path == "/callback":
if "code" in query_params:
code_list = query_params.get("code")
state_list = query_params.get("state")
CallbackHandler.authorization_code = code_list[0] if code_list else None
CallbackHandler.state = state_list[0] if state_list else None
self._send_success_response()
else:
error_list = query_params.get("error")
error_desc_list = query_params.get("error_description")
CallbackHandler.error = error_list[0] if error_list else "unknown"
CallbackHandler.error_description = error_desc_list[0] if error_desc_list else "No description"
self._send_error_response()
if CallbackHandler.callback_received:
CallbackHandler.callback_received.set()
else:
self._send_404_response()
[docs]
@override
def finish(self) -> None:
"""Override finish to ensure connection is properly closed."""
with contextlib.suppress(Exception):
BaseHTTPRequestHandler.finish(self)
def _send_success_response(self) -> None:
"""Send a success response to the browser."""
html_path = Path(__file__).parent / "html" / "success.html"
try:
html = html_path.read_text(encoding="utf-8")
except Exception:
html = "<html><body><h1>Authentication Successful!</h1><p>You can close this window.</p></body></html>"
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
_ = self.wfile.write(html.encode("utf-8"))
def _send_error_response(self) -> None:
"""Send an error response to the browser."""
html_path = Path(__file__).parent / "html" / "error.html"
try:
html = html_path.read_text(encoding="utf-8")
except Exception:
html = "<html><body><h1>Authentication Error!</h1></body></html>"
self.send_response(400)
self.send_header("Content-type", "text/html")
self.end_headers()
_ = self.wfile.write(html.encode("utf-8"))
def _send_404_response(self) -> None:
"""Send a 404 response."""
self.send_response(404)
self.send_header("Content-type", "text/plain")
self.end_headers()
_ = self.wfile.write(b"Not Found")
[docs]
@override
def log_message(self, format: str, *args: object) -> None:
"""Suppress log messages."""
pass
[docs]
class CallbackServer:
"""Local HTTP server for handling OAuth callbacks."""
ADDRESS: Final[str] = "localhost"
[docs]
def __init__(self, port: int = 8080) -> None:
"""Initialize the callback server.
Args:
port: The port to listen on.
"""
self.port: int = port
self.server: Optional[HTTPServer] = None
self.thread: Optional[threading.Thread] = None
[docs]
def get_callback_url(self) -> str:
"""Get the callback URL for the OAuth flow.
Returns:
The callback URL.
"""
return f"http://{self.ADDRESS}:{self.port}/callback"
[docs]
def start(self) -> None:
"""Start the callback server.
Raises:
OAuthFlowError: If the server fails to start.
"""
# Reset class variables
CallbackHandler.authorization_code = None
CallbackHandler.state = None
CallbackHandler.error = None
CallbackHandler.error_description = None
CallbackHandler.callback_received = threading.Event()
try:
self.server = HTTPServer((self.ADDRESS, self.port), CallbackHandler)
self.server.timeout = 1.0
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.thread.start()
except Exception as e:
raise OAuthFlowError(f"Failed to start callback server: {e}") from e
def _run_server(self) -> None:
"""Run the server (called in a separate thread)."""
if self.server:
with contextlib.suppress(Exception):
self.server.serve_forever()
[docs]
def wait_for_callback(self, timeout: float = 120.0) -> CallbackResult:
"""Wait for the OAuth callback.
Args:
timeout: Maximum time to wait in seconds.
Returns:
CallbackResult with the authorization code and state.
Raises:
OAuthFlowError: If timeout or error occurs.
"""
callback_received = False
if CallbackHandler.callback_received:
callback_received = CallbackHandler.callback_received.wait(timeout=timeout)
# Shutdown the server
if self.server:
with contextlib.suppress(Exception):
self.server.shutdown()
if not callback_received:
raise OAuthFlowError("Timed out waiting for OAuth callback")
if CallbackHandler.error:
raise OAuthFlowError(f"OAuth callback error: {CallbackHandler.error_description}")
if CallbackHandler.authorization_code:
return CallbackResult(
code=CallbackHandler.authorization_code,
state=CallbackHandler.state,
)
raise OAuthFlowError("OAuth callback did not contain an authorization code")
[docs]
def stop(self) -> None:
"""Stop the callback server."""
if self.server:
self.server.shutdown()
self.server = None
self.thread = None