diff --git a/src/hooks/session-recovery/recover-tool-result-missing.test.ts b/src/hooks/session-recovery/recover-tool-result-missing.test.ts new file mode 100644 index 000000000..eac10fdd4 --- /dev/null +++ b/src/hooks/session-recovery/recover-tool-result-missing.test.ts @@ -0,0 +1,134 @@ +const { describe, it, expect, mock, beforeEach } = require("bun:test") + +import type { MessageData } from "./types" + +let sqliteBackend = false +let storedParts: Array<{ type: string; id?: string; callID?: string; [key: string]: unknown }> = [] + +mock.module("../../shared/opencode-storage-detection", () => ({ + isSqliteBackend: () => sqliteBackend, +})) + +mock.module("../../shared", () => ({ + normalizeSDKResponse: (response: { data?: TData }, fallback: TData): TData => response.data ?? fallback, +})) + +mock.module("./storage", () => ({ + readParts: () => storedParts, +})) + +const { recoverToolResultMissing } = await import("./recover-tool-result-missing") + +function createMockClient(messages: MessageData[] = []) { + const promptAsync = mock(() => Promise.resolve({})) + + return { + client: { + session: { + messages: mock(() => Promise.resolve({ data: messages })), + promptAsync, + }, + } as never, + promptAsync, + } +} + +const failedAssistantMsg: MessageData = { + info: { id: "msg_failed", role: "assistant" }, + parts: [], +} + +describe("recoverToolResultMissing", () => { + beforeEach(() => { + sqliteBackend = false + storedParts = [] + }) + + it("returns false for sqlite fallback when tool part has no valid callID", async () => { + //#given + sqliteBackend = true + const { client, promptAsync } = createMockClient([ + { + info: { id: "msg_failed", role: "assistant" }, + parts: [{ type: "tool", id: "prt_missing_call", name: "bash", input: {} }], + }, + ]) + + //#when + const result = await recoverToolResultMissing(client, "ses_1", failedAssistantMsg) + + //#then + expect(result).toBe(false) + expect(promptAsync).not.toHaveBeenCalled() + }) + + it("sends the recovered sqlite tool result when callID is valid", async () => { + //#given + sqliteBackend = true + const { client, promptAsync } = createMockClient([ + { + info: { id: "msg_failed", role: "assistant" }, + parts: [{ type: "tool", id: "prt_valid_call", callID: "call_recovered", name: "bash", input: {} }], + }, + ]) + + //#when + const result = await recoverToolResultMissing(client, "ses_1", failedAssistantMsg) + + //#then + expect(result).toBe(true) + expect(promptAsync).toHaveBeenCalledWith({ + path: { id: "ses_1" }, + body: { + parts: [{ + type: "tool_result", + tool_use_id: "call_recovered", + content: "Operation cancelled by user (ESC pressed)", + }], + }, + }) + }) + + it("returns false for stored parts when tool part has no valid callID", async () => { + //#given + storedParts = [{ type: "tool", id: "prt_stored_missing_call", tool: "bash", state: { input: {} } }] + const { client, promptAsync } = createMockClient() + + //#when + const result = await recoverToolResultMissing(client, "ses_2", failedAssistantMsg) + + //#then + expect(result).toBe(false) + expect(promptAsync).not.toHaveBeenCalled() + }) + + it("sends the recovered stored tool result when callID is valid", async () => { + //#given + storedParts = [{ + type: "tool", + id: "prt_stored_valid_call", + callID: "toolu_recovered", + tool: "bash", + state: { input: {} }, + }] + const { client, promptAsync } = createMockClient() + + //#when + const result = await recoverToolResultMissing(client, "ses_2", failedAssistantMsg) + + //#then + expect(result).toBe(true) + expect(promptAsync).toHaveBeenCalledWith({ + path: { id: "ses_2" }, + body: { + parts: [{ + type: "tool_result", + tool_use_id: "toolu_recovered", + content: "Operation cancelled by user (ESC pressed)", + }], + }, + }) + }) +}) + +export {} diff --git a/src/hooks/session-recovery/recover-tool-result-missing.ts b/src/hooks/session-recovery/recover-tool-result-missing.ts index 4d19880b3..c3d12da53 100644 --- a/src/hooks/session-recovery/recover-tool-result-missing.ts +++ b/src/hooks/session-recovery/recover-tool-result-missing.ts @@ -24,8 +24,30 @@ interface MessagePart { id?: string } +function isValidToolUseID(id: string | undefined): id is string { + return typeof id === "string" && /^(toolu_|call_)/.test(id) +} + +function normalizeMessagePart(part: { type: string; id?: string; callID?: string }): MessagePart | null { + if (part.type === "tool" || part.type === "tool_use") { + if (!isValidToolUseID(part.callID)) { + return null + } + + return { + type: "tool_use", + id: part.callID, + } + } + + return { + type: part.type, + id: part.id, + } +} + function extractToolUseIds(parts: MessagePart[]): string[] { - return parts.filter((part): part is ToolUsePart => part.type === "tool_use" && !!part.id).map((part) => part.id) + return parts.filter((part): part is ToolUsePart => part.type === "tool_use" && isValidToolUseID(part.id)).map((part) => part.id) } async function readPartsFromSDKFallback( @@ -39,10 +61,7 @@ async function readPartsFromSDKFallback( const target = messages.find((m) => m.info?.id === messageID) if (!target?.parts) return [] - return target.parts.map((part) => ({ - type: part.type === "tool" ? "tool_use" : part.type, - id: "callID" in part ? (part as { callID?: string }).callID : part.id, - })) + return target.parts.map((part) => normalizeMessagePart(part)).filter((part): part is MessagePart => part !== null) } catch { return [] } @@ -59,10 +78,7 @@ export async function recoverToolResultMissing( parts = await readPartsFromSDKFallback(client, sessionID, failedAssistantMsg.info.id) } else { const storedParts = readParts(failedAssistantMsg.info.id) - parts = storedParts.map((part) => ({ - type: part.type === "tool" ? "tool_use" : part.type, - id: "callID" in part ? (part as { callID?: string }).callID : part.id, - })) + parts = storedParts.map((part) => normalizeMessagePart(part)).filter((part): part is MessagePart => part !== null) } }