- background_output: snapshot read cursor before consuming, restore on /undo message removal so re-reads return data (fixes #2915) - MCP loader: preserve oauth field in transformMcpServer, add scope/ projectPath filtering so local-scoped MCPs only load in matching directories (fixes #2917) - runtime-fallback: add 'reached your usage limit' to retryable error patterns so quota exhaustion triggers model fallback (fixes #2918) Verified: bun test (4606 pass / 0 fail), tsc --noEmit clean
This commit is contained in:
@@ -10,6 +10,7 @@ import type {
|
||||
} from "./types"
|
||||
import { transformMcpServer } from "./transformer"
|
||||
import { log } from "../../shared/logger"
|
||||
import { shouldLoadMcpServer } from "./scope-filter"
|
||||
|
||||
interface McpConfigPath {
|
||||
path: string
|
||||
@@ -75,6 +76,7 @@ export async function loadMcpConfigs(
|
||||
const loadedServers: LoadedMcpServer[] = []
|
||||
const paths = getMcpConfigPaths()
|
||||
const disabledSet = new Set(disabledMcps)
|
||||
const cwd = process.cwd()
|
||||
|
||||
for (const { path, scope } of paths) {
|
||||
const config = await loadMcpConfigFile(path)
|
||||
@@ -86,6 +88,15 @@ export async function loadMcpConfigs(
|
||||
continue
|
||||
}
|
||||
|
||||
if (!shouldLoadMcpServer(serverConfig, cwd)) {
|
||||
log(`Skipping MCP server "${name}" because local scope does not match cwd`, {
|
||||
path,
|
||||
projectPath: serverConfig.projectPath,
|
||||
cwd,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if (serverConfig.disabled) {
|
||||
log(`Disabling MCP server "${name}"`, { path })
|
||||
delete servers[name]
|
||||
|
||||
28
src/features/claude-code-mcp-loader/scope-filter.ts
Normal file
28
src/features/claude-code-mcp-loader/scope-filter.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
import { existsSync, realpathSync } from "fs"
|
||||
import { resolve } from "path"
|
||||
import type { ClaudeCodeMcpServer } from "./types"
|
||||
|
||||
function normalizePath(path: string): string {
|
||||
const resolvedPath = resolve(path)
|
||||
|
||||
if (!existsSync(resolvedPath)) {
|
||||
return resolvedPath
|
||||
}
|
||||
|
||||
return realpathSync(resolvedPath)
|
||||
}
|
||||
|
||||
export function shouldLoadMcpServer(
|
||||
server: Pick<ClaudeCodeMcpServer, "scope" | "projectPath">,
|
||||
cwd = process.cwd()
|
||||
): boolean {
|
||||
if (server.scope !== "local") {
|
||||
return true
|
||||
}
|
||||
|
||||
if (!server.projectPath) {
|
||||
return false
|
||||
}
|
||||
|
||||
return normalizePath(server.projectPath) === normalizePath(cwd)
|
||||
}
|
||||
82
src/features/claude-code-mcp-loader/scope-filtering.test.ts
Normal file
82
src/features/claude-code-mcp-loader/scope-filtering.test.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"
|
||||
import { mkdirSync, rmSync, writeFileSync } from "fs"
|
||||
import { tmpdir } from "os"
|
||||
import { join } from "path"
|
||||
|
||||
const TEST_DIR = join(tmpdir(), `mcp-scope-filtering-test-${Date.now()}`)
|
||||
const TEST_HOME = join(TEST_DIR, "home")
|
||||
|
||||
describe("loadMcpConfigs", () => {
|
||||
beforeEach(() => {
|
||||
mkdirSync(TEST_DIR, { recursive: true })
|
||||
mkdirSync(TEST_HOME, { recursive: true })
|
||||
mock.module("os", () => ({
|
||||
homedir: () => TEST_HOME,
|
||||
tmpdir,
|
||||
}))
|
||||
mock.module("../../shared", () => ({
|
||||
getClaudeConfigDir: () => join(TEST_HOME, ".claude"),
|
||||
}))
|
||||
mock.module("../../shared/logger", () => ({
|
||||
log: () => {},
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
mock.restore()
|
||||
rmSync(TEST_DIR, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
describe("#given user-scoped MCP entries with local scope metadata", () => {
|
||||
it("#when loading configs #then only servers matching the current project path are loaded", async () => {
|
||||
writeFileSync(
|
||||
join(TEST_HOME, ".claude.json"),
|
||||
JSON.stringify({
|
||||
mcpServers: {
|
||||
globalServer: {
|
||||
command: "npx",
|
||||
args: ["global-server"],
|
||||
},
|
||||
matchingLocal: {
|
||||
command: "npx",
|
||||
args: ["matching-local"],
|
||||
scope: "local",
|
||||
projectPath: TEST_DIR,
|
||||
},
|
||||
nonMatchingLocal: {
|
||||
command: "npx",
|
||||
args: ["non-matching-local"],
|
||||
scope: "local",
|
||||
projectPath: join(TEST_DIR, "other-project"),
|
||||
},
|
||||
missingProjectPath: {
|
||||
command: "npx",
|
||||
args: ["missing-project-path"],
|
||||
scope: "local",
|
||||
},
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
const originalCwd = process.cwd()
|
||||
process.chdir(TEST_DIR)
|
||||
|
||||
try {
|
||||
const { loadMcpConfigs } = await import("./loader")
|
||||
const result = await loadMcpConfigs()
|
||||
|
||||
expect(result.servers).toHaveProperty("globalServer")
|
||||
expect(result.servers).toHaveProperty("matchingLocal")
|
||||
expect(result.servers).not.toHaveProperty("nonMatchingLocal")
|
||||
expect(result.servers).not.toHaveProperty("missingProjectPath")
|
||||
|
||||
expect(result.loadedServers.map((server) => server.name)).toEqual([
|
||||
"globalServer",
|
||||
"matchingLocal",
|
||||
])
|
||||
} finally {
|
||||
process.chdir(originalCwd)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
29
src/features/claude-code-mcp-loader/transformer.test.ts
Normal file
29
src/features/claude-code-mcp-loader/transformer.test.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
import { describe, expect, it } from "bun:test"
|
||||
import { transformMcpServer } from "./transformer"
|
||||
|
||||
describe("transformMcpServer", () => {
|
||||
describe("#given a remote MCP server with oauth config", () => {
|
||||
it("#when transforming the server #then preserves oauth on the remote config", () => {
|
||||
const transformed = transformMcpServer("remote-oauth", {
|
||||
type: "http",
|
||||
url: "https://mcp.example.com",
|
||||
headers: { Authorization: "Bearer test" },
|
||||
oauth: {
|
||||
clientId: "client-id",
|
||||
scopes: ["read", "write"],
|
||||
},
|
||||
})
|
||||
|
||||
expect(transformed).toEqual({
|
||||
type: "remote",
|
||||
url: "https://mcp.example.com",
|
||||
headers: { Authorization: "Bearer test" },
|
||||
oauth: {
|
||||
clientId: "client-id",
|
||||
scopes: ["read", "write"],
|
||||
},
|
||||
enabled: true,
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -30,6 +30,10 @@ export function transformMcpServer(
|
||||
config.headers = expanded.headers
|
||||
}
|
||||
|
||||
if (expanded.oauth && Object.keys(expanded.oauth).length > 0) {
|
||||
config.oauth = expanded.oauth
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
export type McpScope = "user" | "project" | "local"
|
||||
|
||||
export interface McpOAuthConfig {
|
||||
clientId?: string
|
||||
scopes?: string[]
|
||||
}
|
||||
|
||||
export interface ClaudeCodeMcpServer {
|
||||
type?: "http" | "sse" | "stdio"
|
||||
url?: string
|
||||
@@ -7,10 +12,9 @@ export interface ClaudeCodeMcpServer {
|
||||
args?: string[]
|
||||
env?: Record<string, string>
|
||||
headers?: Record<string, string>
|
||||
oauth?: {
|
||||
clientId?: string
|
||||
scopes?: string[]
|
||||
}
|
||||
oauth?: McpOAuthConfig
|
||||
scope?: McpScope
|
||||
projectPath?: string
|
||||
disabled?: boolean
|
||||
}
|
||||
|
||||
@@ -29,6 +33,7 @@ export interface McpRemoteConfig {
|
||||
type: "remote"
|
||||
url: string
|
||||
headers?: Record<string, string>
|
||||
oauth?: McpOAuthConfig
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"
|
||||
import { mkdirSync, rmSync, writeFileSync } from "fs"
|
||||
import { tmpdir } from "os"
|
||||
import { join } from "path"
|
||||
import type { LoadedPlugin } from "./types"
|
||||
|
||||
const TEST_DIR = join(tmpdir(), `plugin-mcp-loader-test-${Date.now()}`)
|
||||
const PROJECT_DIR = join(TEST_DIR, "project")
|
||||
const PLUGIN_DIR = join(TEST_DIR, "plugin")
|
||||
const MCP_CONFIG_PATH = join(PLUGIN_DIR, "mcp.json")
|
||||
|
||||
describe("loadPluginMcpServers", () => {
|
||||
beforeEach(() => {
|
||||
mkdirSync(PROJECT_DIR, { recursive: true })
|
||||
mkdirSync(PLUGIN_DIR, { recursive: true })
|
||||
mock.module("../../shared/logger", () => ({
|
||||
log: () => {},
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
mock.restore()
|
||||
rmSync(TEST_DIR, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
describe("#given plugin MCP entries with local scope metadata", () => {
|
||||
it("#when loading plugin MCP servers #then only entries matching the current cwd are included", async () => {
|
||||
writeFileSync(
|
||||
MCP_CONFIG_PATH,
|
||||
JSON.stringify({
|
||||
mcpServers: {
|
||||
globalServer: {
|
||||
command: "npx",
|
||||
args: ["global-plugin-server"],
|
||||
},
|
||||
matchingLocal: {
|
||||
command: "npx",
|
||||
args: ["matching-plugin-local"],
|
||||
scope: "local",
|
||||
projectPath: PROJECT_DIR,
|
||||
},
|
||||
nonMatchingLocal: {
|
||||
command: "npx",
|
||||
args: ["non-matching-plugin-local"],
|
||||
scope: "local",
|
||||
projectPath: join(PROJECT_DIR, "other-project"),
|
||||
},
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
const plugin: LoadedPlugin = {
|
||||
name: "demo-plugin",
|
||||
version: "1.0.0",
|
||||
scope: "project",
|
||||
installPath: PLUGIN_DIR,
|
||||
pluginKey: "demo-plugin@test",
|
||||
mcpPath: MCP_CONFIG_PATH,
|
||||
}
|
||||
|
||||
const originalCwd = process.cwd()
|
||||
process.chdir(PROJECT_DIR)
|
||||
|
||||
try {
|
||||
const { loadPluginMcpServers } = await import("./mcp-server-loader")
|
||||
const servers = await loadPluginMcpServers([plugin])
|
||||
|
||||
expect(servers).toHaveProperty("demo-plugin:globalServer")
|
||||
expect(servers).toHaveProperty("demo-plugin:matchingLocal")
|
||||
expect(servers).not.toHaveProperty("demo-plugin:nonMatchingLocal")
|
||||
} finally {
|
||||
process.chdir(originalCwd)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,7 @@
|
||||
import { existsSync } from "fs"
|
||||
import type { McpServerConfig } from "../claude-code-mcp-loader/types"
|
||||
import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander"
|
||||
import { shouldLoadMcpServer } from "../claude-code-mcp-loader/scope-filter"
|
||||
import { transformMcpServer } from "../claude-code-mcp-loader/transformer"
|
||||
import type { ClaudeCodeMcpConfig } from "../claude-code-mcp-loader/types"
|
||||
import { log } from "../../shared/logger"
|
||||
@@ -11,6 +12,7 @@ export async function loadPluginMcpServers(
|
||||
plugins: LoadedPlugin[],
|
||||
): Promise<Record<string, McpServerConfig>> {
|
||||
const servers: Record<string, McpServerConfig> = {}
|
||||
const cwd = process.cwd()
|
||||
|
||||
for (const plugin of plugins) {
|
||||
if (!plugin.mcpPath || !existsSync(plugin.mcpPath)) continue
|
||||
@@ -25,6 +27,15 @@ export async function loadPluginMcpServers(
|
||||
if (!config.mcpServers) continue
|
||||
|
||||
for (const [name, serverConfig] of Object.entries(config.mcpServers)) {
|
||||
if (!shouldLoadMcpServer(serverConfig, cwd)) {
|
||||
log(`Skipping local plugin MCP server "${name}" outside current cwd`, {
|
||||
path: plugin.mcpPath,
|
||||
projectPath: serverConfig.projectPath,
|
||||
cwd,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if (serverConfig.disabled) {
|
||||
log(`Skipping disabled MCP server "${name}" from plugin ${plugin.name}`)
|
||||
continue
|
||||
|
||||
@@ -27,6 +27,7 @@ export const RETRYABLE_ERROR_PATTERNS = [
|
||||
/too.?many.?requests/i,
|
||||
/quota.?exceeded/i,
|
||||
/quota\s+will\s+reset\s+after/i,
|
||||
/(?:you(?:'ve|\s+have)\s+)?reached\s+your\s+usage\s+limit/i,
|
||||
/all\s+credentials\s+for\s+model/i,
|
||||
/cool(?:ing)?\s+down/i,
|
||||
/exhausted\s+your\s+capacity/i,
|
||||
|
||||
@@ -253,6 +253,17 @@ describe("quota error detection (fixes #2747)", () => {
|
||||
expect(retryable).toBe(true)
|
||||
})
|
||||
|
||||
test("treats hard usage-limit wording as retryable", () => {
|
||||
//#given
|
||||
const error = { message: "You've reached your usage limit for this month. Please upgrade to continue." }
|
||||
|
||||
//#when
|
||||
const retryable = isRetryableError(error, [429, 503])
|
||||
|
||||
//#then
|
||||
expect(retryable).toBe(true)
|
||||
})
|
||||
|
||||
test("classifies QuotaExceededError by errorName even without quota keywords in message", () => {
|
||||
//#given
|
||||
const error = { name: "QuotaExceededError", message: "Request failed." }
|
||||
|
||||
@@ -64,6 +64,11 @@ describe("runtime-fallback", () => {
|
||||
|
||||
function createMockPluginConfigWithCategoryFallback(fallbackModels: string[]): OhMyOpenCodeConfig {
|
||||
return {
|
||||
git_master: {
|
||||
commit_footer: true,
|
||||
include_co_authored_by: true,
|
||||
git_env_prefix: "GIT_MASTER=1",
|
||||
},
|
||||
categories: {
|
||||
test: {
|
||||
fallback_models: fallbackModels,
|
||||
@@ -79,6 +84,11 @@ describe("runtime-fallback", () => {
|
||||
variant?: string,
|
||||
): OhMyOpenCodeConfig {
|
||||
return {
|
||||
git_master: {
|
||||
commit_footer: true,
|
||||
include_co_authored_by: true,
|
||||
git_env_prefix: "GIT_MASTER=1",
|
||||
},
|
||||
categories: {
|
||||
[categoryName]: {
|
||||
model,
|
||||
@@ -272,6 +282,39 @@ describe("runtime-fallback", () => {
|
||||
expect(errorLog).toBeDefined()
|
||||
})
|
||||
|
||||
test("should trigger fallback when session.error says you've reached your usage limit", async () => {
|
||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||
config: createMockConfig({ notify_on_fallback: false }),
|
||||
pluginConfig: createMockPluginConfigWithCategoryFallback(["zai-coding-plan/glm-5.1"]),
|
||||
})
|
||||
const sessionID = "test-session-usage-limit"
|
||||
SessionCategoryRegistry.register(sessionID, "test")
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "session.created",
|
||||
properties: { info: { id: sessionID, model: "kimi-for-coding/k2p5" } },
|
||||
},
|
||||
})
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "session.error",
|
||||
properties: {
|
||||
sessionID,
|
||||
error: { message: "You've reached your usage limit for this month. Please upgrade to continue." },
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const fallbackLog = logCalls.find((c) => c.msg.includes("Preparing fallback"))
|
||||
expect(fallbackLog).toBeDefined()
|
||||
expect(fallbackLog?.data).toMatchObject({ from: "kimi-for-coding/k2p5", to: "zai-coding-plan/glm-5.1" })
|
||||
|
||||
const skipLog = logCalls.find((c) => c.msg.includes("Error not retryable"))
|
||||
expect(skipLog).toBeUndefined()
|
||||
})
|
||||
|
||||
test("should continue fallback chain when fallback model is not found", async () => {
|
||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||
config: createMockConfig({ notify_on_fallback: false }),
|
||||
@@ -767,7 +810,13 @@ describe("runtime-fallback", () => {
|
||||
test("should log when no fallback models configured", async () => {
|
||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||
config: createMockConfig(),
|
||||
pluginConfig: {},
|
||||
pluginConfig: {
|
||||
git_master: {
|
||||
commit_footer: true,
|
||||
include_co_authored_by: true,
|
||||
git_env_prefix: "GIT_MASTER=1",
|
||||
},
|
||||
},
|
||||
})
|
||||
const sessionID = "test-session-no-fallbacks"
|
||||
|
||||
@@ -2299,6 +2348,11 @@ describe("runtime-fallback", () => {
|
||||
describe("fallback models configuration", () => {
|
||||
function createMockPluginConfigWithAgentFallback(agentName: string, fallbackModels: string[]): OhMyOpenCodeConfig {
|
||||
return {
|
||||
git_master: {
|
||||
commit_footer: true,
|
||||
include_co_authored_by: true,
|
||||
git_env_prefix: "GIT_MASTER=1",
|
||||
},
|
||||
agents: {
|
||||
[agentName]: {
|
||||
fallback_models: fallbackModels,
|
||||
@@ -2496,6 +2550,11 @@ describe("runtime-fallback", () => {
|
||||
{
|
||||
config: createMockConfig({ notify_on_fallback: false }),
|
||||
pluginConfig: {
|
||||
git_master: {
|
||||
commit_footer: true,
|
||||
include_co_authored_by: true,
|
||||
git_env_prefix: "GIT_MASTER=1",
|
||||
},
|
||||
categories: {
|
||||
test: {
|
||||
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
||||
@@ -2548,6 +2607,11 @@ describe("runtime-fallback", () => {
|
||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||
config: createMockConfig({ notify_on_fallback: false }),
|
||||
pluginConfig: {
|
||||
git_master: {
|
||||
commit_footer: true,
|
||||
include_co_authored_by: true,
|
||||
git_env_prefix: "GIT_MASTER=1",
|
||||
},
|
||||
categories: {
|
||||
test: {
|
||||
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
||||
@@ -2605,6 +2669,11 @@ describe("runtime-fallback", () => {
|
||||
{
|
||||
config: createMockConfig({ notify_on_fallback: false }),
|
||||
pluginConfig: {
|
||||
git_master: {
|
||||
commit_footer: true,
|
||||
include_co_authored_by: true,
|
||||
git_env_prefix: "GIT_MASTER=1",
|
||||
},
|
||||
categories: {
|
||||
test: {
|
||||
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
||||
@@ -2647,6 +2716,11 @@ describe("runtime-fallback", () => {
|
||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||
config: createMockConfig({ notify_on_fallback: false }),
|
||||
pluginConfig: {
|
||||
git_master: {
|
||||
commit_footer: true,
|
||||
include_co_authored_by: true,
|
||||
git_env_prefix: "GIT_MASTER=1",
|
||||
},
|
||||
categories: {
|
||||
test: {
|
||||
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
||||
|
||||
@@ -17,6 +17,11 @@ import {
|
||||
setPendingModelFallback,
|
||||
} from "../hooks/model-fallback/hook";
|
||||
import { getRawFallbackModels } from "../hooks/runtime-fallback/fallback-models";
|
||||
import {
|
||||
clearBackgroundOutputConsumptionsForParentSession,
|
||||
clearBackgroundOutputConsumptionsForTaskSession,
|
||||
restoreBackgroundOutputConsumption,
|
||||
} from "../shared/background-output-consumption";
|
||||
import { resetMessageCursor } from "../shared";
|
||||
import { getAgentConfigKey } from "../shared/agent-display-names";
|
||||
import { readConnectedProvidersCache } from "../shared/connected-providers-cache";
|
||||
@@ -366,6 +371,8 @@ export function createEventHandler(args: {
|
||||
clearPendingModelFallback(sessionInfo.id);
|
||||
clearSessionFallbackChain(sessionInfo.id);
|
||||
resetMessageCursor(sessionInfo.id);
|
||||
clearBackgroundOutputConsumptionsForParentSession(sessionInfo.id);
|
||||
clearBackgroundOutputConsumptionsForTaskSession(sessionInfo.id);
|
||||
firstMessageVariantGate.clear(sessionInfo.id);
|
||||
clearSessionModel(sessionInfo.id);
|
||||
clearSessionPromptParams(sessionInfo.id);
|
||||
@@ -382,6 +389,12 @@ export function createEventHandler(args: {
|
||||
}
|
||||
}
|
||||
|
||||
if (event.type === "message.removed") {
|
||||
const messageID = props?.messageID as string | undefined;
|
||||
const sessionID = props?.sessionID as string | undefined;
|
||||
restoreBackgroundOutputConsumption(sessionID, messageID);
|
||||
}
|
||||
|
||||
if (event.type === "message.updated") {
|
||||
const info = props?.info as Record<string, unknown> | undefined;
|
||||
const sessionID = info?.sessionID as string | undefined;
|
||||
|
||||
69
src/shared/background-output-consumption.ts
Normal file
69
src/shared/background-output-consumption.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { getMessageCursor, restoreMessageCursor, type CursorState } from "./session-cursor"
|
||||
|
||||
type MessageConsumptionKey = `${string}:${string}`
|
||||
|
||||
const cursorSnapshotsByMessage = new Map<MessageConsumptionKey, Map<string, CursorState | undefined>>()
|
||||
|
||||
function getMessageKey(sessionID: string, messageID: string): MessageConsumptionKey {
|
||||
return `${sessionID}:${messageID}`
|
||||
}
|
||||
|
||||
export function recordBackgroundOutputConsumption(
|
||||
parentSessionID: string | undefined,
|
||||
parentMessageID: string | undefined,
|
||||
taskSessionID: string | undefined
|
||||
): void {
|
||||
if (!parentSessionID || !parentMessageID || !taskSessionID) return
|
||||
|
||||
const messageKey = getMessageKey(parentSessionID, parentMessageID)
|
||||
const existing = cursorSnapshotsByMessage.get(messageKey) ?? new Map<string, CursorState | undefined>()
|
||||
|
||||
if (!cursorSnapshotsByMessage.has(messageKey)) {
|
||||
cursorSnapshotsByMessage.set(messageKey, existing)
|
||||
}
|
||||
|
||||
if (existing.has(taskSessionID)) return
|
||||
existing.set(taskSessionID, getMessageCursor(taskSessionID))
|
||||
}
|
||||
|
||||
export function restoreBackgroundOutputConsumption(
|
||||
parentSessionID: string | undefined,
|
||||
parentMessageID: string | undefined
|
||||
): void {
|
||||
if (!parentSessionID || !parentMessageID) return
|
||||
|
||||
const messageKey = getMessageKey(parentSessionID, parentMessageID)
|
||||
const snapshots = cursorSnapshotsByMessage.get(messageKey)
|
||||
if (!snapshots) return
|
||||
|
||||
cursorSnapshotsByMessage.delete(messageKey)
|
||||
for (const [taskSessionID, cursor] of snapshots) {
|
||||
restoreMessageCursor(taskSessionID, cursor)
|
||||
}
|
||||
}
|
||||
|
||||
export function clearBackgroundOutputConsumptionsForParentSession(sessionID: string | undefined): void {
|
||||
if (!sessionID) return
|
||||
|
||||
const prefix = `${sessionID}:`
|
||||
for (const messageKey of cursorSnapshotsByMessage.keys()) {
|
||||
if (messageKey.startsWith(prefix)) {
|
||||
cursorSnapshotsByMessage.delete(messageKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function clearBackgroundOutputConsumptionsForTaskSession(taskSessionID: string | undefined): void {
|
||||
if (!taskSessionID) return
|
||||
|
||||
for (const [messageKey, snapshots] of cursorSnapshotsByMessage) {
|
||||
snapshots.delete(taskSessionID)
|
||||
if (snapshots.size === 0) {
|
||||
cursorSnapshotsByMessage.delete(messageKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function clearBackgroundOutputConsumptionState(): void {
|
||||
cursorSnapshotsByMessage.clear()
|
||||
}
|
||||
@@ -13,13 +13,21 @@ export type CursorMessage = {
|
||||
info?: MessageInfo
|
||||
}
|
||||
|
||||
interface CursorState {
|
||||
export interface CursorState {
|
||||
lastKey?: string
|
||||
lastCount: number
|
||||
}
|
||||
|
||||
const sessionCursors = new Map<string, CursorState>()
|
||||
|
||||
function cloneCursorState(state: CursorState | undefined): CursorState | undefined {
|
||||
if (!state) return undefined
|
||||
return {
|
||||
lastKey: state.lastKey,
|
||||
lastCount: state.lastCount,
|
||||
}
|
||||
}
|
||||
|
||||
function buildMessageKey(message: CursorMessage, index: number): string {
|
||||
const id = message.info?.id
|
||||
if (id) return `id:${id}`
|
||||
@@ -83,3 +91,18 @@ export function resetMessageCursor(sessionID?: string): void {
|
||||
}
|
||||
sessionCursors.clear()
|
||||
}
|
||||
|
||||
export function getMessageCursor(sessionID: string | undefined): CursorState | undefined {
|
||||
if (!sessionID) return undefined
|
||||
return cloneCursorState(sessionCursors.get(sessionID))
|
||||
}
|
||||
|
||||
export function restoreMessageCursor(sessionID: string | undefined, cursor: CursorState | undefined): void {
|
||||
if (!sessionID) return
|
||||
if (!cursor) {
|
||||
sessionCursors.delete(sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
sessionCursors.set(sessionID, cloneCursorState(cursor)!)
|
||||
}
|
||||
|
||||
@@ -10,11 +10,13 @@ import { formatTaskResult } from "./task-result-format"
|
||||
import { formatTaskStatus } from "./task-status-format"
|
||||
|
||||
import { getAgentDisplayName } from "../../shared/agent-display-names"
|
||||
import { recordBackgroundOutputConsumption } from "../../shared/background-output-consumption"
|
||||
|
||||
const SISYPHUS_JUNIOR_AGENT = getAgentDisplayName("sisyphus-junior")
|
||||
|
||||
type ToolContextWithMetadata = {
|
||||
sessionID: string
|
||||
messageID?: string
|
||||
metadata?: (input: { title?: string; metadata?: Record<string, unknown> }) => void
|
||||
callID?: string
|
||||
callId?: string
|
||||
@@ -139,6 +141,7 @@ export function createBackgroundOutput(manager: BackgroundOutputManager, client:
|
||||
}
|
||||
|
||||
if (resolvedTask.status === "completed") {
|
||||
recordBackgroundOutputConsumption(ctx.sessionID, ctx.messageID, resolvedTask.sessionID)
|
||||
return await formatTaskResult(resolvedTask, client)
|
||||
}
|
||||
|
||||
|
||||
129
src/tools/background-task/create-background-output.undo.test.ts
Normal file
129
src/tools/background-task/create-background-output.undo.test.ts
Normal file
@@ -0,0 +1,129 @@
|
||||
/// <reference types="bun-types" />
|
||||
|
||||
import { afterEach, describe, expect, test } from "bun:test"
|
||||
import type { ToolContext } from "@opencode-ai/plugin/tool"
|
||||
import type { BackgroundTask } from "../../features/background-agent"
|
||||
import { createEventHandler } from "../../plugin/event"
|
||||
import { clearBackgroundOutputConsumptionState } from "../../shared/background-output-consumption"
|
||||
import { resetMessageCursor } from "../../shared/session-cursor"
|
||||
import type { BackgroundOutputClient, BackgroundOutputManager } from "./clients"
|
||||
import { createBackgroundOutput } from "./create-background-output"
|
||||
|
||||
const projectDir = "/Users/yeongyu/local-workspaces/oh-my-opencode"
|
||||
|
||||
const parentSessionID = "parent-session"
|
||||
const taskSessionID = "task-session"
|
||||
|
||||
type ToolContextWithCallID = ToolContext & {
|
||||
callID: string
|
||||
}
|
||||
|
||||
const baseContext = {
|
||||
sessionID: parentSessionID,
|
||||
agent: "test-agent",
|
||||
directory: projectDir,
|
||||
worktree: projectDir,
|
||||
abort: new AbortController().signal,
|
||||
metadata: () => {},
|
||||
ask: async () => {},
|
||||
callID: "call-1",
|
||||
} as const satisfies Partial<ToolContextWithCallID>
|
||||
|
||||
function createTask(overrides: Partial<BackgroundTask> = {}): BackgroundTask {
|
||||
return {
|
||||
id: "task-1",
|
||||
sessionID: taskSessionID,
|
||||
parentSessionID,
|
||||
parentMessageID: "msg-parent",
|
||||
description: "background task",
|
||||
prompt: "do work",
|
||||
agent: "test-agent",
|
||||
status: "completed",
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
function createMockClient(): BackgroundOutputClient {
|
||||
return {
|
||||
session: {
|
||||
messages: async () => ({
|
||||
data: [
|
||||
{
|
||||
id: "m1",
|
||||
info: { role: "assistant", time: "2026-01-01T00:00:00Z" },
|
||||
parts: [{ type: "text", text: "final result" }],
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
function createMockEventHandler() {
|
||||
return createEventHandler({
|
||||
ctx: {} as never,
|
||||
pluginConfig: {} as never,
|
||||
firstMessageVariantGate: {
|
||||
markSessionCreated: () => {},
|
||||
clear: () => {},
|
||||
},
|
||||
managers: {
|
||||
skillMcpManager: {
|
||||
disconnectSession: async () => {},
|
||||
},
|
||||
tmuxSessionManager: {
|
||||
onSessionCreated: async () => {},
|
||||
onSessionDeleted: async () => {},
|
||||
},
|
||||
} as never,
|
||||
hooks: {} as never,
|
||||
})
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
resetMessageCursor(taskSessionID)
|
||||
clearBackgroundOutputConsumptionState()
|
||||
})
|
||||
|
||||
describe("createBackgroundOutput undo regression", () => {
|
||||
test("#given consumed background output #when undo removes the parent message #then output can be consumed again", async () => {
|
||||
// #given
|
||||
const task = createTask()
|
||||
const manager: BackgroundOutputManager = {
|
||||
getTask: id => (id === task.id ? task : undefined),
|
||||
}
|
||||
const tool = createBackgroundOutput(manager, createMockClient())
|
||||
const eventHandler = createMockEventHandler()
|
||||
|
||||
// #when
|
||||
const firstOutput = await tool.execute(
|
||||
{ task_id: task.id },
|
||||
{ ...baseContext, messageID: "msg-result-1" } as ToolContextWithCallID
|
||||
)
|
||||
|
||||
const secondOutput = await tool.execute(
|
||||
{ task_id: task.id },
|
||||
{ ...baseContext, callID: "call-2", messageID: "msg-result-2" } as ToolContextWithCallID
|
||||
)
|
||||
|
||||
await eventHandler({
|
||||
event: {
|
||||
type: "message.removed",
|
||||
properties: {
|
||||
sessionID: parentSessionID,
|
||||
messageID: "msg-result-1",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const thirdOutput = await tool.execute(
|
||||
{ task_id: task.id },
|
||||
{ ...baseContext, callID: "call-3", messageID: "msg-result-3" } as ToolContextWithCallID
|
||||
)
|
||||
|
||||
// #then
|
||||
expect(firstOutput).toContain("final result")
|
||||
expect(secondOutput).toContain("No new output since last check")
|
||||
expect(thirdOutput).toContain("final result")
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user