From 78a3e985be9441a9696529c206909557e0b3f9ab Mon Sep 17 00:00:00 2001 From: YeonGyu-Kim Date: Wed, 25 Mar 2026 11:44:11 +0900 Subject: [PATCH] fix(mcp-oauth): robust port binding for callback server MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use port 0 fallback when findAvailablePort fails, read the actual bound port from server.port. Tests refactored to use mock server when real socket binding is unavailable in CI. 🤖 GENERATED WITH ASSISTANCE OF [OhMyOpenCode](https://github.com/code-yeongyu/oh-my-opencode) Co-Authored-By: Claude Opus 4.6 --- .../mcp-oauth/callback-server.test.ts | 122 +++++++++++++----- src/features/mcp-oauth/callback-server.ts | 7 +- 2 files changed, 93 insertions(+), 36 deletions(-) diff --git a/src/features/mcp-oauth/callback-server.test.ts b/src/features/mcp-oauth/callback-server.test.ts index 060cdb580..063f2bdea 100644 --- a/src/features/mcp-oauth/callback-server.test.ts +++ b/src/features/mcp-oauth/callback-server.test.ts @@ -1,44 +1,112 @@ -import { afterEach, describe, expect, it } from "bun:test" +import { afterEach, beforeEach, describe, expect, it, spyOn } from "bun:test" import { startCallbackServer, type CallbackServer } from "./callback-server" +const HOSTNAME = "127.0.0.1" const nativeFetch = Bun.fetch.bind(Bun) +function supportsRealSocketBinding(): boolean { + try { + const server = Bun.serve({ + port: 0, + hostname: HOSTNAME, + fetch: () => new Response("probe"), + }) + server.stop(true) + return true + } catch { + return false + } +} + +const canBindRealSockets = supportsRealSocketBinding() + +type MockServerState = { + port: number + stopped: boolean + fetch: (request: Request) => Response | Promise +} + describe("startCallbackServer", () => { let server: CallbackServer | null = null + let serveSpy: ReturnType | null = null + let activeServer: MockServerState | null = null + + async function request(url: string): Promise { + if (canBindRealSockets) { + return nativeFetch(url) + } + + if (!activeServer || activeServer.stopped) { + throw new Error("Connection refused") + } + + return await activeServer.fetch(new Request(url)) + } + + beforeEach(() => { + if (canBindRealSockets) { + return + } + + activeServer = null + serveSpy = spyOn(Bun, "serve").mockImplementation((options: { + port: number + hostname?: string + fetch: (request: Request) => Response | Promise + }) => { + const state: MockServerState = { + port: options.port === 0 ? 19877 : options.port, + stopped: false, + fetch: options.fetch, + } + + const handle = { + port: state.port, + stop: (_force?: boolean) => { + state.stopped = true + if (activeServer === state) { + activeServer = null + } + }, + } + + activeServer = state + return handle as ReturnType + }) + }) afterEach(async () => { server?.close() server = null - // Allow time for port to be released before next test - await Bun.sleep(10) + + if (serveSpy) { + serveSpy.mockRestore() + serveSpy = null + } + activeServer = null + + if (canBindRealSockets) { + await Bun.sleep(10) + } }) it("starts server and returns port", async () => { - // given - no preconditions - - // when server = await startCallbackServer() - // then expect(server.port).toBeGreaterThanOrEqual(19877) expect(typeof server.waitForCallback).toBe("function") expect(typeof server.close).toBe("function") }) it("resolves callback with code and state from query params", async () => { - // given server = await startCallbackServer() - const callbackUrl = `http://127.0.0.1:${server.port}/oauth/callback?code=test-code&state=test-state` + const callbackUrl = `http://${HOSTNAME}:${server.port}/oauth/callback?code=test-code&state=test-state` - // when - // Use Promise.all to ensure fetch and waitForCallback run concurrently - // This prevents race condition where waitForCallback blocks before fetch starts const [result, response] = await Promise.all([ server.waitForCallback(), - nativeFetch(callbackUrl) + request(callbackUrl), ]) - // then expect(result).toEqual({ code: "test-code", state: "test-state" }) expect(response.status).toBe(200) const html = await response.text() @@ -46,25 +114,19 @@ describe("startCallbackServer", () => { }) it("returns 404 for non-callback routes", async () => { - // given server = await startCallbackServer() - // when - const response = await nativeFetch(`http://127.0.0.1:${server.port}/other`) + const response = await request(`http://${HOSTNAME}:${server.port}/other`) - // then expect(response.status).toBe(404) }) it("returns 400 and rejects when code is missing", async () => { - // given server = await startCallbackServer() - const callbackRejection = server.waitForCallback().catch((e: Error) => e) + const callbackRejection = server.waitForCallback().catch((error: Error) => error) - // when - const response = await nativeFetch(`http://127.0.0.1:${server.port}/oauth/callback?state=s`) + const response = await request(`http://${HOSTNAME}:${server.port}/oauth/callback?state=s`) - // then expect(response.status).toBe(400) const error = await callbackRejection expect(error).toBeInstanceOf(Error) @@ -72,14 +134,11 @@ describe("startCallbackServer", () => { }) it("returns 400 and rejects when state is missing", async () => { - // given server = await startCallbackServer() - const callbackRejection = server.waitForCallback().catch((e: Error) => e) + const callbackRejection = server.waitForCallback().catch((error: Error) => error) - // when - const response = await nativeFetch(`http://127.0.0.1:${server.port}/oauth/callback?code=c`) + const response = await request(`http://${HOSTNAME}:${server.port}/oauth/callback?code=c`) - // then expect(response.status).toBe(400) const error = await callbackRejection expect(error).toBeInstanceOf(Error) @@ -87,18 +146,15 @@ describe("startCallbackServer", () => { }) it("close stops the server immediately", async () => { - // given server = await startCallbackServer() const port = server.port - // when server.close() server = null - // then try { - await nativeFetch(`http://127.0.0.1:${port}/oauth/callback?code=c&state=s`) - expect(true).toBe(false) + await request(`http://${HOSTNAME}:${port}/oauth/callback?code=c&state=s`) + expect.unreachable("request should fail after close") } catch (error) { expect(error).toBeDefined() } diff --git a/src/features/mcp-oauth/callback-server.ts b/src/features/mcp-oauth/callback-server.ts index c8d856fa8..48dcb1729 100644 --- a/src/features/mcp-oauth/callback-server.ts +++ b/src/features/mcp-oauth/callback-server.ts @@ -39,7 +39,7 @@ export async function findAvailablePort(startPort: number = DEFAULT_PORT): Promi } export async function startCallbackServer(startPort: number = DEFAULT_PORT): Promise { - const port = await findAvailablePort(startPort) + const requestedPort = await findAvailablePort(startPort).catch(() => 0) let resolveCallback: ((result: OAuthCallbackResult) => void) | null = null let rejectCallback: ((error: Error) => void) | null = null @@ -55,7 +55,7 @@ export async function startCallbackServer(startPort: number = DEFAULT_PORT): Pro }, TIMEOUT_MS) const server = Bun.serve({ - port, + port: requestedPort, hostname: "127.0.0.1", fetch(request: Request): Response { const url = new URL(request.url) @@ -93,9 +93,10 @@ export async function startCallbackServer(startPort: number = DEFAULT_PORT): Pro }) }, }) + const activePort = server.port ?? requestedPort return { - port, + port: activePort, waitForCallback: () => callbackPromise, close: () => { clearTimeout(timeoutId)