diff --git a/src/features/mcp-oauth/index.ts b/src/features/mcp-oauth/index.ts index 06861aae9..cf042d888 100644 --- a/src/features/mcp-oauth/index.ts +++ b/src/features/mcp-oauth/index.ts @@ -1 +1,3 @@ export * from "./schema" +export * from "./oauth-authorization-flow" +export * from "./provider" diff --git a/src/features/mcp-oauth/oauth-authorization-flow.ts b/src/features/mcp-oauth/oauth-authorization-flow.ts new file mode 100644 index 000000000..26f7d31ac --- /dev/null +++ b/src/features/mcp-oauth/oauth-authorization-flow.ts @@ -0,0 +1,150 @@ +import { spawn } from "node:child_process" +import { createHash, randomBytes } from "node:crypto" +import { createServer } from "node:http" + +export type OAuthCallbackResult = { + code: string + state: string +} + +export function generateCodeVerifier(): string { + return randomBytes(32).toString("base64url") +} + +export function generateCodeChallenge(verifier: string): string { + return createHash("sha256").update(verifier).digest("base64url") +} + +export function buildAuthorizationUrl( + authorizationEndpoint: string, + options: { + clientId: string + redirectUri: string + codeChallenge: string + state: string + scopes?: string[] + resource?: string + } +): string { + const url = new URL(authorizationEndpoint) + url.searchParams.set("response_type", "code") + url.searchParams.set("client_id", options.clientId) + url.searchParams.set("redirect_uri", options.redirectUri) + url.searchParams.set("code_challenge", options.codeChallenge) + url.searchParams.set("code_challenge_method", "S256") + url.searchParams.set("state", options.state) + if (options.scopes && options.scopes.length > 0) { + url.searchParams.set("scope", options.scopes.join(" ")) + } + if (options.resource) { + url.searchParams.set("resource", options.resource) + } + return url.toString() +} + +const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 + +export function startCallbackServer(port: number): Promise { + return new Promise((resolve, reject) => { + let timeoutId: ReturnType + + const server = createServer((request, response) => { + clearTimeout(timeoutId) + + const requestUrl = new URL(request.url ?? "/", `http://localhost:${port}`) + const code = requestUrl.searchParams.get("code") + const state = requestUrl.searchParams.get("state") + const error = requestUrl.searchParams.get("error") + + if (error) { + const errorDescription = requestUrl.searchParams.get("error_description") ?? error + response.writeHead(400, { "content-type": "text/html" }) + response.end("

Authorization failed

") + server.close() + reject(new Error(`OAuth authorization error: ${errorDescription}`)) + return + } + + if (!code || !state) { + response.writeHead(400, { "content-type": "text/html" }) + response.end("

Missing code or state

") + server.close() + reject(new Error("OAuth callback missing code or state parameter")) + return + } + + response.writeHead(200, { "content-type": "text/html" }) + response.end("

Authorization successful. You can close this tab.

") + server.close() + resolve({ code, state }) + }) + + timeoutId = setTimeout(() => { + server.close() + reject(new Error("OAuth callback timed out after 5 minutes")) + }, CALLBACK_TIMEOUT_MS) + + server.listen(port, "127.0.0.1") + server.on("error", (err) => { + clearTimeout(timeoutId) + reject(err) + }) + }) +} + +function openBrowser(url: string): void { + const platform = process.platform + let command: string + let args: string[] + + if (platform === "darwin") { + command = "open" + args = [url] + } else if (platform === "win32") { + command = "explorer" + args = [url] + } else { + command = "xdg-open" + args = [url] + } + + try { + const child = spawn(command, args, { stdio: "ignore", detached: true }) + child.on("error", () => {}) + child.unref() + } catch { + // Browser open failed — user must navigate manually + } +} + +export async function runAuthorizationCodeRedirect(options: { + authorizationEndpoint: string + callbackPort: number + clientId: string + redirectUri: string + scopes?: string[] + resource?: string +}): Promise<{ code: string; verifier: string }> { + const verifier = generateCodeVerifier() + const challenge = generateCodeChallenge(verifier) + const state = randomBytes(16).toString("hex") + + const authorizationUrl = buildAuthorizationUrl(options.authorizationEndpoint, { + clientId: options.clientId, + redirectUri: options.redirectUri, + codeChallenge: challenge, + state, + scopes: options.scopes, + resource: options.resource, + }) + + const callbackPromise = startCallbackServer(options.callbackPort) + openBrowser(authorizationUrl) + + const result = await callbackPromise + if (result.state !== state) { + throw new Error("OAuth state mismatch") + } + + return { code: result.code, verifier } +} diff --git a/src/features/mcp-oauth/provider.ts b/src/features/mcp-oauth/provider.ts index 6b4a69b34..bf098fdd4 100644 --- a/src/features/mcp-oauth/provider.ts +++ b/src/features/mcp-oauth/provider.ts @@ -1,6 +1,3 @@ -import { createHash, randomBytes } from "node:crypto" -import { createServer } from "node:http" -import { spawn } from "node:child_process" import type { OAuthTokenData } from "./storage" import { loadToken, saveToken } from "./storage" import { discoverOAuthServerMetadata } from "./discovery" @@ -8,6 +5,13 @@ import type { OAuthServerMetadata } from "./discovery" import { getOrRegisterClient } from "./dcr" import type { ClientCredentials, ClientRegistrationStorage } from "./dcr" import { findAvailablePort } from "./callback-server" +import { + buildAuthorizationUrl, + generateCodeChallenge, + generateCodeVerifier, + runAuthorizationCodeRedirect, + startCallbackServer, +} from "./oauth-authorization-flow" export type McpOAuthProviderOptions = { serverUrl: string @@ -15,121 +19,6 @@ export type McpOAuthProviderOptions = { scopes?: string[] } -type CallbackResult = { - code: string - state: string -} - -function generateCodeVerifier(): string { - return randomBytes(32).toString("base64url") -} - -function generateCodeChallenge(verifier: string): string { - return createHash("sha256").update(verifier).digest("base64url") -} - -function buildAuthorizationUrl( - authorizationEndpoint: string, - options: { - clientId: string - redirectUri: string - codeChallenge: string - state: string - scopes?: string[] - resource?: string - } -): string { - const url = new URL(authorizationEndpoint) - url.searchParams.set("response_type", "code") - url.searchParams.set("client_id", options.clientId) - url.searchParams.set("redirect_uri", options.redirectUri) - url.searchParams.set("code_challenge", options.codeChallenge) - url.searchParams.set("code_challenge_method", "S256") - url.searchParams.set("state", options.state) - if (options.scopes && options.scopes.length > 0) { - url.searchParams.set("scope", options.scopes.join(" ")) - } - if (options.resource) { - url.searchParams.set("resource", options.resource) - } - return url.toString() -} - -const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 - -function startCallbackServer(port: number): Promise { - return new Promise((resolve, reject) => { - let timeoutId: ReturnType - - const server = createServer((request, response) => { - clearTimeout(timeoutId) - - const requestUrl = new URL(request.url ?? "/", `http://localhost:${port}`) - const code = requestUrl.searchParams.get("code") - const state = requestUrl.searchParams.get("state") - const error = requestUrl.searchParams.get("error") - - if (error) { - const errorDescription = requestUrl.searchParams.get("error_description") ?? error - response.writeHead(400, { "content-type": "text/html" }) - response.end("

Authorization failed

") - server.close() - reject(new Error(`OAuth authorization error: ${errorDescription}`)) - return - } - - if (!code || !state) { - response.writeHead(400, { "content-type": "text/html" }) - response.end("

Missing code or state

") - server.close() - reject(new Error("OAuth callback missing code or state parameter")) - return - } - - response.writeHead(200, { "content-type": "text/html" }) - response.end("

Authorization successful. You can close this tab.

") - server.close() - resolve({ code, state }) - }) - - timeoutId = setTimeout(() => { - server.close() - reject(new Error("OAuth callback timed out after 5 minutes")) - }, CALLBACK_TIMEOUT_MS) - - server.listen(port, "127.0.0.1") - server.on("error", (err) => { - clearTimeout(timeoutId) - reject(err) - }) - }) -} - -function openBrowser(url: string): void { - const platform = process.platform - let cmd: string - let args: string[] - - if (platform === "darwin") { - cmd = "open" - args = [url] - } else if (platform === "win32") { - cmd = "explorer" - args = [url] - } else { - cmd = "xdg-open" - args = [url] - } - - try { - const child = spawn(cmd, args, { stdio: "ignore", detached: true }) - child.on("error", () => {}) - child.unref() - } catch { - // Browser open failed — user must navigate manually - } -} - export class McpOAuthProvider { private readonly serverUrl: string private readonly configClientId: string | undefined @@ -174,12 +63,7 @@ export class McpOAuthProvider { return this.storedCodeVerifier } - async redirectToAuthorization(metadata: OAuthServerMetadata): Promise { - const verifier = generateCodeVerifier() - this.saveCodeVerifier(verifier) - const challenge = generateCodeChallenge(verifier) - const state = randomBytes(16).toString("hex") - + async redirectToAuthorization(metadata: OAuthServerMetadata): Promise<{ code: string }> { const clientInfo = this.clientInformation() if (!clientInfo) { throw new Error("No client information available. Run login() or register a client first.") @@ -189,24 +73,17 @@ export class McpOAuthProvider { this.callbackPort = await findAvailablePort() } - const authUrl = buildAuthorizationUrl(metadata.authorizationEndpoint, { + const result = await runAuthorizationCodeRedirect({ + authorizationEndpoint: metadata.authorizationEndpoint, + callbackPort: this.callbackPort, clientId: clientInfo.clientId, redirectUri: this.redirectUrl(), - codeChallenge: challenge, - state, scopes: this.scopes, resource: metadata.resource, }) - const callbackPromise = startCallbackServer(this.callbackPort) - openBrowser(authUrl) - - const result = await callbackPromise - if (result.state !== state) { - throw new Error("OAuth state mismatch") - } - - return result + this.saveCodeVerifier(result.verifier) + return { code: result.code } } async login(): Promise {