feat: add models.dev-backed model capabilities

This commit is contained in:
Ravi Tharuma
2026-03-25 14:47:46 +01:00
parent 7a52639a1b
commit 2af9324400
29 changed files with 42264 additions and 114 deletions

View File

@@ -1,6 +1,6 @@
import { normalizeSDKResponse } from "../shared/normalize-sdk-response"
import { getSessionPromptParams } from "../shared/session-prompt-params-state"
import { resolveCompatibleModelSettings } from "../shared"
import { getModelCapabilities, resolveCompatibleModelSettings } from "../shared"
export type ChatParamsInput = {
sessionID: string
@@ -21,25 +21,6 @@ export type ChatParamsOutput = {
options: Record<string, unknown>
}
type ProviderListClient = {
provider?: {
list?: () => Promise<unknown>
}
}
type ProviderModelMetadata = {
variants?: Record<string, unknown>
}
type ProviderListEntry = {
id?: string
models?: Record<string, ProviderModelMetadata>
}
type ProviderListData = {
all?: ProviderListEntry[]
}
function isRecord(value: unknown): value is Record<string, unknown> {
return typeof value === "object" && value !== null
}
@@ -101,33 +82,9 @@ function isChatParamsOutput(raw: unknown): raw is ChatParamsOutput {
return isRecord(raw.options)
}
async function getVariantCapabilities(
client: ProviderListClient | undefined,
model: { providerID: string; modelID: string },
): Promise<string[] | undefined> {
const providerList = client?.provider?.list
if (typeof providerList !== "function") {
return undefined
}
try {
const response = await providerList()
const data = normalizeSDKResponse<ProviderListData>(response, {})
const providerEntry = data.all?.find((entry) => entry.id === model.providerID)
const variants = providerEntry?.models?.[model.modelID]?.variants
if (!variants) {
return undefined
}
return Object.keys(variants)
} catch {
return undefined
}
}
export function createChatParamsHandler(args: {
anthropicEffort: { "chat.params"?: (input: ChatParamsHookInput, output: ChatParamsOutput) => Promise<void> } | null
client?: ProviderListClient
client?: unknown
}): (input: unknown, output: unknown) => Promise<void> {
return async (input, output): Promise<void> => {
const normalizedInput = buildChatParamsInput(input)
@@ -150,7 +107,10 @@ export function createChatParamsHandler(args: {
}
}
const variantCapabilities = await getVariantCapabilities(args.client, normalizedInput.model)
const capabilities = getModelCapabilities({
providerID: normalizedInput.model.providerID,
modelID: normalizedInput.model.modelID,
})
const compatibility = resolveCompatibleModelSettings({
providerID: normalizedInput.model.providerID,
@@ -162,10 +122,12 @@ export function createChatParamsHandler(args: {
reasoningEffort: typeof output.options.reasoningEffort === "string"
? output.options.reasoningEffort
: undefined,
temperature: typeof output.temperature === "number" ? output.temperature : undefined,
topP: typeof output.topP === "number" ? output.topP : undefined,
maxTokens: typeof output.options.maxTokens === "number" ? output.options.maxTokens : undefined,
thinking: isRecord(output.options.thinking) ? output.options.thinking : undefined,
},
capabilities: {
variants: variantCapabilities,
},
capabilities,
})
if (normalizedInput.rawMessage) {
@@ -183,6 +145,38 @@ export function createChatParamsHandler(args: {
delete output.options.reasoningEffort
}
if ("temperature" in compatibility) {
if (compatibility.temperature !== undefined) {
output.temperature = compatibility.temperature
} else {
delete output.temperature
}
}
if ("topP" in compatibility) {
if (compatibility.topP !== undefined) {
output.topP = compatibility.topP
} else {
delete output.topP
}
}
if ("maxTokens" in compatibility) {
if (compatibility.maxTokens !== undefined) {
output.options.maxTokens = compatibility.maxTokens
} else {
delete output.options.maxTokens
}
}
if ("thinking" in compatibility) {
if (compatibility.thinking !== undefined) {
output.options.thinking = compatibility.thinking
} else {
delete output.options.thinking
}
}
await args.anthropicEffort?.["chat.params"]?.(normalizedInput, output)
}
}