Merge pull request #2471 from code-yeongyu/fix/compaction-model-filter

fix(compaction): guard model update during compaction
This commit is contained in:
YeonGyu-Kim
2026-03-11 22:01:53 +09:00
committed by GitHub
7 changed files with 280 additions and 29 deletions

View File

@@ -1,6 +1,7 @@
/// <reference path="../../../bun-test.d.ts" />
import { describe, expect, it } from "bun:test"
import { setCompactionAgentConfigCheckpoint } from "../../shared/compaction-agent-config-checkpoint"
import { createCompactionContextInjector } from "./index"
type SessionMessageResponse = Array<{
@@ -291,4 +292,69 @@ describe("createCompactionContextInjector recovery", () => {
//#then
expect(promptAsyncRecorder.calls.length).toBe(0)
})
it("falls back to the current non-compaction model when a checkpoint model is poisoned", async () => {
//#given
const sessionID = "ses_poisoned_checkpoint_model"
const promptAsyncRecorder = createPromptAsyncRecorder()
setCompactionAgentConfigCheckpoint(sessionID, {
agent: "atlas",
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
tools: { bash: true },
})
const ctx = createMockContext(
[
[
{
info: {
role: "user",
agent: "atlas",
model: { providerID: "openai", modelID: "gpt-5" },
tools: { bash: true },
},
},
{
info: {
role: "user",
agent: "compaction",
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
},
},
],
[
{
info: {
role: "user",
agent: "compaction",
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
},
},
],
[
{
info: {
role: "user",
agent: "atlas",
model: { providerID: "openai", modelID: "gpt-5" },
tools: { bash: true },
},
},
],
],
promptAsyncRecorder.promptAsync,
)
const injector = createCompactionContextInjector({ ctx })
//#when
await injector.event({
event: { type: "session.compacted", properties: { sessionID } },
})
//#then
expect(promptAsyncRecorder.calls.length).toBe(1)
expect(promptAsyncRecorder.calls[0]?.body.model).toEqual({
providerID: "openai",
modelID: "gpt-5",
})
})
})

View File

