fix(mcp-oauth): robust port binding for callback server
Use port 0 fallback when findAvailablePort fails, read the actual bound port from server.port. Tests refactored to use mock server when real socket binding is unavailable in CI. 🤖 GENERATED WITH ASSISTANCE OF [OhMyOpenCode](https://github.com/code-yeongyu/oh-my-opencode) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,44 +1,112 @@
|
||||
import { afterEach, describe, expect, it } from "bun:test"
|
||||
import { afterEach, beforeEach, describe, expect, it, spyOn } from "bun:test"
|
||||
import { startCallbackServer, type CallbackServer } from "./callback-server"
|
||||
|
||||
const HOSTNAME = "127.0.0.1"
|
||||
const nativeFetch = Bun.fetch.bind(Bun)
|
||||
|
||||
function supportsRealSocketBinding(): boolean {
|
||||
try {
|
||||
const server = Bun.serve({
|
||||
port: 0,
|
||||
hostname: HOSTNAME,
|
||||
fetch: () => new Response("probe"),
|
||||
})
|
||||
server.stop(true)
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const canBindRealSockets = supportsRealSocketBinding()
|
||||
|
||||
type MockServerState = {
|
||||
port: number
|
||||
stopped: boolean
|
||||
fetch: (request: Request) => Response | Promise<Response>
|
||||
}
|
||||
|
||||
describe("startCallbackServer", () => {
|
||||
let server: CallbackServer | null = null
|
||||
let serveSpy: ReturnType<typeof spyOn> | null = null
|
||||
let activeServer: MockServerState | null = null
|
||||
|
||||
async function request(url: string): Promise<Response> {
|
||||
if (canBindRealSockets) {
|
||||
return nativeFetch(url)
|
||||
}
|
||||
|
||||
if (!activeServer || activeServer.stopped) {
|
||||
throw new Error("Connection refused")
|
||||
}
|
||||
|
||||
return await activeServer.fetch(new Request(url))
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
if (canBindRealSockets) {
|
||||
return
|
||||
}
|
||||
|
||||
activeServer = null
|
||||
serveSpy = spyOn(Bun, "serve").mockImplementation((options: {
|
||||
port: number
|
||||
hostname?: string
|
||||
fetch: (request: Request) => Response | Promise<Response>
|
||||
}) => {
|
||||
const state: MockServerState = {
|
||||
port: options.port === 0 ? 19877 : options.port,
|
||||
stopped: false,
|
||||
fetch: options.fetch,
|
||||
}
|
||||
|
||||
const handle = {
|
||||
port: state.port,
|
||||
stop: (_force?: boolean) => {
|
||||
state.stopped = true
|
||||
if (activeServer === state) {
|
||||
activeServer = null
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
activeServer = state
|
||||
return handle as ReturnType<typeof Bun.serve>
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(async () => {
|
||||
server?.close()
|
||||
server = null
|
||||
// Allow time for port to be released before next test
|
||||
await Bun.sleep(10)
|
||||
|
||||
if (serveSpy) {
|
||||
serveSpy.mockRestore()
|
||||
serveSpy = null
|
||||
}
|
||||
activeServer = null
|
||||
|
||||
if (canBindRealSockets) {
|
||||
await Bun.sleep(10)
|
||||
}
|
||||
})
|
||||
|
||||
it("starts server and returns port", async () => {
|
||||
// given - no preconditions
|
||||
|
||||
// when
|
||||
server = await startCallbackServer()
|
||||
|
||||
// then
|
||||
expect(server.port).toBeGreaterThanOrEqual(19877)
|
||||
expect(typeof server.waitForCallback).toBe("function")
|
||||
expect(typeof server.close).toBe("function")
|
||||
})
|
||||
|
||||
it("resolves callback with code and state from query params", async () => {
|
||||
// given
|
||||
server = await startCallbackServer()
|
||||
const callbackUrl = `http://127.0.0.1:${server.port}/oauth/callback?code=test-code&state=test-state`
|
||||
const callbackUrl = `http://${HOSTNAME}:${server.port}/oauth/callback?code=test-code&state=test-state`
|
||||
|
||||
// when
|
||||
// Use Promise.all to ensure fetch and waitForCallback run concurrently
|
||||
// This prevents race condition where waitForCallback blocks before fetch starts
|
||||
const [result, response] = await Promise.all([
|
||||
server.waitForCallback(),
|
||||
nativeFetch(callbackUrl)
|
||||
request(callbackUrl),
|
||||
])
|
||||
|
||||
// then
|
||||
expect(result).toEqual({ code: "test-code", state: "test-state" })
|
||||
expect(response.status).toBe(200)
|
||||
const html = await response.text()
|
||||
@@ -46,25 +114,19 @@ describe("startCallbackServer", () => {
|
||||
})
|
||||
|
||||
it("returns 404 for non-callback routes", async () => {
|
||||
// given
|
||||
server = await startCallbackServer()
|
||||
|
||||
// when
|
||||
const response = await nativeFetch(`http://127.0.0.1:${server.port}/other`)
|
||||
const response = await request(`http://${HOSTNAME}:${server.port}/other`)
|
||||
|
||||
// then
|
||||
expect(response.status).toBe(404)
|
||||
})
|
||||
|
||||
it("returns 400 and rejects when code is missing", async () => {
|
||||
// given
|
||||
server = await startCallbackServer()
|
||||
const callbackRejection = server.waitForCallback().catch((e: Error) => e)
|
||||
const callbackRejection = server.waitForCallback().catch((error: Error) => error)
|
||||
|
||||
// when
|
||||
const response = await nativeFetch(`http://127.0.0.1:${server.port}/oauth/callback?state=s`)
|
||||
const response = await request(`http://${HOSTNAME}:${server.port}/oauth/callback?state=s`)
|
||||
|
||||
// then
|
||||
expect(response.status).toBe(400)
|
||||
const error = await callbackRejection
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
@@ -72,14 +134,11 @@ describe("startCallbackServer", () => {
|
||||
})
|
||||
|
||||
it("returns 400 and rejects when state is missing", async () => {
|
||||
// given
|
||||
server = await startCallbackServer()
|
||||
const callbackRejection = server.waitForCallback().catch((e: Error) => e)
|
||||
const callbackRejection = server.waitForCallback().catch((error: Error) => error)
|
||||
|
||||
// when
|
||||
const response = await nativeFetch(`http://127.0.0.1:${server.port}/oauth/callback?code=c`)
|
||||
const response = await request(`http://${HOSTNAME}:${server.port}/oauth/callback?code=c`)
|
||||
|
||||
// then
|
||||
expect(response.status).toBe(400)
|
||||
const error = await callbackRejection
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
@@ -87,18 +146,15 @@ describe("startCallbackServer", () => {
|
||||
})
|
||||
|
||||
it("close stops the server immediately", async () => {
|
||||
// given
|
||||
server = await startCallbackServer()
|
||||
const port = server.port
|
||||
|
||||
// when
|
||||
server.close()
|
||||
server = null
|
||||
|
||||
// then
|
||||
try {
|
||||
await nativeFetch(`http://127.0.0.1:${port}/oauth/callback?code=c&state=s`)
|
||||
expect(true).toBe(false)
|
||||
await request(`http://${HOSTNAME}:${port}/oauth/callback?code=c&state=s`)
|
||||
expect.unreachable("request should fail after close")
|
||||
} catch (error) {
|
||||
expect(error).toBeDefined()
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ export async function findAvailablePort(startPort: number = DEFAULT_PORT): Promi
|
||||
}
|
||||
|
||||
export async function startCallbackServer(startPort: number = DEFAULT_PORT): Promise<CallbackServer> {
|
||||
const port = await findAvailablePort(startPort)
|
||||
const requestedPort = await findAvailablePort(startPort).catch(() => 0)
|
||||
|
||||
let resolveCallback: ((result: OAuthCallbackResult) => void) | null = null
|
||||
let rejectCallback: ((error: Error) => void) | null = null
|
||||
@@ -55,7 +55,7 @@ export async function startCallbackServer(startPort: number = DEFAULT_PORT): Pro
|
||||
}, TIMEOUT_MS)
|
||||
|
||||
const server = Bun.serve({
|
||||
port,
|
||||
port: requestedPort,
|
||||
hostname: "127.0.0.1",
|
||||
fetch(request: Request): Response {
|
||||
const url = new URL(request.url)
|
||||
@@ -93,9 +93,10 @@ export async function startCallbackServer(startPort: number = DEFAULT_PORT): Pro
|
||||
})
|
||||
},
|
||||
})
|
||||
const activePort = server.port ?? requestedPort
|
||||
|
||||
return {
|
||||
port,
|
||||
port: activePort,
|
||||
waitForCallback: () => callbackPromise,
|
||||
close: () => {
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
Reference in New Issue
Block a user