diff --git a/src/features/agent-switch/applier.test.ts b/src/features/agent-switch/applier.test.ts new file mode 100644 index 000000000..8c8b1654a --- /dev/null +++ b/src/features/agent-switch/applier.test.ts @@ -0,0 +1,226 @@ +/// + +import { beforeEach, describe, expect, test } from "bun:test" +import { _resetForTesting, getPendingSwitch, setPendingSwitch } from "./state" +import { + _resetApplierForTesting, + applyPendingSwitch, + clearPendingSwitchRuntime, +} from "./applier" +import { schedulePendingSwitchApply } from "./scheduler" + +describe("agent-switch applier", () => { + beforeEach(() => { + _resetForTesting() + _resetApplierForTesting() + }) + + test("scheduled apply works without idle event", async () => { + const calls: string[] = [] + let switched = false + const client = { + session: { + promptAsync: async (input: { body: { agent: string } }) => { + calls.push(input.body.agent) + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] }) + : ({ data: [] }), + }, + } + + setPendingSwitch("ses-1", "prometheus", "create plan") + schedulePendingSwitchApply({ + sessionID: "ses-1", + client: client as any, + }) + + await new Promise((resolve) => setTimeout(resolve, 300)) + + expect(calls).toEqual(["Prometheus (Plan Builder)"]) + expect(getPendingSwitch("ses-1")).toBeUndefined() + }) + + test("normalizes pending agent to canonical prompt display name", async () => { + const calls: string[] = [] + let switched = false + const client = { + session: { + promptAsync: async (input: { body: { agent: string } }) => { + calls.push(input.body.agent) + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] }) + : ({ data: [] }), + }, + } + + setPendingSwitch("ses-2", "Prometheus (Plan Builder)", "create plan") + await applyPendingSwitch({ + sessionID: "ses-2", + client: client as any, + source: "idle", + }) + + expect(calls).toEqual(["Prometheus (Plan Builder)"]) + expect(getPendingSwitch("ses-2")).toBeUndefined() + }) + + test("retries transient failures and eventually clears pending switch", async () => { + let attempts = 0 + let switched = false + const client = { + session: { + promptAsync: async () => { + attempts += 1 + if (attempts < 3) { + throw new Error("temporary failure") + } + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] }) + : ({ data: [] }), + }, + } + + setPendingSwitch("ses-3", "atlas", "fix this") + await applyPendingSwitch({ + sessionID: "ses-3", + client: client as any, + source: "idle", + }) + + await new Promise((resolve) => setTimeout(resolve, 800)) + + expect(attempts).toBe(3) + expect(getPendingSwitch("ses-3")).toBeUndefined() + }) + + test("waits for session idle before applying switch", async () => { + let statusChecks = 0 + let promptCalls = 0 + let switched = false + const client = { + session: { + status: async () => { + statusChecks += 1 + return { + "ses-5": { type: statusChecks < 3 ? "running" : "idle" }, + } + }, + promptAsync: async () => { + promptCalls += 1 + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] }) + : ({ data: [] }), + }, + } + + setPendingSwitch("ses-5", "atlas", "fix now") + await applyPendingSwitch({ + sessionID: "ses-5", + client: client as any, + source: "idle", + }) + + expect(statusChecks).toBeGreaterThanOrEqual(3) + expect(promptCalls).toBe(1) + expect(getPendingSwitch("ses-5")).toBeUndefined() + }) + + test("clearPendingSwitchRuntime cancels pending retries", async () => { + let attempts = 0 + const client = { + session: { + promptAsync: async () => { + attempts += 1 + throw new Error("always failing") + }, + messages: async () => ({ data: [] }), + }, + } + + setPendingSwitch("ses-4", "atlas", "fix this") + await applyPendingSwitch({ + sessionID: "ses-4", + client: client as any, + source: "idle", + }) + + clearPendingSwitchRuntime("ses-4") + + const attemptsAfterClear = attempts + + await new Promise((resolve) => setTimeout(resolve, 300)) + + expect(attempts).toBe(attemptsAfterClear) + expect(getPendingSwitch("ses-4")).toBeUndefined() + }) + + test("syncs CLI TUI agent selection for athena-to-atlas handoff", async () => { + const originalClientEnv = process.env["OPENCODE_CLIENT"] + process.env["OPENCODE_CLIENT"] = "cli" + + try { + const promptCalls: string[] = [] + const tuiCommands: string[] = [] + let switched = false + const client = { + session: { + promptAsync: async (input: { body: { agent: string } }) => { + promptCalls.push(input.body.agent) + switched = true + }, + messages: async () => switched + ? ({ + data: [ + { info: { role: "user", agent: "Athena (Council)" } }, + { info: { role: "user", agent: "Atlas (Plan Executor)" } }, + ], + }) + : ({ + data: [{ info: { role: "user", agent: "Athena (Council)" } }], + }), + }, + app: { + agents: async () => ({ + data: [ + { name: "Sisyphus (Ultraworker)", mode: "primary" }, + { name: "Hephaestus (Deep Agent)", mode: "primary" }, + { name: "Prometheus (Plan Builder)", mode: "primary" }, + { name: "Atlas (Plan Executor)", mode: "primary" }, + { name: "Athena (Council)", mode: "primary" }, + ], + }), + }, + tui: { + publish: async (input: { body: { properties: { command: string } } }) => { + tuiCommands.push(input.body.properties.command) + }, + }, + } + + setPendingSwitch("ses-6", "atlas", "fix now") + await applyPendingSwitch({ + sessionID: "ses-6", + client: client as any, + source: "message-updated", + }) + + expect(promptCalls).toEqual(["Atlas (Plan Executor)"]) + expect(tuiCommands).toEqual(["agent.cycle.reverse"]) + expect(getPendingSwitch("ses-6")).toBeUndefined() + } finally { + if (originalClientEnv === undefined) { + delete process.env["OPENCODE_CLIENT"] + } else { + process.env["OPENCODE_CLIENT"] = originalClientEnv + } + } + }) +}) diff --git a/src/features/agent-switch/applier.ts b/src/features/agent-switch/applier.ts new file mode 100644 index 000000000..5fda952da --- /dev/null +++ b/src/features/agent-switch/applier.ts @@ -0,0 +1,211 @@ +import { normalizeAgentForPrompt } from "../../shared/agent-display-names" +import { log } from "../../shared/logger" +import { clearPendingSwitch, getPendingSwitch } from "./state" +import { waitForSessionIdle } from "./session-status" +import { fetchMessages, shouldClearAsAlreadyApplied, verifySwitchObserved } from "./apply-verification" +import { getLatestUserAgent } from "./message-inspection" +import { syncCliTuiAgentSelectionAfterSwitch } from "./tui-agent-sync" +import { + clearInFlight, + clearRetryState, + isApplyInFlight, + markApplyInFlight, + resetRetryStateForTesting, + scheduleRetry, +} from "./retry-state" + +type SessionClient = { + session: { + prompt?: (input: { + path: { id: string } + body: { agent: string; parts: Array<{ type: "text"; text: string }> } + }) => Promise + promptAsync: (input: { + path: { id: string } + body: { agent: string; parts: Array<{ type: "text"; text: string }> } + }) => Promise + messages: (input: { path: { id: string } }) => Promise + status?: () => Promise + } + app?: { + agents?: () => Promise + } + tui?: { + publish?: (input: { + body: { + type: "tui.command.execute" + properties: { command: string } + } + }) => Promise + } +} + +async function tryPromptWithCandidates(args: { + client: SessionClient + sessionID: string + agent: string + context: string + source: string +}): Promise { + const { client, sessionID, agent, context, source } = args + const targetAgent = normalizeAgentForPrompt(agent) + if (!targetAgent) { + throw new Error(`invalid target agent for switch prompt: ${agent}`) + } + + try { + const promptInput = { + path: { id: sessionID }, + body: { + agent: targetAgent, + parts: [{ type: "text" as const, text: context }], + }, + } + + if (client.session.prompt) { + await client.session.prompt(promptInput) + } else { + await client.session.promptAsync(promptInput) + } + + if (targetAgent !== agent) { + log("[agent-switch] Normalized pending switch agent for prompt", { + sessionID, + source, + requestedAgent: agent, + usedAgent: targetAgent, + }) + } + + return targetAgent + } catch (error) { + log("[agent-switch] Prompt attempt failed", { + sessionID, + source, + requestedAgent: agent, + attemptedAgent: targetAgent, + error: String(error), + }) + throw error + } +} + +export async function applyPendingSwitch(args: { + sessionID: string + client: SessionClient + source: string +}): Promise { + const { sessionID, client, source } = args + const pending = getPendingSwitch(sessionID) + if (!pending) { + clearRetryState(sessionID) + return + } + + if (isApplyInFlight(sessionID)) { + return + } + + markApplyInFlight(sessionID) + log("[agent-switch] Applying pending switch", { + sessionID, + source, + agent: pending.agent, + }) + + try { + const alreadyApplied = await shouldClearAsAlreadyApplied({ + client, + sessionID, + targetAgent: pending.agent, + }) + if (alreadyApplied) { + clearPendingSwitch(sessionID) + clearRetryState(sessionID) + log("[agent-switch] Pending switch already applied by user-turn evidence; clearing state", { + sessionID, + source, + agent: pending.agent, + }) + return + } + + const idleReady = await waitForSessionIdle({ client, sessionID }) + if (!idleReady) { + throw new Error("session not idle before applying agent switch") + } + + const beforeMessages = await fetchMessages({ client, sessionID }) + const sourceUserAgent = getLatestUserAgent(beforeMessages) + + const usedAgent = await tryPromptWithCandidates({ + client, + sessionID, + agent: pending.agent, + context: pending.context, + source, + }) + + const verified = await verifySwitchObserved({ + client, + sessionID, + targetAgent: pending.agent, + baselineCount: beforeMessages.length, + }) + if (!verified) { + throw new Error(`agent switch not observed after prompt (attempted ${usedAgent})`) + } + + clearPendingSwitch(sessionID) + clearRetryState(sessionID) + + await syncCliTuiAgentSelectionAfterSwitch({ + client, + sessionID, + source, + sourceAgent: sourceUserAgent, + targetAgent: pending.agent, + }) + + log("[agent-switch] Pending switch applied", { + sessionID, + source, + agent: pending.agent, + }) + } catch (error) { + clearInFlight(sessionID) + log("[agent-switch] Pending switch apply failed", { + sessionID, + source, + error: String(error), + }) + scheduleRetry({ + sessionID, + source, + onLimitReached: (attempts) => { + log("[agent-switch] Retry limit reached; waiting for next trigger", { + sessionID, + attempts, + source, + }) + }, + retryFn: (attemptNumber) => { + void applyPendingSwitch({ + sessionID, + client, + source: `retry:${attemptNumber}`, + }) + }, + }) + } +} + +export function clearPendingSwitchRuntime(sessionID: string): void { + clearPendingSwitch(sessionID) + clearRetryState(sessionID) +} + +/** @internal For testing only */ +export function _resetApplierForTesting(): void { + resetRetryStateForTesting() +} diff --git a/src/features/agent-switch/apply-verification.ts b/src/features/agent-switch/apply-verification.ts new file mode 100644 index 000000000..4cfa81afe --- /dev/null +++ b/src/features/agent-switch/apply-verification.ts @@ -0,0 +1,59 @@ +import { extractMessageList, hasNewUserTurnForTargetAgent, hasRecentUserTurnForTargetAgent } from "./message-inspection" +import { log } from "../../shared/logger" +import { sleepWithDelay } from "./session-status" + +type SessionClient = { + session: { + messages: (input: { path: { id: string } }) => Promise + } +} + +export async function fetchMessages(args: { + client: SessionClient + sessionID: string +}): Promise>> { + const response = await args.client.session.messages({ path: { id: args.sessionID } }) + return extractMessageList(response) +} + +export async function verifySwitchObserved(args: { + client: SessionClient + sessionID: string + targetAgent: string + baselineCount: number +}): Promise { + const { client, sessionID, targetAgent, baselineCount } = args + const delays = [100, 300, 800, 1500] as const + + for (const delay of delays) { + await sleepWithDelay(delay) + try { + const messages = await fetchMessages({ client, sessionID }) + if (hasNewUserTurnForTargetAgent({ messages, targetAgent, baselineCount })) { + return true + } + } catch (error) { + log("[agent-switch] Verification read failed", { + sessionID, + error: String(error), + }) + } + } + + return false +} + +export async function shouldClearAsAlreadyApplied(args: { + client: SessionClient + sessionID: string + targetAgent: string +}): Promise { + const { client, sessionID, targetAgent } = args + + try { + const messages = await fetchMessages({ client, sessionID }) + return hasRecentUserTurnForTargetAgent({ messages, targetAgent }) + } catch { + return false + } +} diff --git a/src/features/agent-switch/index.ts b/src/features/agent-switch/index.ts index abc967b38..15be54231 100644 --- a/src/features/agent-switch/index.ts +++ b/src/features/agent-switch/index.ts @@ -1,2 +1,8 @@ -export { setPendingSwitch, consumePendingSwitch, _resetForTesting } from "./state" +export { + setPendingSwitch, + getPendingSwitch, + clearPendingSwitch, + consumePendingSwitch, + _resetForTesting, +} from "./state" export type { PendingSwitch } from "./state" diff --git a/src/features/agent-switch/message-inspection.ts b/src/features/agent-switch/message-inspection.ts new file mode 100644 index 000000000..4a2fc9d6c --- /dev/null +++ b/src/features/agent-switch/message-inspection.ts @@ -0,0 +1,107 @@ +import { getAgentConfigKey } from "../../shared/agent-display-names" + +export interface MessageRoleAgent { + role: string + agent: string +} + +export function extractMessageList(response: unknown): Array> { + if (Array.isArray(response)) { + return response.filter((item): item is Record => typeof item === "object" && item !== null) + } + if (typeof response === "object" && response !== null) { + const data = (response as Record).data + if (Array.isArray(data)) { + return data.filter((item): item is Record => typeof item === "object" && item !== null) + } + } + return [] +} + +function getRoleAgent(message: Record): MessageRoleAgent | undefined { + const info = message.info + if (typeof info !== "object" || info === null) { + return undefined + } + + const role = (info as Record).role + const agent = (info as Record).agent + if (typeof role !== "string" || typeof agent !== "string") { + return undefined + } + + return { role, agent } +} + +export function getLatestUserAgent(messages: Array>): string | undefined { + for (let index = messages.length - 1; index >= 0; index -= 1) { + const message = messages[index] + if (!message) { + continue + } + + const roleAgent = getRoleAgent(message) + if (!roleAgent || roleAgent.role !== "user") { + continue + } + + return roleAgent.agent + } + + return undefined +} + +export function hasRecentUserTurnForTargetAgent(args: { + messages: Array> + targetAgent: string + lookback?: number +}): boolean { + const { messages, targetAgent, lookback = 8 } = args + const targetKey = getAgentConfigKey(targetAgent) + const start = Math.max(0, messages.length - lookback) + + for (let index = messages.length - 1; index >= start; index -= 1) { + const message = messages[index] + if (!message) { + continue + } + + const roleAgent = getRoleAgent(message) + if (!roleAgent || roleAgent.role !== "user") { + continue + } + + if (getAgentConfigKey(roleAgent.agent) === targetKey) { + return true + } + } + + return false +} + +export function hasNewUserTurnForTargetAgent(args: { + messages: Array> + targetAgent: string + baselineCount: number +}): boolean { + const { messages, targetAgent, baselineCount } = args + const targetKey = getAgentConfigKey(targetAgent) + + if (messages.length <= baselineCount) { + return false + } + + const newMessages = messages.slice(Math.max(0, baselineCount)) + for (const message of newMessages) { + const roleAgent = getRoleAgent(message) + if (!roleAgent || roleAgent.role !== "user") { + continue + } + + if (getAgentConfigKey(roleAgent.agent) === targetKey) { + return true + } + } + + return false +} diff --git a/src/features/agent-switch/retry-state.ts b/src/features/agent-switch/retry-state.ts new file mode 100644 index 000000000..05dfc086f --- /dev/null +++ b/src/features/agent-switch/retry-state.ts @@ -0,0 +1,66 @@ +const RETRY_DELAYS_MS = [50, 250, 500, 1000, 2000, 5000] as const + +const inFlightSessions = new Set() +const retryAttempts = new Map() +const retryTimers = new Map>() + +export function isApplyInFlight(sessionID: string): boolean { + return inFlightSessions.has(sessionID) +} + +export function markApplyInFlight(sessionID: string): void { + inFlightSessions.add(sessionID) +} + +export function clearRetryState(sessionID: string): void { + const timer = retryTimers.get(sessionID) + if (timer) { + clearTimeout(timer) + retryTimers.delete(sessionID) + } + retryAttempts.delete(sessionID) + inFlightSessions.delete(sessionID) +} + +export function clearInFlight(sessionID: string): void { + inFlightSessions.delete(sessionID) +} + +export function scheduleRetry(args: { + sessionID: string + source: string + retryFn: (attemptNumber: number) => void + onLimitReached: (attempts: number) => void +}): void { + const { sessionID, retryFn, onLimitReached } = args + const attempts = retryAttempts.get(sessionID) ?? 0 + if (attempts >= RETRY_DELAYS_MS.length) { + onLimitReached(attempts) + return + } + + const delay = RETRY_DELAYS_MS[attempts] + retryAttempts.set(sessionID, attempts + 1) + + const existing = retryTimers.get(sessionID) + if (existing) { + clearTimeout(existing) + } + + const timer = setTimeout(() => { + retryTimers.delete(sessionID) + retryFn(attempts + 1) + }, delay) + + retryTimers.set(sessionID, timer) +} + +/** @internal For testing only */ +export function resetRetryStateForTesting(): void { + for (const timer of retryTimers.values()) { + clearTimeout(timer) + } + retryTimers.clear() + retryAttempts.clear() + inFlightSessions.clear() +} diff --git a/src/features/agent-switch/scheduler.ts b/src/features/agent-switch/scheduler.ts new file mode 100644 index 000000000..59e5418d0 --- /dev/null +++ b/src/features/agent-switch/scheduler.ts @@ -0,0 +1,43 @@ +import { log } from "../../shared/logger" +import { scheduleRetry } from "./retry-state" +import { applyPendingSwitch } from "./applier" + +type SessionClient = { + session: { + prompt?: (input: { + path: { id: string } + body: { agent: string; parts: Array<{ type: "text"; text: string }> } + }) => Promise + promptAsync: (input: { + path: { id: string } + body: { agent: string; parts: Array<{ type: "text"; text: string }> } + }) => Promise + messages: (input: { path: { id: string } }) => Promise + status?: () => Promise + } +} + +export function schedulePendingSwitchApply(args: { + sessionID: string + client: SessionClient +}): void { + const { sessionID, client } = args + scheduleRetry({ + sessionID, + source: "tool", + onLimitReached: (attempts) => { + log("[agent-switch] Retry limit reached; waiting for next trigger", { + sessionID, + attempts, + source: "tool", + }) + }, + retryFn: (attemptNumber) => { + void applyPendingSwitch({ + sessionID, + client, + source: `retry:${attemptNumber}`, + }) + }, + }) +} diff --git a/src/features/agent-switch/session-status.ts b/src/features/agent-switch/session-status.ts new file mode 100644 index 000000000..dfdc09f94 --- /dev/null +++ b/src/features/agent-switch/session-status.ts @@ -0,0 +1,68 @@ +import { log } from "../../shared/logger" + +type SessionClient = { + session: { + status?: () => Promise + } +} + +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + +function getSessionStatusType(statusResponse: unknown, sessionID: string): string | undefined { + if (typeof statusResponse !== "object" || statusResponse === null) { + return undefined + } + + const root = statusResponse as Record + const data = (typeof root.data === "object" && root.data !== null) + ? root.data as Record + : root + + const entry = data[sessionID] + if (typeof entry !== "object" || entry === null) { + return undefined + } + + const entryType = (entry as Record).type + return typeof entryType === "string" ? entryType : undefined +} + +export async function waitForSessionIdle(args: { + client: SessionClient + sessionID: string + timeoutMs?: number +}): Promise { + const { client, sessionID, timeoutMs = 15000 } = args + if (!client.session.status) { + return true + } + + const start = Date.now() + while (Date.now() - start < timeoutMs) { + try { + const statusResponse = await client.session.status() + const statusType = getSessionStatusType(statusResponse, sessionID) + // /session/status only tracks non-idle sessions in SessionStatus.list(). + // Missing entry means idle. + if (!statusType || statusType === "idle") { + return true + } + } catch (error) { + log("[agent-switch] Session status check failed", { + sessionID, + error: String(error), + }) + return true + } + + await sleep(200) + } + + return false +} + +export async function sleepWithDelay(ms: number): Promise { + await sleep(ms) +} diff --git a/src/features/agent-switch/state.test.ts b/src/features/agent-switch/state.test.ts index 87a07a3ad..00f879bec 100644 --- a/src/features/agent-switch/state.test.ts +++ b/src/features/agent-switch/state.test.ts @@ -1,5 +1,11 @@ -import { describe, test, expect, beforeEach } from "bun:test" -import { setPendingSwitch, consumePendingSwitch, _resetForTesting } from "./state" +const { describe, test, expect, beforeEach } = require("bun:test") +import { + setPendingSwitch, + getPendingSwitch, + clearPendingSwitch, + consumePendingSwitch, + _resetForTesting, +} from "./state" describe("agent-switch state", () => { beforeEach(() => { @@ -47,4 +53,21 @@ describe("agent-switch state", () => { expect(consumePendingSwitch("session-1")).toEqual({ agent: "atlas", context: "Fix A" }) expect(consumePendingSwitch("session-2")).toEqual({ agent: "prometheus", context: "Plan B" }) }) + + test("should allow reading without consuming", () => { + setPendingSwitch("session-1", "atlas", "Fix A") + + expect(getPendingSwitch("session-1")).toEqual({ agent: "atlas", context: "Fix A" }) + expect(getPendingSwitch("session-1")).toEqual({ agent: "atlas", context: "Fix A" }) + }) + + test("should clear pending switch explicitly", () => { + setPendingSwitch("session-1", "atlas", "Fix A") + + clearPendingSwitch("session-1") + + expect(getPendingSwitch("session-1")).toBeUndefined() + }) }) + +export {} diff --git a/src/features/agent-switch/state.ts b/src/features/agent-switch/state.ts index 07a3d530c..17f956122 100644 --- a/src/features/agent-switch/state.ts +++ b/src/features/agent-switch/state.ts @@ -1,23 +1,102 @@ +import { existsSync, readFileSync, rmSync, writeFileSync } from "node:fs" +import { join } from "node:path" +import { tmpdir } from "node:os" + export interface PendingSwitch { agent: string context: string } +const PENDING_SWITCH_STATE_FILE = process.platform === "win32" + ? join(tmpdir(), "oh-my-opencode-agent-switch.json") + : "/tmp/oh-my-opencode-agent-switch.json" + const pendingSwitches = new Map() +function isPendingSwitch(value: unknown): value is PendingSwitch { + if (typeof value !== "object" || value === null) return false + const entry = value as Record + return typeof entry.agent === "string" && typeof entry.context === "string" +} + +function readPersistentState(): Record { + try { + if (!existsSync(PENDING_SWITCH_STATE_FILE)) { + return {} + } + + const raw = readFileSync(PENDING_SWITCH_STATE_FILE, "utf8") + const parsed = JSON.parse(raw) + if (typeof parsed !== "object" || parsed === null) { + return {} + } + + const state: Record = {} + for (const [sessionID, value] of Object.entries(parsed)) { + if (isPendingSwitch(value)) { + state[sessionID] = value + } + } + + return state + } catch { + return {} + } +} + +function writePersistentState(state: Record): void { + try { + const keys = Object.keys(state) + if (keys.length === 0) { + rmSync(PENDING_SWITCH_STATE_FILE, { force: true }) + return + } + + writeFileSync(PENDING_SWITCH_STATE_FILE, JSON.stringify(state), "utf8") + } catch { + // ignore persistence errors + } +} + export function setPendingSwitch(sessionID: string, agent: string, context: string): void { - pendingSwitches.set(sessionID, { agent, context }) + const entry = { agent, context } + pendingSwitches.set(sessionID, entry) + + const state = readPersistentState() + state[sessionID] = entry + writePersistentState(state) +} + +export function getPendingSwitch(sessionID: string): PendingSwitch | undefined { + const inMemory = pendingSwitches.get(sessionID) + if (inMemory) { + return inMemory + } + + const state = readPersistentState() + const fromDisk = state[sessionID] + if (fromDisk) { + pendingSwitches.set(sessionID, fromDisk) + } + return fromDisk +} + +export function clearPendingSwitch(sessionID: string): void { + pendingSwitches.delete(sessionID) + + const state = readPersistentState() + delete state[sessionID] + writePersistentState(state) } export function consumePendingSwitch(sessionID: string): PendingSwitch | undefined { - const entry = pendingSwitches.get(sessionID) - if (entry) { - pendingSwitches.delete(sessionID) - } + const entry = getPendingSwitch(sessionID) + clearPendingSwitch(sessionID) return entry } /** @internal For testing only */ export function _resetForTesting(): void { pendingSwitches.clear() + rmSync(PENDING_SWITCH_STATE_FILE, { force: true }) } diff --git a/src/features/agent-switch/tui-agent-sync.ts b/src/features/agent-switch/tui-agent-sync.ts new file mode 100644 index 000000000..d4764ef39 --- /dev/null +++ b/src/features/agent-switch/tui-agent-sync.ts @@ -0,0 +1,132 @@ +import { getAgentConfigKey } from "../../shared/agent-display-names" +import { log, normalizeSDKResponse } from "../../shared" + +type TuiClient = { + app?: { + agents?: () => Promise + } + tui?: { + publish?: (input: { + body: { + type: "tui.command.execute" + properties: { command: string } + } + }) => Promise + } +} + +type AgentInfo = { + name?: string + mode?: "subagent" | "primary" | "all" + hidden?: boolean +} + +function isCliClient(): boolean { + return (process.env["OPENCODE_CLIENT"] ?? "cli") === "cli" +} + +function resolveCyclePlan(args: { + orderedAgentNames: string[] + sourceAgent: string + targetAgent: string +}): { command: "agent.cycle" | "agent.cycle.reverse"; steps: number } | undefined { + const { orderedAgentNames, sourceAgent, targetAgent } = args + if (orderedAgentNames.length < 2) { + return undefined + } + + const orderedKeys = orderedAgentNames.map((name) => getAgentConfigKey(name)) + const sourceKey = getAgentConfigKey(sourceAgent) + const targetKey = getAgentConfigKey(targetAgent) + + const sourceIndex = orderedKeys.indexOf(sourceKey) + const targetIndex = orderedKeys.indexOf(targetKey) + if (sourceIndex < 0 || targetIndex < 0 || sourceIndex === targetIndex) { + return undefined + } + + const size = orderedKeys.length + const forward = (targetIndex - sourceIndex + size) % size + const backward = (sourceIndex - targetIndex + size) % size + + if (forward <= backward) { + return { command: "agent.cycle", steps: forward } + } + + return { command: "agent.cycle.reverse", steps: backward } +} + +export async function syncCliTuiAgentSelectionAfterSwitch(args: { + client: TuiClient + sessionID: string + sourceAgent: string | undefined + targetAgent: string + source: string +}): Promise { + const { client, sessionID, sourceAgent, targetAgent, source } = args + + if (!isCliClient()) { + return + } + + if (!sourceAgent || !client.app?.agents || !client.tui?.publish) { + return + } + + const sourceKey = getAgentConfigKey(sourceAgent) + const targetKey = getAgentConfigKey(targetAgent) + + // Scope to Athena handoffs where CLI TUI can show stale local-agent selection. + if (sourceKey !== "athena" || (targetKey !== "atlas" && targetKey !== "prometheus")) { + return + } + + try { + const response = await client.app.agents() + const agents = normalizeSDKResponse(response, [] as AgentInfo[], { + preferResponseOnMissingData: true, + }) + + const orderedPrimaryAgents = agents + .filter((agent) => typeof agent.name === "string" && agent.mode !== "subagent" && agent.hidden !== true) + .map((agent) => agent.name as string) + + const plan = resolveCyclePlan({ + orderedAgentNames: orderedPrimaryAgents, + sourceAgent, + targetAgent, + }) + + if (!plan || plan.steps <= 0) { + return + } + + for (let step = 0; step < plan.steps; step += 1) { + await client.tui.publish({ + body: { + type: "tui.command.execute", + properties: { + command: plan.command, + }, + }, + }) + } + + log("[agent-switch] Synced CLI TUI local agent after handoff", { + sessionID, + source, + sourceAgent, + targetAgent, + command: plan.command, + steps: plan.steps, + }) + } catch (error) { + log("[agent-switch] Failed syncing CLI TUI local agent after handoff", { + sessionID, + source, + sourceAgent, + targetAgent, + error: String(error), + }) + } +} diff --git a/src/hooks/agent-switch/fallback-handoff.ts b/src/hooks/agent-switch/fallback-handoff.ts new file mode 100644 index 000000000..d3388d780 --- /dev/null +++ b/src/hooks/agent-switch/fallback-handoff.ts @@ -0,0 +1,75 @@ +export function isTerminalFinishValue(finish: unknown): boolean { + if (typeof finish === "boolean") { + return finish + } + + if (typeof finish === "string") { + const normalized = finish.toLowerCase() + return normalized !== "" && normalized !== "tool-calls" && normalized !== "unknown" + } + + if (typeof finish === "object" && finish !== null) { + const record = finish as Record + const kind = record.type ?? record.reason + if (typeof kind === "string") { + const normalized = kind.toLowerCase() + return normalized !== "" && normalized !== "tool-calls" && normalized !== "unknown" + } + } + + return false +} + +export function isTerminalStepFinishPart(part: unknown): boolean { + if (typeof part !== "object" || part === null) { + return false + } + + const record = part as Record + if (record.type !== "step-finish") { + return false + } + + return isTerminalFinishValue(record.reason) +} + +export function extractTextPartsFromMessageResponse(response: unknown): string { + if (typeof response !== "object" || response === null) return "" + const data = (response as Record).data + if (typeof data !== "object" || data === null) return "" + const parts = (data as Record).parts + if (!Array.isArray(parts)) return "" + + return parts + .map((part) => { + if (typeof part !== "object" || part === null) return "" + const partRecord = part as Record + if (partRecord.type !== "text") return "" + return typeof partRecord.text === "string" ? partRecord.text : "" + }) + .filter((text) => text.length > 0) + .join("\n") +} + +export function detectFallbackHandoffTarget(messageText: string): "atlas" | "prometheus" | undefined { + if (!messageText) return undefined + + const normalized = messageText.toLowerCase() + + if (/switching\s+to\s+\*{0,2}\s*prometheus\b/.test(normalized) || /handing\s+off\s+to\s+\*{0,2}\s*prometheus\b/.test(normalized)) { + return "prometheus" + } + + if (/switching\s+to\s+\*{0,2}\s*atlas\b/.test(normalized) || /handing\s+off\s+to\s+\*{0,2}\s*atlas\b/.test(normalized)) { + return "atlas" + } + + return undefined +} + +export function buildFallbackContext(target: "atlas" | "prometheus"): string { + if (target === "prometheus") { + return "Athena indicated handoff to Prometheus. Continue from the current session context and produce the requested phased plan based on the council findings already gathered." + } + return "Athena indicated handoff to Atlas. Continue from the current session context and implement the agreed fixes from the council findings." +} diff --git a/src/hooks/agent-switch/hook.test.ts b/src/hooks/agent-switch/hook.test.ts new file mode 100644 index 000000000..cb2e7c64e --- /dev/null +++ b/src/hooks/agent-switch/hook.test.ts @@ -0,0 +1,358 @@ +/// + +import { beforeEach, describe, expect, test } from "bun:test" +import { createAgentSwitchHook } from "./hook" +import { + _resetForTesting, + getPendingSwitch, + setPendingSwitch, +} from "../../features/agent-switch" +import { _resetApplierForTesting, clearPendingSwitchRuntime } from "../../features/agent-switch/applier" + +describe("agent-switch hook", () => { + beforeEach(() => { + _resetForTesting() + _resetApplierForTesting() + }) + + test("consumes pending switch only after successful promptAsync", async () => { + const promptAsyncCalls: Array> = [] + let switched = false + const ctx = { + client: { + session: { + promptAsync: async (args: Record) => { + promptAsyncCalls.push(args) + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] }) + : ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-1", "prometheus", "plan this") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "session.idle", + properties: { sessionID: "ses-1" }, + }, + }) + + expect(promptAsyncCalls).toHaveLength(1) + expect(getPendingSwitch("ses-1")).toBeUndefined() + }) + + test("keeps pending switch when promptAsync fails", async () => { + const ctx = { + client: { + session: { + promptAsync: async () => { + throw new Error("temporary failure") + }, + messages: async () => ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-2", "atlas", "fix this") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "session.idle", + properties: { sessionID: "ses-2" }, + }, + }) + + expect(getPendingSwitch("ses-2")).toEqual({ + agent: "atlas", + context: "fix this", + }) + + clearPendingSwitchRuntime("ses-2") + }) + + test("retries after transient failure and eventually clears pending switch", async () => { + let attempts = 0 + let switched = false + const ctx = { + client: { + session: { + promptAsync: async () => { + attempts += 1 + if (attempts === 1) { + throw new Error("temporary failure") + } + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] }) + : ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-3", "prometheus", "plan this") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "session.idle", + properties: { sessionID: "ses-3" }, + }, + }) + + await new Promise((resolve) => setTimeout(resolve, 350)) + + expect(attempts).toBe(2) + expect(getPendingSwitch("ses-3")).toBeUndefined() + }) + + test("clears pending switch on session.deleted", async () => { + const ctx = { + client: { + session: { + promptAsync: async () => {}, + messages: async () => ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-4", "atlas", "fix this") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "session.deleted", + properties: { info: { id: "ses-4" } }, + }, + }) + + expect(getPendingSwitch("ses-4")).toBeUndefined() + }) + + test("recovers missing switch_agent tool call from Athena handoff text", async () => { + const promptAsyncCalls: Array> = [] + let switched = false + const ctx = { + client: { + session: { + promptAsync: async (args: Record) => { + promptAsyncCalls.push(args) + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] }) + : ({ data: [] }), + message: async () => ({ + data: { + parts: [ + { + type: "text", + text: "Switching to **Prometheus** now — they'll take it from here and craft a plan for you!", + }, + ], + }, + }), + }, + }, + } as any + + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "message.updated", + properties: { + info: { + id: "msg-athena-1", + sessionID: "ses-5", + role: "assistant", + agent: "Athena (Council)", + finish: "stop", + }, + }, + }, + }) + + expect(promptAsyncCalls).toHaveLength(1) + const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined + expect(body?.agent).toBe("Prometheus (Plan Builder)") + expect(getPendingSwitch("ses-5")).toBeUndefined() + }) + + test("applies queued pending switch on terminal message.updated", async () => { + const promptAsyncCalls: Array> = [] + let switched = false + const ctx = { + client: { + session: { + promptAsync: async (args: Record) => { + promptAsyncCalls.push(args) + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] }) + : ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-6", "atlas", "fix now") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "message.updated", + properties: { + info: { + id: "msg-6", + sessionID: "ses-6", + role: "assistant", + agent: "Athena (Council)", + finish: "stop", + }, + }, + }, + }) + + expect(promptAsyncCalls).toHaveLength(1) + const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined + expect(body?.agent).toBe("Atlas (Plan Executor)") + expect(getPendingSwitch("ses-6")).toBeUndefined() + }) + + test("applies queued pending switch on terminal message.updated even when role is missing", async () => { + const promptAsyncCalls: Array> = [] + let switched = false + const ctx = { + client: { + session: { + promptAsync: async (args: Record) => { + promptAsyncCalls.push(args) + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] }) + : ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-8", "atlas", "fix now") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "message.updated", + properties: { + info: { + id: "msg-8", + sessionID: "ses-8", + agent: "Athena (Council)", + finish: true, + }, + }, + }, + }) + + expect(promptAsyncCalls).toHaveLength(1) + const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined + expect(body?.agent).toBe("Atlas (Plan Executor)") + expect(getPendingSwitch("ses-8")).toBeUndefined() + }) + + test("applies queued pending switch on terminal message.part.updated step-finish", async () => { + const promptAsyncCalls: Array> = [] + let switched = false + const ctx = { + client: { + session: { + promptAsync: async (args: Record) => { + promptAsyncCalls.push(args) + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] }) + : ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-7", "atlas", "fix now") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "message.part.updated", + properties: { + info: { + sessionID: "ses-7", + role: "assistant", + }, + part: { + id: "part-finish-1", + sessionID: "ses-7", + type: "step-finish", + reason: "stop", + }, + }, + }, + }) + + expect(promptAsyncCalls).toHaveLength(1) + const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined + expect(body?.agent).toBe("Atlas (Plan Executor)") + expect(getPendingSwitch("ses-7")).toBeUndefined() + }) + + test("applies queued pending switch on session.status idle", async () => { + const promptAsyncCalls: Array> = [] + let switched = false + const ctx = { + client: { + session: { + promptAsync: async (args: Record) => { + promptAsyncCalls.push(args) + switched = true + }, + messages: async () => switched + ? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] }) + : ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-9", "atlas", "fix now") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "session.status", + properties: { + sessionID: "ses-9", + status: { + type: "idle", + }, + }, + }, + }) + + expect(promptAsyncCalls).toHaveLength(1) + const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined + expect(body?.agent).toBe("Atlas (Plan Executor)") + expect(getPendingSwitch("ses-9")).toBeUndefined() + }) +}) diff --git a/src/hooks/agent-switch/hook.ts b/src/hooks/agent-switch/hook.ts index c219eb5f6..61d44023b 100644 --- a/src/hooks/agent-switch/hook.ts +++ b/src/hooks/agent-switch/hook.ts @@ -1,36 +1,194 @@ import type { PluginInput } from "@opencode-ai/plugin" -import { consumePendingSwitch } from "../../features/agent-switch" +import { getPendingSwitch, setPendingSwitch } from "../../features/agent-switch" +import { applyPendingSwitch, clearPendingSwitchRuntime } from "../../features/agent-switch/applier" +import { getAgentConfigKey } from "../../shared/agent-display-names" import { log } from "../../shared/logger" +import { + buildFallbackContext, + detectFallbackHandoffTarget, + extractTextPartsFromMessageResponse, + isTerminalFinishValue, + isTerminalStepFinishPart, +} from "./fallback-handoff" -const HOOK_NAME = "agent-switch" as const +const processedFallbackMessages = new Set() + +function getSessionIDFromStatusEvent(input: { event: { properties?: Record } }): string | undefined { + const props = input.event.properties as Record | undefined + const fromProps = typeof props?.sessionID === "string" ? props.sessionID : undefined + if (fromProps) { + return fromProps + } + + const status = props?.status as Record | undefined + const fromStatus = typeof status?.sessionID === "string" ? status.sessionID : undefined + return fromStatus +} + +function getStatusTypeFromEvent(input: { event: { properties?: Record } }): string | undefined { + const props = input.event.properties as Record | undefined + const directType = typeof props?.type === "string" ? props.type : undefined + if (directType) { + return directType + } + + const status = props?.status as Record | undefined + const statusType = typeof status?.type === "string" ? status.type : undefined + return statusType +} export function createAgentSwitchHook(ctx: PluginInput) { return { event: async (input: { event: { type: string; properties?: Record } }): Promise => { - if (input.event.type !== "session.idle") return + if (input.event.type === "session.deleted") { + const props = input.event.properties as Record | undefined + const info = props?.info as Record | undefined + const deletedSessionID = info?.id + if (typeof deletedSessionID === "string") { + clearPendingSwitchRuntime(deletedSessionID) + for (const key of Array.from(processedFallbackMessages)) { + if (key.startsWith(`${deletedSessionID}:`)) { + processedFallbackMessages.delete(key) + } + } + } + return + } - const props = input.event.properties as Record | undefined - const sessionID = props?.sessionID as string | undefined - if (!sessionID) return + if (input.event.type === "message.updated") { + const props = input.event.properties as Record | undefined + const info = props?.info as Record | undefined + const sessionID = typeof info?.sessionID === "string" ? info.sessionID : undefined + const messageID = typeof info?.id === "string" ? info.id : undefined + const agent = typeof info?.agent === "string" ? info.agent : undefined + const finish = info?.finish - const pending = consumePendingSwitch(sessionID) - if (!pending) return + if (!sessionID) { + return + } - log(`[${HOOK_NAME}] Switching to ${pending.agent}`, { sessionID }) + const isTerminalAssistantUpdate = isTerminalFinishValue(finish) + if (!isTerminalAssistantUpdate) { + return + } - try { - await ctx.client.session.promptAsync({ - path: { id: sessionID }, - body: { - agent: pending.agent, - parts: [{ type: "text", text: pending.context }], - }, - query: { directory: ctx.directory }, + // Primary path: if switch_agent queued a pending switch, apply it as soon as + // assistant turn is terminal (no reliance on session.idle timing). + if (getPendingSwitch(sessionID)) { + await applyPendingSwitch({ + sessionID, + client: ctx.client, + source: "message-updated", + }) + return + } + + if (!messageID) { + return + } + + if (getAgentConfigKey(agent ?? "") !== "athena") { + return + } + + const marker = `${sessionID}:${messageID}` + if (processedFallbackMessages.has(marker)) { + return + } + processedFallbackMessages.add(marker) + + // If switch_agent already queued a handoff, do not synthesize fallback behavior. + if (getPendingSwitch(sessionID)) { + return + } + + try { + const response = await ctx.client.session.message({ + path: { id: sessionID, messageID }, + }) + const text = extractTextPartsFromMessageResponse(response) + const target = detectFallbackHandoffTarget(text) + if (!target) { + return + } + + setPendingSwitch(sessionID, target, buildFallbackContext(target)) + log("[agent-switch] Recovered missing switch_agent tool call from Athena handoff text", { + sessionID, + messageID, + target, + }) + + await applyPendingSwitch({ + sessionID, + client: ctx.client, + source: "athena-message-fallback", + }) + } catch (error) { + log("[agent-switch] Failed to recover fallback handoff from Athena message", { + sessionID, + messageID, + error: String(error), + }) + } + + return + } + + if (input.event.type === "message.part.updated") { + const props = input.event.properties as Record | undefined + const part = props?.part + const info = props?.info as Record | undefined + const sessionIDFromPart = typeof (part as Record | undefined)?.sessionID === "string" + ? ((part as Record).sessionID as string) + : undefined + const sessionIDFromInfo = typeof info?.sessionID === "string" ? info.sessionID : undefined + const sessionID = sessionIDFromPart ?? sessionIDFromInfo + if (!sessionID) { + return + } + + if (!isTerminalStepFinishPart(part)) { + return + } + + if (!getPendingSwitch(sessionID)) { + return + } + + await applyPendingSwitch({ + sessionID, + client: ctx.client, + source: "message-part-step-finish", }) + return + } - log(`[${HOOK_NAME}] Switch to ${pending.agent} complete`, { sessionID }) - } catch (err) { - log(`[${HOOK_NAME}] Switch failed`, { sessionID, error: String(err) }) + if (input.event.type === "session.idle") { + const props = input.event.properties as Record | undefined + const sessionID = props?.sessionID as string | undefined + if (!sessionID) return + + await applyPendingSwitch({ + sessionID, + client: ctx.client, + source: "idle", + }) + return + } + + if (input.event.type === "session.status") { + const sessionID = getSessionIDFromStatusEvent(input) + const statusType = getStatusTypeFromEvent(input) + if (!sessionID || statusType !== "idle") { + return + } + + await applyPendingSwitch({ + sessionID, + client: ctx.client, + source: "status-idle", + }) } }, } diff --git a/src/plugin/event.ts b/src/plugin/event.ts index f50344516..a6015cd18 100644 --- a/src/plugin/event.ts +++ b/src/plugin/event.ts @@ -10,17 +10,6 @@ import { syncSubagentSessions, updateSessionAgent, } from "../features/claude-code-session-state"; -import { - clearPendingModelFallback, - clearSessionFallbackChain, - setPendingModelFallback, -} from "../hooks/model-fallback/hook"; -import { resetMessageCursor } from "../shared"; -import { log } from "../shared/logger"; -import { shouldRetryError } from "../shared/model-error-classifier"; -import { clearSessionModel, setSessionModel } from "../shared/session-model-state"; -import { deleteSessionTools } from "../shared/session-tools-store"; -import { lspManager } from "../tools"; import type { CreatedHooks } from "../create-hooks"; import type { Managers } from "../create-managers"; @@ -134,29 +123,42 @@ export function createEventHandler(args: { const lastHandledRetryStatusKey = new Map(); const lastKnownModelBySession = new Map(); + async function runHookSafely(hookName: string, runner: () => Promise): Promise { + try { + await runner() + } catch (error) { + log("[event] Hook execution failed", { + hookName, + error: String(error), + }) + } + } + const dispatchToHooks = async (input: EventInput): Promise => { - await Promise.resolve(hooks.autoUpdateChecker?.event?.(input)); - await Promise.resolve(hooks.claudeCodeHooks?.event?.(input)); - await Promise.resolve(hooks.backgroundNotificationHook?.event?.(input)); - await Promise.resolve(hooks.sessionNotification?.(input)); - await Promise.resolve(hooks.todoContinuationEnforcer?.handler?.(input)); - await Promise.resolve(hooks.unstableAgentBabysitter?.event?.(input)); - await Promise.resolve(hooks.contextWindowMonitor?.event?.(input)); - await Promise.resolve(hooks.directoryAgentsInjector?.event?.(input)); - await Promise.resolve(hooks.directoryReadmeInjector?.event?.(input)); - await Promise.resolve(hooks.rulesInjector?.event?.(input)); - await Promise.resolve(hooks.thinkMode?.event?.(input)); - await Promise.resolve(hooks.anthropicContextWindowLimitRecovery?.event?.(input)); - await Promise.resolve(hooks.runtimeFallback?.event?.(input)); - await Promise.resolve(hooks.agentUsageReminder?.event?.(input)); - await Promise.resolve(hooks.categorySkillReminder?.event?.(input)); - await Promise.resolve(hooks.interactiveBashSession?.event?.(input as EventInput)); - await Promise.resolve(hooks.ralphLoop?.event?.(input)); - await Promise.resolve(hooks.stopContinuationGuard?.event?.(input)); - await Promise.resolve(hooks.compactionTodoPreserver?.event?.(input)); - await Promise.resolve(hooks.writeExistingFileGuard?.event?.(input)); - await Promise.resolve(hooks.atlasHook?.handler?.(input)); - await Promise.resolve(hooks.agentSwitchHook?.event?.(input)); + // Keep agent switch early and resilient so queued handoffs are not blocked + // by unrelated hook failures in the same idle cycle. + await runHookSafely("agent-switch", () => Promise.resolve(hooks.agentSwitchHook?.event?.(input))); + await runHookSafely("auto-update-checker", () => Promise.resolve(hooks.autoUpdateChecker?.event?.(input))); + await runHookSafely("claude-code-hooks", () => Promise.resolve(hooks.claudeCodeHooks?.event?.(input))); + await runHookSafely("background-notification", () => Promise.resolve(hooks.backgroundNotificationHook?.event?.(input))); + await runHookSafely("session-notification", () => Promise.resolve(hooks.sessionNotification?.(input))); + await runHookSafely("todo-continuation-enforcer", () => Promise.resolve(hooks.todoContinuationEnforcer?.handler?.(input))); + await runHookSafely("unstable-agent-babysitter", () => Promise.resolve(hooks.unstableAgentBabysitter?.event?.(input))); + await runHookSafely("context-window-monitor", () => Promise.resolve(hooks.contextWindowMonitor?.event?.(input))); + await runHookSafely("directory-agents-injector", () => Promise.resolve(hooks.directoryAgentsInjector?.event?.(input))); + await runHookSafely("directory-readme-injector", () => Promise.resolve(hooks.directoryReadmeInjector?.event?.(input))); + await runHookSafely("rules-injector", () => Promise.resolve(hooks.rulesInjector?.event?.(input))); + await runHookSafely("think-mode", () => Promise.resolve(hooks.thinkMode?.event?.(input))); + await runHookSafely("anthropic-context-window-limit-recovery", () => Promise.resolve(hooks.anthropicContextWindowLimitRecovery?.event?.(input))); + await runHookSafely("runtime-fallback", () => Promise.resolve(hooks.runtimeFallback?.event?.(input))); + await runHookSafely("agent-usage-reminder", () => Promise.resolve(hooks.agentUsageReminder?.event?.(input))); + await runHookSafely("category-skill-reminder", () => Promise.resolve(hooks.categorySkillReminder?.event?.(input))); + await runHookSafely("interactive-bash-session", () => Promise.resolve(hooks.interactiveBashSession?.event?.(input as EventInput))); + await runHookSafely("ralph-loop", () => Promise.resolve(hooks.ralphLoop?.event?.(input))); + await runHookSafely("stop-continuation-guard", () => Promise.resolve(hooks.stopContinuationGuard?.event?.(input))); + await runHookSafely("compaction-todo-preserver", () => Promise.resolve(hooks.compactionTodoPreserver?.event?.(input))); + await runHookSafely("write-existing-file-guard", () => Promise.resolve(hooks.writeExistingFileGuard?.event?.(input))); + await runHookSafely("atlas", () => Promise.resolve(hooks.atlasHook?.handler?.(input))); }; const recentSyntheticIdles = new Map(); diff --git a/src/plugin/tool-registry.ts b/src/plugin/tool-registry.ts index 787afaf3a..98fe73f35 100644 --- a/src/plugin/tool-registry.ts +++ b/src/plugin/tool-registry.ts @@ -55,7 +55,6 @@ export function createToolRegistry(args: { const athenaCouncilTool = createAthenaCouncilTool({ backgroundManager: managers.backgroundManager, councilConfig: athenaCouncilConfig, - client: ctx.client, }) const isMultimodalLookerEnabled = !(pluginConfig.disabled_agents ?? []).some( @@ -135,7 +134,9 @@ export function createToolRegistry(args: { ...backgroundTools, call_omo_agent: callOmoAgent, athena_council: athenaCouncilTool, - switch_agent: createSwitchAgentTool(), + switch_agent: createSwitchAgentTool({ + client: ctx.client, + }), ...(lookAt ? { look_at: lookAt } : {}), task: delegateTask, skill_mcp: skillMcpTool, diff --git a/src/tools/switch-agent/tools.test.ts b/src/tools/switch-agent/tools.test.ts index 0b97da2a4..2a91780aa 100644 --- a/src/tools/switch-agent/tools.test.ts +++ b/src/tools/switch-agent/tools.test.ts @@ -1,3 +1,5 @@ +/// + import { describe, test, expect, beforeEach } from "bun:test" import { createSwitchAgentTool } from "./tools" import { consumePendingSwitch, _resetForTesting as resetSwitch } from "../../features/agent-switch" @@ -20,11 +22,36 @@ describe("switch_agent tool", () => { resetSession() }) + function createToolWithMockClient(promptImpl?: () => Promise) { + const client = { + session: { + promptAsync: + promptImpl ?? + (async () => { + return undefined + }), + messages: async () => ({ data: [] }), + }, + } + + return createSwitchAgentTool({ + client: client as unknown as { + session: { + promptAsync: (input: { + path: { id: string } + body: { agent: string; parts: Array<{ type: "text"; text: string }> } + }) => Promise + messages: (input: { path: { id: string } }) => Promise + } + }, + }) + } + //#given valid atlas switch args //#when execute is called //#then it stores pending switch and updates session agent test("should queue switch to atlas", async () => { - const tool = createSwitchAgentTool() + const tool = createToolWithMockClient() const result = await tool.execute( { agent: "atlas", context: "Fix the auth bug based on council findings" }, toolContext @@ -46,7 +73,7 @@ describe("switch_agent tool", () => { //#when execute is called //#then it stores pending switch for prometheus test("should queue switch to prometheus", async () => { - const tool = createSwitchAgentTool() + const tool = createToolWithMockClient() const result = await tool.execute( { agent: "Prometheus", context: "Create a plan for the refactoring" }, toolContext @@ -63,7 +90,7 @@ describe("switch_agent tool", () => { //#when execute is called //#then it returns an error test("should reject invalid agent names", async () => { - const tool = createSwitchAgentTool() + const tool = createToolWithMockClient() const result = await tool.execute( { agent: "librarian", context: "Some context" }, toolContext @@ -78,7 +105,7 @@ describe("switch_agent tool", () => { //#when execute is called //#then it normalizes to lowercase test("should handle case-insensitive agent names", async () => { - const tool = createSwitchAgentTool() + const tool = createToolWithMockClient() await tool.execute( { agent: "ATLAS", context: "Fix things" }, toolContext diff --git a/src/tools/switch-agent/tools.ts b/src/tools/switch-agent/tools.ts index ea0c00853..239c894be 100644 --- a/src/tools/switch-agent/tools.ts +++ b/src/tools/switch-agent/tools.ts @@ -1,5 +1,6 @@ import { tool, type ToolDefinition } from "@opencode-ai/plugin" import { setPendingSwitch } from "../../features/agent-switch" +import { schedulePendingSwitchApply } from "../../features/agent-switch/scheduler" import { updateSessionAgent } from "../../features/claude-code-session-state" import type { SwitchAgentArgs } from "./types" @@ -10,7 +11,26 @@ const DESCRIPTION = const ALLOWED_AGENTS = new Set(["atlas", "prometheus", "sisyphus", "hephaestus"]) -export function createSwitchAgentTool(): ToolDefinition { +type SessionClient = { + session: { + prompt?: (input: { + path: { id: string } + body: { agent: string; parts: Array<{ type: "text"; text: string }> } + }) => Promise + promptAsync: (input: { + path: { id: string } + body: { agent: string; parts: Array<{ type: "text"; text: string }> } + }) => Promise + messages: (input: { path: { id: string } }) => Promise + status?: () => Promise + } +} + +export function createSwitchAgentTool(args: { + client: SessionClient +}): ToolDefinition { + const { client } = args + return tool({ description: DESCRIPTION, args: { @@ -30,6 +50,10 @@ export function createSwitchAgentTool(): ToolDefinition { updateSessionAgent(toolContext.sessionID, agentName) setPendingSwitch(toolContext.sessionID, agentName, args.context) + schedulePendingSwitchApply({ + sessionID: toolContext.sessionID, + client, + }) return `Agent switch queued. Session will switch to ${agentName} when your turn completes.` },