@@ -10,6 +10,7 @@ import {
createExpectedRecoveryPromptConfig,
isPromptConfigRecovered,
} from "./recovery-prompt-config"
import { validateCheckpointModel } from "./validated-model"
import {
resolveLatestSessionPromptConfig,
resolveSessionPromptConfig,
@@ -34,10 +35,6 @@ export function createRecoveryLogic(
if (!checkpoint?.agent) {
return false
}
const checkpointWithAgent = {
...checkpoint,
agent: checkpoint.agent,
}
const tailState = getTailState(sessionID)
const now = Date.now()
@@ -46,6 +43,25 @@ export function createRecoveryLogic(
}
const currentPromptConfig = await resolveSessionPromptConfig(ctx, sessionID)
const validatedCheckpointModel = validateCheckpointModel(
checkpoint.model,
currentPromptConfig.model,
)
const { model: checkpointModel, ...checkpointWithoutModel } = checkpoint
const checkpointWithAgent = {
...checkpointWithoutModel,
agent: checkpoint.agent,
...(validatedCheckpointModel ? { model: validatedCheckpointModel } : {}),
}
if (checkpointModel && !validatedCheckpointModel) {
log(`[compaction-context-injector] Ignoring checkpoint model that disagrees with current prompt config`, {
sessionID,
checkpointModel,
currentModel: currentPromptConfig.model,
})
}
const expectedPromptConfig = createExpectedRecoveryPromptConfig(
checkpointWithAgent,
currentPromptConfig,

View File

@@ -0,0 +1,98 @@
import { afterEach, describe, expect, it } from "bun:test"
import { _resetForTesting } from "../../features/claude-code-session-state"
import { clearSessionModel, setSessionModel } from "../../shared/session-model-state"
import { clearSessionTools } from "../../shared/session-tools-store"
import {
resolveLatestSessionPromptConfig,
resolveSessionPromptConfig,
} from "./session-prompt-config-resolver"
type SessionMessage = {
info?: {
agent?: string
model?: {
providerID?: string
modelID?: string
}
tools?: Record<string, boolean | "allow" | "deny" | "ask">
}
}
function createMockContext(messages: SessionMessage[]) {
return {
client: {
session: {
messages: async () => ({ data: messages }),
},
},
directory: "/tmp/test",
}
}
describe("session prompt config resolver", () => {
const sessionID = "ses_compaction_model_validation"
afterEach(() => {
_resetForTesting()
clearSessionModel(sessionID)
clearSessionTools()
})
it("prefers the latest non-compaction model over poisoned session state", async () => {
// given
setSessionModel(sessionID, {
providerID: "anthropic",
modelID: "claude-opus-4-1",
})
const ctx = createMockContext([
{
info: {
agent: "atlas",
model: { providerID: "openai", modelID: "gpt-5" },
tools: { bash: "allow" },
},
},
{
info: {
agent: "compaction",
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
},
},
])
// when
const promptConfig = await resolveSessionPromptConfig(ctx, sessionID)
// then
expect(promptConfig).toEqual({
agent: "atlas",
model: { providerID: "openai", modelID: "gpt-5" },
tools: { bash: true },
})
})
it("omits a compaction model from the latest prompt config", async () => {
// given
const ctx = createMockContext([
{
info: {
agent: "atlas",
model: { providerID: "openai", modelID: "gpt-5" },
},
},
{
info: {
agent: "compaction",
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
},
},
])
// when
const promptConfig = await resolveLatestSessionPromptConfig(ctx, sessionID)
// then
expect(promptConfig).toEqual({ agent: "compaction" })
})
})

View File

@@ -5,6 +5,8 @@ import { normalizeSDKResponse } from "../../shared/normalize-sdk-response"
import { normalizePromptTools } from "../../shared/prompt-tools"
import { getSessionModel } from "../../shared/session-model-state"
import { getSessionTools } from "../../shared/session-tools-store"
import { isCompactionAgent } from "./session-id"
import { resolveValidatedModel } from "./validated-model"
type SessionMessage = {
info?: {
@@ -28,30 +30,13 @@ type ResolverContext = {
directory: string
}
function isCompactionAgent(agent: string | undefined): boolean {
return agent?.trim().toLowerCase() === "compaction"
}
function resolveModel(
info: SessionMessage["info"],
): CompactionAgentConfigCheckpoint["model"] | undefined {
const providerID = info?.model?.providerID ?? info?.providerID
const modelID = info?.model?.modelID ?? info?.modelID
if (!providerID || !modelID) {
return undefined
}
return { providerID, modelID }
}
export async function resolveSessionPromptConfig(
ctx: ResolverContext,
sessionID: string,
): Promise<CompactionAgentConfigCheckpoint> {
const storedModel = getSessionModel(sessionID)
const promptConfig: CompactionAgentConfigCheckpoint = {
agent: getSessionAgent(sessionID),
model: getSessionModel(sessionID),
tools: getSessionTools(sessionID),
}
@@ -69,7 +54,7 @@ export async function resolveSessionPromptConfig(
}
if (!promptConfig.model) {
const model = resolveModel(info)
const model = resolveValidatedModel(info)
if (model) {
promptConfig.model = model
}
@@ -94,6 +79,10 @@ export async function resolveSessionPromptConfig(
})
}
if (!promptConfig.model && storedModel) {
promptConfig.model = storedModel
}
return promptConfig
}
@@ -112,7 +101,7 @@ export async function resolveLatestSessionPromptConfig(
return {}
}
const model = resolveModel(latestInfo)
const model = resolveValidatedModel(latestInfo)
const tools = normalizePromptTools(latestInfo.tools)
return {

View File

@@ -0,0 +1,47 @@
import type { CompactionAgentConfigCheckpoint } from "../../shared/compaction-agent-config-checkpoint"
import { isCompactionAgent } from "./session-id"
type PromptConfigInfo = {
agent?: string
model?: {
providerID?: string
modelID?: string
}
providerID?: string
modelID?: string
}
export function resolveValidatedModel(
info: PromptConfigInfo | undefined,
): CompactionAgentConfigCheckpoint["model"] | undefined {
if (isCompactionAgent(info?.agent)) {
return undefined
}
const providerID = info?.model?.providerID ?? info?.providerID
const modelID = info?.model?.modelID ?? info?.modelID
if (!providerID || !modelID) {
return undefined
}
return { providerID, modelID }
}
export function validateCheckpointModel(
checkpointModel: CompactionAgentConfigCheckpoint["model"],
currentModel: CompactionAgentConfigCheckpoint["model"],
): CompactionAgentConfigCheckpoint["model"] | undefined {
if (!checkpointModel) {
return undefined
}
if (!currentModel) {
return checkpointModel
}
return checkpointModel.providerID === currentModel.providerID &&
checkpointModel.modelID === currentModel.modelID
? checkpointModel
: undefined
}

View File

@@ -1,7 +1,7 @@
declare const require: (name: string) => any
const { afterEach, describe, expect, test } = require("bun:test")
import { afterEach, describe, expect, it } from "bun:test"
import { _resetForTesting, getSessionAgent, updateSessionAgent } from "../features/claude-code-session-state"
import { clearSessionModel, getSessionModel, setSessionModel } from "../shared/session-model-state"
import { createEventHandler } from "./event"
function createMinimalEventHandler() {
@@ -51,9 +51,11 @@ function createMinimalEventHandler() {
describe("createEventHandler compaction agent filtering", () => {
afterEach(() => {
_resetForTesting()
clearSessionModel("ses_compaction_poisoning")
clearSessionModel("ses_compaction_model_poisoning")
})
test("does not overwrite the stored session agent with compaction", async () => {
it("does not overwrite the stored session agent with compaction", async () => {
// given
const sessionID = "ses_compaction_poisoning"
updateSessionAgent(sessionID, "atlas")
@@ -80,4 +82,36 @@ describe("createEventHandler compaction agent filtering", () => {
// then
expect(getSessionAgent(sessionID)).toBe("atlas")
})
it("does not overwrite the stored session model with compaction", async () => {
// given
const sessionID = "ses_compaction_model_poisoning"
setSessionModel(sessionID, { providerID: "openai", modelID: "gpt-5" })
const eventHandler = createMinimalEventHandler()
const input: Parameters<ReturnType<typeof createEventHandler>>[0] = {
event: {
type: "message.updated",
properties: {
info: {
id: "msg-compaction-model",
sessionID,
role: "user",
agent: "compaction",
providerID: "anthropic",
modelID: "claude-opus-4-1",
time: { created: Date.now() },
},
},
},
}
// when
await eventHandler(input)
// then
expect(getSessionModel(sessionID)).toEqual({
providerID: "openai",
modelID: "gpt-5",
})
})
})

View File

@@ -317,12 +317,13 @@ export function createEventHandler(args: {
const agent = info?.agent as string | undefined;
const role = info?.role as string | undefined;
if (sessionID && role === "user") {
if (agent && !isCompactionAgent(agent)) {
const isCompactionMessage = agent ? isCompactionAgent(agent) : false;
if (agent && !isCompactionMessage) {
updateSessionAgent(sessionID, agent);
}
const providerID = info?.providerID as string | undefined;
const modelID = info?.modelID as string | undefined;
if (providerID && modelID) {
if (providerID && modelID && !isCompactionMessage) {
lastKnownModelBySession.set(sessionID, { providerID, modelID });
setSessionModel(sessionID, { providerID, modelID });
}