fix: resolve 3 community-reported bugs (#2915, #2917, #2918)

- background_output: snapshot read cursor before consuming, restore on
  /undo message removal so re-reads return data (fixes #2915)
- MCP loader: preserve oauth field in transformMcpServer, add scope/
  projectPath filtering so local-scoped MCPs only load in matching
  directories (fixes #2917)
- runtime-fallback: add 'reached your usage limit' to retryable error
  patterns so quota exhaustion triggers model fallback (fixes #2918)

Verified: bun test (4606 pass / 0 fail), tsc --noEmit clean
This commit is contained in:
YeonGyu-Kim
2026-03-29 04:53:36 +09:00
parent 9fc56ab544
commit b2497f1327
16 changed files with 575 additions and 6 deletions

View File

@@ -10,6 +10,7 @@ import type {
} from "./types" } from "./types"
import { transformMcpServer } from "./transformer" import { transformMcpServer } from "./transformer"
import { log } from "../../shared/logger" import { log } from "../../shared/logger"
import { shouldLoadMcpServer } from "./scope-filter"
interface McpConfigPath { interface McpConfigPath {
path: string path: string
@@ -75,6 +76,7 @@ export async function loadMcpConfigs(
const loadedServers: LoadedMcpServer[] = [] const loadedServers: LoadedMcpServer[] = []
const paths = getMcpConfigPaths() const paths = getMcpConfigPaths()
const disabledSet = new Set(disabledMcps) const disabledSet = new Set(disabledMcps)
const cwd = process.cwd()
for (const { path, scope } of paths) { for (const { path, scope } of paths) {
const config = await loadMcpConfigFile(path) const config = await loadMcpConfigFile(path)
@@ -86,6 +88,15 @@ export async function loadMcpConfigs(
continue continue
} }
if (!shouldLoadMcpServer(serverConfig, cwd)) {
log(`Skipping MCP server "${name}" because local scope does not match cwd`, {
path,
projectPath: serverConfig.projectPath,
cwd,
})
continue
}
if (serverConfig.disabled) { if (serverConfig.disabled) {
log(`Disabling MCP server "${name}"`, { path }) log(`Disabling MCP server "${name}"`, { path })
delete servers[name] delete servers[name]

View File

@@ -0,0 +1,28 @@
import { existsSync, realpathSync } from "fs"
import { resolve } from "path"
import type { ClaudeCodeMcpServer } from "./types"
function normalizePath(path: string): string {
const resolvedPath = resolve(path)
if (!existsSync(resolvedPath)) {
return resolvedPath
}
return realpathSync(resolvedPath)
}
export function shouldLoadMcpServer(
server: Pick<ClaudeCodeMcpServer, "scope" | "projectPath">,
cwd = process.cwd()
): boolean {
if (server.scope !== "local") {
return true
}
if (!server.projectPath) {
return false
}
return normalizePath(server.projectPath) === normalizePath(cwd)
}

View File

@@ -0,0 +1,82 @@
import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"
import { mkdirSync, rmSync, writeFileSync } from "fs"
import { tmpdir } from "os"
import { join } from "path"
const TEST_DIR = join(tmpdir(), `mcp-scope-filtering-test-${Date.now()}`)
const TEST_HOME = join(TEST_DIR, "home")
describe("loadMcpConfigs", () => {
beforeEach(() => {
mkdirSync(TEST_DIR, { recursive: true })
mkdirSync(TEST_HOME, { recursive: true })
mock.module("os", () => ({
homedir: () => TEST_HOME,
tmpdir,
}))
mock.module("../../shared", () => ({
getClaudeConfigDir: () => join(TEST_HOME, ".claude"),
}))
mock.module("../../shared/logger", () => ({
log: () => {},
}))
})
afterEach(() => {
mock.restore()
rmSync(TEST_DIR, { recursive: true, force: true })
})
describe("#given user-scoped MCP entries with local scope metadata", () => {
it("#when loading configs #then only servers matching the current project path are loaded", async () => {
writeFileSync(
join(TEST_HOME, ".claude.json"),
JSON.stringify({
mcpServers: {
globalServer: {
command: "npx",
args: ["global-server"],
},
matchingLocal: {
command: "npx",
args: ["matching-local"],
scope: "local",
projectPath: TEST_DIR,
},
nonMatchingLocal: {
command: "npx",
args: ["non-matching-local"],
scope: "local",
projectPath: join(TEST_DIR, "other-project"),
},
missingProjectPath: {
command: "npx",
args: ["missing-project-path"],
scope: "local",
},
},
})
)
const originalCwd = process.cwd()
process.chdir(TEST_DIR)
try {
const { loadMcpConfigs } = await import("./loader")
const result = await loadMcpConfigs()
expect(result.servers).toHaveProperty("globalServer")
expect(result.servers).toHaveProperty("matchingLocal")
expect(result.servers).not.toHaveProperty("nonMatchingLocal")
expect(result.servers).not.toHaveProperty("missingProjectPath")
expect(result.loadedServers.map((server) => server.name)).toEqual([
"globalServer",
"matchingLocal",
])
} finally {
process.chdir(originalCwd)
}
})
})
})

View File

@@ -0,0 +1,29 @@
import { describe, expect, it } from "bun:test"
import { transformMcpServer } from "./transformer"
describe("transformMcpServer", () => {
describe("#given a remote MCP server with oauth config", () => {
it("#when transforming the server #then preserves oauth on the remote config", () => {
const transformed = transformMcpServer("remote-oauth", {
type: "http",
url: "https://mcp.example.com",
headers: { Authorization: "Bearer test" },
oauth: {
clientId: "client-id",
scopes: ["read", "write"],
},
})
expect(transformed).toEqual({
type: "remote",
url: "https://mcp.example.com",
headers: { Authorization: "Bearer test" },
oauth: {
clientId: "client-id",
scopes: ["read", "write"],
},
enabled: true,
})
})
})
})

View File

@@ -30,6 +30,10 @@ export function transformMcpServer(
config.headers = expanded.headers config.headers = expanded.headers
} }
if (expanded.oauth && Object.keys(expanded.oauth).length > 0) {
config.oauth = expanded.oauth
}
return config return config
} }

View File

@@ -1,5 +1,10 @@
export type McpScope = "user" | "project" | "local" export type McpScope = "user" | "project" | "local"
export interface McpOAuthConfig {
clientId?: string
scopes?: string[]
}
export interface ClaudeCodeMcpServer { export interface ClaudeCodeMcpServer {
type?: "http" | "sse" | "stdio" type?: "http" | "sse" | "stdio"
url?: string url?: string
@@ -7,10 +12,9 @@ export interface ClaudeCodeMcpServer {
args?: string[] args?: string[]
env?: Record<string, string> env?: Record<string, string>
headers?: Record<string, string> headers?: Record<string, string>
oauth?: { oauth?: McpOAuthConfig
clientId?: string scope?: McpScope
scopes?: string[] projectPath?: string
}
disabled?: boolean disabled?: boolean
} }
@@ -29,6 +33,7 @@ export interface McpRemoteConfig {
type: "remote" type: "remote"
url: string url: string
headers?: Record<string, string> headers?: Record<string, string>
oauth?: McpOAuthConfig
enabled?: boolean enabled?: boolean
} }

View File

@@ -0,0 +1,76 @@
import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"
import { mkdirSync, rmSync, writeFileSync } from "fs"
import { tmpdir } from "os"
import { join } from "path"
import type { LoadedPlugin } from "./types"
const TEST_DIR = join(tmpdir(), `plugin-mcp-loader-test-${Date.now()}`)
const PROJECT_DIR = join(TEST_DIR, "project")
const PLUGIN_DIR = join(TEST_DIR, "plugin")
const MCP_CONFIG_PATH = join(PLUGIN_DIR, "mcp.json")
describe("loadPluginMcpServers", () => {
beforeEach(() => {
mkdirSync(PROJECT_DIR, { recursive: true })
mkdirSync(PLUGIN_DIR, { recursive: true })
mock.module("../../shared/logger", () => ({
log: () => {},
}))
})
afterEach(() => {
mock.restore()
rmSync(TEST_DIR, { recursive: true, force: true })
})
describe("#given plugin MCP entries with local scope metadata", () => {
it("#when loading plugin MCP servers #then only entries matching the current cwd are included", async () => {
writeFileSync(
MCP_CONFIG_PATH,
JSON.stringify({
mcpServers: {
globalServer: {
command: "npx",
args: ["global-plugin-server"],
},
matchingLocal: {
command: "npx",
args: ["matching-plugin-local"],
scope: "local",
projectPath: PROJECT_DIR,
},
nonMatchingLocal: {
command: "npx",
args: ["non-matching-plugin-local"],
scope: "local",
projectPath: join(PROJECT_DIR, "other-project"),
},
},
})
)
const plugin: LoadedPlugin = {
name: "demo-plugin",
version: "1.0.0",
scope: "project",
installPath: PLUGIN_DIR,
pluginKey: "demo-plugin@test",
mcpPath: MCP_CONFIG_PATH,
}
const originalCwd = process.cwd()
process.chdir(PROJECT_DIR)
try {
const { loadPluginMcpServers } = await import("./mcp-server-loader")
const servers = await loadPluginMcpServers([plugin])
expect(servers).toHaveProperty("demo-plugin:globalServer")
expect(servers).toHaveProperty("demo-plugin:matchingLocal")
expect(servers).not.toHaveProperty("demo-plugin:nonMatchingLocal")
} finally {
process.chdir(originalCwd)
}
})
})
})

View File

@@ -1,6 +1,7 @@
import { existsSync } from "fs" import { existsSync } from "fs"
import type { McpServerConfig } from "../claude-code-mcp-loader/types" import type { McpServerConfig } from "../claude-code-mcp-loader/types"
import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander" import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander"
import { shouldLoadMcpServer } from "../claude-code-mcp-loader/scope-filter"
import { transformMcpServer } from "../claude-code-mcp-loader/transformer" import { transformMcpServer } from "../claude-code-mcp-loader/transformer"
import type { ClaudeCodeMcpConfig } from "../claude-code-mcp-loader/types" import type { ClaudeCodeMcpConfig } from "../claude-code-mcp-loader/types"
import { log } from "../../shared/logger" import { log } from "../../shared/logger"
@@ -11,6 +12,7 @@ export async function loadPluginMcpServers(
plugins: LoadedPlugin[], plugins: LoadedPlugin[],
): Promise<Record<string, McpServerConfig>> { ): Promise<Record<string, McpServerConfig>> {
const servers: Record<string, McpServerConfig> = {} const servers: Record<string, McpServerConfig> = {}
const cwd = process.cwd()
for (const plugin of plugins) { for (const plugin of plugins) {
if (!plugin.mcpPath || !existsSync(plugin.mcpPath)) continue if (!plugin.mcpPath || !existsSync(plugin.mcpPath)) continue
@@ -25,6 +27,15 @@ export async function loadPluginMcpServers(
if (!config.mcpServers) continue if (!config.mcpServers) continue
for (const [name, serverConfig] of Object.entries(config.mcpServers)) { for (const [name, serverConfig] of Object.entries(config.mcpServers)) {
if (!shouldLoadMcpServer(serverConfig, cwd)) {
log(`Skipping local plugin MCP server "${name}" outside current cwd`, {
path: plugin.mcpPath,
projectPath: serverConfig.projectPath,
cwd,
})
continue
}
if (serverConfig.disabled) { if (serverConfig.disabled) {
log(`Skipping disabled MCP server "${name}" from plugin ${plugin.name}`) log(`Skipping disabled MCP server "${name}" from plugin ${plugin.name}`)
continue continue

View File

@@ -27,6 +27,7 @@ export const RETRYABLE_ERROR_PATTERNS = [
/too.?many.?requests/i, /too.?many.?requests/i,
/quota.?exceeded/i, /quota.?exceeded/i,
/quota\s+will\s+reset\s+after/i, /quota\s+will\s+reset\s+after/i,
/(?:you(?:'ve|\s+have)\s+)?reached\s+your\s+usage\s+limit/i,
/all\s+credentials\s+for\s+model/i, /all\s+credentials\s+for\s+model/i,
/cool(?:ing)?\s+down/i, /cool(?:ing)?\s+down/i,
/exhausted\s+your\s+capacity/i, /exhausted\s+your\s+capacity/i,

View File

@@ -253,6 +253,17 @@ describe("quota error detection (fixes #2747)", () => {
expect(retryable).toBe(true) expect(retryable).toBe(true)
}) })
test("treats hard usage-limit wording as retryable", () => {
//#given
const error = { message: "You've reached your usage limit for this month. Please upgrade to continue." }
//#when
const retryable = isRetryableError(error, [429, 503])
//#then
expect(retryable).toBe(true)
})
test("classifies QuotaExceededError by errorName even without quota keywords in message", () => { test("classifies QuotaExceededError by errorName even without quota keywords in message", () => {
//#given //#given
const error = { name: "QuotaExceededError", message: "Request failed." } const error = { name: "QuotaExceededError", message: "Request failed." }

View File

@@ -64,6 +64,11 @@ describe("runtime-fallback", () => {
function createMockPluginConfigWithCategoryFallback(fallbackModels: string[]): OhMyOpenCodeConfig { function createMockPluginConfigWithCategoryFallback(fallbackModels: string[]): OhMyOpenCodeConfig {
return { return {
git_master: {
commit_footer: true,
include_co_authored_by: true,
git_env_prefix: "GIT_MASTER=1",
},
categories: { categories: {
test: { test: {
fallback_models: fallbackModels, fallback_models: fallbackModels,
@@ -79,6 +84,11 @@ describe("runtime-fallback", () => {
variant?: string, variant?: string,
): OhMyOpenCodeConfig { ): OhMyOpenCodeConfig {
return { return {
git_master: {
commit_footer: true,
include_co_authored_by: true,
git_env_prefix: "GIT_MASTER=1",
},
categories: { categories: {
[categoryName]: { [categoryName]: {
model, model,
@@ -272,6 +282,39 @@ describe("runtime-fallback", () => {
expect(errorLog).toBeDefined() expect(errorLog).toBeDefined()
}) })
test("should trigger fallback when session.error says you've reached your usage limit", async () => {
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
config: createMockConfig({ notify_on_fallback: false }),
pluginConfig: createMockPluginConfigWithCategoryFallback(["zai-coding-plan/glm-5.1"]),
})
const sessionID = "test-session-usage-limit"
SessionCategoryRegistry.register(sessionID, "test")
await hook.event({
event: {
type: "session.created",
properties: { info: { id: sessionID, model: "kimi-for-coding/k2p5" } },
},
})
await hook.event({
event: {
type: "session.error",
properties: {
sessionID,
error: { message: "You've reached your usage limit for this month. Please upgrade to continue." },
},
},
})
const fallbackLog = logCalls.find((c) => c.msg.includes("Preparing fallback"))
expect(fallbackLog).toBeDefined()
expect(fallbackLog?.data).toMatchObject({ from: "kimi-for-coding/k2p5", to: "zai-coding-plan/glm-5.1" })
const skipLog = logCalls.find((c) => c.msg.includes("Error not retryable"))
expect(skipLog).toBeUndefined()
})
test("should continue fallback chain when fallback model is not found", async () => { test("should continue fallback chain when fallback model is not found", async () => {
const hook = createRuntimeFallbackHook(createMockPluginInput(), { const hook = createRuntimeFallbackHook(createMockPluginInput(), {
config: createMockConfig({ notify_on_fallback: false }), config: createMockConfig({ notify_on_fallback: false }),
@@ -767,7 +810,13 @@ describe("runtime-fallback", () => {
test("should log when no fallback models configured", async () => { test("should log when no fallback models configured", async () => {
const hook = createRuntimeFallbackHook(createMockPluginInput(), { const hook = createRuntimeFallbackHook(createMockPluginInput(), {
config: createMockConfig(), config: createMockConfig(),
pluginConfig: {}, pluginConfig: {
git_master: {
commit_footer: true,
include_co_authored_by: true,
git_env_prefix: "GIT_MASTER=1",
},
},
}) })
const sessionID = "test-session-no-fallbacks" const sessionID = "test-session-no-fallbacks"
@@ -2299,6 +2348,11 @@ describe("runtime-fallback", () => {
describe("fallback models configuration", () => { describe("fallback models configuration", () => {
function createMockPluginConfigWithAgentFallback(agentName: string, fallbackModels: string[]): OhMyOpenCodeConfig { function createMockPluginConfigWithAgentFallback(agentName: string, fallbackModels: string[]): OhMyOpenCodeConfig {
return { return {
git_master: {
commit_footer: true,
include_co_authored_by: true,
git_env_prefix: "GIT_MASTER=1",
},
agents: { agents: {
[agentName]: { [agentName]: {
fallback_models: fallbackModels, fallback_models: fallbackModels,
@@ -2496,6 +2550,11 @@ describe("runtime-fallback", () => {
{ {
config: createMockConfig({ notify_on_fallback: false }), config: createMockConfig({ notify_on_fallback: false }),
pluginConfig: { pluginConfig: {
git_master: {
commit_footer: true,
include_co_authored_by: true,
git_env_prefix: "GIT_MASTER=1",
},
categories: { categories: {
test: { test: {
fallback_models: ["provider-a/model-a", "provider-b/model-b"], fallback_models: ["provider-a/model-a", "provider-b/model-b"],
@@ -2548,6 +2607,11 @@ describe("runtime-fallback", () => {
const hook = createRuntimeFallbackHook(createMockPluginInput(), { const hook = createRuntimeFallbackHook(createMockPluginInput(), {
config: createMockConfig({ notify_on_fallback: false }), config: createMockConfig({ notify_on_fallback: false }),
pluginConfig: { pluginConfig: {
git_master: {
commit_footer: true,
include_co_authored_by: true,
git_env_prefix: "GIT_MASTER=1",
},
categories: { categories: {
test: { test: {
fallback_models: ["provider-a/model-a", "provider-b/model-b"], fallback_models: ["provider-a/model-a", "provider-b/model-b"],
@@ -2605,6 +2669,11 @@ describe("runtime-fallback", () => {
{ {
config: createMockConfig({ notify_on_fallback: false }), config: createMockConfig({ notify_on_fallback: false }),
pluginConfig: { pluginConfig: {
git_master: {
commit_footer: true,
include_co_authored_by: true,
git_env_prefix: "GIT_MASTER=1",
},
categories: { categories: {
test: { test: {
fallback_models: ["provider-a/model-a", "provider-b/model-b"], fallback_models: ["provider-a/model-a", "provider-b/model-b"],
@@ -2647,6 +2716,11 @@ describe("runtime-fallback", () => {
const hook = createRuntimeFallbackHook(createMockPluginInput(), { const hook = createRuntimeFallbackHook(createMockPluginInput(), {
config: createMockConfig({ notify_on_fallback: false }), config: createMockConfig({ notify_on_fallback: false }),
pluginConfig: { pluginConfig: {
git_master: {
commit_footer: true,
include_co_authored_by: true,
git_env_prefix: "GIT_MASTER=1",
},
categories: { categories: {
test: { test: {
fallback_models: ["provider-a/model-a", "provider-b/model-b"], fallback_models: ["provider-a/model-a", "provider-b/model-b"],

View File

@@ -17,6 +17,11 @@ import {
setPendingModelFallback, setPendingModelFallback,
} from "../hooks/model-fallback/hook"; } from "../hooks/model-fallback/hook";
import { getRawFallbackModels } from "../hooks/runtime-fallback/fallback-models"; import { getRawFallbackModels } from "../hooks/runtime-fallback/fallback-models";
import {
clearBackgroundOutputConsumptionsForParentSession,
clearBackgroundOutputConsumptionsForTaskSession,
restoreBackgroundOutputConsumption,
} from "../shared/background-output-consumption";
import { resetMessageCursor } from "../shared"; import { resetMessageCursor } from "../shared";
import { getAgentConfigKey } from "../shared/agent-display-names"; import { getAgentConfigKey } from "../shared/agent-display-names";
import { readConnectedProvidersCache } from "../shared/connected-providers-cache"; import { readConnectedProvidersCache } from "../shared/connected-providers-cache";
@@ -366,6 +371,8 @@ export function createEventHandler(args: {
clearPendingModelFallback(sessionInfo.id); clearPendingModelFallback(sessionInfo.id);
clearSessionFallbackChain(sessionInfo.id); clearSessionFallbackChain(sessionInfo.id);
resetMessageCursor(sessionInfo.id); resetMessageCursor(sessionInfo.id);
clearBackgroundOutputConsumptionsForParentSession(sessionInfo.id);
clearBackgroundOutputConsumptionsForTaskSession(sessionInfo.id);
firstMessageVariantGate.clear(sessionInfo.id); firstMessageVariantGate.clear(sessionInfo.id);
clearSessionModel(sessionInfo.id); clearSessionModel(sessionInfo.id);
clearSessionPromptParams(sessionInfo.id); clearSessionPromptParams(sessionInfo.id);
@@ -382,6 +389,12 @@ export function createEventHandler(args: {
} }
} }
if (event.type === "message.removed") {
const messageID = props?.messageID as string | undefined;
const sessionID = props?.sessionID as string | undefined;
restoreBackgroundOutputConsumption(sessionID, messageID);
}
if (event.type === "message.updated") { if (event.type === "message.updated") {
const info = props?.info as Record<string, unknown> | undefined; const info = props?.info as Record<string, unknown> | undefined;
const sessionID = info?.sessionID as string | undefined; const sessionID = info?.sessionID as string | undefined;

View File

@@ -0,0 +1,69 @@
import { getMessageCursor, restoreMessageCursor, type CursorState } from "./session-cursor"
type MessageConsumptionKey = `${string}:${string}`
const cursorSnapshotsByMessage = new Map<MessageConsumptionKey, Map<string, CursorState | undefined>>()
function getMessageKey(sessionID: string, messageID: string): MessageConsumptionKey {
return `${sessionID}:${messageID}`
}
export function recordBackgroundOutputConsumption(
parentSessionID: string | undefined,
parentMessageID: string | undefined,
taskSessionID: string | undefined
): void {
if (!parentSessionID || !parentMessageID || !taskSessionID) return
const messageKey = getMessageKey(parentSessionID, parentMessageID)
const existing = cursorSnapshotsByMessage.get(messageKey) ?? new Map<string, CursorState | undefined>()
if (!cursorSnapshotsByMessage.has(messageKey)) {
cursorSnapshotsByMessage.set(messageKey, existing)
}
if (existing.has(taskSessionID)) return
existing.set(taskSessionID, getMessageCursor(taskSessionID))
}
export function restoreBackgroundOutputConsumption(
parentSessionID: string | undefined,
parentMessageID: string | undefined
): void {
if (!parentSessionID || !parentMessageID) return
const messageKey = getMessageKey(parentSessionID, parentMessageID)
const snapshots = cursorSnapshotsByMessage.get(messageKey)
if (!snapshots) return
cursorSnapshotsByMessage.delete(messageKey)
for (const [taskSessionID, cursor] of snapshots) {
restoreMessageCursor(taskSessionID, cursor)
}
}
export function clearBackgroundOutputConsumptionsForParentSession(sessionID: string | undefined): void {
if (!sessionID) return
const prefix = `${sessionID}:`
for (const messageKey of cursorSnapshotsByMessage.keys()) {
if (messageKey.startsWith(prefix)) {
cursorSnapshotsByMessage.delete(messageKey)
}
}
}
export function clearBackgroundOutputConsumptionsForTaskSession(taskSessionID: string | undefined): void {
if (!taskSessionID) return
for (const [messageKey, snapshots] of cursorSnapshotsByMessage) {
snapshots.delete(taskSessionID)
if (snapshots.size === 0) {
cursorSnapshotsByMessage.delete(messageKey)
}
}
}
export function clearBackgroundOutputConsumptionState(): void {
cursorSnapshotsByMessage.clear()
}

View File

@@ -13,13 +13,21 @@ export type CursorMessage = {
info?: MessageInfo info?: MessageInfo
} }
interface CursorState { export interface CursorState {
lastKey?: string lastKey?: string
lastCount: number lastCount: number
} }
const sessionCursors = new Map<string, CursorState>() const sessionCursors = new Map<string, CursorState>()
function cloneCursorState(state: CursorState | undefined): CursorState | undefined {
if (!state) return undefined
return {
lastKey: state.lastKey,
lastCount: state.lastCount,
}
}
function buildMessageKey(message: CursorMessage, index: number): string { function buildMessageKey(message: CursorMessage, index: number): string {
const id = message.info?.id const id = message.info?.id
if (id) return `id:${id}` if (id) return `id:${id}`
@@ -83,3 +91,18 @@ export function resetMessageCursor(sessionID?: string): void {
} }
sessionCursors.clear() sessionCursors.clear()
} }
export function getMessageCursor(sessionID: string | undefined): CursorState | undefined {
if (!sessionID) return undefined
return cloneCursorState(sessionCursors.get(sessionID))
}
export function restoreMessageCursor(sessionID: string | undefined, cursor: CursorState | undefined): void {
if (!sessionID) return
if (!cursor) {
sessionCursors.delete(sessionID)
return
}
sessionCursors.set(sessionID, cloneCursorState(cursor)!)
}

View File

@@ -10,11 +10,13 @@ import { formatTaskResult } from "./task-result-format"
import { formatTaskStatus } from "./task-status-format" import { formatTaskStatus } from "./task-status-format"
import { getAgentDisplayName } from "../../shared/agent-display-names" import { getAgentDisplayName } from "../../shared/agent-display-names"
import { recordBackgroundOutputConsumption } from "../../shared/background-output-consumption"
const SISYPHUS_JUNIOR_AGENT = getAgentDisplayName("sisyphus-junior") const SISYPHUS_JUNIOR_AGENT = getAgentDisplayName("sisyphus-junior")
type ToolContextWithMetadata = { type ToolContextWithMetadata = {
sessionID: string sessionID: string
messageID?: string
metadata?: (input: { title?: string; metadata?: Record<string, unknown> }) => void metadata?: (input: { title?: string; metadata?: Record<string, unknown> }) => void
callID?: string callID?: string
callId?: string callId?: string
@@ -139,6 +141,7 @@ export function createBackgroundOutput(manager: BackgroundOutputManager, client:
} }
if (resolvedTask.status === "completed") { if (resolvedTask.status === "completed") {
recordBackgroundOutputConsumption(ctx.sessionID, ctx.messageID, resolvedTask.sessionID)
return await formatTaskResult(resolvedTask, client) return await formatTaskResult(resolvedTask, client)
} }

View File

@@ -0,0 +1,129 @@
/// <reference types="bun-types" />
import { afterEach, describe, expect, test } from "bun:test"
import type { ToolContext } from "@opencode-ai/plugin/tool"
import type { BackgroundTask } from "../../features/background-agent"
import { createEventHandler } from "../../plugin/event"
import { clearBackgroundOutputConsumptionState } from "../../shared/background-output-consumption"
import { resetMessageCursor } from "../../shared/session-cursor"
import type { BackgroundOutputClient, BackgroundOutputManager } from "./clients"
import { createBackgroundOutput } from "./create-background-output"
const projectDir = "/Users/yeongyu/local-workspaces/oh-my-opencode"
const parentSessionID = "parent-session"
const taskSessionID = "task-session"
type ToolContextWithCallID = ToolContext & {
callID: string
}
const baseContext = {
sessionID: parentSessionID,
agent: "test-agent",
directory: projectDir,
worktree: projectDir,
abort: new AbortController().signal,
metadata: () => {},
ask: async () => {},
callID: "call-1",
} as const satisfies Partial<ToolContextWithCallID>
function createTask(overrides: Partial<BackgroundTask> = {}): BackgroundTask {
return {
id: "task-1",
sessionID: taskSessionID,
parentSessionID,
parentMessageID: "msg-parent",
description: "background task",
prompt: "do work",
agent: "test-agent",
status: "completed",
...overrides,
}
}
function createMockClient(): BackgroundOutputClient {
return {
session: {
messages: async () => ({
data: [
{
id: "m1",
info: { role: "assistant", time: "2026-01-01T00:00:00Z" },
parts: [{ type: "text", text: "final result" }],
},
],
}),
},
}
}
function createMockEventHandler() {
return createEventHandler({
ctx: {} as never,
pluginConfig: {} as never,
firstMessageVariantGate: {
markSessionCreated: () => {},
clear: () => {},
},
managers: {
skillMcpManager: {
disconnectSession: async () => {},
},
tmuxSessionManager: {
onSessionCreated: async () => {},
onSessionDeleted: async () => {},
},
} as never,
hooks: {} as never,
})
}
afterEach(() => {
resetMessageCursor(taskSessionID)
clearBackgroundOutputConsumptionState()
})
describe("createBackgroundOutput undo regression", () => {
test("#given consumed background output #when undo removes the parent message #then output can be consumed again", async () => {
// #given
const task = createTask()
const manager: BackgroundOutputManager = {
getTask: id => (id === task.id ? task : undefined),
}
const tool = createBackgroundOutput(manager, createMockClient())
const eventHandler = createMockEventHandler()
// #when
const firstOutput = await tool.execute(
{ task_id: task.id },
{ ...baseContext, messageID: "msg-result-1" } as ToolContextWithCallID
)
const secondOutput = await tool.execute(
{ task_id: task.id },
{ ...baseContext, callID: "call-2", messageID: "msg-result-2" } as ToolContextWithCallID
)
await eventHandler({
event: {
type: "message.removed",
properties: {
sessionID: parentSessionID,
messageID: "msg-result-1",
},
},
})
const thirdOutput = await tool.execute(
{ task_id: task.id },
{ ...baseContext, callID: "call-3", messageID: "msg-result-3" } as ToolContextWithCallID
)
// #then
expect(firstOutput).toContain("final result")
expect(secondOutput).toContain("No new output since last check")
expect(thirdOutput).toContain("final result")
})
})