diff --git a/src/shared/model-suggestion-retry.test.ts b/src/shared/model-suggestion-retry.test.ts index 52edba3aa..9732367de 100644 --- a/src/shared/model-suggestion-retry.test.ts +++ b/src/shared/model-suggestion-retry.test.ts @@ -399,6 +399,43 @@ describe("promptSyncWithModelSuggestionRetry", () => { expect(promptAsyncMock).toHaveBeenCalledTimes(0) }) + it("should abort and throw timeout error when sync prompt hangs", async () => { + // given a client where sync prompt never resolves unless aborted + let receivedSignal: AbortSignal | undefined + const promptMock = mock((input: { signal?: AbortSignal }) => { + receivedSignal = input.signal + return new Promise((_, reject) => { + const signal = input.signal + if (!signal) { + return + } + signal.addEventListener("abort", () => { + reject(signal.reason) + }) + }) + }) + const client = { + session: { + prompt: promptMock, + promptAsync: mock(() => Promise.resolve()), + }, + } + + // when calling with short timeout + // then should abort the request and throw timeout error + await expect( + promptSyncWithModelSuggestionRetry(client as any, { + path: { id: "session-1" }, + body: { + parts: [{ type: "text", text: "hello" }], + model: { providerID: "anthropic", modelID: "claude-sonnet-4" }, + }, + }, { timeoutMs: 1 }) + ).rejects.toThrow("prompt timed out after 1ms") + + expect(receivedSignal?.aborted).toBe(true) + }) + it("should retry with suggested model on ProviderModelNotFoundError", async () => { // given a client that fails first with model-not-found, then succeeds const promptMock = mock() diff --git a/src/shared/model-suggestion-retry.ts b/src/shared/model-suggestion-retry.ts index 6a34deacb..0ff9ca86e 100644 --- a/src/shared/model-suggestion-retry.ts +++ b/src/shared/model-suggestion-retry.ts @@ -1,5 +1,10 @@ import type { createOpencodeClient } from "@opencode-ai/sdk" import { log } from "./logger" +import { + createPromptTimeoutContext, + PROMPT_TIMEOUT_MS, + type PromptRetryOptions, +} from "./prompt-timeout-context" type Client = ReturnType @@ -77,30 +82,36 @@ interface PromptBody { interface PromptArgs { path: { id: string } body: PromptBody + signal?: AbortSignal [key: string]: unknown } export async function promptWithModelSuggestionRetry( client: Client, args: PromptArgs, + options: PromptRetryOptions = {}, ): Promise { + const timeoutMs = options.timeoutMs ?? PROMPT_TIMEOUT_MS + const timeoutContext = createPromptTimeoutContext(args, timeoutMs) // NOTE: Model suggestion retry removed — promptAsync returns 204 immediately, // model errors happen asynchronously server-side and cannot be caught here - const promptPromise = client.session.promptAsync( - args as Parameters[0], - ) - - let timeoutID: ReturnType | null = null - const timeoutPromise = new Promise((_, reject) => { - timeoutID = setTimeout(() => { - reject(new Error("promptAsync timed out after 120000ms")) - }, 120000) - }) + const promptPromise = client.session.promptAsync({ + ...args, + signal: timeoutContext.signal, + } as Parameters[0]) try { - await Promise.race([promptPromise, timeoutPromise]) + await promptPromise + if (timeoutContext.wasTimedOut()) { + throw new Error(`promptAsync timed out after ${timeoutMs}ms`) + } + } catch (error) { + if (timeoutContext.wasTimedOut()) { + throw new Error(`promptAsync timed out after ${timeoutMs}ms`) + } + throw error } finally { - if (timeoutID !== null) clearTimeout(timeoutID) + timeoutContext.cleanup() } } @@ -116,9 +127,28 @@ export async function promptWithModelSuggestionRetry( export async function promptSyncWithModelSuggestionRetry( client: Client, args: PromptArgs, + options: PromptRetryOptions = {}, ): Promise { + const timeoutMs = options.timeoutMs ?? PROMPT_TIMEOUT_MS + try { - await client.session.prompt(args as Parameters[0]) + const timeoutContext = createPromptTimeoutContext(args, timeoutMs) + try { + await client.session.prompt({ + ...args, + signal: timeoutContext.signal, + } as Parameters[0]) + if (timeoutContext.wasTimedOut()) { + throw new Error(`prompt timed out after ${timeoutMs}ms`) + } + } catch (error) { + if (timeoutContext.wasTimedOut()) { + throw new Error(`prompt timed out after ${timeoutMs}ms`) + } + throw error + } finally { + timeoutContext.cleanup() + } } catch (error) { const suggestion = parseModelSuggestion(error) if (!suggestion || !args.body.model) { @@ -130,7 +160,7 @@ export async function promptSyncWithModelSuggestionRetry( suggested: suggestion.suggestion, }) - await client.session.prompt({ + const retryArgs: PromptArgs = { ...args, body: { ...args.body, @@ -139,6 +169,24 @@ export async function promptSyncWithModelSuggestionRetry( modelID: suggestion.suggestion, }, }, - } as Parameters[0]) + } + + const timeoutContext = createPromptTimeoutContext(retryArgs, timeoutMs) + try { + await client.session.prompt({ + ...retryArgs, + signal: timeoutContext.signal, + } as Parameters[0]) + if (timeoutContext.wasTimedOut()) { + throw new Error(`prompt timed out after ${timeoutMs}ms`) + } + } catch (retryError) { + if (timeoutContext.wasTimedOut()) { + throw new Error(`prompt timed out after ${timeoutMs}ms`) + } + throw retryError + } finally { + timeoutContext.cleanup() + } } } diff --git a/src/shared/prompt-timeout-context.ts b/src/shared/prompt-timeout-context.ts new file mode 100644 index 000000000..99f081278 --- /dev/null +++ b/src/shared/prompt-timeout-context.ts @@ -0,0 +1,49 @@ +export interface PromptTimeoutArgs { + signal?: AbortSignal +} + +export interface PromptRetryOptions { + timeoutMs?: number +} + +export const PROMPT_TIMEOUT_MS = 120000 + +export function createPromptTimeoutContext(args: PromptTimeoutArgs, timeoutMs: number): { + signal: AbortSignal + wasTimedOut: () => boolean + cleanup: () => void +} { + const timeoutController = new AbortController() + let timeoutID: ReturnType | null = null + let timedOut = false + + const abortOnUpstreamSignal = (): void => { + timeoutController.abort(args.signal?.reason) + } + + if (args.signal) { + if (args.signal.aborted) { + timeoutController.abort(args.signal.reason) + } else { + args.signal.addEventListener("abort", abortOnUpstreamSignal, { once: true }) + } + } + + timeoutID = setTimeout(() => { + timedOut = true + timeoutController.abort(new Error(`prompt timed out after ${timeoutMs}ms`)) + }, timeoutMs) + + return { + signal: timeoutController.signal, + wasTimedOut: () => timedOut, + cleanup: () => { + if (timeoutID !== null) { + clearTimeout(timeoutID) + } + if (args.signal) { + args.signal.removeEventListener("abort", abortOnUpstreamSignal) + } + }, + } +}