- background_output: snapshot read cursor before consuming, restore on /undo message removal so re-reads return data (fixes #2915) - MCP loader: preserve oauth field in transformMcpServer, add scope/ projectPath filtering so local-scoped MCPs only load in matching directories (fixes #2917) - runtime-fallback: add 'reached your usage limit' to retryable error patterns so quota exhaustion triggers model fallback (fixes #2918) Verified: bun test (4606 pass / 0 fail), tsc --noEmit clean
This commit is contained in:
@@ -10,6 +10,7 @@ import type {
|
|||||||
} from "./types"
|
} from "./types"
|
||||||
import { transformMcpServer } from "./transformer"
|
import { transformMcpServer } from "./transformer"
|
||||||
import { log } from "../../shared/logger"
|
import { log } from "../../shared/logger"
|
||||||
|
import { shouldLoadMcpServer } from "./scope-filter"
|
||||||
|
|
||||||
interface McpConfigPath {
|
interface McpConfigPath {
|
||||||
path: string
|
path: string
|
||||||
@@ -75,6 +76,7 @@ export async function loadMcpConfigs(
|
|||||||
const loadedServers: LoadedMcpServer[] = []
|
const loadedServers: LoadedMcpServer[] = []
|
||||||
const paths = getMcpConfigPaths()
|
const paths = getMcpConfigPaths()
|
||||||
const disabledSet = new Set(disabledMcps)
|
const disabledSet = new Set(disabledMcps)
|
||||||
|
const cwd = process.cwd()
|
||||||
|
|
||||||
for (const { path, scope } of paths) {
|
for (const { path, scope } of paths) {
|
||||||
const config = await loadMcpConfigFile(path)
|
const config = await loadMcpConfigFile(path)
|
||||||
@@ -86,6 +88,15 @@ export async function loadMcpConfigs(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!shouldLoadMcpServer(serverConfig, cwd)) {
|
||||||
|
log(`Skipping MCP server "${name}" because local scope does not match cwd`, {
|
||||||
|
path,
|
||||||
|
projectPath: serverConfig.projectPath,
|
||||||
|
cwd,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if (serverConfig.disabled) {
|
if (serverConfig.disabled) {
|
||||||
log(`Disabling MCP server "${name}"`, { path })
|
log(`Disabling MCP server "${name}"`, { path })
|
||||||
delete servers[name]
|
delete servers[name]
|
||||||
|
|||||||
28
src/features/claude-code-mcp-loader/scope-filter.ts
Normal file
28
src/features/claude-code-mcp-loader/scope-filter.ts
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import { existsSync, realpathSync } from "fs"
|
||||||
|
import { resolve } from "path"
|
||||||
|
import type { ClaudeCodeMcpServer } from "./types"
|
||||||
|
|
||||||
|
function normalizePath(path: string): string {
|
||||||
|
const resolvedPath = resolve(path)
|
||||||
|
|
||||||
|
if (!existsSync(resolvedPath)) {
|
||||||
|
return resolvedPath
|
||||||
|
}
|
||||||
|
|
||||||
|
return realpathSync(resolvedPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function shouldLoadMcpServer(
|
||||||
|
server: Pick<ClaudeCodeMcpServer, "scope" | "projectPath">,
|
||||||
|
cwd = process.cwd()
|
||||||
|
): boolean {
|
||||||
|
if (server.scope !== "local") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!server.projectPath) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalizePath(server.projectPath) === normalizePath(cwd)
|
||||||
|
}
|
||||||
82
src/features/claude-code-mcp-loader/scope-filtering.test.ts
Normal file
82
src/features/claude-code-mcp-loader/scope-filtering.test.ts
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"
|
||||||
|
import { mkdirSync, rmSync, writeFileSync } from "fs"
|
||||||
|
import { tmpdir } from "os"
|
||||||
|
import { join } from "path"
|
||||||
|
|
||||||
|
const TEST_DIR = join(tmpdir(), `mcp-scope-filtering-test-${Date.now()}`)
|
||||||
|
const TEST_HOME = join(TEST_DIR, "home")
|
||||||
|
|
||||||
|
describe("loadMcpConfigs", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
mkdirSync(TEST_DIR, { recursive: true })
|
||||||
|
mkdirSync(TEST_HOME, { recursive: true })
|
||||||
|
mock.module("os", () => ({
|
||||||
|
homedir: () => TEST_HOME,
|
||||||
|
tmpdir,
|
||||||
|
}))
|
||||||
|
mock.module("../../shared", () => ({
|
||||||
|
getClaudeConfigDir: () => join(TEST_HOME, ".claude"),
|
||||||
|
}))
|
||||||
|
mock.module("../../shared/logger", () => ({
|
||||||
|
log: () => {},
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
mock.restore()
|
||||||
|
rmSync(TEST_DIR, { recursive: true, force: true })
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("#given user-scoped MCP entries with local scope metadata", () => {
|
||||||
|
it("#when loading configs #then only servers matching the current project path are loaded", async () => {
|
||||||
|
writeFileSync(
|
||||||
|
join(TEST_HOME, ".claude.json"),
|
||||||
|
JSON.stringify({
|
||||||
|
mcpServers: {
|
||||||
|
globalServer: {
|
||||||
|
command: "npx",
|
||||||
|
args: ["global-server"],
|
||||||
|
},
|
||||||
|
matchingLocal: {
|
||||||
|
command: "npx",
|
||||||
|
args: ["matching-local"],
|
||||||
|
scope: "local",
|
||||||
|
projectPath: TEST_DIR,
|
||||||
|
},
|
||||||
|
nonMatchingLocal: {
|
||||||
|
command: "npx",
|
||||||
|
args: ["non-matching-local"],
|
||||||
|
scope: "local",
|
||||||
|
projectPath: join(TEST_DIR, "other-project"),
|
||||||
|
},
|
||||||
|
missingProjectPath: {
|
||||||
|
command: "npx",
|
||||||
|
args: ["missing-project-path"],
|
||||||
|
scope: "local",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
const originalCwd = process.cwd()
|
||||||
|
process.chdir(TEST_DIR)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { loadMcpConfigs } = await import("./loader")
|
||||||
|
const result = await loadMcpConfigs()
|
||||||
|
|
||||||
|
expect(result.servers).toHaveProperty("globalServer")
|
||||||
|
expect(result.servers).toHaveProperty("matchingLocal")
|
||||||
|
expect(result.servers).not.toHaveProperty("nonMatchingLocal")
|
||||||
|
expect(result.servers).not.toHaveProperty("missingProjectPath")
|
||||||
|
|
||||||
|
expect(result.loadedServers.map((server) => server.name)).toEqual([
|
||||||
|
"globalServer",
|
||||||
|
"matchingLocal",
|
||||||
|
])
|
||||||
|
} finally {
|
||||||
|
process.chdir(originalCwd)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
29
src/features/claude-code-mcp-loader/transformer.test.ts
Normal file
29
src/features/claude-code-mcp-loader/transformer.test.ts
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import { describe, expect, it } from "bun:test"
|
||||||
|
import { transformMcpServer } from "./transformer"
|
||||||
|
|
||||||
|
describe("transformMcpServer", () => {
|
||||||
|
describe("#given a remote MCP server with oauth config", () => {
|
||||||
|
it("#when transforming the server #then preserves oauth on the remote config", () => {
|
||||||
|
const transformed = transformMcpServer("remote-oauth", {
|
||||||
|
type: "http",
|
||||||
|
url: "https://mcp.example.com",
|
||||||
|
headers: { Authorization: "Bearer test" },
|
||||||
|
oauth: {
|
||||||
|
clientId: "client-id",
|
||||||
|
scopes: ["read", "write"],
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(transformed).toEqual({
|
||||||
|
type: "remote",
|
||||||
|
url: "https://mcp.example.com",
|
||||||
|
headers: { Authorization: "Bearer test" },
|
||||||
|
oauth: {
|
||||||
|
clientId: "client-id",
|
||||||
|
scopes: ["read", "write"],
|
||||||
|
},
|
||||||
|
enabled: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -30,6 +30,10 @@ export function transformMcpServer(
|
|||||||
config.headers = expanded.headers
|
config.headers = expanded.headers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (expanded.oauth && Object.keys(expanded.oauth).length > 0) {
|
||||||
|
config.oauth = expanded.oauth
|
||||||
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
export type McpScope = "user" | "project" | "local"
|
export type McpScope = "user" | "project" | "local"
|
||||||
|
|
||||||
|
export interface McpOAuthConfig {
|
||||||
|
clientId?: string
|
||||||
|
scopes?: string[]
|
||||||
|
}
|
||||||
|
|
||||||
export interface ClaudeCodeMcpServer {
|
export interface ClaudeCodeMcpServer {
|
||||||
type?: "http" | "sse" | "stdio"
|
type?: "http" | "sse" | "stdio"
|
||||||
url?: string
|
url?: string
|
||||||
@@ -7,10 +12,9 @@ export interface ClaudeCodeMcpServer {
|
|||||||
args?: string[]
|
args?: string[]
|
||||||
env?: Record<string, string>
|
env?: Record<string, string>
|
||||||
headers?: Record<string, string>
|
headers?: Record<string, string>
|
||||||
oauth?: {
|
oauth?: McpOAuthConfig
|
||||||
clientId?: string
|
scope?: McpScope
|
||||||
scopes?: string[]
|
projectPath?: string
|
||||||
}
|
|
||||||
disabled?: boolean
|
disabled?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,6 +33,7 @@ export interface McpRemoteConfig {
|
|||||||
type: "remote"
|
type: "remote"
|
||||||
url: string
|
url: string
|
||||||
headers?: Record<string, string>
|
headers?: Record<string, string>
|
||||||
|
oauth?: McpOAuthConfig
|
||||||
enabled?: boolean
|
enabled?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,76 @@
|
|||||||
|
import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"
|
||||||
|
import { mkdirSync, rmSync, writeFileSync } from "fs"
|
||||||
|
import { tmpdir } from "os"
|
||||||
|
import { join } from "path"
|
||||||
|
import type { LoadedPlugin } from "./types"
|
||||||
|
|
||||||
|
const TEST_DIR = join(tmpdir(), `plugin-mcp-loader-test-${Date.now()}`)
|
||||||
|
const PROJECT_DIR = join(TEST_DIR, "project")
|
||||||
|
const PLUGIN_DIR = join(TEST_DIR, "plugin")
|
||||||
|
const MCP_CONFIG_PATH = join(PLUGIN_DIR, "mcp.json")
|
||||||
|
|
||||||
|
describe("loadPluginMcpServers", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
mkdirSync(PROJECT_DIR, { recursive: true })
|
||||||
|
mkdirSync(PLUGIN_DIR, { recursive: true })
|
||||||
|
mock.module("../../shared/logger", () => ({
|
||||||
|
log: () => {},
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
mock.restore()
|
||||||
|
rmSync(TEST_DIR, { recursive: true, force: true })
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("#given plugin MCP entries with local scope metadata", () => {
|
||||||
|
it("#when loading plugin MCP servers #then only entries matching the current cwd are included", async () => {
|
||||||
|
writeFileSync(
|
||||||
|
MCP_CONFIG_PATH,
|
||||||
|
JSON.stringify({
|
||||||
|
mcpServers: {
|
||||||
|
globalServer: {
|
||||||
|
command: "npx",
|
||||||
|
args: ["global-plugin-server"],
|
||||||
|
},
|
||||||
|
matchingLocal: {
|
||||||
|
command: "npx",
|
||||||
|
args: ["matching-plugin-local"],
|
||||||
|
scope: "local",
|
||||||
|
projectPath: PROJECT_DIR,
|
||||||
|
},
|
||||||
|
nonMatchingLocal: {
|
||||||
|
command: "npx",
|
||||||
|
args: ["non-matching-plugin-local"],
|
||||||
|
scope: "local",
|
||||||
|
projectPath: join(PROJECT_DIR, "other-project"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
const plugin: LoadedPlugin = {
|
||||||
|
name: "demo-plugin",
|
||||||
|
version: "1.0.0",
|
||||||
|
scope: "project",
|
||||||
|
installPath: PLUGIN_DIR,
|
||||||
|
pluginKey: "demo-plugin@test",
|
||||||
|
mcpPath: MCP_CONFIG_PATH,
|
||||||
|
}
|
||||||
|
|
||||||
|
const originalCwd = process.cwd()
|
||||||
|
process.chdir(PROJECT_DIR)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { loadPluginMcpServers } = await import("./mcp-server-loader")
|
||||||
|
const servers = await loadPluginMcpServers([plugin])
|
||||||
|
|
||||||
|
expect(servers).toHaveProperty("demo-plugin:globalServer")
|
||||||
|
expect(servers).toHaveProperty("demo-plugin:matchingLocal")
|
||||||
|
expect(servers).not.toHaveProperty("demo-plugin:nonMatchingLocal")
|
||||||
|
} finally {
|
||||||
|
process.chdir(originalCwd)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import { existsSync } from "fs"
|
import { existsSync } from "fs"
|
||||||
import type { McpServerConfig } from "../claude-code-mcp-loader/types"
|
import type { McpServerConfig } from "../claude-code-mcp-loader/types"
|
||||||
import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander"
|
import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander"
|
||||||
|
import { shouldLoadMcpServer } from "../claude-code-mcp-loader/scope-filter"
|
||||||
import { transformMcpServer } from "../claude-code-mcp-loader/transformer"
|
import { transformMcpServer } from "../claude-code-mcp-loader/transformer"
|
||||||
import type { ClaudeCodeMcpConfig } from "../claude-code-mcp-loader/types"
|
import type { ClaudeCodeMcpConfig } from "../claude-code-mcp-loader/types"
|
||||||
import { log } from "../../shared/logger"
|
import { log } from "../../shared/logger"
|
||||||
@@ -11,6 +12,7 @@ export async function loadPluginMcpServers(
|
|||||||
plugins: LoadedPlugin[],
|
plugins: LoadedPlugin[],
|
||||||
): Promise<Record<string, McpServerConfig>> {
|
): Promise<Record<string, McpServerConfig>> {
|
||||||
const servers: Record<string, McpServerConfig> = {}
|
const servers: Record<string, McpServerConfig> = {}
|
||||||
|
const cwd = process.cwd()
|
||||||
|
|
||||||
for (const plugin of plugins) {
|
for (const plugin of plugins) {
|
||||||
if (!plugin.mcpPath || !existsSync(plugin.mcpPath)) continue
|
if (!plugin.mcpPath || !existsSync(plugin.mcpPath)) continue
|
||||||
@@ -25,6 +27,15 @@ export async function loadPluginMcpServers(
|
|||||||
if (!config.mcpServers) continue
|
if (!config.mcpServers) continue
|
||||||
|
|
||||||
for (const [name, serverConfig] of Object.entries(config.mcpServers)) {
|
for (const [name, serverConfig] of Object.entries(config.mcpServers)) {
|
||||||
|
if (!shouldLoadMcpServer(serverConfig, cwd)) {
|
||||||
|
log(`Skipping local plugin MCP server "${name}" outside current cwd`, {
|
||||||
|
path: plugin.mcpPath,
|
||||||
|
projectPath: serverConfig.projectPath,
|
||||||
|
cwd,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if (serverConfig.disabled) {
|
if (serverConfig.disabled) {
|
||||||
log(`Skipping disabled MCP server "${name}" from plugin ${plugin.name}`)
|
log(`Skipping disabled MCP server "${name}" from plugin ${plugin.name}`)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ export const RETRYABLE_ERROR_PATTERNS = [
|
|||||||
/too.?many.?requests/i,
|
/too.?many.?requests/i,
|
||||||
/quota.?exceeded/i,
|
/quota.?exceeded/i,
|
||||||
/quota\s+will\s+reset\s+after/i,
|
/quota\s+will\s+reset\s+after/i,
|
||||||
|
/(?:you(?:'ve|\s+have)\s+)?reached\s+your\s+usage\s+limit/i,
|
||||||
/all\s+credentials\s+for\s+model/i,
|
/all\s+credentials\s+for\s+model/i,
|
||||||
/cool(?:ing)?\s+down/i,
|
/cool(?:ing)?\s+down/i,
|
||||||
/exhausted\s+your\s+capacity/i,
|
/exhausted\s+your\s+capacity/i,
|
||||||
|
|||||||
@@ -253,6 +253,17 @@ describe("quota error detection (fixes #2747)", () => {
|
|||||||
expect(retryable).toBe(true)
|
expect(retryable).toBe(true)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test("treats hard usage-limit wording as retryable", () => {
|
||||||
|
//#given
|
||||||
|
const error = { message: "You've reached your usage limit for this month. Please upgrade to continue." }
|
||||||
|
|
||||||
|
//#when
|
||||||
|
const retryable = isRetryableError(error, [429, 503])
|
||||||
|
|
||||||
|
//#then
|
||||||
|
expect(retryable).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
test("classifies QuotaExceededError by errorName even without quota keywords in message", () => {
|
test("classifies QuotaExceededError by errorName even without quota keywords in message", () => {
|
||||||
//#given
|
//#given
|
||||||
const error = { name: "QuotaExceededError", message: "Request failed." }
|
const error = { name: "QuotaExceededError", message: "Request failed." }
|
||||||
|
|||||||
@@ -64,6 +64,11 @@ describe("runtime-fallback", () => {
|
|||||||
|
|
||||||
function createMockPluginConfigWithCategoryFallback(fallbackModels: string[]): OhMyOpenCodeConfig {
|
function createMockPluginConfigWithCategoryFallback(fallbackModels: string[]): OhMyOpenCodeConfig {
|
||||||
return {
|
return {
|
||||||
|
git_master: {
|
||||||
|
commit_footer: true,
|
||||||
|
include_co_authored_by: true,
|
||||||
|
git_env_prefix: "GIT_MASTER=1",
|
||||||
|
},
|
||||||
categories: {
|
categories: {
|
||||||
test: {
|
test: {
|
||||||
fallback_models: fallbackModels,
|
fallback_models: fallbackModels,
|
||||||
@@ -79,6 +84,11 @@ describe("runtime-fallback", () => {
|
|||||||
variant?: string,
|
variant?: string,
|
||||||
): OhMyOpenCodeConfig {
|
): OhMyOpenCodeConfig {
|
||||||
return {
|
return {
|
||||||
|
git_master: {
|
||||||
|
commit_footer: true,
|
||||||
|
include_co_authored_by: true,
|
||||||
|
git_env_prefix: "GIT_MASTER=1",
|
||||||
|
},
|
||||||
categories: {
|
categories: {
|
||||||
[categoryName]: {
|
[categoryName]: {
|
||||||
model,
|
model,
|
||||||
@@ -272,6 +282,39 @@ describe("runtime-fallback", () => {
|
|||||||
expect(errorLog).toBeDefined()
|
expect(errorLog).toBeDefined()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test("should trigger fallback when session.error says you've reached your usage limit", async () => {
|
||||||
|
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||||
|
config: createMockConfig({ notify_on_fallback: false }),
|
||||||
|
pluginConfig: createMockPluginConfigWithCategoryFallback(["zai-coding-plan/glm-5.1"]),
|
||||||
|
})
|
||||||
|
const sessionID = "test-session-usage-limit"
|
||||||
|
SessionCategoryRegistry.register(sessionID, "test")
|
||||||
|
|
||||||
|
await hook.event({
|
||||||
|
event: {
|
||||||
|
type: "session.created",
|
||||||
|
properties: { info: { id: sessionID, model: "kimi-for-coding/k2p5" } },
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
await hook.event({
|
||||||
|
event: {
|
||||||
|
type: "session.error",
|
||||||
|
properties: {
|
||||||
|
sessionID,
|
||||||
|
error: { message: "You've reached your usage limit for this month. Please upgrade to continue." },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const fallbackLog = logCalls.find((c) => c.msg.includes("Preparing fallback"))
|
||||||
|
expect(fallbackLog).toBeDefined()
|
||||||
|
expect(fallbackLog?.data).toMatchObject({ from: "kimi-for-coding/k2p5", to: "zai-coding-plan/glm-5.1" })
|
||||||
|
|
||||||
|
const skipLog = logCalls.find((c) => c.msg.includes("Error not retryable"))
|
||||||
|
expect(skipLog).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
test("should continue fallback chain when fallback model is not found", async () => {
|
test("should continue fallback chain when fallback model is not found", async () => {
|
||||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||||
config: createMockConfig({ notify_on_fallback: false }),
|
config: createMockConfig({ notify_on_fallback: false }),
|
||||||
@@ -767,7 +810,13 @@ describe("runtime-fallback", () => {
|
|||||||
test("should log when no fallback models configured", async () => {
|
test("should log when no fallback models configured", async () => {
|
||||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||||
config: createMockConfig(),
|
config: createMockConfig(),
|
||||||
pluginConfig: {},
|
pluginConfig: {
|
||||||
|
git_master: {
|
||||||
|
commit_footer: true,
|
||||||
|
include_co_authored_by: true,
|
||||||
|
git_env_prefix: "GIT_MASTER=1",
|
||||||
|
},
|
||||||
|
},
|
||||||
})
|
})
|
||||||
const sessionID = "test-session-no-fallbacks"
|
const sessionID = "test-session-no-fallbacks"
|
||||||
|
|
||||||
@@ -2299,6 +2348,11 @@ describe("runtime-fallback", () => {
|
|||||||
describe("fallback models configuration", () => {
|
describe("fallback models configuration", () => {
|
||||||
function createMockPluginConfigWithAgentFallback(agentName: string, fallbackModels: string[]): OhMyOpenCodeConfig {
|
function createMockPluginConfigWithAgentFallback(agentName: string, fallbackModels: string[]): OhMyOpenCodeConfig {
|
||||||
return {
|
return {
|
||||||
|
git_master: {
|
||||||
|
commit_footer: true,
|
||||||
|
include_co_authored_by: true,
|
||||||
|
git_env_prefix: "GIT_MASTER=1",
|
||||||
|
},
|
||||||
agents: {
|
agents: {
|
||||||
[agentName]: {
|
[agentName]: {
|
||||||
fallback_models: fallbackModels,
|
fallback_models: fallbackModels,
|
||||||
@@ -2496,6 +2550,11 @@ describe("runtime-fallback", () => {
|
|||||||
{
|
{
|
||||||
config: createMockConfig({ notify_on_fallback: false }),
|
config: createMockConfig({ notify_on_fallback: false }),
|
||||||
pluginConfig: {
|
pluginConfig: {
|
||||||
|
git_master: {
|
||||||
|
commit_footer: true,
|
||||||
|
include_co_authored_by: true,
|
||||||
|
git_env_prefix: "GIT_MASTER=1",
|
||||||
|
},
|
||||||
categories: {
|
categories: {
|
||||||
test: {
|
test: {
|
||||||
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
||||||
@@ -2548,6 +2607,11 @@ describe("runtime-fallback", () => {
|
|||||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||||
config: createMockConfig({ notify_on_fallback: false }),
|
config: createMockConfig({ notify_on_fallback: false }),
|
||||||
pluginConfig: {
|
pluginConfig: {
|
||||||
|
git_master: {
|
||||||
|
commit_footer: true,
|
||||||
|
include_co_authored_by: true,
|
||||||
|
git_env_prefix: "GIT_MASTER=1",
|
||||||
|
},
|
||||||
categories: {
|
categories: {
|
||||||
test: {
|
test: {
|
||||||
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
||||||
@@ -2605,6 +2669,11 @@ describe("runtime-fallback", () => {
|
|||||||
{
|
{
|
||||||
config: createMockConfig({ notify_on_fallback: false }),
|
config: createMockConfig({ notify_on_fallback: false }),
|
||||||
pluginConfig: {
|
pluginConfig: {
|
||||||
|
git_master: {
|
||||||
|
commit_footer: true,
|
||||||
|
include_co_authored_by: true,
|
||||||
|
git_env_prefix: "GIT_MASTER=1",
|
||||||
|
},
|
||||||
categories: {
|
categories: {
|
||||||
test: {
|
test: {
|
||||||
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
||||||
@@ -2647,6 +2716,11 @@ describe("runtime-fallback", () => {
|
|||||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||||
config: createMockConfig({ notify_on_fallback: false }),
|
config: createMockConfig({ notify_on_fallback: false }),
|
||||||
pluginConfig: {
|
pluginConfig: {
|
||||||
|
git_master: {
|
||||||
|
commit_footer: true,
|
||||||
|
include_co_authored_by: true,
|
||||||
|
git_env_prefix: "GIT_MASTER=1",
|
||||||
|
},
|
||||||
categories: {
|
categories: {
|
||||||
test: {
|
test: {
|
||||||
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
fallback_models: ["provider-a/model-a", "provider-b/model-b"],
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ import {
|
|||||||
setPendingModelFallback,
|
setPendingModelFallback,
|
||||||
} from "../hooks/model-fallback/hook";
|
} from "../hooks/model-fallback/hook";
|
||||||
import { getRawFallbackModels } from "../hooks/runtime-fallback/fallback-models";
|
import { getRawFallbackModels } from "../hooks/runtime-fallback/fallback-models";
|
||||||
|
import {
|
||||||
|
clearBackgroundOutputConsumptionsForParentSession,
|
||||||
|
clearBackgroundOutputConsumptionsForTaskSession,
|
||||||
|
restoreBackgroundOutputConsumption,
|
||||||
|
} from "../shared/background-output-consumption";
|
||||||
import { resetMessageCursor } from "../shared";
|
import { resetMessageCursor } from "../shared";
|
||||||
import { getAgentConfigKey } from "../shared/agent-display-names";
|
import { getAgentConfigKey } from "../shared/agent-display-names";
|
||||||
import { readConnectedProvidersCache } from "../shared/connected-providers-cache";
|
import { readConnectedProvidersCache } from "../shared/connected-providers-cache";
|
||||||
@@ -366,6 +371,8 @@ export function createEventHandler(args: {
|
|||||||
clearPendingModelFallback(sessionInfo.id);
|
clearPendingModelFallback(sessionInfo.id);
|
||||||
clearSessionFallbackChain(sessionInfo.id);
|
clearSessionFallbackChain(sessionInfo.id);
|
||||||
resetMessageCursor(sessionInfo.id);
|
resetMessageCursor(sessionInfo.id);
|
||||||
|
clearBackgroundOutputConsumptionsForParentSession(sessionInfo.id);
|
||||||
|
clearBackgroundOutputConsumptionsForTaskSession(sessionInfo.id);
|
||||||
firstMessageVariantGate.clear(sessionInfo.id);
|
firstMessageVariantGate.clear(sessionInfo.id);
|
||||||
clearSessionModel(sessionInfo.id);
|
clearSessionModel(sessionInfo.id);
|
||||||
clearSessionPromptParams(sessionInfo.id);
|
clearSessionPromptParams(sessionInfo.id);
|
||||||
@@ -382,6 +389,12 @@ export function createEventHandler(args: {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (event.type === "message.removed") {
|
||||||
|
const messageID = props?.messageID as string | undefined;
|
||||||
|
const sessionID = props?.sessionID as string | undefined;
|
||||||
|
restoreBackgroundOutputConsumption(sessionID, messageID);
|
||||||
|
}
|
||||||
|
|
||||||
if (event.type === "message.updated") {
|
if (event.type === "message.updated") {
|
||||||
const info = props?.info as Record<string, unknown> | undefined;
|
const info = props?.info as Record<string, unknown> | undefined;
|
||||||
const sessionID = info?.sessionID as string | undefined;
|
const sessionID = info?.sessionID as string | undefined;
|
||||||
|
|||||||
69
src/shared/background-output-consumption.ts
Normal file
69
src/shared/background-output-consumption.ts
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import { getMessageCursor, restoreMessageCursor, type CursorState } from "./session-cursor"
|
||||||
|
|
||||||
|
type MessageConsumptionKey = `${string}:${string}`
|
||||||
|
|
||||||
|
const cursorSnapshotsByMessage = new Map<MessageConsumptionKey, Map<string, CursorState | undefined>>()
|
||||||
|
|
||||||
|
function getMessageKey(sessionID: string, messageID: string): MessageConsumptionKey {
|
||||||
|
return `${sessionID}:${messageID}`
|
||||||
|
}
|
||||||
|
|
||||||
|
export function recordBackgroundOutputConsumption(
|
||||||
|
parentSessionID: string | undefined,
|
||||||
|
parentMessageID: string | undefined,
|
||||||
|
taskSessionID: string | undefined
|
||||||
|
): void {
|
||||||
|
if (!parentSessionID || !parentMessageID || !taskSessionID) return
|
||||||
|
|
||||||
|
const messageKey = getMessageKey(parentSessionID, parentMessageID)
|
||||||
|
const existing = cursorSnapshotsByMessage.get(messageKey) ?? new Map<string, CursorState | undefined>()
|
||||||
|
|
||||||
|
if (!cursorSnapshotsByMessage.has(messageKey)) {
|
||||||
|
cursorSnapshotsByMessage.set(messageKey, existing)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (existing.has(taskSessionID)) return
|
||||||
|
existing.set(taskSessionID, getMessageCursor(taskSessionID))
|
||||||
|
}
|
||||||
|
|
||||||
|
export function restoreBackgroundOutputConsumption(
|
||||||
|
parentSessionID: string | undefined,
|
||||||
|
parentMessageID: string | undefined
|
||||||
|
): void {
|
||||||
|
if (!parentSessionID || !parentMessageID) return
|
||||||
|
|
||||||
|
const messageKey = getMessageKey(parentSessionID, parentMessageID)
|
||||||
|
const snapshots = cursorSnapshotsByMessage.get(messageKey)
|
||||||
|
if (!snapshots) return
|
||||||
|
|
||||||
|
cursorSnapshotsByMessage.delete(messageKey)
|
||||||
|
for (const [taskSessionID, cursor] of snapshots) {
|
||||||
|
restoreMessageCursor(taskSessionID, cursor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function clearBackgroundOutputConsumptionsForParentSession(sessionID: string | undefined): void {
|
||||||
|
if (!sessionID) return
|
||||||
|
|
||||||
|
const prefix = `${sessionID}:`
|
||||||
|
for (const messageKey of cursorSnapshotsByMessage.keys()) {
|
||||||
|
if (messageKey.startsWith(prefix)) {
|
||||||
|
cursorSnapshotsByMessage.delete(messageKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function clearBackgroundOutputConsumptionsForTaskSession(taskSessionID: string | undefined): void {
|
||||||
|
if (!taskSessionID) return
|
||||||
|
|
||||||
|
for (const [messageKey, snapshots] of cursorSnapshotsByMessage) {
|
||||||
|
snapshots.delete(taskSessionID)
|
||||||
|
if (snapshots.size === 0) {
|
||||||
|
cursorSnapshotsByMessage.delete(messageKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function clearBackgroundOutputConsumptionState(): void {
|
||||||
|
cursorSnapshotsByMessage.clear()
|
||||||
|
}
|
||||||
@@ -13,13 +13,21 @@ export type CursorMessage = {
|
|||||||
info?: MessageInfo
|
info?: MessageInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
interface CursorState {
|
export interface CursorState {
|
||||||
lastKey?: string
|
lastKey?: string
|
||||||
lastCount: number
|
lastCount: number
|
||||||
}
|
}
|
||||||
|
|
||||||
const sessionCursors = new Map<string, CursorState>()
|
const sessionCursors = new Map<string, CursorState>()
|
||||||
|
|
||||||
|
function cloneCursorState(state: CursorState | undefined): CursorState | undefined {
|
||||||
|
if (!state) return undefined
|
||||||
|
return {
|
||||||
|
lastKey: state.lastKey,
|
||||||
|
lastCount: state.lastCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function buildMessageKey(message: CursorMessage, index: number): string {
|
function buildMessageKey(message: CursorMessage, index: number): string {
|
||||||
const id = message.info?.id
|
const id = message.info?.id
|
||||||
if (id) return `id:${id}`
|
if (id) return `id:${id}`
|
||||||
@@ -83,3 +91,18 @@ export function resetMessageCursor(sessionID?: string): void {
|
|||||||
}
|
}
|
||||||
sessionCursors.clear()
|
sessionCursors.clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getMessageCursor(sessionID: string | undefined): CursorState | undefined {
|
||||||
|
if (!sessionID) return undefined
|
||||||
|
return cloneCursorState(sessionCursors.get(sessionID))
|
||||||
|
}
|
||||||
|
|
||||||
|
export function restoreMessageCursor(sessionID: string | undefined, cursor: CursorState | undefined): void {
|
||||||
|
if (!sessionID) return
|
||||||
|
if (!cursor) {
|
||||||
|
sessionCursors.delete(sessionID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionCursors.set(sessionID, cloneCursorState(cursor)!)
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,11 +10,13 @@ import { formatTaskResult } from "./task-result-format"
|
|||||||
import { formatTaskStatus } from "./task-status-format"
|
import { formatTaskStatus } from "./task-status-format"
|
||||||
|
|
||||||
import { getAgentDisplayName } from "../../shared/agent-display-names"
|
import { getAgentDisplayName } from "../../shared/agent-display-names"
|
||||||
|
import { recordBackgroundOutputConsumption } from "../../shared/background-output-consumption"
|
||||||
|
|
||||||
const SISYPHUS_JUNIOR_AGENT = getAgentDisplayName("sisyphus-junior")
|
const SISYPHUS_JUNIOR_AGENT = getAgentDisplayName("sisyphus-junior")
|
||||||
|
|
||||||
type ToolContextWithMetadata = {
|
type ToolContextWithMetadata = {
|
||||||
sessionID: string
|
sessionID: string
|
||||||
|
messageID?: string
|
||||||
metadata?: (input: { title?: string; metadata?: Record<string, unknown> }) => void
|
metadata?: (input: { title?: string; metadata?: Record<string, unknown> }) => void
|
||||||
callID?: string
|
callID?: string
|
||||||
callId?: string
|
callId?: string
|
||||||
@@ -139,6 +141,7 @@ export function createBackgroundOutput(manager: BackgroundOutputManager, client:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (resolvedTask.status === "completed") {
|
if (resolvedTask.status === "completed") {
|
||||||
|
recordBackgroundOutputConsumption(ctx.sessionID, ctx.messageID, resolvedTask.sessionID)
|
||||||
return await formatTaskResult(resolvedTask, client)
|
return await formatTaskResult(resolvedTask, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
129
src/tools/background-task/create-background-output.undo.test.ts
Normal file
129
src/tools/background-task/create-background-output.undo.test.ts
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
/// <reference types="bun-types" />
|
||||||
|
|
||||||
|
import { afterEach, describe, expect, test } from "bun:test"
|
||||||
|
import type { ToolContext } from "@opencode-ai/plugin/tool"
|
||||||
|
import type { BackgroundTask } from "../../features/background-agent"
|
||||||
|
import { createEventHandler } from "../../plugin/event"
|
||||||
|
import { clearBackgroundOutputConsumptionState } from "../../shared/background-output-consumption"
|
||||||
|
import { resetMessageCursor } from "../../shared/session-cursor"
|
||||||
|
import type { BackgroundOutputClient, BackgroundOutputManager } from "./clients"
|
||||||
|
import { createBackgroundOutput } from "./create-background-output"
|
||||||
|
|
||||||
|
const projectDir = "/Users/yeongyu/local-workspaces/oh-my-opencode"
|
||||||
|
|
||||||
|
const parentSessionID = "parent-session"
|
||||||
|
const taskSessionID = "task-session"
|
||||||
|
|
||||||
|
type ToolContextWithCallID = ToolContext & {
|
||||||
|
callID: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const baseContext = {
|
||||||
|
sessionID: parentSessionID,
|
||||||
|
agent: "test-agent",
|
||||||
|
directory: projectDir,
|
||||||
|
worktree: projectDir,
|
||||||
|
abort: new AbortController().signal,
|
||||||
|
metadata: () => {},
|
||||||
|
ask: async () => {},
|
||||||
|
callID: "call-1",
|
||||||
|
} as const satisfies Partial<ToolContextWithCallID>
|
||||||
|
|
||||||
|
function createTask(overrides: Partial<BackgroundTask> = {}): BackgroundTask {
|
||||||
|
return {
|
||||||
|
id: "task-1",
|
||||||
|
sessionID: taskSessionID,
|
||||||
|
parentSessionID,
|
||||||
|
parentMessageID: "msg-parent",
|
||||||
|
description: "background task",
|
||||||
|
prompt: "do work",
|
||||||
|
agent: "test-agent",
|
||||||
|
status: "completed",
|
||||||
|
...overrides,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function createMockClient(): BackgroundOutputClient {
|
||||||
|
return {
|
||||||
|
session: {
|
||||||
|
messages: async () => ({
|
||||||
|
data: [
|
||||||
|
{
|
||||||
|
id: "m1",
|
||||||
|
info: { role: "assistant", time: "2026-01-01T00:00:00Z" },
|
||||||
|
parts: [{ type: "text", text: "final result" }],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function createMockEventHandler() {
|
||||||
|
return createEventHandler({
|
||||||
|
ctx: {} as never,
|
||||||
|
pluginConfig: {} as never,
|
||||||
|
firstMessageVariantGate: {
|
||||||
|
markSessionCreated: () => {},
|
||||||
|
clear: () => {},
|
||||||
|
},
|
||||||
|
managers: {
|
||||||
|
skillMcpManager: {
|
||||||
|
disconnectSession: async () => {},
|
||||||
|
},
|
||||||
|
tmuxSessionManager: {
|
||||||
|
onSessionCreated: async () => {},
|
||||||
|
onSessionDeleted: async () => {},
|
||||||
|
},
|
||||||
|
} as never,
|
||||||
|
hooks: {} as never,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
resetMessageCursor(taskSessionID)
|
||||||
|
clearBackgroundOutputConsumptionState()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe("createBackgroundOutput undo regression", () => {
|
||||||
|
test("#given consumed background output #when undo removes the parent message #then output can be consumed again", async () => {
|
||||||
|
// #given
|
||||||
|
const task = createTask()
|
||||||
|
const manager: BackgroundOutputManager = {
|
||||||
|
getTask: id => (id === task.id ? task : undefined),
|
||||||
|
}
|
||||||
|
const tool = createBackgroundOutput(manager, createMockClient())
|
||||||
|
const eventHandler = createMockEventHandler()
|
||||||
|
|
||||||
|
// #when
|
||||||
|
const firstOutput = await tool.execute(
|
||||||
|
{ task_id: task.id },
|
||||||
|
{ ...baseContext, messageID: "msg-result-1" } as ToolContextWithCallID
|
||||||
|
)
|
||||||
|
|
||||||
|
const secondOutput = await tool.execute(
|
||||||
|
{ task_id: task.id },
|
||||||
|
{ ...baseContext, callID: "call-2", messageID: "msg-result-2" } as ToolContextWithCallID
|
||||||
|
)
|
||||||
|
|
||||||
|
await eventHandler({
|
||||||
|
event: {
|
||||||
|
type: "message.removed",
|
||||||
|
properties: {
|
||||||
|
sessionID: parentSessionID,
|
||||||
|
messageID: "msg-result-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const thirdOutput = await tool.execute(
|
||||||
|
{ task_id: task.id },
|
||||||
|
{ ...baseContext, callID: "call-3", messageID: "msg-result-3" } as ToolContextWithCallID
|
||||||
|
)
|
||||||
|
|
||||||
|
// #then
|
||||||
|
expect(firstOutput).toContain("final result")
|
||||||
|
expect(secondOutput).toContain("No new output since last check")
|
||||||
|
expect(thirdOutput).toContain("final result")
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user