Merge pull request #2488 from code-yeongyu/fix/issue-2295-fallback-provider-preserve

fix: preserve session provider context in fallback chain
This commit is contained in:
YeonGyu-Kim
2026-03-12 11:24:43 +09:00
committed by GitHub
3 changed files with 57 additions and 8 deletions

View File

@@ -19,11 +19,12 @@ import {
import { getFallbackModelsForSession } from "../hooks/runtime-fallback/fallback-models";
import { resetMessageCursor } from "../shared";
import { getAgentConfigKey } from "../shared/agent-display-names";
import { readConnectedProvidersCache } from "../shared/connected-providers-cache";
import { log } from "../shared/logger";
import { shouldRetryError } from "../shared/model-error-classifier";
import { buildFallbackChainFromModels } from "../shared/fallback-chain-from-models";
import { extractRetryAttempt, normalizeRetryStatusMessage } from "../shared/retry-status-utils";
import { clearSessionModel, setSessionModel } from "../shared/session-model-state";
import { clearSessionModel, getSessionModel, setSessionModel } from "../shared/session-model-state";
import { deleteSessionTools } from "../shared/session-tools-store";
import { lspManager } from "../tools";
@@ -165,6 +166,30 @@ export function createEventHandler(args: {
const lastHandledRetryStatusKey = new Map<string, string>();
const lastKnownModelBySession = new Map<string, { providerID: string; modelID: string }>();
const resolveFallbackProviderID = (sessionID: string, providerHint?: string): string => {
const sessionModel = getSessionModel(sessionID);
if (sessionModel?.providerID) {
return sessionModel.providerID;
}
const lastKnownModel = lastKnownModelBySession.get(sessionID);
if (lastKnownModel?.providerID) {
return lastKnownModel.providerID;
}
const normalizedProviderHint = providerHint?.trim();
if (normalizedProviderHint) {
return normalizedProviderHint;
}
const connectedProvider = readConnectedProvidersCache()?.[0];
if (connectedProvider) {
return connectedProvider;
}
return "opencode";
};
const dispatchToHooks = async (input: EventInput): Promise<void> => {
await Promise.resolve(hooks.autoUpdateChecker?.event?.(input));
await Promise.resolve(hooks.claudeCodeHooks?.event?.(input));
@@ -361,7 +386,10 @@ export function createEventHandler(args: {
}
if (agentName) {
const currentProvider = (info?.providerID as string | undefined) ?? "opencode";
const currentProvider = resolveFallbackProviderID(
sessionID,
info?.providerID as string | undefined,
);
const rawModel = (info?.modelID as string | undefined) ?? "claude-opus-4-6";
const currentModel = normalizeFallbackModelID(rawModel);
applyUserConfiguredFallbackChain(sessionID, agentName, currentProvider, args.pluginConfig);
@@ -418,7 +446,7 @@ export function createEventHandler(args: {
if (agentName) {
const parsed = extractProviderModelFromErrorMessage(retryMessage);
const lastKnown = lastKnownModelBySession.get(sessionID);
const currentProvider = parsed.providerID ?? lastKnown?.providerID ?? "opencode";
const currentProvider = resolveFallbackProviderID(sessionID, parsed.providerID);
let currentModel = parsed.modelID ?? lastKnown?.modelID ?? "claude-opus-4-6";
currentModel = normalizeFallbackModelID(currentModel);
applyUserConfiguredFallbackChain(sessionID, agentName, currentProvider, args.pluginConfig);
@@ -490,7 +518,10 @@ export function createEventHandler(args: {
if (agentName) {
const parsed = extractProviderModelFromErrorMessage(errorMessage);
const currentProvider = (props?.providerID as string) || parsed.providerID || "opencode";
const currentProvider = resolveFallbackProviderID(
sessionID,
(props?.providerID as string | undefined) || parsed.providerID,
);
let currentModel = (props?.modelID as string) || parsed.modelID || "claude-opus-4-6";
currentModel = normalizeFallbackModelID(currentModel);
applyUserConfiguredFallbackChain(sessionID, agentName, currentProvider, args.pluginConfig);

View File

@@ -32,6 +32,21 @@ describe("fallback-chain-from-models", () => {
})
})
test("uses opencode as absolute fallback provider when context provider is missing", () => {
//#given
const fallbackModel = "gemini-3-flash"
//#when
const parsed = parseFallbackModelEntry(fallbackModel, undefined)
//#then
expect(parsed).toEqual({
providers: ["opencode"],
model: "gemini-3-flash",
variant: undefined,
})
})
test("builds fallback chain from normalized fallback_models input", () => {
//#given
const fallbackModels = ["quotio/kimi-k2.5", "gpt-5.2 medium"]

View File

@@ -39,13 +39,15 @@ function parseVariantFromModel(rawModel: string): { modelID: string; variant?: s
export function parseFallbackModelEntry(
model: string,
defaultProviderID: string,
contextProviderID: string | undefined,
defaultProviderID = "opencode",
): FallbackEntry | undefined {
const trimmed = model.trim()
if (!trimmed) return undefined
const parts = trimmed.split("/")
const providerID = parts.length >= 2 ? parts[0].trim() : defaultProviderID
const providerID =
parts.length >= 2 ? parts[0].trim() : (contextProviderID?.trim() || defaultProviderID)
const rawModelID = parts.length >= 2 ? parts.slice(1).join("/").trim() : trimmed
if (!providerID || !rawModelID) return undefined
@@ -61,13 +63,14 @@ export function parseFallbackModelEntry(
export function buildFallbackChainFromModels(
fallbackModels: string | string[] | undefined,
defaultProviderID: string,
contextProviderID: string | undefined,
defaultProviderID = "opencode",
): FallbackEntry[] | undefined {
const normalized = normalizeFallbackModels(fallbackModels)
if (!normalized || normalized.length === 0) return undefined
const parsed = normalized
.map((model) => parseFallbackModelEntry(model, defaultProviderID))
.map((model) => parseFallbackModelEntry(model, contextProviderID, defaultProviderID))
.filter((entry): entry is FallbackEntry => entry !== undefined)
if (parsed.length === 0) return undefined