Merge pull request #2458 from code-yeongyu/fix/memory-leaks

fix: resolve 12 memory leaks (3 critical + 9 high)
This commit is contained in:
YeonGyu-Kim
2026-03-12 11:21:13 +09:00
committed by GitHub
51 changed files with 2883 additions and 262 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -11,6 +11,20 @@ import { createSkillHooks } from "./plugin/hooks/create-skill-hooks"
export type CreatedHooks = ReturnType<typeof createHooks>
type DisposableHook = { dispose?: () => void } | null | undefined
export type DisposableCreatedHooks = {
runtimeFallback?: DisposableHook
todoContinuationEnforcer?: DisposableHook
autoSlashCommand?: DisposableHook
}
export function disposeCreatedHooks(hooks: DisposableCreatedHooks): void {
hooks.runtimeFallback?.dispose?.()
hooks.todoContinuationEnforcer?.dispose?.()
hooks.autoSlashCommand?.dispose?.()
}
export function createHooks(args: {
ctx: PluginContext
pluginConfig: OhMyOpenCodeConfig
@@ -58,9 +72,16 @@ export function createHooks(args: {
availableSkills,
})
return {
const hooks = {
...core,
...continuation,
...skill,
}
return {
...hooks,
disposeHooks: (): void => {
disposeCreatedHooks(hooks)
},
}
}

View File

@@ -53,8 +53,8 @@ export function createManagers(args: {
log("[index] onSubagentSessionCreated callback completed")
},
onShutdown: () => {
tmuxSessionManager.cleanup().catch((error) => {
onShutdown: async () => {
await tmuxSessionManager.cleanup().catch((error) => {
log("[index] tmux cleanup error during shutdown:", error)
})
},

View File

@@ -0,0 +1,193 @@
import { tmpdir } from "node:os"
import type { PluginInput } from "@opencode-ai/plugin"
import { afterEach, describe, expect, test } from "bun:test"
import { ConcurrencyManager } from "./concurrency"
import { BackgroundManager } from "./manager"
import type { BackgroundTask, LaunchInput } from "./types"
const managersToShutdown: BackgroundManager[] = []
afterEach(() => {
while (managersToShutdown.length > 0) managersToShutdown.pop()?.shutdown()
})
function createBackgroundManager(config?: { defaultConcurrency?: number }): BackgroundManager {
const directory = tmpdir()
const client = { session: {} as PluginInput["client"]["session"] } as PluginInput["client"]
Reflect.set(client.session, "abort", async () => ({ data: true }))
Reflect.set(client.session, "create", async () => ({ data: { id: `session-${crypto.randomUUID().slice(0, 8)}` } }))
Reflect.set(client.session, "get", async () => ({ data: { directory } }))
Reflect.set(client.session, "messages", async () => ({ data: [] }))
Reflect.set(client.session, "prompt", async () => ({ data: { info: {}, parts: [] } }))
Reflect.set(client.session, "promptAsync", async () => ({ data: undefined }))
const manager = new BackgroundManager({
$: {} as PluginInput["$"],
client,
directory,
project: {} as PluginInput["project"],
serverUrl: new URL("http://localhost"),
worktree: directory,
}, config)
managersToShutdown.push(manager)
return manager
}
function createMockTask(overrides: Partial<BackgroundTask> & { id: string; parentSessionID: string }): BackgroundTask {
return {
id: overrides.id,
sessionID: overrides.sessionID,
parentSessionID: overrides.parentSessionID,
parentMessageID: overrides.parentMessageID ?? "parent-message-id",
description: overrides.description ?? "test task",
prompt: overrides.prompt ?? "test prompt",
agent: overrides.agent ?? "test-agent",
status: overrides.status ?? "running",
queuedAt: overrides.queuedAt,
startedAt: overrides.startedAt ?? new Date(),
completedAt: overrides.completedAt,
error: overrides.error,
model: overrides.model,
concurrencyKey: overrides.concurrencyKey,
concurrencyGroup: overrides.concurrencyGroup,
progress: overrides.progress,
}
}
function getTaskMap(manager: BackgroundManager): Map<string, BackgroundTask> { return Reflect.get(manager, "tasks") as Map<string, BackgroundTask> }
function getPendingByParent(manager: BackgroundManager): Map<string, Set<string>> { return Reflect.get(manager, "pendingByParent") as Map<string, Set<string>> }
function getQueuesByKey(manager: BackgroundManager): Map<string, Array<{ task: BackgroundTask; input: LaunchInput }>> { return Reflect.get(manager, "queuesByKey") as Map<string, Array<{ task: BackgroundTask; input: LaunchInput }>> }
function getConcurrencyManager(manager: BackgroundManager): ConcurrencyManager { return Reflect.get(manager, "concurrencyManager") as ConcurrencyManager }
function getCompletionTimers(manager: BackgroundManager): Map<string, ReturnType<typeof setTimeout>> { return Reflect.get(manager, "completionTimers") as Map<string, ReturnType<typeof setTimeout>> }
async function processKeyForTest(manager: BackgroundManager, key: string): Promise<void> {
const processKey = Reflect.get(manager, "processKey") as (key: string) => Promise<void>
await processKey.call(manager, key)
}
function runScheduledCleanup(manager: BackgroundManager, taskId: string): void {
const timer = getCompletionTimers(manager).get(taskId)
if (!timer) {
throw new Error(`Expected cleanup timer for task ${taskId}`)
}
const onTimeout = Reflect.get(timer, "_onTimeout") as (() => void) | undefined
if (!onTimeout) {
throw new Error(`Expected cleanup callback for task ${taskId}`)
}
onTimeout()
}
describe("BackgroundManager.cancelTask cleanup", () => {
test("#given a running task in BackgroundManager #when cancelTask called with skipNotification=true #then task is eventually removed from this.tasks Map", async () => {
// given
const manager = createBackgroundManager()
const task = createMockTask({
id: "task-skip-notification-cleanup",
parentSessionID: "parent-session-skip-notification-cleanup",
sessionID: "session-skip-notification-cleanup",
})
getTaskMap(manager).set(task.id, task)
getPendingByParent(manager).set(task.parentSessionID, new Set([task.id]))
// when
const cancelled = await manager.cancelTask(task.id, {
skipNotification: true,
source: "test",
})
// then
expect(cancelled).toBe(true)
expect(getPendingByParent(manager).get(task.parentSessionID)).toBeUndefined()
runScheduledCleanup(manager, task.id)
expect(manager.getTask(task.id)).toBeUndefined()
})
test("#given a running task #when cancelTask called with skipNotification=false #then task is also eventually removed", async () => {
// given
const manager = createBackgroundManager()
const task = createMockTask({
id: "task-notify-cleanup",
parentSessionID: "parent-session-notify-cleanup",
sessionID: "session-notify-cleanup",
})
getTaskMap(manager).set(task.id, task)
getPendingByParent(manager).set(task.parentSessionID, new Set([task.id]))
// when
const cancelled = await manager.cancelTask(task.id, {
skipNotification: false,
source: "test",
})
// then
expect(cancelled).toBe(true)
runScheduledCleanup(manager, task.id)
expect(manager.getTask(task.id)).toBeUndefined()
})
test("#given a running task #when cancelTask called with skipNotification=true #then concurrency slot is freed and pending tasks can start", async () => {
// given
const manager = createBackgroundManager({ defaultConcurrency: 1 })
const concurrencyManager = getConcurrencyManager(manager)
const concurrencyKey = "test-provider/test-model"
await concurrencyManager.acquire(concurrencyKey)
const runningTask = createMockTask({
id: "task-running-before-cancel",
parentSessionID: "parent-session-concurrency-cleanup",
sessionID: "session-running-before-cancel",
concurrencyKey,
})
const pendingTask = createMockTask({
id: "task-pending-after-cancel",
parentSessionID: runningTask.parentSessionID,
status: "pending",
startedAt: undefined,
queuedAt: new Date(),
model: { providerID: "test-provider", modelID: "test-model" },
})
const queuedInput: LaunchInput = {
agent: pendingTask.agent,
description: pendingTask.description,
model: pendingTask.model,
parentMessageID: pendingTask.parentMessageID,
parentSessionID: pendingTask.parentSessionID,
prompt: pendingTask.prompt,
}
getTaskMap(manager).set(runningTask.id, runningTask)
getTaskMap(manager).set(pendingTask.id, pendingTask)
getPendingByParent(manager).set(runningTask.parentSessionID, new Set([runningTask.id, pendingTask.id]))
getQueuesByKey(manager).set(concurrencyKey, [{ input: queuedInput, task: pendingTask }])
Reflect.set(manager, "startTask", async ({ task }: { task: BackgroundTask; input: LaunchInput }) => {
task.status = "running"
task.startedAt = new Date()
task.sessionID = "session-started-after-cancel"
task.concurrencyKey = concurrencyKey
task.concurrencyGroup = concurrencyKey
})
// when
const cancelled = await manager.cancelTask(runningTask.id, {
abortSession: false,
skipNotification: true,
source: "test",
})
await processKeyForTest(manager, concurrencyKey)
// then
expect(cancelled).toBe(true)
expect(concurrencyManager.getCount(concurrencyKey)).toBe(1)
expect(manager.getTask(pendingTask.id)?.status).toBe("running")
})
})

View File

@@ -3785,7 +3785,7 @@ describe("BackgroundManager.completionTimers - Memory Leak Fix", () => {
manager.shutdown()
})
test("should start cleanup timers only after all tasks complete", async () => {
test("should start per-task cleanup timers independently of sibling completion", async () => {
// given
const client = {
session: {
@@ -3832,7 +3832,7 @@ describe("BackgroundManager.completionTimers - Memory Leak Fix", () => {
// then
const completionTimers = getCompletionTimers(manager)
expect(completionTimers.size).toBe(0)
expect(completionTimers.size).toBe(1)
// when
await (manager as unknown as { notifyParentSession: (task: BackgroundTask) => Promise<void> })

View File

@@ -116,11 +116,12 @@ export class BackgroundManager {
private config?: BackgroundTaskConfig
private tmuxEnabled: boolean
private onSubagentSessionCreated?: OnSubagentSessionCreated
private onShutdown?: () => void
private onShutdown?: () => void | Promise<void>
private queuesByKey: Map<string, QueueItem[]> = new Map()
private processingKeys: Set<string> = new Set()
private completionTimers: Map<string, ReturnType<typeof setTimeout>> = new Map()
private completedTaskSummaries: Map<string, Array<{id: string, description: string}>> = new Map()
private idleDeferralTimers: Map<string, ReturnType<typeof setTimeout>> = new Map()
private notificationQueueByParent: Map<string, Promise<void>> = new Map()
private rootDescendantCounts: Map<string, number>
@@ -133,7 +134,7 @@ export class BackgroundManager {
options?: {
tmuxConfig?: TmuxConfig
onSubagentSessionCreated?: OnSubagentSessionCreated
onShutdown?: () => void
onShutdown?: () => void | Promise<void>
enableParentSessionNotifications?: boolean
}
) {
@@ -906,6 +907,13 @@ export class BackgroundManager {
this.idleDeferralTimers.delete(task.id)
}
this.cleanupPendingByParent(task)
this.clearNotificationsForTask(task.id)
const toastManager = getTaskToastManager()
if (toastManager) {
toastManager.removeTask(task.id)
}
this.scheduleTaskRemoval(task.id)
if (task.sessionID) {
SessionCategoryRegistry.remove(task.sessionID)
}
@@ -932,7 +940,12 @@ export class BackgroundManager {
this.pendingNotifications.delete(sessionID)
if (tasksToCancel.size === 0) return
if (tasksToCancel.size === 0) {
this.clearTaskHistoryWhenParentTasksGone(sessionID)
return
}
const parentSessionsToClear = new Set<string>()
const deletedSessionIDs = new Set<string>([sessionID])
for (const task of tasksToCancel.values()) {
@@ -942,6 +955,8 @@ export class BackgroundManager {
}
for (const task of tasksToCancel.values()) {
parentSessionsToClear.add(task.parentSessionID)
if (task.status === "running" || task.status === "pending") {
void this.cancelTask(task.id, {
source: "session.deleted",
@@ -959,6 +974,10 @@ export class BackgroundManager {
}
}
for (const parentSessionID of parentSessionsToClear) {
this.clearTaskHistoryWhenParentTasksGone(parentSessionID)
}
this.rootDescendantCounts.delete(sessionID)
SessionCategoryRegistry.remove(sessionID)
}
@@ -1125,6 +1144,39 @@ export class BackgroundManager {
}
}
private clearTaskHistoryWhenParentTasksGone(parentSessionID: string | undefined): void {
if (!parentSessionID) return
if (this.getTasksByParentSession(parentSessionID).length > 0) return
this.taskHistory.clearSession(parentSessionID)
this.completedTaskSummaries.delete(parentSessionID)
}
private scheduleTaskRemoval(taskId: string): void {
const existingTimer = this.completionTimers.get(taskId)
if (existingTimer) {
clearTimeout(existingTimer)
this.completionTimers.delete(taskId)
}
const timer = setTimeout(() => {
this.completionTimers.delete(taskId)
const task = this.tasks.get(taskId)
if (task) {
this.clearNotificationsForTask(taskId)
this.tasks.delete(taskId)
this.clearTaskHistoryWhenParentTasksGone(task.parentSessionID)
if (task.sessionID) {
subagentSessions.delete(task.sessionID)
SessionCategoryRegistry.remove(task.sessionID)
}
log("[background-agent] Removed completed task from memory:", taskId)
this.clearTaskHistoryWhenParentTasksGone(task?.parentSessionID)
}
}, TASK_CLEANUP_DELAY_MS)
this.completionTimers.set(taskId, timer)
}
async cancelTask(
taskId: string,
options?: { source?: string; reason?: string; abortSession?: boolean; skipNotification?: boolean }
@@ -1190,6 +1242,8 @@ export class BackgroundManager {
removeTaskToastTracking(task.id)
if (options?.skipNotification) {
this.cleanupPendingByParent(task)
this.scheduleTaskRemoval(task.id)
log(`[background-agent] Task cancelled via ${source} (notification skipped):`, task.id)
return true
}
@@ -1328,6 +1382,14 @@ export class BackgroundManager {
})
}
if (!this.completedTaskSummaries.has(task.parentSessionID)) {
this.completedTaskSummaries.set(task.parentSessionID, [])
}
this.completedTaskSummaries.get(task.parentSessionID)!.push({
id: task.id,
description: task.description,
})
// Update pending tracking and check if all tasks complete
const pendingSet = this.pendingByParent.get(task.parentSessionID)
let allComplete = false
@@ -1347,10 +1409,13 @@ export class BackgroundManager {
}
const completedTasks = allComplete
? Array.from(this.tasks.values())
.filter(t => t.parentSessionID === task.parentSessionID && t.status !== "running" && t.status !== "pending")
? (this.completedTaskSummaries.get(task.parentSessionID) ?? [{ id: task.id, description: task.description }])
: []
if (allComplete) {
this.completedTaskSummaries.delete(task.parentSessionID)
}
const statusText = task.status === "completed"
? "COMPLETED"
: task.status === "interrupt"
@@ -1480,29 +1545,8 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
})
}
if (allComplete) {
for (const completedTask of completedTasks) {
const taskId = completedTask.id
const existingTimer = this.completionTimers.get(taskId)
if (existingTimer) {
clearTimeout(existingTimer)
this.completionTimers.delete(taskId)
}
const timer = setTimeout(() => {
this.completionTimers.delete(taskId)
const taskToRemove = this.tasks.get(taskId)
if (taskToRemove) {
this.clearNotificationsForTask(taskId)
if (taskToRemove.sessionID) {
subagentSessions.delete(taskToRemove.sessionID)
SessionCategoryRegistry.remove(taskToRemove.sessionID)
}
this.tasks.delete(taskId)
log("[background-agent] Removed completed task from memory:", taskId)
}
}, TASK_CLEANUP_DELAY_MS)
this.completionTimers.set(taskId, timer)
}
if (task.status !== "running" && task.status !== "pending") {
this.scheduleTaskRemoval(task.id)
}
}
@@ -1554,6 +1598,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
}
}
}
this.cleanupPendingByParent(task)
this.markForNotification(task)
this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)).catch(err => {
log("[background-agent] Error in notifyParentSession for stale-pruned task:", { taskId: task.id, error: err })
@@ -1657,7 +1702,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
* Cancels all pending concurrency waiters and clears timers.
* Should be called when the plugin is unloaded.
*/
shutdown(): void {
async shutdown(): Promise<void> {
if (this.shutdownTriggered) return
this.shutdownTriggered = true
log("[background-agent] Shutting down BackgroundManager")
@@ -1675,7 +1720,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
// Notify shutdown listeners (e.g., tmux cleanup)
if (this.onShutdown) {
try {
this.onShutdown()
await this.onShutdown()
} catch (error) {
log("[background-agent] Error in onShutdown callback:", error)
}
@@ -1708,6 +1753,8 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
this.rootDescendantCounts.clear()
this.queuesByKey.clear()
this.processingKeys.clear()
this.taskHistory.clearAll()
this.completedTaskSummaries.clear()
this.unregisterProcessCleanup()
log("[background-agent] Shutdown complete")

View File

@@ -19,7 +19,7 @@ function registerProcessSignal(
}
interface CleanupTarget {
shutdown(): void
shutdown(): void | Promise<void>
}
const cleanupManagers = new Set<CleanupTarget>()
@@ -35,7 +35,9 @@ export function registerManagerForCleanup(manager: CleanupTarget): void {
const cleanupAll = () => {
for (const m of cleanupManagers) {
try {
m.shutdown()
void Promise.resolve(m.shutdown()).catch((error) => {
log("[background-agent] Error during async shutdown cleanup:", error)
})
} catch (error) {
log("[background-agent] Error during shutdown cleanup:", error)
}

View File

@@ -0,0 +1,245 @@
declare const require: (name: string) => any
const { describe, test, expect, afterEach } = require("bun:test")
import { tmpdir } from "node:os"
import type { PluginInput } from "@opencode-ai/plugin"
import { TASK_CLEANUP_DELAY_MS } from "./constants"
import { BackgroundManager } from "./manager"
import type { BackgroundTask } from "./types"
type PromptAsyncCall = {
path: { id: string }
body: {
noReply?: boolean
parts?: unknown[]
}
}
type FakeTimers = {
getDelay: (timer: ReturnType<typeof setTimeout>) => number | undefined
run: (timer: ReturnType<typeof setTimeout>) => void
restore: () => void
}
let managerUnderTest: BackgroundManager | undefined
let fakeTimers: FakeTimers | undefined
afterEach(() => {
managerUnderTest?.shutdown()
fakeTimers?.restore()
managerUnderTest = undefined
fakeTimers = undefined
})
function createTask(overrides: Partial<BackgroundTask> & { id: string; parentSessionID: string }): BackgroundTask {
const id = overrides.id
const parentSessionID = overrides.parentSessionID
const { id: _ignoredID, parentSessionID: _ignoredParentSessionID, ...rest } = overrides
return {
parentMessageID: overrides.parentMessageID ?? "parent-message-id",
description: overrides.description ?? overrides.id,
prompt: overrides.prompt ?? `Prompt for ${overrides.id}`,
agent: overrides.agent ?? "test-agent",
status: overrides.status ?? "running",
startedAt: overrides.startedAt ?? new Date("2026-03-11T00:00:00.000Z"),
...rest,
id,
parentSessionID,
}
}
function createManager(enableParentSessionNotifications: boolean): {
manager: BackgroundManager
promptAsyncCalls: PromptAsyncCall[]
} {
const promptAsyncCalls: PromptAsyncCall[] = []
const client = {
session: {
messages: async () => [],
prompt: async () => ({}),
promptAsync: async (call: PromptAsyncCall) => {
promptAsyncCalls.push(call)
return {}
},
abort: async () => ({}),
},
}
const placeholderClient = {} as PluginInput["client"]
const ctx: PluginInput = {
client: placeholderClient,
project: {} as PluginInput["project"],
directory: tmpdir(),
worktree: tmpdir(),
serverUrl: new URL("http://localhost"),
$: {} as PluginInput["$"],
}
const manager = new BackgroundManager(
ctx,
undefined,
{ enableParentSessionNotifications }
)
Reflect.set(manager, "client", client)
return { manager, promptAsyncCalls }
}
function installFakeTimers(): FakeTimers {
const originalSetTimeout = globalThis.setTimeout
const originalClearTimeout = globalThis.clearTimeout
const callbacks = new Map<ReturnType<typeof setTimeout>, () => void>()
const delays = new Map<ReturnType<typeof setTimeout>, number>()
globalThis.setTimeout = ((handler: Parameters<typeof setTimeout>[0], delay?: number, ...args: unknown[]): ReturnType<typeof setTimeout> => {
if (typeof handler !== "function") {
throw new Error("Expected function timeout handler")
}
const timer = originalSetTimeout(() => {}, 60_000)
originalClearTimeout(timer)
const callback = handler as (...callbackArgs: Array<unknown>) => void
callbacks.set(timer, () => callback(...args))
delays.set(timer, delay ?? 0)
return timer
}) as typeof setTimeout
globalThis.clearTimeout = ((timer: ReturnType<typeof setTimeout>): void => {
callbacks.delete(timer)
delays.delete(timer)
}) as typeof clearTimeout
return {
getDelay(timer) {
return delays.get(timer)
},
run(timer) {
const callback = callbacks.get(timer)
if (!callback) {
throw new Error(`Timer not found: ${String(timer)}`)
}
callbacks.delete(timer)
delays.delete(timer)
callback()
},
restore() {
globalThis.setTimeout = originalSetTimeout
globalThis.clearTimeout = originalClearTimeout
},
}
}
function getTasks(manager: BackgroundManager): Map<string, BackgroundTask> {
return Reflect.get(manager, "tasks") as Map<string, BackgroundTask>
}
function getPendingByParent(manager: BackgroundManager): Map<string, Set<string>> {
return Reflect.get(manager, "pendingByParent") as Map<string, Set<string>>
}
function getCompletionTimers(manager: BackgroundManager): Map<string, ReturnType<typeof setTimeout>> {
return Reflect.get(manager, "completionTimers") as Map<string, ReturnType<typeof setTimeout>>
}
async function notifyParentSessionForTest(manager: BackgroundManager, task: BackgroundTask): Promise<void> {
const notifyParentSession = Reflect.get(manager, "notifyParentSession") as (task: BackgroundTask) => Promise<void>
return notifyParentSession.call(manager, task)
}
function getRequiredTimer(manager: BackgroundManager, taskID: string): ReturnType<typeof setTimeout> {
const timer = getCompletionTimers(manager).get(taskID)
expect(timer).toBeDefined()
if (timer === undefined) {
throw new Error(`Missing completion timer for ${taskID}`)
}
return timer
}
describe("BackgroundManager.notifyParentSession cleanup scheduling", () => {
describe("#given 2 tasks for same parent and task A completed", () => {
test("#when task B is still running #then task A is cleaned up from this.tasks after delay even though task B is not done", async () => {
// given
const { manager } = createManager(false)
managerUnderTest = manager
fakeTimers = installFakeTimers()
const taskA = createTask({ id: "task-a", parentSessionID: "parent-1", description: "task A", status: "completed", completedAt: new Date("2026-03-11T00:01:00.000Z") })
const taskB = createTask({ id: "task-b", parentSessionID: "parent-1", description: "task B", status: "running" })
getTasks(manager).set(taskA.id, taskA)
getTasks(manager).set(taskB.id, taskB)
getPendingByParent(manager).set(taskA.parentSessionID, new Set([taskA.id, taskB.id]))
// when
await notifyParentSessionForTest(manager, taskA)
const taskATimer = getRequiredTimer(manager, taskA.id)
expect(fakeTimers.getDelay(taskATimer)).toBe(TASK_CLEANUP_DELAY_MS)
fakeTimers.run(taskATimer)
// then
expect(fakeTimers.getDelay(taskATimer)).toBeUndefined()
expect(getTasks(manager).has(taskA.id)).toBe(false)
expect(getTasks(manager).get(taskB.id)).toBe(taskB)
})
})
describe("#given 2 tasks for same parent and both completed", () => {
test("#when the second completion notification is sent #then ALL BACKGROUND TASKS COMPLETE notification still works correctly", async () => {
// given
const { manager, promptAsyncCalls } = createManager(true)
managerUnderTest = manager
fakeTimers = installFakeTimers()
const taskA = createTask({ id: "task-a", parentSessionID: "parent-1", description: "task A", status: "completed", completedAt: new Date("2026-03-11T00:01:00.000Z") })
const taskB = createTask({ id: "task-b", parentSessionID: "parent-1", description: "task B", status: "running" })
getTasks(manager).set(taskA.id, taskA)
getTasks(manager).set(taskB.id, taskB)
getPendingByParent(manager).set(taskA.parentSessionID, new Set([taskA.id, taskB.id]))
await notifyParentSessionForTest(manager, taskA)
taskB.status = "completed"
taskB.completedAt = new Date("2026-03-11T00:02:00.000Z")
// when
await notifyParentSessionForTest(manager, taskB)
// then
expect(promptAsyncCalls).toHaveLength(2)
expect(getCompletionTimers(manager).size).toBe(2)
const allCompleteCall = promptAsyncCalls[1]
expect(allCompleteCall).toBeDefined()
if (!allCompleteCall) {
throw new Error("Missing all-complete notification call")
}
expect(allCompleteCall.body.noReply).toBe(false)
const allCompletePayload = JSON.stringify(allCompleteCall.body.parts)
expect(allCompletePayload).toContain("ALL BACKGROUND TASKS COMPLETE")
expect(allCompletePayload).toContain(taskA.id)
expect(allCompletePayload).toContain(taskB.id)
expect(allCompletePayload).toContain(taskA.description)
expect(allCompletePayload).toContain(taskB.description)
})
})
describe("#given a completed task with cleanup timer scheduled", () => {
test("#when cleanup timer fires #then task is deleted from this.tasks Map", async () => {
// given
const { manager } = createManager(false)
managerUnderTest = manager
fakeTimers = installFakeTimers()
const task = createTask({ id: "task-a", parentSessionID: "parent-1", description: "task A", status: "completed", completedAt: new Date("2026-03-11T00:01:00.000Z") })
getTasks(manager).set(task.id, task)
getPendingByParent(manager).set(task.parentSessionID, new Set([task.id]))
await notifyParentSessionForTest(manager, task)
const cleanupTimer = getRequiredTimer(manager, task.id)
// when
expect(fakeTimers.getDelay(cleanupTimer)).toBe(TASK_CLEANUP_DELAY_MS)
fakeTimers.run(cleanupTimer)
// then
expect(getCompletionTimers(manager).has(task.id)).toBe(false)
expect(getTasks(manager).has(task.id)).toBe(false)
})
})
})

View File

@@ -0,0 +1,142 @@
import { afterEach, describe, expect, test } from "bun:test"
import { tmpdir } from "node:os"
import type { PluginInput } from "@opencode-ai/plugin"
import { BackgroundManager } from "./manager"
import { TaskHistory } from "./task-history"
import type { BackgroundTask } from "./types"
let managerUnderTest: BackgroundManager | undefined
afterEach(() => {
managerUnderTest?.shutdown()
managerUnderTest = undefined
})
function createManager(): BackgroundManager {
const client = {
session: {
abort: async () => ({}),
},
}
const placeholderClient = {} as PluginInput["client"]
const ctx: PluginInput = {
client: placeholderClient,
project: {} as PluginInput["project"],
directory: tmpdir(),
worktree: tmpdir(),
serverUrl: new URL("http://localhost"),
$: {} as PluginInput["$"],
}
const manager = new BackgroundManager(ctx)
Reflect.set(manager, "client", client)
return manager
}
function createTask(overrides: Partial<BackgroundTask> & { id: string; parentSessionID: string }): BackgroundTask {
const { id, parentSessionID, ...rest } = overrides
return {
...rest,
id,
parentSessionID,
parentMessageID: rest.parentMessageID ?? "parent-message-id",
description: rest.description ?? id,
prompt: rest.prompt ?? `Prompt for ${id}`,
agent: rest.agent ?? "test-agent",
status: rest.status ?? "running",
startedAt: rest.startedAt ?? new Date("2026-03-11T00:00:00.000Z"),
}
}
function getTaskMap(manager: BackgroundManager): Map<string, BackgroundTask> {
return Reflect.get(manager, "tasks") as Map<string, BackgroundTask>
}
function pruneStaleTasksAndNotificationsForTest(manager: BackgroundManager): void {
const pruneStaleTasksAndNotifications = Reflect.get(manager, "pruneStaleTasksAndNotifications") as () => void
pruneStaleTasksAndNotifications.call(manager)
}
describe("task history cleanup", () => {
test("#given TaskHistory with entries for multiple parents #when clearSession called for one parent #then only that parent's entries are removed, others remain", () => {
// given
const history = new TaskHistory()
history.record("parent-1", { id: "task-1", agent: "explore", description: "task 1", status: "pending" })
history.record("parent-2", { id: "task-2", agent: "oracle", description: "task 2", status: "running" })
// when
history.clearSession("parent-1")
// then
expect(history.getByParentSession("parent-1")).toHaveLength(0)
expect(history.getByParentSession("parent-2")).toHaveLength(1)
})
test("#given TaskHistory with entries for multiple parents #when clearAll called #then all entries are removed", () => {
// given
const history = new TaskHistory()
history.record("parent-1", { id: "task-1", agent: "explore", description: "task 1", status: "pending" })
history.record("parent-2", { id: "task-2", agent: "oracle", description: "task 2", status: "running" })
// when
history.clearAll()
// then
expect(history.getByParentSession("parent-1")).toHaveLength(0)
expect(history.getByParentSession("parent-2")).toHaveLength(0)
})
test("#given BackgroundManager with taskHistory entries #when shutdown() called #then taskHistory is cleared via clearAll()", () => {
// given
const manager = createManager()
managerUnderTest = manager
manager.taskHistory.record("parent-1", { id: "task-1", agent: "explore", description: "task 1", status: "pending" })
let clearAllCalls = 0
const originalClearAll = manager.taskHistory.clearAll.bind(manager.taskHistory)
manager.taskHistory.clearAll = (): void => {
clearAllCalls += 1
originalClearAll()
}
// when
manager.shutdown()
// then
expect(clearAllCalls).toBe(1)
expect(manager.taskHistory.getByParentSession("parent-1")).toHaveLength(0)
managerUnderTest = undefined
})
test("#given BackgroundManager with stale tasks for one parent #when pruneStaleTasksAndNotifications() runs #then history is preserved until delayed cleanup", () => {
// given
const manager = createManager()
managerUnderTest = manager
const staleTask = createTask({
id: "task-stale",
parentSessionID: "parent-1",
startedAt: new Date(Date.now() - 31 * 60 * 1000),
})
const liveTask = createTask({
id: "task-live",
parentSessionID: "parent-2",
startedAt: new Date(),
})
getTaskMap(manager).set(staleTask.id, staleTask)
getTaskMap(manager).set(liveTask.id, liveTask)
manager.taskHistory.record("parent-1", { id: staleTask.id, agent: staleTask.agent, description: staleTask.description, status: staleTask.status })
manager.taskHistory.record("parent-2", { id: liveTask.id, agent: liveTask.agent, description: liveTask.description, status: liveTask.status })
// when
pruneStaleTasksAndNotificationsForTest(manager)
// then
expect(manager.taskHistory.getByParentSession("parent-1")).toHaveLength(1)
expect(manager.taskHistory.getByParentSession("parent-2")).toHaveLength(1)
})
})

View File

@@ -54,6 +54,10 @@ export class TaskHistory {
this.entries.delete(parentSessionID)
}
clearAll(): void {
this.entries.clear()
}
formatForCompaction(parentSessionID: string): string | null {
const list = this.getByParentSession(parentSessionID)
if (list.length === 0) return null

View File

@@ -19,11 +19,13 @@ export function registerProcessCleanup(state: SkillMcpManagerState): void {
state.cleanupRegistered = true
const cleanup = async (): Promise<void> => {
state.shutdownGeneration++
for (const managed of state.clients.values()) {
await closeManagedClient(managed)
}
state.clients.clear()
state.pendingConnections.clear()
state.disconnectedSessions.clear()
}
// Note: Node's 'exit' event is synchronous-only, so we rely on signal handlers for async cleanup.
@@ -79,12 +81,23 @@ async function cleanupIdleClients(state: SkillMcpManagerState): Promise<void> {
}
}
if (state.clients.size === 0) {
if (state.clients.size === 0 && state.pendingConnections.size === 0) {
stopCleanupTimer(state)
unregisterProcessCleanup(state)
}
}
export async function disconnectSession(state: SkillMcpManagerState, sessionID: string): Promise<void> {
let hasPendingForSession = false
for (const key of state.pendingConnections.keys()) {
if (key.startsWith(`${sessionID}:`)) {
hasPendingForSession = true
break
}
}
if (hasPendingForSession) {
state.disconnectedSessions.set(sessionID, (state.disconnectedSessions.get(sessionID) ?? 0) + 1)
}
const keysToRemove: string[] = []
for (const [key, managed] of state.clients.entries()) {
@@ -96,22 +109,33 @@ export async function disconnectSession(state: SkillMcpManagerState, sessionID:
}
}
for (const key of state.pendingConnections.keys()) {
if (key.startsWith(`${sessionID}:`)) {
keysToRemove.push(key)
}
}
for (const key of keysToRemove) {
state.pendingConnections.delete(key)
}
if (state.clients.size === 0) {
if (state.clients.size === 0 && state.pendingConnections.size === 0) {
stopCleanupTimer(state)
unregisterProcessCleanup(state)
}
}
export async function disconnectAll(state: SkillMcpManagerState): Promise<void> {
state.shutdownGeneration++
state.disposed = true
stopCleanupTimer(state)
unregisterProcessCleanup(state)
const clients = Array.from(state.clients.values())
state.clients.clear()
state.pendingConnections.clear()
state.disconnectedSessions.clear()
state.inFlightConnections.clear()
state.authProviders.clear()
for (const managed of clients) {

View File

@@ -0,0 +1,291 @@
import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"
import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types"
import type { SkillMcpClientInfo, SkillMcpManagerState } from "./types"
type Deferred<TValue> = {
promise: Promise<TValue>
resolve: (value: TValue) => void
reject: (error: Error) => void
}
const pendingConnects: Deferred<void>[] = []
const trackedStates: SkillMcpManagerState[] = []
const createdClients: MockClient[] = []
const createdTransports: MockStdioClientTransport[] = []
class MockClient {
readonly close = mock(async () => {})
constructor(
_clientInfo: { name: string; version: string },
_options: { capabilities: Record<string, never> }
) {
createdClients.push(this)
}
async connect(_transport: MockStdioClientTransport): Promise<void> {
const pendingConnect = pendingConnects.shift()
if (pendingConnect) {
await pendingConnect.promise
}
}
}
class MockStdioClientTransport {
readonly close = mock(async () => {})
constructor(_options: { command: string; args?: string[]; env?: Record<string, string>; stderr?: string }) {
createdTransports.push(this)
}
}
mock.module("@modelcontextprotocol/sdk/client/index.js", () => ({
Client: MockClient,
}))
mock.module("@modelcontextprotocol/sdk/client/stdio.js", () => ({
StdioClientTransport: MockStdioClientTransport,
}))
const { disconnectAll, disconnectSession } = await import("./cleanup")
const { getOrCreateClient } = await import("./connection")
function createDeferred<TValue>(): Deferred<TValue> {
let resolvePromise: ((value: TValue) => void) | null = null
let rejectPromise: ((error: Error) => void) | null = null
const promise = new Promise<TValue>((resolve, reject) => {
resolvePromise = resolve
rejectPromise = reject
})
if (!resolvePromise || !rejectPromise) {
throw new Error("Failed to create deferred promise")
}
return {
promise,
resolve: resolvePromise,
reject: rejectPromise,
}
}
function createState(): SkillMcpManagerState {
const state: SkillMcpManagerState = {
clients: new Map(),
pendingConnections: new Map(),
disconnectedSessions: new Map(),
authProviders: new Map(),
cleanupRegistered: false,
cleanupInterval: null,
cleanupHandlers: [],
idleTimeoutMs: 5 * 60 * 1000,
shutdownGeneration: 0,
inFlightConnections: new Map(),
disposed: false,
}
trackedStates.push(state)
return state
}
function createClientInfo(sessionID: string): SkillMcpClientInfo {
return {
serverName: "race-server",
skillName: "race-skill",
sessionID,
}
}
function createClientKey(info: SkillMcpClientInfo): string {
return `${info.sessionID}:${info.skillName}:${info.serverName}`
}
const stdioConfig: ClaudeCodeMcpServer = {
command: "mock-mcp-server",
}
beforeEach(() => {
pendingConnects.length = 0
createdClients.length = 0
createdTransports.length = 0
})
afterEach(async () => {
for (const state of trackedStates) {
await disconnectAll(state)
}
trackedStates.length = 0
pendingConnects.length = 0
createdClients.length = 0
createdTransports.length = 0
})
describe("getOrCreateClient disconnect race", () => {
it("#given pending connection for session A #when disconnectSession(A) is called before connection completes #then completed client is not added to state.clients", async () => {
const state = createState()
const info = createClientInfo("session-a")
const clientKey = createClientKey(info)
const pendingConnect = createDeferred<void>()
pendingConnects.push(pendingConnect)
const clientPromise = getOrCreateClient({ state, clientKey, info, config: stdioConfig })
expect(state.pendingConnections.has(clientKey)).toBe(true)
await disconnectSession(state, info.sessionID)
pendingConnect.resolve(undefined)
await expect(clientPromise).rejects.toThrow(/disconnected during MCP connection setup/)
expect(state.clients.has(clientKey)).toBe(false)
expect(state.pendingConnections.has(clientKey)).toBe(false)
expect(state.disconnectedSessions.has(info.sessionID)).toBe(false)
expect(createdClients).toHaveLength(1)
expect(createdClients[0]?.close).toHaveBeenCalledTimes(1)
expect(createdTransports[0]?.close).toHaveBeenCalledTimes(1)
})
it("#given session A in disconnectedSessions #when new connection completes with no remaining pending #then disconnectedSessions entry is cleaned up", async () => {
const state = createState()
const info = createClientInfo("session-a")
const clientKey = createClientKey(info)
state.disconnectedSessions.set(info.sessionID, 1)
const client = await getOrCreateClient({ state, clientKey, info, config: stdioConfig })
expect(state.disconnectedSessions.has(info.sessionID)).toBe(false)
expect(state.clients.get(clientKey)?.client).toBe(client)
expect(createdClients[0]?.close).not.toHaveBeenCalled()
})
it("#given no pending connections #when disconnectSession is called #then no errors occur and session is not added to disconnectedSessions", async () => {
const state = createState()
await expect(disconnectSession(state, "session-a")).resolves.toBeUndefined()
expect(state.disconnectedSessions.has("session-a")).toBe(false)
expect(state.pendingConnections.size).toBe(0)
expect(state.clients.size).toBe(0)
})
})
describe("getOrCreateClient disconnectAll race", () => {
it("#given pending connection #when disconnectAll() is called before connection completes #then client is not added to state.clients", async () => {
const state = createState()
const info = createClientInfo("session-a")
const clientKey = createClientKey(info)
const pendingConnect = createDeferred<void>()
pendingConnects.push(pendingConnect)
const clientPromise = getOrCreateClient({ state, clientKey, info, config: stdioConfig })
expect(state.pendingConnections.has(clientKey)).toBe(true)
await disconnectAll(state)
pendingConnect.resolve(undefined)
await expect(clientPromise).rejects.toThrow(/connection completed after shutdown/)
expect(state.clients.has(clientKey)).toBe(false)
})
it("#given state after disconnectAll() completed #when getOrCreateClient() is called #then it throws shut down error and registers nothing", async () => {
const state = createState()
const info = createClientInfo("session-b")
const clientKey = createClientKey(info)
await disconnectAll(state)
await expect(getOrCreateClient({ state, clientKey, info, config: stdioConfig })).rejects.toThrow(/has been shut down/)
expect(state.clients.size).toBe(0)
expect(state.pendingConnections.size).toBe(0)
expect(state.inFlightConnections.size).toBe(0)
expect(state.disposed).toBe(true)
expect(createdClients).toHaveLength(0)
expect(createdTransports).toHaveLength(0)
})
})
describe("getOrCreateClient multi-key disconnect race", () => {
it("#given 2 pending connections for session A #when disconnectSession(A) before both complete #then both old connections are rejected", async () => {
const state = createState()
const infoKey1 = createClientInfo("session-a")
const infoKey2 = { ...createClientInfo("session-a"), serverName: "server-2" }
const clientKey1 = createClientKey(infoKey1)
const clientKey2 = `${infoKey2.sessionID}:${infoKey2.skillName}:${infoKey2.serverName}`
const pendingConnect1 = createDeferred<void>()
const pendingConnect2 = createDeferred<void>()
pendingConnects.push(pendingConnect1)
pendingConnects.push(pendingConnect2)
const promise1 = getOrCreateClient({ state, clientKey: clientKey1, info: infoKey1, config: stdioConfig })
const promise2 = getOrCreateClient({ state, clientKey: clientKey2, info: infoKey2, config: stdioConfig })
expect(state.pendingConnections.size).toBe(2)
await disconnectSession(state, "session-a")
pendingConnect1.resolve(undefined)
await expect(promise1).rejects.toThrow(/disconnected during MCP connection setup/)
pendingConnect2.resolve(undefined)
await expect(promise2).rejects.toThrow(/disconnected during MCP connection setup/)
expect(state.clients.has(clientKey1)).toBe(false)
expect(state.clients.has(clientKey2)).toBe(false)
expect(state.disconnectedSessions.has("session-a")).toBe(false)
})
it("#given a superseded pending connection #when the old connection completes #then the stale client is removed from state.clients", async () => {
const state = createState()
const info = createClientInfo("session-a")
const clientKey = createClientKey(info)
const pendingConnect = createDeferred<void>()
const supersedingConnection = createDeferred<Awaited<ReturnType<typeof getOrCreateClient>>>()
pendingConnects.push(pendingConnect)
const clientPromise = getOrCreateClient({ state, clientKey, info, config: stdioConfig })
state.pendingConnections.set(clientKey, supersedingConnection.promise)
pendingConnect.resolve(undefined)
await expect(clientPromise).rejects.toThrow(/superseded by a newer connection attempt/)
expect(state.clients.has(clientKey)).toBe(false)
expect(createdClients[0]?.close).toHaveBeenCalledTimes(1)
})
it("#given a superseded pending connection #when a newer client already replaced the map entry #then the stale cleanup does not delete the newer client", async () => {
const state = createState()
const info = createClientInfo("session-a")
const clientKey = createClientKey(info)
const pendingConnect = createDeferred<void>()
const supersedingConnection = createDeferred<Awaited<ReturnType<typeof getOrCreateClient>>>()
pendingConnects.push(pendingConnect)
const newerClient = new MockClient(
{ name: "newer-client", version: "1.0.0" },
{ capabilities: {} },
)
const newerTransport = new MockStdioClientTransport({ command: "mock-mcp-server" })
let replacedEntry = false
const originalSet = state.clients.set.bind(state.clients)
Reflect.set(state.clients, "set", (key: string, value: SkillMcpManagerState["clients"] extends Map<string, infer TValue> ? TValue : never) => {
originalSet(key, value)
if (!replacedEntry && key === clientKey) {
replacedEntry = true
originalSet(key, {
client: newerClient as never,
transport: newerTransport as never,
skillName: info.skillName,
lastUsedAt: Date.now(),
connectionType: "stdio",
})
}
return state.clients
})
const clientPromise = getOrCreateClient({ state, clientKey, info, config: stdioConfig })
state.pendingConnections.set(clientKey, supersedingConnection.promise)
pendingConnect.resolve(undefined)
await expect(clientPromise).rejects.toThrow(/superseded by a newer connection attempt/)
expect(state.clients.get(clientKey)?.client.close).toBe(newerClient.close)
expect(newerClient.close).not.toHaveBeenCalled()
})
})

View File

@@ -7,6 +7,13 @@ import { createHttpClient } from "./http-client"
import { createStdioClient } from "./stdio-client"
import type { SkillMcpClientConnectionParams, SkillMcpClientInfo, SkillMcpManagerState } from "./types"
function removeClientIfCurrent(state: SkillMcpManagerState, clientKey: string, client: Client): void {
const managed = state.clients.get(clientKey)
if (managed?.client === client) {
state.clients.delete(clientKey)
}
}
export async function getOrCreateClient(params: {
state: SkillMcpManagerState
clientKey: string
@@ -15,6 +22,10 @@ export async function getOrCreateClient(params: {
}): Promise<Client> {
const { state, clientKey, info, config } = params
if (state.disposed) {
throw new Error(`MCP manager for "${info.sessionID}" has been shut down, cannot create new connections.`)
}
const existing = state.clients.get(clientKey)
if (existing) {
existing.lastUsedAt = Date.now()
@@ -28,14 +39,52 @@ export async function getOrCreateClient(params: {
}
const expandedConfig = expandEnvVarsInObject(config)
const connectionPromise = createClient({ state, clientKey, info, config: expandedConfig })
state.pendingConnections.set(clientKey, connectionPromise)
let currentConnectionPromise!: Promise<Client>
state.inFlightConnections.set(info.sessionID, (state.inFlightConnections.get(info.sessionID) ?? 0) + 1)
currentConnectionPromise = (async () => {
const disconnectGenAtStart = state.disconnectedSessions.get(info.sessionID) ?? 0
const shutdownGenAtStart = state.shutdownGeneration
const client = await createClient({ state, clientKey, info, config: expandedConfig })
const isStale = state.pendingConnections.has(clientKey) && state.pendingConnections.get(clientKey) !== currentConnectionPromise
if (isStale) {
removeClientIfCurrent(state, clientKey, client)
try { await client.close() } catch {}
throw new Error(`Connection for "${info.sessionID}" was superseded by a newer connection attempt.`)
}
if (state.shutdownGeneration !== shutdownGenAtStart) {
removeClientIfCurrent(state, clientKey, client)
try { await client.close() } catch {}
throw new Error(`Shutdown occurred during MCP connection for "${info.sessionID}"`)
}
const currentDisconnectGen = state.disconnectedSessions.get(info.sessionID) ?? 0
if (currentDisconnectGen > disconnectGenAtStart) {
await forceReconnect(state, clientKey)
throw new Error(`Session "${info.sessionID}" disconnected during MCP connection setup.`)
}
return client
})()
state.pendingConnections.set(clientKey, currentConnectionPromise)
try {
const client = await connectionPromise
const client = await currentConnectionPromise
return client
} finally {
state.pendingConnections.delete(clientKey)
if (state.pendingConnections.get(clientKey) === currentConnectionPromise) {
state.pendingConnections.delete(clientKey)
}
const remaining = (state.inFlightConnections.get(info.sessionID) ?? 1) - 1
if (remaining <= 0) {
state.inFlightConnections.delete(info.sessionID)
state.disconnectedSessions.delete(info.sessionID)
} else {
state.inFlightConnections.set(info.sessionID, remaining)
}
}
}

View File

@@ -0,0 +1,133 @@
import { Client } from "@modelcontextprotocol/sdk/client/index.js"
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"
import { afterEach, describe, expect, it } from "bun:test"
import { disconnectSession, registerProcessCleanup, unregisterProcessCleanup } from "./cleanup"
import type { ManagedClient, SkillMcpManagerState } from "./types"
const trackedStates: SkillMcpManagerState[] = []
afterEach(() => {
for (const state of trackedStates) {
unregisterProcessCleanup(state)
}
trackedStates.length = 0
})
const expectedCleanupHandlerCount = process.platform === "win32" ? 3 : 2
function createState(): SkillMcpManagerState {
const state: SkillMcpManagerState = {
clients: new Map(),
pendingConnections: new Map(),
disconnectedSessions: new Map(),
authProviders: new Map(),
cleanupRegistered: false,
cleanupInterval: null,
cleanupHandlers: [],
idleTimeoutMs: 5 * 60 * 1000,
shutdownGeneration: 0,
inFlightConnections: new Map(),
disposed: false,
}
trackedStates.push(state)
return state
}
function createManagedClient(skillName: string): ManagedClient {
return {
client: new Client(
{ name: `test-${skillName}`, version: "1.0.0" },
{ capabilities: {} }
),
transport: new StreamableHTTPClientTransport(new URL("https://example.com/mcp")),
skillName,
lastUsedAt: Date.now(),
connectionType: "http",
}
}
describe("disconnectSession cleanup registration", () => {
it("#given state with 1 client and cleanup registered #when disconnectSession removes last client #then process cleanup handlers are unregistered", async () => {
// given
const state = createState()
const signalIntCountBeforeRegister = process.listenerCount("SIGINT")
const signalTermCountBeforeRegister = process.listenerCount("SIGTERM")
state.clients.set("session-1:skill-1:server-1", createManagedClient("skill-1"))
registerProcessCleanup(state)
// when
await disconnectSession(state, "session-1")
// then
expect(state.cleanupRegistered).toBe(false)
expect(state.cleanupHandlers).toEqual([])
expect(process.listenerCount("SIGINT")).toBe(signalIntCountBeforeRegister)
expect(process.listenerCount("SIGTERM")).toBe(signalTermCountBeforeRegister)
})
it("#given state with 2 clients in different sessions #when disconnectSession removes one session #then process cleanup handlers remain registered", async () => {
// given
const state = createState()
const signalIntCountBeforeRegister = process.listenerCount("SIGINT")
const signalTermCountBeforeRegister = process.listenerCount("SIGTERM")
state.clients.set("session-1:skill-1:server-1", createManagedClient("skill-1"))
state.clients.set("session-2:skill-2:server-2", createManagedClient("skill-2"))
registerProcessCleanup(state)
// when
await disconnectSession(state, "session-1")
// then
expect(state.clients.has("session-2:skill-2:server-2")).toBe(true)
expect(state.cleanupRegistered).toBe(true)
expect(state.cleanupHandlers).toHaveLength(expectedCleanupHandlerCount)
expect(process.listenerCount("SIGINT")).toBe(signalIntCountBeforeRegister + 1)
expect(process.listenerCount("SIGTERM")).toBe(signalTermCountBeforeRegister + 1)
})
it("#given state with 2 clients in different sessions #when both sessions disconnected #then process cleanup handlers are unregistered", async () => {
// given
const state = createState()
const signalIntCountBeforeRegister = process.listenerCount("SIGINT")
const signalTermCountBeforeRegister = process.listenerCount("SIGTERM")
state.clients.set("session-1:skill-1:server-1", createManagedClient("skill-1"))
state.clients.set("session-2:skill-2:server-2", createManagedClient("skill-2"))
registerProcessCleanup(state)
// when
await disconnectSession(state, "session-1")
await disconnectSession(state, "session-2")
// then
expect(state.clients.size).toBe(0)
expect(state.cleanupRegistered).toBe(false)
expect(state.cleanupHandlers).toEqual([])
expect(process.listenerCount("SIGINT")).toBe(signalIntCountBeforeRegister)
expect(process.listenerCount("SIGTERM")).toBe(signalTermCountBeforeRegister)
})
it("#given state with 1 client and pending connection for different session and cleanup registered #when disconnectSession removes last client but pendingConnections remain #then process cleanup handlers stay registered", async () => {
const state = createState()
const signalIntCountBeforeRegister = process.listenerCount("SIGINT")
const signalTermCountBeforeRegister = process.listenerCount("SIGTERM")
const pendingClient = createManagedClient("skill-pending").client
state.clients.set("session-1:skill-1:server-1", createManagedClient("skill-1"))
state.pendingConnections.set("session-2:skill-2:server-2", Promise.resolve(pendingClient))
registerProcessCleanup(state)
await disconnectSession(state, "session-1")
expect(state.clients.size).toBe(0)
expect(state.pendingConnections.size).toBe(1)
expect(state.cleanupRegistered).toBe(true)
expect(state.cleanupHandlers).toHaveLength(expectedCleanupHandlerCount)
expect(process.listenerCount("SIGINT")).toBe(signalIntCountBeforeRegister + 1)
expect(process.listenerCount("SIGTERM")).toBe(signalTermCountBeforeRegister + 1)
})
})

View File

@@ -24,6 +24,7 @@ function redactUrl(urlStr: string): string {
export async function createHttpClient(params: SkillMcpClientConnectionParams): Promise<Client> {
const { state, clientKey, info, config } = params
const shutdownGenAtStart = state.shutdownGeneration
if (!config.url) {
throw new Error(`MCP server "${info.serverName}" is configured for HTTP but missing 'url' field.`)
@@ -72,6 +73,12 @@ export async function createHttpClient(params: SkillMcpClientConnectionParams):
)
}
if (state.shutdownGeneration !== shutdownGenAtStart) {
try { await client.close() } catch {}
try { await transport.close() } catch {}
throw new Error(`MCP server "${info.serverName}" connection completed after shutdown`)
}
const managedClient = {
client,
transport,

View File

@@ -10,11 +10,15 @@ export class SkillMcpManager {
private readonly state: SkillMcpManagerState = {
clients: new Map(),
pendingConnections: new Map(),
disconnectedSessions: new Map(),
authProviders: new Map(),
cleanupRegistered: false,
cleanupInterval: null,
cleanupHandlers: [],
idleTimeoutMs: 5 * 60 * 1000,
shutdownGeneration: 0,
inFlightConnections: new Map(),
disposed: false,
}
private getClientKey(info: SkillMcpClientInfo): string {

View File

@@ -14,6 +14,7 @@ function getStdioCommand(config: ClaudeCodeMcpServer, serverName: string): strin
export async function createStdioClient(params: SkillMcpClientConnectionParams): Promise<Client> {
const { state, clientKey, info, config } = params
const shutdownGenAtStart = state.shutdownGeneration
const command = getStdioCommand(config, info.serverName)
const args = config.args ?? []
@@ -55,6 +56,12 @@ export async function createStdioClient(params: SkillMcpClientConnectionParams):
)
}
if (state.shutdownGeneration !== shutdownGenAtStart) {
try { await client.close() } catch {}
try { await transport.close() } catch {}
throw new Error(`MCP server "${info.serverName}" connection completed after shutdown`)
}
const managedClient = {
client,
transport,

View File

@@ -51,11 +51,15 @@ export interface ProcessCleanupHandler {
export interface SkillMcpManagerState {
clients: Map<string, ManagedClient>
pendingConnections: Map<string, Promise<Client>>
disconnectedSessions: Map<string, number>
authProviders: Map<string, McpOAuthProvider>
cleanupRegistered: boolean
cleanupInterval: ReturnType<typeof setInterval> | null
cleanupHandlers: ProcessCleanupHandler[]
idleTimeoutMs: number
shutdownGeneration: number
inFlightConnections: Map<string, number>
disposed: boolean
}
export interface SkillMcpClientConnectionParams {

View File

@@ -1,6 +1,6 @@
import type { PluginInput } from "@opencode-ai/plugin"
import type { TmuxConfig } from "../../config/schema"
import type { TrackedSession, CapacityConfig } from "./types"
import type { TrackedSession, CapacityConfig, WindowState } from "./types"
import { log, normalizeSDKResponse } from "../../shared"
import {
isInsideTmux as defaultIsInsideTmux,
@@ -13,6 +13,7 @@ import { queryWindowState } from "./pane-state-querier"
import { decideSpawnActions, decideCloseAction, type SessionMapping } from "./decision-engine"
import { executeActions, executeAction } from "./action-executor"
import { TmuxPollingManager } from "./polling-manager"
import { createTrackedSession, markTrackedSessionClosePending } from "./tracked-session-state"
type OpencodeClient = PluginInput["client"]
interface SessionCreatedEvent {
@@ -38,6 +39,7 @@ const defaultTmuxDeps: TmuxUtilDeps = {
const DEFERRED_SESSION_TTL_MS = 5 * 60 * 1000
const MAX_DEFERRED_QUEUE_SIZE = 20
const MAX_CLOSE_RETRY_COUNT = 3
/**
* State-first Tmux Session Manager
@@ -106,6 +108,123 @@ export class TmuxSessionManager {
}))
}
private removeTrackedSession(sessionId: string): void {
this.sessions.delete(sessionId)
if (this.sessions.size === 0) {
this.pollingManager.stopPolling()
}
}
private markSessionClosePending(sessionId: string): void {
const tracked = this.sessions.get(sessionId)
if (!tracked) return
this.sessions.set(sessionId, markTrackedSessionClosePending(tracked))
log("[tmux-session-manager] marked session close pending", {
sessionId,
paneId: tracked.paneId,
closeRetryCount: tracked.closeRetryCount,
})
}
private async queryWindowStateSafely(): Promise<WindowState | null> {
if (!this.sourcePaneId) return null
try {
return await queryWindowState(this.sourcePaneId)
} catch (error) {
log("[tmux-session-manager] failed to query window state for close", {
error: String(error),
})
return null
}
}
private async tryCloseTrackedSession(tracked: TrackedSession): Promise<boolean> {
const state = await this.queryWindowStateSafely()
if (!state) return false
try {
const result = await executeAction(
{ type: "close", paneId: tracked.paneId, sessionId: tracked.sessionId },
{
config: this.tmuxConfig,
serverUrl: this.serverUrl,
windowState: state,
sourcePaneId: this.sourcePaneId,
}
)
return result.success
} catch (error) {
log("[tmux-session-manager] close session pane failed", {
sessionId: tracked.sessionId,
paneId: tracked.paneId,
error: String(error),
})
return false
}
}
private async retryPendingCloses(): Promise<void> {
const pendingSessions = Array.from(this.sessions.values()).filter(
(tracked) => tracked.closePending,
)
for (const tracked of pendingSessions) {
if (!this.sessions.has(tracked.sessionId)) continue
if (tracked.closeRetryCount >= MAX_CLOSE_RETRY_COUNT) {
log("[tmux-session-manager] force removing close-pending session after max retries", {
sessionId: tracked.sessionId,
paneId: tracked.paneId,
closeRetryCount: tracked.closeRetryCount,
})
this.removeTrackedSession(tracked.sessionId)
continue
}
const closed = await this.tryCloseTrackedSession(tracked)
if (closed) {
log("[tmux-session-manager] retried close succeeded", {
sessionId: tracked.sessionId,
paneId: tracked.paneId,
closeRetryCount: tracked.closeRetryCount,
})
this.removeTrackedSession(tracked.sessionId)
continue
}
const currentTracked = this.sessions.get(tracked.sessionId)
if (!currentTracked || !currentTracked.closePending) {
continue
}
const nextRetryCount = currentTracked.closeRetryCount + 1
if (nextRetryCount >= MAX_CLOSE_RETRY_COUNT) {
log("[tmux-session-manager] force removing close-pending session after failed retry", {
sessionId: currentTracked.sessionId,
paneId: currentTracked.paneId,
closeRetryCount: nextRetryCount,
})
this.removeTrackedSession(currentTracked.sessionId)
continue
}
this.sessions.set(currentTracked.sessionId, {
...currentTracked,
closePending: true,
closeRetryCount: nextRetryCount,
})
log("[tmux-session-manager] retried close failed", {
sessionId: currentTracked.sessionId,
paneId: currentTracked.paneId,
closeRetryCount: nextRetryCount,
})
}
}
private enqueueDeferredSession(sessionId: string, title: string): void {
if (this.deferredSessions.has(sessionId)) return
if (this.deferredQueue.length >= MAX_DEFERRED_QUEUE_SIZE) {
@@ -257,14 +376,14 @@ export class TmuxSessionManager {
})
}
const now = Date.now()
this.sessions.set(sessionId, {
this.sessions.set(
sessionId,
paneId: result.spawnedPaneId,
description: deferred.title,
createdAt: new Date(now),
lastSeenAt: new Date(now),
})
createTrackedSession({
sessionId,
paneId: result.spawnedPaneId,
description: deferred.title,
}),
)
this.removeDeferredSession(sessionId)
this.pollingManager.startPolling()
log("[tmux-session-manager] deferred session attached", {
@@ -324,6 +443,13 @@ export class TmuxSessionManager {
const sessionId = info.id
const title = info.title ?? "Subagent"
if (!this.sourcePaneId) {
log("[tmux-session-manager] no source pane id")
return
}
await this.retryPendingCloses()
if (
this.sessions.has(sessionId) ||
this.pendingSessions.has(sessionId) ||
@@ -332,11 +458,6 @@ export class TmuxSessionManager {
log("[tmux-session-manager] session already tracked or pending", { sessionId })
return
}
if (!this.sourcePaneId) {
log("[tmux-session-manager] no source pane id")
return
}
const sourcePaneId = this.sourcePaneId
this.pendingSessions.add(sessionId)
@@ -418,14 +539,14 @@ export class TmuxSessionManager {
})
}
const now = Date.now()
this.sessions.set(sessionId, {
this.sessions.set(
sessionId,
paneId: result.spawnedPaneId,
description: title,
createdAt: new Date(now),
lastSeenAt: new Date(now),
})
createTrackedSession({
sessionId,
paneId: result.spawnedPaneId,
description: title,
}),
)
log("[tmux-session-manager] pane spawned and tracked", {
sessionId,
paneId: result.spawnedPaneId,
@@ -485,27 +606,40 @@ export class TmuxSessionManager {
log("[tmux-session-manager] onSessionDeleted", { sessionId: event.sessionID })
const state = await queryWindowState(this.sourcePaneId)
const state = await this.queryWindowStateSafely()
if (!state) {
this.sessions.delete(event.sessionID)
this.markSessionClosePending(event.sessionID)
return
}
const closeAction = decideCloseAction(state, event.sessionID, this.getSessionMappings())
if (closeAction) {
await executeAction(closeAction, {
if (!closeAction) {
this.removeTrackedSession(event.sessionID)
return
}
try {
const result = await executeAction(closeAction, {
config: this.tmuxConfig,
serverUrl: this.serverUrl,
windowState: state,
sourcePaneId: this.sourcePaneId,
})
if (!result.success) {
this.markSessionClosePending(event.sessionID)
return
}
} catch (error) {
log("[tmux-session-manager] failed to close pane for deleted session", {
sessionId: event.sessionID,
error: String(error),
})
this.markSessionClosePending(event.sessionID)
return
}
this.sessions.delete(event.sessionID)
if (this.sessions.size === 0) {
this.pollingManager.stopPolling()
}
this.removeTrackedSession(event.sessionID)
}
@@ -513,29 +647,28 @@ export class TmuxSessionManager {
const tracked = this.sessions.get(sessionId)
if (!tracked) return
if (tracked.closePending && tracked.closeRetryCount >= MAX_CLOSE_RETRY_COUNT) {
log("[tmux-session-manager] force removing close-pending session after max retries", {
sessionId,
paneId: tracked.paneId,
closeRetryCount: tracked.closeRetryCount,
})
this.removeTrackedSession(sessionId)
return
}
log("[tmux-session-manager] closing session pane", {
sessionId,
paneId: tracked.paneId,
})
const state = this.sourcePaneId ? await queryWindowState(this.sourcePaneId) : null
if (state) {
await executeAction(
{ type: "close", paneId: tracked.paneId, sessionId },
{
config: this.tmuxConfig,
serverUrl: this.serverUrl,
windowState: state,
sourcePaneId: this.sourcePaneId,
}
)
const closed = await this.tryCloseTrackedSession(tracked)
if (!closed) {
this.markSessionClosePending(sessionId)
return
}
this.sessions.delete(sessionId)
if (this.sessions.size === 0) {
this.pollingManager.stopPolling()
}
this.removeTrackedSession(sessionId)
}
createEventHandler(): (input: { event: { type: string; properties?: unknown } }) => Promise<void> {
@@ -552,30 +685,22 @@ export class TmuxSessionManager {
if (this.sessions.size > 0) {
log("[tmux-session-manager] closing all panes", { count: this.sessions.size })
const state = this.sourcePaneId ? await queryWindowState(this.sourcePaneId) : null
if (state) {
const closePromises = Array.from(this.sessions.values()).map((s) =>
executeAction(
{ type: "close", paneId: s.paneId, sessionId: s.sessionId },
{
config: this.tmuxConfig,
serverUrl: this.serverUrl,
windowState: state,
sourcePaneId: this.sourcePaneId,
}
).catch((err) =>
log("[tmux-session-manager] cleanup error for pane", {
paneId: s.paneId,
error: String(err),
}),
),
)
await Promise.all(closePromises)
const sessionIds = Array.from(this.sessions.keys())
for (const sessionId of sessionIds) {
try {
await this.closeSessionById(sessionId)
} catch (error) {
log("[tmux-session-manager] cleanup error for pane", {
sessionId,
error: String(error),
})
}
}
this.sessions.clear()
}
await this.retryPendingCloses()
log("[tmux-session-manager] cleanup complete")
}
}

View File

@@ -12,6 +12,8 @@ describe("TmuxPollingManager overlap", () => {
description: "test",
createdAt: new Date(),
lastSeenAt: new Date(),
closePending: false,
closeRetryCount: 0,
})
let activeCalls = 0

View File

@@ -6,6 +6,7 @@ import { queryWindowState } from "./pane-state-querier"
import { decideSpawnActions, type SessionMapping } from "./decision-engine"
import { executeActions } from "./action-executor"
import type { SessionCreatedEvent } from "./session-created-event"
import { createTrackedSession } from "./tracked-session-state"
type OpencodeClient = PluginInput["client"]
@@ -152,14 +153,14 @@ export async function handleSessionCreated(
return
}
const now = Date.now()
deps.sessions.set(sessionId, {
deps.sessions.set(
sessionId,
paneId: result.spawnedPaneId,
description: title,
createdAt: new Date(now),
lastSeenAt: new Date(now),
})
createTrackedSession({
sessionId,
paneId: result.spawnedPaneId,
description: title,
}),
)
log("[tmux-session-manager] pane spawned and tracked", {
sessionId,

View File

@@ -0,0 +1,28 @@
import type { TrackedSession } from "./types"
export function createTrackedSession(params: {
sessionId: string
paneId: string
description: string
now?: Date
}): TrackedSession {
const now = params.now ?? new Date()
return {
sessionId: params.sessionId,
paneId: params.paneId,
description: params.description,
createdAt: now,
lastSeenAt: now,
closePending: false,
closeRetryCount: 0,
}
}
export function markTrackedSessionClosePending(tracked: TrackedSession): TrackedSession {
return {
...tracked,
closePending: true,
closeRetryCount: tracked.closePending ? tracked.closeRetryCount + 1 : tracked.closeRetryCount,
}
}

View File

@@ -4,6 +4,8 @@ export interface TrackedSession {
description: string
createdAt: Date
lastSeenAt: Date
closePending: boolean
closeRetryCount: number
// Stability detection fields (prevents premature closure)
lastMessageCount?: number
stableIdlePolls?: number

View File

@@ -0,0 +1,271 @@
import { beforeEach, describe, expect, mock, test } from "bun:test"
import type { TmuxConfig } from "../../config/schema"
import type { ActionResult, ExecuteContext, ExecuteActionsResult } from "./action-executor"
import type { TmuxUtilDeps } from "./manager"
import type { TrackedSession, WindowState } from "./types"
const mockQueryWindowState = mock<(paneId: string) => Promise<WindowState | null>>(async () => ({
windowWidth: 220,
windowHeight: 44,
mainPane: { paneId: "%0", width: 110, height: 44, left: 0, top: 0, title: "main", isActive: true },
agentPanes: [],
}))
const mockExecuteAction = mock<(
action: { type: string },
ctx: ExecuteContext,
) => Promise<ActionResult>>(async () => ({ success: true }))
const mockExecuteActions = mock<(
actions: unknown[],
ctx: ExecuteContext,
) => Promise<ExecuteActionsResult>>(async () => ({
success: true,
spawnedPaneId: "%1",
results: [],
}))
const mockIsInsideTmux = mock<() => boolean>(() => true)
const mockGetCurrentPaneId = mock<() => string | undefined>(() => "%0")
mock.module("./pane-state-querier", () => ({
queryWindowState: mockQueryWindowState,
}))
mock.module("./action-executor", () => ({
executeAction: mockExecuteAction,
executeActions: mockExecuteActions,
}))
mock.module("../../shared/tmux", () => ({
isInsideTmux: mockIsInsideTmux,
getCurrentPaneId: mockGetCurrentPaneId,
POLL_INTERVAL_BACKGROUND_MS: 10,
SESSION_READY_POLL_INTERVAL_MS: 10,
SESSION_READY_TIMEOUT_MS: 50,
SESSION_MISSING_GRACE_MS: 1_000,
}))
const mockTmuxDeps: TmuxUtilDeps = {
isInsideTmux: mockIsInsideTmux,
getCurrentPaneId: mockGetCurrentPaneId,
}
function createConfig(): TmuxConfig {
return {
enabled: true,
layout: "main-vertical",
main_pane_size: 60,
main_pane_min_width: 80,
agent_pane_min_width: 40,
}
}
function createContext() {
const shell = Object.assign(
() => {
throw new Error("shell should not be called in this test")
},
{
braces: () => [],
escape: (input: string) => input,
env() {
return shell
},
cwd() {
return shell
},
nothrow() {
return shell
},
throws() {
return shell
},
},
)
return {
project: {
id: "project-id",
worktree: "/tmp/omo-fix-memory-leaks",
time: { created: Date.now() },
},
directory: "/tmp/omo-fix-memory-leaks",
worktree: "/tmp/omo-fix-memory-leaks",
serverUrl: new URL("http://localhost:4096"),
$: shell,
client: {
session: {
status: mock(async () => ({ data: {} })),
messages: mock(async () => ({ data: [] })),
},
},
}
}
function createTrackedSession(overrides?: Partial<TrackedSession>): TrackedSession {
return {
sessionId: "ses_pending",
paneId: "%1",
description: "Pending pane",
createdAt: new Date(),
lastSeenAt: new Date(),
closePending: false,
closeRetryCount: 0,
...overrides,
}
}
function getTrackedSessions(target: object): Map<string, TrackedSession> {
const sessions = Reflect.get(target, "sessions")
if (!(sessions instanceof Map)) {
throw new Error("Expected sessions map")
}
return sessions
}
function getRetryPendingCloses(target: object): () => Promise<void> {
const retryPendingCloses = Reflect.get(target, "retryPendingCloses")
if (typeof retryPendingCloses !== "function") {
throw new Error("Expected retryPendingCloses method")
}
return retryPendingCloses.bind(target)
}
function getCloseSessionById(target: object): (sessionId: string) => Promise<void> {
const closeSessionById = Reflect.get(target, "closeSessionById")
if (typeof closeSessionById !== "function") {
throw new Error("Expected closeSessionById method")
}
return closeSessionById.bind(target)
}
function createManager(
TmuxSessionManager: typeof import("./manager").TmuxSessionManager,
): import("./manager").TmuxSessionManager {
return Reflect.construct(TmuxSessionManager, [createContext(), createConfig(), mockTmuxDeps])
}
describe("TmuxSessionManager zombie pane handling", () => {
beforeEach(() => {
mockQueryWindowState.mockClear()
mockExecuteAction.mockClear()
mockExecuteActions.mockClear()
mockIsInsideTmux.mockClear()
mockGetCurrentPaneId.mockClear()
mockQueryWindowState.mockImplementation(async () => ({
windowWidth: 220,
windowHeight: 44,
mainPane: { paneId: "%0", width: 110, height: 44, left: 0, top: 0, title: "main", isActive: true },
agentPanes: [],
}))
mockExecuteAction.mockImplementation(async () => ({ success: true }))
mockExecuteActions.mockImplementation(async () => ({
success: true,
spawnedPaneId: "%1",
results: [],
}))
mockIsInsideTmux.mockReturnValue(true)
mockGetCurrentPaneId.mockReturnValue("%0")
})
test("#given session in sessions Map #when onSessionDeleted called with null window state #then session stays in Map with closePending true", async () => {
// given
mockQueryWindowState.mockImplementation(async () => null)
const { TmuxSessionManager } = await import("./manager")
const manager = createManager(TmuxSessionManager)
const sessions = getTrackedSessions(manager)
sessions.set("ses_pending", createTrackedSession())
// when
await manager.onSessionDeleted({ sessionID: "ses_pending" })
// then
const tracked = sessions.get("ses_pending")
expect(tracked).toBeDefined()
expect(tracked?.closePending).toBe(true)
expect(tracked?.closeRetryCount).toBe(0)
expect(mockExecuteAction).not.toHaveBeenCalled()
})
test("#given session with closePending true #when retryPendingCloses succeeds #then session is removed from Map", async () => {
// given
const { TmuxSessionManager } = await import("./manager")
const manager = createManager(TmuxSessionManager)
const sessions = getTrackedSessions(manager)
sessions.set(
"ses_pending",
createTrackedSession({ closePending: true, closeRetryCount: 0 }),
)
// when
await getRetryPendingCloses(manager)()
// then
expect(sessions.has("ses_pending")).toBe(false)
expect(mockExecuteAction).toHaveBeenCalledTimes(1)
})
test("#given session with closePending true and closeRetryCount >= 3 #when retryPendingCloses called #then session is force-removed from Map", async () => {
// given
const { TmuxSessionManager } = await import("./manager")
const manager = createManager(TmuxSessionManager)
const sessions = getTrackedSessions(manager)
sessions.set(
"ses_pending",
createTrackedSession({ closePending: true, closeRetryCount: 3 }),
)
// when
await getRetryPendingCloses(manager)()
// then
expect(sessions.has("ses_pending")).toBe(false)
expect(mockQueryWindowState).not.toHaveBeenCalled()
expect(mockExecuteAction).not.toHaveBeenCalled()
})
test("#given session with closePending true and closeRetryCount >= 3 #when closeSessionById called #then session is force-removed without retrying close", async () => {
// given
const { TmuxSessionManager } = await import("./manager")
const manager = createManager(TmuxSessionManager)
const sessions = getTrackedSessions(manager)
sessions.set(
"ses_pending",
createTrackedSession({ closePending: true, closeRetryCount: 3 }),
)
// when
await getCloseSessionById(manager)("ses_pending")
// then
expect(sessions.has("ses_pending")).toBe(false)
expect(mockQueryWindowState).not.toHaveBeenCalled()
expect(mockExecuteAction).not.toHaveBeenCalled()
})
test("#given close-pending session removed during async close #when retryPendingCloses fails #then it does not resurrect stale session state", async () => {
// given
const { TmuxSessionManager } = await import("./manager")
const manager = createManager(TmuxSessionManager)
const sessions = getTrackedSessions(manager)
sessions.set(
"ses_pending",
createTrackedSession({ closePending: true, closeRetryCount: 0 }),
)
mockExecuteAction.mockImplementationOnce(async () => {
sessions.delete("ses_pending")
return { success: false }
})
// when
await getRetryPendingCloses(manager)()
// then
expect(sessions.has("ses_pending")).toBe(false)
})
})

View File

@@ -0,0 +1,142 @@
import { beforeEach, describe, expect, it, mock, spyOn } from "bun:test"
import { AUTO_SLASH_COMMAND_TAG_OPEN } from "./constants"
import type {
AutoSlashCommandHookInput,
AutoSlashCommandHookOutput,
CommandExecuteBeforeInput,
CommandExecuteBeforeOutput,
} from "./types"
import * as shared from "../../shared"
const executeSlashCommandMock = mock(
async (parsed: { command: string; args: string; raw: string }) => ({
success: true,
replacementText: parsed.raw,
})
)
mock.module("./executor", () => ({
executeSlashCommand: executeSlashCommandMock,
}))
const logMock = spyOn(shared, "log").mockImplementation(() => {})
const { createAutoSlashCommandHook } = await import("./hook")
function createChatInput(sessionID: string, messageID: string): AutoSlashCommandHookInput {
return {
sessionID,
messageID,
}
}
function createChatOutput(text: string): AutoSlashCommandHookOutput {
return {
message: {},
parts: [{ type: "text", text }],
}
}
function createCommandInput(sessionID: string, command: string): CommandExecuteBeforeInput {
return {
sessionID,
command,
arguments: "",
}
}
function createCommandOutput(text: string): CommandExecuteBeforeOutput {
return {
parts: [{ type: "text", text }],
}
}
describe("createAutoSlashCommandHook leak prevention", () => {
beforeEach(() => {
executeSlashCommandMock.mockClear()
logMock.mockClear()
})
describe("#given hook with sessionProcessedCommandExecutions", () => {
describe("#when same command executed twice for same session", () => {
it("#then second execution is deduplicated", async () => {
const hook = createAutoSlashCommandHook()
const input = createCommandInput("session-dedup", "leak-test-command")
const firstOutput = createCommandOutput("first")
const secondOutput = createCommandOutput("second")
await hook["command.execute.before"](input, firstOutput)
await hook["command.execute.before"](input, secondOutput)
expect(executeSlashCommandMock).toHaveBeenCalledTimes(1)
expect(firstOutput.parts[0].text).toContain(AUTO_SLASH_COMMAND_TAG_OPEN)
expect(secondOutput.parts[0].text).toBe("second")
})
})
})
describe("#given hook with entries from multiple sessions", () => {
describe("#when dispose() is called", () => {
it("#then both Sets are empty", async () => {
const hook = createAutoSlashCommandHook()
await hook["chat.message"](
createChatInput("session-chat", "message-chat"),
createChatOutput("/leak-chat")
)
await hook["command.execute.before"](
createCommandInput("session-command", "leak-command"),
createCommandOutput("before")
)
executeSlashCommandMock.mockClear()
hook.dispose()
const chatOutputAfterDispose = createChatOutput("/leak-chat")
const commandOutputAfterDispose = createCommandOutput("after")
await hook["chat.message"](
createChatInput("session-chat", "message-chat"),
chatOutputAfterDispose
)
await hook["command.execute.before"](
createCommandInput("session-command", "leak-command"),
commandOutputAfterDispose
)
expect(executeSlashCommandMock).toHaveBeenCalledTimes(2)
expect(chatOutputAfterDispose.parts[0].text).toContain(AUTO_SLASH_COMMAND_TAG_OPEN)
expect(commandOutputAfterDispose.parts[0].text).toContain(
AUTO_SLASH_COMMAND_TAG_OPEN
)
})
})
})
describe("#given Set with more than 10000 entries", () => {
describe("#when new entry added", () => {
it("#then Set size is reduced", async () => {
const hook = createAutoSlashCommandHook()
const oldestInput = createChatInput("session-oldest", "message-oldest")
await hook["chat.message"](oldestInput, createChatOutput("/leak-oldest"))
for (let index = 0; index < 10000; index += 1) {
await hook["chat.message"](
createChatInput(`session-${index}`, `message-${index}`),
createChatOutput(`/leak-${index}`)
)
}
const newestInput = createChatInput("session-newest", "message-newest")
await hook["chat.message"](newestInput, createChatOutput("/leak-newest"))
executeSlashCommandMock.mockClear()
const oldestRetryOutput = createChatOutput("/leak-oldest")
const newestRetryOutput = createChatOutput("/leak-newest")
await hook["chat.message"](oldestInput, oldestRetryOutput)
await hook["chat.message"](newestInput, newestRetryOutput)
expect(executeSlashCommandMock).toHaveBeenCalledTimes(1)
expect(oldestRetryOutput.parts[0].text).toContain(AUTO_SLASH_COMMAND_TAG_OPEN)
expect(newestRetryOutput.parts[0].text).toBe("/leak-newest")
})
})
})
})

View File

@@ -9,6 +9,7 @@ import {
AUTO_SLASH_COMMAND_TAG_CLOSE,
AUTO_SLASH_COMMAND_TAG_OPEN,
} from "./constants"
import { createProcessedCommandStore } from "./processed-command-store"
import type {
AutoSlashCommandHookInput,
AutoSlashCommandHookOutput,
@@ -17,8 +18,22 @@ import type {
} from "./types"
import type { LoadedSkill } from "../../features/opencode-skill-loader"
const sessionProcessedCommands = new Set<string>()
const sessionProcessedCommandExecutions = new Set<string>()
function isRecord(value: unknown): value is Record<string, unknown> {
return typeof value === "object" && value !== null
}
function getDeletedSessionID(properties: unknown): string | null {
if (!isRecord(properties)) {
return null
}
const info = properties.info
if (!isRecord(info)) {
return null
}
return typeof info.id === "string" ? info.id : null
}
export interface AutoSlashCommandHookOptions {
skills?: LoadedSkill[]
@@ -32,6 +47,13 @@ export function createAutoSlashCommandHook(options?: AutoSlashCommandHookOptions
pluginsEnabled: options?.pluginsEnabled,
enabledPluginsOverride: options?.enabledPluginsOverride,
}
const sessionProcessedCommands = createProcessedCommandStore()
const sessionProcessedCommandExecutions = createProcessedCommandStore()
const dispose = (): void => {
sessionProcessedCommands.clear()
sessionProcessedCommandExecutions.clear()
}
return {
"chat.message": async (
@@ -61,7 +83,9 @@ export function createAutoSlashCommandHook(options?: AutoSlashCommandHookOptions
return
}
const commandKey = `${input.sessionID}:${input.messageID}:${parsed.command}`
const commandKey = input.messageID
? `${input.sessionID}:${input.messageID}:${parsed.command}`
: `${input.sessionID}:${parsed.command}`
if (sessionProcessedCommands.has(commandKey)) {
return
}
@@ -101,7 +125,7 @@ export function createAutoSlashCommandHook(options?: AutoSlashCommandHookOptions
input: CommandExecuteBeforeInput,
output: CommandExecuteBeforeOutput
): Promise<void> => {
const commandKey = `${input.sessionID}:${input.command}:${Date.now()}`
const commandKey = `${input.sessionID}:${input.command.toLowerCase()}:${input.arguments || ""}`
if (sessionProcessedCommandExecutions.has(commandKey)) {
return
}
@@ -145,5 +169,23 @@ export function createAutoSlashCommandHook(options?: AutoSlashCommandHookOptions
command: input.command,
})
},
event: async ({
event,
}: {
event: { type: string; properties?: unknown }
}): Promise<void> => {
if (event.type !== "session.deleted") {
return
}
const sessionID = getDeletedSessionID(event.properties)
if (!sessionID) {
return
}
sessionProcessedCommands.cleanupSession(sessionID)
sessionProcessedCommandExecutions.cleanupSession(sessionID)
},
dispose,
}
}

View File

@@ -0,0 +1,41 @@
const MAX_PROCESSED_ENTRY_COUNT = 10_000
function trimProcessedEntries(entries: Set<string>): Set<string> {
if (entries.size <= MAX_PROCESSED_ENTRY_COUNT) {
return entries
}
return new Set(Array.from(entries).slice(Math.floor(entries.size / 2)))
}
function removeSessionEntries(entries: Set<string>, sessionID: string): Set<string> {
const sessionPrefix = `${sessionID}:`
return new Set(Array.from(entries).filter((entry) => !entry.startsWith(sessionPrefix)))
}
export interface ProcessedCommandStore {
has(commandKey: string): boolean
add(commandKey: string): void
cleanupSession(sessionID: string): void
clear(): void
}
export function createProcessedCommandStore(): ProcessedCommandStore {
let entries = new Set<string>()
return {
has(commandKey: string): boolean {
return entries.has(commandKey)
},
add(commandKey: string): void {
entries.add(commandKey)
entries = trimProcessedEntries(entries)
},
cleanupSession(sessionID: string): void {
entries = removeSessionEntries(entries, sessionID)
},
clear(): void {
entries.clear()
},
}
}

View File

@@ -1,4 +1,4 @@
import type { HookDeps } from "./types"
import type { HookDeps, RuntimeFallbackTimeout } from "./types"
import { HOOK_NAME } from "./constants"
import { log } from "../../shared/logger"
import { normalizeAgentName, resolveAgentForSession } from "./agent-resolver"
@@ -9,8 +9,8 @@ import { SessionCategoryRegistry } from "../../shared/session-category-registry"
const SESSION_TTL_MS = 30 * 60 * 1000
declare function setTimeout(callback: () => void | Promise<void>, delay?: number): ReturnType<typeof globalThis.setTimeout>
declare function clearTimeout(timeout: ReturnType<typeof globalThis.setTimeout>): void
declare function setTimeout(callback: () => void | Promise<void>, delay?: number): RuntimeFallbackTimeout
declare function clearTimeout(timeout: RuntimeFallbackTimeout): void
export function createAutoRetryHelpers(deps: HookDeps) {
const { ctx, config, options, sessionStates, sessionLastAccess, sessionRetryInFlight, sessionAwaitingFallbackResult, sessionFallbackTimeouts, pluginConfig } = deps

View File

@@ -0,0 +1,160 @@
import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"
import type { HookDeps, RuntimeFallbackPluginInput } from "./types"
let capturedDeps: HookDeps | undefined
const mockCreateAutoRetryHelpers = mock((deps: HookDeps) => {
capturedDeps = deps
return {
abortSessionRequest: async () => {},
clearSessionFallbackTimeout: () => {},
scheduleSessionFallbackTimeout: () => {},
autoRetryWithFallback: async () => {},
resolveAgentForSessionFromContext: async () => undefined,
cleanupStaleSessions: () => {},
}
})
const mockCreateEventHandler = mock(() => async () => {})
const mockCreateMessageUpdateHandler = mock(() => async () => {})
const mockCreateChatMessageHandler = mock(() => async () => {})
mock.module("./auto-retry", () => ({
createAutoRetryHelpers: mockCreateAutoRetryHelpers,
}))
mock.module("./event-handler", () => ({
createEventHandler: mockCreateEventHandler,
}))
mock.module("./message-update-handler", () => ({
createMessageUpdateHandler: mockCreateMessageUpdateHandler,
}))
mock.module("./chat-message-handler", () => ({
createChatMessageHandler: mockCreateChatMessageHandler,
}))
const { createRuntimeFallbackHook } = await import("./hook")
function createMockContext(): RuntimeFallbackPluginInput {
return {
client: {
session: {
abort: async () => ({}),
messages: async () => ({}),
promptAsync: async () => ({}),
},
tui: {
showToast: async () => ({}),
},
},
directory: "/test",
}
}
describe("createRuntimeFallbackHook dispose", () => {
const originalSetInterval = globalThis.setInterval
const originalClearInterval = globalThis.clearInterval
const originalClearTimeout = globalThis.clearTimeout
const createdIntervals: Array<ReturnType<typeof originalSetInterval>> = []
const clearedIntervals: Array<Parameters<typeof originalClearInterval>[0]> = []
const clearedTimeouts: Array<Parameters<typeof originalClearTimeout>[0]> = []
const timeoutMapSizesDuringClear: number[] = []
beforeEach(() => {
capturedDeps = undefined
createdIntervals.length = 0
clearedIntervals.length = 0
clearedTimeouts.length = 0
timeoutMapSizesDuringClear.length = 0
mockCreateAutoRetryHelpers.mockClear()
mockCreateEventHandler.mockClear()
mockCreateMessageUpdateHandler.mockClear()
mockCreateChatMessageHandler.mockClear()
const wrappedSetInterval = ((handler: () => void, timeout?: number) => {
const interval = originalSetInterval(handler, timeout)
createdIntervals.push(interval)
return interval
}) as typeof globalThis.setInterval
const wrappedClearInterval = ((interval?: Parameters<typeof clearInterval>[0]) => {
clearedIntervals.push(interval)
return originalClearInterval(interval)
}) as typeof globalThis.clearInterval
const wrappedClearTimeout = ((timeout?: Parameters<typeof clearTimeout>[0]) => {
timeoutMapSizesDuringClear.push(capturedDeps?.sessionFallbackTimeouts.size ?? -1)
clearedTimeouts.push(timeout)
return originalClearTimeout(timeout)
}) as typeof globalThis.clearTimeout
globalThis.setInterval = wrappedSetInterval
globalThis.clearInterval = wrappedClearInterval
globalThis.clearTimeout = wrappedClearTimeout
})
afterEach(() => {
globalThis.setInterval = originalSetInterval
globalThis.clearInterval = originalClearInterval
globalThis.clearTimeout = originalClearTimeout
})
test("#given runtime-fallback hook created #when dispose() is called #then cleanup interval is cleared", () => {
// given
const hook = createRuntimeFallbackHook(createMockContext(), { pluginConfig: {} })
// when
hook.dispose?.()
// then
expect(createdIntervals).toHaveLength(1)
expect(clearedIntervals).toEqual([createdIntervals[0]])
})
test("#given hook with session state data #when dispose() is called #then all Maps and Sets are empty", () => {
// given
const hook = createRuntimeFallbackHook(createMockContext(), { pluginConfig: {} })
const fallbackTimeout = setTimeout(() => {}, 60_000)
capturedDeps?.sessionStates.set("session-1", {
originalModel: "anthropic/claude-opus-4-6",
currentModel: "openai/gpt-5.4",
fallbackIndex: 1,
failedModels: new Map([["anthropic/claude-opus-4-6", 1]]),
attemptCount: 1,
})
capturedDeps?.sessionLastAccess.set("session-1", Date.now())
capturedDeps?.sessionRetryInFlight.add("session-1")
capturedDeps?.sessionAwaitingFallbackResult.add("session-1")
capturedDeps?.sessionFallbackTimeouts.set("session-1", fallbackTimeout)
// when
hook.dispose?.()
// then
expect(capturedDeps?.sessionStates.size).toBe(0)
expect(capturedDeps?.sessionLastAccess.size).toBe(0)
expect(capturedDeps?.sessionRetryInFlight.size).toBe(0)
expect(capturedDeps?.sessionAwaitingFallbackResult.size).toBe(0)
expect(capturedDeps?.sessionFallbackTimeouts.size).toBe(0)
})
test("#given hook with pending fallback timeouts #when dispose() is called #then timeouts are cleared before Map is emptied", () => {
// given
const hook = createRuntimeFallbackHook(createMockContext(), { pluginConfig: {} })
const fallbackTimeout = setTimeout(() => {}, 60_000)
capturedDeps?.sessionFallbackTimeouts.set("session-1", fallbackTimeout)
// when
hook.dispose?.()
// then
expect(clearedTimeouts).toEqual([fallbackTimeout])
expect(timeoutMapSizesDuringClear).toEqual([1])
expect(capturedDeps?.sessionFallbackTimeouts.size).toBe(0)
})
})

View File

@@ -1,5 +1,4 @@
import type { PluginInput } from "@opencode-ai/plugin"
import type { HookDeps, RuntimeFallbackHook, RuntimeFallbackOptions } from "./types"
import type { HookDeps, RuntimeFallbackHook, RuntimeFallbackInterval, RuntimeFallbackOptions, RuntimeFallbackPluginInput, RuntimeFallbackTimeout } from "./types"
import { DEFAULT_CONFIG, HOOK_NAME } from "./constants"
import { log } from "../../shared/logger"
import { loadPluginConfig } from "../../plugin-config"
@@ -8,8 +7,12 @@ import { createEventHandler } from "./event-handler"
import { createMessageUpdateHandler } from "./message-update-handler"
import { createChatMessageHandler } from "./chat-message-handler"
declare function setInterval(callback: () => void, delay?: number): RuntimeFallbackInterval
declare function clearInterval(interval: RuntimeFallbackInterval): void
declare function clearTimeout(timeout: RuntimeFallbackTimeout): void
export function createRuntimeFallbackHook(
ctx: PluginInput,
ctx: RuntimeFallbackPluginInput,
options?: RuntimeFallbackOptions
): RuntimeFallbackHook {
const config = {
@@ -60,8 +63,23 @@ export function createRuntimeFallbackHook(
await baseEventHandler({ event })
}
const dispose = () => {
clearInterval(cleanupInterval)
for (const fallbackTimeout of deps.sessionFallbackTimeouts.values()) {
clearTimeout(fallbackTimeout)
}
deps.sessionStates.clear()
deps.sessionLastAccess.clear()
deps.sessionRetryInFlight.clear()
deps.sessionAwaitingFallbackResult.clear()
deps.sessionFallbackTimeouts.clear()
}
return {
event: eventHandler,
"chat.message": chatMessageHandler,
dispose,
} as RuntimeFallbackHook
}

View File

@@ -1,6 +1,40 @@
import type { PluginInput } from "@opencode-ai/plugin"
import type { RuntimeFallbackConfig, OhMyOpenCodeConfig } from "../../config"
export interface RuntimeFallbackInterval {
unref: () => void
}
export type RuntimeFallbackTimeout = object | number
export interface RuntimeFallbackPluginInput {
client: {
session: {
abort: (input: { path: { id: string } }) => Promise<unknown>
messages: (input: { path: { id: string }; query: { directory: string } }) => Promise<unknown>
promptAsync: (input: {
path: { id: string }
body: {
agent?: string
model: { providerID: string; modelID: string }
parts: Array<{ type: "text"; text: string }>
}
query: { directory: string }
}) => Promise<unknown>
}
tui: {
showToast: (input: {
body: {
title: string
message: string
variant: "success" | "error" | "info" | "warning"
duration: number
}
}) => Promise<unknown>
}
}
directory: string
}
export interface FallbackState {
originalModel: string
currentModel: string
@@ -26,10 +60,11 @@ export interface RuntimeFallbackOptions {
export interface RuntimeFallbackHook {
event: (input: { event: { type: string; properties?: unknown } }) => Promise<void>
"chat.message"?: (input: { sessionID: string; agent?: string; model?: { providerID: string; modelID: string } }, output: { message: { model?: { providerID: string; modelID: string } }; parts?: Array<{ type: string; text?: string }> }) => Promise<void>
dispose?: () => void
}
export interface HookDeps {
ctx: PluginInput
ctx: RuntimeFallbackPluginInput
config: Required<RuntimeFallbackConfig>
options: RuntimeFallbackOptions | undefined
pluginConfig: OhMyOpenCodeConfig | undefined
@@ -37,5 +72,5 @@ export interface HookDeps {
sessionLastAccess: Map<string, number>
sessionRetryInFlight: Set<string>
sessionAwaitingFallbackResult: Set<string>
sessionFallbackTimeouts: Map<string, ReturnType<typeof setTimeout>>
sessionFallbackTimeouts: Map<string, RuntimeFallbackTimeout>
}

View File

@@ -0,0 +1,101 @@
declare module "bun:test" {
export interface Matchers {
toBeDefined(): void
toBeUndefined(): void
toHaveLength(expected: number): void
}
}
import { afterAll, afterEach, describe, expect, it, mock } from "bun:test"
import * as actualSessionStateModule from "./session-state"
import type { SessionStateStore } from "./session-state"
let createdSessionStateStore: SessionStateStore | undefined
const createActualSessionStateStore = actualSessionStateModule.createSessionStateStore
const mockModule = mock as typeof mock & {
module: (specifier: string, factory: () => unknown) => void
}
mockModule.module("./session-state", () => ({
...actualSessionStateModule,
createSessionStateStore: () => {
const sessionStateStore = createActualSessionStateStore()
createdSessionStateStore = sessionStateStore
return sessionStateStore
},
}))
const { createTodoContinuationEnforcer } = await import(".")
type PluginInput = Parameters<typeof createTodoContinuationEnforcer>[0]
function createMockPluginInput(): PluginInput {
return {
directory: "/tmp/test",
} as PluginInput
}
function getCreatedSessionStateStore(): SessionStateStore {
if (!createdSessionStateStore) {
throw new Error("expected session state store to be created")
}
return createdSessionStateStore
}
describe("todo-continuation-enforcer dispose", () => {
afterEach(() => {
createdSessionStateStore?.shutdown()
createdSessionStateStore = undefined
})
afterAll(() => {
mockModule.module("./session-state", () => actualSessionStateModule)
})
it("#given todo-continuation-enforcer created #when dispose exists on return value #then it is a function", () => {
// given
const enforcer = createTodoContinuationEnforcer(createMockPluginInput())
// when
const { dispose } = enforcer
// then
expect(typeof dispose).toBe("function")
enforcer.dispose()
})
it("#given enforcer with active session states #when dispose is called #then internal session state store is shut down", () => {
// given
const originalClearInterval = globalThis.clearInterval
const clearIntervalCalls: Array<Parameters<typeof clearInterval>[0]> = []
globalThis.clearInterval = ((timer?: Parameters<typeof clearInterval>[0]) => {
clearIntervalCalls.push(timer)
return originalClearInterval(timer)
}) as typeof clearInterval
try {
const enforcer = createTodoContinuationEnforcer(createMockPluginInput())
const sessionStateStore = getCreatedSessionStateStore()
enforcer.markRecovering("session-1")
enforcer.markRecovering("session-2")
expect(sessionStateStore.getExistingState("session-1")).toBeDefined()
expect(sessionStateStore.getExistingState("session-2")).toBeDefined()
// when
enforcer.dispose()
// then
expect(clearIntervalCalls).toHaveLength(1)
expect(sessionStateStore.getExistingState("session-1")).toBeUndefined()
expect(sessionStateStore.getExistingState("session-2")).toBeUndefined()
} finally {
globalThis.clearInterval = originalClearInterval
}
})
})

View File

@@ -56,5 +56,6 @@ export function createTodoContinuationEnforcer(
markRecovering,
markRecoveryComplete,
cancelAllCountdowns,
dispose: () => sessionStateStore.shutdown(),
}
}

View File

@@ -13,6 +13,7 @@ export interface TodoContinuationEnforcer {
markRecovering: (sessionID: string) => void
markRecoveryComplete: (sessionID: string) => void
cancelAllCountdowns: () => void
dispose: () => void
}
export interface Todo {

View File

@@ -7,6 +7,7 @@ import { createHooks } from "./create-hooks"
import { createManagers } from "./create-managers"
import { createTools } from "./create-tools"
import { createPluginInterface } from "./plugin-interface"
import { createPluginDispose, type PluginDispose } from "./plugin-dispose"
import { loadPluginConfig } from "./plugin-config"
import { createModelCacheState } from "./plugin-state"
@@ -14,6 +15,8 @@ import { createFirstMessageVariantGate } from "./shared/first-message-variant"
import { injectServerAuthIntoClient, log } from "./shared"
import { startTmuxCheck } from "./tools"
let activePluginDispose: PluginDispose | null = null
const OhMyOpenCodePlugin: Plugin = async (ctx) => {
// Initialize config context for plugin runtime (prevents warnings from hooks)
initConfigContext("opencode", null)
@@ -23,6 +26,7 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
injectServerAuthIntoClient(ctx.client)
startTmuxCheck()
await activePluginDispose?.()
const pluginConfig = loadPluginConfig(ctx.directory, ctx)
const disabledHooks = new Set(pluginConfig.disabled_hooks ?? [])
@@ -67,6 +71,12 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
availableSkills: toolsResult.availableSkills,
})
const dispose = createPluginDispose({
backgroundManager: managers.backgroundManager,
skillMcpManager: managers.skillMcpManager,
disposeHooks: hooks.disposeHooks,
})
const pluginInterface = createPluginInterface({
ctx,
pluginConfig,
@@ -76,6 +86,8 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
tools: toolsResult.filteredTools,
})
activePluginDispose = dispose
return {
...pluginInterface,

175
src/plugin-dispose.test.ts Normal file
View File

@@ -0,0 +1,175 @@
import { describe, expect, spyOn, test } from "bun:test"
import { disposeCreatedHooks } from "./create-hooks"
import { createPluginDispose } from "./plugin-dispose"
describe("createPluginDispose", () => {
test("#given plugin with active managers and hooks #when dispose() is called #then backgroundManager.shutdown() is called", async () => {
// given
const backgroundManager = {
shutdown: async (): Promise<void> => {},
}
const skillMcpManager = {
disconnectAll: async (): Promise<void> => {},
}
const shutdownSpy = spyOn(backgroundManager, "shutdown")
const dispose = createPluginDispose({
backgroundManager,
skillMcpManager,
disposeHooks: (): void => {},
})
// when
await dispose()
// then
expect(shutdownSpy).toHaveBeenCalledTimes(1)
})
test("#given plugin with active MCP connections #when dispose() is called #then skillMcpManager.disconnectAll() is called", async () => {
// given
const backgroundManager = {
shutdown: async (): Promise<void> => {},
}
const skillMcpManager = {
disconnectAll: async (): Promise<void> => {},
}
const disconnectAllSpy = spyOn(skillMcpManager, "disconnectAll")
const dispose = createPluginDispose({
backgroundManager,
skillMcpManager,
disposeHooks: (): void => {},
})
// when
await dispose()
// then
expect(disconnectAllSpy).toHaveBeenCalledTimes(1)
})
test("#given plugin with hooks that have dispose #when dispose() is called #then each hook's dispose is called", async () => {
// given
const runtimeFallback = {
dispose: (): void => {},
}
const todoContinuationEnforcer = {
dispose: (): void => {},
}
const autoSlashCommand = {
dispose: (): void => {},
}
const runtimeFallbackDisposeSpy = spyOn(runtimeFallback, "dispose")
const todoContinuationEnforcerDisposeSpy = spyOn(todoContinuationEnforcer, "dispose")
const autoSlashCommandDisposeSpy = spyOn(autoSlashCommand, "dispose")
const dispose = createPluginDispose({
backgroundManager: {
shutdown: async (): Promise<void> => {},
},
skillMcpManager: {
disconnectAll: async (): Promise<void> => {},
},
disposeHooks: (): void => {
disposeCreatedHooks({
runtimeFallback,
todoContinuationEnforcer,
autoSlashCommand,
})
},
})
// when
await dispose()
// then
expect(runtimeFallbackDisposeSpy).toHaveBeenCalledTimes(1)
expect(todoContinuationEnforcerDisposeSpy).toHaveBeenCalledTimes(1)
expect(autoSlashCommandDisposeSpy).toHaveBeenCalledTimes(1)
})
test("#given dispose already called #when dispose() called again #then no errors", async () => {
// given
const backgroundManager = {
shutdown: async (): Promise<void> => {},
}
const skillMcpManager = {
disconnectAll: async (): Promise<void> => {},
}
const disposeHooks = {
run: (): void => {},
}
const shutdownSpy = spyOn(backgroundManager, "shutdown")
const disconnectAllSpy = spyOn(skillMcpManager, "disconnectAll")
const disposeHooksSpy = spyOn(disposeHooks, "run")
const dispose = createPluginDispose({
backgroundManager,
skillMcpManager,
disposeHooks: disposeHooks.run,
})
// when
await dispose()
await dispose()
// then
expect(shutdownSpy).toHaveBeenCalledTimes(1)
expect(disconnectAllSpy).toHaveBeenCalledTimes(1)
expect(disposeHooksSpy).toHaveBeenCalledTimes(1)
})
test("#given backgroundManager.shutdown() throws #when dispose() is called #then skillMcpManager.disconnectAll() and disposeHooks() are still called", async () => {
// given
const backgroundManager = {
shutdown: async (): Promise<void> => {
throw new Error("shutdown failed")
},
}
const skillMcpManager = {
disconnectAll: async (): Promise<void> => {},
}
const disposeHooksCalls: number[] = []
const disconnectAllSpy = spyOn(skillMcpManager, "disconnectAll")
const dispose = createPluginDispose({
backgroundManager,
skillMcpManager,
disposeHooks: (): void => {
disposeHooksCalls.push(1)
},
})
// when
await dispose()
// then
expect(disconnectAllSpy).toHaveBeenCalledTimes(1)
expect(disposeHooksCalls).toHaveLength(1)
})
test("#given skillMcpManager.disconnectAll() throws #when dispose() is called #then disposeHooks() is still called", async () => {
// given
const backgroundManager = {
shutdown: async (): Promise<void> => {},
}
const skillMcpManager = {
disconnectAll: async (): Promise<void> => {
throw new Error("disconnectAll failed")
},
}
const disposeHooksCalls: number[] = []
const shutdownSpy = spyOn(backgroundManager, "shutdown")
const dispose = createPluginDispose({
backgroundManager,
skillMcpManager,
disposeHooks: (): void => {
disposeHooksCalls.push(1)
},
})
// when
await dispose()
// then
expect(shutdownSpy).toHaveBeenCalledTimes(1)
expect(disposeHooksCalls).toHaveLength(1)
})
})

43
src/plugin-dispose.ts Normal file
View File

@@ -0,0 +1,43 @@
import { log } from "./shared"
export type PluginDispose = () => Promise<void>
export function createPluginDispose(args: {
backgroundManager: {
shutdown: () => void | Promise<void>
}
skillMcpManager: {
disconnectAll: () => Promise<void>
}
disposeHooks: () => void
}): PluginDispose {
const { backgroundManager, skillMcpManager, disposeHooks } = args
let disposePromise: Promise<void> | null = null
return async (): Promise<void> => {
if (disposePromise) {
await disposePromise
return
}
disposePromise = (async (): Promise<void> => {
try {
await backgroundManager.shutdown()
} catch (error) {
log("[plugin-dispose] backgroundManager.shutdown() error:", error)
}
try {
await skillMcpManager.disconnectAll()
} catch (error) {
log("[plugin-dispose] skillMcpManager.disconnectAll() error:", error)
}
try {
disposeHooks()
} catch (error) {
log("[plugin-dispose] disposeHooks() error:", error)
}
})()
await disposePromise
}
}

View File

@@ -190,6 +190,7 @@ export function createEventHandler(args: {
await Promise.resolve(hooks.compactionTodoPreserver?.event?.(input));
await Promise.resolve(hooks.writeExistingFileGuard?.event?.(input));
await Promise.resolve(hooks.atlasHook?.handler?.(input));
await Promise.resolve(hooks.autoSlashCommand?.event?.(input));
};
const recentSyntheticIdles = new Map<string, number>();

View File

@@ -0,0 +1,184 @@
import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"
import {
_resetForTesting,
subagentSessions,
syncSubagentSessions,
} from "../../features/claude-code-session-state"
import { executeSync } from "./sync-executor"
type ExecuteSyncArgs = Parameters<typeof executeSync>[0]
type ExecuteSyncToolContext = Parameters<typeof executeSync>[1]
type ExecuteSyncDeps = NonNullable<Parameters<typeof executeSync>[3]>
function createArgs(): ExecuteSyncArgs {
return {
subagent_type: "explore",
description: "cleanup leak",
prompt: "find something",
run_in_background: false,
}
}
function createToolContext(): ExecuteSyncToolContext {
return {
sessionID: "parent-session",
messageID: "msg-1",
agent: "sisyphus",
abort: new AbortController().signal,
metadata: mock(async () => {}),
}
}
function createContext(promptAsync: ReturnType<typeof mock>) {
return {
client: {
session: {
promptAsync,
},
},
}
}
function createDependencies(overrides?: Partial<ExecuteSyncDeps>): ExecuteSyncDeps {
return {
createOrGetSession: mock(async () => ({ sessionID: "ses-default", isNew: true })),
waitForCompletion: mock(async () => {}),
processMessages: mock(async () => "agent response"),
setSessionFallbackChain: mock(() => {}),
clearSessionFallbackChain: mock(() => {}),
...overrides,
}
}
describe("executeSync session cleanup", () => {
beforeEach(() => {
_resetForTesting()
})
afterEach(() => {
_resetForTesting()
})
describe("#given executeSync creates a session", () => {
test("#when execution completes successfully #then sessionID is removed from subagentSessions and syncSubagentSessions", async () => {
// given
const sessionID = "ses-cleanup-success"
const args = createArgs()
const toolContext = createToolContext()
const promptAsync = mock(async () => ({ data: {} }))
const deps = createDependencies({
createOrGetSession: mock(async () => {
subagentSessions.add(sessionID)
syncSubagentSessions.add(sessionID)
return { sessionID, isNew: true }
}),
waitForCompletion: mock(async (createdSessionID: string) => {
expect(createdSessionID).toBe(sessionID)
expect(subagentSessions.has(sessionID)).toBe(true)
expect(syncSubagentSessions.has(sessionID)).toBe(true)
}),
})
expect(subagentSessions.has(sessionID)).toBe(false)
expect(syncSubagentSessions.has(sessionID)).toBe(false)
// when
const result = await executeSync(args, toolContext, createContext(promptAsync) as never, deps)
// then
expect(result).toContain(`session_id: ${sessionID}`)
expect(subagentSessions.has(sessionID)).toBe(false)
expect(syncSubagentSessions.has(sessionID)).toBe(false)
})
test("#when execution throws an error #then sessionID is still removed from both Sets", async () => {
// given
const sessionID = "ses-cleanup-error"
const args = createArgs()
const toolContext = createToolContext()
const promptAsync = mock(async () => ({ data: {} }))
const deps = createDependencies({
createOrGetSession: mock(async () => {
subagentSessions.add(sessionID)
syncSubagentSessions.add(sessionID)
return { sessionID, isNew: true }
}),
waitForCompletion: mock(async (createdSessionID: string) => {
expect(createdSessionID).toBe(sessionID)
expect(subagentSessions.has(sessionID)).toBe(true)
expect(syncSubagentSessions.has(sessionID)).toBe(true)
throw new Error("poll exploded")
}),
})
// when
const resultPromise = executeSync(args, toolContext, createContext(promptAsync) as never, deps)
// then
let thrownError: Error | undefined
try {
await resultPromise
} catch (error) {
if (error instanceof Error) {
thrownError = error
} else {
throw error
}
}
expect(thrownError?.message).toBe("poll exploded")
expect(subagentSessions.has(sessionID)).toBe(false)
expect(syncSubagentSessions.has(sessionID)).toBe(false)
})
})
describe("#given executeSync reuses an existing session", () => {
test("#when execution completes successfully #then the reused session is tracked in both Sets", async () => {
// given
const sessionID = "ses-reused"
const args = { ...createArgs(), session_id: sessionID }
const toolContext = createToolContext()
const promptAsync = mock(async () => ({ data: {} }))
const deps = createDependencies({
createOrGetSession: mock(async () => ({ sessionID, isNew: false })),
waitForCompletion: mock(async (createdSessionID: string) => {
expect(createdSessionID).toBe(sessionID)
expect(subagentSessions.has(sessionID)).toBe(true)
expect(syncSubagentSessions.has(sessionID)).toBe(true)
}),
})
expect(subagentSessions.has(sessionID)).toBe(false)
expect(syncSubagentSessions.has(sessionID)).toBe(false)
// when
const result = await executeSync(args, toolContext, createContext(promptAsync) as never, deps)
// then
expect(result).toContain(`session_id: ${sessionID}`)
expect(subagentSessions.has(sessionID)).toBe(true)
expect(syncSubagentSessions.has(sessionID)).toBe(true)
})
test("#when execution applies a fallback chain #then it clears that chain in finally", async () => {
// given
const sessionID = "ses-reused-fallback"
const args = { ...createArgs(), session_id: sessionID }
const toolContext = createToolContext()
const promptAsync = mock(async () => ({ data: {} }))
const clearSessionFallbackChain = mock(() => {})
const deps = createDependencies({
createOrGetSession: mock(async () => ({ sessionID, isNew: false })),
clearSessionFallbackChain,
})
const fallbackChain = [{ providers: ["openai"], model: "gpt-5.4" }]
// when
await executeSync(args, toolContext, createContext(promptAsync) as never, deps, fallbackChain)
// then
expect(clearSessionFallbackChain).toHaveBeenCalledWith(sessionID)
})
})
})

View File

@@ -24,6 +24,7 @@ type Dependencies = {
waitForCompletion: ReturnType<typeof mock>
processMessages: ReturnType<typeof mock>
setSessionFallbackChain: ReturnType<typeof mock>
clearSessionFallbackChain: ReturnType<typeof mock>
}
async function importExecuteSync(): Promise<ExecuteSync> {
@@ -37,6 +38,7 @@ function createDependencies(overrides?: Partial<Dependencies>): Dependencies {
waitForCompletion: mock(async () => {}),
processMessages: mock(async () => "agent response"),
setSessionFallbackChain: mock(() => {}),
clearSessionFallbackChain: mock(() => {}),
...overrides,
}
}
@@ -259,6 +261,7 @@ describe("executeSync", () => {
waitForCompletion: mock(async () => {}),
processMessages: mock(async () => "agent response"),
setSessionFallbackChain: mock(() => {}),
clearSessionFallbackChain: mock(() => {}),
}
const spawnReservation = {

View File

@@ -1,12 +1,12 @@
import type { CallOmoAgentArgs } from "./types"
import type { PluginInput } from "@opencode-ai/plugin"
import { log } from "../../shared"
import { subagentSessions, syncSubagentSessions } from "../../features/claude-code-session-state"
import { clearSessionFallbackChain, setSessionFallbackChain } from "../../hooks/model-fallback/hook"
import { getAgentToolRestrictions, log } from "../../shared"
import type { FallbackEntry } from "../../shared/model-requirements"
import { getAgentToolRestrictions } from "../../shared"
import { setSessionFallbackChain } from "../../hooks/model-fallback/hook"
import { createOrGetSession } from "./session-creator"
import { waitForCompletion } from "./completion-poller"
import { processMessages } from "./message-processor"
import { createOrGetSession } from "./session-creator"
type SessionWithPromptAsync = {
promptAsync: (opts: { path: { id: string }; body: Record<string, unknown> }) => Promise<unknown>
@@ -17,6 +17,7 @@ type ExecuteSyncDeps = {
waitForCompletion: typeof waitForCompletion
processMessages: typeof processMessages
setSessionFallbackChain: typeof setSessionFallbackChain
clearSessionFallbackChain: typeof clearSessionFallbackChain
}
type SpawnReservation = {
@@ -29,6 +30,7 @@ const defaultDeps: ExecuteSyncDeps = {
waitForCompletion,
processMessages,
setSessionFallbackChain,
clearSessionFallbackChain,
}
export async function executeSync(
@@ -46,10 +48,15 @@ export async function executeSync(
spawnReservation?: SpawnReservation,
): Promise<string> {
let sessionID: string | undefined
let createdSessionForExecution = false
let appliedFallbackChain = false
try {
const session = await deps.createOrGetSession(args, toolContext, ctx)
sessionID = session.sessionID
createdSessionForExecution = session.isNew
subagentSessions.add(sessionID)
syncSubagentSessions.add(sessionID)
if (session.isNew) {
spawnReservation?.commit()
@@ -57,12 +64,15 @@ export async function executeSync(
if (fallbackChain && fallbackChain.length > 0) {
deps.setSessionFallbackChain(sessionID, fallbackChain)
appliedFallbackChain = true
}
await toolContext.metadata?.({
title: args.description,
metadata: { sessionId: sessionID },
})
await Promise.resolve(
toolContext.metadata?.({
title: args.description,
metadata: { sessionId: sessionID },
})
)
log(`[call_omo_agent] Sending prompt to session ${sessionID}`)
log(`[call_omo_agent] Prompt text:`, args.prompt.substring(0, 100))
@@ -93,12 +103,18 @@ export async function executeSync(
const responseText = await deps.processMessages(sessionID, ctx)
const output =
responseText + "\n\n" + ["<task_metadata>", `session_id: ${sessionID}`, "</task_metadata>"].join("\n")
return output
return responseText + "\n\n" + ["<task_metadata>", `session_id: ${sessionID}`, "</task_metadata>"].join("\n")
} catch (error) {
spawnReservation?.rollback()
throw error
} finally {
if (sessionID && appliedFallbackChain) {
deps.clearSessionFallbackChain(sessionID)
}
if (sessionID && createdSessionForExecution) {
subagentSessions.delete(sessionID)
syncSubagentSessions.delete(sessionID)
}
}
}