diff --git a/src/hooks/stop-continuation-guard/hook.ts b/src/hooks/stop-continuation-guard/hook.ts index f7c49a563..747b7a9b6 100644 --- a/src/hooks/stop-continuation-guard/hook.ts +++ b/src/hooks/stop-continuation-guard/hook.ts @@ -1,4 +1,5 @@ import type { PluginInput } from "@opencode-ai/plugin" +import type { BackgroundManager } from "../../features/background-agent" import { clearContinuationMarker, @@ -8,6 +9,11 @@ import { log } from "../../shared/logger" const HOOK_NAME = "stop-continuation-guard" +type StopContinuationBackgroundManager = Pick< + BackgroundManager, + "getAllDescendantTasks" | "cancelTask" +> + export interface StopContinuationGuard { event: (input: { event: { type: string; properties?: unknown } }) => Promise "chat.message": (input: { sessionID?: string }) => Promise @@ -17,7 +23,10 @@ export interface StopContinuationGuard { } export function createStopContinuationGuardHook( - ctx: PluginInput + ctx: PluginInput, + options?: { + backgroundManager?: StopContinuationBackgroundManager + } ): StopContinuationGuard { const stoppedSessions = new Set() @@ -25,6 +34,38 @@ export function createStopContinuationGuardHook( stoppedSessions.add(sessionID) setContinuationMarkerSource(ctx.directory, sessionID, "stop", "stopped", "continuation stopped") log(`[${HOOK_NAME}] Continuation stopped for session`, { sessionID }) + + const backgroundManager = options?.backgroundManager + if (!backgroundManager) { + return + } + + const cancellableTasks = backgroundManager + .getAllDescendantTasks(sessionID) + .filter((task) => task.status === "running" || task.status === "pending") + + if (cancellableTasks.length === 0) { + return + } + + void Promise.allSettled( + cancellableTasks.map(async (task) => { + await backgroundManager.cancelTask(task.id, { + source: "stop-continuation", + reason: "Continuation stopped via /stop-continuation", + abortSession: task.status === "running", + skipNotification: true, + }) + }) + ).then((results) => { + const cancelledCount = results.filter((result) => result.status === "fulfilled").length + const failedCount = results.length - cancelledCount + log(`[${HOOK_NAME}] Cancelled background tasks for stopped session`, { + sessionID, + cancelledCount, + failedCount, + }) + }) } const isStopped = (sessionID: string): boolean => { diff --git a/src/hooks/stop-continuation-guard/index.test.ts b/src/hooks/stop-continuation-guard/index.test.ts index 9547accf2..a0d08f217 100644 --- a/src/hooks/stop-continuation-guard/index.test.ts +++ b/src/hooks/stop-continuation-guard/index.test.ts @@ -2,9 +2,15 @@ import { afterEach, describe, expect, test } from "bun:test" import { mkdtempSync, rmSync } from "node:fs" import { join } from "node:path" import { tmpdir } from "node:os" +import type { BackgroundManager, BackgroundTask } from "../../features/background-agent" import { readContinuationMarker } from "../../features/run-continuation-state" import { createStopContinuationGuardHook } from "./index" +type CancelCall = { + taskId: string + options?: Parameters[1] +} + describe("stop-continuation-guard", () => { const tempDirs: string[] = [] @@ -34,6 +40,33 @@ describe("stop-continuation-guard", () => { } as any } + function createBackgroundTask(status: BackgroundTask["status"], id: string): BackgroundTask { + return { + id, + status, + description: `${id} description`, + parentSessionID: "parent-session", + parentMessageID: "parent-message", + prompt: "prompt", + agent: "sisyphus-junior", + } + } + + function createMockBackgroundManager(tasks: BackgroundTask[], cancelCalls: CancelCall[]): Pick { + return { + getAllDescendantTasks: () => tasks, + cancelTask: async (taskId: string, options?: Parameters[1]) => { + cancelCalls.push({ taskId, options }) + return true + }, + } + } + + async function flushMicrotasks(): Promise { + await Promise.resolve() + await Promise.resolve() + } + test("should mark session as stopped", () => { // given - a guard hook with no stopped sessions const input = createMockPluginInput() @@ -166,4 +199,31 @@ describe("stop-continuation-guard", () => { // then - should not throw and stopped session remains stopped expect(guard.isStopped("some-session")).toBe(true) }) + + test("should cancel only running and pending background tasks on stop", async () => { + // given - a background manager with mixed task statuses + const cancelCalls: CancelCall[] = [] + const backgroundManager = createMockBackgroundManager( + [ + createBackgroundTask("running", "task-running"), + createBackgroundTask("pending", "task-pending"), + createBackgroundTask("completed", "task-completed"), + ], + cancelCalls, + ) + const guard = createStopContinuationGuardHook(createMockPluginInput(), { + backgroundManager, + }) + + // when - stop continuation is triggered + guard.stop("test-session-bg") + await flushMicrotasks() + + // then - only running and pending tasks are cancelled + expect(cancelCalls).toHaveLength(2) + expect(cancelCalls[0]?.taskId).toBe("task-running") + expect(cancelCalls[0]?.options?.abortSession).toBe(true) + expect(cancelCalls[1]?.taskId).toBe("task-pending") + expect(cancelCalls[1]?.options?.abortSession).toBe(false) + }) }) diff --git a/src/plugin/hooks/create-continuation-hooks.ts b/src/plugin/hooks/create-continuation-hooks.ts index 96bf5de0c..da453f58d 100644 --- a/src/plugin/hooks/create-continuation-hooks.ts +++ b/src/plugin/hooks/create-continuation-hooks.ts @@ -49,7 +49,10 @@ export function createContinuationHooks(args: { safeCreateHook(hookName, factory, { enabled: safeHookEnabled }) const stopContinuationGuard = isHookEnabled("stop-continuation-guard") - ? safeHook("stop-continuation-guard", () => createStopContinuationGuardHook(ctx)) + ? safeHook("stop-continuation-guard", () => + createStopContinuationGuardHook(ctx, { + backgroundManager, + })) : null const compactionContextInjector = isHookEnabled("compaction-context-injector")