diff --git a/src/features/background-agent/fallback-retry-handler.test.ts b/src/features/background-agent/fallback-retry-handler.test.ts index 1a99503fa..03cd2b16f 100644 --- a/src/features/background-agent/fallback-retry-handler.test.ts +++ b/src/features/background-agent/fallback-retry-handler.test.ts @@ -19,6 +19,8 @@ mock.module("../../shared/provider-model-id-transform", () => ({ import { tryFallbackRetry } from "./fallback-retry-handler" import { shouldRetryError } from "../../shared/model-error-classifier" +import { selectFallbackProvider } from "../../shared/model-error-classifier" +import { readProviderModelsCache } from "../../shared" import type { BackgroundTask } from "./types" import type { ConcurrencyManager } from "./concurrency" @@ -82,6 +84,8 @@ function createDefaultArgs(taskOverrides: Partial = {}) { describe("tryFallbackRetry", () => { beforeEach(() => { ;(shouldRetryError as any).mockImplementation(() => true) + ;(selectFallbackProvider as any).mockImplementation((providers: string[]) => providers[0]) + ;(readProviderModelsCache as any).mockReturnValue(null) }) describe("#given retryable error with fallback chain", () => { @@ -267,4 +271,24 @@ describe("tryFallbackRetry", () => { expect(args.task.attemptCount).toBe(2) }) }) + + describe("#given disconnected fallback providers with connected preferred provider", () => { + test("keeps fallback entry and selects connected preferred provider", () => { + ;(readProviderModelsCache as any).mockReturnValue({ connected: ["provider-a"] }) + ;(selectFallbackProvider as any).mockImplementation( + (_providers: string[], preferredProviderID?: string) => preferredProviderID ?? "provider-b", + ) + + const args = createDefaultArgs({ + fallbackChain: [{ model: "fallback-model-1", providers: ["provider-b"], variant: undefined }], + model: { providerID: "provider-a", modelID: "original-model" }, + }) + + const result = tryFallbackRetry(args) + + expect(result).toBe(true) + expect(args.task.model?.providerID).toBe("provider-a") + expect(args.task.model?.modelID).toBe("fallback-model-1") + }) + }) }) diff --git a/src/features/background-agent/fallback-retry-handler.ts b/src/features/background-agent/fallback-retry-handler.ts index 94184fdbd..58c828e82 100644 --- a/src/features/background-agent/fallback-retry-handler.ts +++ b/src/features/background-agent/fallback-retry-handler.ts @@ -35,10 +35,14 @@ export function tryFallbackRetry(args: { const providerModelsCache = readProviderModelsCache() const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache() const connectedSet = connectedProviders ? new Set(connectedProviders.map(p => p.toLowerCase())) : null + const preferredProvider = task.model?.providerID?.toLowerCase() const isReachable = (entry: FallbackEntry): boolean => { if (!connectedSet) return true - return entry.providers.some((p) => connectedSet.has(p.toLowerCase())) + if (entry.providers.some((provider) => connectedSet.has(provider.toLowerCase()))) { + return true + } + return preferredProvider ? connectedSet.has(preferredProvider) : false } let selectedAttemptCount = attemptCount diff --git a/src/hooks/model-fallback/hook.test.ts b/src/hooks/model-fallback/hook.test.ts index 04cfcf659..fd3937311 100644 --- a/src/hooks/model-fallback/hook.test.ts +++ b/src/hooks/model-fallback/hook.test.ts @@ -255,6 +255,50 @@ describe("model fallback hook", () => { clearPendingModelFallback(sessionID) }) + test("uses connected preferred provider when fallback entry providers are disconnected", async () => { + //#given + const sessionID = "ses_model_fallback_preferred_provider" + clearPendingModelFallback(sessionID) + readConnectedProvidersCacheMock.mockReturnValue(["provider-x"]) + + const hook = createModelFallbackHook() as unknown as { + "chat.message"?: ( + input: { sessionID: string }, + output: { message: Record; parts: Array<{ type: string; text?: string }> }, + ) => Promise + } + + setSessionFallbackChain(sessionID, [ + { providers: ["provider-y"], model: "fallback-model" }, + ]) + + expect( + setPendingModelFallback( + sessionID, + "Sisyphus (Ultraworker)", + "provider-x", + "current-model", + ), + ).toBe(true) + + const output = { + message: { + model: { providerID: "provider-x", modelID: "current-model" }, + }, + parts: [{ type: "text", text: "continue" }], + } + + //#when + await hook["chat.message"]?.({ sessionID }, output) + + //#then + expect(output.message["model"]).toEqual({ + providerID: "provider-x", + modelID: "fallback-model", + }) + clearPendingModelFallback(sessionID) + }) + test("shows toast when fallback is applied", async () => { //#given const toastCalls: Array<{ title: string; message: string }> = [] diff --git a/src/hooks/model-fallback/hook.ts b/src/hooks/model-fallback/hook.ts index dbb4aa46d..045bba2df 100644 --- a/src/hooks/model-fallback/hook.ts +++ b/src/hooks/model-fallback/hook.ts @@ -130,14 +130,21 @@ export function getNextFallback( const providerModelsCache = readProviderModelsCache() const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache() - const connectedSet = connectedProviders ? new Set(connectedProviders) : null + const connectedSet = connectedProviders + ? new Set(connectedProviders.map((provider) => provider.toLowerCase())) + : null const isReachable = (entry: FallbackEntry): boolean => { if (!connectedSet) return true // Gate only on provider connectivity. Provider model lists can be stale/incomplete, // especially after users manually add models to opencode.json. - return entry.providers.some((p) => connectedSet.has(p)) + if (entry.providers.some((provider) => connectedSet.has(provider.toLowerCase()))) { + return true + } + + const preferredProvider = state.providerID.toLowerCase() + return connectedSet.has(preferredProvider) } while (state.attemptCount < fallbackChain.length) {