Merge pull request #2471 from code-yeongyu/fix/compaction-model-filter
fix(compaction): guard model update during compaction
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
/// <reference path="../../../bun-test.d.ts" />
|
||||
|
||||
import { describe, expect, it } from "bun:test"
|
||||
import { setCompactionAgentConfigCheckpoint } from "../../shared/compaction-agent-config-checkpoint"
|
||||
import { createCompactionContextInjector } from "./index"
|
||||
|
||||
type SessionMessageResponse = Array<{
|
||||
@@ -291,4 +292,69 @@ describe("createCompactionContextInjector recovery", () => {
|
||||
//#then
|
||||
expect(promptAsyncRecorder.calls.length).toBe(0)
|
||||
})
|
||||
|
||||
it("falls back to the current non-compaction model when a checkpoint model is poisoned", async () => {
|
||||
//#given
|
||||
const sessionID = "ses_poisoned_checkpoint_model"
|
||||
const promptAsyncRecorder = createPromptAsyncRecorder()
|
||||
setCompactionAgentConfigCheckpoint(sessionID, {
|
||||
agent: "atlas",
|
||||
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
|
||||
tools: { bash: true },
|
||||
})
|
||||
const ctx = createMockContext(
|
||||
[
|
||||
[
|
||||
{
|
||||
info: {
|
||||
role: "user",
|
||||
agent: "atlas",
|
||||
model: { providerID: "openai", modelID: "gpt-5" },
|
||||
tools: { bash: true },
|
||||
},
|
||||
},
|
||||
{
|
||||
info: {
|
||||
role: "user",
|
||||
agent: "compaction",
|
||||
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
|
||||
},
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
info: {
|
||||
role: "user",
|
||||
agent: "compaction",
|
||||
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
|
||||
},
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
info: {
|
||||
role: "user",
|
||||
agent: "atlas",
|
||||
model: { providerID: "openai", modelID: "gpt-5" },
|
||||
tools: { bash: true },
|
||||
},
|
||||
},
|
||||
],
|
||||
],
|
||||
promptAsyncRecorder.promptAsync,
|
||||
)
|
||||
const injector = createCompactionContextInjector({ ctx })
|
||||
|
||||
//#when
|
||||
await injector.event({
|
||||
event: { type: "session.compacted", properties: { sessionID } },
|
||||
})
|
||||
|
||||
//#then
|
||||
expect(promptAsyncRecorder.calls.length).toBe(1)
|
||||
expect(promptAsyncRecorder.calls[0]?.body.model).toEqual({
|
||||
providerID: "openai",
|
||||
modelID: "gpt-5",
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
createExpectedRecoveryPromptConfig,
|
||||
isPromptConfigRecovered,
|
||||
} from "./recovery-prompt-config"
|
||||
import { validateCheckpointModel } from "./validated-model"
|
||||
import {
|
||||
resolveLatestSessionPromptConfig,
|
||||
resolveSessionPromptConfig,
|
||||
@@ -34,10 +35,6 @@ export function createRecoveryLogic(
|
||||
if (!checkpoint?.agent) {
|
||||
return false
|
||||
}
|
||||
const checkpointWithAgent = {
|
||||
...checkpoint,
|
||||
agent: checkpoint.agent,
|
||||
}
|
||||
|
||||
const tailState = getTailState(sessionID)
|
||||
const now = Date.now()
|
||||
@@ -46,6 +43,25 @@ export function createRecoveryLogic(
|
||||
}
|
||||
|
||||
const currentPromptConfig = await resolveSessionPromptConfig(ctx, sessionID)
|
||||
const validatedCheckpointModel = validateCheckpointModel(
|
||||
checkpoint.model,
|
||||
currentPromptConfig.model,
|
||||
)
|
||||
const { model: checkpointModel, ...checkpointWithoutModel } = checkpoint
|
||||
const checkpointWithAgent = {
|
||||
...checkpointWithoutModel,
|
||||
agent: checkpoint.agent,
|
||||
...(validatedCheckpointModel ? { model: validatedCheckpointModel } : {}),
|
||||
}
|
||||
|
||||
if (checkpointModel && !validatedCheckpointModel) {
|
||||
log(`[compaction-context-injector] Ignoring checkpoint model that disagrees with current prompt config`, {
|
||||
sessionID,
|
||||
checkpointModel,
|
||||
currentModel: currentPromptConfig.model,
|
||||
})
|
||||
}
|
||||
|
||||
const expectedPromptConfig = createExpectedRecoveryPromptConfig(
|
||||
checkpointWithAgent,
|
||||
currentPromptConfig,
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
import { afterEach, describe, expect, it } from "bun:test"
|
||||
|
||||
import { _resetForTesting } from "../../features/claude-code-session-state"
|
||||
import { clearSessionModel, setSessionModel } from "../../shared/session-model-state"
|
||||
import { clearSessionTools } from "../../shared/session-tools-store"
|
||||
import {
|
||||
resolveLatestSessionPromptConfig,
|
||||
resolveSessionPromptConfig,
|
||||
} from "./session-prompt-config-resolver"
|
||||
|
||||
type SessionMessage = {
|
||||
info?: {
|
||||
agent?: string
|
||||
model?: {
|
||||
providerID?: string
|
||||
modelID?: string
|
||||
}
|
||||
tools?: Record<string, boolean | "allow" | "deny" | "ask">
|
||||
}
|
||||
}
|
||||
|
||||
function createMockContext(messages: SessionMessage[]) {
|
||||
return {
|
||||
client: {
|
||||
session: {
|
||||
messages: async () => ({ data: messages }),
|
||||
},
|
||||
},
|
||||
directory: "/tmp/test",
|
||||
}
|
||||
}
|
||||
|
||||
describe("session prompt config resolver", () => {
|
||||
const sessionID = "ses_compaction_model_validation"
|
||||
|
||||
afterEach(() => {
|
||||
_resetForTesting()
|
||||
clearSessionModel(sessionID)
|
||||
clearSessionTools()
|
||||
})
|
||||
|
||||
it("prefers the latest non-compaction model over poisoned session state", async () => {
|
||||
// given
|
||||
setSessionModel(sessionID, {
|
||||
providerID: "anthropic",
|
||||
modelID: "claude-opus-4-1",
|
||||
})
|
||||
const ctx = createMockContext([
|
||||
{
|
||||
info: {
|
||||
agent: "atlas",
|
||||
model: { providerID: "openai", modelID: "gpt-5" },
|
||||
tools: { bash: "allow" },
|
||||
},
|
||||
},
|
||||
{
|
||||
info: {
|
||||
agent: "compaction",
|
||||
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
|
||||
},
|
||||
},
|
||||
])
|
||||
|
||||
// when
|
||||
const promptConfig = await resolveSessionPromptConfig(ctx, sessionID)
|
||||
|
||||
// then
|
||||
expect(promptConfig).toEqual({
|
||||
agent: "atlas",
|
||||
model: { providerID: "openai", modelID: "gpt-5" },
|
||||
tools: { bash: true },
|
||||
})
|
||||
})
|
||||
|
||||
it("omits a compaction model from the latest prompt config", async () => {
|
||||
// given
|
||||
const ctx = createMockContext([
|
||||
{
|
||||
info: {
|
||||
agent: "atlas",
|
||||
model: { providerID: "openai", modelID: "gpt-5" },
|
||||
},
|
||||
},
|
||||
{
|
||||
info: {
|
||||
agent: "compaction",
|
||||
model: { providerID: "anthropic", modelID: "claude-opus-4-1" },
|
||||
},
|
||||
},
|
||||
])
|
||||
|
||||
// when
|
||||
const promptConfig = await resolveLatestSessionPromptConfig(ctx, sessionID)
|
||||
|
||||
// then
|
||||
expect(promptConfig).toEqual({ agent: "compaction" })
|
||||
})
|
||||
})
|
||||
@@ -5,6 +5,8 @@ import { normalizeSDKResponse } from "../../shared/normalize-sdk-response"
|
||||
import { normalizePromptTools } from "../../shared/prompt-tools"
|
||||
import { getSessionModel } from "../../shared/session-model-state"
|
||||
import { getSessionTools } from "../../shared/session-tools-store"
|
||||
import { isCompactionAgent } from "./session-id"
|
||||
import { resolveValidatedModel } from "./validated-model"
|
||||
|
||||
type SessionMessage = {
|
||||
info?: {
|
||||
@@ -28,30 +30,13 @@ type ResolverContext = {
|
||||
directory: string
|
||||
}
|
||||
|
||||
function isCompactionAgent(agent: string | undefined): boolean {
|
||||
return agent?.trim().toLowerCase() === "compaction"
|
||||
}
|
||||
|
||||
function resolveModel(
|
||||
info: SessionMessage["info"],
|
||||
): CompactionAgentConfigCheckpoint["model"] | undefined {
|
||||
const providerID = info?.model?.providerID ?? info?.providerID
|
||||
const modelID = info?.model?.modelID ?? info?.modelID
|
||||
|
||||
if (!providerID || !modelID) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return { providerID, modelID }
|
||||
}
|
||||
|
||||
export async function resolveSessionPromptConfig(
|
||||
ctx: ResolverContext,
|
||||
sessionID: string,
|
||||
): Promise<CompactionAgentConfigCheckpoint> {
|
||||
const storedModel = getSessionModel(sessionID)
|
||||
const promptConfig: CompactionAgentConfigCheckpoint = {
|
||||
agent: getSessionAgent(sessionID),
|
||||
model: getSessionModel(sessionID),
|
||||
tools: getSessionTools(sessionID),
|
||||
}
|
||||
|
||||
@@ -69,7 +54,7 @@ export async function resolveSessionPromptConfig(
|
||||
}
|
||||
|
||||
if (!promptConfig.model) {
|
||||
const model = resolveModel(info)
|
||||
const model = resolveValidatedModel(info)
|
||||
if (model) {
|
||||
promptConfig.model = model
|
||||
}
|
||||
@@ -94,6 +79,10 @@ export async function resolveSessionPromptConfig(
|
||||
})
|
||||
}
|
||||
|
||||
if (!promptConfig.model && storedModel) {
|
||||
promptConfig.model = storedModel
|
||||
}
|
||||
|
||||
return promptConfig
|
||||
}
|
||||
|
||||
@@ -112,7 +101,7 @@ export async function resolveLatestSessionPromptConfig(
|
||||
return {}
|
||||
}
|
||||
|
||||
const model = resolveModel(latestInfo)
|
||||
const model = resolveValidatedModel(latestInfo)
|
||||
const tools = normalizePromptTools(latestInfo.tools)
|
||||
|
||||
return {
|
||||
|
||||
47
src/hooks/compaction-context-injector/validated-model.ts
Normal file
47
src/hooks/compaction-context-injector/validated-model.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import type { CompactionAgentConfigCheckpoint } from "../../shared/compaction-agent-config-checkpoint"
|
||||
import { isCompactionAgent } from "./session-id"
|
||||
|
||||
type PromptConfigInfo = {
|
||||
agent?: string
|
||||
model?: {
|
||||
providerID?: string
|
||||
modelID?: string
|
||||
}
|
||||
providerID?: string
|
||||
modelID?: string
|
||||
}
|
||||
|
||||
export function resolveValidatedModel(
|
||||
info: PromptConfigInfo | undefined,
|
||||
): CompactionAgentConfigCheckpoint["model"] | undefined {
|
||||
if (isCompactionAgent(info?.agent)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const providerID = info?.model?.providerID ?? info?.providerID
|
||||
const modelID = info?.model?.modelID ?? info?.modelID
|
||||
|
||||
if (!providerID || !modelID) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return { providerID, modelID }
|
||||
}
|
||||
|
||||
export function validateCheckpointModel(
|
||||
checkpointModel: CompactionAgentConfigCheckpoint["model"],
|
||||
currentModel: CompactionAgentConfigCheckpoint["model"],
|
||||
): CompactionAgentConfigCheckpoint["model"] | undefined {
|
||||
if (!checkpointModel) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (!currentModel) {
|
||||
return checkpointModel
|
||||
}
|
||||
|
||||
return checkpointModel.providerID === currentModel.providerID &&
|
||||
checkpointModel.modelID === currentModel.modelID
|
||||
? checkpointModel
|
||||
: undefined
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
declare const require: (name: string) => any
|
||||
const { afterEach, describe, expect, test } = require("bun:test")
|
||||
import { afterEach, describe, expect, it } from "bun:test"
|
||||
|
||||
import { _resetForTesting, getSessionAgent, updateSessionAgent } from "../features/claude-code-session-state"
|
||||
import { clearSessionModel, getSessionModel, setSessionModel } from "../shared/session-model-state"
|
||||
import { createEventHandler } from "./event"
|
||||
|
||||
function createMinimalEventHandler() {
|
||||
@@ -51,9 +51,11 @@ function createMinimalEventHandler() {
|
||||
describe("createEventHandler compaction agent filtering", () => {
|
||||
afterEach(() => {
|
||||
_resetForTesting()
|
||||
clearSessionModel("ses_compaction_poisoning")
|
||||
clearSessionModel("ses_compaction_model_poisoning")
|
||||
})
|
||||
|
||||
test("does not overwrite the stored session agent with compaction", async () => {
|
||||
it("does not overwrite the stored session agent with compaction", async () => {
|
||||
// given
|
||||
const sessionID = "ses_compaction_poisoning"
|
||||
updateSessionAgent(sessionID, "atlas")
|
||||
@@ -80,4 +82,36 @@ describe("createEventHandler compaction agent filtering", () => {
|
||||
// then
|
||||
expect(getSessionAgent(sessionID)).toBe("atlas")
|
||||
})
|
||||
|
||||
it("does not overwrite the stored session model with compaction", async () => {
|
||||
// given
|
||||
const sessionID = "ses_compaction_model_poisoning"
|
||||
setSessionModel(sessionID, { providerID: "openai", modelID: "gpt-5" })
|
||||
const eventHandler = createMinimalEventHandler()
|
||||
const input: Parameters<ReturnType<typeof createEventHandler>>[0] = {
|
||||
event: {
|
||||
type: "message.updated",
|
||||
properties: {
|
||||
info: {
|
||||
id: "msg-compaction-model",
|
||||
sessionID,
|
||||
role: "user",
|
||||
agent: "compaction",
|
||||
providerID: "anthropic",
|
||||
modelID: "claude-opus-4-1",
|
||||
time: { created: Date.now() },
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// when
|
||||
await eventHandler(input)
|
||||
|
||||
// then
|
||||
expect(getSessionModel(sessionID)).toEqual({
|
||||
providerID: "openai",
|
||||
modelID: "gpt-5",
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -317,12 +317,13 @@ export function createEventHandler(args: {
|
||||
const agent = info?.agent as string | undefined;
|
||||
const role = info?.role as string | undefined;
|
||||
if (sessionID && role === "user") {
|
||||
if (agent && !isCompactionAgent(agent)) {
|
||||
const isCompactionMessage = agent ? isCompactionAgent(agent) : false;
|
||||
if (agent && !isCompactionMessage) {
|
||||
updateSessionAgent(sessionID, agent);
|
||||
}
|
||||
const providerID = info?.providerID as string | undefined;
|
||||
const modelID = info?.modelID as string | undefined;
|
||||
if (providerID && modelID) {
|
||||
if (providerID && modelID && !isCompactionMessage) {
|
||||
lastKnownModelBySession.set(sessionID, { providerID, modelID });
|
||||
setSessionModel(sessionID, { providerID, modelID });
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user