fix(agent-switch): make handoff durable and sync CLI TUI selection
This commit is contained in:
226
src/features/agent-switch/applier.test.ts
Normal file
226
src/features/agent-switch/applier.test.ts
Normal file
@@ -0,0 +1,226 @@
|
||||
/// <reference types="bun-types" />
|
||||
|
||||
import { beforeEach, describe, expect, test } from "bun:test"
|
||||
import { _resetForTesting, getPendingSwitch, setPendingSwitch } from "./state"
|
||||
import {
|
||||
_resetApplierForTesting,
|
||||
applyPendingSwitch,
|
||||
clearPendingSwitchRuntime,
|
||||
} from "./applier"
|
||||
import { schedulePendingSwitchApply } from "./scheduler"
|
||||
|
||||
describe("agent-switch applier", () => {
|
||||
beforeEach(() => {
|
||||
_resetForTesting()
|
||||
_resetApplierForTesting()
|
||||
})
|
||||
|
||||
test("scheduled apply works without idle event", async () => {
|
||||
const calls: string[] = []
|
||||
let switched = false
|
||||
const client = {
|
||||
session: {
|
||||
promptAsync: async (input: { body: { agent: string } }) => {
|
||||
calls.push(input.body.agent)
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] })
|
||||
: ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
setPendingSwitch("ses-1", "prometheus", "create plan")
|
||||
schedulePendingSwitchApply({
|
||||
sessionID: "ses-1",
|
||||
client: client as any,
|
||||
})
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 300))
|
||||
|
||||
expect(calls).toEqual(["Prometheus (Plan Builder)"])
|
||||
expect(getPendingSwitch("ses-1")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("normalizes pending agent to canonical prompt display name", async () => {
|
||||
const calls: string[] = []
|
||||
let switched = false
|
||||
const client = {
|
||||
session: {
|
||||
promptAsync: async (input: { body: { agent: string } }) => {
|
||||
calls.push(input.body.agent)
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] })
|
||||
: ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
setPendingSwitch("ses-2", "Prometheus (Plan Builder)", "create plan")
|
||||
await applyPendingSwitch({
|
||||
sessionID: "ses-2",
|
||||
client: client as any,
|
||||
source: "idle",
|
||||
})
|
||||
|
||||
expect(calls).toEqual(["Prometheus (Plan Builder)"])
|
||||
expect(getPendingSwitch("ses-2")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("retries transient failures and eventually clears pending switch", async () => {
|
||||
let attempts = 0
|
||||
let switched = false
|
||||
const client = {
|
||||
session: {
|
||||
promptAsync: async () => {
|
||||
attempts += 1
|
||||
if (attempts < 3) {
|
||||
throw new Error("temporary failure")
|
||||
}
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] })
|
||||
: ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
setPendingSwitch("ses-3", "atlas", "fix this")
|
||||
await applyPendingSwitch({
|
||||
sessionID: "ses-3",
|
||||
client: client as any,
|
||||
source: "idle",
|
||||
})
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 800))
|
||||
|
||||
expect(attempts).toBe(3)
|
||||
expect(getPendingSwitch("ses-3")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("waits for session idle before applying switch", async () => {
|
||||
let statusChecks = 0
|
||||
let promptCalls = 0
|
||||
let switched = false
|
||||
const client = {
|
||||
session: {
|
||||
status: async () => {
|
||||
statusChecks += 1
|
||||
return {
|
||||
"ses-5": { type: statusChecks < 3 ? "running" : "idle" },
|
||||
}
|
||||
},
|
||||
promptAsync: async () => {
|
||||
promptCalls += 1
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] })
|
||||
: ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
setPendingSwitch("ses-5", "atlas", "fix now")
|
||||
await applyPendingSwitch({
|
||||
sessionID: "ses-5",
|
||||
client: client as any,
|
||||
source: "idle",
|
||||
})
|
||||
|
||||
expect(statusChecks).toBeGreaterThanOrEqual(3)
|
||||
expect(promptCalls).toBe(1)
|
||||
expect(getPendingSwitch("ses-5")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("clearPendingSwitchRuntime cancels pending retries", async () => {
|
||||
let attempts = 0
|
||||
const client = {
|
||||
session: {
|
||||
promptAsync: async () => {
|
||||
attempts += 1
|
||||
throw new Error("always failing")
|
||||
},
|
||||
messages: async () => ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
setPendingSwitch("ses-4", "atlas", "fix this")
|
||||
await applyPendingSwitch({
|
||||
sessionID: "ses-4",
|
||||
client: client as any,
|
||||
source: "idle",
|
||||
})
|
||||
|
||||
clearPendingSwitchRuntime("ses-4")
|
||||
|
||||
const attemptsAfterClear = attempts
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 300))
|
||||
|
||||
expect(attempts).toBe(attemptsAfterClear)
|
||||
expect(getPendingSwitch("ses-4")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("syncs CLI TUI agent selection for athena-to-atlas handoff", async () => {
|
||||
const originalClientEnv = process.env["OPENCODE_CLIENT"]
|
||||
process.env["OPENCODE_CLIENT"] = "cli"
|
||||
|
||||
try {
|
||||
const promptCalls: string[] = []
|
||||
const tuiCommands: string[] = []
|
||||
let switched = false
|
||||
const client = {
|
||||
session: {
|
||||
promptAsync: async (input: { body: { agent: string } }) => {
|
||||
promptCalls.push(input.body.agent)
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({
|
||||
data: [
|
||||
{ info: { role: "user", agent: "Athena (Council)" } },
|
||||
{ info: { role: "user", agent: "Atlas (Plan Executor)" } },
|
||||
],
|
||||
})
|
||||
: ({
|
||||
data: [{ info: { role: "user", agent: "Athena (Council)" } }],
|
||||
}),
|
||||
},
|
||||
app: {
|
||||
agents: async () => ({
|
||||
data: [
|
||||
{ name: "Sisyphus (Ultraworker)", mode: "primary" },
|
||||
{ name: "Hephaestus (Deep Agent)", mode: "primary" },
|
||||
{ name: "Prometheus (Plan Builder)", mode: "primary" },
|
||||
{ name: "Atlas (Plan Executor)", mode: "primary" },
|
||||
{ name: "Athena (Council)", mode: "primary" },
|
||||
],
|
||||
}),
|
||||
},
|
||||
tui: {
|
||||
publish: async (input: { body: { properties: { command: string } } }) => {
|
||||
tuiCommands.push(input.body.properties.command)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
setPendingSwitch("ses-6", "atlas", "fix now")
|
||||
await applyPendingSwitch({
|
||||
sessionID: "ses-6",
|
||||
client: client as any,
|
||||
source: "message-updated",
|
||||
})
|
||||
|
||||
expect(promptCalls).toEqual(["Atlas (Plan Executor)"])
|
||||
expect(tuiCommands).toEqual(["agent.cycle.reverse"])
|
||||
expect(getPendingSwitch("ses-6")).toBeUndefined()
|
||||
} finally {
|
||||
if (originalClientEnv === undefined) {
|
||||
delete process.env["OPENCODE_CLIENT"]
|
||||
} else {
|
||||
process.env["OPENCODE_CLIENT"] = originalClientEnv
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
211
src/features/agent-switch/applier.ts
Normal file
211
src/features/agent-switch/applier.ts
Normal file
@@ -0,0 +1,211 @@
|
||||
import { normalizeAgentForPrompt } from "../../shared/agent-display-names"
|
||||
import { log } from "../../shared/logger"
|
||||
import { clearPendingSwitch, getPendingSwitch } from "./state"
|
||||
import { waitForSessionIdle } from "./session-status"
|
||||
import { fetchMessages, shouldClearAsAlreadyApplied, verifySwitchObserved } from "./apply-verification"
|
||||
import { getLatestUserAgent } from "./message-inspection"
|
||||
import { syncCliTuiAgentSelectionAfterSwitch } from "./tui-agent-sync"
|
||||
import {
|
||||
clearInFlight,
|
||||
clearRetryState,
|
||||
isApplyInFlight,
|
||||
markApplyInFlight,
|
||||
resetRetryStateForTesting,
|
||||
scheduleRetry,
|
||||
} from "./retry-state"
|
||||
|
||||
type SessionClient = {
|
||||
session: {
|
||||
prompt?: (input: {
|
||||
path: { id: string }
|
||||
body: { agent: string; parts: Array<{ type: "text"; text: string }> }
|
||||
}) => Promise<unknown>
|
||||
promptAsync: (input: {
|
||||
path: { id: string }
|
||||
body: { agent: string; parts: Array<{ type: "text"; text: string }> }
|
||||
}) => Promise<unknown>
|
||||
messages: (input: { path: { id: string } }) => Promise<unknown>
|
||||
status?: () => Promise<unknown>
|
||||
}
|
||||
app?: {
|
||||
agents?: () => Promise<unknown>
|
||||
}
|
||||
tui?: {
|
||||
publish?: (input: {
|
||||
body: {
|
||||
type: "tui.command.execute"
|
||||
properties: { command: string }
|
||||
}
|
||||
}) => Promise<unknown>
|
||||
}
|
||||
}
|
||||
|
||||
async function tryPromptWithCandidates(args: {
|
||||
client: SessionClient
|
||||
sessionID: string
|
||||
agent: string
|
||||
context: string
|
||||
source: string
|
||||
}): Promise<string> {
|
||||
const { client, sessionID, agent, context, source } = args
|
||||
const targetAgent = normalizeAgentForPrompt(agent)
|
||||
if (!targetAgent) {
|
||||
throw new Error(`invalid target agent for switch prompt: ${agent}`)
|
||||
}
|
||||
|
||||
try {
|
||||
const promptInput = {
|
||||
path: { id: sessionID },
|
||||
body: {
|
||||
agent: targetAgent,
|
||||
parts: [{ type: "text" as const, text: context }],
|
||||
},
|
||||
}
|
||||
|
||||
if (client.session.prompt) {
|
||||
await client.session.prompt(promptInput)
|
||||
} else {
|
||||
await client.session.promptAsync(promptInput)
|
||||
}
|
||||
|
||||
if (targetAgent !== agent) {
|
||||
log("[agent-switch] Normalized pending switch agent for prompt", {
|
||||
sessionID,
|
||||
source,
|
||||
requestedAgent: agent,
|
||||
usedAgent: targetAgent,
|
||||
})
|
||||
}
|
||||
|
||||
return targetAgent
|
||||
} catch (error) {
|
||||
log("[agent-switch] Prompt attempt failed", {
|
||||
sessionID,
|
||||
source,
|
||||
requestedAgent: agent,
|
||||
attemptedAgent: targetAgent,
|
||||
error: String(error),
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
export async function applyPendingSwitch(args: {
|
||||
sessionID: string
|
||||
client: SessionClient
|
||||
source: string
|
||||
}): Promise<void> {
|
||||
const { sessionID, client, source } = args
|
||||
const pending = getPendingSwitch(sessionID)
|
||||
if (!pending) {
|
||||
clearRetryState(sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
if (isApplyInFlight(sessionID)) {
|
||||
return
|
||||
}
|
||||
|
||||
markApplyInFlight(sessionID)
|
||||
log("[agent-switch] Applying pending switch", {
|
||||
sessionID,
|
||||
source,
|
||||
agent: pending.agent,
|
||||
})
|
||||
|
||||
try {
|
||||
const alreadyApplied = await shouldClearAsAlreadyApplied({
|
||||
client,
|
||||
sessionID,
|
||||
targetAgent: pending.agent,
|
||||
})
|
||||
if (alreadyApplied) {
|
||||
clearPendingSwitch(sessionID)
|
||||
clearRetryState(sessionID)
|
||||
log("[agent-switch] Pending switch already applied by user-turn evidence; clearing state", {
|
||||
sessionID,
|
||||
source,
|
||||
agent: pending.agent,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const idleReady = await waitForSessionIdle({ client, sessionID })
|
||||
if (!idleReady) {
|
||||
throw new Error("session not idle before applying agent switch")
|
||||
}
|
||||
|
||||
const beforeMessages = await fetchMessages({ client, sessionID })
|
||||
const sourceUserAgent = getLatestUserAgent(beforeMessages)
|
||||
|
||||
const usedAgent = await tryPromptWithCandidates({
|
||||
client,
|
||||
sessionID,
|
||||
agent: pending.agent,
|
||||
context: pending.context,
|
||||
source,
|
||||
})
|
||||
|
||||
const verified = await verifySwitchObserved({
|
||||
client,
|
||||
sessionID,
|
||||
targetAgent: pending.agent,
|
||||
baselineCount: beforeMessages.length,
|
||||
})
|
||||
if (!verified) {
|
||||
throw new Error(`agent switch not observed after prompt (attempted ${usedAgent})`)
|
||||
}
|
||||
|
||||
clearPendingSwitch(sessionID)
|
||||
clearRetryState(sessionID)
|
||||
|
||||
await syncCliTuiAgentSelectionAfterSwitch({
|
||||
client,
|
||||
sessionID,
|
||||
source,
|
||||
sourceAgent: sourceUserAgent,
|
||||
targetAgent: pending.agent,
|
||||
})
|
||||
|
||||
log("[agent-switch] Pending switch applied", {
|
||||
sessionID,
|
||||
source,
|
||||
agent: pending.agent,
|
||||
})
|
||||
} catch (error) {
|
||||
clearInFlight(sessionID)
|
||||
log("[agent-switch] Pending switch apply failed", {
|
||||
sessionID,
|
||||
source,
|
||||
error: String(error),
|
||||
})
|
||||
scheduleRetry({
|
||||
sessionID,
|
||||
source,
|
||||
onLimitReached: (attempts) => {
|
||||
log("[agent-switch] Retry limit reached; waiting for next trigger", {
|
||||
sessionID,
|
||||
attempts,
|
||||
source,
|
||||
})
|
||||
},
|
||||
retryFn: (attemptNumber) => {
|
||||
void applyPendingSwitch({
|
||||
sessionID,
|
||||
client,
|
||||
source: `retry:${attemptNumber}`,
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export function clearPendingSwitchRuntime(sessionID: string): void {
|
||||
clearPendingSwitch(sessionID)
|
||||
clearRetryState(sessionID)
|
||||
}
|
||||
|
||||
/** @internal For testing only */
|
||||
export function _resetApplierForTesting(): void {
|
||||
resetRetryStateForTesting()
|
||||
}
|
||||
59
src/features/agent-switch/apply-verification.ts
Normal file
59
src/features/agent-switch/apply-verification.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { extractMessageList, hasNewUserTurnForTargetAgent, hasRecentUserTurnForTargetAgent } from "./message-inspection"
|
||||
import { log } from "../../shared/logger"
|
||||
import { sleepWithDelay } from "./session-status"
|
||||
|
||||
type SessionClient = {
|
||||
session: {
|
||||
messages: (input: { path: { id: string } }) => Promise<unknown>
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchMessages(args: {
|
||||
client: SessionClient
|
||||
sessionID: string
|
||||
}): Promise<Array<Record<string, unknown>>> {
|
||||
const response = await args.client.session.messages({ path: { id: args.sessionID } })
|
||||
return extractMessageList(response)
|
||||
}
|
||||
|
||||
export async function verifySwitchObserved(args: {
|
||||
client: SessionClient
|
||||
sessionID: string
|
||||
targetAgent: string
|
||||
baselineCount: number
|
||||
}): Promise<boolean> {
|
||||
const { client, sessionID, targetAgent, baselineCount } = args
|
||||
const delays = [100, 300, 800, 1500] as const
|
||||
|
||||
for (const delay of delays) {
|
||||
await sleepWithDelay(delay)
|
||||
try {
|
||||
const messages = await fetchMessages({ client, sessionID })
|
||||
if (hasNewUserTurnForTargetAgent({ messages, targetAgent, baselineCount })) {
|
||||
return true
|
||||
}
|
||||
} catch (error) {
|
||||
log("[agent-switch] Verification read failed", {
|
||||
sessionID,
|
||||
error: String(error),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
export async function shouldClearAsAlreadyApplied(args: {
|
||||
client: SessionClient
|
||||
sessionID: string
|
||||
targetAgent: string
|
||||
}): Promise<boolean> {
|
||||
const { client, sessionID, targetAgent } = args
|
||||
|
||||
try {
|
||||
const messages = await fetchMessages({ client, sessionID })
|
||||
return hasRecentUserTurnForTargetAgent({ messages, targetAgent })
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,8 @@
|
||||
export { setPendingSwitch, consumePendingSwitch, _resetForTesting } from "./state"
|
||||
export {
|
||||
setPendingSwitch,
|
||||
getPendingSwitch,
|
||||
clearPendingSwitch,
|
||||
consumePendingSwitch,
|
||||
_resetForTesting,
|
||||
} from "./state"
|
||||
export type { PendingSwitch } from "./state"
|
||||
|
||||
107
src/features/agent-switch/message-inspection.ts
Normal file
107
src/features/agent-switch/message-inspection.ts
Normal file
@@ -0,0 +1,107 @@
|
||||
import { getAgentConfigKey } from "../../shared/agent-display-names"
|
||||
|
||||
export interface MessageRoleAgent {
|
||||
role: string
|
||||
agent: string
|
||||
}
|
||||
|
||||
export function extractMessageList(response: unknown): Array<Record<string, unknown>> {
|
||||
if (Array.isArray(response)) {
|
||||
return response.filter((item): item is Record<string, unknown> => typeof item === "object" && item !== null)
|
||||
}
|
||||
if (typeof response === "object" && response !== null) {
|
||||
const data = (response as Record<string, unknown>).data
|
||||
if (Array.isArray(data)) {
|
||||
return data.filter((item): item is Record<string, unknown> => typeof item === "object" && item !== null)
|
||||
}
|
||||
}
|
||||
return []
|
||||
}
|
||||
|
||||
function getRoleAgent(message: Record<string, unknown>): MessageRoleAgent | undefined {
|
||||
const info = message.info
|
||||
if (typeof info !== "object" || info === null) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const role = (info as Record<string, unknown>).role
|
||||
const agent = (info as Record<string, unknown>).agent
|
||||
if (typeof role !== "string" || typeof agent !== "string") {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return { role, agent }
|
||||
}
|
||||
|
||||
export function getLatestUserAgent(messages: Array<Record<string, unknown>>): string | undefined {
|
||||
for (let index = messages.length - 1; index >= 0; index -= 1) {
|
||||
const message = messages[index]
|
||||
if (!message) {
|
||||
continue
|
||||
}
|
||||
|
||||
const roleAgent = getRoleAgent(message)
|
||||
if (!roleAgent || roleAgent.role !== "user") {
|
||||
continue
|
||||
}
|
||||
|
||||
return roleAgent.agent
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
export function hasRecentUserTurnForTargetAgent(args: {
|
||||
messages: Array<Record<string, unknown>>
|
||||
targetAgent: string
|
||||
lookback?: number
|
||||
}): boolean {
|
||||
const { messages, targetAgent, lookback = 8 } = args
|
||||
const targetKey = getAgentConfigKey(targetAgent)
|
||||
const start = Math.max(0, messages.length - lookback)
|
||||
|
||||
for (let index = messages.length - 1; index >= start; index -= 1) {
|
||||
const message = messages[index]
|
||||
if (!message) {
|
||||
continue
|
||||
}
|
||||
|
||||
const roleAgent = getRoleAgent(message)
|
||||
if (!roleAgent || roleAgent.role !== "user") {
|
||||
continue
|
||||
}
|
||||
|
||||
if (getAgentConfigKey(roleAgent.agent) === targetKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
export function hasNewUserTurnForTargetAgent(args: {
|
||||
messages: Array<Record<string, unknown>>
|
||||
targetAgent: string
|
||||
baselineCount: number
|
||||
}): boolean {
|
||||
const { messages, targetAgent, baselineCount } = args
|
||||
const targetKey = getAgentConfigKey(targetAgent)
|
||||
|
||||
if (messages.length <= baselineCount) {
|
||||
return false
|
||||
}
|
||||
|
||||
const newMessages = messages.slice(Math.max(0, baselineCount))
|
||||
for (const message of newMessages) {
|
||||
const roleAgent = getRoleAgent(message)
|
||||
if (!roleAgent || roleAgent.role !== "user") {
|
||||
continue
|
||||
}
|
||||
|
||||
if (getAgentConfigKey(roleAgent.agent) === targetKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
66
src/features/agent-switch/retry-state.ts
Normal file
66
src/features/agent-switch/retry-state.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
const RETRY_DELAYS_MS = [50, 250, 500, 1000, 2000, 5000] as const
|
||||
|
||||
const inFlightSessions = new Set<string>()
|
||||
const retryAttempts = new Map<string, number>()
|
||||
const retryTimers = new Map<string, ReturnType<typeof setTimeout>>()
|
||||
|
||||
export function isApplyInFlight(sessionID: string): boolean {
|
||||
return inFlightSessions.has(sessionID)
|
||||
}
|
||||
|
||||
export function markApplyInFlight(sessionID: string): void {
|
||||
inFlightSessions.add(sessionID)
|
||||
}
|
||||
|
||||
export function clearRetryState(sessionID: string): void {
|
||||
const timer = retryTimers.get(sessionID)
|
||||
if (timer) {
|
||||
clearTimeout(timer)
|
||||
retryTimers.delete(sessionID)
|
||||
}
|
||||
retryAttempts.delete(sessionID)
|
||||
inFlightSessions.delete(sessionID)
|
||||
}
|
||||
|
||||
export function clearInFlight(sessionID: string): void {
|
||||
inFlightSessions.delete(sessionID)
|
||||
}
|
||||
|
||||
export function scheduleRetry(args: {
|
||||
sessionID: string
|
||||
source: string
|
||||
retryFn: (attemptNumber: number) => void
|
||||
onLimitReached: (attempts: number) => void
|
||||
}): void {
|
||||
const { sessionID, retryFn, onLimitReached } = args
|
||||
const attempts = retryAttempts.get(sessionID) ?? 0
|
||||
if (attempts >= RETRY_DELAYS_MS.length) {
|
||||
onLimitReached(attempts)
|
||||
return
|
||||
}
|
||||
|
||||
const delay = RETRY_DELAYS_MS[attempts]
|
||||
retryAttempts.set(sessionID, attempts + 1)
|
||||
|
||||
const existing = retryTimers.get(sessionID)
|
||||
if (existing) {
|
||||
clearTimeout(existing)
|
||||
}
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
retryTimers.delete(sessionID)
|
||||
retryFn(attempts + 1)
|
||||
}, delay)
|
||||
|
||||
retryTimers.set(sessionID, timer)
|
||||
}
|
||||
|
||||
/** @internal For testing only */
|
||||
export function resetRetryStateForTesting(): void {
|
||||
for (const timer of retryTimers.values()) {
|
||||
clearTimeout(timer)
|
||||
}
|
||||
retryTimers.clear()
|
||||
retryAttempts.clear()
|
||||
inFlightSessions.clear()
|
||||
}
|
||||
43
src/features/agent-switch/scheduler.ts
Normal file
43
src/features/agent-switch/scheduler.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { log } from "../../shared/logger"
|
||||
import { scheduleRetry } from "./retry-state"
|
||||
import { applyPendingSwitch } from "./applier"
|
||||
|
||||
type SessionClient = {
|
||||
session: {
|
||||
prompt?: (input: {
|
||||
path: { id: string }
|
||||
body: { agent: string; parts: Array<{ type: "text"; text: string }> }
|
||||
}) => Promise<unknown>
|
||||
promptAsync: (input: {
|
||||
path: { id: string }
|
||||
body: { agent: string; parts: Array<{ type: "text"; text: string }> }
|
||||
}) => Promise<unknown>
|
||||
messages: (input: { path: { id: string } }) => Promise<unknown>
|
||||
status?: () => Promise<unknown>
|
||||
}
|
||||
}
|
||||
|
||||
export function schedulePendingSwitchApply(args: {
|
||||
sessionID: string
|
||||
client: SessionClient
|
||||
}): void {
|
||||
const { sessionID, client } = args
|
||||
scheduleRetry({
|
||||
sessionID,
|
||||
source: "tool",
|
||||
onLimitReached: (attempts) => {
|
||||
log("[agent-switch] Retry limit reached; waiting for next trigger", {
|
||||
sessionID,
|
||||
attempts,
|
||||
source: "tool",
|
||||
})
|
||||
},
|
||||
retryFn: (attemptNumber) => {
|
||||
void applyPendingSwitch({
|
||||
sessionID,
|
||||
client,
|
||||
source: `retry:${attemptNumber}`,
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
68
src/features/agent-switch/session-status.ts
Normal file
68
src/features/agent-switch/session-status.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import { log } from "../../shared/logger"
|
||||
|
||||
type SessionClient = {
|
||||
session: {
|
||||
status?: () => Promise<unknown>
|
||||
}
|
||||
}
|
||||
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms))
|
||||
}
|
||||
|
||||
function getSessionStatusType(statusResponse: unknown, sessionID: string): string | undefined {
|
||||
if (typeof statusResponse !== "object" || statusResponse === null) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const root = statusResponse as Record<string, unknown>
|
||||
const data = (typeof root.data === "object" && root.data !== null)
|
||||
? root.data as Record<string, unknown>
|
||||
: root
|
||||
|
||||
const entry = data[sessionID]
|
||||
if (typeof entry !== "object" || entry === null) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const entryType = (entry as Record<string, unknown>).type
|
||||
return typeof entryType === "string" ? entryType : undefined
|
||||
}
|
||||
|
||||
export async function waitForSessionIdle(args: {
|
||||
client: SessionClient
|
||||
sessionID: string
|
||||
timeoutMs?: number
|
||||
}): Promise<boolean> {
|
||||
const { client, sessionID, timeoutMs = 15000 } = args
|
||||
if (!client.session.status) {
|
||||
return true
|
||||
}
|
||||
|
||||
const start = Date.now()
|
||||
while (Date.now() - start < timeoutMs) {
|
||||
try {
|
||||
const statusResponse = await client.session.status()
|
||||
const statusType = getSessionStatusType(statusResponse, sessionID)
|
||||
// /session/status only tracks non-idle sessions in SessionStatus.list().
|
||||
// Missing entry means idle.
|
||||
if (!statusType || statusType === "idle") {
|
||||
return true
|
||||
}
|
||||
} catch (error) {
|
||||
log("[agent-switch] Session status check failed", {
|
||||
sessionID,
|
||||
error: String(error),
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
await sleep(200)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
export async function sleepWithDelay(ms: number): Promise<void> {
|
||||
await sleep(ms)
|
||||
}
|
||||
@@ -1,5 +1,11 @@
|
||||
import { describe, test, expect, beforeEach } from "bun:test"
|
||||
import { setPendingSwitch, consumePendingSwitch, _resetForTesting } from "./state"
|
||||
const { describe, test, expect, beforeEach } = require("bun:test")
|
||||
import {
|
||||
setPendingSwitch,
|
||||
getPendingSwitch,
|
||||
clearPendingSwitch,
|
||||
consumePendingSwitch,
|
||||
_resetForTesting,
|
||||
} from "./state"
|
||||
|
||||
describe("agent-switch state", () => {
|
||||
beforeEach(() => {
|
||||
@@ -47,4 +53,21 @@ describe("agent-switch state", () => {
|
||||
expect(consumePendingSwitch("session-1")).toEqual({ agent: "atlas", context: "Fix A" })
|
||||
expect(consumePendingSwitch("session-2")).toEqual({ agent: "prometheus", context: "Plan B" })
|
||||
})
|
||||
|
||||
test("should allow reading without consuming", () => {
|
||||
setPendingSwitch("session-1", "atlas", "Fix A")
|
||||
|
||||
expect(getPendingSwitch("session-1")).toEqual({ agent: "atlas", context: "Fix A" })
|
||||
expect(getPendingSwitch("session-1")).toEqual({ agent: "atlas", context: "Fix A" })
|
||||
})
|
||||
|
||||
test("should clear pending switch explicitly", () => {
|
||||
setPendingSwitch("session-1", "atlas", "Fix A")
|
||||
|
||||
clearPendingSwitch("session-1")
|
||||
|
||||
expect(getPendingSwitch("session-1")).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
export {}
|
||||
|
||||
@@ -1,23 +1,102 @@
|
||||
import { existsSync, readFileSync, rmSync, writeFileSync } from "node:fs"
|
||||
import { join } from "node:path"
|
||||
import { tmpdir } from "node:os"
|
||||
|
||||
export interface PendingSwitch {
|
||||
agent: string
|
||||
context: string
|
||||
}
|
||||
|
||||
const PENDING_SWITCH_STATE_FILE = process.platform === "win32"
|
||||
? join(tmpdir(), "oh-my-opencode-agent-switch.json")
|
||||
: "/tmp/oh-my-opencode-agent-switch.json"
|
||||
|
||||
const pendingSwitches = new Map<string, PendingSwitch>()
|
||||
|
||||
function isPendingSwitch(value: unknown): value is PendingSwitch {
|
||||
if (typeof value !== "object" || value === null) return false
|
||||
const entry = value as Record<string, unknown>
|
||||
return typeof entry.agent === "string" && typeof entry.context === "string"
|
||||
}
|
||||
|
||||
function readPersistentState(): Record<string, PendingSwitch> {
|
||||
try {
|
||||
if (!existsSync(PENDING_SWITCH_STATE_FILE)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const raw = readFileSync(PENDING_SWITCH_STATE_FILE, "utf8")
|
||||
const parsed = JSON.parse(raw)
|
||||
if (typeof parsed !== "object" || parsed === null) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const state: Record<string, PendingSwitch> = {}
|
||||
for (const [sessionID, value] of Object.entries(parsed)) {
|
||||
if (isPendingSwitch(value)) {
|
||||
state[sessionID] = value
|
||||
}
|
||||
}
|
||||
|
||||
return state
|
||||
} catch {
|
||||
return {}
|
||||
}
|
||||
}
|
||||
|
||||
function writePersistentState(state: Record<string, PendingSwitch>): void {
|
||||
try {
|
||||
const keys = Object.keys(state)
|
||||
if (keys.length === 0) {
|
||||
rmSync(PENDING_SWITCH_STATE_FILE, { force: true })
|
||||
return
|
||||
}
|
||||
|
||||
writeFileSync(PENDING_SWITCH_STATE_FILE, JSON.stringify(state), "utf8")
|
||||
} catch {
|
||||
// ignore persistence errors
|
||||
}
|
||||
}
|
||||
|
||||
export function setPendingSwitch(sessionID: string, agent: string, context: string): void {
|
||||
pendingSwitches.set(sessionID, { agent, context })
|
||||
const entry = { agent, context }
|
||||
pendingSwitches.set(sessionID, entry)
|
||||
|
||||
const state = readPersistentState()
|
||||
state[sessionID] = entry
|
||||
writePersistentState(state)
|
||||
}
|
||||
|
||||
export function getPendingSwitch(sessionID: string): PendingSwitch | undefined {
|
||||
const inMemory = pendingSwitches.get(sessionID)
|
||||
if (inMemory) {
|
||||
return inMemory
|
||||
}
|
||||
|
||||
const state = readPersistentState()
|
||||
const fromDisk = state[sessionID]
|
||||
if (fromDisk) {
|
||||
pendingSwitches.set(sessionID, fromDisk)
|
||||
}
|
||||
return fromDisk
|
||||
}
|
||||
|
||||
export function clearPendingSwitch(sessionID: string): void {
|
||||
pendingSwitches.delete(sessionID)
|
||||
|
||||
const state = readPersistentState()
|
||||
delete state[sessionID]
|
||||
writePersistentState(state)
|
||||
}
|
||||
|
||||
export function consumePendingSwitch(sessionID: string): PendingSwitch | undefined {
|
||||
const entry = pendingSwitches.get(sessionID)
|
||||
if (entry) {
|
||||
pendingSwitches.delete(sessionID)
|
||||
}
|
||||
const entry = getPendingSwitch(sessionID)
|
||||
clearPendingSwitch(sessionID)
|
||||
return entry
|
||||
}
|
||||
|
||||
/** @internal For testing only */
|
||||
export function _resetForTesting(): void {
|
||||
pendingSwitches.clear()
|
||||
rmSync(PENDING_SWITCH_STATE_FILE, { force: true })
|
||||
}
|
||||
|
||||
132
src/features/agent-switch/tui-agent-sync.ts
Normal file
132
src/features/agent-switch/tui-agent-sync.ts
Normal file
@@ -0,0 +1,132 @@
|
||||
import { getAgentConfigKey } from "../../shared/agent-display-names"
|
||||
import { log, normalizeSDKResponse } from "../../shared"
|
||||
|
||||
type TuiClient = {
|
||||
app?: {
|
||||
agents?: () => Promise<unknown>
|
||||
}
|
||||
tui?: {
|
||||
publish?: (input: {
|
||||
body: {
|
||||
type: "tui.command.execute"
|
||||
properties: { command: string }
|
||||
}
|
||||
}) => Promise<unknown>
|
||||
}
|
||||
}
|
||||
|
||||
type AgentInfo = {
|
||||
name?: string
|
||||
mode?: "subagent" | "primary" | "all"
|
||||
hidden?: boolean
|
||||
}
|
||||
|
||||
function isCliClient(): boolean {
|
||||
return (process.env["OPENCODE_CLIENT"] ?? "cli") === "cli"
|
||||
}
|
||||
|
||||
function resolveCyclePlan(args: {
|
||||
orderedAgentNames: string[]
|
||||
sourceAgent: string
|
||||
targetAgent: string
|
||||
}): { command: "agent.cycle" | "agent.cycle.reverse"; steps: number } | undefined {
|
||||
const { orderedAgentNames, sourceAgent, targetAgent } = args
|
||||
if (orderedAgentNames.length < 2) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const orderedKeys = orderedAgentNames.map((name) => getAgentConfigKey(name))
|
||||
const sourceKey = getAgentConfigKey(sourceAgent)
|
||||
const targetKey = getAgentConfigKey(targetAgent)
|
||||
|
||||
const sourceIndex = orderedKeys.indexOf(sourceKey)
|
||||
const targetIndex = orderedKeys.indexOf(targetKey)
|
||||
if (sourceIndex < 0 || targetIndex < 0 || sourceIndex === targetIndex) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const size = orderedKeys.length
|
||||
const forward = (targetIndex - sourceIndex + size) % size
|
||||
const backward = (sourceIndex - targetIndex + size) % size
|
||||
|
||||
if (forward <= backward) {
|
||||
return { command: "agent.cycle", steps: forward }
|
||||
}
|
||||
|
||||
return { command: "agent.cycle.reverse", steps: backward }
|
||||
}
|
||||
|
||||
export async function syncCliTuiAgentSelectionAfterSwitch(args: {
|
||||
client: TuiClient
|
||||
sessionID: string
|
||||
sourceAgent: string | undefined
|
||||
targetAgent: string
|
||||
source: string
|
||||
}): Promise<void> {
|
||||
const { client, sessionID, sourceAgent, targetAgent, source } = args
|
||||
|
||||
if (!isCliClient()) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!sourceAgent || !client.app?.agents || !client.tui?.publish) {
|
||||
return
|
||||
}
|
||||
|
||||
const sourceKey = getAgentConfigKey(sourceAgent)
|
||||
const targetKey = getAgentConfigKey(targetAgent)
|
||||
|
||||
// Scope to Athena handoffs where CLI TUI can show stale local-agent selection.
|
||||
if (sourceKey !== "athena" || (targetKey !== "atlas" && targetKey !== "prometheus")) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await client.app.agents()
|
||||
const agents = normalizeSDKResponse(response, [] as AgentInfo[], {
|
||||
preferResponseOnMissingData: true,
|
||||
})
|
||||
|
||||
const orderedPrimaryAgents = agents
|
||||
.filter((agent) => typeof agent.name === "string" && agent.mode !== "subagent" && agent.hidden !== true)
|
||||
.map((agent) => agent.name as string)
|
||||
|
||||
const plan = resolveCyclePlan({
|
||||
orderedAgentNames: orderedPrimaryAgents,
|
||||
sourceAgent,
|
||||
targetAgent,
|
||||
})
|
||||
|
||||
if (!plan || plan.steps <= 0) {
|
||||
return
|
||||
}
|
||||
|
||||
for (let step = 0; step < plan.steps; step += 1) {
|
||||
await client.tui.publish({
|
||||
body: {
|
||||
type: "tui.command.execute",
|
||||
properties: {
|
||||
command: plan.command,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
log("[agent-switch] Synced CLI TUI local agent after handoff", {
|
||||
sessionID,
|
||||
source,
|
||||
sourceAgent,
|
||||
targetAgent,
|
||||
command: plan.command,
|
||||
steps: plan.steps,
|
||||
})
|
||||
} catch (error) {
|
||||
log("[agent-switch] Failed syncing CLI TUI local agent after handoff", {
|
||||
sessionID,
|
||||
source,
|
||||
sourceAgent,
|
||||
targetAgent,
|
||||
error: String(error),
|
||||
})
|
||||
}
|
||||
}
|
||||
75
src/hooks/agent-switch/fallback-handoff.ts
Normal file
75
src/hooks/agent-switch/fallback-handoff.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
export function isTerminalFinishValue(finish: unknown): boolean {
|
||||
if (typeof finish === "boolean") {
|
||||
return finish
|
||||
}
|
||||
|
||||
if (typeof finish === "string") {
|
||||
const normalized = finish.toLowerCase()
|
||||
return normalized !== "" && normalized !== "tool-calls" && normalized !== "unknown"
|
||||
}
|
||||
|
||||
if (typeof finish === "object" && finish !== null) {
|
||||
const record = finish as Record<string, unknown>
|
||||
const kind = record.type ?? record.reason
|
||||
if (typeof kind === "string") {
|
||||
const normalized = kind.toLowerCase()
|
||||
return normalized !== "" && normalized !== "tool-calls" && normalized !== "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
export function isTerminalStepFinishPart(part: unknown): boolean {
|
||||
if (typeof part !== "object" || part === null) {
|
||||
return false
|
||||
}
|
||||
|
||||
const record = part as Record<string, unknown>
|
||||
if (record.type !== "step-finish") {
|
||||
return false
|
||||
}
|
||||
|
||||
return isTerminalFinishValue(record.reason)
|
||||
}
|
||||
|
||||
export function extractTextPartsFromMessageResponse(response: unknown): string {
|
||||
if (typeof response !== "object" || response === null) return ""
|
||||
const data = (response as Record<string, unknown>).data
|
||||
if (typeof data !== "object" || data === null) return ""
|
||||
const parts = (data as Record<string, unknown>).parts
|
||||
if (!Array.isArray(parts)) return ""
|
||||
|
||||
return parts
|
||||
.map((part) => {
|
||||
if (typeof part !== "object" || part === null) return ""
|
||||
const partRecord = part as Record<string, unknown>
|
||||
if (partRecord.type !== "text") return ""
|
||||
return typeof partRecord.text === "string" ? partRecord.text : ""
|
||||
})
|
||||
.filter((text) => text.length > 0)
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
export function detectFallbackHandoffTarget(messageText: string): "atlas" | "prometheus" | undefined {
|
||||
if (!messageText) return undefined
|
||||
|
||||
const normalized = messageText.toLowerCase()
|
||||
|
||||
if (/switching\s+to\s+\*{0,2}\s*prometheus\b/.test(normalized) || /handing\s+off\s+to\s+\*{0,2}\s*prometheus\b/.test(normalized)) {
|
||||
return "prometheus"
|
||||
}
|
||||
|
||||
if (/switching\s+to\s+\*{0,2}\s*atlas\b/.test(normalized) || /handing\s+off\s+to\s+\*{0,2}\s*atlas\b/.test(normalized)) {
|
||||
return "atlas"
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
export function buildFallbackContext(target: "atlas" | "prometheus"): string {
|
||||
if (target === "prometheus") {
|
||||
return "Athena indicated handoff to Prometheus. Continue from the current session context and produce the requested phased plan based on the council findings already gathered."
|
||||
}
|
||||
return "Athena indicated handoff to Atlas. Continue from the current session context and implement the agreed fixes from the council findings."
|
||||
}
|
||||
358
src/hooks/agent-switch/hook.test.ts
Normal file
358
src/hooks/agent-switch/hook.test.ts
Normal file
@@ -0,0 +1,358 @@
|
||||
/// <reference types="bun-types" />
|
||||
|
||||
import { beforeEach, describe, expect, test } from "bun:test"
|
||||
import { createAgentSwitchHook } from "./hook"
|
||||
import {
|
||||
_resetForTesting,
|
||||
getPendingSwitch,
|
||||
setPendingSwitch,
|
||||
} from "../../features/agent-switch"
|
||||
import { _resetApplierForTesting, clearPendingSwitchRuntime } from "../../features/agent-switch/applier"
|
||||
|
||||
describe("agent-switch hook", () => {
|
||||
beforeEach(() => {
|
||||
_resetForTesting()
|
||||
_resetApplierForTesting()
|
||||
})
|
||||
|
||||
test("consumes pending switch only after successful promptAsync", async () => {
|
||||
const promptAsyncCalls: Array<Record<string, unknown>> = []
|
||||
let switched = false
|
||||
const ctx = {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync: async (args: Record<string, unknown>) => {
|
||||
promptAsyncCalls.push(args)
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] })
|
||||
: ({ data: [] }),
|
||||
message: async () => ({ data: { parts: [] } }),
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
setPendingSwitch("ses-1", "prometheus", "plan this")
|
||||
const hook = createAgentSwitchHook(ctx)
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "session.idle",
|
||||
properties: { sessionID: "ses-1" },
|
||||
},
|
||||
})
|
||||
|
||||
expect(promptAsyncCalls).toHaveLength(1)
|
||||
expect(getPendingSwitch("ses-1")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("keeps pending switch when promptAsync fails", async () => {
|
||||
const ctx = {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync: async () => {
|
||||
throw new Error("temporary failure")
|
||||
},
|
||||
messages: async () => ({ data: [] }),
|
||||
message: async () => ({ data: { parts: [] } }),
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
setPendingSwitch("ses-2", "atlas", "fix this")
|
||||
const hook = createAgentSwitchHook(ctx)
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "session.idle",
|
||||
properties: { sessionID: "ses-2" },
|
||||
},
|
||||
})
|
||||
|
||||
expect(getPendingSwitch("ses-2")).toEqual({
|
||||
agent: "atlas",
|
||||
context: "fix this",
|
||||
})
|
||||
|
||||
clearPendingSwitchRuntime("ses-2")
|
||||
})
|
||||
|
||||
test("retries after transient failure and eventually clears pending switch", async () => {
|
||||
let attempts = 0
|
||||
let switched = false
|
||||
const ctx = {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync: async () => {
|
||||
attempts += 1
|
||||
if (attempts === 1) {
|
||||
throw new Error("temporary failure")
|
||||
}
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] })
|
||||
: ({ data: [] }),
|
||||
message: async () => ({ data: { parts: [] } }),
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
setPendingSwitch("ses-3", "prometheus", "plan this")
|
||||
const hook = createAgentSwitchHook(ctx)
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "session.idle",
|
||||
properties: { sessionID: "ses-3" },
|
||||
},
|
||||
})
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 350))
|
||||
|
||||
expect(attempts).toBe(2)
|
||||
expect(getPendingSwitch("ses-3")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("clears pending switch on session.deleted", async () => {
|
||||
const ctx = {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync: async () => {},
|
||||
messages: async () => ({ data: [] }),
|
||||
message: async () => ({ data: { parts: [] } }),
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
setPendingSwitch("ses-4", "atlas", "fix this")
|
||||
const hook = createAgentSwitchHook(ctx)
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "session.deleted",
|
||||
properties: { info: { id: "ses-4" } },
|
||||
},
|
||||
})
|
||||
|
||||
expect(getPendingSwitch("ses-4")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("recovers missing switch_agent tool call from Athena handoff text", async () => {
|
||||
const promptAsyncCalls: Array<Record<string, unknown>> = []
|
||||
let switched = false
|
||||
const ctx = {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync: async (args: Record<string, unknown>) => {
|
||||
promptAsyncCalls.push(args)
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Prometheus (Plan Builder)" } }] })
|
||||
: ({ data: [] }),
|
||||
message: async () => ({
|
||||
data: {
|
||||
parts: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Switching to **Prometheus** now — they'll take it from here and craft a plan for you!",
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
const hook = createAgentSwitchHook(ctx)
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "message.updated",
|
||||
properties: {
|
||||
info: {
|
||||
id: "msg-athena-1",
|
||||
sessionID: "ses-5",
|
||||
role: "assistant",
|
||||
agent: "Athena (Council)",
|
||||
finish: "stop",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(promptAsyncCalls).toHaveLength(1)
|
||||
const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined
|
||||
expect(body?.agent).toBe("Prometheus (Plan Builder)")
|
||||
expect(getPendingSwitch("ses-5")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("applies queued pending switch on terminal message.updated", async () => {
|
||||
const promptAsyncCalls: Array<Record<string, unknown>> = []
|
||||
let switched = false
|
||||
const ctx = {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync: async (args: Record<string, unknown>) => {
|
||||
promptAsyncCalls.push(args)
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] })
|
||||
: ({ data: [] }),
|
||||
message: async () => ({ data: { parts: [] } }),
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
setPendingSwitch("ses-6", "atlas", "fix now")
|
||||
const hook = createAgentSwitchHook(ctx)
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "message.updated",
|
||||
properties: {
|
||||
info: {
|
||||
id: "msg-6",
|
||||
sessionID: "ses-6",
|
||||
role: "assistant",
|
||||
agent: "Athena (Council)",
|
||||
finish: "stop",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(promptAsyncCalls).toHaveLength(1)
|
||||
const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined
|
||||
expect(body?.agent).toBe("Atlas (Plan Executor)")
|
||||
expect(getPendingSwitch("ses-6")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("applies queued pending switch on terminal message.updated even when role is missing", async () => {
|
||||
const promptAsyncCalls: Array<Record<string, unknown>> = []
|
||||
let switched = false
|
||||
const ctx = {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync: async (args: Record<string, unknown>) => {
|
||||
promptAsyncCalls.push(args)
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] })
|
||||
: ({ data: [] }),
|
||||
message: async () => ({ data: { parts: [] } }),
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
setPendingSwitch("ses-8", "atlas", "fix now")
|
||||
const hook = createAgentSwitchHook(ctx)
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "message.updated",
|
||||
properties: {
|
||||
info: {
|
||||
id: "msg-8",
|
||||
sessionID: "ses-8",
|
||||
agent: "Athena (Council)",
|
||||
finish: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(promptAsyncCalls).toHaveLength(1)
|
||||
const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined
|
||||
expect(body?.agent).toBe("Atlas (Plan Executor)")
|
||||
expect(getPendingSwitch("ses-8")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("applies queued pending switch on terminal message.part.updated step-finish", async () => {
|
||||
const promptAsyncCalls: Array<Record<string, unknown>> = []
|
||||
let switched = false
|
||||
const ctx = {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync: async (args: Record<string, unknown>) => {
|
||||
promptAsyncCalls.push(args)
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] })
|
||||
: ({ data: [] }),
|
||||
message: async () => ({ data: { parts: [] } }),
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
setPendingSwitch("ses-7", "atlas", "fix now")
|
||||
const hook = createAgentSwitchHook(ctx)
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "message.part.updated",
|
||||
properties: {
|
||||
info: {
|
||||
sessionID: "ses-7",
|
||||
role: "assistant",
|
||||
},
|
||||
part: {
|
||||
id: "part-finish-1",
|
||||
sessionID: "ses-7",
|
||||
type: "step-finish",
|
||||
reason: "stop",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(promptAsyncCalls).toHaveLength(1)
|
||||
const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined
|
||||
expect(body?.agent).toBe("Atlas (Plan Executor)")
|
||||
expect(getPendingSwitch("ses-7")).toBeUndefined()
|
||||
})
|
||||
|
||||
test("applies queued pending switch on session.status idle", async () => {
|
||||
const promptAsyncCalls: Array<Record<string, unknown>> = []
|
||||
let switched = false
|
||||
const ctx = {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync: async (args: Record<string, unknown>) => {
|
||||
promptAsyncCalls.push(args)
|
||||
switched = true
|
||||
},
|
||||
messages: async () => switched
|
||||
? ({ data: [{ info: { role: "user", agent: "Atlas (Plan Executor)" } }] })
|
||||
: ({ data: [] }),
|
||||
message: async () => ({ data: { parts: [] } }),
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
setPendingSwitch("ses-9", "atlas", "fix now")
|
||||
const hook = createAgentSwitchHook(ctx)
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "session.status",
|
||||
properties: {
|
||||
sessionID: "ses-9",
|
||||
status: {
|
||||
type: "idle",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(promptAsyncCalls).toHaveLength(1)
|
||||
const body = promptAsyncCalls[0]?.body as { agent?: string } | undefined
|
||||
expect(body?.agent).toBe("Atlas (Plan Executor)")
|
||||
expect(getPendingSwitch("ses-9")).toBeUndefined()
|
||||
})
|
||||
})
|
||||
@@ -1,36 +1,194 @@
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import { consumePendingSwitch } from "../../features/agent-switch"
|
||||
import { getPendingSwitch, setPendingSwitch } from "../../features/agent-switch"
|
||||
import { applyPendingSwitch, clearPendingSwitchRuntime } from "../../features/agent-switch/applier"
|
||||
import { getAgentConfigKey } from "../../shared/agent-display-names"
|
||||
import { log } from "../../shared/logger"
|
||||
import {
|
||||
buildFallbackContext,
|
||||
detectFallbackHandoffTarget,
|
||||
extractTextPartsFromMessageResponse,
|
||||
isTerminalFinishValue,
|
||||
isTerminalStepFinishPart,
|
||||
} from "./fallback-handoff"
|
||||
|
||||
const HOOK_NAME = "agent-switch" as const
|
||||
const processedFallbackMessages = new Set<string>()
|
||||
|
||||
function getSessionIDFromStatusEvent(input: { event: { properties?: Record<string, unknown> } }): string | undefined {
|
||||
const props = input.event.properties as Record<string, unknown> | undefined
|
||||
const fromProps = typeof props?.sessionID === "string" ? props.sessionID : undefined
|
||||
if (fromProps) {
|
||||
return fromProps
|
||||
}
|
||||
|
||||
const status = props?.status as Record<string, unknown> | undefined
|
||||
const fromStatus = typeof status?.sessionID === "string" ? status.sessionID : undefined
|
||||
return fromStatus
|
||||
}
|
||||
|
||||
function getStatusTypeFromEvent(input: { event: { properties?: Record<string, unknown> } }): string | undefined {
|
||||
const props = input.event.properties as Record<string, unknown> | undefined
|
||||
const directType = typeof props?.type === "string" ? props.type : undefined
|
||||
if (directType) {
|
||||
return directType
|
||||
}
|
||||
|
||||
const status = props?.status as Record<string, unknown> | undefined
|
||||
const statusType = typeof status?.type === "string" ? status.type : undefined
|
||||
return statusType
|
||||
}
|
||||
|
||||
export function createAgentSwitchHook(ctx: PluginInput) {
|
||||
return {
|
||||
event: async (input: { event: { type: string; properties?: Record<string, unknown> } }): Promise<void> => {
|
||||
if (input.event.type !== "session.idle") return
|
||||
if (input.event.type === "session.deleted") {
|
||||
const props = input.event.properties as Record<string, unknown> | undefined
|
||||
const info = props?.info as Record<string, unknown> | undefined
|
||||
const deletedSessionID = info?.id
|
||||
if (typeof deletedSessionID === "string") {
|
||||
clearPendingSwitchRuntime(deletedSessionID)
|
||||
for (const key of Array.from(processedFallbackMessages)) {
|
||||
if (key.startsWith(`${deletedSessionID}:`)) {
|
||||
processedFallbackMessages.delete(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const props = input.event.properties as Record<string, unknown> | undefined
|
||||
const sessionID = props?.sessionID as string | undefined
|
||||
if (!sessionID) return
|
||||
if (input.event.type === "message.updated") {
|
||||
const props = input.event.properties as Record<string, unknown> | undefined
|
||||
const info = props?.info as Record<string, unknown> | undefined
|
||||
const sessionID = typeof info?.sessionID === "string" ? info.sessionID : undefined
|
||||
const messageID = typeof info?.id === "string" ? info.id : undefined
|
||||
const agent = typeof info?.agent === "string" ? info.agent : undefined
|
||||
const finish = info?.finish
|
||||
|
||||
const pending = consumePendingSwitch(sessionID)
|
||||
if (!pending) return
|
||||
if (!sessionID) {
|
||||
return
|
||||
}
|
||||
|
||||
log(`[${HOOK_NAME}] Switching to ${pending.agent}`, { sessionID })
|
||||
const isTerminalAssistantUpdate = isTerminalFinishValue(finish)
|
||||
if (!isTerminalAssistantUpdate) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
await ctx.client.session.promptAsync({
|
||||
path: { id: sessionID },
|
||||
body: {
|
||||
agent: pending.agent,
|
||||
parts: [{ type: "text", text: pending.context }],
|
||||
},
|
||||
query: { directory: ctx.directory },
|
||||
// Primary path: if switch_agent queued a pending switch, apply it as soon as
|
||||
// assistant turn is terminal (no reliance on session.idle timing).
|
||||
if (getPendingSwitch(sessionID)) {
|
||||
await applyPendingSwitch({
|
||||
sessionID,
|
||||
client: ctx.client,
|
||||
source: "message-updated",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if (!messageID) {
|
||||
return
|
||||
}
|
||||
|
||||
if (getAgentConfigKey(agent ?? "") !== "athena") {
|
||||
return
|
||||
}
|
||||
|
||||
const marker = `${sessionID}:${messageID}`
|
||||
if (processedFallbackMessages.has(marker)) {
|
||||
return
|
||||
}
|
||||
processedFallbackMessages.add(marker)
|
||||
|
||||
// If switch_agent already queued a handoff, do not synthesize fallback behavior.
|
||||
if (getPendingSwitch(sessionID)) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await ctx.client.session.message({
|
||||
path: { id: sessionID, messageID },
|
||||
})
|
||||
const text = extractTextPartsFromMessageResponse(response)
|
||||
const target = detectFallbackHandoffTarget(text)
|
||||
if (!target) {
|
||||
return
|
||||
}
|
||||
|
||||
setPendingSwitch(sessionID, target, buildFallbackContext(target))
|
||||
log("[agent-switch] Recovered missing switch_agent tool call from Athena handoff text", {
|
||||
sessionID,
|
||||
messageID,
|
||||
target,
|
||||
})
|
||||
|
||||
await applyPendingSwitch({
|
||||
sessionID,
|
||||
client: ctx.client,
|
||||
source: "athena-message-fallback",
|
||||
})
|
||||
} catch (error) {
|
||||
log("[agent-switch] Failed to recover fallback handoff from Athena message", {
|
||||
sessionID,
|
||||
messageID,
|
||||
error: String(error),
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if (input.event.type === "message.part.updated") {
|
||||
const props = input.event.properties as Record<string, unknown> | undefined
|
||||
const part = props?.part
|
||||
const info = props?.info as Record<string, unknown> | undefined
|
||||
const sessionIDFromPart = typeof (part as Record<string, unknown> | undefined)?.sessionID === "string"
|
||||
? ((part as Record<string, unknown>).sessionID as string)
|
||||
: undefined
|
||||
const sessionIDFromInfo = typeof info?.sessionID === "string" ? info.sessionID : undefined
|
||||
const sessionID = sessionIDFromPart ?? sessionIDFromInfo
|
||||
if (!sessionID) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!isTerminalStepFinishPart(part)) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!getPendingSwitch(sessionID)) {
|
||||
return
|
||||
}
|
||||
|
||||
await applyPendingSwitch({
|
||||
sessionID,
|
||||
client: ctx.client,
|
||||
source: "message-part-step-finish",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
log(`[${HOOK_NAME}] Switch to ${pending.agent} complete`, { sessionID })
|
||||
} catch (err) {
|
||||
log(`[${HOOK_NAME}] Switch failed`, { sessionID, error: String(err) })
|
||||
if (input.event.type === "session.idle") {
|
||||
const props = input.event.properties as Record<string, unknown> | undefined
|
||||
const sessionID = props?.sessionID as string | undefined
|
||||
if (!sessionID) return
|
||||
|
||||
await applyPendingSwitch({
|
||||
sessionID,
|
||||
client: ctx.client,
|
||||
source: "idle",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if (input.event.type === "session.status") {
|
||||
const sessionID = getSessionIDFromStatusEvent(input)
|
||||
const statusType = getStatusTypeFromEvent(input)
|
||||
if (!sessionID || statusType !== "idle") {
|
||||
return
|
||||
}
|
||||
|
||||
await applyPendingSwitch({
|
||||
sessionID,
|
||||
client: ctx.client,
|
||||
source: "status-idle",
|
||||
})
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -10,17 +10,6 @@ import {
|
||||
syncSubagentSessions,
|
||||
updateSessionAgent,
|
||||
} from "../features/claude-code-session-state";
|
||||
import {
|
||||
clearPendingModelFallback,
|
||||
clearSessionFallbackChain,
|
||||
setPendingModelFallback,
|
||||
} from "../hooks/model-fallback/hook";
|
||||
import { resetMessageCursor } from "../shared";
|
||||
import { log } from "../shared/logger";
|
||||
import { shouldRetryError } from "../shared/model-error-classifier";
|
||||
import { clearSessionModel, setSessionModel } from "../shared/session-model-state";
|
||||
import { deleteSessionTools } from "../shared/session-tools-store";
|
||||
import { lspManager } from "../tools";
|
||||
|
||||
import type { CreatedHooks } from "../create-hooks";
|
||||
import type { Managers } from "../create-managers";
|
||||
@@ -134,29 +123,42 @@ export function createEventHandler(args: {
|
||||
const lastHandledRetryStatusKey = new Map<string, string>();
|
||||
const lastKnownModelBySession = new Map<string, { providerID: string; modelID: string }>();
|
||||
|
||||
async function runHookSafely(hookName: string, runner: () => Promise<unknown>): Promise<void> {
|
||||
try {
|
||||
await runner()
|
||||
} catch (error) {
|
||||
log("[event] Hook execution failed", {
|
||||
hookName,
|
||||
error: String(error),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const dispatchToHooks = async (input: EventInput): Promise<void> => {
|
||||
await Promise.resolve(hooks.autoUpdateChecker?.event?.(input));
|
||||
await Promise.resolve(hooks.claudeCodeHooks?.event?.(input));
|
||||
await Promise.resolve(hooks.backgroundNotificationHook?.event?.(input));
|
||||
await Promise.resolve(hooks.sessionNotification?.(input));
|
||||
await Promise.resolve(hooks.todoContinuationEnforcer?.handler?.(input));
|
||||
await Promise.resolve(hooks.unstableAgentBabysitter?.event?.(input));
|
||||
await Promise.resolve(hooks.contextWindowMonitor?.event?.(input));
|
||||
await Promise.resolve(hooks.directoryAgentsInjector?.event?.(input));
|
||||
await Promise.resolve(hooks.directoryReadmeInjector?.event?.(input));
|
||||
await Promise.resolve(hooks.rulesInjector?.event?.(input));
|
||||
await Promise.resolve(hooks.thinkMode?.event?.(input));
|
||||
await Promise.resolve(hooks.anthropicContextWindowLimitRecovery?.event?.(input));
|
||||
await Promise.resolve(hooks.runtimeFallback?.event?.(input));
|
||||
await Promise.resolve(hooks.agentUsageReminder?.event?.(input));
|
||||
await Promise.resolve(hooks.categorySkillReminder?.event?.(input));
|
||||
await Promise.resolve(hooks.interactiveBashSession?.event?.(input as EventInput));
|
||||
await Promise.resolve(hooks.ralphLoop?.event?.(input));
|
||||
await Promise.resolve(hooks.stopContinuationGuard?.event?.(input));
|
||||
await Promise.resolve(hooks.compactionTodoPreserver?.event?.(input));
|
||||
await Promise.resolve(hooks.writeExistingFileGuard?.event?.(input));
|
||||
await Promise.resolve(hooks.atlasHook?.handler?.(input));
|
||||
await Promise.resolve(hooks.agentSwitchHook?.event?.(input));
|
||||
// Keep agent switch early and resilient so queued handoffs are not blocked
|
||||
// by unrelated hook failures in the same idle cycle.
|
||||
await runHookSafely("agent-switch", () => Promise.resolve(hooks.agentSwitchHook?.event?.(input)));
|
||||
await runHookSafely("auto-update-checker", () => Promise.resolve(hooks.autoUpdateChecker?.event?.(input)));
|
||||
await runHookSafely("claude-code-hooks", () => Promise.resolve(hooks.claudeCodeHooks?.event?.(input)));
|
||||
await runHookSafely("background-notification", () => Promise.resolve(hooks.backgroundNotificationHook?.event?.(input)));
|
||||
await runHookSafely("session-notification", () => Promise.resolve(hooks.sessionNotification?.(input)));
|
||||
await runHookSafely("todo-continuation-enforcer", () => Promise.resolve(hooks.todoContinuationEnforcer?.handler?.(input)));
|
||||
await runHookSafely("unstable-agent-babysitter", () => Promise.resolve(hooks.unstableAgentBabysitter?.event?.(input)));
|
||||
await runHookSafely("context-window-monitor", () => Promise.resolve(hooks.contextWindowMonitor?.event?.(input)));
|
||||
await runHookSafely("directory-agents-injector", () => Promise.resolve(hooks.directoryAgentsInjector?.event?.(input)));
|
||||
await runHookSafely("directory-readme-injector", () => Promise.resolve(hooks.directoryReadmeInjector?.event?.(input)));
|
||||
await runHookSafely("rules-injector", () => Promise.resolve(hooks.rulesInjector?.event?.(input)));
|
||||
await runHookSafely("think-mode", () => Promise.resolve(hooks.thinkMode?.event?.(input)));
|
||||
await runHookSafely("anthropic-context-window-limit-recovery", () => Promise.resolve(hooks.anthropicContextWindowLimitRecovery?.event?.(input)));
|
||||
await runHookSafely("runtime-fallback", () => Promise.resolve(hooks.runtimeFallback?.event?.(input)));
|
||||
await runHookSafely("agent-usage-reminder", () => Promise.resolve(hooks.agentUsageReminder?.event?.(input)));
|
||||
await runHookSafely("category-skill-reminder", () => Promise.resolve(hooks.categorySkillReminder?.event?.(input)));
|
||||
await runHookSafely("interactive-bash-session", () => Promise.resolve(hooks.interactiveBashSession?.event?.(input as EventInput)));
|
||||
await runHookSafely("ralph-loop", () => Promise.resolve(hooks.ralphLoop?.event?.(input)));
|
||||
await runHookSafely("stop-continuation-guard", () => Promise.resolve(hooks.stopContinuationGuard?.event?.(input)));
|
||||
await runHookSafely("compaction-todo-preserver", () => Promise.resolve(hooks.compactionTodoPreserver?.event?.(input)));
|
||||
await runHookSafely("write-existing-file-guard", () => Promise.resolve(hooks.writeExistingFileGuard?.event?.(input)));
|
||||
await runHookSafely("atlas", () => Promise.resolve(hooks.atlasHook?.handler?.(input)));
|
||||
};
|
||||
|
||||
const recentSyntheticIdles = new Map<string, number>();
|
||||
|
||||
@@ -55,7 +55,6 @@ export function createToolRegistry(args: {
|
||||
const athenaCouncilTool = createAthenaCouncilTool({
|
||||
backgroundManager: managers.backgroundManager,
|
||||
councilConfig: athenaCouncilConfig,
|
||||
client: ctx.client,
|
||||
})
|
||||
|
||||
const isMultimodalLookerEnabled = !(pluginConfig.disabled_agents ?? []).some(
|
||||
@@ -135,7 +134,9 @@ export function createToolRegistry(args: {
|
||||
...backgroundTools,
|
||||
call_omo_agent: callOmoAgent,
|
||||
athena_council: athenaCouncilTool,
|
||||
switch_agent: createSwitchAgentTool(),
|
||||
switch_agent: createSwitchAgentTool({
|
||||
client: ctx.client,
|
||||
}),
|
||||
...(lookAt ? { look_at: lookAt } : {}),
|
||||
task: delegateTask,
|
||||
skill_mcp: skillMcpTool,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
/// <reference types="bun-types" />
|
||||
|
||||
import { describe, test, expect, beforeEach } from "bun:test"
|
||||
import { createSwitchAgentTool } from "./tools"
|
||||
import { consumePendingSwitch, _resetForTesting as resetSwitch } from "../../features/agent-switch"
|
||||
@@ -20,11 +22,36 @@ describe("switch_agent tool", () => {
|
||||
resetSession()
|
||||
})
|
||||
|
||||
function createToolWithMockClient(promptImpl?: () => Promise<unknown>) {
|
||||
const client = {
|
||||
session: {
|
||||
promptAsync:
|
||||
promptImpl ??
|
||||
(async () => {
|
||||
return undefined
|
||||
}),
|
||||
messages: async () => ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
return createSwitchAgentTool({
|
||||
client: client as unknown as {
|
||||
session: {
|
||||
promptAsync: (input: {
|
||||
path: { id: string }
|
||||
body: { agent: string; parts: Array<{ type: "text"; text: string }> }
|
||||
}) => Promise<unknown>
|
||||
messages: (input: { path: { id: string } }) => Promise<unknown>
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
//#given valid atlas switch args
|
||||
//#when execute is called
|
||||
//#then it stores pending switch and updates session agent
|
||||
test("should queue switch to atlas", async () => {
|
||||
const tool = createSwitchAgentTool()
|
||||
const tool = createToolWithMockClient()
|
||||
const result = await tool.execute(
|
||||
{ agent: "atlas", context: "Fix the auth bug based on council findings" },
|
||||
toolContext
|
||||
@@ -46,7 +73,7 @@ describe("switch_agent tool", () => {
|
||||
//#when execute is called
|
||||
//#then it stores pending switch for prometheus
|
||||
test("should queue switch to prometheus", async () => {
|
||||
const tool = createSwitchAgentTool()
|
||||
const tool = createToolWithMockClient()
|
||||
const result = await tool.execute(
|
||||
{ agent: "Prometheus", context: "Create a plan for the refactoring" },
|
||||
toolContext
|
||||
@@ -63,7 +90,7 @@ describe("switch_agent tool", () => {
|
||||
//#when execute is called
|
||||
//#then it returns an error
|
||||
test("should reject invalid agent names", async () => {
|
||||
const tool = createSwitchAgentTool()
|
||||
const tool = createToolWithMockClient()
|
||||
const result = await tool.execute(
|
||||
{ agent: "librarian", context: "Some context" },
|
||||
toolContext
|
||||
@@ -78,7 +105,7 @@ describe("switch_agent tool", () => {
|
||||
//#when execute is called
|
||||
//#then it normalizes to lowercase
|
||||
test("should handle case-insensitive agent names", async () => {
|
||||
const tool = createSwitchAgentTool()
|
||||
const tool = createToolWithMockClient()
|
||||
await tool.execute(
|
||||
{ agent: "ATLAS", context: "Fix things" },
|
||||
toolContext
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { tool, type ToolDefinition } from "@opencode-ai/plugin"
|
||||
import { setPendingSwitch } from "../../features/agent-switch"
|
||||
import { schedulePendingSwitchApply } from "../../features/agent-switch/scheduler"
|
||||
import { updateSessionAgent } from "../../features/claude-code-session-state"
|
||||
import type { SwitchAgentArgs } from "./types"
|
||||
|
||||
@@ -10,7 +11,26 @@ const DESCRIPTION =
|
||||
|
||||
const ALLOWED_AGENTS = new Set(["atlas", "prometheus", "sisyphus", "hephaestus"])
|
||||
|
||||
export function createSwitchAgentTool(): ToolDefinition {
|
||||
type SessionClient = {
|
||||
session: {
|
||||
prompt?: (input: {
|
||||
path: { id: string }
|
||||
body: { agent: string; parts: Array<{ type: "text"; text: string }> }
|
||||
}) => Promise<unknown>
|
||||
promptAsync: (input: {
|
||||
path: { id: string }
|
||||
body: { agent: string; parts: Array<{ type: "text"; text: string }> }
|
||||
}) => Promise<unknown>
|
||||
messages: (input: { path: { id: string } }) => Promise<unknown>
|
||||
status?: () => Promise<unknown>
|
||||
}
|
||||
}
|
||||
|
||||
export function createSwitchAgentTool(args: {
|
||||
client: SessionClient
|
||||
}): ToolDefinition {
|
||||
const { client } = args
|
||||
|
||||
return tool({
|
||||
description: DESCRIPTION,
|
||||
args: {
|
||||
@@ -30,6 +50,10 @@ export function createSwitchAgentTool(): ToolDefinition {
|
||||
|
||||
updateSessionAgent(toolContext.sessionID, agentName)
|
||||
setPendingSwitch(toolContext.sessionID, agentName, args.context)
|
||||
schedulePendingSwitchApply({
|
||||
sessionID: toolContext.sessionID,
|
||||
client,
|
||||
})
|
||||
|
||||
return `Agent switch queued. Session will switch to ${agentName} when your turn completes.`
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user