From f7085450f1c628a21c00afaeee5a032785b50d49 Mon Sep 17 00:00:00 2001 From: YeonGyu-Kim Date: Wed, 11 Mar 2026 21:57:06 +0900 Subject: [PATCH] fix(compaction): guard model update during compaction and validate checkpoint model --- .../recovery.test.ts | 66 +++++++++++++ .../compaction-context-injector/recovery.ts | 24 ++++- .../session-prompt-config-resolver.test.ts | 98 +++++++++++++++++++ .../session-prompt-config-resolver.ts | 29 ++---- .../validated-model.ts | 47 +++++++++ src/plugin/event-compaction-agent.test.ts | 40 +++++++- src/plugin/event.ts | 5 +- 7 files changed, 280 insertions(+), 29 deletions(-) create mode 100644 src/hooks/compaction-context-injector/session-prompt-config-resolver.test.ts create mode 100644 src/hooks/compaction-context-injector/validated-model.ts diff --git a/src/hooks/compaction-context-injector/recovery.test.ts b/src/hooks/compaction-context-injector/recovery.test.ts index 5d2f88504..3642a40c4 100644 --- a/src/hooks/compaction-context-injector/recovery.test.ts +++ b/src/hooks/compaction-context-injector/recovery.test.ts @@ -1,6 +1,7 @@ /// import { describe, expect, it } from "bun:test" +import { setCompactionAgentConfigCheckpoint } from "../../shared/compaction-agent-config-checkpoint" import { createCompactionContextInjector } from "./index" type SessionMessageResponse = Array<{ @@ -291,4 +292,69 @@ describe("createCompactionContextInjector recovery", () => { //#then expect(promptAsyncRecorder.calls.length).toBe(0) }) + + it("falls back to the current non-compaction model when a checkpoint model is poisoned", async () => { + //#given + const sessionID = "ses_poisoned_checkpoint_model" + const promptAsyncRecorder = createPromptAsyncRecorder() + setCompactionAgentConfigCheckpoint(sessionID, { + agent: "atlas", + model: { providerID: "anthropic", modelID: "claude-opus-4-1" }, + tools: { bash: true }, + }) + const ctx = createMockContext( + [ + [ + { + info: { + role: "user", + agent: "atlas", + model: { providerID: "openai", modelID: "gpt-5" }, + tools: { bash: true }, + }, + }, + { + info: { + role: "user", + agent: "compaction", + model: { providerID: "anthropic", modelID: "claude-opus-4-1" }, + }, + }, + ], + [ + { + info: { + role: "user", + agent: "compaction", + model: { providerID: "anthropic", modelID: "claude-opus-4-1" }, + }, + }, + ], + [ + { + info: { + role: "user", + agent: "atlas", + model: { providerID: "openai", modelID: "gpt-5" }, + tools: { bash: true }, + }, + }, + ], + ], + promptAsyncRecorder.promptAsync, + ) + const injector = createCompactionContextInjector({ ctx }) + + //#when + await injector.event({ + event: { type: "session.compacted", properties: { sessionID } }, + }) + + //#then + expect(promptAsyncRecorder.calls.length).toBe(1) + expect(promptAsyncRecorder.calls[0]?.body.model).toEqual({ + providerID: "openai", + modelID: "gpt-5", + }) + }) }) diff --git a/src/hooks/compaction-context-injector/recovery.ts b/src/hooks/compaction-context-injector/recovery.ts index 9542d6e17..35b8a89de 100644 --- a/src/hooks/compaction-context-injector/recovery.ts +++ b/src/hooks/compaction-context-injector/recovery.ts @@ -10,6 +10,7 @@ import { createExpectedRecoveryPromptConfig, isPromptConfigRecovered, } from "./recovery-prompt-config" +import { validateCheckpointModel } from "./validated-model" import { resolveLatestSessionPromptConfig, resolveSessionPromptConfig, @@ -34,10 +35,6 @@ export function createRecoveryLogic( if (!checkpoint?.agent) { return false } - const checkpointWithAgent = { - ...checkpoint, - agent: checkpoint.agent, - } const tailState = getTailState(sessionID) const now = Date.now() @@ -46,6 +43,25 @@ export function createRecoveryLogic( } const currentPromptConfig = await resolveSessionPromptConfig(ctx, sessionID) + const validatedCheckpointModel = validateCheckpointModel( + checkpoint.model, + currentPromptConfig.model, + ) + const { model: checkpointModel, ...checkpointWithoutModel } = checkpoint + const checkpointWithAgent = { + ...checkpointWithoutModel, + agent: checkpoint.agent, + ...(validatedCheckpointModel ? { model: validatedCheckpointModel } : {}), + } + + if (checkpointModel && !validatedCheckpointModel) { + log(`[compaction-context-injector] Ignoring checkpoint model that disagrees with current prompt config`, { + sessionID, + checkpointModel, + currentModel: currentPromptConfig.model, + }) + } + const expectedPromptConfig = createExpectedRecoveryPromptConfig( checkpointWithAgent, currentPromptConfig, diff --git a/src/hooks/compaction-context-injector/session-prompt-config-resolver.test.ts b/src/hooks/compaction-context-injector/session-prompt-config-resolver.test.ts new file mode 100644 index 000000000..13dfa3fec --- /dev/null +++ b/src/hooks/compaction-context-injector/session-prompt-config-resolver.test.ts @@ -0,0 +1,98 @@ +import { afterEach, describe, expect, it } from "bun:test" + +import { _resetForTesting } from "../../features/claude-code-session-state" +import { clearSessionModel, setSessionModel } from "../../shared/session-model-state" +import { clearSessionTools } from "../../shared/session-tools-store" +import { + resolveLatestSessionPromptConfig, + resolveSessionPromptConfig, +} from "./session-prompt-config-resolver" + +type SessionMessage = { + info?: { + agent?: string + model?: { + providerID?: string + modelID?: string + } + tools?: Record + } +} + +function createMockContext(messages: SessionMessage[]) { + return { + client: { + session: { + messages: async () => ({ data: messages }), + }, + }, + directory: "/tmp/test", + } +} + +describe("session prompt config resolver", () => { + const sessionID = "ses_compaction_model_validation" + + afterEach(() => { + _resetForTesting() + clearSessionModel(sessionID) + clearSessionTools() + }) + + it("prefers the latest non-compaction model over poisoned session state", async () => { + // given + setSessionModel(sessionID, { + providerID: "anthropic", + modelID: "claude-opus-4-1", + }) + const ctx = createMockContext([ + { + info: { + agent: "atlas", + model: { providerID: "openai", modelID: "gpt-5" }, + tools: { bash: "allow" }, + }, + }, + { + info: { + agent: "compaction", + model: { providerID: "anthropic", modelID: "claude-opus-4-1" }, + }, + }, + ]) + + // when + const promptConfig = await resolveSessionPromptConfig(ctx, sessionID) + + // then + expect(promptConfig).toEqual({ + agent: "atlas", + model: { providerID: "openai", modelID: "gpt-5" }, + tools: { bash: true }, + }) + }) + + it("omits a compaction model from the latest prompt config", async () => { + // given + const ctx = createMockContext([ + { + info: { + agent: "atlas", + model: { providerID: "openai", modelID: "gpt-5" }, + }, + }, + { + info: { + agent: "compaction", + model: { providerID: "anthropic", modelID: "claude-opus-4-1" }, + }, + }, + ]) + + // when + const promptConfig = await resolveLatestSessionPromptConfig(ctx, sessionID) + + // then + expect(promptConfig).toEqual({ agent: "compaction" }) + }) +}) diff --git a/src/hooks/compaction-context-injector/session-prompt-config-resolver.ts b/src/hooks/compaction-context-injector/session-prompt-config-resolver.ts index 8d946ef7e..cc42d7a1f 100644 --- a/src/hooks/compaction-context-injector/session-prompt-config-resolver.ts +++ b/src/hooks/compaction-context-injector/session-prompt-config-resolver.ts @@ -5,6 +5,8 @@ import { normalizeSDKResponse } from "../../shared/normalize-sdk-response" import { normalizePromptTools } from "../../shared/prompt-tools" import { getSessionModel } from "../../shared/session-model-state" import { getSessionTools } from "../../shared/session-tools-store" +import { isCompactionAgent } from "./session-id" +import { resolveValidatedModel } from "./validated-model" type SessionMessage = { info?: { @@ -28,30 +30,13 @@ type ResolverContext = { directory: string } -function isCompactionAgent(agent: string | undefined): boolean { - return agent?.trim().toLowerCase() === "compaction" -} - -function resolveModel( - info: SessionMessage["info"], -): CompactionAgentConfigCheckpoint["model"] | undefined { - const providerID = info?.model?.providerID ?? info?.providerID - const modelID = info?.model?.modelID ?? info?.modelID - - if (!providerID || !modelID) { - return undefined - } - - return { providerID, modelID } -} - export async function resolveSessionPromptConfig( ctx: ResolverContext, sessionID: string, ): Promise { + const storedModel = getSessionModel(sessionID) const promptConfig: CompactionAgentConfigCheckpoint = { agent: getSessionAgent(sessionID), - model: getSessionModel(sessionID), tools: getSessionTools(sessionID), } @@ -69,7 +54,7 @@ export async function resolveSessionPromptConfig( } if (!promptConfig.model) { - const model = resolveModel(info) + const model = resolveValidatedModel(info) if (model) { promptConfig.model = model } @@ -94,6 +79,10 @@ export async function resolveSessionPromptConfig( }) } + if (!promptConfig.model && storedModel) { + promptConfig.model = storedModel + } + return promptConfig } @@ -112,7 +101,7 @@ export async function resolveLatestSessionPromptConfig( return {} } - const model = resolveModel(latestInfo) + const model = resolveValidatedModel(latestInfo) const tools = normalizePromptTools(latestInfo.tools) return { diff --git a/src/hooks/compaction-context-injector/validated-model.ts b/src/hooks/compaction-context-injector/validated-model.ts new file mode 100644 index 000000000..5aa3f897a --- /dev/null +++ b/src/hooks/compaction-context-injector/validated-model.ts @@ -0,0 +1,47 @@ +import type { CompactionAgentConfigCheckpoint } from "../../shared/compaction-agent-config-checkpoint" +import { isCompactionAgent } from "./session-id" + +type PromptConfigInfo = { + agent?: string + model?: { + providerID?: string + modelID?: string + } + providerID?: string + modelID?: string +} + +export function resolveValidatedModel( + info: PromptConfigInfo | undefined, +): CompactionAgentConfigCheckpoint["model"] | undefined { + if (isCompactionAgent(info?.agent)) { + return undefined + } + + const providerID = info?.model?.providerID ?? info?.providerID + const modelID = info?.model?.modelID ?? info?.modelID + + if (!providerID || !modelID) { + return undefined + } + + return { providerID, modelID } +} + +export function validateCheckpointModel( + checkpointModel: CompactionAgentConfigCheckpoint["model"], + currentModel: CompactionAgentConfigCheckpoint["model"], +): CompactionAgentConfigCheckpoint["model"] | undefined { + if (!checkpointModel) { + return undefined + } + + if (!currentModel) { + return checkpointModel + } + + return checkpointModel.providerID === currentModel.providerID && + checkpointModel.modelID === currentModel.modelID + ? checkpointModel + : undefined +} diff --git a/src/plugin/event-compaction-agent.test.ts b/src/plugin/event-compaction-agent.test.ts index d08f5e7ec..44b3d1910 100644 --- a/src/plugin/event-compaction-agent.test.ts +++ b/src/plugin/event-compaction-agent.test.ts @@ -1,7 +1,7 @@ -declare const require: (name: string) => any -const { afterEach, describe, expect, test } = require("bun:test") +import { afterEach, describe, expect, it } from "bun:test" import { _resetForTesting, getSessionAgent, updateSessionAgent } from "../features/claude-code-session-state" +import { clearSessionModel, getSessionModel, setSessionModel } from "../shared/session-model-state" import { createEventHandler } from "./event" function createMinimalEventHandler() { @@ -51,9 +51,11 @@ function createMinimalEventHandler() { describe("createEventHandler compaction agent filtering", () => { afterEach(() => { _resetForTesting() + clearSessionModel("ses_compaction_poisoning") + clearSessionModel("ses_compaction_model_poisoning") }) - test("does not overwrite the stored session agent with compaction", async () => { + it("does not overwrite the stored session agent with compaction", async () => { // given const sessionID = "ses_compaction_poisoning" updateSessionAgent(sessionID, "atlas") @@ -80,4 +82,36 @@ describe("createEventHandler compaction agent filtering", () => { // then expect(getSessionAgent(sessionID)).toBe("atlas") }) + + it("does not overwrite the stored session model with compaction", async () => { + // given + const sessionID = "ses_compaction_model_poisoning" + setSessionModel(sessionID, { providerID: "openai", modelID: "gpt-5" }) + const eventHandler = createMinimalEventHandler() + const input: Parameters>[0] = { + event: { + type: "message.updated", + properties: { + info: { + id: "msg-compaction-model", + sessionID, + role: "user", + agent: "compaction", + providerID: "anthropic", + modelID: "claude-opus-4-1", + time: { created: Date.now() }, + }, + }, + }, + } + + // when + await eventHandler(input) + + // then + expect(getSessionModel(sessionID)).toEqual({ + providerID: "openai", + modelID: "gpt-5", + }) + }) }) diff --git a/src/plugin/event.ts b/src/plugin/event.ts index ce62dffcd..740bd4e73 100644 --- a/src/plugin/event.ts +++ b/src/plugin/event.ts @@ -317,12 +317,13 @@ export function createEventHandler(args: { const agent = info?.agent as string | undefined; const role = info?.role as string | undefined; if (sessionID && role === "user") { - if (agent && !isCompactionAgent(agent)) { + const isCompactionMessage = agent ? isCompactionAgent(agent) : false; + if (agent && !isCompactionMessage) { updateSessionAgent(sessionID, agent); } const providerID = info?.providerID as string | undefined; const modelID = info?.modelID as string | undefined; - if (providerID && modelID) { + if (providerID && modelID && !isCompactionMessage) { lastKnownModelBySession.set(sessionID, { providerID, modelID }); setSessionModel(sessionID, { providerID, modelID }); }