diff --git a/src/plugin/chat-message.test.ts b/src/plugin/chat-message.test.ts index a10968303..e79eb8d2d 100644 --- a/src/plugin/chat-message.test.ts +++ b/src/plugin/chat-message.test.ts @@ -1,6 +1,8 @@ -import { describe, test, expect } from "bun:test" +import { afterEach, describe, test, expect } from "bun:test" import { createChatMessageHandler } from "./chat-message" +import { _resetForTesting, setMainSession, subagentSessions } from "../features/claude-code-session-state" +import { clearSessionModel, getSessionModel, setSessionModel } from "../shared/session-model-state" type ChatMessagePart = { type: string; text?: string; [key: string]: unknown } type ChatMessageHandlerOutput = { message: Record; parts: ChatMessagePart[] } @@ -30,6 +32,13 @@ function createMockHandlerArgs(overrides?: { } } +afterEach(() => { + _resetForTesting() + clearSessionModel("test-session") + clearSessionModel("main-session") + clearSessionModel("subagent-session") +}) + function createMockInput(agent?: string, model?: { providerID: string; modelID: string }) { return { sessionID: "test-session", @@ -142,4 +151,100 @@ describe("createChatMessageHandler - TUI variant passthrough", () => { expect(output.parts).toHaveLength(1) expect(output.parts[0].text).toContain("[BACKGROUND TASK COMPLETED]") }) + + test("reuses the stored model for subsequent messages in the main session when the UI sends none", async () => { + //#given + setMainSession("test-session") + setSessionModel("test-session", { providerID: "openai", modelID: "gpt-5.4" }) + const args = createMockHandlerArgs({ shouldOverride: false }) + const handler = createChatMessageHandler(args) + const input = createMockInput("sisyphus") + const output = createMockOutput() + + //#when + await handler(input, output) + + //#then + expect(output.message["model"]).toEqual({ providerID: "openai", modelID: "gpt-5.4" }) + expect(getSessionModel("test-session")).toEqual({ providerID: "openai", modelID: "gpt-5.4" }) + }) + + test("does not reuse a stored model for the first message of a session", async () => { + //#given + setMainSession("test-session") + setSessionModel("test-session", { providerID: "openai", modelID: "gpt-5.4" }) + const args = createMockHandlerArgs({ shouldOverride: true }) + const handler = createChatMessageHandler(args) + const input = createMockInput("sisyphus") + const output = createMockOutput() + + //#when + await handler(input, output) + + //#then + expect(output.message["model"]).toBeUndefined() + }) + + test("does not reuse the main-session model for subagent sessions", async () => { + //#given + setMainSession("main-session") + setSessionModel("main-session", { providerID: "openai", modelID: "gpt-5.4" }) + subagentSessions.add("subagent-session") + const args = createMockHandlerArgs({ shouldOverride: false }) + const handler = createChatMessageHandler(args) + const input = { + sessionID: "subagent-session", + agent: "oracle", + } + const output = createMockOutput() + + //#when + await handler(input, output) + + //#then + expect(output.message["model"]).toBeUndefined() + expect(getSessionModel("subagent-session")).toBeUndefined() + }) + + test("does not override explicit agent model overrides with stored session model", async () => { + //#given + setMainSession("test-session") + setSessionModel("test-session", { providerID: "openai", modelID: "gpt-5.4" }) + const args = createMockHandlerArgs({ + shouldOverride: false, + pluginConfig: { + agents: { + sisyphus: { model: "anthropic/claude-opus-4-6" }, + }, + }, + }) + const handler = createChatMessageHandler(args) + const input = createMockInput("sisyphus") + const output = createMockOutput() + + //#when + await handler(input, output) + + //#then + expect(output.message["model"]).toBeUndefined() + expect(getSessionModel("test-session")).toEqual({ providerID: "openai", modelID: "gpt-5.4" }) + }) + + test("respects a mid-conversation model switch instead of reusing the previous stored model", async () => { + //#given + setMainSession("test-session") + setSessionModel("test-session", { providerID: "anthropic", modelID: "claude-opus-4-6" }) + const args = createMockHandlerArgs({ shouldOverride: false }) + const handler = createChatMessageHandler(args) + const nextModel = { providerID: "openai", modelID: "gpt-5.4" } + const input = createMockInput("sisyphus", nextModel) + const output = createMockOutput() + + //#when + await handler(input, output) + + //#then + expect(output.message["model"]).toBeUndefined() + expect(getSessionModel("test-session")).toEqual(nextModel) + }) }) diff --git a/src/plugin/chat-message.ts b/src/plugin/chat-message.ts index 750eaf667..b7bfea33f 100644 --- a/src/plugin/chat-message.ts +++ b/src/plugin/chat-message.ts @@ -2,8 +2,8 @@ import type { OhMyOpenCodeConfig } from "../config" import type { PluginContext } from "./types" import { hasConnectedProvidersCache } from "../shared" -import { setSessionModel } from "../shared/session-model-state" -import { setSessionAgent } from "../features/claude-code-session-state" +import { getSessionModel, setSessionModel } from "../shared/session-model-state" +import { getMainSessionID, setSessionAgent, subagentSessions } from "../features/claude-code-session-state" import { applyUltraworkModelOverrideOnMessage } from "./ultrawork-model-override" import { parseRalphLoopArguments } from "../hooks/ralph-loop/command-arguments" @@ -23,6 +23,8 @@ export type ChatMessageInput = { } type StartWorkHookOutput = { parts: Array<{ type: string; text?: string }> } +type SessionModelOverride = { providerID: string; modelID: string } + function isStartWorkHookOutput(value: unknown): value is StartWorkHookOutput { if (typeof value !== "object" || value === null) return false const record = value as Record @@ -35,6 +37,53 @@ function isStartWorkHookOutput(value: unknown): value is StartWorkHookOutput { }) } +function hasExplicitAgentModelOverride( + agent: string | undefined, + pluginConfig: OhMyOpenCodeConfig +): boolean { + const configuredAgents = pluginConfig.agents + if (!agent || !configuredAgents || !(agent in configuredAgents)) { + return false + } + + const configuredAgent = configuredAgents[agent as keyof typeof configuredAgents] + const configuredModel = configuredAgent?.model + return typeof configuredModel === "string" && configuredModel.trim().length > 0 +} + +function getStoredMainSessionModel( + input: ChatMessageInput, + pluginConfig: OhMyOpenCodeConfig, + isFirstMessage: boolean, + output: ChatMessageHandlerOutput +): SessionModelOverride | undefined { + if (isFirstMessage) { + return undefined + } + + if (subagentSessions.has(input.sessionID)) { + return undefined + } + + if (getMainSessionID() !== input.sessionID) { + return undefined + } + + if (input.model) { + return undefined + } + + if (output.message["model"] !== undefined) { + return undefined + } + + if (hasExplicitAgentModelOverride(input.agent, pluginConfig)) { + return undefined + } + + return getSessionModel(input.sessionID) +} + export function createChatMessageHandler(args: { ctx: PluginContext pluginConfig: OhMyOpenCodeConfig @@ -74,10 +123,21 @@ export function createChatMessageHandler(args: { setSessionAgent(input.sessionID, input.agent) } - if (firstMessageVariantGate.shouldOverride(input.sessionID)) { + const isFirstMessage = firstMessageVariantGate.shouldOverride(input.sessionID) + if (isFirstMessage) { firstMessageVariantGate.markApplied(input.sessionID) } + const storedMainSessionModel = getStoredMainSessionModel( + input, + pluginConfig, + isFirstMessage, + output, + ) + if (storedMainSessionModel) { + output.message["model"] = storedMainSessionModel + } + if (!isRuntimeFallbackEnabled) { await hooks.modelFallback?.["chat.message"]?.(input, output) }