diff --git a/src/plugin/hooks/create-model-fallback-session-hook.ts b/src/plugin/hooks/create-model-fallback-session-hook.ts new file mode 100644 index 000000000..96bac6d13 --- /dev/null +++ b/src/plugin/hooks/create-model-fallback-session-hook.ts @@ -0,0 +1,131 @@ +import type { OhMyOpenCodeConfig, HookName } from "../../config" + +import { createModelFallbackHook } from "../../hooks" +import { normalizeSDKResponse } from "../../shared" + +import { resolveModelFallbackEnabled } from "./model-fallback-config" + +type SafeHook = (hookName: HookName, factory: () => THook) => THook | null + +type ModelFallbackSessionContext = { + directory: string + client: { + session: { + get: (input: { path: { id: string } }) => Promise + update: (input: { + path: { id: string } + body: { title: string } + query: { directory: string } + }) => Promise + } + tui: { + showToast: (input: { + body: { + title: string + message: string + variant: "success" | "error" | "info" | "warning" + duration: number + } + }) => Promise + } + } +} + +function createFallbackTitleUpdater( + ctx: ModelFallbackSessionContext, + enabled: boolean, +): + | ((input: { + sessionID: string + providerID: string + modelID: string + variant?: string + }) => Promise) + | undefined { + if (!enabled) { + return undefined + } + + const fallbackTitleMaxEntries = 200 + const fallbackTitleState = new Map() + + return async (input) => { + const key = `${input.providerID}/${input.modelID}${input.variant ? `:${input.variant}` : ""}` + const existing = fallbackTitleState.get(input.sessionID) ?? {} + if (existing.lastKey === key) { + return + } + + if (!existing.baseTitle) { + const sessionResp = await ctx.client.session.get({ path: { id: input.sessionID } }).catch(() => null) + const sessionInfo = sessionResp + ? normalizeSDKResponse(sessionResp, null as { title?: string } | null, { + preferResponseOnMissingData: true, + }) + : null + const rawTitle = sessionInfo?.title + if (typeof rawTitle === "string" && rawTitle.length > 0) { + existing.baseTitle = rawTitle.replace(/\s*\[fallback:[^\]]+\]$/i, "").trim() + } else { + existing.baseTitle = "Session" + } + } + + const variantLabel = input.variant ? ` ${input.variant}` : "" + const newTitle = `${existing.baseTitle} [fallback: ${input.providerID}/${input.modelID}${variantLabel}]` + + await ctx.client.session + .update({ + path: { id: input.sessionID }, + body: { title: newTitle }, + query: { directory: ctx.directory }, + }) + .catch(() => {}) + + existing.lastKey = key + fallbackTitleState.set(input.sessionID, existing) + if (fallbackTitleState.size > fallbackTitleMaxEntries) { + const oldestKey = fallbackTitleState.keys().next().value + if (oldestKey) { + fallbackTitleState.delete(oldestKey) + } + } + } +} + +export function createConfiguredModelFallbackHook(args: { + ctx: ModelFallbackSessionContext + pluginConfig: OhMyOpenCodeConfig + isHookEnabled: (hookName: HookName) => boolean + safeHook: SafeHook +}): ReturnType | null { + const { ctx, pluginConfig, isHookEnabled, safeHook } = args + const isModelFallbackEnabled = resolveModelFallbackEnabled(pluginConfig) + + if (!isModelFallbackEnabled || !isHookEnabled("model-fallback")) { + return null + } + + const onApplied = createFallbackTitleUpdater( + ctx, + pluginConfig.experimental?.model_fallback_title ?? false, + ) + + return safeHook("model-fallback", () => + createModelFallbackHook({ + toast: async ({ title, message, variant, duration }) => { + await ctx.client.tui + .showToast({ + body: { + title, + message, + variant: variant ?? "warning", + duration: duration ?? 5000, + }, + }) + .catch(() => {}) + }, + onApplied, + }), + ) +} diff --git a/src/plugin/hooks/create-session-hooks.ts b/src/plugin/hooks/create-session-hooks.ts index daa5e4ff5..3dcc0d7d6 100644 --- a/src/plugin/hooks/create-session-hooks.ts +++ b/src/plugin/hooks/create-session-hooks.ts @@ -1,13 +1,10 @@ import type { OhMyOpenCodeConfig, HookName } from "../../config" import type { ModelCacheState } from "../../plugin-state" -import type { PluginContext } from "../types" - import { createContextWindowMonitorHook, createSessionRecoveryHook, createSessionNotification, createThinkModeHook, - createModelFallbackHook, createAnthropicContextWindowLimitRecoveryHook, createAutoUpdateCheckerHook, createAgentUsageReminderHook, @@ -31,10 +28,10 @@ import { detectExternalNotificationPlugin, getNotificationConflictWarning, log, - normalizeSDKResponse, } from "../../shared" import { safeCreateHook } from "../../shared/safe-create-hook" import { sessionExists } from "../../tools" +import { createConfiguredModelFallbackHook } from "./create-model-fallback-session-hook" export type SessionHooks = { contextWindowMonitor: ReturnType | null @@ -42,7 +39,7 @@ export type SessionHooks = { sessionRecovery: ReturnType | null sessionNotification: ReturnType | null thinkMode: ReturnType | null - modelFallback: ReturnType | null + modelFallback: ReturnType anthropicContextWindowLimitRecovery: ReturnType | null autoUpdateChecker: ReturnType | null agentUsageReminder: ReturnType | null @@ -63,7 +60,7 @@ export type SessionHooks = { } export function createSessionHooks(args: { - ctx: PluginContext + ctx: Parameters[0] pluginConfig: OhMyOpenCodeConfig modelCacheState: ModelCacheState isHookEnabled: (hookName: HookName) => boolean @@ -105,73 +102,12 @@ export function createSessionHooks(args: { ? safeHook("think-mode", () => createThinkModeHook()) : null - const enableFallbackTitle = pluginConfig.experimental?.model_fallback_title ?? false - const fallbackTitleMaxEntries = 200 - const fallbackTitleState = new Map() - const updateFallbackTitle = async (input: { - sessionID: string - providerID: string - modelID: string - variant?: string - }) => { - if (!enableFallbackTitle) return - const key = `${input.providerID}/${input.modelID}${input.variant ? `:${input.variant}` : ""}` - const existing = fallbackTitleState.get(input.sessionID) ?? {} - if (existing.lastKey === key) return - - if (!existing.baseTitle) { - const sessionResp = await ctx.client.session.get({ path: { id: input.sessionID } }).catch(() => null) - const sessionInfo = sessionResp - ? normalizeSDKResponse(sessionResp, null as { title?: string } | null, { preferResponseOnMissingData: true }) - : null - const rawTitle = sessionInfo?.title - if (typeof rawTitle === "string" && rawTitle.length > 0) { - existing.baseTitle = rawTitle.replace(/\s*\[fallback:[^\]]+\]$/i, "").trim() - } else { - existing.baseTitle = "Session" - } - } - - const variantLabel = input.variant ? ` ${input.variant}` : "" - const newTitle = `${existing.baseTitle} [fallback: ${input.providerID}/${input.modelID}${variantLabel}]` - - await ctx.client.session - .update({ - path: { id: input.sessionID }, - body: { title: newTitle }, - query: { directory: ctx.directory }, - }) - .catch(() => {}) - - existing.lastKey = key - fallbackTitleState.set(input.sessionID, existing) - if (fallbackTitleState.size > fallbackTitleMaxEntries) { - const oldestKey = fallbackTitleState.keys().next().value - if (oldestKey) fallbackTitleState.delete(oldestKey) - } - } - - // Model fallback hook (configurable via model_fallback config + disabled_hooks) - // This handles automatic model switching when model errors occur - const isModelFallbackConfigEnabled = pluginConfig.model_fallback ?? false - const modelFallback = isModelFallbackConfigEnabled && isHookEnabled("model-fallback") - ? safeHook("model-fallback", () => - createModelFallbackHook({ - toast: async ({ title, message, variant, duration }) => { - await ctx.client.tui - .showToast({ - body: { - title, - message, - variant: variant ?? "warning", - duration: duration ?? 5000, - }, - }) - .catch(() => {}) - }, - onApplied: enableFallbackTitle ? updateFallbackTitle : undefined, - })) - : null + const modelFallback = createConfiguredModelFallbackHook({ + ctx, + pluginConfig, + isHookEnabled, + safeHook, + }) const anthropicContextWindowLimitRecovery = isHookEnabled("anthropic-context-window-limit-recovery") ? safeHook("anthropic-context-window-limit-recovery", () => diff --git a/src/plugin/hooks/model-fallback-config.test.ts b/src/plugin/hooks/model-fallback-config.test.ts new file mode 100644 index 000000000..283b7946b --- /dev/null +++ b/src/plugin/hooks/model-fallback-config.test.ts @@ -0,0 +1,63 @@ +declare const require: (name: string) => any +const { describe, expect, test } = require("bun:test") + +import type { OhMyOpenCodeConfig } from "../../config" + +import { + hasConfiguredModelFallbacks, + resolveModelFallbackEnabled, +} from "./model-fallback-config" + +describe("model-fallback-config", () => { + test("detects agent fallback_models configuration", () => { + //#given + const pluginConfig: OhMyOpenCodeConfig = { + agents: { + sisyphus: { + fallback_models: ["openai/gpt-5.2", "anthropic/claude-opus-4-6"], + }, + }, + } + + //#when + const result = hasConfiguredModelFallbacks(pluginConfig) + + //#then + expect(result).toBe(true) + }) + + test("auto-enables model fallback when category fallback_models are configured", () => { + //#given + const pluginConfig: OhMyOpenCodeConfig = { + categories: { + quick: { + fallback_models: ["openai/gpt-5.2"], + }, + }, + } + + //#when + const result = resolveModelFallbackEnabled(pluginConfig) + + //#then + expect(result).toBe(true) + }) + + test("keeps model fallback disabled when explicitly turned off", () => { + //#given + const pluginConfig: OhMyOpenCodeConfig = { + model_fallback: false, + agents: { + sisyphus: { + fallback_models: ["openai/gpt-5.2"], + }, + }, + } + + //#when + const result = resolveModelFallbackEnabled(pluginConfig) + + //#then + expect(result).toBe(false) + }) +}) diff --git a/src/plugin/hooks/model-fallback-config.ts b/src/plugin/hooks/model-fallback-config.ts new file mode 100644 index 000000000..a63b3ae58 --- /dev/null +++ b/src/plugin/hooks/model-fallback-config.ts @@ -0,0 +1,33 @@ +import type { OhMyOpenCodeConfig } from "../../config" + +import { log, normalizeFallbackModels } from "../../shared" + +type FallbackModelsConfig = { + fallback_models?: string | string[] +} + +function hasFallbackModels(config: FallbackModelsConfig | undefined): boolean { + return (normalizeFallbackModels(config?.fallback_models)?.length ?? 0) > 0 +} + +export function hasConfiguredModelFallbacks(pluginConfig: OhMyOpenCodeConfig): boolean { + const agentConfigs = Object.values(pluginConfig.agents ?? {}) + if (agentConfigs.some(hasFallbackModels)) { + return true + } + + const categoryConfigs = Object.values(pluginConfig.categories ?? {}) + return categoryConfigs.some(hasFallbackModels) +} + +export function resolveModelFallbackEnabled(pluginConfig: OhMyOpenCodeConfig): boolean { + const hasConfiguredFallbacks = hasConfiguredModelFallbacks(pluginConfig) + + if (pluginConfig.model_fallback === false && hasConfiguredFallbacks) { + log( + "model_fallback is disabled while fallback_models are configured; set model_fallback=true to keep provider fallback retries enabled", + ) + } + + return pluginConfig.model_fallback ?? hasConfiguredFallbacks +} diff --git a/src/shared/model-error-classifier.test.ts b/src/shared/model-error-classifier.test.ts index d359c26d3..a9fa2fb99 100644 --- a/src/shared/model-error-classifier.test.ts +++ b/src/shared/model-error-classifier.test.ts @@ -40,6 +40,28 @@ describe("model-error-classifier", () => { expect(result).toBe(true) }) + test("treats FreeUsageLimitError names as retryable", () => { + //#given + const error = { name: "FreeUsageLimitError" } + + //#when + const result = shouldRetryError(error) + + //#then + expect(result).toBe(true) + }) + + test("treats free tier usage limit messages as retryable", () => { + //#given + const error = { message: "Free tier daily limit reached for this provider" } + + //#when + const result = shouldRetryError(error) + + //#then + expect(result).toBe(true) + }) + test("selectFallbackProvider prefers first connected provider in preference order", () => { //#given readConnectedProvidersCacheMock.mockReturnValue(["anthropic", "nvidia"]) diff --git a/src/shared/model-error-classifier.ts b/src/shared/model-error-classifier.ts index 22d5606c7..ecfd8f036 100644 --- a/src/shared/model-error-classifier.ts +++ b/src/shared/model-error-classifier.ts @@ -6,13 +6,14 @@ import { readConnectedProvidersCache } from "./connected-providers-cache" * These errors completely halt the action loop and should trigger fallback retry. */ const RETRYABLE_ERROR_NAMES = new Set([ - "ProviderModelNotFoundError", - "RateLimitError", - "QuotaExceededError", - "InsufficientCreditsError", - "ModelUnavailableError", - "ProviderConnectionError", - "AuthenticationError", + "providermodelnotfounderror", + "ratelimiterror", + "quotaexceedederror", + "insufficientcreditserror", + "modelunavailableerror", + "providerconnectionerror", + "authenticationerror", + "freeusagelimiterror", ]) /** @@ -20,24 +21,28 @@ const RETRYABLE_ERROR_NAMES = new Set([ * These errors are typically user-induced or fixable without switching models. */ const NON_RETRYABLE_ERROR_NAMES = new Set([ - "MessageAbortedError", - "PermissionDeniedError", - "ContextLengthError", - "TimeoutError", - "ValidationError", - "SyntaxError", - "UserError", + "messageabortederror", + "permissiondeniederror", + "contextlengtherror", + "timeouterror", + "validationerror", + "syntaxerror", + "usererror", ]) /** * Message patterns that indicate a retryable error even without a known error name. */ -const RETRYABLE_MESSAGE_PATTERNS = [ +const RETRYABLE_MESSAGE_PATTERNS: Array = [ "rate_limit", "rate limit", "quota", "quota will reset after", "usage limit has been reached", + /free\s+usage/i, + /free\s+tier/i, + /daily\s+limit/i, + /limit\s+reached/i, "all credentials for model", "cooling down", "exhausted your capacity", @@ -77,6 +82,11 @@ function hasProviderAutoRetrySignal(message: string): boolean { return AUTO_RETRY_GATE_PATTERNS.some((pattern) => message.includes(pattern)) } +function matchesRetryableMessagePattern(message: string): boolean { + return RETRYABLE_MESSAGE_PATTERNS.some((pattern) => + typeof pattern === "string" ? message.includes(pattern) : pattern.test(message)) +} + export interface ErrorInfo { name?: string message?: string @@ -89,12 +99,14 @@ export interface ErrorInfo { export function isRetryableModelError(error: ErrorInfo): boolean { // If we have an error name, check against known lists if (error.name) { + const normalizedErrorName = error.name.toLowerCase() + // Explicit non-retryable takes precedence - if (NON_RETRYABLE_ERROR_NAMES.has(error.name)) { + if (NON_RETRYABLE_ERROR_NAMES.has(normalizedErrorName)) { return false } // Check if it's a known retryable error - if (RETRYABLE_ERROR_NAMES.has(error.name)) { + if (RETRYABLE_ERROR_NAMES.has(normalizedErrorName)) { return true } } @@ -104,7 +116,7 @@ export function isRetryableModelError(error: ErrorInfo): boolean { if (hasProviderAutoRetrySignal(msg)) { return true } - return RETRYABLE_MESSAGE_PATTERNS.some((pattern) => msg.includes(pattern)) + return matchesRetryableMessagePattern(msg) } /**