fix(session-recovery): filter invalid prt_* part IDs from tool_use_id reconstruction
When recovering missing tool results, the session recovery hook was using raw part.id (prt_* format) as tool_use_id when callID was absent, causing ZodError validation failures from the API. Added isValidToolUseID() guard that only accepts toolu_* and call_* prefixed IDs, and normalizeMessagePart() that returns null for parts without valid callIDs. Both the SQLite fallback and stored-parts paths now filter out invalid entries before constructing tool_result payloads. Includes 4 regression tests covering both valid/invalid callID paths for both SQLite and stored-parts backends.
This commit is contained in:
134
src/hooks/session-recovery/recover-tool-result-missing.test.ts
Normal file
134
src/hooks/session-recovery/recover-tool-result-missing.test.ts
Normal file
@@ -0,0 +1,134 @@
|
||||
const { describe, it, expect, mock, beforeEach } = require("bun:test")
|
||||
|
||||
import type { MessageData } from "./types"
|
||||
|
||||
let sqliteBackend = false
|
||||
let storedParts: Array<{ type: string; id?: string; callID?: string; [key: string]: unknown }> = []
|
||||
|
||||
mock.module("../../shared/opencode-storage-detection", () => ({
|
||||
isSqliteBackend: () => sqliteBackend,
|
||||
}))
|
||||
|
||||
mock.module("../../shared", () => ({
|
||||
normalizeSDKResponse: <TData>(response: { data?: TData }, fallback: TData): TData => response.data ?? fallback,
|
||||
}))
|
||||
|
||||
mock.module("./storage", () => ({
|
||||
readParts: () => storedParts,
|
||||
}))
|
||||
|
||||
const { recoverToolResultMissing } = await import("./recover-tool-result-missing")
|
||||
|
||||
function createMockClient(messages: MessageData[] = []) {
|
||||
const promptAsync = mock(() => Promise.resolve({}))
|
||||
|
||||
return {
|
||||
client: {
|
||||
session: {
|
||||
messages: mock(() => Promise.resolve({ data: messages })),
|
||||
promptAsync,
|
||||
},
|
||||
} as never,
|
||||
promptAsync,
|
||||
}
|
||||
}
|
||||
|
||||
const failedAssistantMsg: MessageData = {
|
||||
info: { id: "msg_failed", role: "assistant" },
|
||||
parts: [],
|
||||
}
|
||||
|
||||
describe("recoverToolResultMissing", () => {
|
||||
beforeEach(() => {
|
||||
sqliteBackend = false
|
||||
storedParts = []
|
||||
})
|
||||
|
||||
it("returns false for sqlite fallback when tool part has no valid callID", async () => {
|
||||
//#given
|
||||
sqliteBackend = true
|
||||
const { client, promptAsync } = createMockClient([
|
||||
{
|
||||
info: { id: "msg_failed", role: "assistant" },
|
||||
parts: [{ type: "tool", id: "prt_missing_call", name: "bash", input: {} }],
|
||||
},
|
||||
])
|
||||
|
||||
//#when
|
||||
const result = await recoverToolResultMissing(client, "ses_1", failedAssistantMsg)
|
||||
|
||||
//#then
|
||||
expect(result).toBe(false)
|
||||
expect(promptAsync).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it("sends the recovered sqlite tool result when callID is valid", async () => {
|
||||
//#given
|
||||
sqliteBackend = true
|
||||
const { client, promptAsync } = createMockClient([
|
||||
{
|
||||
info: { id: "msg_failed", role: "assistant" },
|
||||
parts: [{ type: "tool", id: "prt_valid_call", callID: "call_recovered", name: "bash", input: {} }],
|
||||
},
|
||||
])
|
||||
|
||||
//#when
|
||||
const result = await recoverToolResultMissing(client, "ses_1", failedAssistantMsg)
|
||||
|
||||
//#then
|
||||
expect(result).toBe(true)
|
||||
expect(promptAsync).toHaveBeenCalledWith({
|
||||
path: { id: "ses_1" },
|
||||
body: {
|
||||
parts: [{
|
||||
type: "tool_result",
|
||||
tool_use_id: "call_recovered",
|
||||
content: "Operation cancelled by user (ESC pressed)",
|
||||
}],
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it("returns false for stored parts when tool part has no valid callID", async () => {
|
||||
//#given
|
||||
storedParts = [{ type: "tool", id: "prt_stored_missing_call", tool: "bash", state: { input: {} } }]
|
||||
const { client, promptAsync } = createMockClient()
|
||||
|
||||
//#when
|
||||
const result = await recoverToolResultMissing(client, "ses_2", failedAssistantMsg)
|
||||
|
||||
//#then
|
||||
expect(result).toBe(false)
|
||||
expect(promptAsync).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it("sends the recovered stored tool result when callID is valid", async () => {
|
||||
//#given
|
||||
storedParts = [{
|
||||
type: "tool",
|
||||
id: "prt_stored_valid_call",
|
||||
callID: "toolu_recovered",
|
||||
tool: "bash",
|
||||
state: { input: {} },
|
||||
}]
|
||||
const { client, promptAsync } = createMockClient()
|
||||
|
||||
//#when
|
||||
const result = await recoverToolResultMissing(client, "ses_2", failedAssistantMsg)
|
||||
|
||||
//#then
|
||||
expect(result).toBe(true)
|
||||
expect(promptAsync).toHaveBeenCalledWith({
|
||||
path: { id: "ses_2" },
|
||||
body: {
|
||||
parts: [{
|
||||
type: "tool_result",
|
||||
tool_use_id: "toolu_recovered",
|
||||
content: "Operation cancelled by user (ESC pressed)",
|
||||
}],
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
export {}
|
||||
@@ -24,8 +24,30 @@ interface MessagePart {
|
||||
id?: string
|
||||
}
|
||||
|
||||
function isValidToolUseID(id: string | undefined): id is string {
|
||||
return typeof id === "string" && /^(toolu_|call_)/.test(id)
|
||||
}
|
||||
|
||||
function normalizeMessagePart(part: { type: string; id?: string; callID?: string }): MessagePart | null {
|
||||
if (part.type === "tool" || part.type === "tool_use") {
|
||||
if (!isValidToolUseID(part.callID)) {
|
||||
return null
|
||||
}
|
||||
|
||||
return {
|
||||
type: "tool_use",
|
||||
id: part.callID,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
type: part.type,
|
||||
id: part.id,
|
||||
}
|
||||
}
|
||||
|
||||
function extractToolUseIds(parts: MessagePart[]): string[] {
|
||||
return parts.filter((part): part is ToolUsePart => part.type === "tool_use" && !!part.id).map((part) => part.id)
|
||||
return parts.filter((part): part is ToolUsePart => part.type === "tool_use" && isValidToolUseID(part.id)).map((part) => part.id)
|
||||
}
|
||||
|
||||
async function readPartsFromSDKFallback(
|
||||
@@ -39,10 +61,7 @@ async function readPartsFromSDKFallback(
|
||||
const target = messages.find((m) => m.info?.id === messageID)
|
||||
if (!target?.parts) return []
|
||||
|
||||
return target.parts.map((part) => ({
|
||||
type: part.type === "tool" ? "tool_use" : part.type,
|
||||
id: "callID" in part ? (part as { callID?: string }).callID : part.id,
|
||||
}))
|
||||
return target.parts.map((part) => normalizeMessagePart(part)).filter((part): part is MessagePart => part !== null)
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
@@ -59,10 +78,7 @@ export async function recoverToolResultMissing(
|
||||
parts = await readPartsFromSDKFallback(client, sessionID, failedAssistantMsg.info.id)
|
||||
} else {
|
||||
const storedParts = readParts(failedAssistantMsg.info.id)
|
||||
parts = storedParts.map((part) => ({
|
||||
type: part.type === "tool" ? "tool_use" : part.type,
|
||||
id: "callID" in part ? (part as { callID?: string }).callID : part.id,
|
||||
}))
|
||||
parts = storedParts.map((part) => normalizeMessagePart(part)).filter((part): part is MessagePart => part !== null)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user