From f4a9d0c3aa3649327978ca211cf3a395197069aa Mon Sep 17 00:00:00 2001 From: YeonGyu-Kim Date: Fri, 6 Feb 2026 11:21:45 +0900 Subject: [PATCH] feat(hooks): implement task-continuation-enforcer with TDD Mirrors todo-continuation-enforcer but reads from file-based task storage instead of OpenCode's todo API. Includes 19 tests covering all skip conditions, abort detection, countdown, and recovery scenarios. --- src/hooks/task-continuation-enforcer.test.ts | 763 +++++++++++++++++++ src/hooks/task-continuation-enforcer.ts | 530 +++++++++++++ 2 files changed, 1293 insertions(+) create mode 100644 src/hooks/task-continuation-enforcer.test.ts create mode 100644 src/hooks/task-continuation-enforcer.ts diff --git a/src/hooks/task-continuation-enforcer.test.ts b/src/hooks/task-continuation-enforcer.test.ts new file mode 100644 index 000000000..1a0cbc75d --- /dev/null +++ b/src/hooks/task-continuation-enforcer.test.ts @@ -0,0 +1,763 @@ +import { afterEach, beforeEach, describe, expect, test } from "bun:test" + +import { mkdtempSync, rmSync, writeFileSync } from "node:fs" +import { tmpdir } from "node:os" +import { join } from "node:path" + +import { BackgroundManager } from "../features/background-agent" +import { setMainSession, subagentSessions, _resetForTesting } from "../features/claude-code-session-state" +import type { OhMyOpenCodeConfig } from "../config/schema" +import { TaskObjectSchema } from "../tools/task/types" +import type { TaskObject } from "../tools/task/types" +import { createTaskContinuationEnforcer } from "./task-continuation-enforcer" + +type TimerCallback = (...args: any[]) => void + +interface FakeTimers { + advanceBy: (ms: number, advanceClock?: boolean) => Promise + restore: () => void +} + +function createFakeTimers(): FakeTimers { + const originalNow = Date.now() + let clockNow = originalNow + let timerNow = 0 + let nextId = 1 + const timers = new Map() + const cleared = new Set() + + const original = { + setTimeout: globalThis.setTimeout, + clearTimeout: globalThis.clearTimeout, + setInterval: globalThis.setInterval, + clearInterval: globalThis.clearInterval, + dateNow: Date.now, + } + + const normalizeDelay = (delay?: number) => { + if (typeof delay !== "number" || !Number.isFinite(delay)) return 0 + return delay < 0 ? 0 : delay + } + + const schedule = (callback: TimerCallback, delay: number | undefined, interval: number | null, args: any[]) => { + const id = nextId++ + timers.set(id, { + id, + time: timerNow + normalizeDelay(delay), + interval, + callback, + args, + }) + return id + } + + const clear = (id: number | undefined) => { + if (typeof id !== "number") return + cleared.add(id) + timers.delete(id) + } + + globalThis.setTimeout = ((callback: TimerCallback, delay?: number, ...args: any[]) => { + return schedule(callback, delay, null, args) as unknown as ReturnType + }) as typeof setTimeout + + globalThis.setInterval = ((callback: TimerCallback, delay?: number, ...args: any[]) => { + const interval = normalizeDelay(delay) + return schedule(callback, delay, interval, args) as unknown as ReturnType + }) as typeof setInterval + + globalThis.clearTimeout = ((id?: number) => { + clear(id) + }) as typeof clearTimeout + + globalThis.clearInterval = ((id?: number) => { + clear(id) + }) as typeof clearInterval + + Date.now = () => clockNow + + const advanceBy = async (ms: number, advanceClock: boolean = false) => { + const clamped = Math.max(0, ms) + const target = timerNow + clamped + if (advanceClock) { + clockNow += clamped + } + while (true) { + let next: { id: number; time: number; interval: number | null; callback: TimerCallback; args: any[] } | undefined + for (const timer of timers.values()) { + if (timer.time <= target && (!next || timer.time < next.time)) { + next = timer + } + } + if (!next) break + + timerNow = next.time + timers.delete(next.id) + next.callback(...next.args) + + if (next.interval !== null && !cleared.has(next.id)) { + timers.set(next.id, { + id: next.id, + time: timerNow + next.interval, + interval: next.interval, + callback: next.callback, + args: next.args, + }) + } else { + cleared.delete(next.id) + } + + await Promise.resolve() + } + timerNow = target + await Promise.resolve() + } + + const restore = () => { + globalThis.setTimeout = original.setTimeout + globalThis.clearTimeout = original.clearTimeout + globalThis.setInterval = original.setInterval + globalThis.clearInterval = original.clearInterval + Date.now = original.dateNow + } + + return { advanceBy, restore } +} + +const wait = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)) + +describe("task-continuation-enforcer", () => { + let promptCalls: Array<{ sessionID: string; agent?: string; model?: { providerID?: string; modelID?: string }; text: string }> + let toastCalls: Array<{ title: string; message: string }> + let fakeTimers: FakeTimers + let taskDir: string + + interface MockMessage { + info: { + id: string + role: "user" | "assistant" + error?: { name: string; data?: { message: string } } + } + } + + let mockMessages: MockMessage[] = [] + + function createMockPluginInput() { + return { + client: { + session: { + messages: async () => ({ data: mockMessages }), + prompt: async (opts: any) => { + promptCalls.push({ + sessionID: opts.path.id, + agent: opts.body.agent, + model: opts.body.model, + text: opts.body.parts[0].text, + }) + return {} + }, + }, + tui: { + showToast: async (opts: any) => { + toastCalls.push({ + title: opts.body.title, + message: opts.body.message, + }) + return {} + }, + }, + }, + directory: "/tmp/test", + } as any + } + + function createTempTaskDir(): string { + return mkdtempSync(join(tmpdir(), "omo-task-continuation-")) + } + + function writeTaskFile(dir: string, task: TaskObject): void { + const parsed = TaskObjectSchema.safeParse(task) + expect(parsed.success).toBe(true) + if (!parsed.success) return + writeFileSync(join(dir, `${parsed.data.id}.json`), JSON.stringify(parsed.data), "utf-8") + } + + function writeCorruptedTaskFile(dir: string, taskId: string): void { + writeFileSync(join(dir, `${taskId}.json`), "{ this is not valid json", "utf-8") + } + + function createConfig(dir: string): Partial { + return { + sisyphus: { + tasks: { + claude_code_compat: true, + storage_path: dir, + }, + }, + } + } + + function createMockBackgroundManager(runningTasks: boolean = false): BackgroundManager { + return { + getTasksByParentSession: () => (runningTasks ? [{ status: "running" }] : []), + } as any + } + + beforeEach(() => { + fakeTimers = createFakeTimers() + _resetForTesting() + promptCalls = [] + toastCalls = [] + mockMessages = [] + taskDir = createTempTaskDir() + }) + + afterEach(() => { + fakeTimers.restore() + _resetForTesting() + rmSync(taskDir, { recursive: true, force: true }) + }) + + test("should inject continuation when idle with incomplete tasks on disk", async () => { + fakeTimers.restore() + // given - main session with incomplete tasks + const sessionID = "main-123" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + writeTaskFile(taskDir, { + id: "T-2", + subject: "Task 2", + description: "", + status: "completed", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), { + backgroundManager: new BackgroundManager(createMockPluginInput()), + }) + + // when - session goes idle + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + + // then - countdown toast shown + await wait(50) + expect(toastCalls.length).toBeGreaterThanOrEqual(1) + expect(toastCalls[0].title).toBe("Task Continuation") + + // then - after countdown, continuation injected + await wait(2500) + expect(promptCalls.length).toBe(1) + expect(promptCalls[0].text).toContain("TASK CONTINUATION") + }, { timeout: 15000 }) + + test("should NOT inject when all tasks are completed", async () => { + // given - session with all tasks completed + const sessionID = "main-456" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "completed", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when - session goes idle + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(3000) + + // then - no continuation injected + expect(promptCalls).toHaveLength(0) + }) + + test("should NOT inject when all tasks are deleted", async () => { + // given - session with all tasks deleted + const sessionID = "main-deleted" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "deleted", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) + + test("should NOT inject when no task files exist", async () => { + // given - empty task directory + const sessionID = "main-none" + setMainSession(sessionID) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) + + test("should NOT inject when background tasks are running", async () => { + // given - session with incomplete tasks and running background tasks + const sessionID = "main-bg-running" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), { + backgroundManager: createMockBackgroundManager(true), + }) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) + + test("should NOT inject for non-main session", async () => { + // given - main session set, different session goes idle + setMainSession("main-session") + const otherSession = "other-session" + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID: otherSession } } }) + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) + + test("should inject for background task session (subagent)", async () => { + fakeTimers.restore() + // given - main session set, background task session registered + setMainSession("main-session") + const bgTaskSession = "bg-task-session" + subagentSessions.add(bgTaskSession) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID: bgTaskSession } } }) + + // then + await wait(2500) + expect(promptCalls.length).toBe(1) + expect(promptCalls[0].sessionID).toBe(bgTaskSession) + }, { timeout: 15000 }) + + test("should cancel countdown on user message after grace period", async () => { + // given + const sessionID = "main-cancel" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when - session goes idle + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + + // when - wait past grace period (500ms), then user sends message + await fakeTimers.advanceBy(600, true) + await hook.handler({ + event: { + type: "message.updated", + properties: { info: { sessionID, role: "user" } }, + }, + }) + + // then + await fakeTimers.advanceBy(2500) + expect(promptCalls).toHaveLength(0) + }) + + test("should ignore user message within grace period", async () => { + fakeTimers.restore() + // given + const sessionID = "main-grace" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await hook.handler({ + event: { + type: "message.updated", + properties: { info: { sessionID, role: "user" } }, + }, + }) + + // then - countdown should continue + await wait(2500) + expect(promptCalls).toHaveLength(1) + }, { timeout: 15000 }) + + test("should cancel countdown on assistant activity", async () => { + // given + const sessionID = "main-assistant" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(500) + await hook.handler({ + event: { + type: "message.part.updated", + properties: { info: { sessionID, role: "assistant" } }, + }, + }) + + // then + await fakeTimers.advanceBy(3000) + expect(promptCalls).toHaveLength(0) + }) + + test("should cancel countdown on tool execution", async () => { + // given + const sessionID = "main-tool" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(500) + await hook.handler({ event: { type: "tool.execute.before", properties: { sessionID } } }) + + // then + await fakeTimers.advanceBy(3000) + expect(promptCalls).toHaveLength(0) + }) + + test("should skip injection during recovery mode", async () => { + // given + const sessionID = "main-recovery" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + hook.markRecovering(sessionID) + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) + + test("should inject after recovery complete", async () => { + fakeTimers.restore() + // given + const sessionID = "main-recovery-done" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + hook.markRecovering(sessionID) + hook.markRecoveryComplete(sessionID) + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + + // then + await wait(3000) + expect(promptCalls.length).toBe(1) + }, { timeout: 15000 }) + + test("should cleanup on session deleted", async () => { + // given + const sessionID = "main-delete" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(500) + await hook.handler({ event: { type: "session.deleted", properties: { info: { id: sessionID } } } }) + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) + + test("should skip when last assistant message was aborted (API fallback)", async () => { + // given + const sessionID = "main-api-abort" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant", error: { name: "MessageAbortedError", data: { message: "aborted" } } } }, + ] + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) + + test("should skip when abort detected via session.error event", async () => { + // given + const sessionID = "main-event-abort" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant" } }, + ] + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when - abort error event fires + await hook.handler({ + event: { + type: "session.error", + properties: { sessionID, error: { name: "MessageAbortedError" } }, + }, + }) + + // when - session goes idle immediately after + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) + + test("should handle corrupted task files gracefully (readJsonSafe returns null)", async () => { + fakeTimers.restore() + // given + const sessionID = "main-corrupt" + setMainSession(sessionID) + + writeCorruptedTaskFile(taskDir, "T-corrupt") + writeTaskFile(taskDir, { + id: "T-ok", + subject: "Task OK", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await wait(2500) + + // then + expect(promptCalls).toHaveLength(1) + }, { timeout: 15000 }) + + test("should NOT inject when isContinuationStopped returns true", async () => { + // given + const sessionID = "main-stopped" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), { + isContinuationStopped: (id) => id === sessionID, + }) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) + + test("should cancel all countdowns via cancelAllCountdowns", async () => { + // given + const sessionID = "main-cancel-all" + setMainSession(sessionID) + + writeTaskFile(taskDir, { + id: "T-1", + subject: "Task 1", + description: "", + status: "pending", + blocks: [], + blockedBy: [], + threadID: "test", + }) + + const hook = createTaskContinuationEnforcer(createMockPluginInput(), createConfig(taskDir), {}) + + // when + await hook.handler({ event: { type: "session.idle", properties: { sessionID } } }) + await fakeTimers.advanceBy(500) + hook.cancelAllCountdowns() + await fakeTimers.advanceBy(3000) + + // then + expect(promptCalls).toHaveLength(0) + }) +}) diff --git a/src/hooks/task-continuation-enforcer.ts b/src/hooks/task-continuation-enforcer.ts new file mode 100644 index 000000000..f3b7f9c54 --- /dev/null +++ b/src/hooks/task-continuation-enforcer.ts @@ -0,0 +1,530 @@ +import type { PluginInput } from "@opencode-ai/plugin" +import { existsSync, readdirSync } from "node:fs" +import { join } from "node:path" + +import type { BackgroundManager } from "../features/background-agent" +import { getMainSessionID, subagentSessions } from "../features/claude-code-session-state" +import { + findNearestMessageWithFields, + MESSAGE_STORAGE, + type ToolPermission, +} from "../features/hook-message-injector" +import { listTaskFiles, readJsonSafe, getTaskDir } from "../features/claude-tasks/storage" +import type { OhMyOpenCodeConfig } from "../config/schema" +import { TaskObjectSchema } from "../tools/task/types" +import type { TaskObject } from "../tools/task/types" +import { log } from "../shared/logger" +import { createSystemDirective, SystemDirectiveTypes } from "../shared/system-directive" + +const HOOK_NAME = "task-continuation-enforcer" + +const DEFAULT_SKIP_AGENTS = ["prometheus", "compaction"] + +export interface TaskContinuationEnforcerOptions { + backgroundManager?: BackgroundManager + skipAgents?: string[] + isContinuationStopped?: (sessionID: string) => boolean +} + +export interface TaskContinuationEnforcer { + handler: (input: { event: { type: string; properties?: unknown } }) => Promise + markRecovering: (sessionID: string) => void + markRecoveryComplete: (sessionID: string) => void + cancelAllCountdowns: () => void +} + +interface SessionState { + countdownTimer?: ReturnType + countdownInterval?: ReturnType + isRecovering?: boolean + countdownStartedAt?: number + abortDetectedAt?: number +} + +const CONTINUATION_PROMPT = `${createSystemDirective(SystemDirectiveTypes.TASK_CONTINUATION)} + +Incomplete tasks remain in your task list. Continue working on the next pending task. + +- Proceed without asking for permission +- Mark each task complete when finished +- Do not stop until all tasks are done` + +const COUNTDOWN_SECONDS = 2 +const TOAST_DURATION_MS = 900 +const COUNTDOWN_GRACE_PERIOD_MS = 500 + +function getMessageDir(sessionID: string): string | null { + if (!existsSync(MESSAGE_STORAGE)) return null + + const directPath = join(MESSAGE_STORAGE, sessionID) + if (existsSync(directPath)) return directPath + + for (const dir of readdirSync(MESSAGE_STORAGE)) { + const sessionPath = join(MESSAGE_STORAGE, dir, sessionID) + if (existsSync(sessionPath)) return sessionPath + } + + return null +} + +function getIncompleteCount(tasks: TaskObject[]): number { + return tasks.filter(t => t.status !== "completed" && t.status !== "deleted").length +} + +interface MessageInfo { + id?: string + role?: string + error?: { name?: string; data?: unknown } +} + +function isLastAssistantMessageAborted(messages: Array<{ info?: MessageInfo }>): boolean { + if (!messages || messages.length === 0) return false + + const assistantMessages = messages.filter(m => m.info?.role === "assistant") + if (assistantMessages.length === 0) return false + + const lastAssistant = assistantMessages[assistantMessages.length - 1] + const errorName = lastAssistant.info?.error?.name + + if (!errorName) return false + + return errorName === "MessageAbortedError" || errorName === "AbortError" +} + +function loadTasksFromDisk(config: Partial): TaskObject[] { + const taskIds = listTaskFiles(config) + const taskDirectory = getTaskDir(config) + const tasks: TaskObject[] = [] + + for (const id of taskIds) { + const task = readJsonSafe(join(taskDirectory, `${id}.json`), TaskObjectSchema) + if (task) tasks.push(task) + } + + return tasks +} + +export function createTaskContinuationEnforcer( + ctx: PluginInput, + config: Partial, + options: TaskContinuationEnforcerOptions = {} +): TaskContinuationEnforcer { + const { backgroundManager, skipAgents = DEFAULT_SKIP_AGENTS, isContinuationStopped } = options + const sessions = new Map() + + function getState(sessionID: string): SessionState { + let state = sessions.get(sessionID) + if (!state) { + state = {} + sessions.set(sessionID, state) + } + return state + } + + function cancelCountdown(sessionID: string): void { + const state = sessions.get(sessionID) + if (!state) return + + if (state.countdownTimer) { + clearTimeout(state.countdownTimer) + state.countdownTimer = undefined + } + if (state.countdownInterval) { + clearInterval(state.countdownInterval) + state.countdownInterval = undefined + } + state.countdownStartedAt = undefined + } + + function cleanup(sessionID: string): void { + cancelCountdown(sessionID) + sessions.delete(sessionID) + } + + const markRecovering = (sessionID: string): void => { + const state = getState(sessionID) + state.isRecovering = true + cancelCountdown(sessionID) + log(`[${HOOK_NAME}] Session marked as recovering`, { sessionID }) + } + + const markRecoveryComplete = (sessionID: string): void => { + const state = sessions.get(sessionID) + if (state) { + state.isRecovering = false + log(`[${HOOK_NAME}] Session recovery complete`, { sessionID }) + } + } + + async function showCountdownToast(seconds: number, incompleteCount: number): Promise { + await ctx.client.tui + .showToast({ + body: { + title: "Task Continuation", + message: `Resuming in ${seconds}s... (${incompleteCount} tasks remaining)`, + variant: "warning" as const, + duration: TOAST_DURATION_MS, + }, + }) + .catch(() => {}) + } + + interface ResolvedMessageInfo { + agent?: string + model?: { providerID: string; modelID: string } + tools?: Record + } + + async function injectContinuation( + sessionID: string, + incompleteCount: number, + total: number, + resolvedInfo?: ResolvedMessageInfo + ): Promise { + const state = sessions.get(sessionID) + + if (state?.isRecovering) { + log(`[${HOOK_NAME}] Skipped injection: in recovery`, { sessionID }) + return + } + + const hasRunningBgTasks = backgroundManager + ? backgroundManager.getTasksByParentSession(sessionID).some(t => t.status === "running") + : false + + if (hasRunningBgTasks) { + log(`[${HOOK_NAME}] Skipped injection: background tasks running`, { sessionID }) + return + } + + const tasks = loadTasksFromDisk(config) + const freshIncompleteCount = getIncompleteCount(tasks) + if (freshIncompleteCount === 0) { + log(`[${HOOK_NAME}] Skipped injection: no incomplete tasks`, { sessionID }) + return + } + + let agentName = resolvedInfo?.agent + let model = resolvedInfo?.model + let tools = resolvedInfo?.tools + + if (!agentName || !model) { + const messageDir = getMessageDir(sessionID) + const prevMessage = messageDir ? findNearestMessageWithFields(messageDir) : null + agentName = agentName ?? prevMessage?.agent + model = + model ?? + (prevMessage?.model?.providerID && prevMessage?.model?.modelID + ? { + providerID: prevMessage.model.providerID, + modelID: prevMessage.model.modelID, + ...(prevMessage.model.variant ? { variant: prevMessage.model.variant } : {}), + } + : undefined) + tools = tools ?? prevMessage?.tools + } + + if (agentName && skipAgents.includes(agentName)) { + log(`[${HOOK_NAME}] Skipped: agent in skipAgents list`, { sessionID, agent: agentName }) + return + } + + const editPermission = tools?.edit + const writePermission = tools?.write + const hasWritePermission = + !tools || + (editPermission !== false && editPermission !== "deny" && writePermission !== false && writePermission !== "deny") + if (!hasWritePermission) { + log(`[${HOOK_NAME}] Skipped: agent lacks write permission`, { sessionID, agent: agentName }) + return + } + + const incompleteTasks = tasks.filter(t => t.status !== "completed" && t.status !== "deleted") + const taskList = incompleteTasks.map(t => `- [${t.status}] ${t.subject}`).join("\n") + const prompt = `${CONTINUATION_PROMPT} + +[Status: ${tasks.length - freshIncompleteCount}/${tasks.length} completed, ${freshIncompleteCount} remaining] + +Remaining tasks: +${taskList}` + + try { + log(`[${HOOK_NAME}] Injecting continuation`, { + sessionID, + agent: agentName, + model, + incompleteCount: freshIncompleteCount, + }) + + await ctx.client.session.prompt({ + path: { id: sessionID }, + body: { + agent: agentName, + ...(model !== undefined ? { model } : {}), + parts: [{ type: "text", text: prompt }], + }, + query: { directory: ctx.directory }, + }) + + log(`[${HOOK_NAME}] Injection successful`, { sessionID }) + } catch (err) { + log(`[${HOOK_NAME}] Injection failed`, { sessionID, error: String(err) }) + } + } + + function startCountdown( + sessionID: string, + incompleteCount: number, + total: number, + resolvedInfo?: ResolvedMessageInfo + ): void { + const state = getState(sessionID) + cancelCountdown(sessionID) + + let secondsRemaining = COUNTDOWN_SECONDS + showCountdownToast(secondsRemaining, incompleteCount) + state.countdownStartedAt = Date.now() + + state.countdownInterval = setInterval(() => { + secondsRemaining-- + if (secondsRemaining > 0) { + showCountdownToast(secondsRemaining, incompleteCount) + } + }, 1000) + + state.countdownTimer = setTimeout(() => { + cancelCountdown(sessionID) + injectContinuation(sessionID, incompleteCount, total, resolvedInfo) + }, COUNTDOWN_SECONDS * 1000) + + log(`[${HOOK_NAME}] Countdown started`, { sessionID, seconds: COUNTDOWN_SECONDS, incompleteCount }) + } + + const handler = async ({ event }: { event: { type: string; properties?: unknown } }): Promise => { + const props = event.properties as Record | undefined + + if (event.type === "session.error") { + const sessionID = props?.sessionID as string | undefined + if (!sessionID) return + + const error = props?.error as { name?: string } | undefined + if (error?.name === "MessageAbortedError" || error?.name === "AbortError") { + const state = getState(sessionID) + state.abortDetectedAt = Date.now() + log(`[${HOOK_NAME}] Abort detected via session.error`, { sessionID, errorName: error.name }) + } + + cancelCountdown(sessionID) + log(`[${HOOK_NAME}] session.error`, { sessionID }) + return + } + + if (event.type === "session.idle") { + const sessionID = props?.sessionID as string | undefined + if (!sessionID) return + + log(`[${HOOK_NAME}] session.idle`, { sessionID }) + + const mainSessionID = getMainSessionID() + const isMainSession = sessionID === mainSessionID + const isBackgroundTaskSession = subagentSessions.has(sessionID) + + if (mainSessionID && !isMainSession && !isBackgroundTaskSession) { + log(`[${HOOK_NAME}] Skipped: not main or background task session`, { sessionID }) + return + } + + const state = getState(sessionID) + + if (state.isRecovering) { + log(`[${HOOK_NAME}] Skipped: in recovery`, { sessionID }) + return + } + + // Check 1: Event-based abort detection (primary, most reliable) + if (state.abortDetectedAt) { + const timeSinceAbort = Date.now() - state.abortDetectedAt + const ABORT_WINDOW_MS = 3000 + if (timeSinceAbort < ABORT_WINDOW_MS) { + log(`[${HOOK_NAME}] Skipped: abort detected via event ${timeSinceAbort}ms ago`, { sessionID }) + state.abortDetectedAt = undefined + return + } + state.abortDetectedAt = undefined + } + + const hasRunningBgTasks = backgroundManager + ? backgroundManager.getTasksByParentSession(sessionID).some(t => t.status === "running") + : false + + if (hasRunningBgTasks) { + log(`[${HOOK_NAME}] Skipped: background tasks running`, { sessionID }) + return + } + + // Check 2: API-based abort detection (fallback, for cases where event was missed) + try { + const messagesResp = await ctx.client.session.messages({ + path: { id: sessionID }, + query: { directory: ctx.directory }, + }) + const messages = (messagesResp as { data?: Array<{ info?: MessageInfo }> }).data ?? [] + + if (isLastAssistantMessageAborted(messages)) { + log(`[${HOOK_NAME}] Skipped: last assistant message was aborted (API fallback)`, { sessionID }) + return + } + } catch (err) { + log(`[${HOOK_NAME}] Messages fetch failed, continuing`, { sessionID, error: String(err) }) + } + + const tasks = loadTasksFromDisk(config) + + if (!tasks || tasks.length === 0) { + log(`[${HOOK_NAME}] No tasks`, { sessionID }) + return + } + + const incompleteCount = getIncompleteCount(tasks) + if (incompleteCount === 0) { + log(`[${HOOK_NAME}] All tasks complete`, { sessionID, total: tasks.length }) + return + } + + let resolvedInfo: ResolvedMessageInfo | undefined + let hasCompactionMessage = false + try { + const messagesResp = await ctx.client.session.messages({ + path: { id: sessionID }, + }) + const messages = (messagesResp.data ?? []) as Array<{ + info?: { + agent?: string + model?: { providerID: string; modelID: string } + modelID?: string + providerID?: string + tools?: Record + } + }> + for (let i = messages.length - 1; i >= 0; i--) { + const info = messages[i].info + if (info?.agent === "compaction") { + hasCompactionMessage = true + continue + } + if (info?.agent || info?.model || (info?.modelID && info?.providerID)) { + resolvedInfo = { + agent: info.agent, + model: + info.model ?? + (info.providerID && info.modelID + ? { providerID: info.providerID, modelID: info.modelID } + : undefined), + tools: info.tools, + } + break + } + } + } catch (err) { + log(`[${HOOK_NAME}] Failed to fetch messages for agent check`, { sessionID, error: String(err) }) + } + + log(`[${HOOK_NAME}] Agent check`, { + sessionID, + agentName: resolvedInfo?.agent, + skipAgents, + hasCompactionMessage, + }) + if (resolvedInfo?.agent && skipAgents.includes(resolvedInfo.agent)) { + log(`[${HOOK_NAME}] Skipped: agent in skipAgents list`, { sessionID, agent: resolvedInfo.agent }) + return + } + if (hasCompactionMessage && !resolvedInfo?.agent) { + log(`[${HOOK_NAME}] Skipped: compaction occurred but no agent info resolved`, { sessionID }) + return + } + + if (isContinuationStopped?.(sessionID)) { + log(`[${HOOK_NAME}] Skipped: continuation stopped for session`, { sessionID }) + return + } + + startCountdown(sessionID, incompleteCount, tasks.length, resolvedInfo) + return + } + + if (event.type === "message.updated") { + const info = props?.info as Record | undefined + const sessionID = info?.sessionID as string | undefined + const role = info?.role as string | undefined + + if (!sessionID) return + + if (role === "user") { + const state = sessions.get(sessionID) + if (state?.countdownStartedAt) { + const elapsed = Date.now() - state.countdownStartedAt + if (elapsed < COUNTDOWN_GRACE_PERIOD_MS) { + log(`[${HOOK_NAME}] Ignoring user message in grace period`, { sessionID, elapsed }) + return + } + } + if (state) state.abortDetectedAt = undefined + cancelCountdown(sessionID) + } + + if (role === "assistant") { + const state = sessions.get(sessionID) + if (state) state.abortDetectedAt = undefined + cancelCountdown(sessionID) + } + return + } + + if (event.type === "message.part.updated") { + const info = props?.info as Record | undefined + const sessionID = info?.sessionID as string | undefined + const role = info?.role as string | undefined + + if (sessionID && role === "assistant") { + const state = sessions.get(sessionID) + if (state) state.abortDetectedAt = undefined + cancelCountdown(sessionID) + } + return + } + + if (event.type === "tool.execute.before" || event.type === "tool.execute.after") { + const sessionID = props?.sessionID as string | undefined + if (sessionID) { + const state = sessions.get(sessionID) + if (state) state.abortDetectedAt = undefined + cancelCountdown(sessionID) + } + return + } + + if (event.type === "session.deleted") { + const sessionInfo = props?.info as { id?: string } | undefined + if (sessionInfo?.id) { + cleanup(sessionInfo.id) + log(`[${HOOK_NAME}] Session deleted: cleaned up`, { sessionID: sessionInfo.id }) + } + return + } + } + + const cancelAllCountdowns = (): void => { + for (const sessionID of sessions.keys()) { + cancelCountdown(sessionID) + } + log(`[${HOOK_NAME}] All countdowns cancelled`) + } + + return { + handler, + markRecovering, + markRecoveryComplete, + cancelAllCountdowns, + } +}