feat(tools): add switch agent background workflow
This commit is contained in:
@@ -41,6 +41,7 @@ export const HookNameSchema = z.enum([
|
||||
"no-hephaestus-non-gpt",
|
||||
"start-work",
|
||||
"atlas",
|
||||
"agent-switch",
|
||||
"unstable-agent-babysitter",
|
||||
"task-resume-info",
|
||||
"stop-continuation-guard",
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
export * from "./state"
|
||||
export * from "./switch-agent-state"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { resetPendingSessionAgentSwitchesForTesting } from "./switch-agent-state"
|
||||
|
||||
export const subagentSessions = new Set<string>()
|
||||
export const syncSubagentSessions = new Set<string>()
|
||||
|
||||
@@ -17,6 +19,7 @@ export function _resetForTesting(): void {
|
||||
subagentSessions.clear()
|
||||
syncSubagentSessions.clear()
|
||||
sessionAgentMap.clear()
|
||||
resetPendingSessionAgentSwitchesForTesting()
|
||||
}
|
||||
|
||||
const sessionAgentMap = new Map<string, string>()
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
import { describe, expect, test, beforeEach } from "bun:test"
|
||||
import {
|
||||
clearPendingSessionAgentSwitch,
|
||||
consumePendingSessionAgentSwitch,
|
||||
getPendingSessionAgentSwitch,
|
||||
resetPendingSessionAgentSwitchesForTesting,
|
||||
setPendingSessionAgentSwitch,
|
||||
} from "./switch-agent-state"
|
||||
|
||||
describe("switch-agent-state", () => {
|
||||
beforeEach(() => {
|
||||
resetPendingSessionAgentSwitchesForTesting()
|
||||
})
|
||||
|
||||
test("#given pending switch #when consuming #then consumes once and clears", () => {
|
||||
// given
|
||||
setPendingSessionAgentSwitch("ses-1", "explore")
|
||||
|
||||
// when
|
||||
const first = consumePendingSessionAgentSwitch("ses-1")
|
||||
const second = consumePendingSessionAgentSwitch("ses-1")
|
||||
|
||||
// then
|
||||
expect(first?.agent).toBe("explore")
|
||||
expect(second).toBeUndefined()
|
||||
})
|
||||
|
||||
test("#given pending switch #when clearing #then state is removed", () => {
|
||||
// given
|
||||
setPendingSessionAgentSwitch("ses-1", "librarian")
|
||||
|
||||
// when
|
||||
clearPendingSessionAgentSwitch("ses-1")
|
||||
|
||||
// then
|
||||
expect(getPendingSessionAgentSwitch("ses-1")).toBeUndefined()
|
||||
})
|
||||
})
|
||||
37
src/features/claude-code-session-state/switch-agent-state.ts
Normal file
37
src/features/claude-code-session-state/switch-agent-state.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
type PendingAgentSwitch = {
|
||||
agent: string
|
||||
requestedAt: Date
|
||||
}
|
||||
|
||||
const pendingAgentSwitchBySession = new Map<string, PendingAgentSwitch>()
|
||||
|
||||
export function setPendingSessionAgentSwitch(sessionID: string, agent: string): PendingAgentSwitch {
|
||||
const pendingSwitch: PendingAgentSwitch = {
|
||||
agent,
|
||||
requestedAt: new Date(),
|
||||
}
|
||||
pendingAgentSwitchBySession.set(sessionID, pendingSwitch)
|
||||
return pendingSwitch
|
||||
}
|
||||
|
||||
export function getPendingSessionAgentSwitch(sessionID: string): PendingAgentSwitch | undefined {
|
||||
return pendingAgentSwitchBySession.get(sessionID)
|
||||
}
|
||||
|
||||
export function consumePendingSessionAgentSwitch(sessionID: string): PendingAgentSwitch | undefined {
|
||||
const pendingSwitch = pendingAgentSwitchBySession.get(sessionID)
|
||||
if (!pendingSwitch) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
pendingAgentSwitchBySession.delete(sessionID)
|
||||
return pendingSwitch
|
||||
}
|
||||
|
||||
export function clearPendingSessionAgentSwitch(sessionID: string): void {
|
||||
pendingAgentSwitchBySession.delete(sessionID)
|
||||
}
|
||||
|
||||
export function resetPendingSessionAgentSwitchesForTesting(): void {
|
||||
pendingAgentSwitchBySession.clear()
|
||||
}
|
||||
@@ -53,3 +53,4 @@ export { createJsonErrorRecoveryHook, JSON_ERROR_TOOL_EXCLUDE_LIST, JSON_ERROR_P
|
||||
export { createReadImageResizerHook } from "./read-image-resizer"
|
||||
export { createTodoDescriptionOverrideHook } from "./todo-description-override"
|
||||
export { createWebFetchRedirectGuardHook } from "./webfetch-redirect-guard"
|
||||
export { createSwitchAgentHook } from "./switch-agent"
|
||||
|
||||
32
src/hooks/switch-agent/hook.test.ts
Normal file
32
src/hooks/switch-agent/hook.test.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import { describe, expect, test, beforeEach } from "bun:test"
|
||||
import { createSwitchAgentHook } from "./hook"
|
||||
import {
|
||||
_resetForTesting,
|
||||
getSessionAgent,
|
||||
setPendingSessionAgentSwitch,
|
||||
} from "../../features/claude-code-session-state"
|
||||
|
||||
describe("switch-agent hook", () => {
|
||||
beforeEach(() => {
|
||||
_resetForTesting()
|
||||
})
|
||||
|
||||
test("#given pending switch #when chat.message hook runs #then output agent is switched and persisted", async () => {
|
||||
// given
|
||||
const hook = createSwitchAgentHook()
|
||||
setPendingSessionAgentSwitch("ses-1", "explore")
|
||||
const input = { sessionID: "ses-1", agent: "sisyphus" }
|
||||
const output = {
|
||||
message: {} as Record<string, unknown>,
|
||||
parts: [] as Array<{ type: string; text?: string }>,
|
||||
}
|
||||
|
||||
// when
|
||||
await hook["chat.message"](input, output)
|
||||
|
||||
// then
|
||||
expect(input.agent).toBe("explore")
|
||||
expect(output.message["agent"]).toBe("explore")
|
||||
expect(getSessionAgent("ses-1")).toBe("explore")
|
||||
})
|
||||
})
|
||||
20
src/hooks/switch-agent/hook.ts
Normal file
20
src/hooks/switch-agent/hook.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { ChatMessageHandlerOutput, ChatMessageInput } from "../../plugin/chat-message"
|
||||
import {
|
||||
consumePendingSessionAgentSwitch,
|
||||
updateSessionAgent,
|
||||
} from "../../features/claude-code-session-state"
|
||||
|
||||
export function createSwitchAgentHook() {
|
||||
return {
|
||||
"chat.message": async (input: ChatMessageInput, output: ChatMessageHandlerOutput): Promise<void> => {
|
||||
const pendingSwitch = consumePendingSessionAgentSwitch(input.sessionID)
|
||||
if (!pendingSwitch) {
|
||||
return
|
||||
}
|
||||
|
||||
output.message["agent"] = pendingSwitch.agent
|
||||
input.agent = pendingSwitch.agent
|
||||
updateSessionAgent(input.sessionID, pendingSwitch.agent)
|
||||
},
|
||||
}
|
||||
}
|
||||
1
src/hooks/switch-agent/index.ts
Normal file
1
src/hooks/switch-agent/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { createSwitchAgentHook } from "./hook"
|
||||
@@ -157,6 +157,7 @@ export function createChatMessageHandler(args: {
|
||||
setSessionModel(input.sessionID, input.model)
|
||||
}
|
||||
await hooks.stopContinuationGuard?.["chat.message"]?.(input)
|
||||
await hooks.switchAgentHook?.["chat.message"]?.(input, output)
|
||||
await hooks.backgroundNotificationHook?.["chat.message"]?.(input, output)
|
||||
await hooks.runtimeFallback?.["chat.message"]?.(input, output)
|
||||
await hooks.keywordDetector?.["chat.message"]?.(input, output)
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
clearSessionAgent,
|
||||
getMainSessionID,
|
||||
getSessionAgent,
|
||||
clearPendingSessionAgentSwitch,
|
||||
setMainSession,
|
||||
subagentSessions,
|
||||
syncSubagentSessions,
|
||||
@@ -323,6 +324,7 @@ export function createEventHandler(args: {
|
||||
if (sessionInfo?.id) {
|
||||
const wasSyncSubagentSession = syncSubagentSessions.has(sessionInfo.id);
|
||||
clearSessionAgent(sessionInfo.id);
|
||||
clearPendingSessionAgentSwitch(sessionInfo.id);
|
||||
lastHandledModelErrorMessageID.delete(sessionInfo.id);
|
||||
lastHandledRetryStatusKey.delete(sessionInfo.id);
|
||||
lastKnownModelBySession.delete(sessionInfo.id);
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
createCompactionContextInjector,
|
||||
createCompactionTodoPreserverHook,
|
||||
createAtlasHook,
|
||||
createSwitchAgentHook,
|
||||
} from "../../hooks"
|
||||
import { safeCreateHook } from "../../shared/safe-create-hook"
|
||||
import { createUnstableAgentBabysitter } from "../unstable-agent-babysitter"
|
||||
@@ -21,6 +22,7 @@ export type ContinuationHooks = {
|
||||
unstableAgentBabysitter: ReturnType<typeof createUnstableAgentBabysitter> | null
|
||||
backgroundNotificationHook: ReturnType<typeof createBackgroundNotificationHook> | null
|
||||
atlasHook: ReturnType<typeof createAtlasHook> | null
|
||||
switchAgentHook: ReturnType<typeof createSwitchAgentHook> | null
|
||||
}
|
||||
|
||||
type SessionRecovery = {
|
||||
@@ -116,6 +118,10 @@ export function createContinuationHooks(args: {
|
||||
}))
|
||||
: null
|
||||
|
||||
const switchAgentHook = isHookEnabled("agent-switch")
|
||||
? safeHook("agent-switch", () => createSwitchAgentHook())
|
||||
: null
|
||||
|
||||
return {
|
||||
stopContinuationGuard,
|
||||
compactionContextInjector,
|
||||
@@ -124,5 +130,6 @@ export function createContinuationHooks(args: {
|
||||
unstableAgentBabysitter,
|
||||
backgroundNotificationHook,
|
||||
atlasHook,
|
||||
switchAgentHook,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ import {
|
||||
createTaskList,
|
||||
createTaskUpdateTool,
|
||||
createHashlineEditTool,
|
||||
createSwitchAgentTool,
|
||||
} from "../tools"
|
||||
import { getMainSessionID } from "../features/claude-code-session-state"
|
||||
import { filterDisabledTools } from "../shared/disabled-tools"
|
||||
@@ -144,6 +145,7 @@ export function createToolRegistry(args: {
|
||||
interactive_bash,
|
||||
...taskToolsRecord,
|
||||
...hashlineToolsRecord,
|
||||
switch_agent: createSwitchAgentTool(ctx.client, pluginConfig.disabled_agents ?? []),
|
||||
}
|
||||
|
||||
for (const toolDefinition of Object.values(allTools)) {
|
||||
|
||||
@@ -5,3 +5,5 @@ Use \`background_output\` to get results. Prompts MUST be in English.`
|
||||
export const BACKGROUND_OUTPUT_DESCRIPTION = `Get output from background task. Use full_session=true to fetch session messages with filters. System notifies on completion, so block=true rarely needed. - Timeout values are in milliseconds (ms), NOT seconds.`
|
||||
|
||||
export const BACKGROUND_CANCEL_DESCRIPTION = `Cancel running background task(s). Use all=true to cancel ALL before final answer.`
|
||||
|
||||
export const BACKGROUND_WAIT_DESCRIPTION = `Wait on grouped background tasks with all/any/quorum semantics. Returns structured grouped status for orchestration.`
|
||||
|
||||
110
src/tools/background-task/create-background-wait.test.ts
Normal file
110
src/tools/background-task/create-background-wait.test.ts
Normal file
@@ -0,0 +1,110 @@
|
||||
import { describe, expect, test } from "bun:test"
|
||||
import { createBackgroundWait } from "./create-background-wait"
|
||||
import type { BackgroundOutputManager, BackgroundWaitResult } from "./types"
|
||||
import type { BackgroundTask } from "../../features/background-agent"
|
||||
|
||||
function parseResult(result: string): BackgroundWaitResult {
|
||||
return JSON.parse(result) as BackgroundWaitResult
|
||||
}
|
||||
|
||||
function createTask(overrides: Partial<BackgroundTask>): BackgroundTask {
|
||||
return {
|
||||
id: "bg-1",
|
||||
parentSessionID: "main-1",
|
||||
parentMessageID: "msg-1",
|
||||
description: "task",
|
||||
prompt: "prompt",
|
||||
agent: "explore",
|
||||
status: "running",
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
describe("background_wait", () => {
|
||||
test("#given grouped task IDs #when block=false #then returns grouped structured status", async () => {
|
||||
// given
|
||||
const runningTask = createTask({ id: "bg-running", status: "running" })
|
||||
const completedTask = createTask({ id: "bg-done", status: "completed" })
|
||||
const manager: BackgroundOutputManager = {
|
||||
getTask: (taskID: string) => {
|
||||
if (taskID === runningTask.id) return runningTask
|
||||
if (taskID === completedTask.id) return completedTask
|
||||
return undefined
|
||||
},
|
||||
}
|
||||
const tool = createBackgroundWait(manager)
|
||||
|
||||
// when
|
||||
const output = await tool.execute({
|
||||
task_ids: [runningTask.id, completedTask.id, "bg-missing"],
|
||||
block: false,
|
||||
}, {} as never)
|
||||
const parsed = parseResult(output)
|
||||
|
||||
// then
|
||||
expect(parsed.summary.total).toBe(3)
|
||||
expect(parsed.summary.by_status.running).toBe(1)
|
||||
expect(parsed.summary.by_status.completed).toBe(1)
|
||||
expect(parsed.summary.by_status.not_found).toBe(1)
|
||||
expect(parsed.grouped.completed).toContain(completedTask.id)
|
||||
expect(parsed.grouped.not_found).toContain("bg-missing")
|
||||
})
|
||||
|
||||
test("#given race mode #when block=true and one task reaches terminal #then returns quorum_reached", async () => {
|
||||
// given
|
||||
const task = createTask({ id: "bg-race", status: "running" })
|
||||
let readCount = 0
|
||||
const manager: BackgroundOutputManager = {
|
||||
getTask: (taskID: string) => {
|
||||
if (taskID !== task.id) return undefined
|
||||
readCount += 1
|
||||
if (readCount >= 2) {
|
||||
task.status = "completed"
|
||||
}
|
||||
return task
|
||||
},
|
||||
}
|
||||
const tool = createBackgroundWait(manager)
|
||||
|
||||
// when
|
||||
const output = await tool.execute({
|
||||
task_ids: [task.id],
|
||||
mode: "any",
|
||||
block: true,
|
||||
timeout: 500,
|
||||
poll_interval: 20,
|
||||
}, {} as never)
|
||||
const parsed = parseResult(output)
|
||||
|
||||
// then
|
||||
expect(parsed.done).toBe(true)
|
||||
expect(parsed.reason).toBe("quorum_reached")
|
||||
expect(parsed.quorum.target).toBe(1)
|
||||
expect(parsed.quorum.reached).toBe(1)
|
||||
})
|
||||
|
||||
test("#given unmet quorum #when block=true until timeout #then returns timeout status", async () => {
|
||||
// given
|
||||
const runningTask = createTask({ id: "bg-still-running", status: "running" })
|
||||
const manager: BackgroundOutputManager = {
|
||||
getTask: (taskID: string) => (taskID === runningTask.id ? runningTask : undefined),
|
||||
}
|
||||
const tool = createBackgroundWait(manager)
|
||||
|
||||
// when
|
||||
const output = await tool.execute({
|
||||
task_ids: [runningTask.id],
|
||||
quorum: 1,
|
||||
block: true,
|
||||
timeout: 120,
|
||||
poll_interval: 20,
|
||||
}, {} as never)
|
||||
const parsed = parseResult(output)
|
||||
|
||||
// then
|
||||
expect(parsed.done).toBe(false)
|
||||
expect(parsed.reason).toBe("timeout")
|
||||
expect(parsed.summary.by_status.running).toBe(1)
|
||||
expect(parsed.quorum.reached).toBe(0)
|
||||
})
|
||||
})
|
||||
158
src/tools/background-task/create-background-wait.ts
Normal file
158
src/tools/background-task/create-background-wait.ts
Normal file
@@ -0,0 +1,158 @@
|
||||
import { tool, type ToolDefinition } from "@opencode-ai/plugin"
|
||||
import type { BackgroundTask } from "../../features/background-agent"
|
||||
import { BACKGROUND_WAIT_DESCRIPTION } from "./constants"
|
||||
import { delay } from "./delay"
|
||||
import type { BackgroundOutputManager, BackgroundWaitArgs, BackgroundWaitResult } from "./types"
|
||||
|
||||
type WaitTaskStatus = "pending" | "running" | "completed" | "error" | "cancelled" | "interrupt" | "not_found"
|
||||
|
||||
const TERMINAL_STATUSES: ReadonlySet<BackgroundTask["status"]> = new Set([
|
||||
"completed",
|
||||
"error",
|
||||
"cancelled",
|
||||
"interrupt",
|
||||
])
|
||||
|
||||
function isTerminalStatus(status: BackgroundTask["status"]): boolean {
|
||||
return TERMINAL_STATUSES.has(status)
|
||||
}
|
||||
|
||||
function toValidTaskIDs(taskIDs: string[]): string[] {
|
||||
const uniqueTaskIDs = new Set<string>()
|
||||
for (const taskID of taskIDs) {
|
||||
const normalized = taskID.trim()
|
||||
if (normalized) {
|
||||
uniqueTaskIDs.add(normalized)
|
||||
}
|
||||
}
|
||||
return [...uniqueTaskIDs]
|
||||
}
|
||||
|
||||
export function createBackgroundWait(manager: BackgroundOutputManager): ToolDefinition {
|
||||
return tool({
|
||||
description: BACKGROUND_WAIT_DESCRIPTION,
|
||||
args: {
|
||||
task_ids: tool.schema.array(tool.schema.string()).describe("Task IDs to inspect as a group"),
|
||||
mode: tool.schema.string().optional().describe("all (default) waits for all, any returns on first quorum/race completion"),
|
||||
quorum: tool.schema.number().optional().describe("Optional terminal-task quorum target"),
|
||||
block: tool.schema.boolean().optional().describe("Wait for quorum/race completion (default: false)"),
|
||||
timeout: tool.schema.number().optional().describe("Max wait time in ms when block=true (default: 60000, max: 600000)"),
|
||||
poll_interval: tool.schema.number().optional().describe("Polling interval in ms when block=true (default: 1000, min: 100)"),
|
||||
},
|
||||
async execute(args: BackgroundWaitArgs) {
|
||||
const taskIDs = toValidTaskIDs(args.task_ids)
|
||||
if (taskIDs.length === 0) {
|
||||
return "Error: task_ids must contain at least one task ID."
|
||||
}
|
||||
|
||||
const mode = args.mode === "any" ? "any" : args.mode === undefined || args.mode === "all" ? "all" : null
|
||||
if (!mode) {
|
||||
return `Error: invalid mode \"${args.mode}\". Use \"all\" or \"any\".`
|
||||
}
|
||||
|
||||
if (args.quorum !== undefined && (!Number.isInteger(args.quorum) || args.quorum < 1)) {
|
||||
return "Error: quorum must be a positive integer."
|
||||
}
|
||||
|
||||
const timeoutMs = Math.min(args.timeout ?? 60000, 600000)
|
||||
const pollIntervalMs = Math.max(args.poll_interval ?? 1000, 100)
|
||||
const block = args.block === true
|
||||
const quorumTarget = Math.min(args.quorum ?? (mode === "any" ? 1 : taskIDs.length), taskIDs.length)
|
||||
const startTime = Date.now()
|
||||
|
||||
const buildSnapshot = (): BackgroundWaitResult => {
|
||||
const byStatus: Record<string, number> = {
|
||||
pending: 0,
|
||||
running: 0,
|
||||
completed: 0,
|
||||
error: 0,
|
||||
cancelled: 0,
|
||||
interrupt: 0,
|
||||
not_found: 0,
|
||||
}
|
||||
|
||||
const tasks = taskIDs.map((taskID) => {
|
||||
const task = manager.getTask(taskID)
|
||||
if (!task) {
|
||||
byStatus.not_found += 1
|
||||
return {
|
||||
task_id: taskID,
|
||||
found: false,
|
||||
status: "not_found" as const,
|
||||
}
|
||||
}
|
||||
|
||||
byStatus[task.status] += 1
|
||||
return {
|
||||
task_id: task.id,
|
||||
found: true,
|
||||
status: task.status,
|
||||
agent: task.agent,
|
||||
description: task.description,
|
||||
session_id: task.sessionID,
|
||||
started_at: task.startedAt?.toISOString(),
|
||||
completed_at: task.completedAt?.toISOString(),
|
||||
}
|
||||
})
|
||||
|
||||
const terminalCount = tasks.filter((task) => task.found && isTerminalStatus(task.status as BackgroundTask["status"]))
|
||||
.length
|
||||
const activeCount = tasks.filter((task) => task.status === "pending" || task.status === "running").length
|
||||
const quorumReached = terminalCount >= quorumTarget
|
||||
|
||||
return {
|
||||
mode,
|
||||
block,
|
||||
timeout_ms: timeoutMs,
|
||||
waited_ms: Date.now() - startTime,
|
||||
done: quorumReached,
|
||||
reason: block ? "waiting" : "non_blocking",
|
||||
quorum: {
|
||||
target: quorumTarget,
|
||||
reached: terminalCount,
|
||||
remaining: Math.max(quorumTarget - terminalCount, 0),
|
||||
progress: quorumTarget === 0 ? 1 : terminalCount / quorumTarget,
|
||||
},
|
||||
summary: {
|
||||
total: tasks.length,
|
||||
terminal: terminalCount,
|
||||
active: activeCount,
|
||||
by_status: byStatus,
|
||||
},
|
||||
grouped: {
|
||||
pending: tasks.filter((task) => task.status === "pending").map((task) => task.task_id),
|
||||
running: tasks.filter((task) => task.status === "running").map((task) => task.task_id),
|
||||
completed: tasks.filter((task) => task.status === "completed").map((task) => task.task_id),
|
||||
error: tasks.filter((task) => task.status === "error").map((task) => task.task_id),
|
||||
cancelled: tasks.filter((task) => task.status === "cancelled").map((task) => task.task_id),
|
||||
interrupt: tasks.filter((task) => task.status === "interrupt").map((task) => task.task_id),
|
||||
not_found: tasks.filter((task) => task.status === "not_found").map((task) => task.task_id),
|
||||
},
|
||||
tasks: tasks.map((task) => ({
|
||||
...task,
|
||||
status: task.status as WaitTaskStatus,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
let snapshot = buildSnapshot()
|
||||
if (!block) {
|
||||
return JSON.stringify(snapshot, null, 2)
|
||||
}
|
||||
|
||||
while (!snapshot.done && Date.now() - startTime < timeoutMs) {
|
||||
await delay(pollIntervalMs)
|
||||
snapshot = buildSnapshot()
|
||||
}
|
||||
|
||||
const finalSnapshot: BackgroundWaitResult = {
|
||||
...snapshot,
|
||||
waited_ms: Date.now() - startTime,
|
||||
done: snapshot.done,
|
||||
reason: snapshot.done ? "quorum_reached" : "timeout",
|
||||
}
|
||||
|
||||
return JSON.stringify(finalSnapshot, null, 2)
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -2,6 +2,7 @@ export {
|
||||
createBackgroundTask,
|
||||
createBackgroundOutput,
|
||||
createBackgroundCancel,
|
||||
createBackgroundWait,
|
||||
} from "./tools"
|
||||
|
||||
export type * from "./types"
|
||||
|
||||
@@ -9,3 +9,4 @@ export type {
|
||||
export { createBackgroundTask } from "./create-background-task"
|
||||
export { createBackgroundOutput } from "./create-background-output"
|
||||
export { createBackgroundCancel } from "./create-background-cancel"
|
||||
export { createBackgroundWait } from "./create-background-wait"
|
||||
|
||||
@@ -21,6 +21,49 @@ export interface BackgroundCancelArgs {
|
||||
all?: boolean
|
||||
}
|
||||
|
||||
export interface BackgroundWaitArgs {
|
||||
task_ids: string[]
|
||||
mode?: "all" | "any"
|
||||
quorum?: number
|
||||
block?: boolean
|
||||
timeout?: number
|
||||
poll_interval?: number
|
||||
}
|
||||
|
||||
export type BackgroundWaitTaskSnapshot = {
|
||||
task_id: string
|
||||
found: boolean
|
||||
status: "pending" | "running" | "completed" | "error" | "cancelled" | "interrupt" | "not_found"
|
||||
agent?: string
|
||||
description?: string
|
||||
session_id?: string
|
||||
started_at?: string
|
||||
completed_at?: string
|
||||
}
|
||||
|
||||
export type BackgroundWaitResult = {
|
||||
mode: "all" | "any"
|
||||
block: boolean
|
||||
timeout_ms: number
|
||||
waited_ms: number
|
||||
done: boolean
|
||||
reason: "non_blocking" | "waiting" | "quorum_reached" | "timeout"
|
||||
quorum: {
|
||||
target: number
|
||||
reached: number
|
||||
remaining: number
|
||||
progress: number
|
||||
}
|
||||
summary: {
|
||||
total: number
|
||||
terminal: number
|
||||
active: number
|
||||
by_status: Record<string, number>
|
||||
}
|
||||
grouped: Record<string, string[]>
|
||||
tasks: BackgroundWaitTaskSnapshot[]
|
||||
}
|
||||
|
||||
export type BackgroundOutputMessage = {
|
||||
info?: { role?: string; time?: string | { created?: number }; agent?: string }
|
||||
parts?: Array<{
|
||||
|
||||
@@ -21,10 +21,12 @@ export { sessionExists } from "./session-manager/storage"
|
||||
|
||||
export { interactive_bash, startBackgroundCheck as startTmuxCheck } from "./interactive-bash"
|
||||
export { createSkillMcpTool } from "./skill-mcp"
|
||||
export { createSwitchAgentTool } from "./switch-agent"
|
||||
|
||||
import {
|
||||
createBackgroundOutput,
|
||||
createBackgroundCancel,
|
||||
createBackgroundWait,
|
||||
type BackgroundOutputManager,
|
||||
type BackgroundCancelClient,
|
||||
} from "./background-task"
|
||||
@@ -51,6 +53,7 @@ export function createBackgroundTools(manager: BackgroundManager, client: Openco
|
||||
return {
|
||||
background_output: createBackgroundOutput(outputManager, client),
|
||||
background_cancel: createBackgroundCancel(manager, cancelClient),
|
||||
background_wait: createBackgroundWait(outputManager),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
1
src/tools/switch-agent/constants.ts
Normal file
1
src/tools/switch-agent/constants.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const SWITCH_AGENT_DESCRIPTION = "Queue an agent switch for the current or target session. Switch is applied on next chat.message through hook flow."
|
||||
2
src/tools/switch-agent/index.ts
Normal file
2
src/tools/switch-agent/index.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export { createSwitchAgentTool } from "./tools"
|
||||
export type { SwitchAgentArgs } from "./types"
|
||||
79
src/tools/switch-agent/tools.test.ts
Normal file
79
src/tools/switch-agent/tools.test.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import { describe, expect, test, beforeEach } from "bun:test"
|
||||
import { createSwitchAgentTool } from "./tools"
|
||||
import {
|
||||
_resetForTesting,
|
||||
getPendingSessionAgentSwitch,
|
||||
} from "../../features/claude-code-session-state"
|
||||
|
||||
describe("switch_agent tool", () => {
|
||||
beforeEach(() => {
|
||||
_resetForTesting()
|
||||
})
|
||||
|
||||
test("#given empty agent #when executing #then returns validation error", async () => {
|
||||
// given
|
||||
const client = {
|
||||
app: {
|
||||
agents: async () => ({ data: [{ name: "sisyphus" }] }),
|
||||
},
|
||||
} as unknown as Parameters<typeof createSwitchAgentTool>[0]
|
||||
const tool = createSwitchAgentTool(client)
|
||||
|
||||
// when
|
||||
const output = await tool.execute({ agent: " " }, { sessionID: "ses-1" } as never)
|
||||
|
||||
// then
|
||||
expect(output).toContain("agent is required")
|
||||
})
|
||||
|
||||
test("#given unknown agent #when executing #then returns invalid switch error", async () => {
|
||||
// given
|
||||
const client = {
|
||||
app: {
|
||||
agents: async () => ({ data: [{ name: "sisyphus" }, { name: "explore" }] }),
|
||||
},
|
||||
} as unknown as Parameters<typeof createSwitchAgentTool>[0]
|
||||
const tool = createSwitchAgentTool(client)
|
||||
|
||||
// when
|
||||
const output = await tool.execute({ agent: "ghost" }, { sessionID: "ses-1" } as never)
|
||||
|
||||
// then
|
||||
expect(output).toContain("unknown agent")
|
||||
expect(getPendingSessionAgentSwitch("ses-1")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("#given known but disabled agent #when executing #then returns disabled error", async () => {
|
||||
// given
|
||||
const client = {
|
||||
app: {
|
||||
agents: async () => ({ data: [{ name: "explore" }] }),
|
||||
},
|
||||
} as unknown as Parameters<typeof createSwitchAgentTool>[0]
|
||||
const tool = createSwitchAgentTool(client, ["explore"])
|
||||
|
||||
// when
|
||||
const output = await tool.execute({ agent: "explore" }, { sessionID: "ses-1" } as never)
|
||||
|
||||
// then
|
||||
expect(output).toContain("disabled")
|
||||
expect(getPendingSessionAgentSwitch("ses-1")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("#given known enabled agent #when executing #then queues pending switch", async () => {
|
||||
// given
|
||||
const client = {
|
||||
app: {
|
||||
agents: async () => ({ data: [{ name: "explore" }, { name: "Athena" }] }),
|
||||
},
|
||||
} as unknown as Parameters<typeof createSwitchAgentTool>[0]
|
||||
const tool = createSwitchAgentTool(client)
|
||||
|
||||
// when
|
||||
const output = await tool.execute({ agent: "explore" }, { sessionID: "ses-1" } as never)
|
||||
|
||||
// then
|
||||
expect(output).toContain("Agent switch queued")
|
||||
expect(getPendingSessionAgentSwitch("ses-1")?.agent).toBe("explore")
|
||||
})
|
||||
})
|
||||
57
src/tools/switch-agent/tools.ts
Normal file
57
src/tools/switch-agent/tools.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import { tool, type PluginInput, type ToolDefinition } from "@opencode-ai/plugin"
|
||||
import { setPendingSessionAgentSwitch } from "../../features/claude-code-session-state"
|
||||
import { normalizeSDKResponse } from "../../shared"
|
||||
import type { SwitchAgentArgs } from "./types"
|
||||
import { SWITCH_AGENT_DESCRIPTION } from "./constants"
|
||||
|
||||
type SwitchableAgent = {
|
||||
name: string
|
||||
mode?: "subagent" | "primary" | "all"
|
||||
}
|
||||
|
||||
export function createSwitchAgentTool(client: PluginInput["client"], disabledAgents: string[] = []): ToolDefinition {
|
||||
return tool({
|
||||
description: SWITCH_AGENT_DESCRIPTION,
|
||||
args: {
|
||||
agent: tool.schema.string().describe("Agent name to switch to"),
|
||||
session_id: tool.schema.string().optional().describe("Session ID to switch. Defaults to current session"),
|
||||
},
|
||||
async execute(args: SwitchAgentArgs, toolContext) {
|
||||
const targetSessionID = args.session_id ?? toolContext.sessionID
|
||||
const requestedAgent = args.agent?.trim()
|
||||
|
||||
if (!requestedAgent) {
|
||||
return "Error: agent is required."
|
||||
}
|
||||
|
||||
try {
|
||||
const agentsResponse = await client.app.agents()
|
||||
const agents = normalizeSDKResponse(agentsResponse, [] as SwitchableAgent[], {
|
||||
preferResponseOnMissingData: true,
|
||||
})
|
||||
const matchedAgent = agents.find((agent) => agent.name.toLowerCase() === requestedAgent.toLowerCase())
|
||||
|
||||
if (!matchedAgent) {
|
||||
const availableAgents = agents.map((agent) => agent.name).sort()
|
||||
return `Error: unknown agent \"${requestedAgent}\". Available agents: ${availableAgents.join(", ")}`
|
||||
}
|
||||
|
||||
if (disabledAgents.some((disabledAgent) => disabledAgent.toLowerCase() === matchedAgent.name.toLowerCase())) {
|
||||
return `Error: agent \"${matchedAgent.name}\" is disabled via disabled_agents configuration.`
|
||||
}
|
||||
|
||||
const pendingSwitch = setPendingSessionAgentSwitch(targetSessionID, matchedAgent.name)
|
||||
|
||||
return [
|
||||
"Agent switch queued.",
|
||||
`Session ID: ${targetSessionID}`,
|
||||
`Next agent: ${pendingSwitch.agent}`,
|
||||
`Requested at: ${pendingSwitch.requestedAt.toISOString()}`,
|
||||
"The switch will be applied by hook flow on the next chat.message turn.",
|
||||
].join("\n")
|
||||
} catch (error) {
|
||||
return `Error: failed to queue agent switch: ${error instanceof Error ? error.message : String(error)}`
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
4
src/tools/switch-agent/types.ts
Normal file
4
src/tools/switch-agent/types.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
export type SwitchAgentArgs = {
|
||||
agent: string
|
||||
session_id?: string
|
||||
}
|
||||
Reference in New Issue
Block a user