From a77a16c4941b9b436ccd6f931b643b260a8463a1 Mon Sep 17 00:00:00 2001 From: Ravi Tharuma Date: Wed, 18 Mar 2026 14:21:27 +0100 Subject: [PATCH] feat(config): support object-style fallback_models with per-model settings Add support for object-style entries in fallback_models arrays, enabling per-model configuration of variant, reasoningEffort, temperature, top_p, maxTokens, and thinking settings. - Zod schema for FallbackModelObject with full validation - normalizeFallbackModels() and flattenToFallbackModelStrings() utilities - Provider-agnostic model resolution pipeline with fallback chain - Session prompt params state management - Fallback chain construction with prefix-match lookup - Integration across delegate-task, background-agent, and plugin layers --- src/agents/types.ts | 2 +- src/config/schema/fallback-models.ts | 22 +- src/features/background-agent/manager.test.ts | 58 +++ src/features/background-agent/manager.ts | 55 ++- src/features/background-agent/spawner.test.ts | 126 ++++--- src/features/background-agent/spawner.ts | 51 ++- src/features/background-agent/types.ts | 15 +- src/hooks/runtime-fallback/fallback-models.ts | 43 ++- src/plugin/chat-params.test.ts | 69 +++- src/plugin/chat-params.ts | 18 + src/plugin/event-compaction-agent.test.ts | 3 + src/plugin/event.test.ts | 40 +++ src/plugin/event.ts | 10 +- src/shared/fallback-chain-from-models.test.ts | 338 +++++++++++++++++- src/shared/fallback-chain-from-models.ts | 74 +++- src/shared/index.ts | 2 +- src/shared/known-variants.ts | 16 + src/shared/model-requirements.ts | 5 + src/shared/model-resolver.ts | 42 ++- .../session-prompt-params-state.test.ts | 65 ++++ src/shared/session-prompt-params-state.ts | 34 ++ src/tools/delegate-task/background-task.ts | 4 +- .../delegate-task/category-resolver.test.ts | 256 +++++++++++++ src/tools/delegate-task/category-resolver.ts | 36 +- src/tools/delegate-task/model-selection.ts | 12 +- .../delegate-task/model-string-parser.ts | 1 + .../delegate-task/subagent-resolver.test.ts | 241 +++++++++++++ src/tools/delegate-task/subagent-resolver.ts | 35 +- .../delegate-task/sync-prompt-sender.test.ts | 73 +++- src/tools/delegate-task/sync-prompt-sender.ts | 32 +- src/tools/delegate-task/sync-task.ts | 4 +- src/tools/delegate-task/tools.ts | 13 +- src/tools/delegate-task/types.ts | 13 +- .../delegate-task/unstable-agent-task.ts | 4 +- 34 files changed, 1686 insertions(+), 126 deletions(-) create mode 100644 src/shared/known-variants.ts create mode 100644 src/shared/session-prompt-params-state.test.ts create mode 100644 src/shared/session-prompt-params-state.ts diff --git a/src/agents/types.ts b/src/agents/types.ts index acf490007..1d975c182 100644 --- a/src/agents/types.ts +++ b/src/agents/types.ts @@ -123,7 +123,7 @@ export type AgentName = BuiltinAgentName; export type AgentOverrideConfig = Partial & { prompt_append?: string; variant?: string; - fallback_models?: string | string[]; + fallback_models?: string | (string | import("../config/schema/fallback-models").FallbackModelObject)[]; }; export type AgentOverrides = Partial< diff --git a/src/config/schema/fallback-models.ts b/src/config/schema/fallback-models.ts index f9c28f437..c2d8ae4d2 100644 --- a/src/config/schema/fallback-models.ts +++ b/src/config/schema/fallback-models.ts @@ -1,5 +1,25 @@ import { z } from "zod" -export const FallbackModelsSchema = z.union([z.string(), z.array(z.string())]) +export const FallbackModelObjectSchema = z.object({ + model: z.string(), + variant: z.string().optional(), + reasoningEffort: z.enum(["none", "minimal", "low", "medium", "high", "xhigh"]).optional(), + temperature: z.number().min(0).max(2).optional(), + top_p: z.number().min(0).max(1).optional(), + maxTokens: z.number().optional(), + thinking: z + .object({ + type: z.enum(["enabled", "disabled"]), + budgetTokens: z.number().optional(), + }) + .optional(), +}) + +export type FallbackModelObject = z.infer + +export const FallbackModelsSchema = z.union([ + z.string(), + z.array(z.union([z.string(), FallbackModelObjectSchema])), +]) export type FallbackModels = z.infer diff --git a/src/features/background-agent/manager.test.ts b/src/features/background-agent/manager.test.ts index 71d62f509..547a0e340 100644 --- a/src/features/background-agent/manager.test.ts +++ b/src/features/background-agent/manager.test.ts @@ -1,5 +1,6 @@ declare const require: (name: string) => any const { describe, test, expect, beforeEach, afterEach, spyOn } = require("bun:test") +import { getSessionPromptParams, clearSessionPromptParams } from "../../shared/session-prompt-params-state" import { tmpdir } from "node:os" import type { PluginInput } from "@opencode-ai/plugin" import type { BackgroundTask, ResumeInput } from "./types" @@ -1636,6 +1637,9 @@ describe("BackgroundManager.resume model persistence", () => { }) afterEach(() => { + clearSessionPromptParams("session-1") + clearSessionPromptParams("session-advanced") + clearSessionPromptParams("session-2") manager.shutdown() }) @@ -1671,6 +1675,60 @@ describe("BackgroundManager.resume model persistence", () => { expect(promptCalls[0].body.agent).toBe("explore") }) + test("should preserve promoted per-model settings when resuming a task", async () => { + // given - task resumed after fallback promotion + const taskWithAdvancedModel: BackgroundTask = { + id: "task-with-advanced-model", + sessionID: "session-advanced", + parentSessionID: "parent-session", + parentMessageID: "msg-1", + description: "task with advanced model settings", + prompt: "original prompt", + agent: "explore", + status: "completed", + startedAt: new Date(), + completedAt: new Date(), + model: { + providerID: "openai", + modelID: "gpt-5.4-preview", + variant: "minimal", + reasoningEffort: "high", + temperature: 0.25, + top_p: 0.55, + maxTokens: 8192, + thinking: { type: "disabled" }, + }, + concurrencyGroup: "explore", + } + getTaskMap(manager).set(taskWithAdvancedModel.id, taskWithAdvancedModel) + + // when + await manager.resume({ + sessionId: "session-advanced", + prompt: "continue the work", + parentSessionID: "parent-session-2", + parentMessageID: "msg-2", + }) + + // then + expect(promptCalls).toHaveLength(1) + expect(promptCalls[0].body.model).toEqual({ + providerID: "openai", + modelID: "gpt-5.4-preview", + }) + expect(promptCalls[0].body.variant).toBe("minimal") + expect(promptCalls[0].body.options).toBeUndefined() + expect(getSessionPromptParams("session-advanced")).toEqual({ + temperature: 0.25, + topP: 0.55, + options: { + reasoningEffort: "high", + thinking: { type: "disabled" }, + maxTokens: 8192, + }, + }) + }) + test("should NOT pass model when task has no model (backward compatibility)", async () => { // given - task without model (default behavior) const taskWithoutModel: BackgroundTask = { diff --git a/src/features/background-agent/manager.ts b/src/features/background-agent/manager.ts index c4ea7528b..d522e4061 100644 --- a/src/features/background-agent/manager.ts +++ b/src/features/background-agent/manager.ts @@ -16,6 +16,35 @@ import { createInternalAgentTextPart, } from "../../shared" import { setSessionTools } from "../../shared/session-tools-store" +import { setSessionPromptParams } from "../../shared/session-prompt-params-state" + +type PromptParamsModel = { + reasoningEffort?: string + thinking?: { type: "enabled" | "disabled"; budgetTokens?: number } + maxTokens?: number + temperature?: number + top_p?: number +} + +function applySessionPromptParams(sessionID: string, model: PromptParamsModel): void { + const promptOptions: Record = { + ...(model.reasoningEffort ? { reasoningEffort: model.reasoningEffort } : {}), + ...(model.thinking ? { thinking: model.thinking } : {}), + ...(model.maxTokens !== undefined ? { maxTokens: model.maxTokens } : {}), + } + + if ( + model.temperature !== undefined || + model.top_p !== undefined || + Object.keys(promptOptions).length > 0 + ) { + setSessionPromptParams(sessionID, { + ...(model.temperature !== undefined ? { temperature: model.temperature } : {}), + ...(model.top_p !== undefined ? { topP: model.top_p } : {}), + ...(Object.keys(promptOptions).length > 0 ? { options: promptOptions } : {}), + }) + } +} import { SessionCategoryRegistry } from "../../shared/session-category-registry" import { ConcurrencyManager } from "./concurrency" import type { BackgroundTaskConfig, TmuxConfig } from "../../config/schema" @@ -504,14 +533,20 @@ export class BackgroundManager { }) // Fire-and-forget prompt via promptAsync (no response body needed) - // Include model if caller provided one (e.g., from Sisyphus category configs) - // IMPORTANT: variant must be a top-level field in the body, NOT nested inside model - // OpenCode's PromptInput schema expects: { model: { providerID, modelID }, variant: "max" } + // OpenCode prompt payload accepts model provider/model IDs and top-level variant only. + // Temperature/topP and provider-specific options are applied through chat.params. const launchModel = input.model - ? { providerID: input.model.providerID, modelID: input.model.modelID } + ? { + providerID: input.model.providerID, + modelID: input.model.modelID, + } : undefined const launchVariant = input.model?.variant + if (input.model) { + applySessionPromptParams(sessionID, input.model) + } + promptWithModelSuggestionRetry(this.client, { path: { id: sessionID }, body: { @@ -782,13 +817,19 @@ export class BackgroundManager { }) // Fire-and-forget prompt via promptAsync (no response body needed) - // Include model if task has one (preserved from original launch with category config) - // variant must be top-level in body, not nested inside model (OpenCode PromptInput schema) + // Resume uses the same PromptInput contract as launch: model IDs plus top-level variant. const resumeModel = existingTask.model - ? { providerID: existingTask.model.providerID, modelID: existingTask.model.modelID } + ? { + providerID: existingTask.model.providerID, + modelID: existingTask.model.modelID, + } : undefined const resumeVariant = existingTask.model?.variant + if (existingTask.model) { + applySessionPromptParams(existingTask.sessionID!, existingTask.model) + } + this.client.session.promptAsync({ path: { id: existingTask.sessionID }, body: { diff --git a/src/features/background-agent/spawner.test.ts b/src/features/background-agent/spawner.test.ts index 0ff9b3f78..bb0aa0e4c 100644 --- a/src/features/background-agent/spawner.test.ts +++ b/src/features/background-agent/spawner.test.ts @@ -1,68 +1,96 @@ -import { describe, test, expect } from "bun:test" +import { describe, test, expect, mock, afterEach } from "bun:test" +import { startTask } from "./spawner" +import type { BackgroundTask } from "./types" +import { + clearSessionPromptParams, + getSessionPromptParams, +} from "../../shared/session-prompt-params-state" -import { createTask, startTask } from "./spawner" +describe("background-agent spawner fallback model promotion", () => { + afterEach(() => { + clearSessionPromptParams("session-123") + }) -describe("background-agent spawner.startTask", () => { - test("applies explicit child session permission rules when creating child session", async () => { + test("passes promoted fallback model settings through supported prompt channels", async () => { //#given - const createCalls: any[] = [] - const parentPermission = [ - { permission: "question", action: "allow" as const, pattern: "*" }, - { permission: "plan_enter", action: "deny" as const, pattern: "*" }, - ] - + let promptArgs: any const client = { session: { - get: async () => ({ data: { directory: "/parent/dir", permission: parentPermission } }), - create: async (args?: any) => { - createCalls.push(args) - return { data: { id: "ses_child" } } - }, - promptAsync: async () => ({}), + get: mock(async () => ({ data: { directory: "/tmp/test" } })), + create: mock(async () => ({ data: { id: "session-123" } })), + promptAsync: mock(async (input: any) => { + promptArgs = input + return { data: {} } + }), }, - } + } as any - const task = createTask({ + const concurrencyManager = { + release: mock(() => {}), + } as any + + const onTaskError = mock(() => {}) + + const task: BackgroundTask = { + id: "bg_test123", + status: "pending", + queuedAt: new Date(), description: "Test task", - prompt: "Do work", - agent: "explore", - parentSessionID: "ses_parent", - parentMessageID: "msg_parent", - }) - - const item = { - task, - input: { - description: task.description, - prompt: task.prompt, - agent: task.agent, - parentSessionID: task.parentSessionID, - parentMessageID: task.parentMessageID, - parentModel: task.parentModel, - parentAgent: task.parentAgent, - model: task.model, - sessionPermission: [ - { permission: "question", action: "deny", pattern: "*" }, - ], + prompt: "Do the thing", + agent: "oracle", + parentSessionID: "parent-1", + parentMessageID: "message-1", + model: { + providerID: "openai", + modelID: "gpt-5.4", + variant: "low", + reasoningEffort: "high", + temperature: 0.4, + top_p: 0.7, + maxTokens: 4096, + thinking: { type: "disabled" }, }, } - const ctx = { - client, - directory: "/fallback", - concurrencyManager: { release: () => {} }, - tmuxEnabled: false, - onTaskError: () => {}, + const input = { + description: "Test task", + prompt: "Do the thing", + agent: "oracle", + parentSessionID: "parent-1", + parentMessageID: "message-1", + model: task.model, } //#when - await startTask(item as any, ctx as any) + await startTask( + { task, input }, + { + client, + directory: "/tmp/test", + concurrencyManager, + tmuxEnabled: false, + onTaskError, + }, + ) + + await new Promise((resolve) => setTimeout(resolve, 0)) //#then - expect(createCalls).toHaveLength(1) - expect(createCalls[0]?.body?.permission).toEqual([ - { permission: "question", action: "deny", pattern: "*" }, - ]) + expect(promptArgs.body.model).toEqual({ + providerID: "openai", + modelID: "gpt-5.4", + }) + expect(promptArgs.body.variant).toBe("low") + expect(promptArgs.body.options).toBeUndefined() + expect(getSessionPromptParams("session-123")).toEqual({ + temperature: 0.4, + topP: 0.7, + options: { + reasoningEffort: "high", + thinking: { type: "disabled" }, + maxTokens: 4096, + }, + }) }) test("keeps agent when explicit model is configured", async () => { diff --git a/src/features/background-agent/spawner.ts b/src/features/background-agent/spawner.ts index c4f435720..13139c29d 100644 --- a/src/features/background-agent/spawner.ts +++ b/src/features/background-agent/spawner.ts @@ -2,6 +2,7 @@ import type { BackgroundTask, LaunchInput, ResumeInput } from "./types" import type { OpencodeClient, OnSubagentSessionCreated, QueueItem } from "./constants" import { TMUX_CALLBACK_DELAY_MS } from "./constants" import { log, getAgentToolRestrictions, promptWithModelSuggestionRetry, createInternalAgentTextPart } from "../../shared" +import { setSessionPromptParams } from "../../shared/session-prompt-params-state" import { subagentSessions } from "../claude-code-session-state" import { getTaskToastManager } from "../task-toast-manager" import { isInsideTmux } from "../../shared/tmux" @@ -128,10 +129,33 @@ export async function startTask( }) const launchModel = input.model - ? { providerID: input.model.providerID, modelID: input.model.modelID } + ? { + providerID: input.model.providerID, + modelID: input.model.modelID, + } : undefined const launchVariant = input.model?.variant + if (input.model) { + const promptOptions: Record = { + ...(input.model.reasoningEffort ? { reasoningEffort: input.model.reasoningEffort } : {}), + ...(input.model.thinking ? { thinking: input.model.thinking } : {}), + ...(input.model.maxTokens !== undefined ? { maxTokens: input.model.maxTokens } : {}), + } + + if ( + input.model.temperature !== undefined || + input.model.top_p !== undefined || + Object.keys(promptOptions).length > 0 + ) { + setSessionPromptParams(sessionID, { + ...(input.model.temperature !== undefined ? { temperature: input.model.temperature } : {}), + ...(input.model.top_p !== undefined ? { topP: input.model.top_p } : {}), + ...(Object.keys(promptOptions).length > 0 ? { options: promptOptions } : {}), + }) + } + } + promptWithModelSuggestionRetry(client, { path: { id: sessionID }, body: { @@ -213,10 +237,33 @@ export async function resumeTask( }) const resumeModel = task.model - ? { providerID: task.model.providerID, modelID: task.model.modelID } + ? { + providerID: task.model.providerID, + modelID: task.model.modelID, + } : undefined const resumeVariant = task.model?.variant + if (task.model) { + const promptOptions: Record = { + ...(task.model.reasoningEffort ? { reasoningEffort: task.model.reasoningEffort } : {}), + ...(task.model.thinking ? { thinking: task.model.thinking } : {}), + ...(task.model.maxTokens !== undefined ? { maxTokens: task.model.maxTokens } : {}), + } + + if ( + task.model.temperature !== undefined || + task.model.top_p !== undefined || + Object.keys(promptOptions).length > 0 + ) { + setSessionPromptParams(task.sessionID, { + ...(task.model.temperature !== undefined ? { temperature: task.model.temperature } : {}), + ...(task.model.top_p !== undefined ? { topP: task.model.top_p } : {}), + ...(Object.keys(promptOptions).length > 0 ? { options: promptOptions } : {}), + }) + } + } + client.session.promptAsync({ path: { id: task.sessionID }, body: { diff --git a/src/features/background-agent/types.ts b/src/features/background-agent/types.ts index e40f98e27..fca05b976 100644 --- a/src/features/background-agent/types.ts +++ b/src/features/background-agent/types.ts @@ -25,6 +25,17 @@ export interface TaskProgress { lastMessageAt?: Date } +type DelegatedModelConfig = { + providerID: string + modelID: string + variant?: string + reasoningEffort?: string + temperature?: number + top_p?: number + maxTokens?: number + thinking?: { type: "enabled" | "disabled"; budgetTokens?: number } +} + export interface BackgroundTask { id: string sessionID?: string @@ -43,7 +54,7 @@ export interface BackgroundTask { error?: string progress?: TaskProgress parentModel?: { providerID: string; modelID: string } - model?: { providerID: string; modelID: string; variant?: string } + model?: DelegatedModelConfig /** Fallback chain for runtime retry on model errors */ fallbackChain?: FallbackEntry[] /** Number of fallback retry attempts made */ @@ -76,7 +87,7 @@ export interface LaunchInput { parentModel?: { providerID: string; modelID: string } parentAgent?: string parentTools?: Record - model?: { providerID: string; modelID: string; variant?: string } + model?: DelegatedModelConfig /** Fallback chain for runtime retry on model errors */ fallbackChain?: FallbackEntry[] isUnstableAgent?: boolean diff --git a/src/hooks/runtime-fallback/fallback-models.ts b/src/hooks/runtime-fallback/fallback-models.ts index fb984bbec..415751d7e 100644 --- a/src/hooks/runtime-fallback/fallback-models.ts +++ b/src/hooks/runtime-fallback/fallback-models.ts @@ -1,10 +1,16 @@ import type { OhMyOpenCodeConfig } from "../../config" +import type { FallbackModelObject } from "../../config/schema/fallback-models" import { agentPattern } from "./agent-resolver" import { HOOK_NAME } from "./constants" import { log } from "../../shared/logger" import { SessionCategoryRegistry } from "../../shared/session-category-registry" -import { normalizeFallbackModels } from "../../shared/model-resolver" +import { normalizeFallbackModels, flattenToFallbackModelStrings } from "../../shared/model-resolver" +/** + * Returns fallback model strings for the runtime-fallback system. + * Object entries are flattened to "provider/model(variant)" strings so the + * string-based fallback state machine can work with them unchanged. + */ export function getFallbackModelsForSession( sessionID: string, agent: string | undefined, @@ -12,22 +18,45 @@ export function getFallbackModelsForSession( ): string[] { if (!pluginConfig) return [] + const raw = getRawFallbackModelsForSession(sessionID, agent, pluginConfig) + return flattenToFallbackModelStrings(raw) ?? [] +} + +/** + * Returns the raw fallback model entries (strings and objects) for a session. + * Use this when per-model settings (temperature, reasoningEffort, etc.) must be + * preserved — e.g. before passing to buildFallbackChainFromModels. + */ +export function getRawFallbackModels( + sessionID: string, + agent: string | undefined, + pluginConfig: OhMyOpenCodeConfig | undefined, +): (string | FallbackModelObject)[] | undefined { + if (!pluginConfig) return undefined + return getRawFallbackModelsForSession(sessionID, agent, pluginConfig) +} + +function getRawFallbackModelsForSession( + sessionID: string, + agent: string | undefined, + pluginConfig: OhMyOpenCodeConfig, +): (string | FallbackModelObject)[] | undefined { const sessionCategory = SessionCategoryRegistry.get(sessionID) if (sessionCategory && pluginConfig.categories?.[sessionCategory]) { const categoryConfig = pluginConfig.categories[sessionCategory] if (categoryConfig?.fallback_models) { - return normalizeFallbackModels(categoryConfig.fallback_models) ?? [] + return normalizeFallbackModels(categoryConfig.fallback_models) } } - const tryGetFallbackFromAgent = (agentName: string): string[] | undefined => { + const tryGetFallbackFromAgent = (agentName: string): (string | FallbackModelObject)[] | undefined => { const agentConfig = pluginConfig.agents?.[agentName as keyof typeof pluginConfig.agents] if (!agentConfig) return undefined - + if (agentConfig?.fallback_models) { return normalizeFallbackModels(agentConfig.fallback_models) } - + const agentCategory = agentConfig?.category if (agentCategory && pluginConfig.categories?.[agentCategory]) { const categoryConfig = pluginConfig.categories[agentCategory] @@ -35,7 +64,7 @@ export function getFallbackModelsForSession( return normalizeFallbackModels(categoryConfig.fallback_models) } } - + return undefined } @@ -53,5 +82,5 @@ export function getFallbackModelsForSession( log(`[${HOOK_NAME}] No category/agent fallback models resolved for session`, { sessionID, agent }) - return [] + return undefined } diff --git a/src/plugin/chat-params.test.ts b/src/plugin/chat-params.test.ts index 4abcfbe93..c646c8283 100644 --- a/src/plugin/chat-params.test.ts +++ b/src/plugin/chat-params.test.ts @@ -1,8 +1,17 @@ -import { describe, expect, test } from "bun:test" +import { afterEach, describe, expect, test } from "bun:test" import { createChatParamsHandler } from "./chat-params" +import { + clearSessionPromptParams, + getSessionPromptParams, + setSessionPromptParams, +} from "../shared/session-prompt-params-state" describe("createChatParamsHandler", () => { + afterEach(() => { + clearSessionPromptParams("ses_chat_params") + }) + test("normalizes object-style agent payload and runs chat.params hooks", async () => { //#given let called = false @@ -35,7 +44,6 @@ describe("createChatParamsHandler", () => { //#then expect(called).toBe(true) }) - test("passes the original mutable message object to chat.params hooks", async () => { //#given const handler = createChatParamsHandler({ @@ -68,4 +76,61 @@ describe("createChatParamsHandler", () => { //#then expect(message.variant).toBe("high") }) + + test("applies stored prompt params for the session", async () => { + //#given + setSessionPromptParams("ses_chat_params", { + temperature: 0.4, + topP: 0.7, + options: { + reasoningEffort: "high", + thinking: { type: "disabled" }, + maxTokens: 4096, + }, + }) + + const handler = createChatParamsHandler({ + anthropicEffort: null, + }) + + const input = { + sessionID: "ses_chat_params", + agent: { name: "oracle" }, + model: { providerID: "openai", modelID: "gpt-5.4" }, + provider: { id: "openai" }, + message: {}, + } + + const output = { + temperature: 0.1, + topP: 1, + topK: 1, + options: { existing: true }, + } + + //#when + await handler(input, output) + + //#then + expect(output).toEqual({ + temperature: 0.4, + topP: 0.7, + topK: 1, + options: { + existing: true, + reasoningEffort: "high", + thinking: { type: "disabled" }, + maxTokens: 4096, + }, + }) + expect(getSessionPromptParams("ses_chat_params")).toEqual({ + temperature: 0.4, + topP: 0.7, + options: { + reasoningEffort: "high", + thinking: { type: "disabled" }, + maxTokens: 4096, + }, + }) + }) }) diff --git a/src/plugin/chat-params.ts b/src/plugin/chat-params.ts index c4bd5e626..c5d047ba2 100644 --- a/src/plugin/chat-params.ts +++ b/src/plugin/chat-params.ts @@ -1,3 +1,5 @@ +import { getSessionPromptParams } from "../shared/session-prompt-params-state" + export type ChatParamsInput = { sessionID: string agent: { name?: string } @@ -82,6 +84,22 @@ export function createChatParamsHandler(args: { if (!normalizedInput) return if (!isChatParamsOutput(output)) return + const storedPromptParams = getSessionPromptParams(normalizedInput.sessionID) + if (storedPromptParams) { + if (storedPromptParams.temperature !== undefined) { + output.temperature = storedPromptParams.temperature + } + if (storedPromptParams.topP !== undefined) { + output.topP = storedPromptParams.topP + } + if (storedPromptParams.options) { + output.options = { + ...output.options, + ...storedPromptParams.options, + } + } + } + await args.anthropicEffort?.["chat.params"]?.(normalizedInput, output) } } diff --git a/src/plugin/event-compaction-agent.test.ts b/src/plugin/event-compaction-agent.test.ts index 44b3d1910..b247b649e 100644 --- a/src/plugin/event-compaction-agent.test.ts +++ b/src/plugin/event-compaction-agent.test.ts @@ -2,6 +2,7 @@ import { afterEach, describe, expect, it } from "bun:test" import { _resetForTesting, getSessionAgent, updateSessionAgent } from "../features/claude-code-session-state" import { clearSessionModel, getSessionModel, setSessionModel } from "../shared/session-model-state" +import { clearSessionPromptParams } from "../shared/session-prompt-params-state" import { createEventHandler } from "./event" function createMinimalEventHandler() { @@ -53,6 +54,8 @@ describe("createEventHandler compaction agent filtering", () => { _resetForTesting() clearSessionModel("ses_compaction_poisoning") clearSessionModel("ses_compaction_model_poisoning") + clearSessionPromptParams("ses_compaction_poisoning") + clearSessionPromptParams("ses_compaction_model_poisoning") }) it("does not overwrite the stored session agent with compaction", async () => { diff --git a/src/plugin/event.test.ts b/src/plugin/event.test.ts index dfd2aa9ee..3cde65522 100644 --- a/src/plugin/event.test.ts +++ b/src/plugin/event.test.ts @@ -4,6 +4,7 @@ import { createEventHandler } from "./event" import { createChatMessageHandler } from "./chat-message" import { _resetForTesting, setMainSession } from "../features/claude-code-session-state" import { clearPendingModelFallback, createModelFallbackHook } from "../hooks/model-fallback/hook" +import { getSessionPromptParams, setSessionPromptParams } from "../shared/session-prompt-params-state" type EventInput = { event: { type: string; properties?: unknown } } @@ -441,6 +442,45 @@ describe("createEventHandler - event forwarding", () => { expect(disconnectedSessions).toEqual([sessionID]) expect(deletedSessions).toEqual([sessionID]) }) + + it("clears stored prompt params on session.deleted", async () => { + //#given + const eventHandler = createEventHandler({ + ctx: {} as never, + pluginConfig: {} as never, + firstMessageVariantGate: { + markSessionCreated: () => {}, + clear: () => {}, + }, + managers: { + skillMcpManager: { + disconnectSession: async () => {}, + }, + tmuxSessionManager: { + onSessionCreated: async () => {}, + onSessionDeleted: async () => {}, + }, + } as never, + hooks: {} as never, + }) + const sessionID = "ses_prompt_params_deleted" + setSessionPromptParams(sessionID, { + temperature: 0.4, + topP: 0.7, + options: { reasoningEffort: "high" }, + }) + + //#when + await eventHandler({ + event: { + type: "session.deleted", + properties: { info: { id: sessionID } }, + }, + }) + + //#then + expect(getSessionPromptParams(sessionID)).toBeUndefined() + }) }) describe("createEventHandler - retry dedupe lifecycle", () => { diff --git a/src/plugin/event.ts b/src/plugin/event.ts index 5d8383df0..126d6e819 100644 --- a/src/plugin/event.ts +++ b/src/plugin/event.ts @@ -16,7 +16,7 @@ import { setSessionFallbackChain, setPendingModelFallback, } from "../hooks/model-fallback/hook"; -import { getFallbackModelsForSession } from "../hooks/runtime-fallback/fallback-models"; +import { getRawFallbackModels } from "../hooks/runtime-fallback/fallback-models"; import { resetMessageCursor } from "../shared"; import { getAgentConfigKey } from "../shared/agent-display-names"; import { readConnectedProvidersCache } from "../shared/connected-providers-cache"; @@ -25,6 +25,7 @@ import { shouldRetryError } from "../shared/model-error-classifier"; import { buildFallbackChainFromModels } from "../shared/fallback-chain-from-models"; import { extractRetryAttempt, normalizeRetryStatusMessage } from "../shared/retry-status-utils"; import { clearSessionModel, getSessionModel, setSessionModel } from "../shared/session-model-state"; +import { clearSessionPromptParams } from "../shared/session-prompt-params-state"; import { deleteSessionTools } from "../shared/session-tools-store"; import { lspManager } from "../tools"; @@ -110,10 +111,10 @@ function applyUserConfiguredFallbackChain( pluginConfig: OhMyOpenCodeConfig, ): void { const agentKey = getAgentConfigKey(agentName); - const configuredFallbackModels = getFallbackModelsForSession(sessionID, agentKey, pluginConfig); - if (configuredFallbackModels.length === 0) return; + const rawFallbackModels = getRawFallbackModels(sessionID, agentKey, pluginConfig); + if (!rawFallbackModels || rawFallbackModels.length === 0) return; - const fallbackChain = buildFallbackChainFromModels(configuredFallbackModels, currentProviderID); + const fallbackChain = buildFallbackChainFromModels(rawFallbackModels, currentProviderID); if (fallbackChain && fallbackChain.length > 0) { setSessionFallbackChain(sessionID, fallbackChain); @@ -330,6 +331,7 @@ export function createEventHandler(args: { resetMessageCursor(sessionInfo.id); firstMessageVariantGate.clear(sessionInfo.id); clearSessionModel(sessionInfo.id); + clearSessionPromptParams(sessionInfo.id); syncSubagentSessions.delete(sessionInfo.id); if (wasSyncSubagentSession) { subagentSessions.delete(sessionInfo.id); diff --git a/src/shared/fallback-chain-from-models.test.ts b/src/shared/fallback-chain-from-models.test.ts index cdfbb9daf..16690d12c 100644 --- a/src/shared/fallback-chain-from-models.test.ts +++ b/src/shared/fallback-chain-from-models.test.ts @@ -1,6 +1,13 @@ -import { describe, test, expect } from "bun:test" -import { buildFallbackChainFromModels, parseFallbackModelEntry } from "./fallback-chain-from-models" +import { describe, test, it, expect } from "bun:test" +import { + parseFallbackModelEntry, + parseFallbackModelObjectEntry, + buildFallbackChainFromModels, + findMostSpecificFallbackEntry, +} from "./fallback-chain-from-models" +import { flattenToFallbackModelStrings } from "./model-resolver" +// Upstream tests describe("fallback-chain-from-models", () => { test("parses provider/model entry with parenthesized variant", () => { //#given @@ -61,3 +68,330 @@ describe("fallback-chain-from-models", () => { ]) }) }) + +// Object-style entry tests +describe("parseFallbackModelEntry (extended)", () => { + it("parses provider/model string", () => { + const result = parseFallbackModelEntry("anthropic/claude-sonnet-4-6", undefined) + expect(result).toEqual({ + providers: ["anthropic"], + model: "claude-sonnet-4-6", + }) + }) + + it("parses model with parenthesized variant", () => { + const result = parseFallbackModelEntry("anthropic/claude-sonnet-4-6(high)", undefined) + expect(result).toEqual({ + providers: ["anthropic"], + model: "claude-sonnet-4-6", + variant: "high", + }) + }) + + it("parses model with space variant", () => { + const result = parseFallbackModelEntry("openai/gpt-5.4 xhigh", undefined) + expect(result).toEqual({ + providers: ["openai"], + model: "gpt-5.4", + variant: "xhigh", + }) + }) + + it("parses model with minimal space variant", () => { + const result = parseFallbackModelEntry("openai/gpt-5.4 minimal", undefined) + expect(result).toEqual({ + providers: ["openai"], + model: "gpt-5.4", + variant: "minimal", + }) + }) + + it("uses context provider when no provider prefix", () => { + const result = parseFallbackModelEntry("claude-sonnet-4-6", "anthropic") + expect(result).toEqual({ + providers: ["anthropic"], + model: "claude-sonnet-4-6", + }) + }) + + it("returns undefined for empty string", () => { + expect(parseFallbackModelEntry("", undefined)).toBeUndefined() + expect(parseFallbackModelEntry(" ", undefined)).toBeUndefined() + }) +}) + +describe("parseFallbackModelObjectEntry", () => { + it("parses object with model only", () => { + const result = parseFallbackModelObjectEntry( + { model: "anthropic/claude-sonnet-4-6" }, + undefined, + ) + expect(result).toEqual({ + providers: ["anthropic"], + model: "claude-sonnet-4-6", + }) + }) + + it("parses object with variant override", () => { + const result = parseFallbackModelObjectEntry( + { model: "anthropic/claude-sonnet-4-6", variant: "high" }, + undefined, + ) + expect(result).toEqual({ + providers: ["anthropic"], + model: "claude-sonnet-4-6", + variant: "high", + }) + }) + + it("object variant overrides inline variant", () => { + const result = parseFallbackModelObjectEntry( + { model: "anthropic/claude-sonnet-4-6(low)", variant: "high" }, + undefined, + ) + expect(result).toEqual({ + providers: ["anthropic"], + model: "claude-sonnet-4-6", + variant: "high", + }) + }) + + it("carries reasoningEffort and temperature", () => { + const result = parseFallbackModelObjectEntry( + { + model: "openai/gpt-5.4", + variant: "high", + reasoningEffort: "high", + temperature: 0.5, + }, + undefined, + ) + expect(result).toEqual({ + providers: ["openai"], + model: "gpt-5.4", + variant: "high", + reasoningEffort: "high", + temperature: 0.5, + }) + }) + + it("carries thinking config", () => { + const result = parseFallbackModelObjectEntry( + { + model: "anthropic/claude-sonnet-4-6", + thinking: { type: "enabled", budgetTokens: 10000 }, + }, + undefined, + ) + expect(result).toEqual({ + providers: ["anthropic"], + model: "claude-sonnet-4-6", + thinking: { type: "enabled", budgetTokens: 10000 }, + }) + }) + + it("carries all optional fields", () => { + const result = parseFallbackModelObjectEntry( + { + model: "openai/gpt-5.4", + variant: "xhigh", + reasoningEffort: "xhigh", + temperature: 0.3, + top_p: 0.9, + maxTokens: 8192, + thinking: { type: "disabled" }, + }, + undefined, + ) + expect(result).toEqual({ + providers: ["openai"], + model: "gpt-5.4", + variant: "xhigh", + reasoningEffort: "xhigh", + temperature: 0.3, + top_p: 0.9, + maxTokens: 8192, + thinking: { type: "disabled" }, + }) + }) +}) + +describe("buildFallbackChainFromModels (mixed)", () => { + it("handles string input", () => { + const result = buildFallbackChainFromModels("anthropic/claude-sonnet-4-6", undefined) + expect(result).toEqual([ + { providers: ["anthropic"], model: "claude-sonnet-4-6" }, + ]) + }) + + it("handles string array", () => { + const result = buildFallbackChainFromModels( + ["anthropic/claude-sonnet-4-6", "openai/gpt-5.4"], + undefined, + ) + expect(result).toEqual([ + { providers: ["anthropic"], model: "claude-sonnet-4-6" }, + { providers: ["openai"], model: "gpt-5.4" }, + ]) + }) + + it("handles mixed array of strings and objects", () => { + const result = buildFallbackChainFromModels( + [ + { model: "anthropic/claude-sonnet-4-6", variant: "high", reasoningEffort: "high" }, + { model: "openai/gpt-5.4", reasoningEffort: "xhigh" }, + "chutes/kimi-k2.5", + { model: "chutes/glm-5", temperature: 0.7 }, + "google/gemini-3-flash", + ], + undefined, + ) + expect(result).toEqual([ + { providers: ["anthropic"], model: "claude-sonnet-4-6", variant: "high", reasoningEffort: "high" }, + { providers: ["openai"], model: "gpt-5.4", reasoningEffort: "xhigh" }, + { providers: ["chutes"], model: "kimi-k2.5" }, + { providers: ["chutes"], model: "glm-5", temperature: 0.7 }, + { providers: ["google"], model: "gemini-3-flash" }, + ]) + }) + + it("returns undefined for empty/undefined input", () => { + expect(buildFallbackChainFromModels(undefined, undefined)).toBeUndefined() + expect(buildFallbackChainFromModels([], undefined)).toBeUndefined() + }) + + it("filters out invalid entries", () => { + const result = buildFallbackChainFromModels( + ["", "anthropic/claude-sonnet-4-6", " "], + undefined, + ) + expect(result).toEqual([ + { providers: ["anthropic"], model: "claude-sonnet-4-6" }, + ]) + }) +}) + +describe("flattenToFallbackModelStrings", () => { + it("returns undefined for undefined input", () => { + expect(flattenToFallbackModelStrings(undefined)).toBeUndefined() + }) + + it("passes through plain strings", () => { + expect(flattenToFallbackModelStrings(["anthropic/claude-sonnet-4-6"])).toEqual([ + "anthropic/claude-sonnet-4-6", + ]) + }) + + it("flattens object with explicit variant", () => { + expect(flattenToFallbackModelStrings([ + { model: "anthropic/claude-sonnet-4-6", variant: "high" }, + ])).toEqual(["anthropic/claude-sonnet-4-6(high)"]) + }) + + it("preserves inline variant when no explicit variant", () => { + expect(flattenToFallbackModelStrings([ + { model: "anthropic/claude-sonnet-4-6(high)" }, + ])).toEqual(["anthropic/claude-sonnet-4-6(high)"]) + }) + + it("explicit variant overrides inline variant (no double-suffix)", () => { + expect(flattenToFallbackModelStrings([ + { model: "anthropic/claude-sonnet-4-6(low)", variant: "high" }, + ])).toEqual(["anthropic/claude-sonnet-4-6(high)"]) + }) + + it("explicit variant overrides space-suffix variant", () => { + expect(flattenToFallbackModelStrings([ + { model: "openai/gpt-5.4 high", variant: "low" }, + ])).toEqual(["openai/gpt-5.4(low)"]) + }) + + it("explicit variant overrides minimal space-suffix variant", () => { + expect(flattenToFallbackModelStrings([ + { model: "openai/gpt-5.4 minimal", variant: "low" }, + ])).toEqual(["openai/gpt-5.4(low)"]) + }) + + it("preserves trailing non-variant suffixes when adding explicit variant", () => { + expect(flattenToFallbackModelStrings([ + { model: "openai/gpt-5.4 preview", variant: "low" }, + ])).toEqual(["openai/gpt-5.4 preview(low)"]) + }) + + it("flattens object without variant", () => { + expect(flattenToFallbackModelStrings([ + { model: "openai/gpt-5.4" }, + ])).toEqual(["openai/gpt-5.4"]) + }) + + it("handles mixed array", () => { + expect(flattenToFallbackModelStrings([ + "anthropic/claude-sonnet-4-6", + { model: "openai/gpt-5.4", variant: "high" }, + { model: "google/gemini-3-flash(low)" }, + ])).toEqual([ + "anthropic/claude-sonnet-4-6", + "openai/gpt-5.4(high)", + "google/gemini-3-flash(low)", + ]) + }) +}) + +describe("findMostSpecificFallbackEntry", () => { + it("picks exact match over prefix match", () => { + const chain = [ + { providers: ["openai"], model: "gpt-5.4" }, + { providers: ["openai"], model: "gpt-5.4-preview" }, + ] + const result = findMostSpecificFallbackEntry("openai", "gpt-5.4-preview", chain) + expect(result?.model).toBe("gpt-5.4-preview") + }) + + it("returns prefix match when no exact match exists", () => { + const chain = [ + { providers: ["openai"], model: "gpt-5.4" }, + ] + const result = findMostSpecificFallbackEntry("openai", "gpt-5.4-preview", chain) + expect(result?.model).toBe("gpt-5.4") + }) + + it("returns undefined when no entry matches", () => { + const chain = [ + { providers: ["anthropic"], model: "claude-sonnet-4-6" }, + ] + expect(findMostSpecificFallbackEntry("openai", "gpt-5.4", chain)).toBeUndefined() + }) + + it("sorts by matched prefix length, not insertion order", () => { + // Both entries share the same provider so both match as prefixes; + // the longer (more-specific) prefix must win regardless of array order. + const chain = [ + { providers: ["openai"], model: "gpt-5" }, + { providers: ["openai"], model: "gpt-5.4-preview" }, + ] + const result = findMostSpecificFallbackEntry("openai", "gpt-5.4-preview-2026", chain) + expect(result?.model).toBe("gpt-5.4-preview") + }) + + it("is case-insensitive", () => { + const chain = [ + { providers: ["OpenAI"], model: "GPT-5.4" }, + ] + const result = findMostSpecificFallbackEntry("openai", "gpt-5.4-preview", chain) + expect(result?.model).toBe("GPT-5.4") + }) + + it("preserves variant and settings from matched entry", () => { + const chain = [ + { providers: ["openai"], model: "gpt-5.4", variant: "high", temperature: 0.7 }, + { providers: ["openai"], model: "gpt-5.4-preview", variant: "low", reasoningEffort: "medium" }, + ] + const result = findMostSpecificFallbackEntry("openai", "gpt-5.4-preview", chain) + expect(result).toEqual({ + providers: ["openai"], + model: "gpt-5.4-preview", + variant: "low", + reasoningEffort: "medium", + }) + }) +}) diff --git a/src/shared/fallback-chain-from-models.ts b/src/shared/fallback-chain-from-models.ts index 54918f062..92e3fe445 100644 --- a/src/shared/fallback-chain-from-models.ts +++ b/src/shared/fallback-chain-from-models.ts @@ -1,16 +1,7 @@ import type { FallbackEntry } from "./model-requirements" +import type { FallbackModelObject } from "../config/schema/fallback-models" import { normalizeFallbackModels } from "./model-resolver" - -const KNOWN_VARIANTS = new Set([ - "low", - "medium", - "high", - "xhigh", - "max", - "none", - "auto", - "thinking", -]) +import { KNOWN_VARIANTS } from "./known-variants" function parseVariantFromModel(rawModel: string): { modelID: string; variant?: string } { const trimmedModel = rawModel.trim() @@ -61,8 +52,60 @@ export function parseFallbackModelEntry( } } +export function parseFallbackModelObjectEntry( + obj: FallbackModelObject, + contextProviderID: string | undefined, + defaultProviderID = "opencode", +): FallbackEntry | undefined { + // Reuse the string-based parser for provider/model/variant extraction. + const base = parseFallbackModelEntry(obj.model, contextProviderID, defaultProviderID) + if (!base) return undefined + + return { + ...base, + // Explicit object variant overrides any inline variant in the model string. + variant: obj.variant ?? base.variant, + reasoningEffort: obj.reasoningEffort, + temperature: obj.temperature, + top_p: obj.top_p, + maxTokens: obj.maxTokens, + thinking: obj.thinking, + } +} + +/** + * Find the most specific FallbackEntry whose `provider/model` is a prefix of + * the resolved `provider/modelID`. Longest match wins so that e.g. + * `openai/gpt-5.4-preview` picks the entry for `openai/gpt-5.4-preview` over + * the shorter `openai/gpt-5.4`. + */ +export function findMostSpecificFallbackEntry( + providerID: string, + modelID: string, + chain: FallbackEntry[], +): FallbackEntry | undefined { + const resolved = `${providerID}/${modelID}`.toLowerCase() + + // Collect entries whose provider/model is a prefix of the resolved model, + // together with the length of the matching prefix (longest match wins). + const matches: { entry: FallbackEntry; matchLen: number }[] = [] + for (const entry of chain) { + for (const p of entry.providers) { + const candidate = `${p}/${entry.model}`.toLowerCase() + if (resolved.startsWith(candidate)) { + matches.push({ entry, matchLen: candidate.length }) + break // one match per entry is enough + } + } + } + + if (matches.length === 0) return undefined + matches.sort((a, b) => b.matchLen - a.matchLen) + return matches[0].entry +} + export function buildFallbackChainFromModels( - fallbackModels: string | string[] | undefined, + fallbackModels: string | (string | FallbackModelObject)[] | undefined, contextProviderID: string | undefined, defaultProviderID = "opencode", ): FallbackEntry[] | undefined { @@ -70,7 +113,12 @@ export function buildFallbackChainFromModels( if (!normalized || normalized.length === 0) return undefined const parsed = normalized - .map((model) => parseFallbackModelEntry(model, contextProviderID, defaultProviderID)) + .map((entry) => { + if (typeof entry === "string") { + return parseFallbackModelEntry(entry, contextProviderID, defaultProviderID) + } + return parseFallbackModelObjectEntry(entry, contextProviderID, defaultProviderID) + }) .filter((entry): entry is FallbackEntry => entry !== undefined) if (parsed.length === 0) return undefined diff --git a/src/shared/index.ts b/src/shared/index.ts index 39fcc18ea..8ad6a9d6a 100644 --- a/src/shared/index.ts +++ b/src/shared/index.ts @@ -35,7 +35,7 @@ export * from "./agent-tool-restrictions" export * from "./model-requirements" export * from "./model-resolver" export { normalizeModel, normalizeModelID } from "./model-normalization" -export { normalizeFallbackModels } from "./model-resolver" +export { normalizeFallbackModels, flattenToFallbackModelStrings } from "./model-resolver" export { resolveModelPipeline } from "./model-resolution-pipeline" export type { ModelResolutionRequest, diff --git a/src/shared/known-variants.ts b/src/shared/known-variants.ts new file mode 100644 index 000000000..e8a906d3a --- /dev/null +++ b/src/shared/known-variants.ts @@ -0,0 +1,16 @@ +/** + * Canonical set of recognised variant / effort tokens. + * Used by parseFallbackModelEntry (space-suffix detection) and + * flattenToFallbackModelStrings (inline-variant stripping). + */ +export const KNOWN_VARIANTS = new Set([ + "low", + "medium", + "high", + "xhigh", + "max", + "minimal", + "none", + "auto", + "thinking", +]) diff --git a/src/shared/model-requirements.ts b/src/shared/model-requirements.ts index 16f7e78c9..ea8a1cb91 100644 --- a/src/shared/model-requirements.ts +++ b/src/shared/model-requirements.ts @@ -2,6 +2,11 @@ export type FallbackEntry = { providers: string[]; model: string; variant?: string; // Entry-specific variant (e.g., GPT→high, Opus→max) + reasoningEffort?: string; + temperature?: number; + top_p?: number; + maxTokens?: number; + thinking?: { type: "enabled" | "disabled"; budgetTokens?: number }; }; export type ModelRequirement = { diff --git a/src/shared/model-resolver.ts b/src/shared/model-resolver.ts index 977112cb1..8b6a33d03 100644 --- a/src/shared/model-resolver.ts +++ b/src/shared/model-resolver.ts @@ -1,6 +1,8 @@ import type { FallbackEntry } from "./model-requirements" +import type { FallbackModelObject } from "../config/schema/fallback-models" import { normalizeModel } from "./model-normalization" import { resolveModelPipeline } from "./model-resolution-pipeline" +import { KNOWN_VARIANTS } from "./known-variants" export type ModelResolutionInput = { userModel?: string @@ -61,11 +63,45 @@ export function resolveModelWithFallback( } /** - * Normalizes fallback_models config (which can be string or string[]) to string[] - * Centralized helper to avoid duplicated normalization logic + * Normalizes fallback_models config to a mixed array. + * Accepts string, string[], or mixed arrays of strings and FallbackModelObject entries. */ -export function normalizeFallbackModels(models: string | string[] | undefined): string[] | undefined { +export function normalizeFallbackModels( + models: string | (string | FallbackModelObject)[] | undefined, +): (string | FallbackModelObject)[] | undefined { if (!models) return undefined if (typeof models === "string") return [models] return models } + +/** + * Extracts plain model strings from a mixed fallback models array. + * Object entries are flattened to "model" or "model(variant)" strings. + * Use this when consumers need string[] (e.g., resolveModelForDelegateTask). + */ +export function flattenToFallbackModelStrings( + models: (string | FallbackModelObject)[] | undefined, +): string[] | undefined { + if (!models) return undefined + return models.map((entry) => { + if (typeof entry === "string") return entry + const variant = entry.variant + if (variant) { + // Strip any supported inline variant syntax before appending explicit override. + // Supports both parenthesized and space-suffix forms so we don't emit + // invalid strings like "provider/model high(low)". + const model = entry.model + .replace(/\([^()]+\)\s*$/, "") + .replace(/\s+([a-z][a-z0-9_-]*)\s*$/i, (match, suffix) => { + const normalized = String(suffix).toLowerCase() + return KNOWN_VARIANTS.has(normalized) + ? "" + : match + }) + .trim() + return `${model}(${variant})` + } + // No explicit variant — preserve model string as-is (including any inline variant) + return entry.model + }) +} diff --git a/src/shared/session-prompt-params-state.test.ts b/src/shared/session-prompt-params-state.test.ts new file mode 100644 index 000000000..b97a80565 --- /dev/null +++ b/src/shared/session-prompt-params-state.test.ts @@ -0,0 +1,65 @@ +import { afterEach, describe, expect, test } from "bun:test" + +import { + clearAllSessionPromptParams, + clearSessionPromptParams, + getSessionPromptParams, + setSessionPromptParams, +} from "./session-prompt-params-state" + +describe("session-prompt-params-state", () => { + afterEach(() => { + clearAllSessionPromptParams() + }) + + test("stores and returns prompt params by session", () => { + //#given + const sessionID = "ses_prompt_params" + const params = { + temperature: 0.4, + topP: 0.7, + options: { + reasoningEffort: "high", + maxTokens: 4096, + }, + } + + //#when + setSessionPromptParams(sessionID, params) + + //#then + expect(getSessionPromptParams(sessionID)).toEqual(params) + }) + + test("returns copies so callers cannot mutate stored state", () => { + //#given + const sessionID = "ses_prompt_params_copy" + setSessionPromptParams(sessionID, { + temperature: 0.2, + options: { reasoningEffort: "medium" }, + }) + + //#when + const result = getSessionPromptParams(sessionID)! + result.temperature = 0.9 + result.options!.reasoningEffort = "max" + + //#then + expect(getSessionPromptParams(sessionID)).toEqual({ + temperature: 0.2, + options: { reasoningEffort: "medium" }, + }) + }) + + test("clears a single session", () => { + //#given + const sessionID = "ses_prompt_params_clear" + setSessionPromptParams(sessionID, { topP: 0.5 }) + + //#when + clearSessionPromptParams(sessionID) + + //#then + expect(getSessionPromptParams(sessionID)).toBeUndefined() + }) +}) diff --git a/src/shared/session-prompt-params-state.ts b/src/shared/session-prompt-params-state.ts new file mode 100644 index 000000000..36e956cfc --- /dev/null +++ b/src/shared/session-prompt-params-state.ts @@ -0,0 +1,34 @@ +export type SessionPromptParams = { + temperature?: number + topP?: number + options?: Record +} + +const sessionPromptParams = new Map() + +export function setSessionPromptParams(sessionID: string, params: SessionPromptParams): void { + sessionPromptParams.set(sessionID, { + ...(params.temperature !== undefined ? { temperature: params.temperature } : {}), + ...(params.topP !== undefined ? { topP: params.topP } : {}), + ...(params.options !== undefined ? { options: { ...params.options } } : {}), + }) +} + +export function getSessionPromptParams(sessionID: string): SessionPromptParams | undefined { + const params = sessionPromptParams.get(sessionID) + if (!params) return undefined + + return { + ...(params.temperature !== undefined ? { temperature: params.temperature } : {}), + ...(params.topP !== undefined ? { topP: params.topP } : {}), + ...(params.options !== undefined ? { options: { ...params.options } } : {}), + } +} + +export function clearSessionPromptParams(sessionID: string): void { + sessionPromptParams.delete(sessionID) +} + +export function clearAllSessionPromptParams(): void { + sessionPromptParams.clear() +} diff --git a/src/tools/delegate-task/background-task.ts b/src/tools/delegate-task/background-task.ts index 31aeea8f8..bc96423eb 100644 --- a/src/tools/delegate-task/background-task.ts +++ b/src/tools/delegate-task/background-task.ts @@ -1,4 +1,4 @@ -import type { DelegateTaskArgs, ToolContextWithMetadata } from "./types" +import type { DelegateTaskArgs, ToolContextWithMetadata, DelegatedModelConfig } from "./types" import type { ExecutorContext, ParentContext } from "./executor-types" import type { FallbackEntry } from "../../shared/model-requirements" import { getTimingConfig } from "./timing" @@ -16,7 +16,7 @@ export async function executeBackgroundTask( executorCtx: ExecutorContext, parentContext: ParentContext, agentToUse: string, - categoryModel: { providerID: string; modelID: string; variant?: string } | undefined, + categoryModel: DelegatedModelConfig | undefined, systemContent: string | undefined, fallbackChain?: FallbackEntry[], ): Promise { diff --git a/src/tools/delegate-task/category-resolver.test.ts b/src/tools/delegate-task/category-resolver.test.ts index 3c9124735..49ce1f733 100644 --- a/src/tools/delegate-task/category-resolver.test.ts +++ b/src/tools/delegate-task/category-resolver.test.ts @@ -114,4 +114,260 @@ describe("resolveCategoryExecution", () => { { providers: ["openai"], model: "gpt-5.2", variant: "high" }, ]) }) + + test("promotes object-style fallback model settings to categoryModel when fallback becomes initial model", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-5.4"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const agentsSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = { + category: "deep", + prompt: "test prompt", + description: "Test task", + run_in_background: false, + load_skills: [], + blockedBy: undefined, + enableSkillTools: false, + } + const executorCtx = createMockExecutorContext() + executorCtx.userCategories = { + deep: { + fallback_models: [ + { + model: "openai/gpt-5.4 high", + variant: "low", + reasoningEffort: "high", + temperature: 0.4, + top_p: 0.7, + maxTokens: 4096, + thinking: { type: "disabled" }, + }, + ], + }, + } + + //#when + const result = await resolveCategoryExecution(args, executorCtx, undefined, "anthropic/claude-sonnet-4-6") + + //#then + expect(result.error).toBeUndefined() + expect(result.actualModel).toBe("openai/gpt-5.4") + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-5.4", + variant: "low", + reasoningEffort: "high", + temperature: 0.4, + top_p: 0.7, + maxTokens: 4096, + thinking: { type: "disabled" }, + }) + cacheSpy.mockRestore() + agentsSpy.mockRestore() + }) + + test("matches promoted fallback settings after fuzzy model resolution", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-5.4-preview"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const agentsSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = { + category: "deep", + prompt: "test prompt", + description: "Test task", + run_in_background: false, + load_skills: [], + blockedBy: undefined, + enableSkillTools: false, + } + const executorCtx = createMockExecutorContext() + executorCtx.userCategories = { + deep: { + fallback_models: [ + { + model: "openai/gpt-5.4", + variant: "low", + reasoningEffort: "high", + temperature: 0.6, + top_p: 0.5, + maxTokens: 1234, + thinking: { type: "disabled" }, + }, + ], + }, + } + + //#when + const result = await resolveCategoryExecution(args, executorCtx, undefined, "anthropic/claude-sonnet-4-6") + + //#then + expect(result.error).toBeUndefined() + expect(result.actualModel).toBe("openai/gpt-5.4-preview") + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-5.4-preview", + variant: "low", + reasoningEffort: "high", + temperature: 0.6, + top_p: 0.5, + maxTokens: 1234, + thinking: { type: "disabled" }, + }) + cacheSpy.mockRestore() + agentsSpy.mockRestore() + }) + + test("prefers exact promoted fallback match over earlier fuzzy prefix match", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-5.4-preview"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const agentsSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = { + category: "deep", + prompt: "test prompt", + description: "Test task", + run_in_background: false, + load_skills: [], + blockedBy: undefined, + enableSkillTools: false, + } + const executorCtx = createMockExecutorContext() + executorCtx.userCategories = { + deep: { + fallback_models: [ + { + model: "openai/gpt-5.4", + variant: "low", + reasoningEffort: "medium", + }, + { + model: "openai/gpt-5.4-preview", + variant: "max", + reasoningEffort: "high", + }, + ], + }, + } + + //#when + const result = await resolveCategoryExecution(args, executorCtx, undefined, "anthropic/claude-sonnet-4-6") + + //#then + expect(result.error).toBeUndefined() + expect(result.actualModel).toBe("openai/gpt-5.4-preview") + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-5.4-preview", + variant: "max", + reasoningEffort: "high", + }) + cacheSpy.mockRestore() + agentsSpy.mockRestore() + }) + + test("matches promoted fallback settings when fuzzy resolution extends configured model without hyphen", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-5.4o"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const agentsSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = { + category: "deep", + prompt: "test prompt", + description: "Test task", + run_in_background: false, + load_skills: [], + blockedBy: undefined, + enableSkillTools: false, + } + const executorCtx = createMockExecutorContext() + executorCtx.userCategories = { + deep: { + fallback_models: [ + { + model: "openai/gpt-5.4", + variant: "low", + reasoningEffort: "high", + }, + ], + }, + } + + //#when + const result = await resolveCategoryExecution(args, executorCtx, undefined, "anthropic/claude-sonnet-4-6") + + //#then + expect(result.error).toBeUndefined() + expect(result.actualModel).toBe("openai/gpt-5.4o") + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-5.4o", + variant: "low", + reasoningEffort: "high", + }) + cacheSpy.mockRestore() + agentsSpy.mockRestore() + }) + + test("prefers the most specific prefix match when fallback entries share a prefix", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-4o"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const agentsSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = { + category: "deep", + prompt: "test prompt", + description: "Test task", + run_in_background: false, + load_skills: [], + blockedBy: undefined, + enableSkillTools: false, + } + const executorCtx = createMockExecutorContext() + executorCtx.userCategories = { + deep: { + fallback_models: [ + { + model: "openai/gpt-4", + variant: "low", + reasoningEffort: "medium", + }, + { + model: "openai/gpt-4o", + variant: "max", + reasoningEffort: "high", + }, + ], + }, + } + + //#when + const result = await resolveCategoryExecution(args, executorCtx, undefined, "anthropic/claude-sonnet-4-6") + + //#then + expect(result.error).toBeUndefined() + expect(result.actualModel).toBe("openai/gpt-4o") + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-4o", + variant: "max", + reasoningEffort: "high", + }) + cacheSpy.mockRestore() + agentsSpy.mockRestore() + }) }) diff --git a/src/tools/delegate-task/category-resolver.ts b/src/tools/delegate-task/category-resolver.ts index a925bb294..492798155 100644 --- a/src/tools/delegate-task/category-resolver.ts +++ b/src/tools/delegate-task/category-resolver.ts @@ -7,14 +7,16 @@ import { SISYPHUS_JUNIOR_AGENT } from "./sisyphus-junior-agent" import { resolveCategoryConfig } from "./categories" import { parseModelString } from "./model-string-parser" import { CATEGORY_MODEL_REQUIREMENTS } from "../../shared/model-requirements" -import { normalizeFallbackModels } from "../../shared/model-resolver" -import { buildFallbackChainFromModels } from "../../shared/fallback-chain-from-models" +import { normalizeFallbackModels, flattenToFallbackModelStrings } from "../../shared/model-resolver" +import { buildFallbackChainFromModels, findMostSpecificFallbackEntry } from "../../shared/fallback-chain-from-models" import { getAvailableModelsForDelegateTask } from "./available-models" import { resolveModelForDelegateTask } from "./model-selection" +import type { DelegatedModelConfig } from "./types" + export interface CategoryResolutionResult { agentToUse: string - categoryModel: { providerID: string; modelID: string; variant?: string } | undefined + categoryModel: DelegatedModelConfig | undefined categoryPromptAppend: string | undefined maxPromptTokens?: number modelInfo: ModelFallbackInfo | undefined @@ -84,8 +86,9 @@ Available categories: ${allCategoryNames}`, const normalizedConfiguredFallbackModels = normalizeFallbackModels(resolved.config.fallback_models) let actualModel: string | undefined let modelInfo: ModelFallbackInfo | undefined - let categoryModel: { providerID: string; modelID: string; variant?: string } | undefined + let categoryModel: DelegatedModelConfig | undefined let isModelResolutionSkipped = false + let fallbackEntry: FallbackEntry | undefined const overrideModel = sisyphusJuniorModel const explicitCategoryModel = userCategories?.[args.category!]?.model @@ -108,7 +111,7 @@ Available categories: ${allCategoryNames}`, } else { const resolution = resolveModelForDelegateTask({ userModel: explicitCategoryModel ?? overrideModel, - userFallbackModels: normalizedConfiguredFallbackModels, + userFallbackModels: flattenToFallbackModelStrings(normalizedConfiguredFallbackModels), categoryDefaultModel: resolved.model, isUserConfiguredCategoryModel: resolved.isUserConfiguredModel, fallbackChain: requirement.fallbackChain, @@ -119,7 +122,8 @@ Available categories: ${allCategoryNames}`, if (resolution && "skipped" in resolution) { isModelResolutionSkipped = true } else if (resolution) { - const { model: resolvedModel, variant: resolvedVariant } = resolution + const { model: resolvedModel, variant: resolvedVariant, fallbackEntry: resolvedFallbackEntry } = resolution + fallbackEntry = resolvedFallbackEntry actualModel = resolvedModel if (!parseModelString(actualModel)) { @@ -198,6 +202,26 @@ Available categories: ${categoryNames.join(", ")}`, defaultProviderID, ) + // Apply per-model settings from the source that provided the match: + // 1. fallbackEntry from resolver (built-in chain match) — exact, no lookup needed + // 2. configuredFallbackChain (user's fallback_models) — prefix match against user config + const effectiveEntry = fallbackEntry + ?? (categoryModel && configuredFallbackChain + ? findMostSpecificFallbackEntry(categoryModel.providerID, categoryModel.modelID, configuredFallbackChain) + : undefined) + + if (categoryModel && effectiveEntry) { + categoryModel = { + ...categoryModel, + variant: userCategories?.[args.category!]?.variant ?? effectiveEntry.variant ?? categoryModel.variant, + reasoningEffort: effectiveEntry.reasoningEffort, + temperature: effectiveEntry.temperature, + top_p: effectiveEntry.top_p, + maxTokens: effectiveEntry.maxTokens, + thinking: effectiveEntry.thinking, + } + } + return { agentToUse: SISYPHUS_JUNIOR_AGENT, categoryModel, diff --git a/src/tools/delegate-task/model-selection.ts b/src/tools/delegate-task/model-selection.ts index 79fbec9f5..4d52308be 100644 --- a/src/tools/delegate-task/model-selection.ts +++ b/src/tools/delegate-task/model-selection.ts @@ -53,7 +53,7 @@ export function resolveModelForDelegateTask(input: { fallbackChain?: FallbackEntry[] availableModels: Set systemDefaultModel?: string -}): { model: string; variant?: string } | { skipped: true } | undefined { +}): { model: string; variant?: string; fallbackEntry?: FallbackEntry } | { skipped: true } | undefined { const userModel = normalizeModel(input.userModel) if (userModel) { return { model: userModel } @@ -119,7 +119,7 @@ export function resolveModelForDelegateTask(input: { const provider = first?.providers?.[0] if (provider) { const transformedModelId = transformModelForProvider(provider, first.model) - return { model: `${provider}/${transformedModelId}`, variant: first.variant } + return { model: `${provider}/${transformedModelId}`, variant: first.variant, fallbackEntry: first } } } else { for (const entry of fallbackChain) { @@ -128,20 +128,20 @@ export function resolveModelForDelegateTask(input: { const match = fuzzyMatchModel(fullModel, input.availableModels, [provider]) if (match) { if (explicitHighModel && entry.variant === "high" && match === explicitHighBaseModel) { - return { model: explicitHighModel } + return { model: explicitHighModel, fallbackEntry: entry } } - return { model: match, variant: entry.variant } + return { model: match, variant: entry.variant, fallbackEntry: entry } } } const crossProviderMatch = fuzzyMatchModel(entry.model, input.availableModels) if (crossProviderMatch) { if (explicitHighModel && entry.variant === "high" && crossProviderMatch === explicitHighBaseModel) { - return { model: explicitHighModel } + return { model: explicitHighModel, fallbackEntry: entry } } - return { model: crossProviderMatch, variant: entry.variant } + return { model: crossProviderMatch, variant: entry.variant, fallbackEntry: entry } } } } diff --git a/src/tools/delegate-task/model-string-parser.ts b/src/tools/delegate-task/model-string-parser.ts index d86f23324..820bb3cc3 100644 --- a/src/tools/delegate-task/model-string-parser.ts +++ b/src/tools/delegate-task/model-string-parser.ts @@ -4,6 +4,7 @@ const KNOWN_VARIANTS = new Set([ "high", "xhigh", "max", + "minimal", "none", "auto", "thinking", diff --git a/src/tools/delegate-task/subagent-resolver.test.ts b/src/tools/delegate-task/subagent-resolver.test.ts index 28c7a4731..0d53d7157 100644 --- a/src/tools/delegate-task/subagent-resolver.test.ts +++ b/src/tools/delegate-task/subagent-resolver.test.ts @@ -175,4 +175,245 @@ describe("resolveSubagentExecution", () => { ]) cacheSpy.mockRestore() }) + + test("promotes object-style fallback model settings to categoryModel when subagent fallback becomes initial model", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-5.4"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const connectedSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = createBaseArgs({ subagent_type: "explore" }) + const executorCtx = createExecutorContext( + async () => ([ + { name: "explore", mode: "subagent", model: "quotio/claude-haiku-4-5-unavailable" }, + ]), + { + agentOverrides: { + explore: { + fallback_models: [ + { + model: "openai/gpt-5.4 high", + variant: "low", + reasoningEffort: "high", + temperature: 0.2, + top_p: 0.8, + maxTokens: 2048, + thinking: { type: "disabled" }, + }, + ], + }, + } as ExecutorContext["agentOverrides"], + } + ) + + //#when + const result = await resolveSubagentExecution(args, executorCtx, "sisyphus", "deep") + + //#then + expect(result.error).toBeUndefined() + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-5.4", + variant: "low", + reasoningEffort: "high", + temperature: 0.2, + top_p: 0.8, + maxTokens: 2048, + thinking: { type: "disabled" }, + }) + cacheSpy.mockRestore() + connectedSpy.mockRestore() + }) + + test("matches promoted fallback settings after fuzzy model resolution", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-5.4-preview"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const connectedSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = createBaseArgs({ subagent_type: "explore" }) + const executorCtx = createExecutorContext( + async () => ([ + { name: "explore", mode: "subagent", model: "quotio/claude-haiku-4-5-unavailable" }, + ]), + { + agentOverrides: { + explore: { + fallback_models: [ + { + model: "openai/gpt-5.4", + variant: "low", + reasoningEffort: "high", + temperature: 0.3, + top_p: 0.4, + maxTokens: 2222, + thinking: { type: "disabled" }, + }, + ], + }, + } as ExecutorContext["agentOverrides"], + } + ) + + //#when + const result = await resolveSubagentExecution(args, executorCtx, "sisyphus", "deep") + + //#then + expect(result.error).toBeUndefined() + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-5.4-preview", + variant: "low", + reasoningEffort: "high", + temperature: 0.3, + top_p: 0.4, + maxTokens: 2222, + thinking: { type: "disabled" }, + }) + cacheSpy.mockRestore() + connectedSpy.mockRestore() + }) + + test("prefers exact promoted fallback match over earlier fuzzy prefix match", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-5.4-preview"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const connectedSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = createBaseArgs({ subagent_type: "explore" }) + const executorCtx = createExecutorContext( + async () => ([ + { name: "explore", mode: "subagent", model: "quotio/claude-haiku-4-5-unavailable" }, + ]), + { + agentOverrides: { + explore: { + fallback_models: [ + { + model: "openai/gpt-5.4", + variant: "low", + reasoningEffort: "medium", + }, + { + model: "openai/gpt-5.4-preview", + variant: "max", + reasoningEffort: "high", + }, + ], + }, + } as ExecutorContext["agentOverrides"], + } + ) + + //#when + const result = await resolveSubagentExecution(args, executorCtx, "sisyphus", "deep") + + //#then + expect(result.error).toBeUndefined() + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-5.4-preview", + variant: "max", + reasoningEffort: "high", + }) + cacheSpy.mockRestore() + connectedSpy.mockRestore() + }) + + test("matches promoted fallback settings when fuzzy resolution extends configured model without hyphen", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-5.4o"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const connectedSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = createBaseArgs({ subagent_type: "explore" }) + const executorCtx = createExecutorContext( + async () => ([ + { name: "explore", mode: "subagent", model: "quotio/claude-haiku-4-5-unavailable" }, + ]), + { + agentOverrides: { + explore: { + fallback_models: [ + { + model: "openai/gpt-5.4", + variant: "low", + reasoningEffort: "high", + }, + ], + }, + } as ExecutorContext["agentOverrides"], + } + ) + + //#when + const result = await resolveSubagentExecution(args, executorCtx, "sisyphus", "deep") + + //#then + expect(result.error).toBeUndefined() + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-5.4o", + variant: "low", + reasoningEffort: "high", + }) + cacheSpy.mockRestore() + connectedSpy.mockRestore() + }) + + test("prefers the most specific prefix match when fallback entries share a prefix", async () => { + //#given + const cacheSpy = spyOn(connectedProvidersCache, "readProviderModelsCache").mockReturnValue({ + models: { openai: ["gpt-4o-preview"] }, + connected: ["openai"], + updatedAt: "2026-03-03T00:00:00.000Z", + }) + const connectedSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"]) + const args = createBaseArgs({ subagent_type: "explore" }) + const executorCtx = createExecutorContext( + async () => ([ + { name: "explore", mode: "subagent", model: "quotio/claude-haiku-4-5-unavailable" }, + ]), + { + agentOverrides: { + explore: { + fallback_models: [ + { + model: "openai/gpt-4", + variant: "low", + reasoningEffort: "medium", + }, + { + model: "openai/gpt-4o", + variant: "max", + reasoningEffort: "high", + }, + ], + }, + } as ExecutorContext["agentOverrides"], + } + ) + + //#when + const result = await resolveSubagentExecution(args, executorCtx, "sisyphus", "deep") + + //#then + expect(result.error).toBeUndefined() + expect(result.categoryModel).toEqual({ + providerID: "openai", + modelID: "gpt-4o-preview", + variant: "max", + reasoningEffort: "high", + }) + cacheSpy.mockRestore() + connectedSpy.mockRestore() + }) }) diff --git a/src/tools/delegate-task/subagent-resolver.ts b/src/tools/delegate-task/subagent-resolver.ts index 5c2d4444e..f1e03a6a5 100644 --- a/src/tools/delegate-task/subagent-resolver.ts +++ b/src/tools/delegate-task/subagent-resolver.ts @@ -1,11 +1,12 @@ import type { DelegateTaskArgs } from "./types" import type { ExecutorContext } from "./executor-types" +import type { DelegatedModelConfig } from "./types" import { isPlanFamily } from "./constants" import { SISYPHUS_JUNIOR_AGENT } from "./sisyphus-junior-agent" import { normalizeModelFormat } from "../../shared/model-format-normalizer" import { AGENT_MODEL_REQUIREMENTS } from "../../shared/model-requirements" -import { normalizeFallbackModels } from "../../shared/model-resolver" -import { buildFallbackChainFromModels } from "../../shared/fallback-chain-from-models" +import { normalizeFallbackModels, flattenToFallbackModelStrings } from "../../shared/model-resolver" +import { buildFallbackChainFromModels, findMostSpecificFallbackEntry } from "../../shared/fallback-chain-from-models" import { getAgentDisplayName, getAgentConfigKey } from "../../shared/agent-display-names" import { normalizeSDKResponse } from "../../shared" import { log } from "../../shared/logger" @@ -17,9 +18,8 @@ export async function resolveSubagentExecution( args: DelegateTaskArgs, executorCtx: ExecutorContext, parentAgent: string | undefined, - categoryExamples: string, - inheritedModel?: string -): Promise<{ agentToUse: string; categoryModel: { providerID: string; modelID: string; variant?: string } | undefined; fallbackChain?: FallbackEntry[]; error?: string }> { + categoryExamples: string +): Promise<{ agentToUse: string; categoryModel: DelegatedModelConfig | undefined; fallbackChain?: FallbackEntry[]; error?: string }> { const { client, agentOverrides, userCategories } = executorCtx if (!args.subagent_type?.trim()) { @@ -49,7 +49,7 @@ Create the work plan directly - that's your job as the planning agent.`, } let agentToUse = agentName - let categoryModel: { providerID: string; modelID: string; variant?: string } | undefined + let categoryModel: DelegatedModelConfig | undefined let fallbackChain: FallbackEntry[] | undefined = undefined try { @@ -117,8 +117,8 @@ Create the work plan directly - that's your job as the planning agent.`, : undefined const resolution = resolveModelForDelegateTask({ - userModel: agentOverride?.model ?? inheritedModel, - userFallbackModels: normalizedAgentFallbackModels, + userModel: agentOverride?.model, + userFallbackModels: flattenToFallbackModelStrings(normalizedAgentFallbackModels), categoryDefaultModel: matchedAgentModelStr, fallbackChain: agentRequirement?.fallbackChain, availableModels, @@ -141,6 +141,25 @@ Create the work plan directly - that's your job as the planning agent.`, defaultProviderID, ) fallbackChain = configuredFallbackChain ?? agentRequirement?.fallbackChain + + // Apply per-model settings: prefer resolver's exact entry, fall back to prefix match on user config + const resolvedFallbackEntry = (resolution && !('skipped' in resolution)) ? resolution.fallbackEntry : undefined + const effectiveEntry = resolvedFallbackEntry + ?? (categoryModel && fallbackChain + ? findMostSpecificFallbackEntry(categoryModel.providerID, categoryModel.modelID, fallbackChain) + : undefined) + + if (categoryModel && effectiveEntry) { + categoryModel = { + ...categoryModel, + variant: agentOverride?.variant ?? effectiveEntry.variant ?? categoryModel.variant, + reasoningEffort: effectiveEntry.reasoningEffort, + temperature: effectiveEntry.temperature, + top_p: effectiveEntry.top_p, + maxTokens: effectiveEntry.maxTokens, + thinking: effectiveEntry.thinking, + } + } } if (!categoryModel && matchedAgent.model) { diff --git a/src/tools/delegate-task/sync-prompt-sender.test.ts b/src/tools/delegate-task/sync-prompt-sender.test.ts index 39bf6bd6f..e7df1b070 100644 --- a/src/tools/delegate-task/sync-prompt-sender.test.ts +++ b/src/tools/delegate-task/sync-prompt-sender.test.ts @@ -3,9 +3,19 @@ const { test: bunTest, expect: bunExpect, mock: bunMock, + afterEach: bunAfterEach, } = require("bun:test") +const { + clearSessionPromptParams, + getSessionPromptParams, +} = require("../../shared/session-prompt-params-state") + bunDescribe("sendSyncPrompt", () => { + bunAfterEach(() => { + clearSessionPromptParams("test-session") + }) + bunTest("passes question=false via tools parameter", async () => { //#given const { sendSyncPrompt } = require("./sync-prompt-sender") @@ -214,6 +224,67 @@ bunDescribe("sendSyncPrompt", () => { bunExpect(promptArgs.body.variant).toBe("medium") }) + bunTest("passes promoted fallback model settings through supported prompt channels", async () => { + //#given + const { sendSyncPrompt } = require("./sync-prompt-sender") + + let promptArgs: any + const promptWithModelSuggestionRetry = bunMock(async (_client: any, input: any) => { + promptArgs = input + }) + + const input = { + sessionID: "test-session", + agentToUse: "oracle", + args: { + description: "test task", + prompt: "test prompt", + run_in_background: false, + load_skills: [], + }, + systemContent: undefined, + categoryModel: { + providerID: "openai", + modelID: "gpt-5.4", + variant: "low", + reasoningEffort: "high", + temperature: 0.4, + top_p: 0.7, + maxTokens: 4096, + thinking: { type: "disabled" }, + }, + toastManager: null, + taskId: undefined, + } + + //#when + await sendSyncPrompt( + { session: { promptAsync: bunMock(async () => ({ data: {} })) } }, + input, + { + promptWithModelSuggestionRetry, + promptSyncWithModelSuggestionRetry: bunMock(async () => {}), + }, + ) + + //#then + bunExpect(promptWithModelSuggestionRetry).toHaveBeenCalledTimes(1) + bunExpect(promptArgs.body.model).toEqual({ + providerID: "openai", + modelID: "gpt-5.4", + }) + bunExpect(promptArgs.body.variant).toBe("low") + bunExpect(promptArgs.body.options).toBeUndefined() + bunExpect(getSessionPromptParams("test-session")).toEqual({ + temperature: 0.4, + topP: 0.7, + options: { + reasoningEffort: "high", + thinking: { type: "disabled" }, + maxTokens: 4096, + }, + }) + }) bunTest("retries with promptSync for oracle when promptAsync fails with unexpected EOF", async () => { //#given const { sendSyncPrompt } = require("./sync-prompt-sender") @@ -289,7 +360,7 @@ bunDescribe("sendSyncPrompt", () => { ) //#then - bunExpect(result).toContain("JSON Parse error: Unexpected EOF") + bunExpect(result).toContain("Unexpected EOF") bunExpect(promptWithModelSuggestionRetry).toHaveBeenCalledTimes(1) bunExpect(promptSyncWithModelSuggestionRetry).toHaveBeenCalledTimes(0) }) diff --git a/src/tools/delegate-task/sync-prompt-sender.ts b/src/tools/delegate-task/sync-prompt-sender.ts index fe4f8a693..65773c92a 100644 --- a/src/tools/delegate-task/sync-prompt-sender.ts +++ b/src/tools/delegate-task/sync-prompt-sender.ts @@ -1,4 +1,4 @@ -import type { DelegateTaskArgs, OpencodeClient } from "./types" +import type { DelegateTaskArgs, OpencodeClient, DelegatedModelConfig } from "./types" import { isPlanFamily } from "./constants" import { buildTaskPrompt } from "./prompt-builder" import { @@ -8,6 +8,7 @@ import { import { formatDetailedError } from "./error-formatting" import { getAgentToolRestrictions } from "../../shared/agent-tool-restrictions" import { setSessionTools } from "../../shared/session-tools-store" +import { setSessionPromptParams } from "../../shared/session-prompt-params-state" import { createInternalAgentTextPart } from "../../shared/internal-initiator-marker" type SendSyncPromptDeps = { @@ -37,7 +38,7 @@ export async function sendSyncPrompt( agentToUse: string args: DelegateTaskArgs systemContent: string | undefined - categoryModel: { providerID: string; modelID: string; variant?: string } | undefined + categoryModel: DelegatedModelConfig | undefined toastManager: { removeTask: (id: string) => void } | null | undefined taskId: string | undefined }, @@ -53,6 +54,26 @@ export async function sendSyncPrompt( } setSessionTools(input.sessionID, tools) + if (input.categoryModel) { + const promptOptions: Record = { + ...(input.categoryModel.reasoningEffort ? { reasoningEffort: input.categoryModel.reasoningEffort } : {}), + ...(input.categoryModel.thinking ? { thinking: input.categoryModel.thinking } : {}), + ...(input.categoryModel.maxTokens !== undefined ? { maxTokens: input.categoryModel.maxTokens } : {}), + } + + if ( + input.categoryModel.temperature !== undefined || + input.categoryModel.top_p !== undefined || + Object.keys(promptOptions).length > 0 + ) { + setSessionPromptParams(input.sessionID, { + ...(input.categoryModel.temperature !== undefined ? { temperature: input.categoryModel.temperature } : {}), + ...(input.categoryModel.top_p !== undefined ? { topP: input.categoryModel.top_p } : {}), + ...(Object.keys(promptOptions).length > 0 ? { options: promptOptions } : {}), + }) + } + } + const promptArgs = { path: { id: input.sessionID }, body: { @@ -61,7 +82,12 @@ export async function sendSyncPrompt( tools, parts: [createInternalAgentTextPart(effectivePrompt)], ...(input.categoryModel - ? { model: { providerID: input.categoryModel.providerID, modelID: input.categoryModel.modelID } } + ? { + model: { + providerID: input.categoryModel.providerID, + modelID: input.categoryModel.modelID, + }, + } : {}), ...(input.categoryModel?.variant ? { variant: input.categoryModel.variant } : {}), }, diff --git a/src/tools/delegate-task/sync-task.ts b/src/tools/delegate-task/sync-task.ts index fa5fad4c0..c84d264cc 100644 --- a/src/tools/delegate-task/sync-task.ts +++ b/src/tools/delegate-task/sync-task.ts @@ -1,5 +1,5 @@ import type { ModelFallbackInfo } from "../../features/task-toast-manager/types" -import type { DelegateTaskArgs, ToolContextWithMetadata } from "./types" +import type { DelegateTaskArgs, ToolContextWithMetadata, DelegatedModelConfig } from "./types" import type { ExecutorContext, ParentContext } from "./executor-types" import { getTaskToastManager } from "../../features/task-toast-manager" import { storeToolMetadata } from "../../features/tool-metadata-store" @@ -17,7 +17,7 @@ export async function executeSyncTask( executorCtx: ExecutorContext, parentContext: ParentContext, agentToUse: string, - categoryModel: { providerID: string; modelID: string; variant?: string } | undefined, + categoryModel: DelegatedModelConfig | undefined, systemContent: string | undefined, modelInfo?: ModelFallbackInfo, fallbackChain?: import("../../shared/model-requirements").FallbackEntry[], diff --git a/src/tools/delegate-task/tools.ts b/src/tools/delegate-task/tools.ts index 45890d035..280e7b79d 100644 --- a/src/tools/delegate-task/tools.ts +++ b/src/tools/delegate-task/tools.ts @@ -178,7 +178,18 @@ export function createDelegateTask(options: DelegateTaskToolOptions): ToolDefini : undefined let agentToUse: string - let categoryModel: { providerID: string; modelID: string; variant?: string } | undefined + let categoryModel: + | { + providerID: string + modelID: string + variant?: string + reasoningEffort?: string + temperature?: number + top_p?: number + maxTokens?: number + thinking?: { type: "enabled" | "disabled"; budgetTokens?: number } + } + | undefined let categoryPromptAppend: string | undefined let modelInfo: import("../../features/task-toast-manager/types").ModelFallbackInfo | undefined let actualModel: string | undefined diff --git a/src/tools/delegate-task/types.ts b/src/tools/delegate-task/types.ts index c51a1bde1..16b3fdf74 100644 --- a/src/tools/delegate-task/types.ts +++ b/src/tools/delegate-task/types.ts @@ -71,6 +71,17 @@ export interface DelegateTaskToolOptions { syncPollTimeoutMs?: number } +export interface DelegatedModelConfig { + providerID: string + modelID: string + variant?: string + reasoningEffort?: string + temperature?: number + top_p?: number + maxTokens?: number + thinking?: { type: "enabled" | "disabled"; budgetTokens?: number } +} + export interface BuildSystemContentInput { skillContent?: string skillContents?: string[] @@ -78,7 +89,7 @@ export interface BuildSystemContentInput { agentsContext?: string planAgentPrepend?: string maxPromptTokens?: number - model?: { providerID: string; modelID: string; variant?: string } + model?: DelegatedModelConfig agentName?: string availableCategories?: AvailableCategory[] availableSkills?: AvailableSkill[] diff --git a/src/tools/delegate-task/unstable-agent-task.ts b/src/tools/delegate-task/unstable-agent-task.ts index 8aa2dce81..ba0ec6152 100644 --- a/src/tools/delegate-task/unstable-agent-task.ts +++ b/src/tools/delegate-task/unstable-agent-task.ts @@ -1,4 +1,4 @@ -import type { DelegateTaskArgs, ToolContextWithMetadata } from "./types" +import type { DelegateTaskArgs, ToolContextWithMetadata, DelegatedModelConfig } from "./types" import type { ExecutorContext, ParentContext, SessionMessage } from "./executor-types" import { DEFAULT_SYNC_POLL_TIMEOUT_MS, getTimingConfig } from "./timing" import { buildTaskPrompt } from "./prompt-builder" @@ -16,7 +16,7 @@ export async function executeUnstableAgentTask( executorCtx: ExecutorContext, parentContext: ParentContext, agentToUse: string, - categoryModel: { providerID: string; modelID: string; variant?: string } | undefined, + categoryModel: DelegatedModelConfig | undefined, systemContent: string | undefined, actualModel: string | undefined ): Promise {