diff --git a/src/auth/antigravity/oauth.test.ts b/src/auth/antigravity/oauth.test.ts new file mode 100644 index 000000000..9fcfa675e --- /dev/null +++ b/src/auth/antigravity/oauth.test.ts @@ -0,0 +1,191 @@ +import { describe, it, expect, beforeEach, afterEach, mock } from "bun:test" +import { buildAuthURL, exchangeCode } from "./oauth" +import { ANTIGRAVITY_CLIENT_ID, GOOGLE_TOKEN_URL } from "./constants" + +describe("OAuth PKCE Removal", () => { + describe("buildAuthURL", () => { + it("should NOT include code_challenge parameter", async () => { + // #given + const projectId = "test-project" + + // #when + const result = await buildAuthURL(projectId) + const url = new URL(result.url) + + // #then + expect(url.searchParams.has("code_challenge")).toBe(false) + }) + + it("should NOT include code_challenge_method parameter", async () => { + // #given + const projectId = "test-project" + + // #when + const result = await buildAuthURL(projectId) + const url = new URL(result.url) + + // #then + expect(url.searchParams.has("code_challenge_method")).toBe(false) + }) + + it("should include state parameter for CSRF protection", async () => { + // #given + const projectId = "test-project" + + // #when + const result = await buildAuthURL(projectId) + const url = new URL(result.url) + const state = url.searchParams.get("state") + + // #then + expect(state).toBeTruthy() + }) + + it("should have state as simple random string (not JSON/base64)", async () => { + // #given + const projectId = "test-project" + + // #when + const result = await buildAuthURL(projectId) + const url = new URL(result.url) + const state = url.searchParams.get("state")! + + // #then - positive assertions for simple random string + expect(state.length).toBeGreaterThanOrEqual(16) + expect(state.length).toBeLessThanOrEqual(64) + // Should be URL-safe (alphanumeric, no special chars like { } " :) + expect(state).toMatch(/^[a-zA-Z0-9_-]+$/) + // Should NOT contain JSON indicators + expect(state).not.toContain("{") + expect(state).not.toContain("}") + expect(state).not.toContain('"') + }) + + it("should include access_type=offline", async () => { + // #given + const projectId = "test-project" + + // #when + const result = await buildAuthURL(projectId) + const url = new URL(result.url) + + // #then + expect(url.searchParams.get("access_type")).toBe("offline") + }) + + it("should include prompt=consent", async () => { + // #given + const projectId = "test-project" + + // #when + const result = await buildAuthURL(projectId) + const url = new URL(result.url) + + // #then + expect(url.searchParams.get("prompt")).toBe("consent") + }) + + it("should NOT return verifier property (PKCE removed)", async () => { + // #given + const projectId = "test-project" + + // #when + const result = await buildAuthURL(projectId) + + // #then + expect(result).not.toHaveProperty("verifier") + expect(result).toHaveProperty("url") + expect(result).toHaveProperty("state") + }) + + it("should return state that matches URL state param", async () => { + // #given + const projectId = "test-project" + + // #when + const result = await buildAuthURL(projectId) + const url = new URL(result.url) + + // #then + expect(result.state).toBe(url.searchParams.get("state")!) + }) + }) + + describe("exchangeCode", () => { + let originalFetch: typeof fetch + + beforeEach(() => { + originalFetch = globalThis.fetch + }) + + afterEach(() => { + globalThis.fetch = originalFetch + }) + + it("should NOT send code_verifier in token exchange", async () => { + // #given + let capturedBody: string | null = null + globalThis.fetch = mock(async (url: string, init?: RequestInit) => { + if (url === GOOGLE_TOKEN_URL) { + capturedBody = init?.body as string + return new Response(JSON.stringify({ + access_token: "test-access", + refresh_token: "test-refresh", + expires_in: 3600, + token_type: "Bearer" + })) + } + return new Response("", { status: 404 }) + }) as unknown as typeof fetch + + // #when + await exchangeCode("test-code", "http://localhost:51121/oauth-callback") + + // #then + expect(capturedBody).toBeTruthy() + const params = new URLSearchParams(capturedBody!) + expect(params.has("code_verifier")).toBe(false) + }) + + it("should send required OAuth parameters", async () => { + // #given + let capturedBody: string | null = null + globalThis.fetch = mock(async (url: string, init?: RequestInit) => { + if (url === GOOGLE_TOKEN_URL) { + capturedBody = init?.body as string + return new Response(JSON.stringify({ + access_token: "test-access", + refresh_token: "test-refresh", + expires_in: 3600, + token_type: "Bearer" + })) + } + return new Response("", { status: 404 }) + }) as unknown as typeof fetch + + // #when + await exchangeCode("test-code", "http://localhost:51121/oauth-callback") + + // #then + const params = new URLSearchParams(capturedBody!) + expect(params.get("grant_type")).toBe("authorization_code") + expect(params.get("code")).toBe("test-code") + expect(params.get("client_id")).toBe(ANTIGRAVITY_CLIENT_ID) + expect(params.get("redirect_uri")).toBe("http://localhost:51121/oauth-callback") + }) + }) + + describe("State/CSRF Validation", () => { + it("should generate unique state for each call", async () => { + // #given + const projectId = "test-project" + + // #when + const result1 = await buildAuthURL(projectId) + const result2 = await buildAuthURL(projectId) + + // #then + expect(result1.state).not.toBe(result2.state) + }) + }) +}) diff --git a/src/auth/antigravity/oauth.ts b/src/auth/antigravity/oauth.ts index 7e76b4417..aa7ca78d0 100644 --- a/src/auth/antigravity/oauth.ts +++ b/src/auth/antigravity/oauth.ts @@ -1,9 +1,7 @@ /** - * Antigravity OAuth 2.0 flow implementation with PKCE. + * Antigravity OAuth 2.0 flow implementation. * Handles Google OAuth for Antigravity authentication. */ -import { generatePKCE } from "@openauthjs/openauth/pkce" - import { ANTIGRAVITY_CLIENT_ID, ANTIGRAVITY_CLIENT_SECRET, @@ -19,37 +17,14 @@ import type { AntigravityUserInfo, } from "./types" -/** - * PKCE pair containing verifier and challenge. - */ -export interface PKCEPair { - /** PKCE verifier - used during token exchange */ - verifier: string - /** PKCE challenge - sent in auth URL */ - challenge: string - /** Challenge method - always "S256" */ - method: string -} - -/** - * OAuth state encoded in the auth URL. - * Contains the PKCE verifier for later retrieval. - */ -export interface OAuthState { - /** PKCE verifier */ - verifier: string - /** Optional project ID */ - projectId?: string -} - /** * Result from building an OAuth authorization URL. */ export interface AuthorizationResult { /** Full OAuth URL to open in browser */ url: string - /** PKCE verifier to use during code exchange */ - verifier: string + /** State for CSRF protection */ + state: string } /** @@ -64,70 +39,12 @@ export interface CallbackResult { error?: string } -/** - * Generate PKCE verifier and challenge pair. - * Uses @openauthjs/openauth for cryptographically secure generation. - * - * @returns PKCE pair with verifier, challenge, and method - */ -export async function generatePKCEPair(): Promise { - const pkce = await generatePKCE() - return { - verifier: pkce.verifier, - challenge: pkce.challenge, - method: pkce.method, - } -} - -/** - * Encode OAuth state into a URL-safe base64 string. - * - * @param state - OAuth state object - * @returns Base64URL encoded state - */ -function encodeState(state: OAuthState): string { - const json = JSON.stringify(state) - return Buffer.from(json, "utf8").toString("base64url") -} - -/** - * Decode OAuth state from a base64 string. - * - * @param encoded - Base64URL or Base64 encoded state - * @returns Decoded OAuth state - */ -export function decodeState(encoded: string): OAuthState { - // Handle both base64url and standard base64 - const normalized = encoded.replace(/-/g, "+").replace(/_/g, "/") - const padded = normalized.padEnd( - normalized.length + ((4 - (normalized.length % 4)) % 4), - "=" - ) - const json = Buffer.from(padded, "base64").toString("utf8") - const parsed = JSON.parse(json) - - if (typeof parsed.verifier !== "string") { - throw new Error("Missing PKCE verifier in state") - } - - return { - verifier: parsed.verifier, - projectId: - typeof parsed.projectId === "string" ? parsed.projectId : undefined, - } -} - export async function buildAuthURL( projectId?: string, clientId: string = ANTIGRAVITY_CLIENT_ID, port: number = ANTIGRAVITY_CALLBACK_PORT ): Promise { - const pkce = await generatePKCEPair() - - const state: OAuthState = { - verifier: pkce.verifier, - projectId, - } + const state = crypto.randomUUID().replace(/-/g, "") const redirectUri = `http://localhost:${port}/oauth-callback` @@ -136,15 +53,13 @@ export async function buildAuthURL( url.searchParams.set("redirect_uri", redirectUri) url.searchParams.set("response_type", "code") url.searchParams.set("scope", ANTIGRAVITY_SCOPES.join(" ")) - url.searchParams.set("state", encodeState(state)) - url.searchParams.set("code_challenge", pkce.challenge) - url.searchParams.set("code_challenge_method", "S256") + url.searchParams.set("state", state) url.searchParams.set("access_type", "offline") url.searchParams.set("prompt", "consent") return { url: url.toString(), - verifier: pkce.verifier, + state, } } @@ -152,26 +67,23 @@ export async function buildAuthURL( * Exchange authorization code for tokens. * * @param code - Authorization code from OAuth callback - * @param verifier - PKCE verifier from initial auth request + * @param redirectUri - OAuth redirect URI * @param clientId - Optional custom client ID (defaults to ANTIGRAVITY_CLIENT_ID) * @param clientSecret - Optional custom client secret (defaults to ANTIGRAVITY_CLIENT_SECRET) * @returns Token exchange result with access and refresh tokens */ export async function exchangeCode( code: string, - verifier: string, + redirectUri: string, clientId: string = ANTIGRAVITY_CLIENT_ID, - clientSecret: string = ANTIGRAVITY_CLIENT_SECRET, - port: number = ANTIGRAVITY_CALLBACK_PORT + clientSecret: string = ANTIGRAVITY_CLIENT_SECRET ): Promise { - const redirectUri = `http://localhost:${port}/oauth-callback` const params = new URLSearchParams({ client_id: clientId, client_secret: clientSecret, code, grant_type: "authorization_code", redirect_uri: redirectUri, - code_verifier: verifier, }) const response = await fetch(GOOGLE_TOKEN_URL, { @@ -324,7 +236,7 @@ export async function performOAuthFlow( ): Promise<{ tokens: AntigravityTokenExchangeResult userInfo: AntigravityUserInfo - verifier: string + state: string }> { const serverHandle = startCallbackServer() @@ -345,15 +257,15 @@ export async function performOAuthFlow( throw new Error("No authorization code received") } - const state = decodeState(callback.state) - if (state.verifier !== auth.verifier) { - throw new Error("PKCE verifier mismatch - possible CSRF attack") + if (callback.state !== auth.state) { + throw new Error("State mismatch - possible CSRF attack") } - const tokens = await exchangeCode(callback.code, auth.verifier, clientId, clientSecret, serverHandle.port) + const redirectUri = `http://localhost:${serverHandle.port}/oauth-callback` + const tokens = await exchangeCode(callback.code, redirectUri, clientId, clientSecret) const userInfo = await fetchUserInfo(tokens.access_token) - return { tokens, userInfo, verifier: auth.verifier } + return { tokens, userInfo, state: auth.state } } catch (err) { serverHandle.close() throw err