Merge pull request #2798 from code-yeongyu/fix/2353-model-selection-v2
fix(plugin): persist selected model only for main session
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import { describe, test, expect } from "bun:test"
|
||||
import { afterEach, describe, test, expect } from "bun:test"
|
||||
|
||||
import { createChatMessageHandler } from "./chat-message"
|
||||
import { _resetForTesting, setMainSession, subagentSessions } from "../features/claude-code-session-state"
|
||||
import { clearSessionModel, getSessionModel, setSessionModel } from "../shared/session-model-state"
|
||||
|
||||
type ChatMessagePart = { type: string; text?: string; [key: string]: unknown }
|
||||
type ChatMessageHandlerOutput = { message: Record<string, unknown>; parts: ChatMessagePart[] }
|
||||
@@ -30,6 +32,13 @@ function createMockHandlerArgs(overrides?: {
|
||||
}
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
_resetForTesting()
|
||||
clearSessionModel("test-session")
|
||||
clearSessionModel("main-session")
|
||||
clearSessionModel("subagent-session")
|
||||
})
|
||||
|
||||
function createMockInput(agent?: string, model?: { providerID: string; modelID: string }) {
|
||||
return {
|
||||
sessionID: "test-session",
|
||||
@@ -142,4 +151,100 @@ describe("createChatMessageHandler - TUI variant passthrough", () => {
|
||||
expect(output.parts).toHaveLength(1)
|
||||
expect(output.parts[0].text).toContain("[BACKGROUND TASK COMPLETED]")
|
||||
})
|
||||
|
||||
test("reuses the stored model for subsequent messages in the main session when the UI sends none", async () => {
|
||||
//#given
|
||||
setMainSession("test-session")
|
||||
setSessionModel("test-session", { providerID: "openai", modelID: "gpt-5.4" })
|
||||
const args = createMockHandlerArgs({ shouldOverride: false })
|
||||
const handler = createChatMessageHandler(args)
|
||||
const input = createMockInput("sisyphus")
|
||||
const output = createMockOutput()
|
||||
|
||||
//#when
|
||||
await handler(input, output)
|
||||
|
||||
//#then
|
||||
expect(output.message["model"]).toEqual({ providerID: "openai", modelID: "gpt-5.4" })
|
||||
expect(getSessionModel("test-session")).toEqual({ providerID: "openai", modelID: "gpt-5.4" })
|
||||
})
|
||||
|
||||
test("does not reuse a stored model for the first message of a session", async () => {
|
||||
//#given
|
||||
setMainSession("test-session")
|
||||
setSessionModel("test-session", { providerID: "openai", modelID: "gpt-5.4" })
|
||||
const args = createMockHandlerArgs({ shouldOverride: true })
|
||||
const handler = createChatMessageHandler(args)
|
||||
const input = createMockInput("sisyphus")
|
||||
const output = createMockOutput()
|
||||
|
||||
//#when
|
||||
await handler(input, output)
|
||||
|
||||
//#then
|
||||
expect(output.message["model"]).toBeUndefined()
|
||||
})
|
||||
|
||||
test("does not reuse the main-session model for subagent sessions", async () => {
|
||||
//#given
|
||||
setMainSession("main-session")
|
||||
setSessionModel("main-session", { providerID: "openai", modelID: "gpt-5.4" })
|
||||
subagentSessions.add("subagent-session")
|
||||
const args = createMockHandlerArgs({ shouldOverride: false })
|
||||
const handler = createChatMessageHandler(args)
|
||||
const input = {
|
||||
sessionID: "subagent-session",
|
||||
agent: "oracle",
|
||||
}
|
||||
const output = createMockOutput()
|
||||
|
||||
//#when
|
||||
await handler(input, output)
|
||||
|
||||
//#then
|
||||
expect(output.message["model"]).toBeUndefined()
|
||||
expect(getSessionModel("subagent-session")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("does not override explicit agent model overrides with stored session model", async () => {
|
||||
//#given
|
||||
setMainSession("test-session")
|
||||
setSessionModel("test-session", { providerID: "openai", modelID: "gpt-5.4" })
|
||||
const args = createMockHandlerArgs({
|
||||
shouldOverride: false,
|
||||
pluginConfig: {
|
||||
agents: {
|
||||
sisyphus: { model: "anthropic/claude-opus-4-6" },
|
||||
},
|
||||
},
|
||||
})
|
||||
const handler = createChatMessageHandler(args)
|
||||
const input = createMockInput("sisyphus")
|
||||
const output = createMockOutput()
|
||||
|
||||
//#when
|
||||
await handler(input, output)
|
||||
|
||||
//#then
|
||||
expect(output.message["model"]).toBeUndefined()
|
||||
expect(getSessionModel("test-session")).toEqual({ providerID: "openai", modelID: "gpt-5.4" })
|
||||
})
|
||||
|
||||
test("respects a mid-conversation model switch instead of reusing the previous stored model", async () => {
|
||||
//#given
|
||||
setMainSession("test-session")
|
||||
setSessionModel("test-session", { providerID: "anthropic", modelID: "claude-opus-4-6" })
|
||||
const args = createMockHandlerArgs({ shouldOverride: false })
|
||||
const handler = createChatMessageHandler(args)
|
||||
const nextModel = { providerID: "openai", modelID: "gpt-5.4" }
|
||||
const input = createMockInput("sisyphus", nextModel)
|
||||
const output = createMockOutput()
|
||||
|
||||
//#when
|
||||
await handler(input, output)
|
||||
|
||||
//#then
|
||||
expect(output.message["model"]).toBeUndefined()
|
||||
expect(getSessionModel("test-session")).toEqual(nextModel)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,8 +2,8 @@ import type { OhMyOpenCodeConfig } from "../config"
|
||||
import type { PluginContext } from "./types"
|
||||
|
||||
import { hasConnectedProvidersCache } from "../shared"
|
||||
import { setSessionModel } from "../shared/session-model-state"
|
||||
import { setSessionAgent } from "../features/claude-code-session-state"
|
||||
import { getSessionModel, setSessionModel } from "../shared/session-model-state"
|
||||
import { getMainSessionID, setSessionAgent, subagentSessions } from "../features/claude-code-session-state"
|
||||
import { applyUltraworkModelOverrideOnMessage } from "./ultrawork-model-override"
|
||||
import { parseRalphLoopArguments } from "../hooks/ralph-loop/command-arguments"
|
||||
|
||||
@@ -23,6 +23,8 @@ export type ChatMessageInput = {
|
||||
}
|
||||
type StartWorkHookOutput = { parts: Array<{ type: string; text?: string }> }
|
||||
|
||||
type SessionModelOverride = { providerID: string; modelID: string }
|
||||
|
||||
function isStartWorkHookOutput(value: unknown): value is StartWorkHookOutput {
|
||||
if (typeof value !== "object" || value === null) return false
|
||||
const record = value as Record<string, unknown>
|
||||
@@ -35,6 +37,53 @@ function isStartWorkHookOutput(value: unknown): value is StartWorkHookOutput {
|
||||
})
|
||||
}
|
||||
|
||||
function hasExplicitAgentModelOverride(
|
||||
agent: string | undefined,
|
||||
pluginConfig: OhMyOpenCodeConfig
|
||||
): boolean {
|
||||
const configuredAgents = pluginConfig.agents
|
||||
if (!agent || !configuredAgents || !(agent in configuredAgents)) {
|
||||
return false
|
||||
}
|
||||
|
||||
const configuredAgent = configuredAgents[agent as keyof typeof configuredAgents]
|
||||
const configuredModel = configuredAgent?.model
|
||||
return typeof configuredModel === "string" && configuredModel.trim().length > 0
|
||||
}
|
||||
|
||||
function getStoredMainSessionModel(
|
||||
input: ChatMessageInput,
|
||||
pluginConfig: OhMyOpenCodeConfig,
|
||||
isFirstMessage: boolean,
|
||||
output: ChatMessageHandlerOutput
|
||||
): SessionModelOverride | undefined {
|
||||
if (isFirstMessage) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (subagentSessions.has(input.sessionID)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (getMainSessionID() !== input.sessionID) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (input.model) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (output.message["model"] !== undefined) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (hasExplicitAgentModelOverride(input.agent, pluginConfig)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return getSessionModel(input.sessionID)
|
||||
}
|
||||
|
||||
export function createChatMessageHandler(args: {
|
||||
ctx: PluginContext
|
||||
pluginConfig: OhMyOpenCodeConfig
|
||||
@@ -74,10 +123,21 @@ export function createChatMessageHandler(args: {
|
||||
setSessionAgent(input.sessionID, input.agent)
|
||||
}
|
||||
|
||||
if (firstMessageVariantGate.shouldOverride(input.sessionID)) {
|
||||
const isFirstMessage = firstMessageVariantGate.shouldOverride(input.sessionID)
|
||||
if (isFirstMessage) {
|
||||
firstMessageVariantGate.markApplied(input.sessionID)
|
||||
}
|
||||
|
||||
const storedMainSessionModel = getStoredMainSessionModel(
|
||||
input,
|
||||
pluginConfig,
|
||||
isFirstMessage,
|
||||
output,
|
||||
)
|
||||
if (storedMainSessionModel) {
|
||||
output.message["model"] = storedMainSessionModel
|
||||
}
|
||||
|
||||
if (!isRuntimeFallbackEnabled) {
|
||||
await hooks.modelFallback?.["chat.message"]?.(input, output)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user