fix(session-recovery): filter invalid prt_* part IDs from tool_use_id reconstruction

When recovering missing tool results, the session recovery hook was using
raw part.id (prt_* format) as tool_use_id when callID was absent, causing
ZodError validation failures from the API.

Added isValidToolUseID() guard that only accepts toolu_* and call_* prefixed
IDs, and normalizeMessagePart() that returns null for parts without valid
callIDs. Both the SQLite fallback and stored-parts paths now filter out
invalid entries before constructing tool_result payloads.

Includes 4 regression tests covering both valid/invalid callID paths for
both SQLite and stored-parts backends.
This commit is contained in:
YeonGyu-Kim
2026-03-26 20:48:33 +09:00
parent 8e65d6cf2c
commit 3e13a4cf57
2 changed files with 159 additions and 9 deletions

View File

@@ -0,0 +1,134 @@
const { describe, it, expect, mock, beforeEach } = require("bun:test")
import type { MessageData } from "./types"
let sqliteBackend = false
let storedParts: Array<{ type: string; id?: string; callID?: string; [key: string]: unknown }> = []
mock.module("../../shared/opencode-storage-detection", () => ({
isSqliteBackend: () => sqliteBackend,
}))
mock.module("../../shared", () => ({
normalizeSDKResponse: <TData>(response: { data?: TData }, fallback: TData): TData => response.data ?? fallback,
}))
mock.module("./storage", () => ({
readParts: () => storedParts,
}))
const { recoverToolResultMissing } = await import("./recover-tool-result-missing")
function createMockClient(messages: MessageData[] = []) {
const promptAsync = mock(() => Promise.resolve({}))
return {
client: {
session: {
messages: mock(() => Promise.resolve({ data: messages })),
promptAsync,
},
} as never,
promptAsync,
}
}
const failedAssistantMsg: MessageData = {
info: { id: "msg_failed", role: "assistant" },
parts: [],
}
describe("recoverToolResultMissing", () => {
beforeEach(() => {
sqliteBackend = false
storedParts = []
})
it("returns false for sqlite fallback when tool part has no valid callID", async () => {
//#given
sqliteBackend = true
const { client, promptAsync } = createMockClient([
{
info: { id: "msg_failed", role: "assistant" },
parts: [{ type: "tool", id: "prt_missing_call", name: "bash", input: {} }],
},
])
//#when
const result = await recoverToolResultMissing(client, "ses_1", failedAssistantMsg)
//#then
expect(result).toBe(false)
expect(promptAsync).not.toHaveBeenCalled()
})
it("sends the recovered sqlite tool result when callID is valid", async () => {
//#given
sqliteBackend = true
const { client, promptAsync } = createMockClient([
{
info: { id: "msg_failed", role: "assistant" },
parts: [{ type: "tool", id: "prt_valid_call", callID: "call_recovered", name: "bash", input: {} }],
},
])
//#when
const result = await recoverToolResultMissing(client, "ses_1", failedAssistantMsg)
//#then
expect(result).toBe(true)
expect(promptAsync).toHaveBeenCalledWith({
path: { id: "ses_1" },
body: {
parts: [{
type: "tool_result",
tool_use_id: "call_recovered",
content: "Operation cancelled by user (ESC pressed)",
}],
},
})
})
it("returns false for stored parts when tool part has no valid callID", async () => {
//#given
storedParts = [{ type: "tool", id: "prt_stored_missing_call", tool: "bash", state: { input: {} } }]
const { client, promptAsync } = createMockClient()
//#when
const result = await recoverToolResultMissing(client, "ses_2", failedAssistantMsg)
//#then
expect(result).toBe(false)
expect(promptAsync).not.toHaveBeenCalled()
})
it("sends the recovered stored tool result when callID is valid", async () => {
//#given
storedParts = [{
type: "tool",
id: "prt_stored_valid_call",
callID: "toolu_recovered",
tool: "bash",
state: { input: {} },
}]
const { client, promptAsync } = createMockClient()
//#when
const result = await recoverToolResultMissing(client, "ses_2", failedAssistantMsg)
//#then
expect(result).toBe(true)
expect(promptAsync).toHaveBeenCalledWith({
path: { id: "ses_2" },
body: {
parts: [{
type: "tool_result",
tool_use_id: "toolu_recovered",
content: "Operation cancelled by user (ESC pressed)",
}],
},
})
})
})
export {}

View File

@@ -24,8 +24,30 @@ interface MessagePart {
id?: string
}
function isValidToolUseID(id: string | undefined): id is string {
return typeof id === "string" && /^(toolu_|call_)/.test(id)
}
function normalizeMessagePart(part: { type: string; id?: string; callID?: string }): MessagePart | null {
if (part.type === "tool" || part.type === "tool_use") {
if (!isValidToolUseID(part.callID)) {
return null
}
return {
type: "tool_use",
id: part.callID,
}
}
return {
type: part.type,
id: part.id,
}
}
function extractToolUseIds(parts: MessagePart[]): string[] {
return parts.filter((part): part is ToolUsePart => part.type === "tool_use" && !!part.id).map((part) => part.id)
return parts.filter((part): part is ToolUsePart => part.type === "tool_use" && isValidToolUseID(part.id)).map((part) => part.id)
}
async function readPartsFromSDKFallback(
@@ -39,10 +61,7 @@ async function readPartsFromSDKFallback(
const target = messages.find((m) => m.info?.id === messageID)
if (!target?.parts) return []
return target.parts.map((part) => ({
type: part.type === "tool" ? "tool_use" : part.type,
id: "callID" in part ? (part as { callID?: string }).callID : part.id,
}))
return target.parts.map((part) => normalizeMessagePart(part)).filter((part): part is MessagePart => part !== null)
} catch {
return []
}
@@ -59,10 +78,7 @@ export async function recoverToolResultMissing(
parts = await readPartsFromSDKFallback(client, sessionID, failedAssistantMsg.info.id)
} else {
const storedParts = readParts(failedAssistantMsg.info.id)
parts = storedParts.map((part) => ({
type: part.type === "tool" ? "tool_use" : part.type,
id: "callID" in part ? (part as { callID?: string }).callID : part.id,
}))
parts = storedParts.map((part) => normalizeMessagePart(part)).filter((part): part is MessagePart => part !== null)
}
}