diff --git a/.changeset/sse-auth-events.md b/.changeset/sse-auth-events.md new file mode 100644 index 0000000..49a1f60 --- /dev/null +++ b/.changeset/sse-auth-events.md @@ -0,0 +1,5 @@ +--- +"@onkernel/managed-auth-react": minor +--- + +Subscribe to managed auth state via the `/auth/connections/{id}/events` SSE endpoint instead of polling `/auth/connections/{id}` every 2s. Removes the post-submit race where the UI could briefly snap back to `awaiting_input` after submission. diff --git a/packages/managed-auth-react/src/lib/api.ts b/packages/managed-auth-react/src/lib/api.ts index 3f12960..fba39b6 100644 --- a/packages/managed-auth-react/src/lib/api.ts +++ b/packages/managed-auth-react/src/lib/api.ts @@ -1,4 +1,10 @@ -import type { ManagedAuthResponse, MFAType } from "./types"; +import type { + ManagedAuthResponse, + ManagedAuthStateEventData, + MFAType, +} from "./types"; + +export type { ManagedAuthStateEventData }; export interface ApiClientOptions { baseUrl?: string; @@ -10,11 +16,13 @@ const DEFAULT_BASE_URL = "https://api.onkernel.com"; export class ManagedAuthApiError extends Error { public readonly status: number; public readonly body: string; - constructor(message: string, status: number, body: string) { + public readonly fatal: boolean; + constructor(message: string, status: number, body: string, fatal = false) { super(message); this.name = "ManagedAuthApiError"; this.status = status; this.body = body; + this.fatal = fatal; } } @@ -156,3 +164,116 @@ export function submitSignInOption( options, ); } + +export interface ManagedAuthStreamHandlers { + onState: (event: ManagedAuthStateEventData) => void; + onError: (err: ManagedAuthApiError) => void; + onClose: () => void; +} + +interface ParsedSSEMessage { + event?: string; + data: string; +} + +function parseSSEMessage(raw: string): ParsedSSEMessage | null { + if (!raw.trim()) return null; + let event: string | undefined; + const dataLines: string[] = []; + for (const line of raw.split(/\r\n|\r|\n/)) { + if (!line || line.startsWith(":")) continue; + const colonIdx = line.indexOf(":"); + const field = colonIdx === -1 ? line : line.slice(0, colonIdx); + let value = colonIdx === -1 ? "" : line.slice(colonIdx + 1); + if (value.startsWith(" ")) value = value.slice(1); + if (field === "event") event = value; + else if (field === "data") dataLines.push(value); + } + if (dataLines.length === 0) return null; + return { event, data: dataLines.join("\n") }; +} + +export function streamManagedAuthEvents( + id: string, + jwt: string, + handlers: ManagedAuthStreamHandlers, + options?: ApiClientOptions, +): () => void { + const ac = new AbortController(); + void (async () => { + try { + const f = getFetch(options); + const res = await f( + `${getBaseUrl(options)}/auth/connections/${id}/events`, + { + method: "GET", + headers: { + Authorization: `Bearer ${jwt}`, + Accept: "text/event-stream", + }, + signal: ac.signal, + }, + ); + if (!res.ok) { + const msg = await parseError(res); + handlers.onError(new ManagedAuthApiError(msg, res.status, msg)); + return; + } + if (!res.body) { + handlers.onError( + new ManagedAuthApiError("SSE response has no body", 500, ""), + ); + return; + } + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let buf = ""; + // SSE message separator: blank line. Per spec, line endings can be + // \n, \r\n, or \r — so the separator can be \n\n, \r\n\r\n, or \r\r. + const SEPARATOR_RE = /\r\n\r\n|\r\r|\n\n/; + for (;;) { + const { value, done } = await reader.read(); + if (done) break; + buf += decoder.decode(value, { stream: true }); + for (;;) { + const match = SEPARATOR_RE.exec(buf); + if (!match) break; + const raw = buf.slice(0, match.index); + buf = buf.slice(match.index + match[0].length); + const msg = parseSSEMessage(raw); + if (!msg) continue; + if (msg.event === "managed_auth_state") { + try { + handlers.onState( + JSON.parse(msg.data) as ManagedAuthStateEventData, + ); + } catch { + /* ignore malformed payload */ + } + } else if (msg.event === "error") { + let message = "Stream error"; + try { + const data = JSON.parse(msg.data) as { + error?: { code?: string; message?: string }; + }; + if (data.error?.message) message = data.error.message; + } catch { + /* fall through with default message */ + } + handlers.onError( + new ManagedAuthApiError(message, 500, message, true), + ); + ac.abort(); + return; + } + } + } + handlers.onClose(); + } catch (err) { + if ((err as { name?: string })?.name === "AbortError") return; + const message = err instanceof Error ? err.message : "Stream failed"; + handlers.onError(new ManagedAuthApiError(message, 0, message)); + } + })(); + return () => ac.abort(); +} diff --git a/packages/managed-auth-react/src/lib/types.ts b/packages/managed-auth-react/src/lib/types.ts index fbce640..3caf3ef 100644 --- a/packages/managed-auth-react/src/lib/types.ts +++ b/packages/managed-auth-react/src/lib/types.ts @@ -73,6 +73,26 @@ export interface ManagedAuthResponse { error_code?: string | null; } +// Mirrors @onkernel/sdk's ConnectionFollowResponse.ManagedAuthStateEvent. +export interface ManagedAuthStateEventData { + event: "managed_auth_state"; + timestamp: string; + flow_status: FlowStatus; + flow_step: FlowStep; + flow_type?: "LOGIN" | "REAUTH"; + discovered_fields?: DiscoveredField[]; + mfa_options?: MFAOption[]; + sign_in_options?: SignInOption[]; + pending_sso_buttons?: SSOButton[]; + external_action_message?: string; + website_error?: string; + error_message?: string; + error_code?: string; + post_login_url?: string; + live_view_url?: string; + hosted_url?: string; +} + export type UIState = | "prime" | "discovering" diff --git a/packages/managed-auth-react/src/session/useManagedAuthSession.ts b/packages/managed-auth-react/src/session/useManagedAuthSession.ts index 515db26..6cceb57 100644 --- a/packages/managed-auth-react/src/session/useManagedAuthSession.ts +++ b/packages/managed-auth-react/src/session/useManagedAuthSession.ts @@ -1,13 +1,14 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { exchangeHandoffCode, - ManagedAuthApiError, retrieveManagedAuth, + streamManagedAuthEvents, submitFieldValues, submitMFASelection, submitSignInOption, submitSSOButton, type ApiClientOptions, + type ManagedAuthStateEventData, } from "../lib/api"; import type { AuthErrorPayload, @@ -18,8 +19,8 @@ import type { UIState, } from "../lib/types"; -const POLL_INTERVAL_MS = 2000; -const POST_SUBMIT_DELAY_MS = 2000; +const RECONNECT_BASE_MS = 1000; +const RECONNECT_MAX_MS = 15000; function deriveUIState(state: ManagedAuthResponse): UIState { if (state.flow_status === "FAILED" || state.flow_status === "CANCELED") { @@ -42,6 +43,29 @@ function deriveUIState(state: ManagedAuthResponse): UIState { } } +function isTerminal(uiState: UIState): boolean { + return uiState === "success" || uiState === "expired" || uiState === "error"; +} + +function mergeStateEvent( + base: ManagedAuthResponse, + ev: ManagedAuthStateEventData, +): ManagedAuthResponse { + return { + ...base, + flow_status: ev.flow_status, + flow_step: ev.flow_step, + discovered_fields: ev.discovered_fields ?? null, + pending_sso_buttons: ev.pending_sso_buttons ?? null, + mfa_options: ev.mfa_options ?? null, + sign_in_options: ev.sign_in_options ?? null, + external_action_message: ev.external_action_message ?? null, + website_error: ev.website_error ?? null, + error_message: ev.error_message ?? null, + error_code: ev.error_code ?? null, + }; +} + export interface ManagedAuthSessionOptions extends ApiClientOptions { sessionId: string; handoffCode: string; @@ -66,7 +90,7 @@ export interface ManagedAuthSessionValue { /** * Internal hook that owns the full state machine for a managed auth session — - * handoff code exchange, polling, submissions, UI-state derivation. + * handoff code exchange, SSE subscription, submissions, UI-state derivation. */ export function useManagedAuthSession( options: ManagedAuthSessionOptions, @@ -80,94 +104,175 @@ export function useManagedAuthSession( const [submitError, setSubmitError] = useState(null); const [initError, setInitError] = useState(null); - const pollRef = useRef | null>(null); - const pollDelayRef = useRef | null>(null); + const stateRef = useRef(null); + const disconnectRef = useRef<(() => void) | null>(null); + const reconnectTimerRef = useRef | null>(null); + const reconnectAttemptsRef = useRef(0); + const terminalRef = useRef(false); const callbackFiredRef = useRef<{ success: boolean; error: boolean }>({ success: false, error: false, }); - const stopPolling = useCallback(() => { - if (pollDelayRef.current) { - clearTimeout(pollDelayRef.current); - pollDelayRef.current = null; + const fireSuccessOnce = useCallback( + (payload: AuthSuccessPayload) => { + if (callbackFiredRef.current.success) return; + callbackFiredRef.current.success = true; + onSuccess?.(payload); + }, + [onSuccess], + ); + + const fireErrorOnce = useCallback( + (payload: AuthErrorPayload) => { + if (callbackFiredRef.current.error) return; + callbackFiredRef.current.error = true; + onError?.(payload); + }, + [onError], + ); + + const disconnectStream = useCallback(() => { + if (reconnectTimerRef.current) { + clearTimeout(reconnectTimerRef.current); + reconnectTimerRef.current = null; } - if (pollRef.current) { - clearInterval(pollRef.current); - pollRef.current = null; + if (disconnectRef.current) { + disconnectRef.current(); + disconnectRef.current = null; } }, []); - const pollOnce = useCallback( - async (tokenOverride?: string) => { - const token = tokenOverride ?? jwt; - if (!token) return; - try { - const newState = await retrieveManagedAuth(sessionId, token, options); - setState(newState); - setSubmitError(null); + const connectStream = useCallback( + (token: string) => { + if (terminalRef.current) return; + if (disconnectRef.current) return; - const nextUI = deriveUIState(newState); + const handleStateEvent = (ev: ManagedAuthStateEventData) => { + reconnectAttemptsRef.current = 0; + setSubmitError(null); + const base = stateRef.current; + if (!base) return; + const merged = mergeStateEvent(base, ev); + stateRef.current = merged; + setState(merged); + const nextUI = deriveUIState(merged); setUIState(nextUI); - if (nextUI === "success") { - if (!callbackFiredRef.current.success) { - callbackFiredRef.current.success = true; - onSuccess?.({ - profileName: newState.profile_name, - domain: newState.domain, - }); - } - stopPolling(); + terminalRef.current = true; + fireSuccessOnce({ + profileName: merged.profile_name, + domain: merged.domain, + }); + disconnectStream(); } else if (nextUI === "error" || nextUI === "expired") { - if (!callbackFiredRef.current.error) { - callbackFiredRef.current.error = true; - onError?.({ - code: newState.error_code ?? undefined, - message: - newState.error_message || - newState.website_error || - (nextUI === "expired" ? "Session expired" : "Login failed"), - }); - } - stopPolling(); + terminalRef.current = true; + fireErrorOnce({ + code: merged.error_code ?? undefined, + message: + merged.error_message || + merged.website_error || + (nextUI === "expired" ? "Session expired" : "Login failed"), + }); + disconnectStream(); } - } catch (err) { - const apiErr = err as ManagedAuthApiError; - if (apiErr?.status === 401 || apiErr?.status === 410) { - stopPolling(); - setUIState("expired"); - if (!callbackFiredRef.current.error) { - callbackFiredRef.current.error = true; - onError?.({ message: "Session expired" }); + }; + + const scheduleReconnect = () => { + if (terminalRef.current) return; + const attempt = reconnectAttemptsRef.current++; + const delay = Math.min( + RECONNECT_BASE_MS * Math.pow(2, attempt), + RECONNECT_MAX_MS, + ); + reconnectTimerRef.current = setTimeout(() => { + reconnectTimerRef.current = null; + void resyncAndConnect(token); + }, delay); + }; + + // SSE only emits future deltas. After a drop, resync via GET so we don't + // miss state changes that happened during the disconnect window before + // resubscribing to the stream. + const resyncAndConnect = async (t: string) => { + if (terminalRef.current) return; + try { + const fresh = await retrieveManagedAuth(sessionId, t, options); + if (terminalRef.current) return; + stateRef.current = fresh; + setState(fresh); + const derived = deriveUIState(fresh); + setUIState(derived); + if (isTerminal(derived)) { + terminalRef.current = true; + if (derived === "success") { + fireSuccessOnce({ + profileName: fresh.profile_name, + domain: fresh.domain, + }); + } else { + fireErrorOnce({ + code: fresh.error_code ?? undefined, + message: + fresh.error_message || + fresh.website_error || + (derived === "expired" ? "Session expired" : "Login failed"), + }); + } + return; + } + connectStream(t); + } catch (err) { + const status = (err as { status?: number })?.status; + if (status === 401 || status === 410) { + terminalRef.current = true; + setUIState("expired"); + fireErrorOnce({ message: "Session expired" }); + return; } + scheduleReconnect(); } - } - }, - [jwt, onError, onSuccess, options, sessionId, stopPolling], - ); - - const startPolling = useCallback( - (immediate = true, delayMs = 0, tokenOverride?: string) => { - if (pollRef.current) return; - const begin = () => { - if (pollRef.current) return; - pollRef.current = setInterval(() => { - void pollOnce(tokenOverride); - }, POLL_INTERVAL_MS); - if (immediate) void pollOnce(tokenOverride); }; - if (delayMs > 0) { - pollDelayRef.current = setTimeout(begin, delayMs); - } else { - begin(); - } + + disconnectRef.current = streamManagedAuthEvents( + sessionId, + token, + { + onState: handleStateEvent, + onError: (err) => { + disconnectRef.current = null; + if (err.status === 401 || err.status === 410) { + terminalRef.current = true; + setUIState("expired"); + fireErrorOnce({ message: "Session expired" }); + return; + } + if (err.fatal) { + terminalRef.current = true; + setUIState("error"); + fireErrorOnce({ message: err.message }); + return; + } + scheduleReconnect(); + }, + onClose: () => { + disconnectRef.current = null; + if (terminalRef.current) return; + scheduleReconnect(); + }, + }, + options, + ); }, - [pollOnce], + [disconnectStream, fireErrorOnce, fireSuccessOnce, options, sessionId], ); useEffect(() => { let cancelled = false; + terminalRef.current = false; + reconnectAttemptsRef.current = 0; + callbackFiredRef.current = { success: false, error: false }; + (async () => { try { const token = await exchangeHandoffCode( @@ -179,26 +284,19 @@ export function useManagedAuthSession( setJwt(token); const initial = await retrieveManagedAuth(sessionId, token, options); if (cancelled) return; + stateRef.current = initial; setState(initial); const derived = deriveUIState(initial); - if ( - derived === "success" || - derived === "expired" || - derived === "error" - ) { + if (isTerminal(derived)) { + terminalRef.current = true; setUIState(derived); - if (derived === "success" && !callbackFiredRef.current.success) { - callbackFiredRef.current.success = true; - onSuccess?.({ + if (derived === "success") { + fireSuccessOnce({ profileName: initial.profile_name, domain: initial.domain, }); - } else if ( - (derived === "error" || derived === "expired") && - !callbackFiredRef.current.error - ) { - callbackFiredRef.current.error = true; - onError?.({ + } else { + fireErrorOnce({ code: initial.error_code ?? undefined, message: initial.error_message || @@ -208,7 +306,7 @@ export function useManagedAuthSession( } } else if (autoStart) { setUIState("discovering"); - startPolling(true, 0, token); + connectStream(token); } else { setUIState("prime"); } @@ -218,15 +316,13 @@ export function useManagedAuthSession( err instanceof Error ? err.message : "Failed to start session"; setInitError(message); setUIState("error"); - if (!callbackFiredRef.current.error) { - callbackFiredRef.current.error = true; - onError?.({ message }); - } + terminalRef.current = true; + fireErrorOnce({ message }); } })(); return () => { cancelled = true; - stopPolling(); + disconnectStream(); }; // eslint-disable-next-line react-hooks/exhaustive-deps }, [sessionId, handoffCode]); @@ -234,8 +330,8 @@ export function useManagedAuthSession( const startFlow = useCallback(() => { if (!jwt) return; setUIState("discovering"); - startPolling(true, 0); - }, [jwt, startPolling]); + connectStream(jwt); + }, [jwt, connectStream]); const submit = useCallback( async (fn: () => Promise, onFail: string) => { @@ -243,24 +339,19 @@ export function useManagedAuthSession( setIsSubmitting(true); setSubmitError(null); setUIState("submitting"); - stopPolling(); try { await fn(); - startPolling(false, POST_SUBMIT_DELAY_MS); } catch (err) { const msg = err instanceof Error ? err.message : onFail; setSubmitError(msg); setUIState((current) => - current === "success" || current === "expired" || current === "error" - ? current - : "awaiting_input", + isTerminal(current) ? current : "awaiting_input", ); - startPolling(); } finally { setIsSubmitting(false); } }, - [jwt, startPolling, stopPolling], + [jwt], ); const submitFields = useCallback(