diff --git a/src/tools/delegate-task/model-selection.test.ts b/src/tools/delegate-task/model-selection.test.ts index 857efa7ad..18d0d5f82 100644 --- a/src/tools/delegate-task/model-selection.test.ts +++ b/src/tools/delegate-task/model-selection.test.ts @@ -1,5 +1,4 @@ -declare const require: (name: string) => any -const { describe, test, expect, beforeEach, afterEach, spyOn, mock } = require("bun:test") +import { afterEach, beforeEach, describe, expect, mock, spyOn, test } from "bun:test" import { resolveModelForDelegateTask } from "./model-selection" import * as connectedProvidersCache from "../../shared/connected-providers-cache" @@ -105,6 +104,26 @@ describe("resolveModelForDelegateTask", () => { expect(result!.model).toBe("anthropic/claude-sonnet-4-6") }) }) + + describe("#when user fallback models include variant syntax", () => { + test("#then resolves a parenthesized variant against the base available model", () => { + const result = resolveModelForDelegateTask({ + userFallbackModels: ["openai/gpt-5.2(high)"], + availableModels: new Set(["openai/gpt-5.2"]), + }) + + expect(result).toEqual({ model: "openai/gpt-5.2", variant: "high" }) + }) + + test("#then resolves a space-separated variant against the base available model", () => { + const result = resolveModelForDelegateTask({ + userFallbackModels: ["gpt-5.2 medium"], + availableModels: new Set(["openai/gpt-5.2"]), + }) + + expect(result).toEqual({ model: "openai/gpt-5.2", variant: "medium" }) + }) + }) }) describe("#given only connected providers cache exists (no provider-models cache)", () => { diff --git a/src/tools/delegate-task/model-selection.ts b/src/tools/delegate-task/model-selection.ts index 92a33db56..6beb24192 100644 --- a/src/tools/delegate-task/model-selection.ts +++ b/src/tools/delegate-task/model-selection.ts @@ -3,6 +3,7 @@ import { normalizeModel } from "../../shared/model-normalization" import { fuzzyMatchModel } from "../../shared/model-availability" import { transformModelForProvider } from "../../shared/provider-model-id-transform" import { hasConnectedProvidersCache, hasProviderModelsCache } from "../../shared/connected-providers-cache" +import { parseModelString, parseVariantFromModelID } from "./model-string-parser" function isExplicitHighModel(model: string): boolean { return /(?:^|\/)[^/]+-high$/.test(model) @@ -12,6 +13,36 @@ function getExplicitHighBaseModel(model: string): string | null { return isExplicitHighModel(model) ? model.replace(/-high$/, "") : null } +function parseUserFallbackModel(fallbackModel: string): { + baseModel: string + providerHint?: string[] + variant?: string +} | undefined { + const normalizedFallback = normalizeModel(fallbackModel) + if (!normalizedFallback) { + return undefined + } + + const parsedFullModel = parseModelString(normalizedFallback) + if (parsedFullModel) { + return { + baseModel: `${parsedFullModel.providerID}/${parsedFullModel.modelID}`, + providerHint: [parsedFullModel.providerID], + variant: parsedFullModel.variant, + } + } + + const parsedModel = parseVariantFromModelID(normalizedFallback) + if (!parsedModel.modelID) { + return undefined + } + + return { + baseModel: parsedModel.modelID, + variant: parsedModel.variant, + } +} + export function resolveModelForDelegateTask(input: { userModel?: string @@ -55,20 +86,18 @@ export function resolveModelForDelegateTask(input: { const userFallbackModels = input.userFallbackModels if (userFallbackModels && userFallbackModels.length > 0) { if (input.availableModels.size === 0) { - const first = normalizeModel(userFallbackModels[0]) + const first = userFallbackModels[0] ? parseUserFallbackModel(userFallbackModels[0]) : undefined if (first) { - return { model: first } + return { model: first.baseModel, variant: first.variant } } } else { for (const fallbackModel of userFallbackModels) { - const normalizedFallback = normalizeModel(fallbackModel) - if (!normalizedFallback) continue + const parsedFallback = parseUserFallbackModel(fallbackModel) + if (!parsedFallback) continue - const parts = normalizedFallback.split("/") - const providerHint = parts.length >= 2 ? [parts[0]] : undefined - const match = fuzzyMatchModel(normalizedFallback, input.availableModels, providerHint) + const match = fuzzyMatchModel(parsedFallback.baseModel, input.availableModels, parsedFallback.providerHint) if (match) { - return { model: match } + return { model: match, variant: parsedFallback.variant } } } } diff --git a/src/tools/delegate-task/model-string-parser.ts b/src/tools/delegate-task/model-string-parser.ts index 061a97876..d86f23324 100644 --- a/src/tools/delegate-task/model-string-parser.ts +++ b/src/tools/delegate-task/model-string-parser.ts @@ -9,7 +9,7 @@ const KNOWN_VARIANTS = new Set([ "thinking", ]) -function parseVariantFromModelID(rawModelID: string): { modelID: string; variant?: string } { +export function parseVariantFromModelID(rawModelID: string): { modelID: string; variant?: string } { const trimmedModelID = rawModelID.trim() if (!trimmedModelID) { return { modelID: "" }