Merge pull request #2826 from RaviTharuma/feat/model-capabilities-models-dev
feat(model-capabilities): add models.dev snapshot and runtime capability refresh
This commit is contained in:
@@ -4696,6 +4696,27 @@
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"model_capabilities": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"auto_refresh_on_start": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"refresh_timeout_ms": {
|
||||
"type": "integer",
|
||||
"exclusiveMinimum": 0,
|
||||
"maximum": 9007199254740991
|
||||
},
|
||||
"source_url": {
|
||||
"type": "string",
|
||||
"format": "uri"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"openclaw": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
"build:all": "bun run build && bun run build:binaries",
|
||||
"build:binaries": "bun run script/build-binaries.ts",
|
||||
"build:schema": "bun run script/build-schema.ts",
|
||||
"build:model-capabilities": "bun run script/build-model-capabilities.ts",
|
||||
"clean": "rm -rf dist",
|
||||
"prepare": "bun run build",
|
||||
"postinstall": "node postinstall.mjs",
|
||||
|
||||
13
script/build-model-capabilities.ts
Normal file
13
script/build-model-capabilities.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import { writeFileSync } from "fs"
|
||||
import { resolve } from "path"
|
||||
import {
|
||||
fetchModelCapabilitiesSnapshot,
|
||||
MODELS_DEV_SOURCE_URL,
|
||||
} from "../src/shared/model-capabilities-cache"
|
||||
|
||||
const OUTPUT_PATH = resolve(import.meta.dir, "../src/generated/model-capabilities.generated.json")
|
||||
|
||||
console.log(`Fetching model capabilities snapshot from ${MODELS_DEV_SOURCE_URL}...`)
|
||||
const snapshot = await fetchModelCapabilitiesSnapshot()
|
||||
writeFileSync(OUTPUT_PATH, `${JSON.stringify(snapshot, null, 2)}\n`)
|
||||
console.log(`Generated ${OUTPUT_PATH} with ${Object.keys(snapshot.models).length} models`)
|
||||
@@ -3,6 +3,7 @@ import { install } from "./install"
|
||||
import { run } from "./run"
|
||||
import { getLocalVersion } from "./get-local-version"
|
||||
import { doctor } from "./doctor"
|
||||
import { refreshModelCapabilities } from "./refresh-model-capabilities"
|
||||
import { createMcpOAuthCommand } from "./mcp-oauth"
|
||||
import type { InstallArgs } from "./types"
|
||||
import type { RunOptions } from "./run"
|
||||
@@ -176,6 +177,21 @@ Examples:
|
||||
process.exit(exitCode)
|
||||
})
|
||||
|
||||
program
|
||||
.command("refresh-model-capabilities")
|
||||
.description("Refresh the cached models.dev-based model capabilities snapshot")
|
||||
.option("-d, --directory <path>", "Working directory to read oh-my-opencode config from")
|
||||
.option("--source-url <url>", "Override the models.dev source URL")
|
||||
.option("--json", "Output refresh summary as JSON")
|
||||
.action(async (options) => {
|
||||
const exitCode = await refreshModelCapabilities({
|
||||
directory: options.directory,
|
||||
sourceUrl: options.sourceUrl,
|
||||
json: options.json ?? false,
|
||||
})
|
||||
process.exit(exitCode)
|
||||
})
|
||||
|
||||
program
|
||||
.command("version")
|
||||
.description("Show version information")
|
||||
|
||||
114
src/cli/refresh-model-capabilities.test.ts
Normal file
114
src/cli/refresh-model-capabilities.test.ts
Normal file
@@ -0,0 +1,114 @@
|
||||
import { describe, expect, it, mock } from "bun:test"
|
||||
|
||||
import { refreshModelCapabilities } from "./refresh-model-capabilities"
|
||||
|
||||
describe("refreshModelCapabilities", () => {
|
||||
it("uses config source_url when CLI override is absent", async () => {
|
||||
const loadConfig = mock(() => ({
|
||||
model_capabilities: {
|
||||
source_url: "https://mirror.example/api.json",
|
||||
},
|
||||
}))
|
||||
const refreshCache = mock(async () => ({
|
||||
generatedAt: "2026-03-25T00:00:00.000Z",
|
||||
sourceUrl: "https://mirror.example/api.json",
|
||||
models: {
|
||||
"gpt-5.4": { id: "gpt-5.4" },
|
||||
},
|
||||
}))
|
||||
let stdout = ""
|
||||
|
||||
const exitCode = await refreshModelCapabilities(
|
||||
{ directory: "/repo", json: false },
|
||||
{
|
||||
loadConfig,
|
||||
refreshCache,
|
||||
stdout: {
|
||||
write: (chunk: string) => {
|
||||
stdout += chunk
|
||||
return true
|
||||
},
|
||||
} as never,
|
||||
stderr: {
|
||||
write: () => true,
|
||||
} as never,
|
||||
},
|
||||
)
|
||||
|
||||
expect(exitCode).toBe(0)
|
||||
expect(loadConfig).toHaveBeenCalledWith("/repo", null)
|
||||
expect(refreshCache).toHaveBeenCalledWith({
|
||||
sourceUrl: "https://mirror.example/api.json",
|
||||
})
|
||||
expect(stdout).toContain("Refreshed model capabilities cache (1 models)")
|
||||
})
|
||||
|
||||
it("CLI sourceUrl overrides config and supports json output", async () => {
|
||||
const refreshCache = mock(async () => ({
|
||||
generatedAt: "2026-03-25T00:00:00.000Z",
|
||||
sourceUrl: "https://override.example/api.json",
|
||||
models: {
|
||||
"gpt-5.4": { id: "gpt-5.4" },
|
||||
"claude-opus-4-6": { id: "claude-opus-4-6" },
|
||||
},
|
||||
}))
|
||||
let stdout = ""
|
||||
|
||||
const exitCode = await refreshModelCapabilities(
|
||||
{
|
||||
directory: "/repo",
|
||||
json: true,
|
||||
sourceUrl: "https://override.example/api.json",
|
||||
},
|
||||
{
|
||||
loadConfig: () => ({}),
|
||||
refreshCache,
|
||||
stdout: {
|
||||
write: (chunk: string) => {
|
||||
stdout += chunk
|
||||
return true
|
||||
},
|
||||
} as never,
|
||||
stderr: {
|
||||
write: () => true,
|
||||
} as never,
|
||||
},
|
||||
)
|
||||
|
||||
expect(exitCode).toBe(0)
|
||||
expect(refreshCache).toHaveBeenCalledWith({
|
||||
sourceUrl: "https://override.example/api.json",
|
||||
})
|
||||
expect(JSON.parse(stdout)).toEqual({
|
||||
sourceUrl: "https://override.example/api.json",
|
||||
generatedAt: "2026-03-25T00:00:00.000Z",
|
||||
modelCount: 2,
|
||||
})
|
||||
})
|
||||
|
||||
it("returns exit code 1 when refresh fails", async () => {
|
||||
let stderr = ""
|
||||
|
||||
const exitCode = await refreshModelCapabilities(
|
||||
{ directory: "/repo" },
|
||||
{
|
||||
loadConfig: () => ({}),
|
||||
refreshCache: async () => {
|
||||
throw new Error("boom")
|
||||
},
|
||||
stdout: {
|
||||
write: () => true,
|
||||
} as never,
|
||||
stderr: {
|
||||
write: (chunk: string) => {
|
||||
stderr += chunk
|
||||
return true
|
||||
},
|
||||
} as never,
|
||||
},
|
||||
)
|
||||
|
||||
expect(exitCode).toBe(1)
|
||||
expect(stderr).toContain("Failed to refresh model capabilities cache")
|
||||
})
|
||||
})
|
||||
51
src/cli/refresh-model-capabilities.ts
Normal file
51
src/cli/refresh-model-capabilities.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import { loadPluginConfig } from "../plugin-config"
|
||||
import { refreshModelCapabilitiesCache } from "../shared/model-capabilities-cache"
|
||||
|
||||
export type RefreshModelCapabilitiesOptions = {
|
||||
directory?: string
|
||||
json?: boolean
|
||||
sourceUrl?: string
|
||||
}
|
||||
|
||||
type RefreshModelCapabilitiesDeps = {
|
||||
loadConfig?: typeof loadPluginConfig
|
||||
refreshCache?: typeof refreshModelCapabilitiesCache
|
||||
stdout?: Pick<typeof process.stdout, "write">
|
||||
stderr?: Pick<typeof process.stderr, "write">
|
||||
}
|
||||
|
||||
export async function refreshModelCapabilities(
|
||||
options: RefreshModelCapabilitiesOptions,
|
||||
deps: RefreshModelCapabilitiesDeps = {},
|
||||
): Promise<number> {
|
||||
const directory = options.directory ?? process.cwd()
|
||||
const loadConfig = deps.loadConfig ?? loadPluginConfig
|
||||
const refreshCache = deps.refreshCache ?? refreshModelCapabilitiesCache
|
||||
const stdout = deps.stdout ?? process.stdout
|
||||
const stderr = deps.stderr ?? process.stderr
|
||||
|
||||
try {
|
||||
const config = loadConfig(directory, null)
|
||||
const sourceUrl = options.sourceUrl ?? config.model_capabilities?.source_url
|
||||
const snapshot = await refreshCache({ sourceUrl })
|
||||
|
||||
const summary = {
|
||||
sourceUrl: snapshot.sourceUrl,
|
||||
generatedAt: snapshot.generatedAt,
|
||||
modelCount: Object.keys(snapshot.models).length,
|
||||
}
|
||||
|
||||
if (options.json) {
|
||||
stdout.write(`${JSON.stringify(summary, null, 2)}\n`)
|
||||
} else {
|
||||
stdout.write(
|
||||
`Refreshed model capabilities cache (${summary.modelCount} models) from ${summary.sourceUrl}\n`,
|
||||
)
|
||||
}
|
||||
|
||||
return 0
|
||||
} catch (error) {
|
||||
stderr.write(`Failed to refresh model capabilities cache: ${String(error)}\n`)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
@@ -19,5 +19,6 @@ export type {
|
||||
SisyphusConfig,
|
||||
SisyphusTasksConfig,
|
||||
RuntimeFallbackConfig,
|
||||
ModelCapabilitiesConfig,
|
||||
FallbackModels,
|
||||
} from "./schema"
|
||||
|
||||
@@ -147,6 +147,37 @@ describe("disabled_mcps schema", () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe("OhMyOpenCodeConfigSchema - model_capabilities", () => {
|
||||
test("accepts valid model capabilities config", () => {
|
||||
const input = {
|
||||
model_capabilities: {
|
||||
enabled: true,
|
||||
auto_refresh_on_start: true,
|
||||
refresh_timeout_ms: 5000,
|
||||
source_url: "https://models.dev/api.json",
|
||||
},
|
||||
}
|
||||
|
||||
const result = OhMyOpenCodeConfigSchema.safeParse(input)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.data.model_capabilities).toEqual(input.model_capabilities)
|
||||
}
|
||||
})
|
||||
|
||||
test("rejects invalid model capabilities config", () => {
|
||||
const result = OhMyOpenCodeConfigSchema.safeParse({
|
||||
model_capabilities: {
|
||||
refresh_timeout_ms: -1,
|
||||
source_url: "not-a-url",
|
||||
},
|
||||
})
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("AgentOverrideConfigSchema", () => {
|
||||
describe("category field", () => {
|
||||
test("accepts category as optional string", () => {
|
||||
|
||||
@@ -13,6 +13,7 @@ export * from "./schema/fallback-models"
|
||||
export * from "./schema/git-env-prefix"
|
||||
export * from "./schema/git-master"
|
||||
export * from "./schema/hooks"
|
||||
export * from "./schema/model-capabilities"
|
||||
export * from "./schema/notification"
|
||||
export * from "./schema/oh-my-opencode-config"
|
||||
export * from "./schema/ralph-loop"
|
||||
|
||||
10
src/config/schema/model-capabilities.ts
Normal file
10
src/config/schema/model-capabilities.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { z } from "zod"
|
||||
|
||||
export const ModelCapabilitiesConfigSchema = z.object({
|
||||
enabled: z.boolean().optional(),
|
||||
auto_refresh_on_start: z.boolean().optional(),
|
||||
refresh_timeout_ms: z.number().int().positive().optional(),
|
||||
source_url: z.string().url().optional(),
|
||||
})
|
||||
|
||||
export type ModelCapabilitiesConfig = z.infer<typeof ModelCapabilitiesConfigSchema>
|
||||
@@ -13,6 +13,7 @@ import { ExperimentalConfigSchema } from "./experimental"
|
||||
import { GitMasterConfigSchema } from "./git-master"
|
||||
import { NotificationConfigSchema } from "./notification"
|
||||
import { OpenClawConfigSchema } from "./openclaw"
|
||||
import { ModelCapabilitiesConfigSchema } from "./model-capabilities"
|
||||
import { RalphLoopConfigSchema } from "./ralph-loop"
|
||||
import { RuntimeFallbackConfigSchema } from "./runtime-fallback"
|
||||
import { SkillsConfigSchema } from "./skills"
|
||||
@@ -56,6 +57,7 @@ export const OhMyOpenCodeConfigSchema = z.object({
|
||||
runtime_fallback: z.union([z.boolean(), RuntimeFallbackConfigSchema]).optional(),
|
||||
background_task: BackgroundTaskConfigSchema.optional(),
|
||||
notification: NotificationConfigSchema.optional(),
|
||||
model_capabilities: ModelCapabilitiesConfigSchema.optional(),
|
||||
openclaw: OpenClawConfigSchema.optional(),
|
||||
babysitting: BabysittingConfigSchema.optional(),
|
||||
git_master: GitMasterConfigSchema.optional(),
|
||||
|
||||
40690
src/generated/model-capabilities.generated.json
Normal file
40690
src/generated/model-capabilities.generated.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"
|
||||
const mockShowConfigErrorsIfAny = mock(async () => {})
|
||||
const mockShowModelCacheWarningIfNeeded = mock(async () => {})
|
||||
const mockUpdateAndShowConnectedProvidersCacheStatus = mock(async () => {})
|
||||
const mockRefreshModelCapabilitiesOnStartup = mock(async () => {})
|
||||
const mockShowLocalDevToast = mock(async () => {})
|
||||
const mockShowVersionToast = mock(async () => {})
|
||||
const mockRunBackgroundUpdateCheck = mock(async () => {})
|
||||
@@ -22,6 +23,10 @@ mock.module("./hook/connected-providers-status", () => ({
|
||||
mockUpdateAndShowConnectedProvidersCacheStatus,
|
||||
}))
|
||||
|
||||
mock.module("./hook/model-capabilities-status", () => ({
|
||||
refreshModelCapabilitiesOnStartup: mockRefreshModelCapabilitiesOnStartup,
|
||||
}))
|
||||
|
||||
mock.module("./hook/startup-toasts", () => ({
|
||||
showLocalDevToast: mockShowLocalDevToast,
|
||||
showVersionToast: mockShowVersionToast,
|
||||
@@ -78,6 +83,7 @@ beforeEach(() => {
|
||||
mockShowConfigErrorsIfAny.mockClear()
|
||||
mockShowModelCacheWarningIfNeeded.mockClear()
|
||||
mockUpdateAndShowConnectedProvidersCacheStatus.mockClear()
|
||||
mockRefreshModelCapabilitiesOnStartup.mockClear()
|
||||
mockShowLocalDevToast.mockClear()
|
||||
mockShowVersionToast.mockClear()
|
||||
mockRunBackgroundUpdateCheck.mockClear()
|
||||
@@ -112,6 +118,7 @@ describe("createAutoUpdateCheckerHook", () => {
|
||||
expect(mockShowConfigErrorsIfAny).not.toHaveBeenCalled()
|
||||
expect(mockShowModelCacheWarningIfNeeded).not.toHaveBeenCalled()
|
||||
expect(mockUpdateAndShowConnectedProvidersCacheStatus).not.toHaveBeenCalled()
|
||||
expect(mockRefreshModelCapabilitiesOnStartup).not.toHaveBeenCalled()
|
||||
expect(mockShowLocalDevToast).not.toHaveBeenCalled()
|
||||
expect(mockShowVersionToast).not.toHaveBeenCalled()
|
||||
expect(mockRunBackgroundUpdateCheck).not.toHaveBeenCalled()
|
||||
@@ -129,6 +136,7 @@ describe("createAutoUpdateCheckerHook", () => {
|
||||
//#then - startup checks, toast, and background check run
|
||||
expect(mockShowConfigErrorsIfAny).toHaveBeenCalledTimes(1)
|
||||
expect(mockUpdateAndShowConnectedProvidersCacheStatus).toHaveBeenCalledTimes(1)
|
||||
expect(mockRefreshModelCapabilitiesOnStartup).toHaveBeenCalledTimes(1)
|
||||
expect(mockShowModelCacheWarningIfNeeded).toHaveBeenCalledTimes(1)
|
||||
expect(mockShowVersionToast).toHaveBeenCalledTimes(1)
|
||||
expect(mockRunBackgroundUpdateCheck).toHaveBeenCalledTimes(1)
|
||||
@@ -146,6 +154,7 @@ describe("createAutoUpdateCheckerHook", () => {
|
||||
//#then - no startup actions run
|
||||
expect(mockShowConfigErrorsIfAny).not.toHaveBeenCalled()
|
||||
expect(mockUpdateAndShowConnectedProvidersCacheStatus).not.toHaveBeenCalled()
|
||||
expect(mockRefreshModelCapabilitiesOnStartup).not.toHaveBeenCalled()
|
||||
expect(mockShowModelCacheWarningIfNeeded).not.toHaveBeenCalled()
|
||||
expect(mockShowLocalDevToast).not.toHaveBeenCalled()
|
||||
expect(mockShowVersionToast).not.toHaveBeenCalled()
|
||||
@@ -165,6 +174,7 @@ describe("createAutoUpdateCheckerHook", () => {
|
||||
//#then - side effects execute only once
|
||||
expect(mockShowConfigErrorsIfAny).toHaveBeenCalledTimes(1)
|
||||
expect(mockUpdateAndShowConnectedProvidersCacheStatus).toHaveBeenCalledTimes(1)
|
||||
expect(mockRefreshModelCapabilitiesOnStartup).toHaveBeenCalledTimes(1)
|
||||
expect(mockShowModelCacheWarningIfNeeded).toHaveBeenCalledTimes(1)
|
||||
expect(mockShowVersionToast).toHaveBeenCalledTimes(1)
|
||||
expect(mockRunBackgroundUpdateCheck).toHaveBeenCalledTimes(1)
|
||||
@@ -183,6 +193,7 @@ describe("createAutoUpdateCheckerHook", () => {
|
||||
//#then - local dev toast is shown and background check is skipped
|
||||
expect(mockShowConfigErrorsIfAny).toHaveBeenCalledTimes(1)
|
||||
expect(mockUpdateAndShowConnectedProvidersCacheStatus).toHaveBeenCalledTimes(1)
|
||||
expect(mockRefreshModelCapabilitiesOnStartup).toHaveBeenCalledTimes(1)
|
||||
expect(mockShowModelCacheWarningIfNeeded).toHaveBeenCalledTimes(1)
|
||||
expect(mockShowLocalDevToast).toHaveBeenCalledTimes(1)
|
||||
expect(mockShowVersionToast).not.toHaveBeenCalled()
|
||||
@@ -205,6 +216,7 @@ describe("createAutoUpdateCheckerHook", () => {
|
||||
//#then - no startup actions run
|
||||
expect(mockShowConfigErrorsIfAny).not.toHaveBeenCalled()
|
||||
expect(mockUpdateAndShowConnectedProvidersCacheStatus).not.toHaveBeenCalled()
|
||||
expect(mockRefreshModelCapabilitiesOnStartup).not.toHaveBeenCalled()
|
||||
expect(mockShowModelCacheWarningIfNeeded).not.toHaveBeenCalled()
|
||||
expect(mockShowLocalDevToast).not.toHaveBeenCalled()
|
||||
expect(mockShowVersionToast).not.toHaveBeenCalled()
|
||||
|
||||
@@ -5,11 +5,17 @@ import type { AutoUpdateCheckerOptions } from "./types"
|
||||
import { runBackgroundUpdateCheck } from "./hook/background-update-check"
|
||||
import { showConfigErrorsIfAny } from "./hook/config-errors-toast"
|
||||
import { updateAndShowConnectedProvidersCacheStatus } from "./hook/connected-providers-status"
|
||||
import { refreshModelCapabilitiesOnStartup } from "./hook/model-capabilities-status"
|
||||
import { showModelCacheWarningIfNeeded } from "./hook/model-cache-warning"
|
||||
import { showLocalDevToast, showVersionToast } from "./hook/startup-toasts"
|
||||
|
||||
export function createAutoUpdateCheckerHook(ctx: PluginInput, options: AutoUpdateCheckerOptions = {}) {
|
||||
const { showStartupToast = true, isSisyphusEnabled = false, autoUpdate = true } = options
|
||||
const {
|
||||
showStartupToast = true,
|
||||
isSisyphusEnabled = false,
|
||||
autoUpdate = true,
|
||||
modelCapabilities,
|
||||
} = options
|
||||
const isCliRunMode = process.env.OPENCODE_CLI_RUN_MODE === "true"
|
||||
|
||||
const getToastMessage = (isUpdate: boolean, latestVersion?: string): string => {
|
||||
@@ -43,6 +49,7 @@ export function createAutoUpdateCheckerHook(ctx: PluginInput, options: AutoUpdat
|
||||
|
||||
await showConfigErrorsIfAny(ctx)
|
||||
await updateAndShowConnectedProvidersCacheStatus(ctx)
|
||||
await refreshModelCapabilitiesOnStartup(modelCapabilities)
|
||||
await showModelCacheWarningIfNeeded(ctx)
|
||||
|
||||
if (localDevVersion) {
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
import type { ModelCapabilitiesConfig } from "../../../config/schema/model-capabilities"
|
||||
import { refreshModelCapabilitiesCache } from "../../../shared/model-capabilities-cache"
|
||||
import { log } from "../../../shared/logger"
|
||||
|
||||
const DEFAULT_REFRESH_TIMEOUT_MS = 5000
|
||||
|
||||
export async function refreshModelCapabilitiesOnStartup(
|
||||
config: ModelCapabilitiesConfig | undefined,
|
||||
): Promise<void> {
|
||||
if (config?.enabled === false) {
|
||||
return
|
||||
}
|
||||
|
||||
if (config?.auto_refresh_on_start === false) {
|
||||
return
|
||||
}
|
||||
|
||||
const timeoutMs = config?.refresh_timeout_ms ?? DEFAULT_REFRESH_TIMEOUT_MS
|
||||
|
||||
let timeoutId: ReturnType<typeof setTimeout> | undefined
|
||||
try {
|
||||
await Promise.race([
|
||||
refreshModelCapabilitiesCache({
|
||||
sourceUrl: config?.source_url,
|
||||
}),
|
||||
new Promise<never>((_, reject) => {
|
||||
timeoutId = setTimeout(() => reject(new Error("Model capabilities refresh timed out")), timeoutMs)
|
||||
}),
|
||||
])
|
||||
} catch (error) {
|
||||
log("[auto-update-checker] Model capabilities refresh failed", { error: String(error) })
|
||||
} finally {
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
import type { ModelCapabilitiesConfig } from "../../config/schema/model-capabilities"
|
||||
|
||||
export interface NpmDistTags {
|
||||
latest: string
|
||||
[key: string]: string
|
||||
@@ -26,4 +28,5 @@ export interface AutoUpdateCheckerOptions {
|
||||
showStartupToast?: boolean
|
||||
isSisyphusEnabled?: boolean
|
||||
autoUpdate?: boolean
|
||||
modelCapabilities?: ModelCapabilitiesConfig
|
||||
}
|
||||
|
||||
@@ -113,7 +113,6 @@ describe("createChatParamsHandler", () => {
|
||||
|
||||
//#then
|
||||
expect(output).toEqual({
|
||||
temperature: 0.4,
|
||||
topP: 0.7,
|
||||
topK: 1,
|
||||
options: {
|
||||
@@ -133,4 +132,86 @@ describe("createChatParamsHandler", () => {
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test("drops unsupported temperature and clamps maxTokens from bundled model capabilities", async () => {
|
||||
//#given
|
||||
setSessionPromptParams("ses_chat_params", {
|
||||
temperature: 0.7,
|
||||
options: {
|
||||
maxTokens: 200_000,
|
||||
},
|
||||
})
|
||||
|
||||
const handler = createChatParamsHandler({
|
||||
anthropicEffort: null,
|
||||
})
|
||||
|
||||
const input = {
|
||||
sessionID: "ses_chat_params",
|
||||
agent: { name: "oracle" },
|
||||
model: { providerID: "openai", modelID: "gpt-5.4" },
|
||||
provider: { id: "openai" },
|
||||
message: {},
|
||||
}
|
||||
|
||||
const output = {
|
||||
temperature: 0.1,
|
||||
topP: 1,
|
||||
topK: 1,
|
||||
options: {},
|
||||
}
|
||||
|
||||
//#when
|
||||
await handler(input, output)
|
||||
|
||||
//#then
|
||||
expect(output).toEqual({
|
||||
topP: 1,
|
||||
topK: 1,
|
||||
options: {
|
||||
maxTokens: 128_000,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test("drops unsupported reasoning settings from bundled model capabilities", async () => {
|
||||
//#given
|
||||
setSessionPromptParams("ses_chat_params", {
|
||||
temperature: 0.4,
|
||||
options: {
|
||||
reasoningEffort: "high",
|
||||
thinking: { type: "enabled", budgetTokens: 4096 },
|
||||
},
|
||||
})
|
||||
|
||||
const handler = createChatParamsHandler({
|
||||
anthropicEffort: null,
|
||||
})
|
||||
|
||||
const input = {
|
||||
sessionID: "ses_chat_params",
|
||||
agent: { name: "oracle" },
|
||||
model: { providerID: "openai", modelID: "gpt-4.1" },
|
||||
provider: { id: "openai" },
|
||||
message: {},
|
||||
}
|
||||
|
||||
const output = {
|
||||
temperature: 0.1,
|
||||
topP: 1,
|
||||
topK: 1,
|
||||
options: {},
|
||||
}
|
||||
|
||||
//#when
|
||||
await handler(input, output)
|
||||
|
||||
//#then
|
||||
expect(output).toEqual({
|
||||
temperature: 0.4,
|
||||
topP: 1,
|
||||
topK: 1,
|
||||
options: {},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { normalizeSDKResponse } from "../shared/normalize-sdk-response"
|
||||
import { getSessionPromptParams } from "../shared/session-prompt-params-state"
|
||||
import { resolveCompatibleModelSettings } from "../shared"
|
||||
import { getModelCapabilities, resolveCompatibleModelSettings } from "../shared"
|
||||
|
||||
export type ChatParamsInput = {
|
||||
sessionID: string
|
||||
@@ -21,25 +21,6 @@ export type ChatParamsOutput = {
|
||||
options: Record<string, unknown>
|
||||
}
|
||||
|
||||
type ProviderListClient = {
|
||||
provider?: {
|
||||
list?: () => Promise<unknown>
|
||||
}
|
||||
}
|
||||
|
||||
type ProviderModelMetadata = {
|
||||
variants?: Record<string, unknown>
|
||||
}
|
||||
|
||||
type ProviderListEntry = {
|
||||
id?: string
|
||||
models?: Record<string, ProviderModelMetadata>
|
||||
}
|
||||
|
||||
type ProviderListData = {
|
||||
all?: ProviderListEntry[]
|
||||
}
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null
|
||||
}
|
||||
@@ -101,33 +82,9 @@ function isChatParamsOutput(raw: unknown): raw is ChatParamsOutput {
|
||||
return isRecord(raw.options)
|
||||
}
|
||||
|
||||
async function getVariantCapabilities(
|
||||
client: ProviderListClient | undefined,
|
||||
model: { providerID: string; modelID: string },
|
||||
): Promise<string[] | undefined> {
|
||||
const providerList = client?.provider?.list
|
||||
if (typeof providerList !== "function") {
|
||||
return undefined
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await providerList()
|
||||
const data = normalizeSDKResponse<ProviderListData>(response, {})
|
||||
const providerEntry = data.all?.find((entry) => entry.id === model.providerID)
|
||||
const variants = providerEntry?.models?.[model.modelID]?.variants
|
||||
if (!variants) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return Object.keys(variants)
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
export function createChatParamsHandler(args: {
|
||||
anthropicEffort: { "chat.params"?: (input: ChatParamsHookInput, output: ChatParamsOutput) => Promise<void> } | null
|
||||
client?: ProviderListClient
|
||||
client?: unknown
|
||||
}): (input: unknown, output: unknown) => Promise<void> {
|
||||
return async (input, output): Promise<void> => {
|
||||
const normalizedInput = buildChatParamsInput(input)
|
||||
@@ -150,7 +107,10 @@ export function createChatParamsHandler(args: {
|
||||
}
|
||||
}
|
||||
|
||||
const variantCapabilities = await getVariantCapabilities(args.client, normalizedInput.model)
|
||||
const capabilities = getModelCapabilities({
|
||||
providerID: normalizedInput.model.providerID,
|
||||
modelID: normalizedInput.model.modelID,
|
||||
})
|
||||
|
||||
const compatibility = resolveCompatibleModelSettings({
|
||||
providerID: normalizedInput.model.providerID,
|
||||
@@ -162,10 +122,12 @@ export function createChatParamsHandler(args: {
|
||||
reasoningEffort: typeof output.options.reasoningEffort === "string"
|
||||
? output.options.reasoningEffort
|
||||
: undefined,
|
||||
temperature: typeof output.temperature === "number" ? output.temperature : undefined,
|
||||
topP: typeof output.topP === "number" ? output.topP : undefined,
|
||||
maxTokens: typeof output.options.maxTokens === "number" ? output.options.maxTokens : undefined,
|
||||
thinking: isRecord(output.options.thinking) ? output.options.thinking : undefined,
|
||||
},
|
||||
capabilities: {
|
||||
variants: variantCapabilities,
|
||||
},
|
||||
capabilities,
|
||||
})
|
||||
|
||||
if (normalizedInput.rawMessage) {
|
||||
@@ -183,6 +145,38 @@ export function createChatParamsHandler(args: {
|
||||
delete output.options.reasoningEffort
|
||||
}
|
||||
|
||||
if ("temperature" in compatibility) {
|
||||
if (compatibility.temperature !== undefined) {
|
||||
output.temperature = compatibility.temperature
|
||||
} else {
|
||||
delete output.temperature
|
||||
}
|
||||
}
|
||||
|
||||
if ("topP" in compatibility) {
|
||||
if (compatibility.topP !== undefined) {
|
||||
output.topP = compatibility.topP
|
||||
} else {
|
||||
delete output.topP
|
||||
}
|
||||
}
|
||||
|
||||
if ("maxTokens" in compatibility) {
|
||||
if (compatibility.maxTokens !== undefined) {
|
||||
output.options.maxTokens = compatibility.maxTokens
|
||||
} else {
|
||||
delete output.options.maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
if ("thinking" in compatibility) {
|
||||
if (compatibility.thinking !== undefined) {
|
||||
output.options.thinking = compatibility.thinking
|
||||
} else {
|
||||
delete output.options.thinking
|
||||
}
|
||||
}
|
||||
|
||||
await args.anthropicEffort?.["chat.params"]?.(normalizedInput, output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,6 +184,7 @@ export function createSessionHooks(args: {
|
||||
showStartupToast: isHookEnabled("startup-toast"),
|
||||
isSisyphusEnabled: pluginConfig.sisyphus_agent?.disabled !== true,
|
||||
autoUpdate: pluginConfig.auto_update ?? true,
|
||||
modelCapabilities: pluginConfig.model_capabilities,
|
||||
}))
|
||||
: null
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import { tmpdir } from "node:os"
|
||||
import { join } from "node:path"
|
||||
import {
|
||||
createConnectedProvidersCacheStore,
|
||||
findProviderModelMetadata,
|
||||
} from "./connected-providers-cache"
|
||||
|
||||
let fakeUserCacheRoot = ""
|
||||
@@ -68,8 +69,14 @@ describe("updateConnectedProvidersCache", () => {
|
||||
expect(cache).not.toBeNull()
|
||||
expect(cache!.connected).toEqual(["openai", "anthropic"])
|
||||
expect(cache!.models).toEqual({
|
||||
openai: ["gpt-5.3-codex", "gpt-5.4"],
|
||||
anthropic: ["claude-opus-4-6", "claude-sonnet-4-6"],
|
||||
openai: [
|
||||
{ id: "gpt-5.3-codex", name: "GPT-5.3 Codex" },
|
||||
{ id: "gpt-5.4", name: "GPT-5.4" },
|
||||
],
|
||||
anthropic: [
|
||||
{ id: "claude-opus-4-6", name: "Claude Opus 4.6" },
|
||||
{ id: "claude-sonnet-4-6", name: "Claude Sonnet 4.6" },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
@@ -174,4 +181,52 @@ describe("updateConnectedProvidersCache", () => {
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
test("findProviderModelMetadata returns rich cached metadata", async () => {
|
||||
//#given
|
||||
const mockClient = {
|
||||
provider: {
|
||||
list: async () => ({
|
||||
data: {
|
||||
connected: ["openai"],
|
||||
all: [
|
||||
{
|
||||
id: "openai",
|
||||
models: {
|
||||
"gpt-5.4": {
|
||||
id: "gpt-5.4",
|
||||
name: "GPT-5.4",
|
||||
temperature: false,
|
||||
variants: {
|
||||
low: {},
|
||||
high: {},
|
||||
},
|
||||
limit: { output: 128000 },
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
await testCacheStore.updateConnectedProvidersCache(mockClient)
|
||||
const cache = testCacheStore.readProviderModelsCache()
|
||||
|
||||
//#when
|
||||
const result = findProviderModelMetadata("openai", "gpt-5.4", cache)
|
||||
|
||||
//#then
|
||||
expect(result).toEqual({
|
||||
id: "gpt-5.4",
|
||||
name: "GPT-5.4",
|
||||
temperature: false,
|
||||
variants: {
|
||||
low: {},
|
||||
high: {},
|
||||
},
|
||||
limit: { output: 128000 },
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -11,20 +11,39 @@ interface ConnectedProvidersCache {
|
||||
updatedAt: string
|
||||
}
|
||||
|
||||
interface ModelMetadata {
|
||||
export interface ModelMetadata {
|
||||
id: string
|
||||
provider?: string
|
||||
context?: number
|
||||
output?: number
|
||||
name?: string
|
||||
variants?: Record<string, unknown>
|
||||
limit?: {
|
||||
context?: number
|
||||
input?: number
|
||||
output?: number
|
||||
}
|
||||
modalities?: {
|
||||
input?: string[]
|
||||
output?: string[]
|
||||
}
|
||||
capabilities?: Record<string, unknown>
|
||||
reasoning?: boolean
|
||||
temperature?: boolean
|
||||
tool_call?: boolean
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
interface ProviderModelsCache {
|
||||
export interface ProviderModelsCache {
|
||||
models: Record<string, string[] | ModelMetadata[]>
|
||||
connected: string[]
|
||||
updatedAt: string
|
||||
}
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null
|
||||
}
|
||||
|
||||
export function createConnectedProvidersCacheStore(
|
||||
getCacheDir: () => string = dataPath.getOmoOpenCodeCacheDir
|
||||
) {
|
||||
@@ -119,7 +138,7 @@ export function createConnectedProvidersCacheStore(
|
||||
return existsSync(cacheFile)
|
||||
}
|
||||
|
||||
function writeProviderModelsCache(data: { models: Record<string, string[]>; connected: string[] }): void {
|
||||
function writeProviderModelsCache(data: { models: Record<string, string[] | ModelMetadata[]>; connected: string[] }): void {
|
||||
ensureCacheDir()
|
||||
const cacheFile = getCacheFilePath(PROVIDER_MODELS_CACHE_FILE)
|
||||
|
||||
@@ -164,14 +183,27 @@ export function createConnectedProvidersCacheStore(
|
||||
|
||||
writeConnectedProvidersCache(connected)
|
||||
|
||||
const modelsByProvider: Record<string, string[]> = {}
|
||||
const modelsByProvider: Record<string, ModelMetadata[]> = {}
|
||||
const allProviders = result.data?.all ?? []
|
||||
|
||||
for (const provider of allProviders) {
|
||||
if (provider.models) {
|
||||
const modelIds = Object.keys(provider.models)
|
||||
if (modelIds.length > 0) {
|
||||
modelsByProvider[provider.id] = modelIds
|
||||
const modelMetadata = Object.entries(provider.models).map(([modelID, rawMetadata]) => {
|
||||
if (!isRecord(rawMetadata)) {
|
||||
return { id: modelID }
|
||||
}
|
||||
|
||||
const normalizedID = typeof rawMetadata.id === "string"
|
||||
? rawMetadata.id
|
||||
: modelID
|
||||
|
||||
return {
|
||||
id: normalizedID,
|
||||
...rawMetadata,
|
||||
} satisfies ModelMetadata
|
||||
})
|
||||
if (modelMetadata.length > 0) {
|
||||
modelsByProvider[provider.id] = modelMetadata
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,6 +232,32 @@ export function createConnectedProvidersCacheStore(
|
||||
}
|
||||
}
|
||||
|
||||
export function findProviderModelMetadata(
|
||||
providerID: string,
|
||||
modelID: string,
|
||||
cache: ProviderModelsCache | null = defaultConnectedProvidersCacheStore.readProviderModelsCache(),
|
||||
): ModelMetadata | undefined {
|
||||
const providerModels = cache?.models?.[providerID]
|
||||
if (!providerModels) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
for (const entry of providerModels) {
|
||||
if (typeof entry === "string") {
|
||||
if (entry === modelID) {
|
||||
return { id: entry }
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if (entry?.id === modelID) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
const defaultConnectedProvidersCacheStore = createConnectedProvidersCacheStore(
|
||||
() => dataPath.getOmoOpenCodeCacheDir()
|
||||
)
|
||||
|
||||
@@ -43,6 +43,9 @@ export type {
|
||||
ModelResolutionResult,
|
||||
} from "./model-resolution-types"
|
||||
export * from "./model-availability"
|
||||
export * from "./model-capabilities"
|
||||
export * from "./model-capabilities-cache"
|
||||
export * from "./model-capability-heuristics"
|
||||
export * from "./model-settings-compatibility"
|
||||
export * from "./fallback-model-availability"
|
||||
export * from "./connected-providers-cache"
|
||||
|
||||
134
src/shared/model-capabilities-cache.test.ts
Normal file
134
src/shared/model-capabilities-cache.test.ts
Normal file
@@ -0,0 +1,134 @@
|
||||
/// <reference types="bun-types" />
|
||||
|
||||
import { afterEach, beforeEach, describe, expect, test } from "bun:test"
|
||||
|
||||
import { existsSync, mkdirSync, mkdtempSync, readFileSync, rmSync, writeFileSync } from "node:fs"
|
||||
import { tmpdir } from "node:os"
|
||||
import { join } from "node:path"
|
||||
import {
|
||||
buildModelCapabilitiesSnapshotFromModelsDev,
|
||||
createModelCapabilitiesCacheStore,
|
||||
MODELS_DEV_SOURCE_URL,
|
||||
} from "./model-capabilities-cache"
|
||||
|
||||
let fakeUserCacheRoot = ""
|
||||
let testCacheDir = ""
|
||||
|
||||
describe("model-capabilities-cache", () => {
|
||||
beforeEach(() => {
|
||||
fakeUserCacheRoot = mkdtempSync(join(tmpdir(), "model-capabilities-cache-"))
|
||||
testCacheDir = join(fakeUserCacheRoot, "oh-my-opencode")
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
if (existsSync(fakeUserCacheRoot)) {
|
||||
rmSync(fakeUserCacheRoot, { recursive: true, force: true })
|
||||
}
|
||||
fakeUserCacheRoot = ""
|
||||
testCacheDir = ""
|
||||
})
|
||||
|
||||
test("builds a normalized snapshot from provider-keyed models.dev data", () => {
|
||||
//#given
|
||||
const raw = {
|
||||
openai: {
|
||||
models: {
|
||||
"gpt-5.4": {
|
||||
id: "gpt-5.4",
|
||||
family: "gpt",
|
||||
reasoning: true,
|
||||
temperature: false,
|
||||
tool_call: true,
|
||||
modalities: {
|
||||
input: ["text", "image"],
|
||||
output: ["text"],
|
||||
},
|
||||
limit: {
|
||||
context: 1_050_000,
|
||||
output: 128_000,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
anthropic: {
|
||||
models: {
|
||||
"claude-sonnet-4-6": {
|
||||
family: "claude-sonnet",
|
||||
reasoning: true,
|
||||
temperature: true,
|
||||
limit: {
|
||||
context: 1_000_000,
|
||||
output: 64_000,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
//#when
|
||||
const snapshot = buildModelCapabilitiesSnapshotFromModelsDev(raw)
|
||||
|
||||
//#then
|
||||
expect(snapshot.sourceUrl).toBe(MODELS_DEV_SOURCE_URL)
|
||||
expect(snapshot.models["gpt-5.4"]).toEqual({
|
||||
id: "gpt-5.4",
|
||||
family: "gpt",
|
||||
reasoning: true,
|
||||
temperature: false,
|
||||
toolCall: true,
|
||||
modalities: {
|
||||
input: ["text", "image"],
|
||||
output: ["text"],
|
||||
},
|
||||
limit: {
|
||||
context: 1_050_000,
|
||||
output: 128_000,
|
||||
},
|
||||
})
|
||||
expect(snapshot.models["claude-sonnet-4-6"]).toEqual({
|
||||
id: "claude-sonnet-4-6",
|
||||
family: "claude-sonnet",
|
||||
reasoning: true,
|
||||
temperature: true,
|
||||
limit: {
|
||||
context: 1_000_000,
|
||||
output: 64_000,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
test("refresh writes cache and preserves unrelated files in the cache directory", async () => {
|
||||
//#given
|
||||
const sentinelPath = join(testCacheDir, "keep-me.json")
|
||||
const store = createModelCapabilitiesCacheStore(() => testCacheDir)
|
||||
mkdirSync(testCacheDir, { recursive: true })
|
||||
writeFileSync(sentinelPath, JSON.stringify({ keep: true }))
|
||||
|
||||
const fetchImpl: typeof fetch = async () =>
|
||||
new Response(JSON.stringify({
|
||||
openai: {
|
||||
models: {
|
||||
"gpt-5.4": {
|
||||
id: "gpt-5.4",
|
||||
family: "gpt",
|
||||
reasoning: true,
|
||||
limit: { output: 128_000 },
|
||||
},
|
||||
},
|
||||
},
|
||||
}), {
|
||||
status: 200,
|
||||
headers: { "content-type": "application/json" },
|
||||
})
|
||||
|
||||
//#when
|
||||
const snapshot = await store.refreshModelCapabilitiesCache({ fetchImpl })
|
||||
const reloadedStore = createModelCapabilitiesCacheStore(() => testCacheDir)
|
||||
|
||||
//#then
|
||||
expect(snapshot.models["gpt-5.4"]?.limit?.output).toBe(128_000)
|
||||
expect(existsSync(sentinelPath)).toBe(true)
|
||||
expect(readFileSync(sentinelPath, "utf-8")).toBe(JSON.stringify({ keep: true }))
|
||||
expect(reloadedStore.readModelCapabilitiesCache()).toEqual(snapshot)
|
||||
})
|
||||
})
|
||||
241
src/shared/model-capabilities-cache.ts
Normal file
241
src/shared/model-capabilities-cache.ts
Normal file
@@ -0,0 +1,241 @@
|
||||
import { existsSync, mkdirSync, readFileSync, writeFileSync } from "fs"
|
||||
import { join } from "path"
|
||||
import * as dataPath from "./data-path"
|
||||
import { log } from "./logger"
|
||||
import type { ModelCapabilitiesSnapshot, ModelCapabilitiesSnapshotEntry } from "./model-capabilities"
|
||||
|
||||
export const MODELS_DEV_SOURCE_URL = "https://models.dev/api.json"
|
||||
const MODEL_CAPABILITIES_CACHE_FILE = "model-capabilities.json"
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null
|
||||
}
|
||||
|
||||
function readBoolean(value: unknown): boolean | undefined {
|
||||
return typeof value === "boolean" ? value : undefined
|
||||
}
|
||||
|
||||
function readNumber(value: unknown): number | undefined {
|
||||
return typeof value === "number" ? value : undefined
|
||||
}
|
||||
|
||||
function readString(value: unknown): string | undefined {
|
||||
return typeof value === "string" ? value : undefined
|
||||
}
|
||||
|
||||
function readStringArray(value: unknown): string[] | undefined {
|
||||
if (!Array.isArray(value)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const result = value.filter((item): item is string => typeof item === "string")
|
||||
return result.length > 0 ? result : undefined
|
||||
}
|
||||
|
||||
function normalizeSnapshotEntry(rawModelID: string, rawModel: unknown): ModelCapabilitiesSnapshotEntry | undefined {
|
||||
if (!isRecord(rawModel)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const id = readString(rawModel.id) ?? rawModelID
|
||||
const family = readString(rawModel.family)
|
||||
const reasoning = readBoolean(rawModel.reasoning)
|
||||
const temperature = readBoolean(rawModel.temperature)
|
||||
const toolCall = readBoolean(rawModel.tool_call)
|
||||
|
||||
const rawModalities = isRecord(rawModel.modalities) ? rawModel.modalities : undefined
|
||||
const modalitiesInput = readStringArray(rawModalities?.input)
|
||||
const modalitiesOutput = readStringArray(rawModalities?.output)
|
||||
const modalities = modalitiesInput || modalitiesOutput
|
||||
? {
|
||||
...(modalitiesInput ? { input: modalitiesInput } : {}),
|
||||
...(modalitiesOutput ? { output: modalitiesOutput } : {}),
|
||||
}
|
||||
: undefined
|
||||
|
||||
const rawLimit = isRecord(rawModel.limit) ? rawModel.limit : undefined
|
||||
const limitContext = readNumber(rawLimit?.context)
|
||||
const limitInput = readNumber(rawLimit?.input)
|
||||
const limitOutput = readNumber(rawLimit?.output)
|
||||
const limit = limitContext !== undefined || limitInput !== undefined || limitOutput !== undefined
|
||||
? {
|
||||
...(limitContext !== undefined ? { context: limitContext } : {}),
|
||||
...(limitInput !== undefined ? { input: limitInput } : {}),
|
||||
...(limitOutput !== undefined ? { output: limitOutput } : {}),
|
||||
}
|
||||
: undefined
|
||||
|
||||
return {
|
||||
id,
|
||||
...(family ? { family } : {}),
|
||||
...(reasoning !== undefined ? { reasoning } : {}),
|
||||
...(temperature !== undefined ? { temperature } : {}),
|
||||
...(toolCall !== undefined ? { toolCall } : {}),
|
||||
...(modalities ? { modalities } : {}),
|
||||
...(limit ? { limit } : {}),
|
||||
}
|
||||
}
|
||||
|
||||
function mergeSnapshotEntries(
|
||||
existing: ModelCapabilitiesSnapshotEntry | undefined,
|
||||
incoming: ModelCapabilitiesSnapshotEntry,
|
||||
): ModelCapabilitiesSnapshotEntry {
|
||||
if (!existing) {
|
||||
return incoming
|
||||
}
|
||||
|
||||
return {
|
||||
...existing,
|
||||
...incoming,
|
||||
modalities: {
|
||||
...existing.modalities,
|
||||
...incoming.modalities,
|
||||
},
|
||||
limit: {
|
||||
...existing.limit,
|
||||
...incoming.limit,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
export function buildModelCapabilitiesSnapshotFromModelsDev(raw: unknown): ModelCapabilitiesSnapshot {
|
||||
const models: Record<string, ModelCapabilitiesSnapshotEntry> = {}
|
||||
const providers = isRecord(raw) ? raw : {}
|
||||
|
||||
for (const providerValue of Object.values(providers)) {
|
||||
if (!isRecord(providerValue)) {
|
||||
continue
|
||||
}
|
||||
|
||||
const providerModels = providerValue.models
|
||||
if (!isRecord(providerModels)) {
|
||||
continue
|
||||
}
|
||||
|
||||
for (const [rawModelID, rawModel] of Object.entries(providerModels)) {
|
||||
const normalizedEntry = normalizeSnapshotEntry(rawModelID, rawModel)
|
||||
if (!normalizedEntry) {
|
||||
continue
|
||||
}
|
||||
|
||||
models[normalizedEntry.id.toLowerCase()] = mergeSnapshotEntries(
|
||||
models[normalizedEntry.id.toLowerCase()],
|
||||
normalizedEntry,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
generatedAt: new Date().toISOString(),
|
||||
sourceUrl: MODELS_DEV_SOURCE_URL,
|
||||
models,
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchModelCapabilitiesSnapshot(args: {
|
||||
sourceUrl?: string
|
||||
fetchImpl?: typeof fetch
|
||||
} = {}): Promise<ModelCapabilitiesSnapshot> {
|
||||
const sourceUrl = args.sourceUrl ?? MODELS_DEV_SOURCE_URL
|
||||
const fetchImpl = args.fetchImpl ?? fetch
|
||||
const response = await fetchImpl(sourceUrl)
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`models.dev fetch failed with ${response.status}`)
|
||||
}
|
||||
|
||||
const raw = await response.json()
|
||||
const snapshot = buildModelCapabilitiesSnapshotFromModelsDev(raw)
|
||||
return {
|
||||
...snapshot,
|
||||
sourceUrl,
|
||||
}
|
||||
}
|
||||
|
||||
export function createModelCapabilitiesCacheStore(
|
||||
getCacheDir: () => string = dataPath.getOmoOpenCodeCacheDir,
|
||||
) {
|
||||
let memSnapshot: ModelCapabilitiesSnapshot | null | undefined
|
||||
|
||||
function getCacheFilePath(): string {
|
||||
return join(getCacheDir(), MODEL_CAPABILITIES_CACHE_FILE)
|
||||
}
|
||||
|
||||
function ensureCacheDir(): void {
|
||||
const cacheDir = getCacheDir()
|
||||
if (!existsSync(cacheDir)) {
|
||||
mkdirSync(cacheDir, { recursive: true })
|
||||
}
|
||||
}
|
||||
|
||||
function readModelCapabilitiesCache(): ModelCapabilitiesSnapshot | null {
|
||||
if (memSnapshot !== undefined) {
|
||||
return memSnapshot
|
||||
}
|
||||
|
||||
const cacheFile = getCacheFilePath()
|
||||
if (!existsSync(cacheFile)) {
|
||||
memSnapshot = null
|
||||
log("[model-capabilities-cache] Cache file not found", { cacheFile })
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(cacheFile, "utf-8")
|
||||
const snapshot = JSON.parse(content) as ModelCapabilitiesSnapshot
|
||||
memSnapshot = snapshot
|
||||
log("[model-capabilities-cache] Read cache", {
|
||||
modelCount: Object.keys(snapshot.models).length,
|
||||
generatedAt: snapshot.generatedAt,
|
||||
})
|
||||
return snapshot
|
||||
} catch (error) {
|
||||
memSnapshot = null
|
||||
log("[model-capabilities-cache] Error reading cache", { error: String(error) })
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function hasModelCapabilitiesCache(): boolean {
|
||||
return existsSync(getCacheFilePath())
|
||||
}
|
||||
|
||||
function writeModelCapabilitiesCache(snapshot: ModelCapabilitiesSnapshot): void {
|
||||
ensureCacheDir()
|
||||
const cacheFile = getCacheFilePath()
|
||||
|
||||
writeFileSync(cacheFile, JSON.stringify(snapshot, null, 2) + "\n")
|
||||
memSnapshot = snapshot
|
||||
log("[model-capabilities-cache] Cache written", {
|
||||
modelCount: Object.keys(snapshot.models).length,
|
||||
generatedAt: snapshot.generatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
async function refreshModelCapabilitiesCache(args: {
|
||||
sourceUrl?: string
|
||||
fetchImpl?: typeof fetch
|
||||
} = {}): Promise<ModelCapabilitiesSnapshot> {
|
||||
const snapshot = await fetchModelCapabilitiesSnapshot(args)
|
||||
writeModelCapabilitiesCache(snapshot)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
return {
|
||||
readModelCapabilitiesCache,
|
||||
hasModelCapabilitiesCache,
|
||||
writeModelCapabilitiesCache,
|
||||
refreshModelCapabilitiesCache,
|
||||
}
|
||||
}
|
||||
|
||||
const defaultModelCapabilitiesCacheStore = createModelCapabilitiesCacheStore(
|
||||
() => dataPath.getOmoOpenCodeCacheDir(),
|
||||
)
|
||||
|
||||
export const {
|
||||
readModelCapabilitiesCache,
|
||||
hasModelCapabilitiesCache,
|
||||
writeModelCapabilitiesCache,
|
||||
refreshModelCapabilitiesCache,
|
||||
} = defaultModelCapabilitiesCacheStore
|
||||
159
src/shared/model-capabilities.test.ts
Normal file
159
src/shared/model-capabilities.test.ts
Normal file
@@ -0,0 +1,159 @@
|
||||
import { describe, expect, test } from "bun:test"
|
||||
|
||||
import {
|
||||
getModelCapabilities,
|
||||
type ModelCapabilitiesSnapshot,
|
||||
} from "./model-capabilities"
|
||||
|
||||
describe("getModelCapabilities", () => {
|
||||
const bundledSnapshot: ModelCapabilitiesSnapshot = {
|
||||
generatedAt: "2026-03-25T00:00:00.000Z",
|
||||
sourceUrl: "https://models.dev/api.json",
|
||||
models: {
|
||||
"claude-opus-4-6": {
|
||||
id: "claude-opus-4-6",
|
||||
family: "claude-opus",
|
||||
reasoning: true,
|
||||
temperature: true,
|
||||
modalities: {
|
||||
input: ["text", "image", "pdf"],
|
||||
output: ["text"],
|
||||
},
|
||||
limit: {
|
||||
context: 1_000_000,
|
||||
output: 128_000,
|
||||
},
|
||||
toolCall: true,
|
||||
},
|
||||
"gemini-3.1-pro-preview": {
|
||||
id: "gemini-3.1-pro-preview",
|
||||
family: "gemini",
|
||||
reasoning: true,
|
||||
temperature: true,
|
||||
modalities: {
|
||||
input: ["text", "image"],
|
||||
output: ["text"],
|
||||
},
|
||||
limit: {
|
||||
context: 1_000_000,
|
||||
output: 65_000,
|
||||
},
|
||||
},
|
||||
"gpt-5.4": {
|
||||
id: "gpt-5.4",
|
||||
family: "gpt",
|
||||
reasoning: true,
|
||||
temperature: false,
|
||||
modalities: {
|
||||
input: ["text", "image", "pdf"],
|
||||
output: ["text"],
|
||||
},
|
||||
limit: {
|
||||
context: 1_050_000,
|
||||
output: 128_000,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
test("uses runtime metadata before snapshot data", () => {
|
||||
const result = getModelCapabilities({
|
||||
providerID: "anthropic",
|
||||
modelID: "claude-opus-4-6",
|
||||
runtimeModel: {
|
||||
variants: {
|
||||
low: {},
|
||||
medium: {},
|
||||
high: {},
|
||||
},
|
||||
},
|
||||
bundledSnapshot,
|
||||
})
|
||||
|
||||
expect(result).toMatchObject({
|
||||
canonicalModelID: "claude-opus-4-6",
|
||||
family: "claude-opus",
|
||||
variants: ["low", "medium", "high"],
|
||||
supportsThinking: true,
|
||||
supportsTemperature: true,
|
||||
maxOutputTokens: 128_000,
|
||||
toolCall: true,
|
||||
})
|
||||
})
|
||||
|
||||
test("normalizes thinking suffix aliases before snapshot lookup", () => {
|
||||
const result = getModelCapabilities({
|
||||
providerID: "anthropic",
|
||||
modelID: "claude-opus-4-6-thinking",
|
||||
bundledSnapshot,
|
||||
})
|
||||
|
||||
expect(result).toMatchObject({
|
||||
canonicalModelID: "claude-opus-4-6",
|
||||
family: "claude-opus",
|
||||
supportsThinking: true,
|
||||
supportsTemperature: true,
|
||||
maxOutputTokens: 128_000,
|
||||
})
|
||||
})
|
||||
|
||||
test("maps local gemini aliases to canonical models.dev entries", () => {
|
||||
const result = getModelCapabilities({
|
||||
providerID: "google",
|
||||
modelID: "gemini-3.1-pro-high",
|
||||
bundledSnapshot,
|
||||
})
|
||||
|
||||
expect(result).toMatchObject({
|
||||
canonicalModelID: "gemini-3.1-pro-preview",
|
||||
family: "gemini",
|
||||
supportsThinking: true,
|
||||
supportsTemperature: true,
|
||||
maxOutputTokens: 65_000,
|
||||
})
|
||||
})
|
||||
|
||||
test("prefers runtime models.dev cache over bundled snapshot", () => {
|
||||
const runtimeSnapshot: ModelCapabilitiesSnapshot = {
|
||||
...bundledSnapshot,
|
||||
models: {
|
||||
...bundledSnapshot.models,
|
||||
"gpt-5.4": {
|
||||
...bundledSnapshot.models["gpt-5.4"],
|
||||
limit: {
|
||||
context: 1_050_000,
|
||||
output: 64_000,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const result = getModelCapabilities({
|
||||
providerID: "openai",
|
||||
modelID: "gpt-5.4",
|
||||
bundledSnapshot,
|
||||
runtimeSnapshot,
|
||||
})
|
||||
|
||||
expect(result).toMatchObject({
|
||||
canonicalModelID: "gpt-5.4",
|
||||
maxOutputTokens: 64_000,
|
||||
supportsTemperature: false,
|
||||
})
|
||||
})
|
||||
|
||||
test("falls back to heuristic family rules when no snapshot entry exists", () => {
|
||||
const result = getModelCapabilities({
|
||||
providerID: "openai",
|
||||
modelID: "o3-mini",
|
||||
bundledSnapshot,
|
||||
})
|
||||
|
||||
expect(result).toMatchObject({
|
||||
canonicalModelID: "o3-mini",
|
||||
family: "openai-reasoning",
|
||||
variants: ["low", "medium", "high"],
|
||||
reasoningEfforts: ["none", "minimal", "low", "medium", "high"],
|
||||
})
|
||||
})
|
||||
})
|
||||
228
src/shared/model-capabilities.ts
Normal file
228
src/shared/model-capabilities.ts
Normal file
@@ -0,0 +1,228 @@
|
||||
import bundledModelCapabilitiesSnapshotJson from "../generated/model-capabilities.generated.json"
|
||||
import { findProviderModelMetadata, type ModelMetadata } from "./connected-providers-cache"
|
||||
import { detectHeuristicModelFamily } from "./model-capability-heuristics"
|
||||
|
||||
export type ModelCapabilitiesSnapshotEntry = {
|
||||
id: string
|
||||
family?: string
|
||||
reasoning?: boolean
|
||||
temperature?: boolean
|
||||
toolCall?: boolean
|
||||
modalities?: {
|
||||
input?: string[]
|
||||
output?: string[]
|
||||
}
|
||||
limit?: {
|
||||
context?: number
|
||||
input?: number
|
||||
output?: number
|
||||
}
|
||||
}
|
||||
|
||||
export type ModelCapabilitiesSnapshot = {
|
||||
generatedAt: string
|
||||
sourceUrl: string
|
||||
models: Record<string, ModelCapabilitiesSnapshotEntry>
|
||||
}
|
||||
|
||||
export type ModelCapabilities = {
|
||||
requestedModelID: string
|
||||
canonicalModelID: string
|
||||
family?: string
|
||||
variants?: string[]
|
||||
reasoningEfforts?: string[]
|
||||
reasoning?: boolean
|
||||
supportsThinking?: boolean
|
||||
supportsTemperature?: boolean
|
||||
supportsTopP?: boolean
|
||||
maxOutputTokens?: number
|
||||
toolCall?: boolean
|
||||
modalities?: {
|
||||
input?: string[]
|
||||
output?: string[]
|
||||
}
|
||||
}
|
||||
|
||||
type GetModelCapabilitiesInput = {
|
||||
providerID: string
|
||||
modelID: string
|
||||
runtimeModel?: ModelMetadata | Record<string, unknown>
|
||||
runtimeSnapshot?: ModelCapabilitiesSnapshot
|
||||
bundledSnapshot?: ModelCapabilitiesSnapshot
|
||||
}
|
||||
|
||||
type ModelCapabilityOverride = {
|
||||
canonicalModelID?: string
|
||||
variants?: string[]
|
||||
reasoningEfforts?: string[]
|
||||
supportsThinking?: boolean
|
||||
supportsTemperature?: boolean
|
||||
supportsTopP?: boolean
|
||||
}
|
||||
|
||||
const MODEL_ID_OVERRIDES: Record<string, ModelCapabilityOverride> = {
|
||||
"claude-opus-4-6-thinking": { canonicalModelID: "claude-opus-4-6" },
|
||||
"claude-sonnet-4-6-thinking": { canonicalModelID: "claude-sonnet-4-6" },
|
||||
"claude-opus-4-5-thinking": { canonicalModelID: "claude-opus-4-5-20251101" },
|
||||
"gpt-5.3-codex-spark": { canonicalModelID: "gpt-5.3-codex" },
|
||||
"gemini-3.1-pro-high": { canonicalModelID: "gemini-3.1-pro-preview" },
|
||||
"gemini-3.1-pro-low": { canonicalModelID: "gemini-3.1-pro-preview" },
|
||||
"gemini-3-pro-high": { canonicalModelID: "gemini-3-pro-preview" },
|
||||
"gemini-3-pro-low": { canonicalModelID: "gemini-3-pro-preview" },
|
||||
}
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null
|
||||
}
|
||||
|
||||
function normalizeLookupModelID(modelID: string): string {
|
||||
return modelID.trim().toLowerCase()
|
||||
}
|
||||
|
||||
function readBoolean(value: unknown): boolean | undefined {
|
||||
return typeof value === "boolean" ? value : undefined
|
||||
}
|
||||
|
||||
function readNumber(value: unknown): number | undefined {
|
||||
return typeof value === "number" ? value : undefined
|
||||
}
|
||||
|
||||
function readStringArray(value: unknown): string[] | undefined {
|
||||
if (!Array.isArray(value)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const strings = value.filter((item): item is string => typeof item === "string")
|
||||
return strings.length > 0 ? strings : undefined
|
||||
}
|
||||
|
||||
function normalizeVariantKeys(value: unknown): string[] | undefined {
|
||||
if (!isRecord(value)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const variants = Object.keys(value).map((variant) => variant.toLowerCase())
|
||||
return variants.length > 0 ? variants : undefined
|
||||
}
|
||||
|
||||
function normalizeModalities(value: unknown): ModelCapabilities["modalities"] | undefined {
|
||||
if (!isRecord(value)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const input = readStringArray(value.input)
|
||||
const output = readStringArray(value.output)
|
||||
|
||||
if (!input && !output) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return {
|
||||
...(input ? { input } : {}),
|
||||
...(output ? { output } : {}),
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeSnapshot(snapshot: ModelCapabilitiesSnapshot | typeof bundledModelCapabilitiesSnapshotJson): ModelCapabilitiesSnapshot {
|
||||
return snapshot as ModelCapabilitiesSnapshot
|
||||
}
|
||||
|
||||
function getCanonicalModelID(modelID: string): string {
|
||||
const normalizedModelID = normalizeLookupModelID(modelID)
|
||||
const override = MODEL_ID_OVERRIDES[normalizedModelID]
|
||||
if (override?.canonicalModelID) {
|
||||
return override.canonicalModelID
|
||||
}
|
||||
|
||||
if (normalizedModelID.startsWith("claude-") && normalizedModelID.endsWith("-thinking")) {
|
||||
return normalizedModelID.replace(/-thinking$/i, "")
|
||||
}
|
||||
|
||||
return normalizedModelID
|
||||
}
|
||||
|
||||
function getOverride(modelID: string): ModelCapabilityOverride | undefined {
|
||||
return MODEL_ID_OVERRIDES[normalizeLookupModelID(modelID)]
|
||||
}
|
||||
|
||||
function readRuntimeModelLimitOutput(runtimeModel: Record<string, unknown> | undefined): number | undefined {
|
||||
if (!runtimeModel) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const limit = runtimeModel.limit
|
||||
if (!isRecord(limit)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return readNumber(limit.output)
|
||||
}
|
||||
|
||||
function readRuntimeModelBoolean(runtimeModel: Record<string, unknown> | undefined, keys: string[]): boolean | undefined {
|
||||
if (!runtimeModel) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
for (const key of keys) {
|
||||
const value = runtimeModel[key]
|
||||
if (typeof value === "boolean") {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
function readRuntimeModel(runtimeModel: ModelMetadata | Record<string, unknown> | undefined): Record<string, unknown> | undefined {
|
||||
return isRecord(runtimeModel) ? runtimeModel : undefined
|
||||
}
|
||||
|
||||
const bundledModelCapabilitiesSnapshot = normalizeSnapshot(bundledModelCapabilitiesSnapshotJson)
|
||||
|
||||
export function getBundledModelCapabilitiesSnapshot(): ModelCapabilitiesSnapshot {
|
||||
return bundledModelCapabilitiesSnapshot
|
||||
}
|
||||
|
||||
export function getModelCapabilities(input: GetModelCapabilitiesInput): ModelCapabilities {
|
||||
const requestedModelID = normalizeLookupModelID(input.modelID)
|
||||
const canonicalModelID = getCanonicalModelID(input.modelID)
|
||||
const override = getOverride(input.modelID)
|
||||
const runtimeModel = readRuntimeModel(
|
||||
input.runtimeModel ?? findProviderModelMetadata(input.providerID, input.modelID),
|
||||
)
|
||||
const runtimeSnapshot = input.runtimeSnapshot
|
||||
const bundledSnapshot = input.bundledSnapshot ?? bundledModelCapabilitiesSnapshot
|
||||
const snapshotEntry = runtimeSnapshot?.models?.[canonicalModelID] ?? bundledSnapshot.models[canonicalModelID]
|
||||
const heuristicFamily = detectHeuristicModelFamily(canonicalModelID)
|
||||
const runtimeVariants = normalizeVariantKeys(runtimeModel?.variants)
|
||||
|
||||
return {
|
||||
requestedModelID,
|
||||
canonicalModelID,
|
||||
family: snapshotEntry?.family ?? heuristicFamily?.family,
|
||||
variants: runtimeVariants ?? override?.variants ?? heuristicFamily?.variants,
|
||||
reasoningEfforts: override?.reasoningEfforts ?? heuristicFamily?.reasoningEfforts,
|
||||
reasoning: readRuntimeModelBoolean(runtimeModel, ["reasoning"]) ?? snapshotEntry?.reasoning,
|
||||
supportsThinking:
|
||||
override?.supportsThinking
|
||||
?? heuristicFamily?.supportsThinking
|
||||
?? readRuntimeModelBoolean(runtimeModel, ["reasoning"])
|
||||
?? snapshotEntry?.reasoning,
|
||||
supportsTemperature:
|
||||
readRuntimeModelBoolean(runtimeModel, ["temperature"])
|
||||
?? override?.supportsTemperature
|
||||
?? snapshotEntry?.temperature,
|
||||
supportsTopP:
|
||||
readRuntimeModelBoolean(runtimeModel, ["topP", "top_p"])
|
||||
?? override?.supportsTopP,
|
||||
maxOutputTokens:
|
||||
readRuntimeModelLimitOutput(runtimeModel)
|
||||
?? snapshotEntry?.limit?.output,
|
||||
toolCall:
|
||||
readRuntimeModelBoolean(runtimeModel, ["toolCall", "tool_call"])
|
||||
?? snapshotEntry?.toolCall,
|
||||
modalities:
|
||||
normalizeModalities(runtimeModel?.modalities)
|
||||
?? snapshotEntry?.modalities,
|
||||
}
|
||||
}
|
||||
93
src/shared/model-capability-heuristics.ts
Normal file
93
src/shared/model-capability-heuristics.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
import { normalizeModelID } from "./model-normalization"
|
||||
|
||||
export type HeuristicModelFamilyDefinition = {
|
||||
family: string
|
||||
includes?: string[]
|
||||
pattern?: RegExp
|
||||
variants?: string[]
|
||||
reasoningEfforts?: string[]
|
||||
supportsThinking?: boolean
|
||||
}
|
||||
|
||||
export const HEURISTIC_MODEL_FAMILY_REGISTRY: ReadonlyArray<HeuristicModelFamilyDefinition> = [
|
||||
{
|
||||
family: "claude-opus",
|
||||
pattern: /claude(?:-\d+(?:-\d+)*)?-opus/,
|
||||
variants: ["low", "medium", "high", "max"],
|
||||
supportsThinking: true,
|
||||
},
|
||||
{
|
||||
family: "claude-non-opus",
|
||||
includes: ["claude"],
|
||||
variants: ["low", "medium", "high"],
|
||||
supportsThinking: true,
|
||||
},
|
||||
{
|
||||
family: "openai-reasoning",
|
||||
pattern: /^o\d(?:$|-)/,
|
||||
variants: ["low", "medium", "high"],
|
||||
reasoningEfforts: ["none", "minimal", "low", "medium", "high"],
|
||||
},
|
||||
{
|
||||
family: "gpt-5",
|
||||
includes: ["gpt-5"],
|
||||
variants: ["low", "medium", "high", "xhigh", "max"],
|
||||
reasoningEfforts: ["none", "minimal", "low", "medium", "high", "xhigh"],
|
||||
},
|
||||
{
|
||||
family: "gpt-legacy",
|
||||
includes: ["gpt"],
|
||||
variants: ["low", "medium", "high"],
|
||||
},
|
||||
{
|
||||
family: "gemini",
|
||||
includes: ["gemini"],
|
||||
variants: ["low", "medium", "high"],
|
||||
},
|
||||
{
|
||||
family: "kimi",
|
||||
includes: ["kimi", "k2"],
|
||||
variants: ["low", "medium", "high"],
|
||||
},
|
||||
{
|
||||
family: "glm",
|
||||
includes: ["glm"],
|
||||
variants: ["low", "medium", "high"],
|
||||
},
|
||||
{
|
||||
family: "minimax",
|
||||
includes: ["minimax"],
|
||||
variants: ["low", "medium", "high"],
|
||||
},
|
||||
{
|
||||
family: "deepseek",
|
||||
includes: ["deepseek"],
|
||||
variants: ["low", "medium", "high"],
|
||||
},
|
||||
{
|
||||
family: "mistral",
|
||||
includes: ["mistral", "codestral"],
|
||||
variants: ["low", "medium", "high"],
|
||||
},
|
||||
{
|
||||
family: "llama",
|
||||
includes: ["llama"],
|
||||
variants: ["low", "medium", "high"],
|
||||
},
|
||||
]
|
||||
|
||||
export function detectHeuristicModelFamily(modelID: string): HeuristicModelFamilyDefinition | undefined {
|
||||
const normalizedModelID = normalizeModelID(modelID).toLowerCase()
|
||||
|
||||
for (const definition of HEURISTIC_MODEL_FAMILY_REGISTRY) {
|
||||
if (definition.pattern?.test(normalizedModelID)) {
|
||||
return definition
|
||||
}
|
||||
|
||||
if (definition.includes?.some((value) => normalizedModelID.includes(value))) {
|
||||
return definition
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
@@ -418,6 +418,63 @@ describe("resolveCompatibleModelSettings", () => {
|
||||
])
|
||||
})
|
||||
|
||||
test("drops unsupported temperature when capability metadata disables it", () => {
|
||||
const result = resolveCompatibleModelSettings({
|
||||
providerID: "openai",
|
||||
modelID: "gpt-5.4",
|
||||
desired: { temperature: 0.7 },
|
||||
capabilities: { supportsTemperature: false },
|
||||
})
|
||||
|
||||
expect(result.temperature).toBeUndefined()
|
||||
expect(result.changes).toEqual([
|
||||
{
|
||||
field: "temperature",
|
||||
from: "0.7",
|
||||
to: undefined,
|
||||
reason: "unsupported-by-model-metadata",
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
test("drops thinking when model capabilities say it is unsupported", () => {
|
||||
const result = resolveCompatibleModelSettings({
|
||||
providerID: "openai",
|
||||
modelID: "gpt-5.4",
|
||||
desired: { thinking: { type: "enabled", budgetTokens: 4096 } },
|
||||
capabilities: { supportsThinking: false },
|
||||
})
|
||||
|
||||
expect(result.thinking).toBeUndefined()
|
||||
expect(result.changes).toEqual([
|
||||
{
|
||||
field: "thinking",
|
||||
from: "{\"type\":\"enabled\",\"budgetTokens\":4096}",
|
||||
to: undefined,
|
||||
reason: "unsupported-by-model-metadata",
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
test("clamps maxTokens to the model output limit", () => {
|
||||
const result = resolveCompatibleModelSettings({
|
||||
providerID: "openai",
|
||||
modelID: "gpt-5.4",
|
||||
desired: { maxTokens: 200_000 },
|
||||
capabilities: { maxOutputTokens: 128_000 },
|
||||
})
|
||||
|
||||
expect(result.maxTokens).toBe(128_000)
|
||||
expect(result.changes).toEqual([
|
||||
{
|
||||
field: "maxTokens",
|
||||
from: "200000",
|
||||
to: "128000",
|
||||
reason: "max-output-limit",
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
// Passthrough: undefined desired values produce no changes
|
||||
test("no-op when desired settings are empty", () => {
|
||||
const result = resolveCompatibleModelSettings({
|
||||
|
||||
@@ -1,84 +1,56 @@
|
||||
import { normalizeModelID } from "./model-normalization"
|
||||
import { detectHeuristicModelFamily } from "./model-capability-heuristics"
|
||||
|
||||
type CompatibilityField = "variant" | "reasoningEffort"
|
||||
type CompatibilityField = "variant" | "reasoningEffort" | "temperature" | "topP" | "maxTokens" | "thinking"
|
||||
|
||||
type DesiredModelSettings = {
|
||||
variant?: string
|
||||
reasoningEffort?: string
|
||||
temperature?: number
|
||||
topP?: number
|
||||
maxTokens?: number
|
||||
thinking?: Record<string, unknown>
|
||||
}
|
||||
|
||||
type VariantCapabilities = {
|
||||
type CompatibilityCapabilities = {
|
||||
variants?: string[]
|
||||
reasoningEfforts?: string[]
|
||||
supportsTemperature?: boolean
|
||||
supportsTopP?: boolean
|
||||
maxOutputTokens?: number
|
||||
supportsThinking?: boolean
|
||||
}
|
||||
|
||||
export type ModelSettingsCompatibilityInput = {
|
||||
providerID: string
|
||||
modelID: string
|
||||
desired: DesiredModelSettings
|
||||
capabilities?: VariantCapabilities
|
||||
capabilities?: CompatibilityCapabilities
|
||||
}
|
||||
|
||||
export type ModelSettingsCompatibilityChange = {
|
||||
field: CompatibilityField
|
||||
from: string
|
||||
to?: string
|
||||
reason: "unsupported-by-model-family" | "unknown-model-family" | "unsupported-by-model-metadata"
|
||||
reason:
|
||||
| "unsupported-by-model-family"
|
||||
| "unknown-model-family"
|
||||
| "unsupported-by-model-metadata"
|
||||
| "max-output-limit"
|
||||
}
|
||||
|
||||
export type ModelSettingsCompatibilityResult = {
|
||||
variant?: string
|
||||
reasoningEffort?: string
|
||||
temperature?: number
|
||||
topP?: number
|
||||
maxTokens?: number
|
||||
thinking?: Record<string, unknown>
|
||||
changes: ModelSettingsCompatibilityChange[]
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Unified model family registry — detection rules + capabilities in ONE row.
|
||||
// New model family = one entry. Zero code changes anywhere else.
|
||||
// Order matters: more-specific patterns first (claude-opus before claude).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type FamilyDefinition = {
|
||||
/** Substring(s) in normalised model ID that identify this family (OR) */
|
||||
includes?: string[]
|
||||
/** Regex when substring matching isn't enough */
|
||||
pattern?: RegExp
|
||||
/** Supported variant levels (ordered low -> max) */
|
||||
variants: string[]
|
||||
/** Supported reasoning-effort levels. Omit = not supported. */
|
||||
reasoningEffort?: string[]
|
||||
}
|
||||
|
||||
const MODEL_FAMILY_REGISTRY: ReadonlyArray<readonly [string, FamilyDefinition]> = [
|
||||
["claude-opus", { pattern: /claude(?:-\d+(?:-\d+)*)?-opus/, variants: ["low", "medium", "high", "max"] }],
|
||||
["claude-non-opus", { includes: ["claude"], variants: ["low", "medium", "high"] }],
|
||||
["openai-reasoning", { pattern: /^o\d(?:$|-)/, variants: ["low", "medium", "high"], reasoningEffort: ["none", "minimal", "low", "medium", "high"] }],
|
||||
["gpt-5", { includes: ["gpt-5"], variants: ["low", "medium", "high", "xhigh", "max"], reasoningEffort: ["none", "minimal", "low", "medium", "high", "xhigh"] }],
|
||||
["gpt-legacy", { includes: ["gpt"], variants: ["low", "medium", "high"] }],
|
||||
["gemini", { includes: ["gemini"], variants: ["low", "medium", "high"] }],
|
||||
["kimi", { includes: ["kimi", "k2"], variants: ["low", "medium", "high"] }],
|
||||
["glm", { includes: ["glm"], variants: ["low", "medium", "high"] }],
|
||||
["minimax", { includes: ["minimax"], variants: ["low", "medium", "high"] }],
|
||||
["deepseek", { includes: ["deepseek"], variants: ["low", "medium", "high"] }],
|
||||
["mistral", { includes: ["mistral", "codestral"], variants: ["low", "medium", "high"] }],
|
||||
["llama", { includes: ["llama"], variants: ["low", "medium", "high"] }],
|
||||
]
|
||||
|
||||
const VARIANT_LADDER = ["low", "medium", "high", "xhigh", "max"]
|
||||
const REASONING_LADDER = ["none", "minimal", "low", "medium", "high", "xhigh"]
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model family detection — single pass over the registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function detectFamily(_providerID: string, modelID: string): FamilyDefinition | undefined {
|
||||
const model = normalizeModelID(modelID).toLowerCase()
|
||||
for (const [, def] of MODEL_FAMILY_REGISTRY) {
|
||||
if (def.pattern?.test(model)) return def
|
||||
if (def.includes?.some((s) => model.includes(s))) return def
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Generic resolution — one function for both fields
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -96,13 +68,20 @@ function downgradeWithinLadder(value: string, allowed: string[], ladder: string[
|
||||
return undefined
|
||||
}
|
||||
|
||||
function normalizeCapabilitiesVariants(capabilities: VariantCapabilities | undefined): string[] | undefined {
|
||||
function normalizeCapabilitiesVariants(capabilities: CompatibilityCapabilities | undefined): string[] | undefined {
|
||||
if (!capabilities?.variants || capabilities.variants.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
return capabilities.variants.map((v) => v.toLowerCase())
|
||||
}
|
||||
|
||||
function normalizeCapabilitiesReasoningEfforts(capabilities: CompatibilityCapabilities | undefined): string[] | undefined {
|
||||
if (!capabilities?.reasoningEfforts || capabilities.reasoningEfforts.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
return capabilities.reasoningEfforts.map((value) => value.toLowerCase())
|
||||
}
|
||||
|
||||
type FieldResolution = { value?: string; reason?: ModelSettingsCompatibilityChange["reason"] }
|
||||
|
||||
function resolveField(
|
||||
@@ -146,10 +125,11 @@ function resolveField(
|
||||
export function resolveCompatibleModelSettings(
|
||||
input: ModelSettingsCompatibilityInput,
|
||||
): ModelSettingsCompatibilityResult {
|
||||
const family = detectFamily(input.providerID, input.modelID)
|
||||
const family = detectHeuristicModelFamily(input.modelID)
|
||||
const familyKnown = family !== undefined
|
||||
const changes: ModelSettingsCompatibilityChange[] = []
|
||||
const metadataVariants = normalizeCapabilitiesVariants(input.capabilities)
|
||||
const metadataReasoningEfforts = normalizeCapabilitiesReasoningEfforts(input.capabilities)
|
||||
|
||||
let variant = input.desired.variant
|
||||
if (variant !== undefined) {
|
||||
@@ -164,12 +144,68 @@ export function resolveCompatibleModelSettings(
|
||||
let reasoningEffort = input.desired.reasoningEffort
|
||||
if (reasoningEffort !== undefined) {
|
||||
const normalized = reasoningEffort.toLowerCase()
|
||||
const resolved = resolveField(normalized, family?.reasoningEffort, REASONING_LADDER, familyKnown)
|
||||
const resolved = resolveField(normalized, family?.reasoningEfforts, REASONING_LADDER, familyKnown, metadataReasoningEfforts)
|
||||
if (resolved.value !== normalized && resolved.reason) {
|
||||
changes.push({ field: "reasoningEffort", from: reasoningEffort, to: resolved.value, reason: resolved.reason })
|
||||
}
|
||||
reasoningEffort = resolved.value
|
||||
}
|
||||
|
||||
return { variant, reasoningEffort, changes }
|
||||
let temperature = input.desired.temperature
|
||||
if (temperature !== undefined && input.capabilities?.supportsTemperature === false) {
|
||||
changes.push({
|
||||
field: "temperature",
|
||||
from: String(temperature),
|
||||
to: undefined,
|
||||
reason: "unsupported-by-model-metadata",
|
||||
})
|
||||
temperature = undefined
|
||||
}
|
||||
|
||||
let topP = input.desired.topP
|
||||
if (topP !== undefined && input.capabilities?.supportsTopP === false) {
|
||||
changes.push({
|
||||
field: "topP",
|
||||
from: String(topP),
|
||||
to: undefined,
|
||||
reason: "unsupported-by-model-metadata",
|
||||
})
|
||||
topP = undefined
|
||||
}
|
||||
|
||||
let maxTokens = input.desired.maxTokens
|
||||
if (
|
||||
maxTokens !== undefined &&
|
||||
input.capabilities?.maxOutputTokens !== undefined &&
|
||||
maxTokens > input.capabilities.maxOutputTokens
|
||||
) {
|
||||
changes.push({
|
||||
field: "maxTokens",
|
||||
from: String(maxTokens),
|
||||
to: String(input.capabilities.maxOutputTokens),
|
||||
reason: "max-output-limit",
|
||||
})
|
||||
maxTokens = input.capabilities.maxOutputTokens
|
||||
}
|
||||
|
||||
let thinking = input.desired.thinking
|
||||
if (thinking !== undefined && input.capabilities?.supportsThinking === false) {
|
||||
changes.push({
|
||||
field: "thinking",
|
||||
from: JSON.stringify(thinking),
|
||||
to: undefined,
|
||||
reason: "unsupported-by-model-metadata",
|
||||
})
|
||||
thinking = undefined
|
||||
}
|
||||
|
||||
return {
|
||||
variant,
|
||||
reasoningEffort,
|
||||
...(input.desired.temperature !== undefined ? { temperature } : {}),
|
||||
...(input.desired.topP !== undefined ? { topP } : {}),
|
||||
...(input.desired.maxTokens !== undefined ? { maxTokens } : {}),
|
||||
...(input.desired.thinking !== undefined ? { thinking } : {}),
|
||||
changes,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user