diff --git a/packages/runtime/src/handlers/websocket.ts b/packages/runtime/src/handlers/websocket.ts index 014dfc44b..d4b556bb7 100644 --- a/packages/runtime/src/handlers/websocket.ts +++ b/packages/runtime/src/handlers/websocket.ts @@ -2,6 +2,7 @@ import type { Context, MiddlewareHandler } from 'hono'; import { upgradeWebSocket } from 'hono/bun'; import { context as otelContext, ROOT_CONTEXT } from '@opentelemetry/api'; import type { Env } from '../app'; +import { getAgentAsyncLocalStorage, getHTTPAsyncLocalStorage } from '../_context'; import { tagRoute } from './_route-meta'; /** @@ -100,6 +101,10 @@ export function websocket( let messageHandler: ((event: MessageEvent) => void | Promise) | undefined; let closeHandler: ((event: CloseEvent) => void | Promise) | undefined; let initialized = false; + const agentAsyncLocalStorage = getAgentAsyncLocalStorage(); + const httpAsyncLocalStorage = getHTTPAsyncLocalStorage(); + const capturedAgentContext = agentAsyncLocalStorage.getStore(); + const capturedHTTPContext = httpAsyncLocalStorage.getStore(); // Create done promise for session lifecycle deferral, but ONLY for actual // WebSocket upgrade requests. The factory runs unconditionally for every @@ -137,12 +142,29 @@ export function websocket( }, }; + const runWithCapturedContext = (callback: () => T): T => { + return otelContext.with(ROOT_CONTEXT, () => { + const runWithAgentContext = () => { + if (capturedAgentContext) { + return agentAsyncLocalStorage.run(capturedAgentContext, callback); + } + return callback(); + }; + + if (capturedHTTPContext) { + return httpAsyncLocalStorage.run(capturedHTTPContext, runWithAgentContext); + } + + return runWithAgentContext(); + }); + }; + // IMPORTANT: We run in ROOT_CONTEXT (no active OTEL span) to avoid a Bun bug // where OTEL-instrumented fetch conflicts with streaming responses. // See: https://github.com/agentuity/sdk/issues/471 // See: https://github.com/oven-sh/bun/issues/24766 const runHandler = () => { - otelContext.with(ROOT_CONTEXT, () => { + runWithCapturedContext(() => { handler(c, wsConnection); }); initialized = true; @@ -157,7 +179,7 @@ export function websocket( wsConnection.send = (data) => ws.send(data); if (openHandler) { - await otelContext.with(ROOT_CONTEXT, () => openHandler!(event)); + await runWithCapturedContext(() => openHandler!(event)); } } catch (err) { c.var.logger?.error('WebSocket onOpen error:', err); @@ -172,7 +194,7 @@ export function websocket( runHandler(); } if (messageHandler) { - await otelContext.with(ROOT_CONTEXT, () => messageHandler!(event)); + await runWithCapturedContext(() => messageHandler!(event)); } } catch (err) { c.var.logger?.error('WebSocket onMessage error:', err); @@ -183,7 +205,7 @@ export function websocket( onClose: async (event: CloseEvent, _ws: any) => { try { if (closeHandler) { - await otelContext.with(ROOT_CONTEXT, () => closeHandler!(event)); + await runWithCapturedContext(() => closeHandler!(event)); } } catch (err) { c.var.logger?.error('WebSocket onClose error:', err); diff --git a/packages/runtime/test/websocket-agent-context.test.ts b/packages/runtime/test/websocket-agent-context.test.ts new file mode 100644 index 000000000..d03e9c936 --- /dev/null +++ b/packages/runtime/test/websocket-agent-context.test.ts @@ -0,0 +1,277 @@ +import { describe, expect, test } from 'bun:test'; +import type { AuthInterface } from '@agentuity/auth/types'; +import { trace } from '@opentelemetry/api'; +import { Hono } from 'hono'; +import { websocket as bunWebsocket } from 'hono/bun'; +import { runInHTTPContext } from '../src/_context'; +import { createAgent, createAgentMiddleware } from '../src/agent'; +import { websocket } from '../src/handlers/websocket'; +import type { Logger } from '../src/logger'; +import type { Session, Thread } from '../src/session'; +import WaitUntilHandler from '../src/_waituntil'; +import { z } from 'zod'; + +interface EchoReadyMessage { + type: 'ready'; +} + +interface EchoSuccessMessage { + type: 'echo'; + routeSessionId: string; + routeThreadId: string; + data: { + echo: string; + sessionId: string; + threadId: string; + userId: string | null; + }; +} + +interface EchoErrorMessage { + type: 'error'; + message: string; +} + +type SocketMessage = EchoReadyMessage | EchoSuccessMessage | EchoErrorMessage; + +function createMockAuth(userId: string): AuthInterface { + return { + user: { id: userId, email: `${userId}@example.com`, name: 'Test User' }, + session: { id: 'session-123', userId }, + authMethod: 'session', + raw: {}, + getUser: async () => ({ id: userId, email: `${userId}@example.com`, name: 'Test User' }), + getToken: async () => null, + getOrg: async () => null, + getOrgRole: async () => null, + hasOrgRole: async () => false, + apiKey: null, + hasPermission: () => false, + }; +} + +function createMockLogger(): Logger { + const noop = () => {}; + return { + trace: noop, + debug: noop, + info: noop, + warn: noop, + error: noop, + fatal: noop as Logger['fatal'], + child: () => createMockLogger(), + }; +} + +function createMockThread(id: string): Thread { + const thread: Thread = { + id, + state: new Map(), + addEventListener: () => {}, + removeEventListener: () => {}, + destroy: async () => {}, + empty: () => thread.state.size === 0, + }; + return thread; +} + +function createMockSession(thread: Thread, id: string): Session { + return { + id, + thread, + state: new Map(), + addEventListener: () => {}, + removeEventListener: () => {}, + serializeUserData: () => undefined, + }; +} + +function waitForOpen(socket: WebSocket): Promise { + return new Promise((resolve, reject) => { + const cleanup = () => { + socket.removeEventListener('open', onOpen); + socket.removeEventListener('error', onError); + socket.removeEventListener('close', onClose); + }; + + const onOpen = () => { + cleanup(); + resolve(); + }; + const onError = () => { + cleanup(); + reject(new Error('WebSocket failed to open')); + }; + const onClose = () => { + cleanup(); + reject(new Error('WebSocket closed before opening')); + }; + + socket.addEventListener('open', onOpen, { once: true }); + socket.addEventListener('error', onError, { once: true }); + socket.addEventListener('close', onClose, { once: true }); + }); +} + +function waitForJsonMessage(socket: WebSocket): Promise { + return new Promise((resolve, reject) => { + const cleanup = () => { + socket.removeEventListener('message', onMessage); + socket.removeEventListener('error', onError); + socket.removeEventListener('close', onClose); + }; + + const onMessage = (event: MessageEvent) => { + cleanup(); + try { + resolve(JSON.parse(String(event.data)) as SocketMessage); + } catch (error) { + reject(error); + } + }; + const onError = () => { + cleanup(); + reject(new Error('WebSocket errored while waiting for a message')); + }; + const onClose = () => { + cleanup(); + reject(new Error('WebSocket closed before a message was received')); + }; + + socket.addEventListener('message', onMessage, { once: true }); + socket.addEventListener('error', onError, { once: true }); + socket.addEventListener('close', onClose, { once: true }); + }); +} + +async function closeSocket(socket: WebSocket): Promise { + if (socket.readyState === WebSocket.CLOSED) { + return; + } + + await new Promise((resolve) => { + socket.addEventListener('close', () => resolve(), { once: true }); + socket.close(); + }); +} + +describe('WebSocket agent context propagation', () => { + test('agent.run inside ws.onMessage preserves session, thread, and routed auth context', async () => { + const tracer = trace.getTracer('websocket-agent-context-test'); + + const echoAgent = createAgent('websocket-agent-context-propagation-test', { + schema: { + input: z.string(), + output: z.object({ + echo: z.string(), + sessionId: z.string(), + threadId: z.string(), + userId: z.string().nullable(), + }), + }, + handler: async (ctx, input) => { + return { + echo: input, + sessionId: ctx.sessionId, + threadId: ctx.thread.id, + userId: ctx.auth?.user?.id ?? null, + }; + }, + }); + + const app = new Hono(); + + app.use('*', async (c, next) => { + await runInHTTPContext(c, next); + }); + + app.use('/api/*', async (c, next) => { + const logger = createMockLogger(); + const thread = createMockThread(`thrd_${crypto.randomUUID()}`); + const session = createMockSession(thread, `sess_${crypto.randomUUID()}`); + const waitUntilHandler = new WaitUntilHandler(tracer); + + c.set('logger', logger); + c.set('tracer', tracer); + c.set('sessionId', session.id); + c.set('thread', thread); + c.set('session', session); + c.set('waitUntilHandler', waitUntilHandler); + c.set('agentIds', new Set()); + c.set('trigger', 'websocket'); + c.set('app', {}); + await next(); + }); + + app.use('/api/*', createAgentMiddleware('')); + + // This middleware runs after createAgentMiddleware('') and should still + // be visible to the agent via the lazy ctx.auth getter. + app.use('/api/*', async (c, next) => { + c.set('auth', createMockAuth('late-bound-user')); + await next(); + }); + + app.get( + '/api/echo', + websocket((c, ws) => { + ws.onOpen(() => { + ws.send(JSON.stringify({ type: 'ready' })); + }); + + ws.onMessage(async (event) => { + try { + const result = await echoAgent.run(String(event.data)); + ws.send( + JSON.stringify({ + type: 'echo', + routeSessionId: c.var.sessionId, + routeThreadId: c.var.thread.id, + data: result, + }) + ); + } catch (error) { + ws.send( + JSON.stringify({ + type: 'error', + message: error instanceof Error ? error.message : String(error), + }) + ); + } + }); + }) + ); + + const server = Bun.serve({ + port: 0, + fetch: (request, server) => app.fetch(request, server), + websocket: bunWebsocket, + }); + + const socket = new WebSocket(`ws://127.0.0.1:${server.port}/api/echo`); + + try { + await waitForOpen(socket); + + const ready = await waitForJsonMessage(socket); + expect(ready).toEqual({ type: 'ready' }); + + socket.send('hello from websocket test'); + + const response = await waitForJsonMessage(socket); + if (response.type !== 'echo') { + throw new Error(`Expected echo response, received ${JSON.stringify(response)}`); + } + + expect(response.data.echo).toBe('hello from websocket test'); + expect(response.data.sessionId).toBe(response.routeSessionId); + expect(response.data.threadId).toBe(response.routeThreadId); + expect(response.data.userId).toBe('late-bound-user'); + expect(response.data.sessionId.length).toBeGreaterThan(0); + expect(response.data.threadId.length).toBeGreaterThan(0); + } finally { + await closeSocket(socket); + server.stop(true); + } + }); +});