fix(runtime-fallback): make fallback provider selection provider-agnostic (fixes #2303)

This commit is contained in:
MoerAI
2026-03-19 21:02:34 +09:00
committed by sspark-kisane
parent d2a49428b9
commit 0e610a72bc
4 changed files with 82 additions and 3 deletions

View File

@@ -19,6 +19,8 @@ mock.module("../../shared/provider-model-id-transform", () => ({
import { tryFallbackRetry } from "./fallback-retry-handler"
import { shouldRetryError } from "../../shared/model-error-classifier"
import { selectFallbackProvider } from "../../shared/model-error-classifier"
import { readProviderModelsCache } from "../../shared"
import type { BackgroundTask } from "./types"
import type { ConcurrencyManager } from "./concurrency"
@@ -82,6 +84,8 @@ function createDefaultArgs(taskOverrides: Partial<BackgroundTask> = {}) {
describe("tryFallbackRetry", () => {
beforeEach(() => {
;(shouldRetryError as any).mockImplementation(() => true)
;(selectFallbackProvider as any).mockImplementation((providers: string[]) => providers[0])
;(readProviderModelsCache as any).mockReturnValue(null)
})
describe("#given retryable error with fallback chain", () => {
@@ -267,4 +271,24 @@ describe("tryFallbackRetry", () => {
expect(args.task.attemptCount).toBe(2)
})
})
describe("#given disconnected fallback providers with connected preferred provider", () => {
test("keeps fallback entry and selects connected preferred provider", () => {
;(readProviderModelsCache as any).mockReturnValue({ connected: ["provider-a"] })
;(selectFallbackProvider as any).mockImplementation(
(_providers: string[], preferredProviderID?: string) => preferredProviderID ?? "provider-b",
)
const args = createDefaultArgs({
fallbackChain: [{ model: "fallback-model-1", providers: ["provider-b"], variant: undefined }],
model: { providerID: "provider-a", modelID: "original-model" },
})
const result = tryFallbackRetry(args)
expect(result).toBe(true)
expect(args.task.model?.providerID).toBe("provider-a")
expect(args.task.model?.modelID).toBe("fallback-model-1")
})
})
})

View File

@@ -35,10 +35,14 @@ export function tryFallbackRetry(args: {
const providerModelsCache = readProviderModelsCache()
const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache()
const connectedSet = connectedProviders ? new Set(connectedProviders.map(p => p.toLowerCase())) : null
const preferredProvider = task.model?.providerID?.toLowerCase()
const isReachable = (entry: FallbackEntry): boolean => {
if (!connectedSet) return true
return entry.providers.some((p) => connectedSet.has(p.toLowerCase()))
if (entry.providers.some((provider) => connectedSet.has(provider.toLowerCase()))) {
return true
}
return preferredProvider ? connectedSet.has(preferredProvider) : false
}
let selectedAttemptCount = attemptCount

View File

@@ -255,6 +255,50 @@ describe("model fallback hook", () => {
clearPendingModelFallback(sessionID)
})
test("uses connected preferred provider when fallback entry providers are disconnected", async () => {
//#given
const sessionID = "ses_model_fallback_preferred_provider"
clearPendingModelFallback(sessionID)
readConnectedProvidersCacheMock.mockReturnValue(["provider-x"])
const hook = createModelFallbackHook() as unknown as {
"chat.message"?: (
input: { sessionID: string },
output: { message: Record<string, unknown>; parts: Array<{ type: string; text?: string }> },
) => Promise<void>
}
setSessionFallbackChain(sessionID, [
{ providers: ["provider-y"], model: "fallback-model" },
])
expect(
setPendingModelFallback(
sessionID,
"Sisyphus (Ultraworker)",
"provider-x",
"current-model",
),
).toBe(true)
const output = {
message: {
model: { providerID: "provider-x", modelID: "current-model" },
},
parts: [{ type: "text", text: "continue" }],
}
//#when
await hook["chat.message"]?.({ sessionID }, output)
//#then
expect(output.message["model"]).toEqual({
providerID: "provider-x",
modelID: "fallback-model",
})
clearPendingModelFallback(sessionID)
})
test("shows toast when fallback is applied", async () => {
//#given
const toastCalls: Array<{ title: string; message: string }> = []

View File

@@ -130,14 +130,21 @@ export function getNextFallback(
const providerModelsCache = readProviderModelsCache()
const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache()
const connectedSet = connectedProviders ? new Set(connectedProviders) : null
const connectedSet = connectedProviders
? new Set(connectedProviders.map((provider) => provider.toLowerCase()))
: null
const isReachable = (entry: FallbackEntry): boolean => {
if (!connectedSet) return true
// Gate only on provider connectivity. Provider model lists can be stale/incomplete,
// especially after users manually add models to opencode.json.
return entry.providers.some((p) => connectedSet.has(p))
if (entry.providers.some((provider) => connectedSet.has(provider.toLowerCase()))) {
return true
}
const preferredProvider = state.providerID.toLowerCase()
return connectedSet.has(preferredProvider)
}
while (state.attemptCount < fallbackChain.length) {