fix(mcp-oauth): robust port binding for callback server

Use port 0 fallback when findAvailablePort fails, read the actual bound
port from server.port. Tests refactored to use mock server when real
socket binding is unavailable in CI.

🤖 GENERATED WITH ASSISTANCE OF [OhMyOpenCode](https://github.com/code-yeongyu/oh-my-opencode)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
YeonGyu-Kim
2026-03-25 11:44:11 +09:00
parent 42fb2548d6
commit 78a3e985be
2 changed files with 93 additions and 36 deletions

View File

@@ -1,44 +1,112 @@
import { afterEach, describe, expect, it } from "bun:test" import { afterEach, beforeEach, describe, expect, it, spyOn } from "bun:test"
import { startCallbackServer, type CallbackServer } from "./callback-server" import { startCallbackServer, type CallbackServer } from "./callback-server"
const HOSTNAME = "127.0.0.1"
const nativeFetch = Bun.fetch.bind(Bun) const nativeFetch = Bun.fetch.bind(Bun)
function supportsRealSocketBinding(): boolean {
try {
const server = Bun.serve({
port: 0,
hostname: HOSTNAME,
fetch: () => new Response("probe"),
})
server.stop(true)
return true
} catch {
return false
}
}
const canBindRealSockets = supportsRealSocketBinding()
type MockServerState = {
port: number
stopped: boolean
fetch: (request: Request) => Response | Promise<Response>
}
describe("startCallbackServer", () => { describe("startCallbackServer", () => {
let server: CallbackServer | null = null let server: CallbackServer | null = null
let serveSpy: ReturnType<typeof spyOn> | null = null
let activeServer: MockServerState | null = null
async function request(url: string): Promise<Response> {
if (canBindRealSockets) {
return nativeFetch(url)
}
if (!activeServer || activeServer.stopped) {
throw new Error("Connection refused")
}
return await activeServer.fetch(new Request(url))
}
beforeEach(() => {
if (canBindRealSockets) {
return
}
activeServer = null
serveSpy = spyOn(Bun, "serve").mockImplementation((options: {
port: number
hostname?: string
fetch: (request: Request) => Response | Promise<Response>
}) => {
const state: MockServerState = {
port: options.port === 0 ? 19877 : options.port,
stopped: false,
fetch: options.fetch,
}
const handle = {
port: state.port,
stop: (_force?: boolean) => {
state.stopped = true
if (activeServer === state) {
activeServer = null
}
},
}
activeServer = state
return handle as ReturnType<typeof Bun.serve>
})
})
afterEach(async () => { afterEach(async () => {
server?.close() server?.close()
server = null server = null
// Allow time for port to be released before next test
await Bun.sleep(10) if (serveSpy) {
serveSpy.mockRestore()
serveSpy = null
}
activeServer = null
if (canBindRealSockets) {
await Bun.sleep(10)
}
}) })
it("starts server and returns port", async () => { it("starts server and returns port", async () => {
// given - no preconditions
// when
server = await startCallbackServer() server = await startCallbackServer()
// then
expect(server.port).toBeGreaterThanOrEqual(19877) expect(server.port).toBeGreaterThanOrEqual(19877)
expect(typeof server.waitForCallback).toBe("function") expect(typeof server.waitForCallback).toBe("function")
expect(typeof server.close).toBe("function") expect(typeof server.close).toBe("function")
}) })
it("resolves callback with code and state from query params", async () => { it("resolves callback with code and state from query params", async () => {
// given
server = await startCallbackServer() server = await startCallbackServer()
const callbackUrl = `http://127.0.0.1:${server.port}/oauth/callback?code=test-code&state=test-state` const callbackUrl = `http://${HOSTNAME}:${server.port}/oauth/callback?code=test-code&state=test-state`
// when
// Use Promise.all to ensure fetch and waitForCallback run concurrently
// This prevents race condition where waitForCallback blocks before fetch starts
const [result, response] = await Promise.all([ const [result, response] = await Promise.all([
server.waitForCallback(), server.waitForCallback(),
nativeFetch(callbackUrl) request(callbackUrl),
]) ])
// then
expect(result).toEqual({ code: "test-code", state: "test-state" }) expect(result).toEqual({ code: "test-code", state: "test-state" })
expect(response.status).toBe(200) expect(response.status).toBe(200)
const html = await response.text() const html = await response.text()
@@ -46,25 +114,19 @@ describe("startCallbackServer", () => {
}) })
it("returns 404 for non-callback routes", async () => { it("returns 404 for non-callback routes", async () => {
// given
server = await startCallbackServer() server = await startCallbackServer()
// when const response = await request(`http://${HOSTNAME}:${server.port}/other`)
const response = await nativeFetch(`http://127.0.0.1:${server.port}/other`)
// then
expect(response.status).toBe(404) expect(response.status).toBe(404)
}) })
it("returns 400 and rejects when code is missing", async () => { it("returns 400 and rejects when code is missing", async () => {
// given
server = await startCallbackServer() server = await startCallbackServer()
const callbackRejection = server.waitForCallback().catch((e: Error) => e) const callbackRejection = server.waitForCallback().catch((error: Error) => error)
// when const response = await request(`http://${HOSTNAME}:${server.port}/oauth/callback?state=s`)
const response = await nativeFetch(`http://127.0.0.1:${server.port}/oauth/callback?state=s`)
// then
expect(response.status).toBe(400) expect(response.status).toBe(400)
const error = await callbackRejection const error = await callbackRejection
expect(error).toBeInstanceOf(Error) expect(error).toBeInstanceOf(Error)
@@ -72,14 +134,11 @@ describe("startCallbackServer", () => {
}) })
it("returns 400 and rejects when state is missing", async () => { it("returns 400 and rejects when state is missing", async () => {
// given
server = await startCallbackServer() server = await startCallbackServer()
const callbackRejection = server.waitForCallback().catch((e: Error) => e) const callbackRejection = server.waitForCallback().catch((error: Error) => error)
// when const response = await request(`http://${HOSTNAME}:${server.port}/oauth/callback?code=c`)
const response = await nativeFetch(`http://127.0.0.1:${server.port}/oauth/callback?code=c`)
// then
expect(response.status).toBe(400) expect(response.status).toBe(400)
const error = await callbackRejection const error = await callbackRejection
expect(error).toBeInstanceOf(Error) expect(error).toBeInstanceOf(Error)
@@ -87,18 +146,15 @@ describe("startCallbackServer", () => {
}) })
it("close stops the server immediately", async () => { it("close stops the server immediately", async () => {
// given
server = await startCallbackServer() server = await startCallbackServer()
const port = server.port const port = server.port
// when
server.close() server.close()
server = null server = null
// then
try { try {
await nativeFetch(`http://127.0.0.1:${port}/oauth/callback?code=c&state=s`) await request(`http://${HOSTNAME}:${port}/oauth/callback?code=c&state=s`)
expect(true).toBe(false) expect.unreachable("request should fail after close")
} catch (error) { } catch (error) {
expect(error).toBeDefined() expect(error).toBeDefined()
} }

View File

@@ -39,7 +39,7 @@ export async function findAvailablePort(startPort: number = DEFAULT_PORT): Promi
} }
export async function startCallbackServer(startPort: number = DEFAULT_PORT): Promise<CallbackServer> { export async function startCallbackServer(startPort: number = DEFAULT_PORT): Promise<CallbackServer> {
const port = await findAvailablePort(startPort) const requestedPort = await findAvailablePort(startPort).catch(() => 0)
let resolveCallback: ((result: OAuthCallbackResult) => void) | null = null let resolveCallback: ((result: OAuthCallbackResult) => void) | null = null
let rejectCallback: ((error: Error) => void) | null = null let rejectCallback: ((error: Error) => void) | null = null
@@ -55,7 +55,7 @@ export async function startCallbackServer(startPort: number = DEFAULT_PORT): Pro
}, TIMEOUT_MS) }, TIMEOUT_MS)
const server = Bun.serve({ const server = Bun.serve({
port, port: requestedPort,
hostname: "127.0.0.1", hostname: "127.0.0.1",
fetch(request: Request): Response { fetch(request: Request): Response {
const url = new URL(request.url) const url = new URL(request.url)
@@ -93,9 +93,10 @@ export async function startCallbackServer(startPort: number = DEFAULT_PORT): Pro
}) })
}, },
}) })
const activePort = server.port ?? requestedPort
return { return {
port, port: activePort,
waitForCallback: () => callbackPromise, waitForCallback: () => callbackPromise,
close: () => { close: () => {
clearTimeout(timeoutId) clearTimeout(timeoutId)