diff --git a/src/config/schema/hooks.ts b/src/config/schema/hooks.ts index 6b0219f72..543266c7a 100644 --- a/src/config/schema/hooks.ts +++ b/src/config/schema/hooks.ts @@ -41,6 +41,7 @@ export const HookNameSchema = z.enum([ "no-hephaestus-non-gpt", "start-work", "atlas", + "agent-switch", "unstable-agent-babysitter", "task-resume-info", "stop-continuation-guard", diff --git a/src/features/claude-code-session-state/index.ts b/src/features/claude-code-session-state/index.ts index 2b19e022c..3b05d7408 100644 --- a/src/features/claude-code-session-state/index.ts +++ b/src/features/claude-code-session-state/index.ts @@ -1 +1,2 @@ export * from "./state" +export * from "./switch-agent-state" diff --git a/src/features/claude-code-session-state/state.ts b/src/features/claude-code-session-state/state.ts index 60a8f8a84..646b1945b 100644 --- a/src/features/claude-code-session-state/state.ts +++ b/src/features/claude-code-session-state/state.ts @@ -1,3 +1,5 @@ +import { resetPendingSessionAgentSwitchesForTesting } from "./switch-agent-state" + export const subagentSessions = new Set() export const syncSubagentSessions = new Set() @@ -17,6 +19,7 @@ export function _resetForTesting(): void { subagentSessions.clear() syncSubagentSessions.clear() sessionAgentMap.clear() + resetPendingSessionAgentSwitchesForTesting() } const sessionAgentMap = new Map() diff --git a/src/features/claude-code-session-state/switch-agent-state.test.ts b/src/features/claude-code-session-state/switch-agent-state.test.ts new file mode 100644 index 000000000..f22b9a49b --- /dev/null +++ b/src/features/claude-code-session-state/switch-agent-state.test.ts @@ -0,0 +1,38 @@ +import { describe, expect, test, beforeEach } from "bun:test" +import { + clearPendingSessionAgentSwitch, + consumePendingSessionAgentSwitch, + getPendingSessionAgentSwitch, + resetPendingSessionAgentSwitchesForTesting, + setPendingSessionAgentSwitch, +} from "./switch-agent-state" + +describe("switch-agent-state", () => { + beforeEach(() => { + resetPendingSessionAgentSwitchesForTesting() + }) + + test("#given pending switch #when consuming #then consumes once and clears", () => { + // given + setPendingSessionAgentSwitch("ses-1", "explore") + + // when + const first = consumePendingSessionAgentSwitch("ses-1") + const second = consumePendingSessionAgentSwitch("ses-1") + + // then + expect(first?.agent).toBe("explore") + expect(second).toBeUndefined() + }) + + test("#given pending switch #when clearing #then state is removed", () => { + // given + setPendingSessionAgentSwitch("ses-1", "librarian") + + // when + clearPendingSessionAgentSwitch("ses-1") + + // then + expect(getPendingSessionAgentSwitch("ses-1")).toBeUndefined() + }) +}) diff --git a/src/features/claude-code-session-state/switch-agent-state.ts b/src/features/claude-code-session-state/switch-agent-state.ts new file mode 100644 index 000000000..5e6706205 --- /dev/null +++ b/src/features/claude-code-session-state/switch-agent-state.ts @@ -0,0 +1,37 @@ +type PendingAgentSwitch = { + agent: string + requestedAt: Date +} + +const pendingAgentSwitchBySession = new Map() + +export function setPendingSessionAgentSwitch(sessionID: string, agent: string): PendingAgentSwitch { + const pendingSwitch: PendingAgentSwitch = { + agent, + requestedAt: new Date(), + } + pendingAgentSwitchBySession.set(sessionID, pendingSwitch) + return pendingSwitch +} + +export function getPendingSessionAgentSwitch(sessionID: string): PendingAgentSwitch | undefined { + return pendingAgentSwitchBySession.get(sessionID) +} + +export function consumePendingSessionAgentSwitch(sessionID: string): PendingAgentSwitch | undefined { + const pendingSwitch = pendingAgentSwitchBySession.get(sessionID) + if (!pendingSwitch) { + return undefined + } + + pendingAgentSwitchBySession.delete(sessionID) + return pendingSwitch +} + +export function clearPendingSessionAgentSwitch(sessionID: string): void { + pendingAgentSwitchBySession.delete(sessionID) +} + +export function resetPendingSessionAgentSwitchesForTesting(): void { + pendingAgentSwitchBySession.clear() +} diff --git a/src/hooks/index.ts b/src/hooks/index.ts index 8121f5097..7ab6e4639 100644 --- a/src/hooks/index.ts +++ b/src/hooks/index.ts @@ -53,3 +53,4 @@ export { createJsonErrorRecoveryHook, JSON_ERROR_TOOL_EXCLUDE_LIST, JSON_ERROR_P export { createReadImageResizerHook } from "./read-image-resizer" export { createTodoDescriptionOverrideHook } from "./todo-description-override" export { createWebFetchRedirectGuardHook } from "./webfetch-redirect-guard" +export { createSwitchAgentHook } from "./switch-agent" diff --git a/src/hooks/switch-agent/hook.test.ts b/src/hooks/switch-agent/hook.test.ts new file mode 100644 index 000000000..dace6cd19 --- /dev/null +++ b/src/hooks/switch-agent/hook.test.ts @@ -0,0 +1,32 @@ +import { describe, expect, test, beforeEach } from "bun:test" +import { createSwitchAgentHook } from "./hook" +import { + _resetForTesting, + getSessionAgent, + setPendingSessionAgentSwitch, +} from "../../features/claude-code-session-state" + +describe("switch-agent hook", () => { + beforeEach(() => { + _resetForTesting() + }) + + test("#given pending switch #when chat.message hook runs #then output agent is switched and persisted", async () => { + // given + const hook = createSwitchAgentHook() + setPendingSessionAgentSwitch("ses-1", "explore") + const input = { sessionID: "ses-1", agent: "sisyphus" } + const output = { + message: {} as Record, + parts: [] as Array<{ type: string; text?: string }>, + } + + // when + await hook["chat.message"](input, output) + + // then + expect(input.agent).toBe("explore") + expect(output.message["agent"]).toBe("explore") + expect(getSessionAgent("ses-1")).toBe("explore") + }) +}) diff --git a/src/hooks/switch-agent/hook.ts b/src/hooks/switch-agent/hook.ts new file mode 100644 index 000000000..e9255f7f7 --- /dev/null +++ b/src/hooks/switch-agent/hook.ts @@ -0,0 +1,20 @@ +import type { ChatMessageHandlerOutput, ChatMessageInput } from "../../plugin/chat-message" +import { + consumePendingSessionAgentSwitch, + updateSessionAgent, +} from "../../features/claude-code-session-state" + +export function createSwitchAgentHook() { + return { + "chat.message": async (input: ChatMessageInput, output: ChatMessageHandlerOutput): Promise => { + const pendingSwitch = consumePendingSessionAgentSwitch(input.sessionID) + if (!pendingSwitch) { + return + } + + output.message["agent"] = pendingSwitch.agent + input.agent = pendingSwitch.agent + updateSessionAgent(input.sessionID, pendingSwitch.agent) + }, + } +} diff --git a/src/hooks/switch-agent/index.ts b/src/hooks/switch-agent/index.ts new file mode 100644 index 000000000..cae12eb9e --- /dev/null +++ b/src/hooks/switch-agent/index.ts @@ -0,0 +1 @@ +export { createSwitchAgentHook } from "./hook" diff --git a/src/plugin/chat-message.ts b/src/plugin/chat-message.ts index b7bfea33f..80107e184 100644 --- a/src/plugin/chat-message.ts +++ b/src/plugin/chat-message.ts @@ -157,6 +157,7 @@ export function createChatMessageHandler(args: { setSessionModel(input.sessionID, input.model) } await hooks.stopContinuationGuard?.["chat.message"]?.(input) + await hooks.switchAgentHook?.["chat.message"]?.(input, output) await hooks.backgroundNotificationHook?.["chat.message"]?.(input, output) await hooks.runtimeFallback?.["chat.message"]?.(input, output) await hooks.keywordDetector?.["chat.message"]?.(input, output) diff --git a/src/plugin/event.ts b/src/plugin/event.ts index 126d6e819..61db9583d 100644 --- a/src/plugin/event.ts +++ b/src/plugin/event.ts @@ -5,6 +5,7 @@ import { clearSessionAgent, getMainSessionID, getSessionAgent, + clearPendingSessionAgentSwitch, setMainSession, subagentSessions, syncSubagentSessions, @@ -323,6 +324,7 @@ export function createEventHandler(args: { if (sessionInfo?.id) { const wasSyncSubagentSession = syncSubagentSessions.has(sessionInfo.id); clearSessionAgent(sessionInfo.id); + clearPendingSessionAgentSwitch(sessionInfo.id); lastHandledModelErrorMessageID.delete(sessionInfo.id); lastHandledRetryStatusKey.delete(sessionInfo.id); lastKnownModelBySession.delete(sessionInfo.id); diff --git a/src/plugin/hooks/create-continuation-hooks.ts b/src/plugin/hooks/create-continuation-hooks.ts index c44247af9..bb9bb6d10 100644 --- a/src/plugin/hooks/create-continuation-hooks.ts +++ b/src/plugin/hooks/create-continuation-hooks.ts @@ -9,6 +9,7 @@ import { createCompactionContextInjector, createCompactionTodoPreserverHook, createAtlasHook, + createSwitchAgentHook, } from "../../hooks" import { safeCreateHook } from "../../shared/safe-create-hook" import { createUnstableAgentBabysitter } from "../unstable-agent-babysitter" @@ -21,6 +22,7 @@ export type ContinuationHooks = { unstableAgentBabysitter: ReturnType | null backgroundNotificationHook: ReturnType | null atlasHook: ReturnType | null + switchAgentHook: ReturnType | null } type SessionRecovery = { @@ -116,6 +118,10 @@ export function createContinuationHooks(args: { })) : null + const switchAgentHook = isHookEnabled("agent-switch") + ? safeHook("agent-switch", () => createSwitchAgentHook()) + : null + return { stopContinuationGuard, compactionContextInjector, @@ -124,5 +130,6 @@ export function createContinuationHooks(args: { unstableAgentBabysitter, backgroundNotificationHook, atlasHook, + switchAgentHook, } } diff --git a/src/plugin/tool-registry.ts b/src/plugin/tool-registry.ts index 82a9a2ceb..b6df5bb6a 100644 --- a/src/plugin/tool-registry.ts +++ b/src/plugin/tool-registry.ts @@ -25,6 +25,7 @@ import { createTaskList, createTaskUpdateTool, createHashlineEditTool, + createSwitchAgentTool, } from "../tools" import { getMainSessionID } from "../features/claude-code-session-state" import { filterDisabledTools } from "../shared/disabled-tools" @@ -144,6 +145,7 @@ export function createToolRegistry(args: { interactive_bash, ...taskToolsRecord, ...hashlineToolsRecord, + switch_agent: createSwitchAgentTool(ctx.client, pluginConfig.disabled_agents ?? []), } for (const toolDefinition of Object.values(allTools)) { diff --git a/src/tools/background-task/constants.ts b/src/tools/background-task/constants.ts index 0a1eeedea..5cfcdca41 100644 --- a/src/tools/background-task/constants.ts +++ b/src/tools/background-task/constants.ts @@ -5,3 +5,5 @@ Use \`background_output\` to get results. Prompts MUST be in English.` export const BACKGROUND_OUTPUT_DESCRIPTION = `Get output from background task. Use full_session=true to fetch session messages with filters. System notifies on completion, so block=true rarely needed. - Timeout values are in milliseconds (ms), NOT seconds.` export const BACKGROUND_CANCEL_DESCRIPTION = `Cancel running background task(s). Use all=true to cancel ALL before final answer.` + +export const BACKGROUND_WAIT_DESCRIPTION = `Wait on grouped background tasks with all/any/quorum semantics. Returns structured grouped status for orchestration.` diff --git a/src/tools/background-task/create-background-wait.test.ts b/src/tools/background-task/create-background-wait.test.ts new file mode 100644 index 000000000..48fc0df61 --- /dev/null +++ b/src/tools/background-task/create-background-wait.test.ts @@ -0,0 +1,110 @@ +import { describe, expect, test } from "bun:test" +import { createBackgroundWait } from "./create-background-wait" +import type { BackgroundOutputManager, BackgroundWaitResult } from "./types" +import type { BackgroundTask } from "../../features/background-agent" + +function parseResult(result: string): BackgroundWaitResult { + return JSON.parse(result) as BackgroundWaitResult +} + +function createTask(overrides: Partial): BackgroundTask { + return { + id: "bg-1", + parentSessionID: "main-1", + parentMessageID: "msg-1", + description: "task", + prompt: "prompt", + agent: "explore", + status: "running", + ...overrides, + } +} + +describe("background_wait", () => { + test("#given grouped task IDs #when block=false #then returns grouped structured status", async () => { + // given + const runningTask = createTask({ id: "bg-running", status: "running" }) + const completedTask = createTask({ id: "bg-done", status: "completed" }) + const manager: BackgroundOutputManager = { + getTask: (taskID: string) => { + if (taskID === runningTask.id) return runningTask + if (taskID === completedTask.id) return completedTask + return undefined + }, + } + const tool = createBackgroundWait(manager) + + // when + const output = await tool.execute({ + task_ids: [runningTask.id, completedTask.id, "bg-missing"], + block: false, + }, {} as never) + const parsed = parseResult(output) + + // then + expect(parsed.summary.total).toBe(3) + expect(parsed.summary.by_status.running).toBe(1) + expect(parsed.summary.by_status.completed).toBe(1) + expect(parsed.summary.by_status.not_found).toBe(1) + expect(parsed.grouped.completed).toContain(completedTask.id) + expect(parsed.grouped.not_found).toContain("bg-missing") + }) + + test("#given race mode #when block=true and one task reaches terminal #then returns quorum_reached", async () => { + // given + const task = createTask({ id: "bg-race", status: "running" }) + let readCount = 0 + const manager: BackgroundOutputManager = { + getTask: (taskID: string) => { + if (taskID !== task.id) return undefined + readCount += 1 + if (readCount >= 2) { + task.status = "completed" + } + return task + }, + } + const tool = createBackgroundWait(manager) + + // when + const output = await tool.execute({ + task_ids: [task.id], + mode: "any", + block: true, + timeout: 500, + poll_interval: 20, + }, {} as never) + const parsed = parseResult(output) + + // then + expect(parsed.done).toBe(true) + expect(parsed.reason).toBe("quorum_reached") + expect(parsed.quorum.target).toBe(1) + expect(parsed.quorum.reached).toBe(1) + }) + + test("#given unmet quorum #when block=true until timeout #then returns timeout status", async () => { + // given + const runningTask = createTask({ id: "bg-still-running", status: "running" }) + const manager: BackgroundOutputManager = { + getTask: (taskID: string) => (taskID === runningTask.id ? runningTask : undefined), + } + const tool = createBackgroundWait(manager) + + // when + const output = await tool.execute({ + task_ids: [runningTask.id], + quorum: 1, + block: true, + timeout: 120, + poll_interval: 20, + }, {} as never) + const parsed = parseResult(output) + + // then + expect(parsed.done).toBe(false) + expect(parsed.reason).toBe("timeout") + expect(parsed.summary.by_status.running).toBe(1) + expect(parsed.quorum.reached).toBe(0) + }) +}) diff --git a/src/tools/background-task/create-background-wait.ts b/src/tools/background-task/create-background-wait.ts new file mode 100644 index 000000000..cc7511831 --- /dev/null +++ b/src/tools/background-task/create-background-wait.ts @@ -0,0 +1,158 @@ +import { tool, type ToolDefinition } from "@opencode-ai/plugin" +import type { BackgroundTask } from "../../features/background-agent" +import { BACKGROUND_WAIT_DESCRIPTION } from "./constants" +import { delay } from "./delay" +import type { BackgroundOutputManager, BackgroundWaitArgs, BackgroundWaitResult } from "./types" + +type WaitTaskStatus = "pending" | "running" | "completed" | "error" | "cancelled" | "interrupt" | "not_found" + +const TERMINAL_STATUSES: ReadonlySet = new Set([ + "completed", + "error", + "cancelled", + "interrupt", +]) + +function isTerminalStatus(status: BackgroundTask["status"]): boolean { + return TERMINAL_STATUSES.has(status) +} + +function toValidTaskIDs(taskIDs: string[]): string[] { + const uniqueTaskIDs = new Set() + for (const taskID of taskIDs) { + const normalized = taskID.trim() + if (normalized) { + uniqueTaskIDs.add(normalized) + } + } + return [...uniqueTaskIDs] +} + +export function createBackgroundWait(manager: BackgroundOutputManager): ToolDefinition { + return tool({ + description: BACKGROUND_WAIT_DESCRIPTION, + args: { + task_ids: tool.schema.array(tool.schema.string()).describe("Task IDs to inspect as a group"), + mode: tool.schema.string().optional().describe("all (default) waits for all, any returns on first quorum/race completion"), + quorum: tool.schema.number().optional().describe("Optional terminal-task quorum target"), + block: tool.schema.boolean().optional().describe("Wait for quorum/race completion (default: false)"), + timeout: tool.schema.number().optional().describe("Max wait time in ms when block=true (default: 60000, max: 600000)"), + poll_interval: tool.schema.number().optional().describe("Polling interval in ms when block=true (default: 1000, min: 100)"), + }, + async execute(args: BackgroundWaitArgs) { + const taskIDs = toValidTaskIDs(args.task_ids) + if (taskIDs.length === 0) { + return "Error: task_ids must contain at least one task ID." + } + + const mode = args.mode === "any" ? "any" : args.mode === undefined || args.mode === "all" ? "all" : null + if (!mode) { + return `Error: invalid mode \"${args.mode}\". Use \"all\" or \"any\".` + } + + if (args.quorum !== undefined && (!Number.isInteger(args.quorum) || args.quorum < 1)) { + return "Error: quorum must be a positive integer." + } + + const timeoutMs = Math.min(args.timeout ?? 60000, 600000) + const pollIntervalMs = Math.max(args.poll_interval ?? 1000, 100) + const block = args.block === true + const quorumTarget = Math.min(args.quorum ?? (mode === "any" ? 1 : taskIDs.length), taskIDs.length) + const startTime = Date.now() + + const buildSnapshot = (): BackgroundWaitResult => { + const byStatus: Record = { + pending: 0, + running: 0, + completed: 0, + error: 0, + cancelled: 0, + interrupt: 0, + not_found: 0, + } + + const tasks = taskIDs.map((taskID) => { + const task = manager.getTask(taskID) + if (!task) { + byStatus.not_found += 1 + return { + task_id: taskID, + found: false, + status: "not_found" as const, + } + } + + byStatus[task.status] += 1 + return { + task_id: task.id, + found: true, + status: task.status, + agent: task.agent, + description: task.description, + session_id: task.sessionID, + started_at: task.startedAt?.toISOString(), + completed_at: task.completedAt?.toISOString(), + } + }) + + const terminalCount = tasks.filter((task) => task.found && isTerminalStatus(task.status as BackgroundTask["status"])) + .length + const activeCount = tasks.filter((task) => task.status === "pending" || task.status === "running").length + const quorumReached = terminalCount >= quorumTarget + + return { + mode, + block, + timeout_ms: timeoutMs, + waited_ms: Date.now() - startTime, + done: quorumReached, + reason: block ? "waiting" : "non_blocking", + quorum: { + target: quorumTarget, + reached: terminalCount, + remaining: Math.max(quorumTarget - terminalCount, 0), + progress: quorumTarget === 0 ? 1 : terminalCount / quorumTarget, + }, + summary: { + total: tasks.length, + terminal: terminalCount, + active: activeCount, + by_status: byStatus, + }, + grouped: { + pending: tasks.filter((task) => task.status === "pending").map((task) => task.task_id), + running: tasks.filter((task) => task.status === "running").map((task) => task.task_id), + completed: tasks.filter((task) => task.status === "completed").map((task) => task.task_id), + error: tasks.filter((task) => task.status === "error").map((task) => task.task_id), + cancelled: tasks.filter((task) => task.status === "cancelled").map((task) => task.task_id), + interrupt: tasks.filter((task) => task.status === "interrupt").map((task) => task.task_id), + not_found: tasks.filter((task) => task.status === "not_found").map((task) => task.task_id), + }, + tasks: tasks.map((task) => ({ + ...task, + status: task.status as WaitTaskStatus, + })), + } + } + + let snapshot = buildSnapshot() + if (!block) { + return JSON.stringify(snapshot, null, 2) + } + + while (!snapshot.done && Date.now() - startTime < timeoutMs) { + await delay(pollIntervalMs) + snapshot = buildSnapshot() + } + + const finalSnapshot: BackgroundWaitResult = { + ...snapshot, + waited_ms: Date.now() - startTime, + done: snapshot.done, + reason: snapshot.done ? "quorum_reached" : "timeout", + } + + return JSON.stringify(finalSnapshot, null, 2) + }, + }) +} diff --git a/src/tools/background-task/index.ts b/src/tools/background-task/index.ts index 22324f8dd..8c9113acc 100644 --- a/src/tools/background-task/index.ts +++ b/src/tools/background-task/index.ts @@ -2,6 +2,7 @@ export { createBackgroundTask, createBackgroundOutput, createBackgroundCancel, + createBackgroundWait, } from "./tools" export type * from "./types" diff --git a/src/tools/background-task/tools.ts b/src/tools/background-task/tools.ts index ce30adb91..8fb0b438b 100644 --- a/src/tools/background-task/tools.ts +++ b/src/tools/background-task/tools.ts @@ -9,3 +9,4 @@ export type { export { createBackgroundTask } from "./create-background-task" export { createBackgroundOutput } from "./create-background-output" export { createBackgroundCancel } from "./create-background-cancel" +export { createBackgroundWait } from "./create-background-wait" diff --git a/src/tools/background-task/types.ts b/src/tools/background-task/types.ts index eafd87c07..161bcb422 100644 --- a/src/tools/background-task/types.ts +++ b/src/tools/background-task/types.ts @@ -21,6 +21,49 @@ export interface BackgroundCancelArgs { all?: boolean } +export interface BackgroundWaitArgs { + task_ids: string[] + mode?: "all" | "any" + quorum?: number + block?: boolean + timeout?: number + poll_interval?: number +} + +export type BackgroundWaitTaskSnapshot = { + task_id: string + found: boolean + status: "pending" | "running" | "completed" | "error" | "cancelled" | "interrupt" | "not_found" + agent?: string + description?: string + session_id?: string + started_at?: string + completed_at?: string +} + +export type BackgroundWaitResult = { + mode: "all" | "any" + block: boolean + timeout_ms: number + waited_ms: number + done: boolean + reason: "non_blocking" | "waiting" | "quorum_reached" | "timeout" + quorum: { + target: number + reached: number + remaining: number + progress: number + } + summary: { + total: number + terminal: number + active: number + by_status: Record + } + grouped: Record + tasks: BackgroundWaitTaskSnapshot[] +} + export type BackgroundOutputMessage = { info?: { role?: string; time?: string | { created?: number }; agent?: string } parts?: Array<{ diff --git a/src/tools/index.ts b/src/tools/index.ts index 9d9bd9c04..e55601ed5 100644 --- a/src/tools/index.ts +++ b/src/tools/index.ts @@ -21,10 +21,12 @@ export { sessionExists } from "./session-manager/storage" export { interactive_bash, startBackgroundCheck as startTmuxCheck } from "./interactive-bash" export { createSkillMcpTool } from "./skill-mcp" +export { createSwitchAgentTool } from "./switch-agent" import { createBackgroundOutput, createBackgroundCancel, + createBackgroundWait, type BackgroundOutputManager, type BackgroundCancelClient, } from "./background-task" @@ -51,6 +53,7 @@ export function createBackgroundTools(manager: BackgroundManager, client: Openco return { background_output: createBackgroundOutput(outputManager, client), background_cancel: createBackgroundCancel(manager, cancelClient), + background_wait: createBackgroundWait(outputManager), } } diff --git a/src/tools/switch-agent/constants.ts b/src/tools/switch-agent/constants.ts new file mode 100644 index 000000000..bde139a75 --- /dev/null +++ b/src/tools/switch-agent/constants.ts @@ -0,0 +1 @@ +export const SWITCH_AGENT_DESCRIPTION = "Queue an agent switch for the current or target session. Switch is applied on next chat.message through hook flow." diff --git a/src/tools/switch-agent/index.ts b/src/tools/switch-agent/index.ts new file mode 100644 index 000000000..3f1703bc2 --- /dev/null +++ b/src/tools/switch-agent/index.ts @@ -0,0 +1,2 @@ +export { createSwitchAgentTool } from "./tools" +export type { SwitchAgentArgs } from "./types" diff --git a/src/tools/switch-agent/tools.test.ts b/src/tools/switch-agent/tools.test.ts new file mode 100644 index 000000000..e91215339 --- /dev/null +++ b/src/tools/switch-agent/tools.test.ts @@ -0,0 +1,79 @@ +import { describe, expect, test, beforeEach } from "bun:test" +import { createSwitchAgentTool } from "./tools" +import { + _resetForTesting, + getPendingSessionAgentSwitch, +} from "../../features/claude-code-session-state" + +describe("switch_agent tool", () => { + beforeEach(() => { + _resetForTesting() + }) + + test("#given empty agent #when executing #then returns validation error", async () => { + // given + const client = { + app: { + agents: async () => ({ data: [{ name: "sisyphus" }] }), + }, + } as unknown as Parameters[0] + const tool = createSwitchAgentTool(client) + + // when + const output = await tool.execute({ agent: " " }, { sessionID: "ses-1" } as never) + + // then + expect(output).toContain("agent is required") + }) + + test("#given unknown agent #when executing #then returns invalid switch error", async () => { + // given + const client = { + app: { + agents: async () => ({ data: [{ name: "sisyphus" }, { name: "explore" }] }), + }, + } as unknown as Parameters[0] + const tool = createSwitchAgentTool(client) + + // when + const output = await tool.execute({ agent: "ghost" }, { sessionID: "ses-1" } as never) + + // then + expect(output).toContain("unknown agent") + expect(getPendingSessionAgentSwitch("ses-1")).toBeUndefined() + }) + + test("#given known but disabled agent #when executing #then returns disabled error", async () => { + // given + const client = { + app: { + agents: async () => ({ data: [{ name: "explore" }] }), + }, + } as unknown as Parameters[0] + const tool = createSwitchAgentTool(client, ["explore"]) + + // when + const output = await tool.execute({ agent: "explore" }, { sessionID: "ses-1" } as never) + + // then + expect(output).toContain("disabled") + expect(getPendingSessionAgentSwitch("ses-1")).toBeUndefined() + }) + + test("#given known enabled agent #when executing #then queues pending switch", async () => { + // given + const client = { + app: { + agents: async () => ({ data: [{ name: "explore" }, { name: "Athena" }] }), + }, + } as unknown as Parameters[0] + const tool = createSwitchAgentTool(client) + + // when + const output = await tool.execute({ agent: "explore" }, { sessionID: "ses-1" } as never) + + // then + expect(output).toContain("Agent switch queued") + expect(getPendingSessionAgentSwitch("ses-1")?.agent).toBe("explore") + }) +}) diff --git a/src/tools/switch-agent/tools.ts b/src/tools/switch-agent/tools.ts new file mode 100644 index 000000000..3451e58d1 --- /dev/null +++ b/src/tools/switch-agent/tools.ts @@ -0,0 +1,57 @@ +import { tool, type PluginInput, type ToolDefinition } from "@opencode-ai/plugin" +import { setPendingSessionAgentSwitch } from "../../features/claude-code-session-state" +import { normalizeSDKResponse } from "../../shared" +import type { SwitchAgentArgs } from "./types" +import { SWITCH_AGENT_DESCRIPTION } from "./constants" + +type SwitchableAgent = { + name: string + mode?: "subagent" | "primary" | "all" +} + +export function createSwitchAgentTool(client: PluginInput["client"], disabledAgents: string[] = []): ToolDefinition { + return tool({ + description: SWITCH_AGENT_DESCRIPTION, + args: { + agent: tool.schema.string().describe("Agent name to switch to"), + session_id: tool.schema.string().optional().describe("Session ID to switch. Defaults to current session"), + }, + async execute(args: SwitchAgentArgs, toolContext) { + const targetSessionID = args.session_id ?? toolContext.sessionID + const requestedAgent = args.agent?.trim() + + if (!requestedAgent) { + return "Error: agent is required." + } + + try { + const agentsResponse = await client.app.agents() + const agents = normalizeSDKResponse(agentsResponse, [] as SwitchableAgent[], { + preferResponseOnMissingData: true, + }) + const matchedAgent = agents.find((agent) => agent.name.toLowerCase() === requestedAgent.toLowerCase()) + + if (!matchedAgent) { + const availableAgents = agents.map((agent) => agent.name).sort() + return `Error: unknown agent \"${requestedAgent}\". Available agents: ${availableAgents.join(", ")}` + } + + if (disabledAgents.some((disabledAgent) => disabledAgent.toLowerCase() === matchedAgent.name.toLowerCase())) { + return `Error: agent \"${matchedAgent.name}\" is disabled via disabled_agents configuration.` + } + + const pendingSwitch = setPendingSessionAgentSwitch(targetSessionID, matchedAgent.name) + + return [ + "Agent switch queued.", + `Session ID: ${targetSessionID}`, + `Next agent: ${pendingSwitch.agent}`, + `Requested at: ${pendingSwitch.requestedAt.toISOString()}`, + "The switch will be applied by hook flow on the next chat.message turn.", + ].join("\n") + } catch (error) { + return `Error: failed to queue agent switch: ${error instanceof Error ? error.message : String(error)}` + } + }, + }) +} diff --git a/src/tools/switch-agent/types.ts b/src/tools/switch-agent/types.ts new file mode 100644 index 000000000..83415c5cb --- /dev/null +++ b/src/tools/switch-agent/types.ts @@ -0,0 +1,4 @@ +export type SwitchAgentArgs = { + agent: string + session_id?: string +}