From 613ef8eee8a175ba2a5c57ba8c29929b8d7ba8b7 Mon Sep 17 00:00:00 2001 From: Ravi Tharuma Date: Wed, 25 Mar 2026 15:09:25 +0100 Subject: [PATCH] fix(model-capabilities): harden runtime capability handling --- src/shared/connected-providers-cache.test.ts | 34 +++++ src/shared/connected-providers-cache.ts | 2 +- src/shared/model-capabilities-cache.test.ts | 31 ++++ src/shared/model-capabilities-cache.ts | 25 ++-- src/shared/model-capabilities.test.ts | 62 ++++++++ src/shared/model-capabilities.ts | 140 ++++++++++++++++-- src/shared/model-capability-heuristics.ts | 4 +- .../model-settings-compatibility.test.ts | 21 +++ 8 files changed, 296 insertions(+), 23 deletions(-) diff --git a/src/shared/connected-providers-cache.test.ts b/src/shared/connected-providers-cache.test.ts index cd2573b93..73c905d25 100644 --- a/src/shared/connected-providers-cache.test.ts +++ b/src/shared/connected-providers-cache.test.ts @@ -229,4 +229,38 @@ describe("updateConnectedProvidersCache", () => { limit: { output: 128000 }, }) }) + + test("keeps normalized fallback ids when raw metadata id is not a string", async () => { + const mockClient = { + provider: { + list: async () => ({ + data: { + connected: ["openai"], + all: [ + { + id: "openai", + models: { + "o3-mini": { + id: 123, + name: "o3-mini", + }, + }, + }, + ], + }, + }), + }, + } + + await testCacheStore.updateConnectedProvidersCache(mockClient) + const cache = testCacheStore.readProviderModelsCache() + + expect(cache?.models.openai).toEqual([ + { id: "o3-mini", name: "o3-mini" }, + ]) + expect(findProviderModelMetadata("openai", "o3-mini", cache)).toEqual({ + id: "o3-mini", + name: "o3-mini", + }) + }) }) diff --git a/src/shared/connected-providers-cache.ts b/src/shared/connected-providers-cache.ts index 61003cd38..cf17852cd 100644 --- a/src/shared/connected-providers-cache.ts +++ b/src/shared/connected-providers-cache.ts @@ -198,8 +198,8 @@ export function createConnectedProvidersCacheStore( : modelID return { - id: normalizedID, ...rawMetadata, + id: normalizedID, } satisfies ModelMetadata }) if (modelMetadata.length > 0) { diff --git a/src/shared/model-capabilities-cache.test.ts b/src/shared/model-capabilities-cache.test.ts index 2575577c3..0773a5fe0 100644 --- a/src/shared/model-capabilities-cache.test.ts +++ b/src/shared/model-capabilities-cache.test.ts @@ -97,6 +97,37 @@ describe("model-capabilities-cache", () => { }) }) + test("merges repeated snapshot entries without materializing empty optional objects", () => { + const raw = { + openai: { + models: { + "gpt-5.4": { + id: "gpt-5.4", + family: "gpt", + }, + }, + }, + alias: { + models: { + "gpt-5.4-preview": { + id: "gpt-5.4", + reasoning: true, + }, + }, + }, + } + + const snapshot = buildModelCapabilitiesSnapshotFromModelsDev(raw) + + expect(snapshot.models["gpt-5.4"]).toEqual({ + id: "gpt-5.4", + family: "gpt", + reasoning: true, + }) + expect(snapshot.models["gpt-5.4"]).not.toHaveProperty("modalities") + expect(snapshot.models["gpt-5.4"]).not.toHaveProperty("limit") + }) + test("refresh writes cache and preserves unrelated files in the cache directory", async () => { //#given const sentinelPath = join(testCacheDir, "keep-me.json") diff --git a/src/shared/model-capabilities-cache.ts b/src/shared/model-capabilities-cache.ts index c3339cd8d..bff841c68 100644 --- a/src/shared/model-capabilities-cache.ts +++ b/src/shared/model-capabilities-cache.ts @@ -8,7 +8,7 @@ export const MODELS_DEV_SOURCE_URL = "https://models.dev/api.json" const MODEL_CAPABILITIES_CACHE_FILE = "model-capabilities.json" function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null + return typeof value === "object" && value !== null && !Array.isArray(value) } function readBoolean(value: unknown): boolean | undefined { @@ -84,17 +84,24 @@ function mergeSnapshotEntries( return incoming } + const mergedModalities = existing.modalities || incoming.modalities + ? { + ...existing.modalities, + ...incoming.modalities, + } + : undefined + const mergedLimit = existing.limit || incoming.limit + ? { + ...existing.limit, + ...incoming.limit, + } + : undefined + return { ...existing, ...incoming, - modalities: { - ...existing.modalities, - ...incoming.modalities, - }, - limit: { - ...existing.limit, - ...incoming.limit, - }, + ...(mergedModalities ? { modalities: mergedModalities } : {}), + ...(mergedLimit ? { limit: mergedLimit } : {}), } } diff --git a/src/shared/model-capabilities.test.ts b/src/shared/model-capabilities.test.ts index 172e7a523..82b5ea649 100644 --- a/src/shared/model-capabilities.test.ts +++ b/src/shared/model-capabilities.test.ts @@ -81,6 +81,53 @@ describe("getModelCapabilities", () => { }) }) + test("reads structured runtime capabilities from the SDK v2 shape", () => { + const result = getModelCapabilities({ + providerID: "openai", + modelID: "gpt-5.4", + runtimeModel: { + capabilities: { + reasoning: true, + temperature: false, + toolcall: true, + input: { + text: true, + image: true, + }, + output: { + text: true, + }, + }, + }, + bundledSnapshot, + }) + + expect(result).toMatchObject({ + canonicalModelID: "gpt-5.4", + reasoning: true, + supportsThinking: true, + supportsTemperature: false, + toolCall: true, + modalities: { + input: ["text", "image"], + output: ["text"], + }, + }) + }) + + test("accepts runtime variant arrays without corrupting them into numeric keys", () => { + const result = getModelCapabilities({ + providerID: "openai", + modelID: "gpt-5.4", + runtimeModel: { + variants: ["low", "medium", "high", "xhigh"], + }, + bundledSnapshot, + }) + + expect(result.variants).toEqual(["low", "medium", "high", "xhigh"]) + }) + test("normalizes thinking suffix aliases before snapshot lookup", () => { const result = getModelCapabilities({ providerID: "anthropic", @@ -156,4 +203,19 @@ describe("getModelCapabilities", () => { reasoningEfforts: ["none", "minimal", "low", "medium", "high"], }) }) + + test("detects prefixed o-series model IDs through the heuristic fallback", () => { + const result = getModelCapabilities({ + providerID: "azure-openai", + modelID: "openai/o3-mini", + bundledSnapshot, + }) + + expect(result).toMatchObject({ + canonicalModelID: "openai/o3-mini", + family: "openai-reasoning", + variants: ["low", "medium", "high"], + reasoningEfforts: ["none", "minimal", "low", "medium", "high"], + }) + }) }) diff --git a/src/shared/model-capabilities.ts b/src/shared/model-capabilities.ts index 887d15286..cead7f00e 100644 --- a/src/shared/model-capabilities.ts +++ b/src/shared/model-capabilities.ts @@ -72,7 +72,7 @@ const MODEL_ID_OVERRIDES: Record = { } function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null + return typeof value === "object" && value !== null && !Array.isArray(value) } function normalizeLookupModelID(modelID: string): string { @@ -97,6 +97,11 @@ function readStringArray(value: unknown): string[] | undefined { } function normalizeVariantKeys(value: unknown): string[] | undefined { + const arrayVariants = readStringArray(value) + if (arrayVariants) { + return arrayVariants.map((variant) => variant.toLowerCase()) + } + if (!isRecord(value)) { return undefined } @@ -105,13 +110,30 @@ function normalizeVariantKeys(value: unknown): string[] | undefined { return variants.length > 0 ? variants : undefined } +function readModalityKeys(value: unknown): string[] | undefined { + const stringArray = readStringArray(value) + if (stringArray) { + return stringArray.map((entry) => entry.toLowerCase()) + } + + if (!isRecord(value)) { + return undefined + } + + const enabled = Object.entries(value) + .filter(([, supported]) => supported === true) + .map(([modality]) => modality.toLowerCase()) + + return enabled.length > 0 ? enabled : undefined +} + function normalizeModalities(value: unknown): ModelCapabilities["modalities"] | undefined { if (!isRecord(value)) { return undefined } - const input = readStringArray(value.input) - const output = readStringArray(value.output) + const input = readModalityKeys(value.input) + const output = readModalityKeys(value.output) if (!input && !output) { return undefined @@ -145,12 +167,18 @@ function getOverride(modelID: string): ModelCapabilityOverride | undefined { return MODEL_ID_OVERRIDES[normalizeLookupModelID(modelID)] } +function readRuntimeModelCapabilities(runtimeModel: Record | undefined): Record | undefined { + return isRecord(runtimeModel?.capabilities) ? runtimeModel.capabilities : undefined +} + function readRuntimeModelLimitOutput(runtimeModel: Record | undefined): number | undefined { if (!runtimeModel) { return undefined } - const limit = runtimeModel.limit + const limit = isRecord(runtimeModel.limit) + ? runtimeModel.limit + : readRuntimeModelCapabilities(runtimeModel)?.limit if (!isRecord(limit)) { return undefined } @@ -163,11 +191,101 @@ function readRuntimeModelBoolean(runtimeModel: Record | undefin return undefined } + const runtimeCapabilities = readRuntimeModelCapabilities(runtimeModel) + for (const key of keys) { const value = runtimeModel[key] if (typeof value === "boolean") { return value } + + const capabilityValue = runtimeCapabilities?.[key] + if (typeof capabilityValue === "boolean") { + return capabilityValue + } + } + + return undefined +} + +function readRuntimeModelModalities(runtimeModel: Record | undefined): ModelCapabilities["modalities"] | undefined { + if (!runtimeModel) { + return undefined + } + + const rootModalities = normalizeModalities(runtimeModel.modalities) + if (rootModalities) { + return rootModalities + } + + const runtimeCapabilities = readRuntimeModelCapabilities(runtimeModel) + if (!runtimeCapabilities) { + return undefined + } + + const nestedModalities = normalizeModalities(runtimeCapabilities.modalities) + if (nestedModalities) { + return nestedModalities + } + + const capabilityModalities = normalizeModalities(runtimeCapabilities) + if (capabilityModalities) { + return capabilityModalities + } + + return undefined +} + +function readRuntimeModelVariants(runtimeModel: Record | undefined): string[] | undefined { + if (!runtimeModel) { + return undefined + } + + const rootVariants = normalizeVariantKeys(runtimeModel.variants) + if (rootVariants) { + return rootVariants + } + + const runtimeCapabilities = readRuntimeModelCapabilities(runtimeModel) + if (!runtimeCapabilities) { + return undefined + } + + return normalizeVariantKeys(runtimeCapabilities.variants) +} + +function readRuntimeModelTopPSupport(runtimeModel: Record | undefined): boolean | undefined { + return readRuntimeModelBoolean(runtimeModel, ["topP", "top_p"]) +} + +function readRuntimeModelToolCallSupport(runtimeModel: Record | undefined): boolean | undefined { + return readRuntimeModelBoolean(runtimeModel, ["toolCall", "tool_call", "toolcall"]) +} + +function readRuntimeModelReasoningSupport(runtimeModel: Record | undefined): boolean | undefined { + return readRuntimeModelBoolean(runtimeModel, ["reasoning"]) +} + +function readRuntimeModelTemperatureSupport(runtimeModel: Record | undefined): boolean | undefined { + return readRuntimeModelBoolean(runtimeModel, ["temperature"]) +} + +function readRuntimeModelThinkingSupport(runtimeModel: Record | undefined): boolean | undefined { + const capabilityValue = readRuntimeModelReasoningSupport(runtimeModel) + if (capabilityValue !== undefined) { + return capabilityValue + } + + const runtimeCapabilities = readRuntimeModelCapabilities(runtimeModel) + if (!runtimeCapabilities) { + return undefined + } + + for (const key of ["thinking", "supportsThinking"] as const) { + const value = runtimeCapabilities[key] + if (typeof value === "boolean") { + return value + } } return undefined @@ -194,7 +312,7 @@ export function getModelCapabilities(input: GetModelCapabilitiesInput): ModelCap const bundledSnapshot = input.bundledSnapshot ?? bundledModelCapabilitiesSnapshot const snapshotEntry = runtimeSnapshot?.models?.[canonicalModelID] ?? bundledSnapshot.models[canonicalModelID] const heuristicFamily = detectHeuristicModelFamily(canonicalModelID) - const runtimeVariants = normalizeVariantKeys(runtimeModel?.variants) + const runtimeVariants = readRuntimeModelVariants(runtimeModel) return { requestedModelID, @@ -202,27 +320,27 @@ export function getModelCapabilities(input: GetModelCapabilitiesInput): ModelCap family: snapshotEntry?.family ?? heuristicFamily?.family, variants: runtimeVariants ?? override?.variants ?? heuristicFamily?.variants, reasoningEfforts: override?.reasoningEfforts ?? heuristicFamily?.reasoningEfforts, - reasoning: readRuntimeModelBoolean(runtimeModel, ["reasoning"]) ?? snapshotEntry?.reasoning, + reasoning: readRuntimeModelReasoningSupport(runtimeModel) ?? snapshotEntry?.reasoning, supportsThinking: override?.supportsThinking ?? heuristicFamily?.supportsThinking - ?? readRuntimeModelBoolean(runtimeModel, ["reasoning"]) + ?? readRuntimeModelThinkingSupport(runtimeModel) ?? snapshotEntry?.reasoning, supportsTemperature: - readRuntimeModelBoolean(runtimeModel, ["temperature"]) + readRuntimeModelTemperatureSupport(runtimeModel) ?? override?.supportsTemperature ?? snapshotEntry?.temperature, supportsTopP: - readRuntimeModelBoolean(runtimeModel, ["topP", "top_p"]) + readRuntimeModelTopPSupport(runtimeModel) ?? override?.supportsTopP, maxOutputTokens: readRuntimeModelLimitOutput(runtimeModel) ?? snapshotEntry?.limit?.output, toolCall: - readRuntimeModelBoolean(runtimeModel, ["toolCall", "tool_call"]) + readRuntimeModelToolCallSupport(runtimeModel) ?? snapshotEntry?.toolCall, modalities: - normalizeModalities(runtimeModel?.modalities) + readRuntimeModelModalities(runtimeModel) ?? snapshotEntry?.modalities, } } diff --git a/src/shared/model-capability-heuristics.ts b/src/shared/model-capability-heuristics.ts index 73286badc..374c185ea 100644 --- a/src/shared/model-capability-heuristics.ts +++ b/src/shared/model-capability-heuristics.ts @@ -24,14 +24,14 @@ export const HEURISTIC_MODEL_FAMILY_REGISTRY: ReadonlyArray { }) }) + test("GPT-5 downgrades unsupported max variant to xhigh", () => { + const result = resolveCompatibleModelSettings({ + providerID: "openai", + modelID: "gpt-5.4", + desired: { variant: "max" }, + }) + + expect(result).toEqual({ + variant: "xhigh", + reasoningEffort: undefined, + changes: [ + { + field: "variant", + from: "max", + to: "xhigh", + reason: "unsupported-by-model-family", + }, + ], + }) + }) + // Reasoning effort: "none" and "minimal" are valid per Vercel AI SDK test("GPT-5 keeps none reasoningEffort", () => { const result = resolveCompatibleModelSettings({