diff --git a/src/hooks/todo-continuation-enforcer.test.ts b/src/hooks/todo-continuation-enforcer.test.ts index 8f6c6f7e4..32e28bf2b 100644 --- a/src/hooks/todo-continuation-enforcer.test.ts +++ b/src/hooks/todo-continuation-enforcer.test.ts @@ -349,6 +349,25 @@ describe("todo-continuation-enforcer", () => { expect(promptCalls).toHaveLength(0) }) + test("should accept skipAgents option without error", async () => { + // #given - session with skipAgents configured for Prometheus + const sessionID = "main-prometheus-option" + setMainSession(sessionID) + + // #when - create hook with skipAgents option (should not throw) + const hook = createTodoContinuationEnforcer(createMockPluginInput(), { + skipAgents: ["Prometheus (Planner)", "custom-agent"], + }) + + // #then - handler works without error + await hook.handler({ + event: { type: "session.idle", properties: { sessionID } }, + }) + + await new Promise(r => setTimeout(r, 100)) + expect(toastCalls.length).toBeGreaterThanOrEqual(1) + }) + test("should show countdown toast updates", async () => { // #given - session with incomplete todos const sessionID = "main-toast" diff --git a/src/hooks/todo-continuation-enforcer.ts b/src/hooks/todo-continuation-enforcer.ts index 5e16354d7..9f843f06c 100644 --- a/src/hooks/todo-continuation-enforcer.ts +++ b/src/hooks/todo-continuation-enforcer.ts @@ -11,8 +11,11 @@ import { log } from "../shared/logger" const HOOK_NAME = "todo-continuation-enforcer" +const DEFAULT_SKIP_AGENTS = ["Prometheus (Planner)"] + export interface TodoContinuationEnforcerOptions { backgroundManager?: BackgroundManager + skipAgents?: string[] } export interface TodoContinuationEnforcer { @@ -89,7 +92,7 @@ export function createTodoContinuationEnforcer( ctx: PluginInput, options: TodoContinuationEnforcerOptions = {} ): TodoContinuationEnforcer { - const { backgroundManager } = options + const { backgroundManager, skipAgents = DEFAULT_SKIP_AGENTS } = options const sessions = new Map() function getState(sessionID: string): SessionState { @@ -184,17 +187,19 @@ export function createTodoContinuationEnforcer( const messageDir = getMessageDir(sessionID) const prevMessage = messageDir ? findNearestMessageWithFields(messageDir) : null - const hasWritePermission = !prevMessage?.tools || - (prevMessage.tools.write !== false && prevMessage.tools.edit !== false) - - if (!hasWritePermission) { - log(`[${HOOK_NAME}] Skipped: agent lacks write permission`, { sessionID, agent: prevMessage?.agent }) + const agentName = prevMessage?.agent + if (agentName && skipAgents.includes(agentName)) { + log(`[${HOOK_NAME}] Skipped: agent in skipAgents list`, { sessionID, agent: agentName }) return } - const agentName = prevMessage?.agent?.toLowerCase() ?? "" - if (agentName === "plan" || agentName === "planner-sisyphus") { - log(`[${HOOK_NAME}] Skipped: plan mode agent`, { sessionID, agent: prevMessage?.agent }) + const editPermission = prevMessage?.tools?.edit + const writePermission = prevMessage?.tools?.write + const hasWritePermission = !prevMessage?.tools || + ((editPermission !== false && editPermission !== "deny") && + (writePermission !== false && writePermission !== "deny")) + if (!hasWritePermission) { + log(`[${HOOK_NAME}] Skipped: agent lacks write permission`, { sessionID, agent: prevMessage?.agent }) return } @@ -324,6 +329,28 @@ export function createTodoContinuationEnforcer( return } + let agentName: string | undefined + try { + const messagesResp = await ctx.client.session.messages({ + path: { id: sessionID }, + }) + const messages = (messagesResp.data ?? []) as Array<{ info?: { agent?: string } }> + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].info?.agent) { + agentName = messages[i].info?.agent + break + } + } + } catch (err) { + log(`[${HOOK_NAME}] Failed to fetch messages for agent check`, { sessionID, error: String(err) }) + } + + log(`[${HOOK_NAME}] Agent check`, { sessionID, agentName, skipAgents }) + if (agentName && skipAgents.includes(agentName)) { + log(`[${HOOK_NAME}] Skipped: agent in skipAgents list`, { sessionID, agent: agentName }) + return + } + startCountdown(sessionID, incompleteCount, todos.length) return }