fix(runtime-fallback): make fallback provider selection provider-agnostic (fixes #2303)
This commit is contained in:
@@ -19,6 +19,8 @@ mock.module("../../shared/provider-model-id-transform", () => ({
|
||||
|
||||
import { tryFallbackRetry } from "./fallback-retry-handler"
|
||||
import { shouldRetryError } from "../../shared/model-error-classifier"
|
||||
import { selectFallbackProvider } from "../../shared/model-error-classifier"
|
||||
import { readProviderModelsCache } from "../../shared"
|
||||
import type { BackgroundTask } from "./types"
|
||||
import type { ConcurrencyManager } from "./concurrency"
|
||||
|
||||
@@ -82,6 +84,8 @@ function createDefaultArgs(taskOverrides: Partial<BackgroundTask> = {}) {
|
||||
describe("tryFallbackRetry", () => {
|
||||
beforeEach(() => {
|
||||
;(shouldRetryError as any).mockImplementation(() => true)
|
||||
;(selectFallbackProvider as any).mockImplementation((providers: string[]) => providers[0])
|
||||
;(readProviderModelsCache as any).mockReturnValue(null)
|
||||
})
|
||||
|
||||
describe("#given retryable error with fallback chain", () => {
|
||||
@@ -267,4 +271,24 @@ describe("tryFallbackRetry", () => {
|
||||
expect(args.task.attemptCount).toBe(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given disconnected fallback providers with connected preferred provider", () => {
|
||||
test("keeps fallback entry and selects connected preferred provider", () => {
|
||||
;(readProviderModelsCache as any).mockReturnValue({ connected: ["provider-a"] })
|
||||
;(selectFallbackProvider as any).mockImplementation(
|
||||
(_providers: string[], preferredProviderID?: string) => preferredProviderID ?? "provider-b",
|
||||
)
|
||||
|
||||
const args = createDefaultArgs({
|
||||
fallbackChain: [{ model: "fallback-model-1", providers: ["provider-b"], variant: undefined }],
|
||||
model: { providerID: "provider-a", modelID: "original-model" },
|
||||
})
|
||||
|
||||
const result = tryFallbackRetry(args)
|
||||
|
||||
expect(result).toBe(true)
|
||||
expect(args.task.model?.providerID).toBe("provider-a")
|
||||
expect(args.task.model?.modelID).toBe("fallback-model-1")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -35,10 +35,14 @@ export function tryFallbackRetry(args: {
|
||||
const providerModelsCache = readProviderModelsCache()
|
||||
const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache()
|
||||
const connectedSet = connectedProviders ? new Set(connectedProviders.map(p => p.toLowerCase())) : null
|
||||
const preferredProvider = task.model?.providerID?.toLowerCase()
|
||||
|
||||
const isReachable = (entry: FallbackEntry): boolean => {
|
||||
if (!connectedSet) return true
|
||||
return entry.providers.some((p) => connectedSet.has(p.toLowerCase()))
|
||||
if (entry.providers.some((provider) => connectedSet.has(provider.toLowerCase()))) {
|
||||
return true
|
||||
}
|
||||
return preferredProvider ? connectedSet.has(preferredProvider) : false
|
||||
}
|
||||
|
||||
let selectedAttemptCount = attemptCount
|
||||
|
||||
@@ -255,6 +255,50 @@ describe("model fallback hook", () => {
|
||||
clearPendingModelFallback(sessionID)
|
||||
})
|
||||
|
||||
test("uses connected preferred provider when fallback entry providers are disconnected", async () => {
|
||||
//#given
|
||||
const sessionID = "ses_model_fallback_preferred_provider"
|
||||
clearPendingModelFallback(sessionID)
|
||||
readConnectedProvidersCacheMock.mockReturnValue(["provider-x"])
|
||||
|
||||
const hook = createModelFallbackHook() as unknown as {
|
||||
"chat.message"?: (
|
||||
input: { sessionID: string },
|
||||
output: { message: Record<string, unknown>; parts: Array<{ type: string; text?: string }> },
|
||||
) => Promise<void>
|
||||
}
|
||||
|
||||
setSessionFallbackChain(sessionID, [
|
||||
{ providers: ["provider-y"], model: "fallback-model" },
|
||||
])
|
||||
|
||||
expect(
|
||||
setPendingModelFallback(
|
||||
sessionID,
|
||||
"Sisyphus (Ultraworker)",
|
||||
"provider-x",
|
||||
"current-model",
|
||||
),
|
||||
).toBe(true)
|
||||
|
||||
const output = {
|
||||
message: {
|
||||
model: { providerID: "provider-x", modelID: "current-model" },
|
||||
},
|
||||
parts: [{ type: "text", text: "continue" }],
|
||||
}
|
||||
|
||||
//#when
|
||||
await hook["chat.message"]?.({ sessionID }, output)
|
||||
|
||||
//#then
|
||||
expect(output.message["model"]).toEqual({
|
||||
providerID: "provider-x",
|
||||
modelID: "fallback-model",
|
||||
})
|
||||
clearPendingModelFallback(sessionID)
|
||||
})
|
||||
|
||||
test("shows toast when fallback is applied", async () => {
|
||||
//#given
|
||||
const toastCalls: Array<{ title: string; message: string }> = []
|
||||
|
||||
@@ -130,14 +130,21 @@ export function getNextFallback(
|
||||
|
||||
const providerModelsCache = readProviderModelsCache()
|
||||
const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache()
|
||||
const connectedSet = connectedProviders ? new Set(connectedProviders) : null
|
||||
const connectedSet = connectedProviders
|
||||
? new Set(connectedProviders.map((provider) => provider.toLowerCase()))
|
||||
: null
|
||||
|
||||
const isReachable = (entry: FallbackEntry): boolean => {
|
||||
if (!connectedSet) return true
|
||||
|
||||
// Gate only on provider connectivity. Provider model lists can be stale/incomplete,
|
||||
// especially after users manually add models to opencode.json.
|
||||
return entry.providers.some((p) => connectedSet.has(p))
|
||||
if (entry.providers.some((provider) => connectedSet.has(provider.toLowerCase()))) {
|
||||
return true
|
||||
}
|
||||
|
||||
const preferredProvider = state.providerID.toLowerCase()
|
||||
return connectedSet.has(preferredProvider)
|
||||
}
|
||||
|
||||
while (state.attemptCount < fallbackChain.length) {
|
||||
|
||||
Reference in New Issue
Block a user