From b2497f13276ddf4562456212230777fd049c91dc Mon Sep 17 00:00:00 2001 From: YeonGyu-Kim Date: Sun, 29 Mar 2026 04:53:36 +0900 Subject: [PATCH] fix: resolve 3 community-reported bugs (#2915, #2917, #2918) - background_output: snapshot read cursor before consuming, restore on /undo message removal so re-reads return data (fixes #2915) - MCP loader: preserve oauth field in transformMcpServer, add scope/ projectPath filtering so local-scoped MCPs only load in matching directories (fixes #2917) - runtime-fallback: add 'reached your usage limit' to retryable error patterns so quota exhaustion triggers model fallback (fixes #2918) Verified: bun test (4606 pass / 0 fail), tsc --noEmit clean --- src/features/claude-code-mcp-loader/loader.ts | 11 ++ .../claude-code-mcp-loader/scope-filter.ts | 28 ++++ .../scope-filtering.test.ts | 82 +++++++++++ .../transformer.test.ts | 29 ++++ .../claude-code-mcp-loader/transformer.ts | 4 + src/features/claude-code-mcp-loader/types.ts | 13 +- .../mcp-server-loader.test.ts | 76 +++++++++++ .../mcp-server-loader.ts | 11 ++ src/hooks/runtime-fallback/constants.ts | 1 + .../runtime-fallback/error-classifier.test.ts | 11 ++ src/hooks/runtime-fallback/index.test.ts | 76 ++++++++++- src/plugin/event.ts | 13 ++ src/shared/background-output-consumption.ts | 69 ++++++++++ src/shared/session-cursor.ts | 25 +++- .../create-background-output.ts | 3 + .../create-background-output.undo.test.ts | 129 ++++++++++++++++++ 16 files changed, 575 insertions(+), 6 deletions(-) create mode 100644 src/features/claude-code-mcp-loader/scope-filter.ts create mode 100644 src/features/claude-code-mcp-loader/scope-filtering.test.ts create mode 100644 src/features/claude-code-mcp-loader/transformer.test.ts create mode 100644 src/features/claude-code-plugin-loader/mcp-server-loader.test.ts create mode 100644 src/shared/background-output-consumption.ts create mode 100644 src/tools/background-task/create-background-output.undo.test.ts diff --git a/src/features/claude-code-mcp-loader/loader.ts b/src/features/claude-code-mcp-loader/loader.ts index ae43dc137..6ccf08b42 100644 --- a/src/features/claude-code-mcp-loader/loader.ts +++ b/src/features/claude-code-mcp-loader/loader.ts @@ -10,6 +10,7 @@ import type { } from "./types" import { transformMcpServer } from "./transformer" import { log } from "../../shared/logger" +import { shouldLoadMcpServer } from "./scope-filter" interface McpConfigPath { path: string @@ -75,6 +76,7 @@ export async function loadMcpConfigs( const loadedServers: LoadedMcpServer[] = [] const paths = getMcpConfigPaths() const disabledSet = new Set(disabledMcps) + const cwd = process.cwd() for (const { path, scope } of paths) { const config = await loadMcpConfigFile(path) @@ -86,6 +88,15 @@ export async function loadMcpConfigs( continue } + if (!shouldLoadMcpServer(serverConfig, cwd)) { + log(`Skipping MCP server "${name}" because local scope does not match cwd`, { + path, + projectPath: serverConfig.projectPath, + cwd, + }) + continue + } + if (serverConfig.disabled) { log(`Disabling MCP server "${name}"`, { path }) delete servers[name] diff --git a/src/features/claude-code-mcp-loader/scope-filter.ts b/src/features/claude-code-mcp-loader/scope-filter.ts new file mode 100644 index 000000000..690421e0c --- /dev/null +++ b/src/features/claude-code-mcp-loader/scope-filter.ts @@ -0,0 +1,28 @@ +import { existsSync, realpathSync } from "fs" +import { resolve } from "path" +import type { ClaudeCodeMcpServer } from "./types" + +function normalizePath(path: string): string { + const resolvedPath = resolve(path) + + if (!existsSync(resolvedPath)) { + return resolvedPath + } + + return realpathSync(resolvedPath) +} + +export function shouldLoadMcpServer( + server: Pick, + cwd = process.cwd() +): boolean { + if (server.scope !== "local") { + return true + } + + if (!server.projectPath) { + return false + } + + return normalizePath(server.projectPath) === normalizePath(cwd) +} diff --git a/src/features/claude-code-mcp-loader/scope-filtering.test.ts b/src/features/claude-code-mcp-loader/scope-filtering.test.ts new file mode 100644 index 000000000..e90136b24 --- /dev/null +++ b/src/features/claude-code-mcp-loader/scope-filtering.test.ts @@ -0,0 +1,82 @@ +import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test" +import { mkdirSync, rmSync, writeFileSync } from "fs" +import { tmpdir } from "os" +import { join } from "path" + +const TEST_DIR = join(tmpdir(), `mcp-scope-filtering-test-${Date.now()}`) +const TEST_HOME = join(TEST_DIR, "home") + +describe("loadMcpConfigs", () => { + beforeEach(() => { + mkdirSync(TEST_DIR, { recursive: true }) + mkdirSync(TEST_HOME, { recursive: true }) + mock.module("os", () => ({ + homedir: () => TEST_HOME, + tmpdir, + })) + mock.module("../../shared", () => ({ + getClaudeConfigDir: () => join(TEST_HOME, ".claude"), + })) + mock.module("../../shared/logger", () => ({ + log: () => {}, + })) + }) + + afterEach(() => { + mock.restore() + rmSync(TEST_DIR, { recursive: true, force: true }) + }) + + describe("#given user-scoped MCP entries with local scope metadata", () => { + it("#when loading configs #then only servers matching the current project path are loaded", async () => { + writeFileSync( + join(TEST_HOME, ".claude.json"), + JSON.stringify({ + mcpServers: { + globalServer: { + command: "npx", + args: ["global-server"], + }, + matchingLocal: { + command: "npx", + args: ["matching-local"], + scope: "local", + projectPath: TEST_DIR, + }, + nonMatchingLocal: { + command: "npx", + args: ["non-matching-local"], + scope: "local", + projectPath: join(TEST_DIR, "other-project"), + }, + missingProjectPath: { + command: "npx", + args: ["missing-project-path"], + scope: "local", + }, + }, + }) + ) + + const originalCwd = process.cwd() + process.chdir(TEST_DIR) + + try { + const { loadMcpConfigs } = await import("./loader") + const result = await loadMcpConfigs() + + expect(result.servers).toHaveProperty("globalServer") + expect(result.servers).toHaveProperty("matchingLocal") + expect(result.servers).not.toHaveProperty("nonMatchingLocal") + expect(result.servers).not.toHaveProperty("missingProjectPath") + + expect(result.loadedServers.map((server) => server.name)).toEqual([ + "globalServer", + "matchingLocal", + ]) + } finally { + process.chdir(originalCwd) + } + }) + }) +}) diff --git a/src/features/claude-code-mcp-loader/transformer.test.ts b/src/features/claude-code-mcp-loader/transformer.test.ts new file mode 100644 index 000000000..fa4508372 --- /dev/null +++ b/src/features/claude-code-mcp-loader/transformer.test.ts @@ -0,0 +1,29 @@ +import { describe, expect, it } from "bun:test" +import { transformMcpServer } from "./transformer" + +describe("transformMcpServer", () => { + describe("#given a remote MCP server with oauth config", () => { + it("#when transforming the server #then preserves oauth on the remote config", () => { + const transformed = transformMcpServer("remote-oauth", { + type: "http", + url: "https://mcp.example.com", + headers: { Authorization: "Bearer test" }, + oauth: { + clientId: "client-id", + scopes: ["read", "write"], + }, + }) + + expect(transformed).toEqual({ + type: "remote", + url: "https://mcp.example.com", + headers: { Authorization: "Bearer test" }, + oauth: { + clientId: "client-id", + scopes: ["read", "write"], + }, + enabled: true, + }) + }) + }) +}) diff --git a/src/features/claude-code-mcp-loader/transformer.ts b/src/features/claude-code-mcp-loader/transformer.ts index f94e504b9..b08be7e34 100644 --- a/src/features/claude-code-mcp-loader/transformer.ts +++ b/src/features/claude-code-mcp-loader/transformer.ts @@ -30,6 +30,10 @@ export function transformMcpServer( config.headers = expanded.headers } + if (expanded.oauth && Object.keys(expanded.oauth).length > 0) { + config.oauth = expanded.oauth + } + return config } diff --git a/src/features/claude-code-mcp-loader/types.ts b/src/features/claude-code-mcp-loader/types.ts index 66822e8d2..ed89aa264 100644 --- a/src/features/claude-code-mcp-loader/types.ts +++ b/src/features/claude-code-mcp-loader/types.ts @@ -1,5 +1,10 @@ export type McpScope = "user" | "project" | "local" +export interface McpOAuthConfig { + clientId?: string + scopes?: string[] +} + export interface ClaudeCodeMcpServer { type?: "http" | "sse" | "stdio" url?: string @@ -7,10 +12,9 @@ export interface ClaudeCodeMcpServer { args?: string[] env?: Record headers?: Record - oauth?: { - clientId?: string - scopes?: string[] - } + oauth?: McpOAuthConfig + scope?: McpScope + projectPath?: string disabled?: boolean } @@ -29,6 +33,7 @@ export interface McpRemoteConfig { type: "remote" url: string headers?: Record + oauth?: McpOAuthConfig enabled?: boolean } diff --git a/src/features/claude-code-plugin-loader/mcp-server-loader.test.ts b/src/features/claude-code-plugin-loader/mcp-server-loader.test.ts new file mode 100644 index 000000000..7f474b4cc --- /dev/null +++ b/src/features/claude-code-plugin-loader/mcp-server-loader.test.ts @@ -0,0 +1,76 @@ +import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test" +import { mkdirSync, rmSync, writeFileSync } from "fs" +import { tmpdir } from "os" +import { join } from "path" +import type { LoadedPlugin } from "./types" + +const TEST_DIR = join(tmpdir(), `plugin-mcp-loader-test-${Date.now()}`) +const PROJECT_DIR = join(TEST_DIR, "project") +const PLUGIN_DIR = join(TEST_DIR, "plugin") +const MCP_CONFIG_PATH = join(PLUGIN_DIR, "mcp.json") + +describe("loadPluginMcpServers", () => { + beforeEach(() => { + mkdirSync(PROJECT_DIR, { recursive: true }) + mkdirSync(PLUGIN_DIR, { recursive: true }) + mock.module("../../shared/logger", () => ({ + log: () => {}, + })) + }) + + afterEach(() => { + mock.restore() + rmSync(TEST_DIR, { recursive: true, force: true }) + }) + + describe("#given plugin MCP entries with local scope metadata", () => { + it("#when loading plugin MCP servers #then only entries matching the current cwd are included", async () => { + writeFileSync( + MCP_CONFIG_PATH, + JSON.stringify({ + mcpServers: { + globalServer: { + command: "npx", + args: ["global-plugin-server"], + }, + matchingLocal: { + command: "npx", + args: ["matching-plugin-local"], + scope: "local", + projectPath: PROJECT_DIR, + }, + nonMatchingLocal: { + command: "npx", + args: ["non-matching-plugin-local"], + scope: "local", + projectPath: join(PROJECT_DIR, "other-project"), + }, + }, + }) + ) + + const plugin: LoadedPlugin = { + name: "demo-plugin", + version: "1.0.0", + scope: "project", + installPath: PLUGIN_DIR, + pluginKey: "demo-plugin@test", + mcpPath: MCP_CONFIG_PATH, + } + + const originalCwd = process.cwd() + process.chdir(PROJECT_DIR) + + try { + const { loadPluginMcpServers } = await import("./mcp-server-loader") + const servers = await loadPluginMcpServers([plugin]) + + expect(servers).toHaveProperty("demo-plugin:globalServer") + expect(servers).toHaveProperty("demo-plugin:matchingLocal") + expect(servers).not.toHaveProperty("demo-plugin:nonMatchingLocal") + } finally { + process.chdir(originalCwd) + } + }) + }) +}) diff --git a/src/features/claude-code-plugin-loader/mcp-server-loader.ts b/src/features/claude-code-plugin-loader/mcp-server-loader.ts index 9fcfba231..b0f0f8b8f 100644 --- a/src/features/claude-code-plugin-loader/mcp-server-loader.ts +++ b/src/features/claude-code-plugin-loader/mcp-server-loader.ts @@ -1,6 +1,7 @@ import { existsSync } from "fs" import type { McpServerConfig } from "../claude-code-mcp-loader/types" import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander" +import { shouldLoadMcpServer } from "../claude-code-mcp-loader/scope-filter" import { transformMcpServer } from "../claude-code-mcp-loader/transformer" import type { ClaudeCodeMcpConfig } from "../claude-code-mcp-loader/types" import { log } from "../../shared/logger" @@ -11,6 +12,7 @@ export async function loadPluginMcpServers( plugins: LoadedPlugin[], ): Promise> { const servers: Record = {} + const cwd = process.cwd() for (const plugin of plugins) { if (!plugin.mcpPath || !existsSync(plugin.mcpPath)) continue @@ -25,6 +27,15 @@ export async function loadPluginMcpServers( if (!config.mcpServers) continue for (const [name, serverConfig] of Object.entries(config.mcpServers)) { + if (!shouldLoadMcpServer(serverConfig, cwd)) { + log(`Skipping local plugin MCP server "${name}" outside current cwd`, { + path: plugin.mcpPath, + projectPath: serverConfig.projectPath, + cwd, + }) + continue + } + if (serverConfig.disabled) { log(`Skipping disabled MCP server "${name}" from plugin ${plugin.name}`) continue diff --git a/src/hooks/runtime-fallback/constants.ts b/src/hooks/runtime-fallback/constants.ts index eaf4e78c5..0e78fdbc7 100644 --- a/src/hooks/runtime-fallback/constants.ts +++ b/src/hooks/runtime-fallback/constants.ts @@ -27,6 +27,7 @@ export const RETRYABLE_ERROR_PATTERNS = [ /too.?many.?requests/i, /quota.?exceeded/i, /quota\s+will\s+reset\s+after/i, + /(?:you(?:'ve|\s+have)\s+)?reached\s+your\s+usage\s+limit/i, /all\s+credentials\s+for\s+model/i, /cool(?:ing)?\s+down/i, /exhausted\s+your\s+capacity/i, diff --git a/src/hooks/runtime-fallback/error-classifier.test.ts b/src/hooks/runtime-fallback/error-classifier.test.ts index f738ebce9..7a6ca8671 100644 --- a/src/hooks/runtime-fallback/error-classifier.test.ts +++ b/src/hooks/runtime-fallback/error-classifier.test.ts @@ -253,6 +253,17 @@ describe("quota error detection (fixes #2747)", () => { expect(retryable).toBe(true) }) + test("treats hard usage-limit wording as retryable", () => { + //#given + const error = { message: "You've reached your usage limit for this month. Please upgrade to continue." } + + //#when + const retryable = isRetryableError(error, [429, 503]) + + //#then + expect(retryable).toBe(true) + }) + test("classifies QuotaExceededError by errorName even without quota keywords in message", () => { //#given const error = { name: "QuotaExceededError", message: "Request failed." } diff --git a/src/hooks/runtime-fallback/index.test.ts b/src/hooks/runtime-fallback/index.test.ts index 7982b424f..fce27febe 100644 --- a/src/hooks/runtime-fallback/index.test.ts +++ b/src/hooks/runtime-fallback/index.test.ts @@ -64,6 +64,11 @@ describe("runtime-fallback", () => { function createMockPluginConfigWithCategoryFallback(fallbackModels: string[]): OhMyOpenCodeConfig { return { + git_master: { + commit_footer: true, + include_co_authored_by: true, + git_env_prefix: "GIT_MASTER=1", + }, categories: { test: { fallback_models: fallbackModels, @@ -79,6 +84,11 @@ describe("runtime-fallback", () => { variant?: string, ): OhMyOpenCodeConfig { return { + git_master: { + commit_footer: true, + include_co_authored_by: true, + git_env_prefix: "GIT_MASTER=1", + }, categories: { [categoryName]: { model, @@ -272,6 +282,39 @@ describe("runtime-fallback", () => { expect(errorLog).toBeDefined() }) + test("should trigger fallback when session.error says you've reached your usage limit", async () => { + const hook = createRuntimeFallbackHook(createMockPluginInput(), { + config: createMockConfig({ notify_on_fallback: false }), + pluginConfig: createMockPluginConfigWithCategoryFallback(["zai-coding-plan/glm-5.1"]), + }) + const sessionID = "test-session-usage-limit" + SessionCategoryRegistry.register(sessionID, "test") + + await hook.event({ + event: { + type: "session.created", + properties: { info: { id: sessionID, model: "kimi-for-coding/k2p5" } }, + }, + }) + + await hook.event({ + event: { + type: "session.error", + properties: { + sessionID, + error: { message: "You've reached your usage limit for this month. Please upgrade to continue." }, + }, + }, + }) + + const fallbackLog = logCalls.find((c) => c.msg.includes("Preparing fallback")) + expect(fallbackLog).toBeDefined() + expect(fallbackLog?.data).toMatchObject({ from: "kimi-for-coding/k2p5", to: "zai-coding-plan/glm-5.1" }) + + const skipLog = logCalls.find((c) => c.msg.includes("Error not retryable")) + expect(skipLog).toBeUndefined() + }) + test("should continue fallback chain when fallback model is not found", async () => { const hook = createRuntimeFallbackHook(createMockPluginInput(), { config: createMockConfig({ notify_on_fallback: false }), @@ -767,7 +810,13 @@ describe("runtime-fallback", () => { test("should log when no fallback models configured", async () => { const hook = createRuntimeFallbackHook(createMockPluginInput(), { config: createMockConfig(), - pluginConfig: {}, + pluginConfig: { + git_master: { + commit_footer: true, + include_co_authored_by: true, + git_env_prefix: "GIT_MASTER=1", + }, + }, }) const sessionID = "test-session-no-fallbacks" @@ -2299,6 +2348,11 @@ describe("runtime-fallback", () => { describe("fallback models configuration", () => { function createMockPluginConfigWithAgentFallback(agentName: string, fallbackModels: string[]): OhMyOpenCodeConfig { return { + git_master: { + commit_footer: true, + include_co_authored_by: true, + git_env_prefix: "GIT_MASTER=1", + }, agents: { [agentName]: { fallback_models: fallbackModels, @@ -2496,6 +2550,11 @@ describe("runtime-fallback", () => { { config: createMockConfig({ notify_on_fallback: false }), pluginConfig: { + git_master: { + commit_footer: true, + include_co_authored_by: true, + git_env_prefix: "GIT_MASTER=1", + }, categories: { test: { fallback_models: ["provider-a/model-a", "provider-b/model-b"], @@ -2548,6 +2607,11 @@ describe("runtime-fallback", () => { const hook = createRuntimeFallbackHook(createMockPluginInput(), { config: createMockConfig({ notify_on_fallback: false }), pluginConfig: { + git_master: { + commit_footer: true, + include_co_authored_by: true, + git_env_prefix: "GIT_MASTER=1", + }, categories: { test: { fallback_models: ["provider-a/model-a", "provider-b/model-b"], @@ -2605,6 +2669,11 @@ describe("runtime-fallback", () => { { config: createMockConfig({ notify_on_fallback: false }), pluginConfig: { + git_master: { + commit_footer: true, + include_co_authored_by: true, + git_env_prefix: "GIT_MASTER=1", + }, categories: { test: { fallback_models: ["provider-a/model-a", "provider-b/model-b"], @@ -2647,6 +2716,11 @@ describe("runtime-fallback", () => { const hook = createRuntimeFallbackHook(createMockPluginInput(), { config: createMockConfig({ notify_on_fallback: false }), pluginConfig: { + git_master: { + commit_footer: true, + include_co_authored_by: true, + git_env_prefix: "GIT_MASTER=1", + }, categories: { test: { fallback_models: ["provider-a/model-a", "provider-b/model-b"], diff --git a/src/plugin/event.ts b/src/plugin/event.ts index 52c3f403f..594f82919 100644 --- a/src/plugin/event.ts +++ b/src/plugin/event.ts @@ -17,6 +17,11 @@ import { setPendingModelFallback, } from "../hooks/model-fallback/hook"; import { getRawFallbackModels } from "../hooks/runtime-fallback/fallback-models"; +import { + clearBackgroundOutputConsumptionsForParentSession, + clearBackgroundOutputConsumptionsForTaskSession, + restoreBackgroundOutputConsumption, +} from "../shared/background-output-consumption"; import { resetMessageCursor } from "../shared"; import { getAgentConfigKey } from "../shared/agent-display-names"; import { readConnectedProvidersCache } from "../shared/connected-providers-cache"; @@ -366,6 +371,8 @@ export function createEventHandler(args: { clearPendingModelFallback(sessionInfo.id); clearSessionFallbackChain(sessionInfo.id); resetMessageCursor(sessionInfo.id); + clearBackgroundOutputConsumptionsForParentSession(sessionInfo.id); + clearBackgroundOutputConsumptionsForTaskSession(sessionInfo.id); firstMessageVariantGate.clear(sessionInfo.id); clearSessionModel(sessionInfo.id); clearSessionPromptParams(sessionInfo.id); @@ -382,6 +389,12 @@ export function createEventHandler(args: { } } + if (event.type === "message.removed") { + const messageID = props?.messageID as string | undefined; + const sessionID = props?.sessionID as string | undefined; + restoreBackgroundOutputConsumption(sessionID, messageID); + } + if (event.type === "message.updated") { const info = props?.info as Record | undefined; const sessionID = info?.sessionID as string | undefined; diff --git a/src/shared/background-output-consumption.ts b/src/shared/background-output-consumption.ts new file mode 100644 index 000000000..e672285a9 --- /dev/null +++ b/src/shared/background-output-consumption.ts @@ -0,0 +1,69 @@ +import { getMessageCursor, restoreMessageCursor, type CursorState } from "./session-cursor" + +type MessageConsumptionKey = `${string}:${string}` + +const cursorSnapshotsByMessage = new Map>() + +function getMessageKey(sessionID: string, messageID: string): MessageConsumptionKey { + return `${sessionID}:${messageID}` +} + +export function recordBackgroundOutputConsumption( + parentSessionID: string | undefined, + parentMessageID: string | undefined, + taskSessionID: string | undefined +): void { + if (!parentSessionID || !parentMessageID || !taskSessionID) return + + const messageKey = getMessageKey(parentSessionID, parentMessageID) + const existing = cursorSnapshotsByMessage.get(messageKey) ?? new Map() + + if (!cursorSnapshotsByMessage.has(messageKey)) { + cursorSnapshotsByMessage.set(messageKey, existing) + } + + if (existing.has(taskSessionID)) return + existing.set(taskSessionID, getMessageCursor(taskSessionID)) +} + +export function restoreBackgroundOutputConsumption( + parentSessionID: string | undefined, + parentMessageID: string | undefined +): void { + if (!parentSessionID || !parentMessageID) return + + const messageKey = getMessageKey(parentSessionID, parentMessageID) + const snapshots = cursorSnapshotsByMessage.get(messageKey) + if (!snapshots) return + + cursorSnapshotsByMessage.delete(messageKey) + for (const [taskSessionID, cursor] of snapshots) { + restoreMessageCursor(taskSessionID, cursor) + } +} + +export function clearBackgroundOutputConsumptionsForParentSession(sessionID: string | undefined): void { + if (!sessionID) return + + const prefix = `${sessionID}:` + for (const messageKey of cursorSnapshotsByMessage.keys()) { + if (messageKey.startsWith(prefix)) { + cursorSnapshotsByMessage.delete(messageKey) + } + } +} + +export function clearBackgroundOutputConsumptionsForTaskSession(taskSessionID: string | undefined): void { + if (!taskSessionID) return + + for (const [messageKey, snapshots] of cursorSnapshotsByMessage) { + snapshots.delete(taskSessionID) + if (snapshots.size === 0) { + cursorSnapshotsByMessage.delete(messageKey) + } + } +} + +export function clearBackgroundOutputConsumptionState(): void { + cursorSnapshotsByMessage.clear() +} diff --git a/src/shared/session-cursor.ts b/src/shared/session-cursor.ts index 37ec0bab5..9554c0d86 100644 --- a/src/shared/session-cursor.ts +++ b/src/shared/session-cursor.ts @@ -13,13 +13,21 @@ export type CursorMessage = { info?: MessageInfo } -interface CursorState { +export interface CursorState { lastKey?: string lastCount: number } const sessionCursors = new Map() +function cloneCursorState(state: CursorState | undefined): CursorState | undefined { + if (!state) return undefined + return { + lastKey: state.lastKey, + lastCount: state.lastCount, + } +} + function buildMessageKey(message: CursorMessage, index: number): string { const id = message.info?.id if (id) return `id:${id}` @@ -83,3 +91,18 @@ export function resetMessageCursor(sessionID?: string): void { } sessionCursors.clear() } + +export function getMessageCursor(sessionID: string | undefined): CursorState | undefined { + if (!sessionID) return undefined + return cloneCursorState(sessionCursors.get(sessionID)) +} + +export function restoreMessageCursor(sessionID: string | undefined, cursor: CursorState | undefined): void { + if (!sessionID) return + if (!cursor) { + sessionCursors.delete(sessionID) + return + } + + sessionCursors.set(sessionID, cloneCursorState(cursor)!) +} diff --git a/src/tools/background-task/create-background-output.ts b/src/tools/background-task/create-background-output.ts index f0e31696e..925db344a 100644 --- a/src/tools/background-task/create-background-output.ts +++ b/src/tools/background-task/create-background-output.ts @@ -10,11 +10,13 @@ import { formatTaskResult } from "./task-result-format" import { formatTaskStatus } from "./task-status-format" import { getAgentDisplayName } from "../../shared/agent-display-names" +import { recordBackgroundOutputConsumption } from "../../shared/background-output-consumption" const SISYPHUS_JUNIOR_AGENT = getAgentDisplayName("sisyphus-junior") type ToolContextWithMetadata = { sessionID: string + messageID?: string metadata?: (input: { title?: string; metadata?: Record }) => void callID?: string callId?: string @@ -139,6 +141,7 @@ export function createBackgroundOutput(manager: BackgroundOutputManager, client: } if (resolvedTask.status === "completed") { + recordBackgroundOutputConsumption(ctx.sessionID, ctx.messageID, resolvedTask.sessionID) return await formatTaskResult(resolvedTask, client) } diff --git a/src/tools/background-task/create-background-output.undo.test.ts b/src/tools/background-task/create-background-output.undo.test.ts new file mode 100644 index 000000000..c060cf473 --- /dev/null +++ b/src/tools/background-task/create-background-output.undo.test.ts @@ -0,0 +1,129 @@ +/// + +import { afterEach, describe, expect, test } from "bun:test" +import type { ToolContext } from "@opencode-ai/plugin/tool" +import type { BackgroundTask } from "../../features/background-agent" +import { createEventHandler } from "../../plugin/event" +import { clearBackgroundOutputConsumptionState } from "../../shared/background-output-consumption" +import { resetMessageCursor } from "../../shared/session-cursor" +import type { BackgroundOutputClient, BackgroundOutputManager } from "./clients" +import { createBackgroundOutput } from "./create-background-output" + +const projectDir = "/Users/yeongyu/local-workspaces/oh-my-opencode" + +const parentSessionID = "parent-session" +const taskSessionID = "task-session" + +type ToolContextWithCallID = ToolContext & { + callID: string +} + +const baseContext = { + sessionID: parentSessionID, + agent: "test-agent", + directory: projectDir, + worktree: projectDir, + abort: new AbortController().signal, + metadata: () => {}, + ask: async () => {}, + callID: "call-1", +} as const satisfies Partial + +function createTask(overrides: Partial = {}): BackgroundTask { + return { + id: "task-1", + sessionID: taskSessionID, + parentSessionID, + parentMessageID: "msg-parent", + description: "background task", + prompt: "do work", + agent: "test-agent", + status: "completed", + ...overrides, + } +} + +function createMockClient(): BackgroundOutputClient { + return { + session: { + messages: async () => ({ + data: [ + { + id: "m1", + info: { role: "assistant", time: "2026-01-01T00:00:00Z" }, + parts: [{ type: "text", text: "final result" }], + }, + ], + }), + }, + } +} + +function createMockEventHandler() { + return createEventHandler({ + ctx: {} as never, + pluginConfig: {} as never, + firstMessageVariantGate: { + markSessionCreated: () => {}, + clear: () => {}, + }, + managers: { + skillMcpManager: { + disconnectSession: async () => {}, + }, + tmuxSessionManager: { + onSessionCreated: async () => {}, + onSessionDeleted: async () => {}, + }, + } as never, + hooks: {} as never, + }) +} + +afterEach(() => { + resetMessageCursor(taskSessionID) + clearBackgroundOutputConsumptionState() +}) + +describe("createBackgroundOutput undo regression", () => { + test("#given consumed background output #when undo removes the parent message #then output can be consumed again", async () => { + // #given + const task = createTask() + const manager: BackgroundOutputManager = { + getTask: id => (id === task.id ? task : undefined), + } + const tool = createBackgroundOutput(manager, createMockClient()) + const eventHandler = createMockEventHandler() + + // #when + const firstOutput = await tool.execute( + { task_id: task.id }, + { ...baseContext, messageID: "msg-result-1" } as ToolContextWithCallID + ) + + const secondOutput = await tool.execute( + { task_id: task.id }, + { ...baseContext, callID: "call-2", messageID: "msg-result-2" } as ToolContextWithCallID + ) + + await eventHandler({ + event: { + type: "message.removed", + properties: { + sessionID: parentSessionID, + messageID: "msg-result-1", + }, + }, + }) + + const thirdOutput = await tool.execute( + { task_id: task.id }, + { ...baseContext, callID: "call-3", messageID: "msg-result-3" } as ToolContextWithCallID + ) + + // #then + expect(firstOutput).toContain("final result") + expect(secondOutput).toContain("No new output since last check") + expect(thirdOutput).toContain("final result") + }) +})