Merge pull request #2458 from code-yeongyu/fix/memory-leaks
fix: resolve 12 memory leaks (3 critical + 9 high)
This commit is contained in:
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -11,6 +11,20 @@ import { createSkillHooks } from "./plugin/hooks/create-skill-hooks"
|
||||
|
||||
export type CreatedHooks = ReturnType<typeof createHooks>
|
||||
|
||||
type DisposableHook = { dispose?: () => void } | null | undefined
|
||||
|
||||
export type DisposableCreatedHooks = {
|
||||
runtimeFallback?: DisposableHook
|
||||
todoContinuationEnforcer?: DisposableHook
|
||||
autoSlashCommand?: DisposableHook
|
||||
}
|
||||
|
||||
export function disposeCreatedHooks(hooks: DisposableCreatedHooks): void {
|
||||
hooks.runtimeFallback?.dispose?.()
|
||||
hooks.todoContinuationEnforcer?.dispose?.()
|
||||
hooks.autoSlashCommand?.dispose?.()
|
||||
}
|
||||
|
||||
export function createHooks(args: {
|
||||
ctx: PluginContext
|
||||
pluginConfig: OhMyOpenCodeConfig
|
||||
@@ -58,9 +72,16 @@ export function createHooks(args: {
|
||||
availableSkills,
|
||||
})
|
||||
|
||||
return {
|
||||
const hooks = {
|
||||
...core,
|
||||
...continuation,
|
||||
...skill,
|
||||
}
|
||||
|
||||
return {
|
||||
...hooks,
|
||||
disposeHooks: (): void => {
|
||||
disposeCreatedHooks(hooks)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,8 +53,8 @@ export function createManagers(args: {
|
||||
|
||||
log("[index] onSubagentSessionCreated callback completed")
|
||||
},
|
||||
onShutdown: () => {
|
||||
tmuxSessionManager.cleanup().catch((error) => {
|
||||
onShutdown: async () => {
|
||||
await tmuxSessionManager.cleanup().catch((error) => {
|
||||
log("[index] tmux cleanup error during shutdown:", error)
|
||||
})
|
||||
},
|
||||
|
||||
193
src/features/background-agent/cancel-task-cleanup.test.ts
Normal file
193
src/features/background-agent/cancel-task-cleanup.test.ts
Normal file
@@ -0,0 +1,193 @@
|
||||
import { tmpdir } from "node:os"
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import { afterEach, describe, expect, test } from "bun:test"
|
||||
import { ConcurrencyManager } from "./concurrency"
|
||||
import { BackgroundManager } from "./manager"
|
||||
import type { BackgroundTask, LaunchInput } from "./types"
|
||||
|
||||
const managersToShutdown: BackgroundManager[] = []
|
||||
|
||||
afterEach(() => {
|
||||
while (managersToShutdown.length > 0) managersToShutdown.pop()?.shutdown()
|
||||
})
|
||||
|
||||
function createBackgroundManager(config?: { defaultConcurrency?: number }): BackgroundManager {
|
||||
const directory = tmpdir()
|
||||
const client = { session: {} as PluginInput["client"]["session"] } as PluginInput["client"]
|
||||
|
||||
Reflect.set(client.session, "abort", async () => ({ data: true }))
|
||||
Reflect.set(client.session, "create", async () => ({ data: { id: `session-${crypto.randomUUID().slice(0, 8)}` } }))
|
||||
Reflect.set(client.session, "get", async () => ({ data: { directory } }))
|
||||
Reflect.set(client.session, "messages", async () => ({ data: [] }))
|
||||
Reflect.set(client.session, "prompt", async () => ({ data: { info: {}, parts: [] } }))
|
||||
Reflect.set(client.session, "promptAsync", async () => ({ data: undefined }))
|
||||
|
||||
const manager = new BackgroundManager({
|
||||
$: {} as PluginInput["$"],
|
||||
client,
|
||||
directory,
|
||||
project: {} as PluginInput["project"],
|
||||
serverUrl: new URL("http://localhost"),
|
||||
worktree: directory,
|
||||
}, config)
|
||||
managersToShutdown.push(manager)
|
||||
return manager
|
||||
}
|
||||
|
||||
function createMockTask(overrides: Partial<BackgroundTask> & { id: string; parentSessionID: string }): BackgroundTask {
|
||||
return {
|
||||
id: overrides.id,
|
||||
sessionID: overrides.sessionID,
|
||||
parentSessionID: overrides.parentSessionID,
|
||||
parentMessageID: overrides.parentMessageID ?? "parent-message-id",
|
||||
description: overrides.description ?? "test task",
|
||||
prompt: overrides.prompt ?? "test prompt",
|
||||
agent: overrides.agent ?? "test-agent",
|
||||
status: overrides.status ?? "running",
|
||||
queuedAt: overrides.queuedAt,
|
||||
startedAt: overrides.startedAt ?? new Date(),
|
||||
completedAt: overrides.completedAt,
|
||||
error: overrides.error,
|
||||
model: overrides.model,
|
||||
concurrencyKey: overrides.concurrencyKey,
|
||||
concurrencyGroup: overrides.concurrencyGroup,
|
||||
progress: overrides.progress,
|
||||
}
|
||||
}
|
||||
|
||||
function getTaskMap(manager: BackgroundManager): Map<string, BackgroundTask> { return Reflect.get(manager, "tasks") as Map<string, BackgroundTask> }
|
||||
|
||||
function getPendingByParent(manager: BackgroundManager): Map<string, Set<string>> { return Reflect.get(manager, "pendingByParent") as Map<string, Set<string>> }
|
||||
|
||||
function getQueuesByKey(manager: BackgroundManager): Map<string, Array<{ task: BackgroundTask; input: LaunchInput }>> { return Reflect.get(manager, "queuesByKey") as Map<string, Array<{ task: BackgroundTask; input: LaunchInput }>> }
|
||||
|
||||
function getConcurrencyManager(manager: BackgroundManager): ConcurrencyManager { return Reflect.get(manager, "concurrencyManager") as ConcurrencyManager }
|
||||
|
||||
function getCompletionTimers(manager: BackgroundManager): Map<string, ReturnType<typeof setTimeout>> { return Reflect.get(manager, "completionTimers") as Map<string, ReturnType<typeof setTimeout>> }
|
||||
|
||||
async function processKeyForTest(manager: BackgroundManager, key: string): Promise<void> {
|
||||
const processKey = Reflect.get(manager, "processKey") as (key: string) => Promise<void>
|
||||
await processKey.call(manager, key)
|
||||
}
|
||||
|
||||
function runScheduledCleanup(manager: BackgroundManager, taskId: string): void {
|
||||
const timer = getCompletionTimers(manager).get(taskId)
|
||||
if (!timer) {
|
||||
throw new Error(`Expected cleanup timer for task ${taskId}`)
|
||||
}
|
||||
|
||||
const onTimeout = Reflect.get(timer, "_onTimeout") as (() => void) | undefined
|
||||
if (!onTimeout) {
|
||||
throw new Error(`Expected cleanup callback for task ${taskId}`)
|
||||
}
|
||||
|
||||
onTimeout()
|
||||
}
|
||||
|
||||
describe("BackgroundManager.cancelTask cleanup", () => {
|
||||
test("#given a running task in BackgroundManager #when cancelTask called with skipNotification=true #then task is eventually removed from this.tasks Map", async () => {
|
||||
// given
|
||||
const manager = createBackgroundManager()
|
||||
const task = createMockTask({
|
||||
id: "task-skip-notification-cleanup",
|
||||
parentSessionID: "parent-session-skip-notification-cleanup",
|
||||
sessionID: "session-skip-notification-cleanup",
|
||||
})
|
||||
|
||||
getTaskMap(manager).set(task.id, task)
|
||||
getPendingByParent(manager).set(task.parentSessionID, new Set([task.id]))
|
||||
|
||||
// when
|
||||
const cancelled = await manager.cancelTask(task.id, {
|
||||
skipNotification: true,
|
||||
source: "test",
|
||||
})
|
||||
|
||||
// then
|
||||
expect(cancelled).toBe(true)
|
||||
expect(getPendingByParent(manager).get(task.parentSessionID)).toBeUndefined()
|
||||
runScheduledCleanup(manager, task.id)
|
||||
expect(manager.getTask(task.id)).toBeUndefined()
|
||||
})
|
||||
|
||||
test("#given a running task #when cancelTask called with skipNotification=false #then task is also eventually removed", async () => {
|
||||
// given
|
||||
const manager = createBackgroundManager()
|
||||
const task = createMockTask({
|
||||
id: "task-notify-cleanup",
|
||||
parentSessionID: "parent-session-notify-cleanup",
|
||||
sessionID: "session-notify-cleanup",
|
||||
})
|
||||
|
||||
getTaskMap(manager).set(task.id, task)
|
||||
getPendingByParent(manager).set(task.parentSessionID, new Set([task.id]))
|
||||
|
||||
// when
|
||||
const cancelled = await manager.cancelTask(task.id, {
|
||||
skipNotification: false,
|
||||
source: "test",
|
||||
})
|
||||
|
||||
// then
|
||||
expect(cancelled).toBe(true)
|
||||
runScheduledCleanup(manager, task.id)
|
||||
expect(manager.getTask(task.id)).toBeUndefined()
|
||||
})
|
||||
|
||||
test("#given a running task #when cancelTask called with skipNotification=true #then concurrency slot is freed and pending tasks can start", async () => {
|
||||
// given
|
||||
const manager = createBackgroundManager({ defaultConcurrency: 1 })
|
||||
const concurrencyManager = getConcurrencyManager(manager)
|
||||
const concurrencyKey = "test-provider/test-model"
|
||||
await concurrencyManager.acquire(concurrencyKey)
|
||||
|
||||
const runningTask = createMockTask({
|
||||
id: "task-running-before-cancel",
|
||||
parentSessionID: "parent-session-concurrency-cleanup",
|
||||
sessionID: "session-running-before-cancel",
|
||||
concurrencyKey,
|
||||
})
|
||||
const pendingTask = createMockTask({
|
||||
id: "task-pending-after-cancel",
|
||||
parentSessionID: runningTask.parentSessionID,
|
||||
status: "pending",
|
||||
startedAt: undefined,
|
||||
queuedAt: new Date(),
|
||||
model: { providerID: "test-provider", modelID: "test-model" },
|
||||
})
|
||||
const queuedInput: LaunchInput = {
|
||||
agent: pendingTask.agent,
|
||||
description: pendingTask.description,
|
||||
model: pendingTask.model,
|
||||
parentMessageID: pendingTask.parentMessageID,
|
||||
parentSessionID: pendingTask.parentSessionID,
|
||||
prompt: pendingTask.prompt,
|
||||
}
|
||||
|
||||
getTaskMap(manager).set(runningTask.id, runningTask)
|
||||
getTaskMap(manager).set(pendingTask.id, pendingTask)
|
||||
getPendingByParent(manager).set(runningTask.parentSessionID, new Set([runningTask.id, pendingTask.id]))
|
||||
getQueuesByKey(manager).set(concurrencyKey, [{ input: queuedInput, task: pendingTask }])
|
||||
|
||||
Reflect.set(manager, "startTask", async ({ task }: { task: BackgroundTask; input: LaunchInput }) => {
|
||||
task.status = "running"
|
||||
task.startedAt = new Date()
|
||||
task.sessionID = "session-started-after-cancel"
|
||||
task.concurrencyKey = concurrencyKey
|
||||
task.concurrencyGroup = concurrencyKey
|
||||
})
|
||||
|
||||
// when
|
||||
const cancelled = await manager.cancelTask(runningTask.id, {
|
||||
abortSession: false,
|
||||
skipNotification: true,
|
||||
source: "test",
|
||||
})
|
||||
await processKeyForTest(manager, concurrencyKey)
|
||||
|
||||
// then
|
||||
expect(cancelled).toBe(true)
|
||||
expect(concurrencyManager.getCount(concurrencyKey)).toBe(1)
|
||||
expect(manager.getTask(pendingTask.id)?.status).toBe("running")
|
||||
})
|
||||
})
|
||||
@@ -3785,7 +3785,7 @@ describe("BackgroundManager.completionTimers - Memory Leak Fix", () => {
|
||||
manager.shutdown()
|
||||
})
|
||||
|
||||
test("should start cleanup timers only after all tasks complete", async () => {
|
||||
test("should start per-task cleanup timers independently of sibling completion", async () => {
|
||||
// given
|
||||
const client = {
|
||||
session: {
|
||||
@@ -3832,7 +3832,7 @@ describe("BackgroundManager.completionTimers - Memory Leak Fix", () => {
|
||||
|
||||
// then
|
||||
const completionTimers = getCompletionTimers(manager)
|
||||
expect(completionTimers.size).toBe(0)
|
||||
expect(completionTimers.size).toBe(1)
|
||||
|
||||
// when
|
||||
await (manager as unknown as { notifyParentSession: (task: BackgroundTask) => Promise<void> })
|
||||
|
||||
@@ -116,11 +116,12 @@ export class BackgroundManager {
|
||||
private config?: BackgroundTaskConfig
|
||||
private tmuxEnabled: boolean
|
||||
private onSubagentSessionCreated?: OnSubagentSessionCreated
|
||||
private onShutdown?: () => void
|
||||
private onShutdown?: () => void | Promise<void>
|
||||
|
||||
private queuesByKey: Map<string, QueueItem[]> = new Map()
|
||||
private processingKeys: Set<string> = new Set()
|
||||
private completionTimers: Map<string, ReturnType<typeof setTimeout>> = new Map()
|
||||
private completedTaskSummaries: Map<string, Array<{id: string, description: string}>> = new Map()
|
||||
private idleDeferralTimers: Map<string, ReturnType<typeof setTimeout>> = new Map()
|
||||
private notificationQueueByParent: Map<string, Promise<void>> = new Map()
|
||||
private rootDescendantCounts: Map<string, number>
|
||||
@@ -133,7 +134,7 @@ export class BackgroundManager {
|
||||
options?: {
|
||||
tmuxConfig?: TmuxConfig
|
||||
onSubagentSessionCreated?: OnSubagentSessionCreated
|
||||
onShutdown?: () => void
|
||||
onShutdown?: () => void | Promise<void>
|
||||
enableParentSessionNotifications?: boolean
|
||||
}
|
||||
) {
|
||||
@@ -906,6 +907,13 @@ export class BackgroundManager {
|
||||
this.idleDeferralTimers.delete(task.id)
|
||||
}
|
||||
|
||||
this.cleanupPendingByParent(task)
|
||||
this.clearNotificationsForTask(task.id)
|
||||
const toastManager = getTaskToastManager()
|
||||
if (toastManager) {
|
||||
toastManager.removeTask(task.id)
|
||||
}
|
||||
this.scheduleTaskRemoval(task.id)
|
||||
if (task.sessionID) {
|
||||
SessionCategoryRegistry.remove(task.sessionID)
|
||||
}
|
||||
@@ -932,7 +940,12 @@ export class BackgroundManager {
|
||||
|
||||
this.pendingNotifications.delete(sessionID)
|
||||
|
||||
if (tasksToCancel.size === 0) return
|
||||
if (tasksToCancel.size === 0) {
|
||||
this.clearTaskHistoryWhenParentTasksGone(sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
const parentSessionsToClear = new Set<string>()
|
||||
|
||||
const deletedSessionIDs = new Set<string>([sessionID])
|
||||
for (const task of tasksToCancel.values()) {
|
||||
@@ -942,6 +955,8 @@ export class BackgroundManager {
|
||||
}
|
||||
|
||||
for (const task of tasksToCancel.values()) {
|
||||
parentSessionsToClear.add(task.parentSessionID)
|
||||
|
||||
if (task.status === "running" || task.status === "pending") {
|
||||
void this.cancelTask(task.id, {
|
||||
source: "session.deleted",
|
||||
@@ -959,6 +974,10 @@ export class BackgroundManager {
|
||||
}
|
||||
}
|
||||
|
||||
for (const parentSessionID of parentSessionsToClear) {
|
||||
this.clearTaskHistoryWhenParentTasksGone(parentSessionID)
|
||||
}
|
||||
|
||||
this.rootDescendantCounts.delete(sessionID)
|
||||
SessionCategoryRegistry.remove(sessionID)
|
||||
}
|
||||
@@ -1125,6 +1144,39 @@ export class BackgroundManager {
|
||||
}
|
||||
}
|
||||
|
||||
private clearTaskHistoryWhenParentTasksGone(parentSessionID: string | undefined): void {
|
||||
if (!parentSessionID) return
|
||||
if (this.getTasksByParentSession(parentSessionID).length > 0) return
|
||||
this.taskHistory.clearSession(parentSessionID)
|
||||
this.completedTaskSummaries.delete(parentSessionID)
|
||||
}
|
||||
|
||||
private scheduleTaskRemoval(taskId: string): void {
|
||||
const existingTimer = this.completionTimers.get(taskId)
|
||||
if (existingTimer) {
|
||||
clearTimeout(existingTimer)
|
||||
this.completionTimers.delete(taskId)
|
||||
}
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
this.completionTimers.delete(taskId)
|
||||
const task = this.tasks.get(taskId)
|
||||
if (task) {
|
||||
this.clearNotificationsForTask(taskId)
|
||||
this.tasks.delete(taskId)
|
||||
this.clearTaskHistoryWhenParentTasksGone(task.parentSessionID)
|
||||
if (task.sessionID) {
|
||||
subagentSessions.delete(task.sessionID)
|
||||
SessionCategoryRegistry.remove(task.sessionID)
|
||||
}
|
||||
log("[background-agent] Removed completed task from memory:", taskId)
|
||||
this.clearTaskHistoryWhenParentTasksGone(task?.parentSessionID)
|
||||
}
|
||||
}, TASK_CLEANUP_DELAY_MS)
|
||||
|
||||
this.completionTimers.set(taskId, timer)
|
||||
}
|
||||
|
||||
async cancelTask(
|
||||
taskId: string,
|
||||
options?: { source?: string; reason?: string; abortSession?: boolean; skipNotification?: boolean }
|
||||
@@ -1190,6 +1242,8 @@ export class BackgroundManager {
|
||||
removeTaskToastTracking(task.id)
|
||||
|
||||
if (options?.skipNotification) {
|
||||
this.cleanupPendingByParent(task)
|
||||
this.scheduleTaskRemoval(task.id)
|
||||
log(`[background-agent] Task cancelled via ${source} (notification skipped):`, task.id)
|
||||
return true
|
||||
}
|
||||
@@ -1328,6 +1382,14 @@ export class BackgroundManager {
|
||||
})
|
||||
}
|
||||
|
||||
if (!this.completedTaskSummaries.has(task.parentSessionID)) {
|
||||
this.completedTaskSummaries.set(task.parentSessionID, [])
|
||||
}
|
||||
this.completedTaskSummaries.get(task.parentSessionID)!.push({
|
||||
id: task.id,
|
||||
description: task.description,
|
||||
})
|
||||
|
||||
// Update pending tracking and check if all tasks complete
|
||||
const pendingSet = this.pendingByParent.get(task.parentSessionID)
|
||||
let allComplete = false
|
||||
@@ -1347,10 +1409,13 @@ export class BackgroundManager {
|
||||
}
|
||||
|
||||
const completedTasks = allComplete
|
||||
? Array.from(this.tasks.values())
|
||||
.filter(t => t.parentSessionID === task.parentSessionID && t.status !== "running" && t.status !== "pending")
|
||||
? (this.completedTaskSummaries.get(task.parentSessionID) ?? [{ id: task.id, description: task.description }])
|
||||
: []
|
||||
|
||||
if (allComplete) {
|
||||
this.completedTaskSummaries.delete(task.parentSessionID)
|
||||
}
|
||||
|
||||
const statusText = task.status === "completed"
|
||||
? "COMPLETED"
|
||||
: task.status === "interrupt"
|
||||
@@ -1480,29 +1545,8 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
})
|
||||
}
|
||||
|
||||
if (allComplete) {
|
||||
for (const completedTask of completedTasks) {
|
||||
const taskId = completedTask.id
|
||||
const existingTimer = this.completionTimers.get(taskId)
|
||||
if (existingTimer) {
|
||||
clearTimeout(existingTimer)
|
||||
this.completionTimers.delete(taskId)
|
||||
}
|
||||
const timer = setTimeout(() => {
|
||||
this.completionTimers.delete(taskId)
|
||||
const taskToRemove = this.tasks.get(taskId)
|
||||
if (taskToRemove) {
|
||||
this.clearNotificationsForTask(taskId)
|
||||
if (taskToRemove.sessionID) {
|
||||
subagentSessions.delete(taskToRemove.sessionID)
|
||||
SessionCategoryRegistry.remove(taskToRemove.sessionID)
|
||||
}
|
||||
this.tasks.delete(taskId)
|
||||
log("[background-agent] Removed completed task from memory:", taskId)
|
||||
}
|
||||
}, TASK_CLEANUP_DELAY_MS)
|
||||
this.completionTimers.set(taskId, timer)
|
||||
}
|
||||
if (task.status !== "running" && task.status !== "pending") {
|
||||
this.scheduleTaskRemoval(task.id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1554,6 +1598,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
}
|
||||
}
|
||||
}
|
||||
this.cleanupPendingByParent(task)
|
||||
this.markForNotification(task)
|
||||
this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)).catch(err => {
|
||||
log("[background-agent] Error in notifyParentSession for stale-pruned task:", { taskId: task.id, error: err })
|
||||
@@ -1657,7 +1702,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
* Cancels all pending concurrency waiters and clears timers.
|
||||
* Should be called when the plugin is unloaded.
|
||||
*/
|
||||
shutdown(): void {
|
||||
async shutdown(): Promise<void> {
|
||||
if (this.shutdownTriggered) return
|
||||
this.shutdownTriggered = true
|
||||
log("[background-agent] Shutting down BackgroundManager")
|
||||
@@ -1675,7 +1720,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
// Notify shutdown listeners (e.g., tmux cleanup)
|
||||
if (this.onShutdown) {
|
||||
try {
|
||||
this.onShutdown()
|
||||
await this.onShutdown()
|
||||
} catch (error) {
|
||||
log("[background-agent] Error in onShutdown callback:", error)
|
||||
}
|
||||
@@ -1708,6 +1753,8 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
this.rootDescendantCounts.clear()
|
||||
this.queuesByKey.clear()
|
||||
this.processingKeys.clear()
|
||||
this.taskHistory.clearAll()
|
||||
this.completedTaskSummaries.clear()
|
||||
this.unregisterProcessCleanup()
|
||||
log("[background-agent] Shutdown complete")
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ function registerProcessSignal(
|
||||
}
|
||||
|
||||
interface CleanupTarget {
|
||||
shutdown(): void
|
||||
shutdown(): void | Promise<void>
|
||||
}
|
||||
|
||||
const cleanupManagers = new Set<CleanupTarget>()
|
||||
@@ -35,7 +35,9 @@ export function registerManagerForCleanup(manager: CleanupTarget): void {
|
||||
const cleanupAll = () => {
|
||||
for (const m of cleanupManagers) {
|
||||
try {
|
||||
m.shutdown()
|
||||
void Promise.resolve(m.shutdown()).catch((error) => {
|
||||
log("[background-agent] Error during async shutdown cleanup:", error)
|
||||
})
|
||||
} catch (error) {
|
||||
log("[background-agent] Error during shutdown cleanup:", error)
|
||||
}
|
||||
|
||||
245
src/features/background-agent/task-completion-cleanup.test.ts
Normal file
245
src/features/background-agent/task-completion-cleanup.test.ts
Normal file
@@ -0,0 +1,245 @@
|
||||
declare const require: (name: string) => any
|
||||
const { describe, test, expect, afterEach } = require("bun:test")
|
||||
import { tmpdir } from "node:os"
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import { TASK_CLEANUP_DELAY_MS } from "./constants"
|
||||
import { BackgroundManager } from "./manager"
|
||||
import type { BackgroundTask } from "./types"
|
||||
|
||||
type PromptAsyncCall = {
|
||||
path: { id: string }
|
||||
body: {
|
||||
noReply?: boolean
|
||||
parts?: unknown[]
|
||||
}
|
||||
}
|
||||
|
||||
type FakeTimers = {
|
||||
getDelay: (timer: ReturnType<typeof setTimeout>) => number | undefined
|
||||
run: (timer: ReturnType<typeof setTimeout>) => void
|
||||
restore: () => void
|
||||
}
|
||||
|
||||
let managerUnderTest: BackgroundManager | undefined
|
||||
let fakeTimers: FakeTimers | undefined
|
||||
|
||||
afterEach(() => {
|
||||
managerUnderTest?.shutdown()
|
||||
fakeTimers?.restore()
|
||||
managerUnderTest = undefined
|
||||
fakeTimers = undefined
|
||||
})
|
||||
|
||||
function createTask(overrides: Partial<BackgroundTask> & { id: string; parentSessionID: string }): BackgroundTask {
|
||||
const id = overrides.id
|
||||
const parentSessionID = overrides.parentSessionID
|
||||
const { id: _ignoredID, parentSessionID: _ignoredParentSessionID, ...rest } = overrides
|
||||
|
||||
return {
|
||||
parentMessageID: overrides.parentMessageID ?? "parent-message-id",
|
||||
description: overrides.description ?? overrides.id,
|
||||
prompt: overrides.prompt ?? `Prompt for ${overrides.id}`,
|
||||
agent: overrides.agent ?? "test-agent",
|
||||
status: overrides.status ?? "running",
|
||||
startedAt: overrides.startedAt ?? new Date("2026-03-11T00:00:00.000Z"),
|
||||
...rest,
|
||||
id,
|
||||
parentSessionID,
|
||||
}
|
||||
}
|
||||
|
||||
function createManager(enableParentSessionNotifications: boolean): {
|
||||
manager: BackgroundManager
|
||||
promptAsyncCalls: PromptAsyncCall[]
|
||||
} {
|
||||
const promptAsyncCalls: PromptAsyncCall[] = []
|
||||
const client = {
|
||||
session: {
|
||||
messages: async () => [],
|
||||
prompt: async () => ({}),
|
||||
promptAsync: async (call: PromptAsyncCall) => {
|
||||
promptAsyncCalls.push(call)
|
||||
return {}
|
||||
},
|
||||
abort: async () => ({}),
|
||||
},
|
||||
}
|
||||
const placeholderClient = {} as PluginInput["client"]
|
||||
const ctx: PluginInput = {
|
||||
client: placeholderClient,
|
||||
project: {} as PluginInput["project"],
|
||||
directory: tmpdir(),
|
||||
worktree: tmpdir(),
|
||||
serverUrl: new URL("http://localhost"),
|
||||
$: {} as PluginInput["$"],
|
||||
}
|
||||
|
||||
const manager = new BackgroundManager(
|
||||
ctx,
|
||||
undefined,
|
||||
{ enableParentSessionNotifications }
|
||||
)
|
||||
Reflect.set(manager, "client", client)
|
||||
|
||||
return { manager, promptAsyncCalls }
|
||||
}
|
||||
|
||||
function installFakeTimers(): FakeTimers {
|
||||
const originalSetTimeout = globalThis.setTimeout
|
||||
const originalClearTimeout = globalThis.clearTimeout
|
||||
const callbacks = new Map<ReturnType<typeof setTimeout>, () => void>()
|
||||
const delays = new Map<ReturnType<typeof setTimeout>, number>()
|
||||
|
||||
globalThis.setTimeout = ((handler: Parameters<typeof setTimeout>[0], delay?: number, ...args: unknown[]): ReturnType<typeof setTimeout> => {
|
||||
if (typeof handler !== "function") {
|
||||
throw new Error("Expected function timeout handler")
|
||||
}
|
||||
|
||||
const timer = originalSetTimeout(() => {}, 60_000)
|
||||
originalClearTimeout(timer)
|
||||
const callback = handler as (...callbackArgs: Array<unknown>) => void
|
||||
callbacks.set(timer, () => callback(...args))
|
||||
delays.set(timer, delay ?? 0)
|
||||
return timer
|
||||
}) as typeof setTimeout
|
||||
|
||||
globalThis.clearTimeout = ((timer: ReturnType<typeof setTimeout>): void => {
|
||||
callbacks.delete(timer)
|
||||
delays.delete(timer)
|
||||
}) as typeof clearTimeout
|
||||
|
||||
return {
|
||||
getDelay(timer) {
|
||||
return delays.get(timer)
|
||||
},
|
||||
run(timer) {
|
||||
const callback = callbacks.get(timer)
|
||||
if (!callback) {
|
||||
throw new Error(`Timer not found: ${String(timer)}`)
|
||||
}
|
||||
|
||||
callbacks.delete(timer)
|
||||
delays.delete(timer)
|
||||
callback()
|
||||
},
|
||||
restore() {
|
||||
globalThis.setTimeout = originalSetTimeout
|
||||
globalThis.clearTimeout = originalClearTimeout
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
function getTasks(manager: BackgroundManager): Map<string, BackgroundTask> {
|
||||
return Reflect.get(manager, "tasks") as Map<string, BackgroundTask>
|
||||
}
|
||||
|
||||
function getPendingByParent(manager: BackgroundManager): Map<string, Set<string>> {
|
||||
return Reflect.get(manager, "pendingByParent") as Map<string, Set<string>>
|
||||
}
|
||||
|
||||
function getCompletionTimers(manager: BackgroundManager): Map<string, ReturnType<typeof setTimeout>> {
|
||||
return Reflect.get(manager, "completionTimers") as Map<string, ReturnType<typeof setTimeout>>
|
||||
}
|
||||
|
||||
async function notifyParentSessionForTest(manager: BackgroundManager, task: BackgroundTask): Promise<void> {
|
||||
const notifyParentSession = Reflect.get(manager, "notifyParentSession") as (task: BackgroundTask) => Promise<void>
|
||||
return notifyParentSession.call(manager, task)
|
||||
}
|
||||
|
||||
function getRequiredTimer(manager: BackgroundManager, taskID: string): ReturnType<typeof setTimeout> {
|
||||
const timer = getCompletionTimers(manager).get(taskID)
|
||||
expect(timer).toBeDefined()
|
||||
if (timer === undefined) {
|
||||
throw new Error(`Missing completion timer for ${taskID}`)
|
||||
}
|
||||
|
||||
return timer
|
||||
}
|
||||
|
||||
describe("BackgroundManager.notifyParentSession cleanup scheduling", () => {
|
||||
describe("#given 2 tasks for same parent and task A completed", () => {
|
||||
test("#when task B is still running #then task A is cleaned up from this.tasks after delay even though task B is not done", async () => {
|
||||
// given
|
||||
const { manager } = createManager(false)
|
||||
managerUnderTest = manager
|
||||
fakeTimers = installFakeTimers()
|
||||
const taskA = createTask({ id: "task-a", parentSessionID: "parent-1", description: "task A", status: "completed", completedAt: new Date("2026-03-11T00:01:00.000Z") })
|
||||
const taskB = createTask({ id: "task-b", parentSessionID: "parent-1", description: "task B", status: "running" })
|
||||
getTasks(manager).set(taskA.id, taskA)
|
||||
getTasks(manager).set(taskB.id, taskB)
|
||||
getPendingByParent(manager).set(taskA.parentSessionID, new Set([taskA.id, taskB.id]))
|
||||
|
||||
// when
|
||||
await notifyParentSessionForTest(manager, taskA)
|
||||
const taskATimer = getRequiredTimer(manager, taskA.id)
|
||||
expect(fakeTimers.getDelay(taskATimer)).toBe(TASK_CLEANUP_DELAY_MS)
|
||||
fakeTimers.run(taskATimer)
|
||||
|
||||
// then
|
||||
expect(fakeTimers.getDelay(taskATimer)).toBeUndefined()
|
||||
expect(getTasks(manager).has(taskA.id)).toBe(false)
|
||||
expect(getTasks(manager).get(taskB.id)).toBe(taskB)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given 2 tasks for same parent and both completed", () => {
|
||||
test("#when the second completion notification is sent #then ALL BACKGROUND TASKS COMPLETE notification still works correctly", async () => {
|
||||
// given
|
||||
const { manager, promptAsyncCalls } = createManager(true)
|
||||
managerUnderTest = manager
|
||||
fakeTimers = installFakeTimers()
|
||||
const taskA = createTask({ id: "task-a", parentSessionID: "parent-1", description: "task A", status: "completed", completedAt: new Date("2026-03-11T00:01:00.000Z") })
|
||||
const taskB = createTask({ id: "task-b", parentSessionID: "parent-1", description: "task B", status: "running" })
|
||||
getTasks(manager).set(taskA.id, taskA)
|
||||
getTasks(manager).set(taskB.id, taskB)
|
||||
getPendingByParent(manager).set(taskA.parentSessionID, new Set([taskA.id, taskB.id]))
|
||||
|
||||
await notifyParentSessionForTest(manager, taskA)
|
||||
taskB.status = "completed"
|
||||
taskB.completedAt = new Date("2026-03-11T00:02:00.000Z")
|
||||
|
||||
// when
|
||||
await notifyParentSessionForTest(manager, taskB)
|
||||
|
||||
// then
|
||||
expect(promptAsyncCalls).toHaveLength(2)
|
||||
expect(getCompletionTimers(manager).size).toBe(2)
|
||||
const allCompleteCall = promptAsyncCalls[1]
|
||||
expect(allCompleteCall).toBeDefined()
|
||||
if (!allCompleteCall) {
|
||||
throw new Error("Missing all-complete notification call")
|
||||
}
|
||||
|
||||
expect(allCompleteCall.body.noReply).toBe(false)
|
||||
const allCompletePayload = JSON.stringify(allCompleteCall.body.parts)
|
||||
expect(allCompletePayload).toContain("ALL BACKGROUND TASKS COMPLETE")
|
||||
expect(allCompletePayload).toContain(taskA.id)
|
||||
expect(allCompletePayload).toContain(taskB.id)
|
||||
expect(allCompletePayload).toContain(taskA.description)
|
||||
expect(allCompletePayload).toContain(taskB.description)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given a completed task with cleanup timer scheduled", () => {
|
||||
test("#when cleanup timer fires #then task is deleted from this.tasks Map", async () => {
|
||||
// given
|
||||
const { manager } = createManager(false)
|
||||
managerUnderTest = manager
|
||||
fakeTimers = installFakeTimers()
|
||||
const task = createTask({ id: "task-a", parentSessionID: "parent-1", description: "task A", status: "completed", completedAt: new Date("2026-03-11T00:01:00.000Z") })
|
||||
getTasks(manager).set(task.id, task)
|
||||
getPendingByParent(manager).set(task.parentSessionID, new Set([task.id]))
|
||||
|
||||
await notifyParentSessionForTest(manager, task)
|
||||
const cleanupTimer = getRequiredTimer(manager, task.id)
|
||||
|
||||
// when
|
||||
expect(fakeTimers.getDelay(cleanupTimer)).toBe(TASK_CLEANUP_DELAY_MS)
|
||||
fakeTimers.run(cleanupTimer)
|
||||
|
||||
// then
|
||||
expect(getCompletionTimers(manager).has(task.id)).toBe(false)
|
||||
expect(getTasks(manager).has(task.id)).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
142
src/features/background-agent/task-history-cleanup.test.ts
Normal file
142
src/features/background-agent/task-history-cleanup.test.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
import { afterEach, describe, expect, test } from "bun:test"
|
||||
import { tmpdir } from "node:os"
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import { BackgroundManager } from "./manager"
|
||||
import { TaskHistory } from "./task-history"
|
||||
import type { BackgroundTask } from "./types"
|
||||
|
||||
let managerUnderTest: BackgroundManager | undefined
|
||||
|
||||
afterEach(() => {
|
||||
managerUnderTest?.shutdown()
|
||||
managerUnderTest = undefined
|
||||
})
|
||||
|
||||
function createManager(): BackgroundManager {
|
||||
const client = {
|
||||
session: {
|
||||
abort: async () => ({}),
|
||||
},
|
||||
}
|
||||
|
||||
const placeholderClient = {} as PluginInput["client"]
|
||||
const ctx: PluginInput = {
|
||||
client: placeholderClient,
|
||||
project: {} as PluginInput["project"],
|
||||
directory: tmpdir(),
|
||||
worktree: tmpdir(),
|
||||
serverUrl: new URL("http://localhost"),
|
||||
$: {} as PluginInput["$"],
|
||||
}
|
||||
|
||||
const manager = new BackgroundManager(ctx)
|
||||
Reflect.set(manager, "client", client)
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
function createTask(overrides: Partial<BackgroundTask> & { id: string; parentSessionID: string }): BackgroundTask {
|
||||
const { id, parentSessionID, ...rest } = overrides
|
||||
|
||||
return {
|
||||
...rest,
|
||||
id,
|
||||
parentSessionID,
|
||||
parentMessageID: rest.parentMessageID ?? "parent-message-id",
|
||||
description: rest.description ?? id,
|
||||
prompt: rest.prompt ?? `Prompt for ${id}`,
|
||||
agent: rest.agent ?? "test-agent",
|
||||
status: rest.status ?? "running",
|
||||
startedAt: rest.startedAt ?? new Date("2026-03-11T00:00:00.000Z"),
|
||||
}
|
||||
}
|
||||
|
||||
function getTaskMap(manager: BackgroundManager): Map<string, BackgroundTask> {
|
||||
return Reflect.get(manager, "tasks") as Map<string, BackgroundTask>
|
||||
}
|
||||
|
||||
function pruneStaleTasksAndNotificationsForTest(manager: BackgroundManager): void {
|
||||
const pruneStaleTasksAndNotifications = Reflect.get(manager, "pruneStaleTasksAndNotifications") as () => void
|
||||
pruneStaleTasksAndNotifications.call(manager)
|
||||
}
|
||||
|
||||
describe("task history cleanup", () => {
|
||||
test("#given TaskHistory with entries for multiple parents #when clearSession called for one parent #then only that parent's entries are removed, others remain", () => {
|
||||
// given
|
||||
const history = new TaskHistory()
|
||||
history.record("parent-1", { id: "task-1", agent: "explore", description: "task 1", status: "pending" })
|
||||
history.record("parent-2", { id: "task-2", agent: "oracle", description: "task 2", status: "running" })
|
||||
|
||||
// when
|
||||
history.clearSession("parent-1")
|
||||
|
||||
// then
|
||||
expect(history.getByParentSession("parent-1")).toHaveLength(0)
|
||||
expect(history.getByParentSession("parent-2")).toHaveLength(1)
|
||||
})
|
||||
|
||||
test("#given TaskHistory with entries for multiple parents #when clearAll called #then all entries are removed", () => {
|
||||
// given
|
||||
const history = new TaskHistory()
|
||||
history.record("parent-1", { id: "task-1", agent: "explore", description: "task 1", status: "pending" })
|
||||
history.record("parent-2", { id: "task-2", agent: "oracle", description: "task 2", status: "running" })
|
||||
|
||||
// when
|
||||
history.clearAll()
|
||||
|
||||
// then
|
||||
expect(history.getByParentSession("parent-1")).toHaveLength(0)
|
||||
expect(history.getByParentSession("parent-2")).toHaveLength(0)
|
||||
})
|
||||
|
||||
test("#given BackgroundManager with taskHistory entries #when shutdown() called #then taskHistory is cleared via clearAll()", () => {
|
||||
// given
|
||||
const manager = createManager()
|
||||
managerUnderTest = manager
|
||||
manager.taskHistory.record("parent-1", { id: "task-1", agent: "explore", description: "task 1", status: "pending" })
|
||||
|
||||
let clearAllCalls = 0
|
||||
const originalClearAll = manager.taskHistory.clearAll.bind(manager.taskHistory)
|
||||
manager.taskHistory.clearAll = (): void => {
|
||||
clearAllCalls += 1
|
||||
originalClearAll()
|
||||
}
|
||||
|
||||
// when
|
||||
manager.shutdown()
|
||||
|
||||
// then
|
||||
expect(clearAllCalls).toBe(1)
|
||||
expect(manager.taskHistory.getByParentSession("parent-1")).toHaveLength(0)
|
||||
|
||||
managerUnderTest = undefined
|
||||
})
|
||||
|
||||
test("#given BackgroundManager with stale tasks for one parent #when pruneStaleTasksAndNotifications() runs #then history is preserved until delayed cleanup", () => {
|
||||
// given
|
||||
const manager = createManager()
|
||||
managerUnderTest = manager
|
||||
const staleTask = createTask({
|
||||
id: "task-stale",
|
||||
parentSessionID: "parent-1",
|
||||
startedAt: new Date(Date.now() - 31 * 60 * 1000),
|
||||
})
|
||||
const liveTask = createTask({
|
||||
id: "task-live",
|
||||
parentSessionID: "parent-2",
|
||||
startedAt: new Date(),
|
||||
})
|
||||
|
||||
getTaskMap(manager).set(staleTask.id, staleTask)
|
||||
getTaskMap(manager).set(liveTask.id, liveTask)
|
||||
manager.taskHistory.record("parent-1", { id: staleTask.id, agent: staleTask.agent, description: staleTask.description, status: staleTask.status })
|
||||
manager.taskHistory.record("parent-2", { id: liveTask.id, agent: liveTask.agent, description: liveTask.description, status: liveTask.status })
|
||||
|
||||
// when
|
||||
pruneStaleTasksAndNotificationsForTest(manager)
|
||||
|
||||
// then
|
||||
expect(manager.taskHistory.getByParentSession("parent-1")).toHaveLength(1)
|
||||
expect(manager.taskHistory.getByParentSession("parent-2")).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
@@ -54,6 +54,10 @@ export class TaskHistory {
|
||||
this.entries.delete(parentSessionID)
|
||||
}
|
||||
|
||||
clearAll(): void {
|
||||
this.entries.clear()
|
||||
}
|
||||
|
||||
formatForCompaction(parentSessionID: string): string | null {
|
||||
const list = this.getByParentSession(parentSessionID)
|
||||
if (list.length === 0) return null
|
||||
|
||||
@@ -19,11 +19,13 @@ export function registerProcessCleanup(state: SkillMcpManagerState): void {
|
||||
state.cleanupRegistered = true
|
||||
|
||||
const cleanup = async (): Promise<void> => {
|
||||
state.shutdownGeneration++
|
||||
for (const managed of state.clients.values()) {
|
||||
await closeManagedClient(managed)
|
||||
}
|
||||
state.clients.clear()
|
||||
state.pendingConnections.clear()
|
||||
state.disconnectedSessions.clear()
|
||||
}
|
||||
|
||||
// Note: Node's 'exit' event is synchronous-only, so we rely on signal handlers for async cleanup.
|
||||
@@ -79,12 +81,23 @@ async function cleanupIdleClients(state: SkillMcpManagerState): Promise<void> {
|
||||
}
|
||||
}
|
||||
|
||||
if (state.clients.size === 0) {
|
||||
if (state.clients.size === 0 && state.pendingConnections.size === 0) {
|
||||
stopCleanupTimer(state)
|
||||
unregisterProcessCleanup(state)
|
||||
}
|
||||
}
|
||||
|
||||
export async function disconnectSession(state: SkillMcpManagerState, sessionID: string): Promise<void> {
|
||||
let hasPendingForSession = false
|
||||
for (const key of state.pendingConnections.keys()) {
|
||||
if (key.startsWith(`${sessionID}:`)) {
|
||||
hasPendingForSession = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if (hasPendingForSession) {
|
||||
state.disconnectedSessions.set(sessionID, (state.disconnectedSessions.get(sessionID) ?? 0) + 1)
|
||||
}
|
||||
const keysToRemove: string[] = []
|
||||
|
||||
for (const [key, managed] of state.clients.entries()) {
|
||||
@@ -96,22 +109,33 @@ export async function disconnectSession(state: SkillMcpManagerState, sessionID:
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of state.pendingConnections.keys()) {
|
||||
if (key.startsWith(`${sessionID}:`)) {
|
||||
keysToRemove.push(key)
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of keysToRemove) {
|
||||
state.pendingConnections.delete(key)
|
||||
}
|
||||
|
||||
if (state.clients.size === 0) {
|
||||
if (state.clients.size === 0 && state.pendingConnections.size === 0) {
|
||||
stopCleanupTimer(state)
|
||||
unregisterProcessCleanup(state)
|
||||
}
|
||||
}
|
||||
|
||||
export async function disconnectAll(state: SkillMcpManagerState): Promise<void> {
|
||||
state.shutdownGeneration++
|
||||
state.disposed = true
|
||||
stopCleanupTimer(state)
|
||||
unregisterProcessCleanup(state)
|
||||
|
||||
const clients = Array.from(state.clients.values())
|
||||
state.clients.clear()
|
||||
state.pendingConnections.clear()
|
||||
state.disconnectedSessions.clear()
|
||||
state.inFlightConnections.clear()
|
||||
state.authProviders.clear()
|
||||
|
||||
for (const managed of clients) {
|
||||
|
||||
291
src/features/skill-mcp-manager/connection-race.test.ts
Normal file
291
src/features/skill-mcp-manager/connection-race.test.ts
Normal file
@@ -0,0 +1,291 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"
|
||||
import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types"
|
||||
import type { SkillMcpClientInfo, SkillMcpManagerState } from "./types"
|
||||
|
||||
type Deferred<TValue> = {
|
||||
promise: Promise<TValue>
|
||||
resolve: (value: TValue) => void
|
||||
reject: (error: Error) => void
|
||||
}
|
||||
|
||||
const pendingConnects: Deferred<void>[] = []
|
||||
const trackedStates: SkillMcpManagerState[] = []
|
||||
const createdClients: MockClient[] = []
|
||||
const createdTransports: MockStdioClientTransport[] = []
|
||||
|
||||
class MockClient {
|
||||
readonly close = mock(async () => {})
|
||||
|
||||
constructor(
|
||||
_clientInfo: { name: string; version: string },
|
||||
_options: { capabilities: Record<string, never> }
|
||||
) {
|
||||
createdClients.push(this)
|
||||
}
|
||||
|
||||
async connect(_transport: MockStdioClientTransport): Promise<void> {
|
||||
const pendingConnect = pendingConnects.shift()
|
||||
if (pendingConnect) {
|
||||
await pendingConnect.promise
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class MockStdioClientTransport {
|
||||
readonly close = mock(async () => {})
|
||||
|
||||
constructor(_options: { command: string; args?: string[]; env?: Record<string, string>; stderr?: string }) {
|
||||
createdTransports.push(this)
|
||||
}
|
||||
}
|
||||
|
||||
mock.module("@modelcontextprotocol/sdk/client/index.js", () => ({
|
||||
Client: MockClient,
|
||||
}))
|
||||
|
||||
mock.module("@modelcontextprotocol/sdk/client/stdio.js", () => ({
|
||||
StdioClientTransport: MockStdioClientTransport,
|
||||
}))
|
||||
|
||||
const { disconnectAll, disconnectSession } = await import("./cleanup")
|
||||
const { getOrCreateClient } = await import("./connection")
|
||||
|
||||
function createDeferred<TValue>(): Deferred<TValue> {
|
||||
let resolvePromise: ((value: TValue) => void) | null = null
|
||||
let rejectPromise: ((error: Error) => void) | null = null
|
||||
const promise = new Promise<TValue>((resolve, reject) => {
|
||||
resolvePromise = resolve
|
||||
rejectPromise = reject
|
||||
})
|
||||
|
||||
if (!resolvePromise || !rejectPromise) {
|
||||
throw new Error("Failed to create deferred promise")
|
||||
}
|
||||
|
||||
return {
|
||||
promise,
|
||||
resolve: resolvePromise,
|
||||
reject: rejectPromise,
|
||||
}
|
||||
}
|
||||
|
||||
function createState(): SkillMcpManagerState {
|
||||
const state: SkillMcpManagerState = {
|
||||
clients: new Map(),
|
||||
pendingConnections: new Map(),
|
||||
disconnectedSessions: new Map(),
|
||||
authProviders: new Map(),
|
||||
cleanupRegistered: false,
|
||||
cleanupInterval: null,
|
||||
cleanupHandlers: [],
|
||||
idleTimeoutMs: 5 * 60 * 1000,
|
||||
shutdownGeneration: 0,
|
||||
inFlightConnections: new Map(),
|
||||
disposed: false,
|
||||
}
|
||||
|
||||
trackedStates.push(state)
|
||||
return state
|
||||
}
|
||||
|
||||
function createClientInfo(sessionID: string): SkillMcpClientInfo {
|
||||
return {
|
||||
serverName: "race-server",
|
||||
skillName: "race-skill",
|
||||
sessionID,
|
||||
}
|
||||
}
|
||||
|
||||
function createClientKey(info: SkillMcpClientInfo): string {
|
||||
return `${info.sessionID}:${info.skillName}:${info.serverName}`
|
||||
}
|
||||
|
||||
const stdioConfig: ClaudeCodeMcpServer = {
|
||||
command: "mock-mcp-server",
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
pendingConnects.length = 0
|
||||
createdClients.length = 0
|
||||
createdTransports.length = 0
|
||||
})
|
||||
|
||||
afterEach(async () => {
|
||||
for (const state of trackedStates) {
|
||||
await disconnectAll(state)
|
||||
}
|
||||
|
||||
trackedStates.length = 0
|
||||
pendingConnects.length = 0
|
||||
createdClients.length = 0
|
||||
createdTransports.length = 0
|
||||
})
|
||||
|
||||
describe("getOrCreateClient disconnect race", () => {
|
||||
it("#given pending connection for session A #when disconnectSession(A) is called before connection completes #then completed client is not added to state.clients", async () => {
|
||||
const state = createState()
|
||||
const info = createClientInfo("session-a")
|
||||
const clientKey = createClientKey(info)
|
||||
const pendingConnect = createDeferred<void>()
|
||||
pendingConnects.push(pendingConnect)
|
||||
|
||||
const clientPromise = getOrCreateClient({ state, clientKey, info, config: stdioConfig })
|
||||
expect(state.pendingConnections.has(clientKey)).toBe(true)
|
||||
|
||||
await disconnectSession(state, info.sessionID)
|
||||
pendingConnect.resolve(undefined)
|
||||
|
||||
await expect(clientPromise).rejects.toThrow(/disconnected during MCP connection setup/)
|
||||
expect(state.clients.has(clientKey)).toBe(false)
|
||||
expect(state.pendingConnections.has(clientKey)).toBe(false)
|
||||
expect(state.disconnectedSessions.has(info.sessionID)).toBe(false)
|
||||
expect(createdClients).toHaveLength(1)
|
||||
expect(createdClients[0]?.close).toHaveBeenCalledTimes(1)
|
||||
expect(createdTransports[0]?.close).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it("#given session A in disconnectedSessions #when new connection completes with no remaining pending #then disconnectedSessions entry is cleaned up", async () => {
|
||||
const state = createState()
|
||||
const info = createClientInfo("session-a")
|
||||
const clientKey = createClientKey(info)
|
||||
state.disconnectedSessions.set(info.sessionID, 1)
|
||||
|
||||
const client = await getOrCreateClient({ state, clientKey, info, config: stdioConfig })
|
||||
|
||||
expect(state.disconnectedSessions.has(info.sessionID)).toBe(false)
|
||||
expect(state.clients.get(clientKey)?.client).toBe(client)
|
||||
expect(createdClients[0]?.close).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it("#given no pending connections #when disconnectSession is called #then no errors occur and session is not added to disconnectedSessions", async () => {
|
||||
const state = createState()
|
||||
|
||||
await expect(disconnectSession(state, "session-a")).resolves.toBeUndefined()
|
||||
expect(state.disconnectedSessions.has("session-a")).toBe(false)
|
||||
expect(state.pendingConnections.size).toBe(0)
|
||||
expect(state.clients.size).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe("getOrCreateClient disconnectAll race", () => {
|
||||
it("#given pending connection #when disconnectAll() is called before connection completes #then client is not added to state.clients", async () => {
|
||||
const state = createState()
|
||||
const info = createClientInfo("session-a")
|
||||
const clientKey = createClientKey(info)
|
||||
const pendingConnect = createDeferred<void>()
|
||||
pendingConnects.push(pendingConnect)
|
||||
|
||||
const clientPromise = getOrCreateClient({ state, clientKey, info, config: stdioConfig })
|
||||
expect(state.pendingConnections.has(clientKey)).toBe(true)
|
||||
|
||||
await disconnectAll(state)
|
||||
pendingConnect.resolve(undefined)
|
||||
|
||||
await expect(clientPromise).rejects.toThrow(/connection completed after shutdown/)
|
||||
expect(state.clients.has(clientKey)).toBe(false)
|
||||
})
|
||||
|
||||
it("#given state after disconnectAll() completed #when getOrCreateClient() is called #then it throws shut down error and registers nothing", async () => {
|
||||
const state = createState()
|
||||
const info = createClientInfo("session-b")
|
||||
const clientKey = createClientKey(info)
|
||||
|
||||
await disconnectAll(state)
|
||||
|
||||
await expect(getOrCreateClient({ state, clientKey, info, config: stdioConfig })).rejects.toThrow(/has been shut down/)
|
||||
expect(state.clients.size).toBe(0)
|
||||
expect(state.pendingConnections.size).toBe(0)
|
||||
expect(state.inFlightConnections.size).toBe(0)
|
||||
expect(state.disposed).toBe(true)
|
||||
expect(createdClients).toHaveLength(0)
|
||||
expect(createdTransports).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe("getOrCreateClient multi-key disconnect race", () => {
|
||||
it("#given 2 pending connections for session A #when disconnectSession(A) before both complete #then both old connections are rejected", async () => {
|
||||
const state = createState()
|
||||
const infoKey1 = createClientInfo("session-a")
|
||||
const infoKey2 = { ...createClientInfo("session-a"), serverName: "server-2" }
|
||||
const clientKey1 = createClientKey(infoKey1)
|
||||
const clientKey2 = `${infoKey2.sessionID}:${infoKey2.skillName}:${infoKey2.serverName}`
|
||||
const pendingConnect1 = createDeferred<void>()
|
||||
const pendingConnect2 = createDeferred<void>()
|
||||
pendingConnects.push(pendingConnect1)
|
||||
pendingConnects.push(pendingConnect2)
|
||||
|
||||
const promise1 = getOrCreateClient({ state, clientKey: clientKey1, info: infoKey1, config: stdioConfig })
|
||||
const promise2 = getOrCreateClient({ state, clientKey: clientKey2, info: infoKey2, config: stdioConfig })
|
||||
expect(state.pendingConnections.size).toBe(2)
|
||||
|
||||
await disconnectSession(state, "session-a")
|
||||
|
||||
pendingConnect1.resolve(undefined)
|
||||
await expect(promise1).rejects.toThrow(/disconnected during MCP connection setup/)
|
||||
|
||||
pendingConnect2.resolve(undefined)
|
||||
await expect(promise2).rejects.toThrow(/disconnected during MCP connection setup/)
|
||||
|
||||
expect(state.clients.has(clientKey1)).toBe(false)
|
||||
expect(state.clients.has(clientKey2)).toBe(false)
|
||||
expect(state.disconnectedSessions.has("session-a")).toBe(false)
|
||||
})
|
||||
|
||||
it("#given a superseded pending connection #when the old connection completes #then the stale client is removed from state.clients", async () => {
|
||||
const state = createState()
|
||||
const info = createClientInfo("session-a")
|
||||
const clientKey = createClientKey(info)
|
||||
const pendingConnect = createDeferred<void>()
|
||||
const supersedingConnection = createDeferred<Awaited<ReturnType<typeof getOrCreateClient>>>()
|
||||
pendingConnects.push(pendingConnect)
|
||||
|
||||
const clientPromise = getOrCreateClient({ state, clientKey, info, config: stdioConfig })
|
||||
state.pendingConnections.set(clientKey, supersedingConnection.promise)
|
||||
|
||||
pendingConnect.resolve(undefined)
|
||||
|
||||
await expect(clientPromise).rejects.toThrow(/superseded by a newer connection attempt/)
|
||||
expect(state.clients.has(clientKey)).toBe(false)
|
||||
expect(createdClients[0]?.close).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it("#given a superseded pending connection #when a newer client already replaced the map entry #then the stale cleanup does not delete the newer client", async () => {
|
||||
const state = createState()
|
||||
const info = createClientInfo("session-a")
|
||||
const clientKey = createClientKey(info)
|
||||
const pendingConnect = createDeferred<void>()
|
||||
const supersedingConnection = createDeferred<Awaited<ReturnType<typeof getOrCreateClient>>>()
|
||||
pendingConnects.push(pendingConnect)
|
||||
|
||||
const newerClient = new MockClient(
|
||||
{ name: "newer-client", version: "1.0.0" },
|
||||
{ capabilities: {} },
|
||||
)
|
||||
const newerTransport = new MockStdioClientTransport({ command: "mock-mcp-server" })
|
||||
let replacedEntry = false
|
||||
const originalSet = state.clients.set.bind(state.clients)
|
||||
Reflect.set(state.clients, "set", (key: string, value: SkillMcpManagerState["clients"] extends Map<string, infer TValue> ? TValue : never) => {
|
||||
originalSet(key, value)
|
||||
if (!replacedEntry && key === clientKey) {
|
||||
replacedEntry = true
|
||||
originalSet(key, {
|
||||
client: newerClient as never,
|
||||
transport: newerTransport as never,
|
||||
skillName: info.skillName,
|
||||
lastUsedAt: Date.now(),
|
||||
connectionType: "stdio",
|
||||
})
|
||||
}
|
||||
return state.clients
|
||||
})
|
||||
|
||||
const clientPromise = getOrCreateClient({ state, clientKey, info, config: stdioConfig })
|
||||
state.pendingConnections.set(clientKey, supersedingConnection.promise)
|
||||
|
||||
pendingConnect.resolve(undefined)
|
||||
|
||||
await expect(clientPromise).rejects.toThrow(/superseded by a newer connection attempt/)
|
||||
expect(state.clients.get(clientKey)?.client.close).toBe(newerClient.close)
|
||||
expect(newerClient.close).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -7,6 +7,13 @@ import { createHttpClient } from "./http-client"
|
||||
import { createStdioClient } from "./stdio-client"
|
||||
import type { SkillMcpClientConnectionParams, SkillMcpClientInfo, SkillMcpManagerState } from "./types"
|
||||
|
||||
function removeClientIfCurrent(state: SkillMcpManagerState, clientKey: string, client: Client): void {
|
||||
const managed = state.clients.get(clientKey)
|
||||
if (managed?.client === client) {
|
||||
state.clients.delete(clientKey)
|
||||
}
|
||||
}
|
||||
|
||||
export async function getOrCreateClient(params: {
|
||||
state: SkillMcpManagerState
|
||||
clientKey: string
|
||||
@@ -15,6 +22,10 @@ export async function getOrCreateClient(params: {
|
||||
}): Promise<Client> {
|
||||
const { state, clientKey, info, config } = params
|
||||
|
||||
if (state.disposed) {
|
||||
throw new Error(`MCP manager for "${info.sessionID}" has been shut down, cannot create new connections.`)
|
||||
}
|
||||
|
||||
const existing = state.clients.get(clientKey)
|
||||
if (existing) {
|
||||
existing.lastUsedAt = Date.now()
|
||||
@@ -28,14 +39,52 @@ export async function getOrCreateClient(params: {
|
||||
}
|
||||
|
||||
const expandedConfig = expandEnvVarsInObject(config)
|
||||
const connectionPromise = createClient({ state, clientKey, info, config: expandedConfig })
|
||||
state.pendingConnections.set(clientKey, connectionPromise)
|
||||
let currentConnectionPromise!: Promise<Client>
|
||||
state.inFlightConnections.set(info.sessionID, (state.inFlightConnections.get(info.sessionID) ?? 0) + 1)
|
||||
currentConnectionPromise = (async () => {
|
||||
const disconnectGenAtStart = state.disconnectedSessions.get(info.sessionID) ?? 0
|
||||
const shutdownGenAtStart = state.shutdownGeneration
|
||||
|
||||
const client = await createClient({ state, clientKey, info, config: expandedConfig })
|
||||
|
||||
const isStale = state.pendingConnections.has(clientKey) && state.pendingConnections.get(clientKey) !== currentConnectionPromise
|
||||
if (isStale) {
|
||||
removeClientIfCurrent(state, clientKey, client)
|
||||
try { await client.close() } catch {}
|
||||
throw new Error(`Connection for "${info.sessionID}" was superseded by a newer connection attempt.`)
|
||||
}
|
||||
|
||||
if (state.shutdownGeneration !== shutdownGenAtStart) {
|
||||
removeClientIfCurrent(state, clientKey, client)
|
||||
try { await client.close() } catch {}
|
||||
throw new Error(`Shutdown occurred during MCP connection for "${info.sessionID}"`)
|
||||
}
|
||||
|
||||
const currentDisconnectGen = state.disconnectedSessions.get(info.sessionID) ?? 0
|
||||
if (currentDisconnectGen > disconnectGenAtStart) {
|
||||
await forceReconnect(state, clientKey)
|
||||
throw new Error(`Session "${info.sessionID}" disconnected during MCP connection setup.`)
|
||||
}
|
||||
|
||||
return client
|
||||
})()
|
||||
|
||||
state.pendingConnections.set(clientKey, currentConnectionPromise)
|
||||
|
||||
try {
|
||||
const client = await connectionPromise
|
||||
const client = await currentConnectionPromise
|
||||
return client
|
||||
} finally {
|
||||
state.pendingConnections.delete(clientKey)
|
||||
if (state.pendingConnections.get(clientKey) === currentConnectionPromise) {
|
||||
state.pendingConnections.delete(clientKey)
|
||||
}
|
||||
const remaining = (state.inFlightConnections.get(info.sessionID) ?? 1) - 1
|
||||
if (remaining <= 0) {
|
||||
state.inFlightConnections.delete(info.sessionID)
|
||||
state.disconnectedSessions.delete(info.sessionID)
|
||||
} else {
|
||||
state.inFlightConnections.set(info.sessionID, remaining)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
133
src/features/skill-mcp-manager/disconnect-cleanup.test.ts
Normal file
133
src/features/skill-mcp-manager/disconnect-cleanup.test.ts
Normal file
@@ -0,0 +1,133 @@
|
||||
import { Client } from "@modelcontextprotocol/sdk/client/index.js"
|
||||
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"
|
||||
import { afterEach, describe, expect, it } from "bun:test"
|
||||
import { disconnectSession, registerProcessCleanup, unregisterProcessCleanup } from "./cleanup"
|
||||
import type { ManagedClient, SkillMcpManagerState } from "./types"
|
||||
|
||||
const trackedStates: SkillMcpManagerState[] = []
|
||||
|
||||
afterEach(() => {
|
||||
for (const state of trackedStates) {
|
||||
unregisterProcessCleanup(state)
|
||||
}
|
||||
|
||||
trackedStates.length = 0
|
||||
})
|
||||
|
||||
const expectedCleanupHandlerCount = process.platform === "win32" ? 3 : 2
|
||||
|
||||
function createState(): SkillMcpManagerState {
|
||||
const state: SkillMcpManagerState = {
|
||||
clients: new Map(),
|
||||
pendingConnections: new Map(),
|
||||
disconnectedSessions: new Map(),
|
||||
authProviders: new Map(),
|
||||
cleanupRegistered: false,
|
||||
cleanupInterval: null,
|
||||
cleanupHandlers: [],
|
||||
idleTimeoutMs: 5 * 60 * 1000,
|
||||
shutdownGeneration: 0,
|
||||
inFlightConnections: new Map(),
|
||||
disposed: false,
|
||||
}
|
||||
|
||||
trackedStates.push(state)
|
||||
return state
|
||||
}
|
||||
|
||||
function createManagedClient(skillName: string): ManagedClient {
|
||||
return {
|
||||
client: new Client(
|
||||
{ name: `test-${skillName}`, version: "1.0.0" },
|
||||
{ capabilities: {} }
|
||||
),
|
||||
transport: new StreamableHTTPClientTransport(new URL("https://example.com/mcp")),
|
||||
skillName,
|
||||
lastUsedAt: Date.now(),
|
||||
connectionType: "http",
|
||||
}
|
||||
}
|
||||
|
||||
describe("disconnectSession cleanup registration", () => {
|
||||
it("#given state with 1 client and cleanup registered #when disconnectSession removes last client #then process cleanup handlers are unregistered", async () => {
|
||||
// given
|
||||
const state = createState()
|
||||
const signalIntCountBeforeRegister = process.listenerCount("SIGINT")
|
||||
const signalTermCountBeforeRegister = process.listenerCount("SIGTERM")
|
||||
|
||||
state.clients.set("session-1:skill-1:server-1", createManagedClient("skill-1"))
|
||||
registerProcessCleanup(state)
|
||||
|
||||
// when
|
||||
await disconnectSession(state, "session-1")
|
||||
|
||||
// then
|
||||
expect(state.cleanupRegistered).toBe(false)
|
||||
expect(state.cleanupHandlers).toEqual([])
|
||||
expect(process.listenerCount("SIGINT")).toBe(signalIntCountBeforeRegister)
|
||||
expect(process.listenerCount("SIGTERM")).toBe(signalTermCountBeforeRegister)
|
||||
})
|
||||
|
||||
it("#given state with 2 clients in different sessions #when disconnectSession removes one session #then process cleanup handlers remain registered", async () => {
|
||||
// given
|
||||
const state = createState()
|
||||
const signalIntCountBeforeRegister = process.listenerCount("SIGINT")
|
||||
const signalTermCountBeforeRegister = process.listenerCount("SIGTERM")
|
||||
|
||||
state.clients.set("session-1:skill-1:server-1", createManagedClient("skill-1"))
|
||||
state.clients.set("session-2:skill-2:server-2", createManagedClient("skill-2"))
|
||||
registerProcessCleanup(state)
|
||||
|
||||
// when
|
||||
await disconnectSession(state, "session-1")
|
||||
|
||||
// then
|
||||
expect(state.clients.has("session-2:skill-2:server-2")).toBe(true)
|
||||
expect(state.cleanupRegistered).toBe(true)
|
||||
expect(state.cleanupHandlers).toHaveLength(expectedCleanupHandlerCount)
|
||||
expect(process.listenerCount("SIGINT")).toBe(signalIntCountBeforeRegister + 1)
|
||||
expect(process.listenerCount("SIGTERM")).toBe(signalTermCountBeforeRegister + 1)
|
||||
})
|
||||
|
||||
it("#given state with 2 clients in different sessions #when both sessions disconnected #then process cleanup handlers are unregistered", async () => {
|
||||
// given
|
||||
const state = createState()
|
||||
const signalIntCountBeforeRegister = process.listenerCount("SIGINT")
|
||||
const signalTermCountBeforeRegister = process.listenerCount("SIGTERM")
|
||||
|
||||
state.clients.set("session-1:skill-1:server-1", createManagedClient("skill-1"))
|
||||
state.clients.set("session-2:skill-2:server-2", createManagedClient("skill-2"))
|
||||
registerProcessCleanup(state)
|
||||
|
||||
// when
|
||||
await disconnectSession(state, "session-1")
|
||||
await disconnectSession(state, "session-2")
|
||||
|
||||
// then
|
||||
expect(state.clients.size).toBe(0)
|
||||
expect(state.cleanupRegistered).toBe(false)
|
||||
expect(state.cleanupHandlers).toEqual([])
|
||||
expect(process.listenerCount("SIGINT")).toBe(signalIntCountBeforeRegister)
|
||||
expect(process.listenerCount("SIGTERM")).toBe(signalTermCountBeforeRegister)
|
||||
})
|
||||
|
||||
it("#given state with 1 client and pending connection for different session and cleanup registered #when disconnectSession removes last client but pendingConnections remain #then process cleanup handlers stay registered", async () => {
|
||||
const state = createState()
|
||||
const signalIntCountBeforeRegister = process.listenerCount("SIGINT")
|
||||
const signalTermCountBeforeRegister = process.listenerCount("SIGTERM")
|
||||
const pendingClient = createManagedClient("skill-pending").client
|
||||
|
||||
state.clients.set("session-1:skill-1:server-1", createManagedClient("skill-1"))
|
||||
state.pendingConnections.set("session-2:skill-2:server-2", Promise.resolve(pendingClient))
|
||||
registerProcessCleanup(state)
|
||||
|
||||
await disconnectSession(state, "session-1")
|
||||
|
||||
expect(state.clients.size).toBe(0)
|
||||
expect(state.pendingConnections.size).toBe(1)
|
||||
expect(state.cleanupRegistered).toBe(true)
|
||||
expect(state.cleanupHandlers).toHaveLength(expectedCleanupHandlerCount)
|
||||
expect(process.listenerCount("SIGINT")).toBe(signalIntCountBeforeRegister + 1)
|
||||
expect(process.listenerCount("SIGTERM")).toBe(signalTermCountBeforeRegister + 1)
|
||||
})
|
||||
})
|
||||
@@ -24,6 +24,7 @@ function redactUrl(urlStr: string): string {
|
||||
|
||||
export async function createHttpClient(params: SkillMcpClientConnectionParams): Promise<Client> {
|
||||
const { state, clientKey, info, config } = params
|
||||
const shutdownGenAtStart = state.shutdownGeneration
|
||||
|
||||
if (!config.url) {
|
||||
throw new Error(`MCP server "${info.serverName}" is configured for HTTP but missing 'url' field.`)
|
||||
@@ -72,6 +73,12 @@ export async function createHttpClient(params: SkillMcpClientConnectionParams):
|
||||
)
|
||||
}
|
||||
|
||||
if (state.shutdownGeneration !== shutdownGenAtStart) {
|
||||
try { await client.close() } catch {}
|
||||
try { await transport.close() } catch {}
|
||||
throw new Error(`MCP server "${info.serverName}" connection completed after shutdown`)
|
||||
}
|
||||
|
||||
const managedClient = {
|
||||
client,
|
||||
transport,
|
||||
|
||||
@@ -10,11 +10,15 @@ export class SkillMcpManager {
|
||||
private readonly state: SkillMcpManagerState = {
|
||||
clients: new Map(),
|
||||
pendingConnections: new Map(),
|
||||
disconnectedSessions: new Map(),
|
||||
authProviders: new Map(),
|
||||
cleanupRegistered: false,
|
||||
cleanupInterval: null,
|
||||
cleanupHandlers: [],
|
||||
idleTimeoutMs: 5 * 60 * 1000,
|
||||
shutdownGeneration: 0,
|
||||
inFlightConnections: new Map(),
|
||||
disposed: false,
|
||||
}
|
||||
|
||||
private getClientKey(info: SkillMcpClientInfo): string {
|
||||
|
||||
@@ -14,6 +14,7 @@ function getStdioCommand(config: ClaudeCodeMcpServer, serverName: string): strin
|
||||
|
||||
export async function createStdioClient(params: SkillMcpClientConnectionParams): Promise<Client> {
|
||||
const { state, clientKey, info, config } = params
|
||||
const shutdownGenAtStart = state.shutdownGeneration
|
||||
|
||||
const command = getStdioCommand(config, info.serverName)
|
||||
const args = config.args ?? []
|
||||
@@ -55,6 +56,12 @@ export async function createStdioClient(params: SkillMcpClientConnectionParams):
|
||||
)
|
||||
}
|
||||
|
||||
if (state.shutdownGeneration !== shutdownGenAtStart) {
|
||||
try { await client.close() } catch {}
|
||||
try { await transport.close() } catch {}
|
||||
throw new Error(`MCP server "${info.serverName}" connection completed after shutdown`)
|
||||
}
|
||||
|
||||
const managedClient = {
|
||||
client,
|
||||
transport,
|
||||
|
||||
@@ -51,11 +51,15 @@ export interface ProcessCleanupHandler {
|
||||
export interface SkillMcpManagerState {
|
||||
clients: Map<string, ManagedClient>
|
||||
pendingConnections: Map<string, Promise<Client>>
|
||||
disconnectedSessions: Map<string, number>
|
||||
authProviders: Map<string, McpOAuthProvider>
|
||||
cleanupRegistered: boolean
|
||||
cleanupInterval: ReturnType<typeof setInterval> | null
|
||||
cleanupHandlers: ProcessCleanupHandler[]
|
||||
idleTimeoutMs: number
|
||||
shutdownGeneration: number
|
||||
inFlightConnections: Map<string, number>
|
||||
disposed: boolean
|
||||
}
|
||||
|
||||
export interface SkillMcpClientConnectionParams {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import type { TmuxConfig } from "../../config/schema"
|
||||
import type { TrackedSession, CapacityConfig } from "./types"
|
||||
import type { TrackedSession, CapacityConfig, WindowState } from "./types"
|
||||
import { log, normalizeSDKResponse } from "../../shared"
|
||||
import {
|
||||
isInsideTmux as defaultIsInsideTmux,
|
||||
@@ -13,6 +13,7 @@ import { queryWindowState } from "./pane-state-querier"
|
||||
import { decideSpawnActions, decideCloseAction, type SessionMapping } from "./decision-engine"
|
||||
import { executeActions, executeAction } from "./action-executor"
|
||||
import { TmuxPollingManager } from "./polling-manager"
|
||||
import { createTrackedSession, markTrackedSessionClosePending } from "./tracked-session-state"
|
||||
type OpencodeClient = PluginInput["client"]
|
||||
|
||||
interface SessionCreatedEvent {
|
||||
@@ -38,6 +39,7 @@ const defaultTmuxDeps: TmuxUtilDeps = {
|
||||
|
||||
const DEFERRED_SESSION_TTL_MS = 5 * 60 * 1000
|
||||
const MAX_DEFERRED_QUEUE_SIZE = 20
|
||||
const MAX_CLOSE_RETRY_COUNT = 3
|
||||
|
||||
/**
|
||||
* State-first Tmux Session Manager
|
||||
@@ -106,6 +108,123 @@ export class TmuxSessionManager {
|
||||
}))
|
||||
}
|
||||
|
||||
private removeTrackedSession(sessionId: string): void {
|
||||
this.sessions.delete(sessionId)
|
||||
|
||||
if (this.sessions.size === 0) {
|
||||
this.pollingManager.stopPolling()
|
||||
}
|
||||
}
|
||||
|
||||
private markSessionClosePending(sessionId: string): void {
|
||||
const tracked = this.sessions.get(sessionId)
|
||||
if (!tracked) return
|
||||
|
||||
this.sessions.set(sessionId, markTrackedSessionClosePending(tracked))
|
||||
log("[tmux-session-manager] marked session close pending", {
|
||||
sessionId,
|
||||
paneId: tracked.paneId,
|
||||
closeRetryCount: tracked.closeRetryCount,
|
||||
})
|
||||
}
|
||||
|
||||
private async queryWindowStateSafely(): Promise<WindowState | null> {
|
||||
if (!this.sourcePaneId) return null
|
||||
|
||||
try {
|
||||
return await queryWindowState(this.sourcePaneId)
|
||||
} catch (error) {
|
||||
log("[tmux-session-manager] failed to query window state for close", {
|
||||
error: String(error),
|
||||
})
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
private async tryCloseTrackedSession(tracked: TrackedSession): Promise<boolean> {
|
||||
const state = await this.queryWindowStateSafely()
|
||||
if (!state) return false
|
||||
|
||||
try {
|
||||
const result = await executeAction(
|
||||
{ type: "close", paneId: tracked.paneId, sessionId: tracked.sessionId },
|
||||
{
|
||||
config: this.tmuxConfig,
|
||||
serverUrl: this.serverUrl,
|
||||
windowState: state,
|
||||
sourcePaneId: this.sourcePaneId,
|
||||
}
|
||||
)
|
||||
|
||||
return result.success
|
||||
} catch (error) {
|
||||
log("[tmux-session-manager] close session pane failed", {
|
||||
sessionId: tracked.sessionId,
|
||||
paneId: tracked.paneId,
|
||||
error: String(error),
|
||||
})
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private async retryPendingCloses(): Promise<void> {
|
||||
const pendingSessions = Array.from(this.sessions.values()).filter(
|
||||
(tracked) => tracked.closePending,
|
||||
)
|
||||
|
||||
for (const tracked of pendingSessions) {
|
||||
if (!this.sessions.has(tracked.sessionId)) continue
|
||||
|
||||
if (tracked.closeRetryCount >= MAX_CLOSE_RETRY_COUNT) {
|
||||
log("[tmux-session-manager] force removing close-pending session after max retries", {
|
||||
sessionId: tracked.sessionId,
|
||||
paneId: tracked.paneId,
|
||||
closeRetryCount: tracked.closeRetryCount,
|
||||
})
|
||||
this.removeTrackedSession(tracked.sessionId)
|
||||
continue
|
||||
}
|
||||
|
||||
const closed = await this.tryCloseTrackedSession(tracked)
|
||||
if (closed) {
|
||||
log("[tmux-session-manager] retried close succeeded", {
|
||||
sessionId: tracked.sessionId,
|
||||
paneId: tracked.paneId,
|
||||
closeRetryCount: tracked.closeRetryCount,
|
||||
})
|
||||
this.removeTrackedSession(tracked.sessionId)
|
||||
continue
|
||||
}
|
||||
|
||||
const currentTracked = this.sessions.get(tracked.sessionId)
|
||||
if (!currentTracked || !currentTracked.closePending) {
|
||||
continue
|
||||
}
|
||||
|
||||
const nextRetryCount = currentTracked.closeRetryCount + 1
|
||||
if (nextRetryCount >= MAX_CLOSE_RETRY_COUNT) {
|
||||
log("[tmux-session-manager] force removing close-pending session after failed retry", {
|
||||
sessionId: currentTracked.sessionId,
|
||||
paneId: currentTracked.paneId,
|
||||
closeRetryCount: nextRetryCount,
|
||||
})
|
||||
this.removeTrackedSession(currentTracked.sessionId)
|
||||
continue
|
||||
}
|
||||
|
||||
this.sessions.set(currentTracked.sessionId, {
|
||||
...currentTracked,
|
||||
closePending: true,
|
||||
closeRetryCount: nextRetryCount,
|
||||
})
|
||||
log("[tmux-session-manager] retried close failed", {
|
||||
sessionId: currentTracked.sessionId,
|
||||
paneId: currentTracked.paneId,
|
||||
closeRetryCount: nextRetryCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
private enqueueDeferredSession(sessionId: string, title: string): void {
|
||||
if (this.deferredSessions.has(sessionId)) return
|
||||
if (this.deferredQueue.length >= MAX_DEFERRED_QUEUE_SIZE) {
|
||||
@@ -257,14 +376,14 @@ export class TmuxSessionManager {
|
||||
})
|
||||
}
|
||||
|
||||
const now = Date.now()
|
||||
this.sessions.set(sessionId, {
|
||||
this.sessions.set(
|
||||
sessionId,
|
||||
paneId: result.spawnedPaneId,
|
||||
description: deferred.title,
|
||||
createdAt: new Date(now),
|
||||
lastSeenAt: new Date(now),
|
||||
})
|
||||
createTrackedSession({
|
||||
sessionId,
|
||||
paneId: result.spawnedPaneId,
|
||||
description: deferred.title,
|
||||
}),
|
||||
)
|
||||
this.removeDeferredSession(sessionId)
|
||||
this.pollingManager.startPolling()
|
||||
log("[tmux-session-manager] deferred session attached", {
|
||||
@@ -324,6 +443,13 @@ export class TmuxSessionManager {
|
||||
const sessionId = info.id
|
||||
const title = info.title ?? "Subagent"
|
||||
|
||||
if (!this.sourcePaneId) {
|
||||
log("[tmux-session-manager] no source pane id")
|
||||
return
|
||||
}
|
||||
|
||||
await this.retryPendingCloses()
|
||||
|
||||
if (
|
||||
this.sessions.has(sessionId) ||
|
||||
this.pendingSessions.has(sessionId) ||
|
||||
@@ -332,11 +458,6 @@ export class TmuxSessionManager {
|
||||
log("[tmux-session-manager] session already tracked or pending", { sessionId })
|
||||
return
|
||||
}
|
||||
|
||||
if (!this.sourcePaneId) {
|
||||
log("[tmux-session-manager] no source pane id")
|
||||
return
|
||||
}
|
||||
const sourcePaneId = this.sourcePaneId
|
||||
|
||||
this.pendingSessions.add(sessionId)
|
||||
@@ -418,14 +539,14 @@ export class TmuxSessionManager {
|
||||
})
|
||||
}
|
||||
|
||||
const now = Date.now()
|
||||
this.sessions.set(sessionId, {
|
||||
this.sessions.set(
|
||||
sessionId,
|
||||
paneId: result.spawnedPaneId,
|
||||
description: title,
|
||||
createdAt: new Date(now),
|
||||
lastSeenAt: new Date(now),
|
||||
})
|
||||
createTrackedSession({
|
||||
sessionId,
|
||||
paneId: result.spawnedPaneId,
|
||||
description: title,
|
||||
}),
|
||||
)
|
||||
log("[tmux-session-manager] pane spawned and tracked", {
|
||||
sessionId,
|
||||
paneId: result.spawnedPaneId,
|
||||
@@ -485,27 +606,40 @@ export class TmuxSessionManager {
|
||||
|
||||
log("[tmux-session-manager] onSessionDeleted", { sessionId: event.sessionID })
|
||||
|
||||
const state = await queryWindowState(this.sourcePaneId)
|
||||
const state = await this.queryWindowStateSafely()
|
||||
if (!state) {
|
||||
this.sessions.delete(event.sessionID)
|
||||
this.markSessionClosePending(event.sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
const closeAction = decideCloseAction(state, event.sessionID, this.getSessionMappings())
|
||||
if (closeAction) {
|
||||
await executeAction(closeAction, {
|
||||
if (!closeAction) {
|
||||
this.removeTrackedSession(event.sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await executeAction(closeAction, {
|
||||
config: this.tmuxConfig,
|
||||
serverUrl: this.serverUrl,
|
||||
windowState: state,
|
||||
sourcePaneId: this.sourcePaneId,
|
||||
})
|
||||
|
||||
if (!result.success) {
|
||||
this.markSessionClosePending(event.sessionID)
|
||||
return
|
||||
}
|
||||
} catch (error) {
|
||||
log("[tmux-session-manager] failed to close pane for deleted session", {
|
||||
sessionId: event.sessionID,
|
||||
error: String(error),
|
||||
})
|
||||
this.markSessionClosePending(event.sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
this.sessions.delete(event.sessionID)
|
||||
|
||||
if (this.sessions.size === 0) {
|
||||
this.pollingManager.stopPolling()
|
||||
}
|
||||
this.removeTrackedSession(event.sessionID)
|
||||
}
|
||||
|
||||
|
||||
@@ -513,29 +647,28 @@ export class TmuxSessionManager {
|
||||
const tracked = this.sessions.get(sessionId)
|
||||
if (!tracked) return
|
||||
|
||||
if (tracked.closePending && tracked.closeRetryCount >= MAX_CLOSE_RETRY_COUNT) {
|
||||
log("[tmux-session-manager] force removing close-pending session after max retries", {
|
||||
sessionId,
|
||||
paneId: tracked.paneId,
|
||||
closeRetryCount: tracked.closeRetryCount,
|
||||
})
|
||||
this.removeTrackedSession(sessionId)
|
||||
return
|
||||
}
|
||||
|
||||
log("[tmux-session-manager] closing session pane", {
|
||||
sessionId,
|
||||
paneId: tracked.paneId,
|
||||
})
|
||||
|
||||
const state = this.sourcePaneId ? await queryWindowState(this.sourcePaneId) : null
|
||||
if (state) {
|
||||
await executeAction(
|
||||
{ type: "close", paneId: tracked.paneId, sessionId },
|
||||
{
|
||||
config: this.tmuxConfig,
|
||||
serverUrl: this.serverUrl,
|
||||
windowState: state,
|
||||
sourcePaneId: this.sourcePaneId,
|
||||
}
|
||||
)
|
||||
const closed = await this.tryCloseTrackedSession(tracked)
|
||||
if (!closed) {
|
||||
this.markSessionClosePending(sessionId)
|
||||
return
|
||||
}
|
||||
|
||||
this.sessions.delete(sessionId)
|
||||
|
||||
if (this.sessions.size === 0) {
|
||||
this.pollingManager.stopPolling()
|
||||
}
|
||||
this.removeTrackedSession(sessionId)
|
||||
}
|
||||
|
||||
createEventHandler(): (input: { event: { type: string; properties?: unknown } }) => Promise<void> {
|
||||
@@ -552,30 +685,22 @@ export class TmuxSessionManager {
|
||||
|
||||
if (this.sessions.size > 0) {
|
||||
log("[tmux-session-manager] closing all panes", { count: this.sessions.size })
|
||||
const state = this.sourcePaneId ? await queryWindowState(this.sourcePaneId) : null
|
||||
|
||||
if (state) {
|
||||
const closePromises = Array.from(this.sessions.values()).map((s) =>
|
||||
executeAction(
|
||||
{ type: "close", paneId: s.paneId, sessionId: s.sessionId },
|
||||
{
|
||||
config: this.tmuxConfig,
|
||||
serverUrl: this.serverUrl,
|
||||
windowState: state,
|
||||
sourcePaneId: this.sourcePaneId,
|
||||
}
|
||||
).catch((err) =>
|
||||
log("[tmux-session-manager] cleanup error for pane", {
|
||||
paneId: s.paneId,
|
||||
error: String(err),
|
||||
}),
|
||||
),
|
||||
)
|
||||
await Promise.all(closePromises)
|
||||
|
||||
const sessionIds = Array.from(this.sessions.keys())
|
||||
for (const sessionId of sessionIds) {
|
||||
try {
|
||||
await this.closeSessionById(sessionId)
|
||||
} catch (error) {
|
||||
log("[tmux-session-manager] cleanup error for pane", {
|
||||
sessionId,
|
||||
error: String(error),
|
||||
})
|
||||
}
|
||||
}
|
||||
this.sessions.clear()
|
||||
}
|
||||
|
||||
await this.retryPendingCloses()
|
||||
|
||||
log("[tmux-session-manager] cleanup complete")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ describe("TmuxPollingManager overlap", () => {
|
||||
description: "test",
|
||||
createdAt: new Date(),
|
||||
lastSeenAt: new Date(),
|
||||
closePending: false,
|
||||
closeRetryCount: 0,
|
||||
})
|
||||
|
||||
let activeCalls = 0
|
||||
|
||||
@@ -6,6 +6,7 @@ import { queryWindowState } from "./pane-state-querier"
|
||||
import { decideSpawnActions, type SessionMapping } from "./decision-engine"
|
||||
import { executeActions } from "./action-executor"
|
||||
import type { SessionCreatedEvent } from "./session-created-event"
|
||||
import { createTrackedSession } from "./tracked-session-state"
|
||||
|
||||
type OpencodeClient = PluginInput["client"]
|
||||
|
||||
@@ -152,14 +153,14 @@ export async function handleSessionCreated(
|
||||
return
|
||||
}
|
||||
|
||||
const now = Date.now()
|
||||
deps.sessions.set(sessionId, {
|
||||
deps.sessions.set(
|
||||
sessionId,
|
||||
paneId: result.spawnedPaneId,
|
||||
description: title,
|
||||
createdAt: new Date(now),
|
||||
lastSeenAt: new Date(now),
|
||||
})
|
||||
createTrackedSession({
|
||||
sessionId,
|
||||
paneId: result.spawnedPaneId,
|
||||
description: title,
|
||||
}),
|
||||
)
|
||||
|
||||
log("[tmux-session-manager] pane spawned and tracked", {
|
||||
sessionId,
|
||||
|
||||
28
src/features/tmux-subagent/tracked-session-state.ts
Normal file
28
src/features/tmux-subagent/tracked-session-state.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
import type { TrackedSession } from "./types"
|
||||
|
||||
export function createTrackedSession(params: {
|
||||
sessionId: string
|
||||
paneId: string
|
||||
description: string
|
||||
now?: Date
|
||||
}): TrackedSession {
|
||||
const now = params.now ?? new Date()
|
||||
|
||||
return {
|
||||
sessionId: params.sessionId,
|
||||
paneId: params.paneId,
|
||||
description: params.description,
|
||||
createdAt: now,
|
||||
lastSeenAt: now,
|
||||
closePending: false,
|
||||
closeRetryCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
export function markTrackedSessionClosePending(tracked: TrackedSession): TrackedSession {
|
||||
return {
|
||||
...tracked,
|
||||
closePending: true,
|
||||
closeRetryCount: tracked.closePending ? tracked.closeRetryCount + 1 : tracked.closeRetryCount,
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,8 @@ export interface TrackedSession {
|
||||
description: string
|
||||
createdAt: Date
|
||||
lastSeenAt: Date
|
||||
closePending: boolean
|
||||
closeRetryCount: number
|
||||
// Stability detection fields (prevents premature closure)
|
||||
lastMessageCount?: number
|
||||
stableIdlePolls?: number
|
||||
|
||||
271
src/features/tmux-subagent/zombie-pane.test.ts
Normal file
271
src/features/tmux-subagent/zombie-pane.test.ts
Normal file
@@ -0,0 +1,271 @@
|
||||
import { beforeEach, describe, expect, mock, test } from "bun:test"
|
||||
import type { TmuxConfig } from "../../config/schema"
|
||||
import type { ActionResult, ExecuteContext, ExecuteActionsResult } from "./action-executor"
|
||||
import type { TmuxUtilDeps } from "./manager"
|
||||
import type { TrackedSession, WindowState } from "./types"
|
||||
|
||||
const mockQueryWindowState = mock<(paneId: string) => Promise<WindowState | null>>(async () => ({
|
||||
windowWidth: 220,
|
||||
windowHeight: 44,
|
||||
mainPane: { paneId: "%0", width: 110, height: 44, left: 0, top: 0, title: "main", isActive: true },
|
||||
agentPanes: [],
|
||||
}))
|
||||
|
||||
const mockExecuteAction = mock<(
|
||||
action: { type: string },
|
||||
ctx: ExecuteContext,
|
||||
) => Promise<ActionResult>>(async () => ({ success: true }))
|
||||
|
||||
const mockExecuteActions = mock<(
|
||||
actions: unknown[],
|
||||
ctx: ExecuteContext,
|
||||
) => Promise<ExecuteActionsResult>>(async () => ({
|
||||
success: true,
|
||||
spawnedPaneId: "%1",
|
||||
results: [],
|
||||
}))
|
||||
|
||||
const mockIsInsideTmux = mock<() => boolean>(() => true)
|
||||
const mockGetCurrentPaneId = mock<() => string | undefined>(() => "%0")
|
||||
|
||||
mock.module("./pane-state-querier", () => ({
|
||||
queryWindowState: mockQueryWindowState,
|
||||
}))
|
||||
|
||||
mock.module("./action-executor", () => ({
|
||||
executeAction: mockExecuteAction,
|
||||
executeActions: mockExecuteActions,
|
||||
}))
|
||||
|
||||
mock.module("../../shared/tmux", () => ({
|
||||
isInsideTmux: mockIsInsideTmux,
|
||||
getCurrentPaneId: mockGetCurrentPaneId,
|
||||
POLL_INTERVAL_BACKGROUND_MS: 10,
|
||||
SESSION_READY_POLL_INTERVAL_MS: 10,
|
||||
SESSION_READY_TIMEOUT_MS: 50,
|
||||
SESSION_MISSING_GRACE_MS: 1_000,
|
||||
}))
|
||||
|
||||
const mockTmuxDeps: TmuxUtilDeps = {
|
||||
isInsideTmux: mockIsInsideTmux,
|
||||
getCurrentPaneId: mockGetCurrentPaneId,
|
||||
}
|
||||
|
||||
function createConfig(): TmuxConfig {
|
||||
return {
|
||||
enabled: true,
|
||||
layout: "main-vertical",
|
||||
main_pane_size: 60,
|
||||
main_pane_min_width: 80,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
}
|
||||
|
||||
function createContext() {
|
||||
const shell = Object.assign(
|
||||
() => {
|
||||
throw new Error("shell should not be called in this test")
|
||||
},
|
||||
{
|
||||
braces: () => [],
|
||||
escape: (input: string) => input,
|
||||
env() {
|
||||
return shell
|
||||
},
|
||||
cwd() {
|
||||
return shell
|
||||
},
|
||||
nothrow() {
|
||||
return shell
|
||||
},
|
||||
throws() {
|
||||
return shell
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
project: {
|
||||
id: "project-id",
|
||||
worktree: "/tmp/omo-fix-memory-leaks",
|
||||
time: { created: Date.now() },
|
||||
},
|
||||
directory: "/tmp/omo-fix-memory-leaks",
|
||||
worktree: "/tmp/omo-fix-memory-leaks",
|
||||
serverUrl: new URL("http://localhost:4096"),
|
||||
$: shell,
|
||||
client: {
|
||||
session: {
|
||||
status: mock(async () => ({ data: {} })),
|
||||
messages: mock(async () => ({ data: [] })),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
function createTrackedSession(overrides?: Partial<TrackedSession>): TrackedSession {
|
||||
return {
|
||||
sessionId: "ses_pending",
|
||||
paneId: "%1",
|
||||
description: "Pending pane",
|
||||
createdAt: new Date(),
|
||||
lastSeenAt: new Date(),
|
||||
closePending: false,
|
||||
closeRetryCount: 0,
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
function getTrackedSessions(target: object): Map<string, TrackedSession> {
|
||||
const sessions = Reflect.get(target, "sessions")
|
||||
if (!(sessions instanceof Map)) {
|
||||
throw new Error("Expected sessions map")
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
function getRetryPendingCloses(target: object): () => Promise<void> {
|
||||
const retryPendingCloses = Reflect.get(target, "retryPendingCloses")
|
||||
if (typeof retryPendingCloses !== "function") {
|
||||
throw new Error("Expected retryPendingCloses method")
|
||||
}
|
||||
|
||||
return retryPendingCloses.bind(target)
|
||||
}
|
||||
|
||||
function getCloseSessionById(target: object): (sessionId: string) => Promise<void> {
|
||||
const closeSessionById = Reflect.get(target, "closeSessionById")
|
||||
if (typeof closeSessionById !== "function") {
|
||||
throw new Error("Expected closeSessionById method")
|
||||
}
|
||||
|
||||
return closeSessionById.bind(target)
|
||||
}
|
||||
|
||||
function createManager(
|
||||
TmuxSessionManager: typeof import("./manager").TmuxSessionManager,
|
||||
): import("./manager").TmuxSessionManager {
|
||||
return Reflect.construct(TmuxSessionManager, [createContext(), createConfig(), mockTmuxDeps])
|
||||
}
|
||||
|
||||
describe("TmuxSessionManager zombie pane handling", () => {
|
||||
beforeEach(() => {
|
||||
mockQueryWindowState.mockClear()
|
||||
mockExecuteAction.mockClear()
|
||||
mockExecuteActions.mockClear()
|
||||
mockIsInsideTmux.mockClear()
|
||||
mockGetCurrentPaneId.mockClear()
|
||||
|
||||
mockQueryWindowState.mockImplementation(async () => ({
|
||||
windowWidth: 220,
|
||||
windowHeight: 44,
|
||||
mainPane: { paneId: "%0", width: 110, height: 44, left: 0, top: 0, title: "main", isActive: true },
|
||||
agentPanes: [],
|
||||
}))
|
||||
mockExecuteAction.mockImplementation(async () => ({ success: true }))
|
||||
mockExecuteActions.mockImplementation(async () => ({
|
||||
success: true,
|
||||
spawnedPaneId: "%1",
|
||||
results: [],
|
||||
}))
|
||||
mockIsInsideTmux.mockReturnValue(true)
|
||||
mockGetCurrentPaneId.mockReturnValue("%0")
|
||||
})
|
||||
|
||||
test("#given session in sessions Map #when onSessionDeleted called with null window state #then session stays in Map with closePending true", async () => {
|
||||
// given
|
||||
mockQueryWindowState.mockImplementation(async () => null)
|
||||
const { TmuxSessionManager } = await import("./manager")
|
||||
const manager = createManager(TmuxSessionManager)
|
||||
const sessions = getTrackedSessions(manager)
|
||||
sessions.set("ses_pending", createTrackedSession())
|
||||
|
||||
// when
|
||||
await manager.onSessionDeleted({ sessionID: "ses_pending" })
|
||||
|
||||
// then
|
||||
const tracked = sessions.get("ses_pending")
|
||||
expect(tracked).toBeDefined()
|
||||
expect(tracked?.closePending).toBe(true)
|
||||
expect(tracked?.closeRetryCount).toBe(0)
|
||||
expect(mockExecuteAction).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
test("#given session with closePending true #when retryPendingCloses succeeds #then session is removed from Map", async () => {
|
||||
// given
|
||||
const { TmuxSessionManager } = await import("./manager")
|
||||
const manager = createManager(TmuxSessionManager)
|
||||
const sessions = getTrackedSessions(manager)
|
||||
sessions.set(
|
||||
"ses_pending",
|
||||
createTrackedSession({ closePending: true, closeRetryCount: 0 }),
|
||||
)
|
||||
|
||||
// when
|
||||
await getRetryPendingCloses(manager)()
|
||||
|
||||
// then
|
||||
expect(sessions.has("ses_pending")).toBe(false)
|
||||
expect(mockExecuteAction).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
test("#given session with closePending true and closeRetryCount >= 3 #when retryPendingCloses called #then session is force-removed from Map", async () => {
|
||||
// given
|
||||
const { TmuxSessionManager } = await import("./manager")
|
||||
const manager = createManager(TmuxSessionManager)
|
||||
const sessions = getTrackedSessions(manager)
|
||||
sessions.set(
|
||||
"ses_pending",
|
||||
createTrackedSession({ closePending: true, closeRetryCount: 3 }),
|
||||
)
|
||||
|
||||
// when
|
||||
await getRetryPendingCloses(manager)()
|
||||
|
||||
// then
|
||||
expect(sessions.has("ses_pending")).toBe(false)
|
||||
expect(mockQueryWindowState).not.toHaveBeenCalled()
|
||||
expect(mockExecuteAction).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
test("#given session with closePending true and closeRetryCount >= 3 #when closeSessionById called #then session is force-removed without retrying close", async () => {
|
||||
// given
|
||||
const { TmuxSessionManager } = await import("./manager")
|
||||
const manager = createManager(TmuxSessionManager)
|
||||
const sessions = getTrackedSessions(manager)
|
||||
sessions.set(
|
||||
"ses_pending",
|
||||
createTrackedSession({ closePending: true, closeRetryCount: 3 }),
|
||||
)
|
||||
|
||||
// when
|
||||
await getCloseSessionById(manager)("ses_pending")
|
||||
|
||||
// then
|
||||
expect(sessions.has("ses_pending")).toBe(false)
|
||||
expect(mockQueryWindowState).not.toHaveBeenCalled()
|
||||
expect(mockExecuteAction).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
test("#given close-pending session removed during async close #when retryPendingCloses fails #then it does not resurrect stale session state", async () => {
|
||||
// given
|
||||
const { TmuxSessionManager } = await import("./manager")
|
||||
const manager = createManager(TmuxSessionManager)
|
||||
const sessions = getTrackedSessions(manager)
|
||||
sessions.set(
|
||||
"ses_pending",
|
||||
createTrackedSession({ closePending: true, closeRetryCount: 0 }),
|
||||
)
|
||||
mockExecuteAction.mockImplementationOnce(async () => {
|
||||
sessions.delete("ses_pending")
|
||||
return { success: false }
|
||||
})
|
||||
|
||||
// when
|
||||
await getRetryPendingCloses(manager)()
|
||||
|
||||
// then
|
||||
expect(sessions.has("ses_pending")).toBe(false)
|
||||
})
|
||||
})
|
||||
142
src/hooks/auto-slash-command/auto-slash-command-leak.test.ts
Normal file
142
src/hooks/auto-slash-command/auto-slash-command-leak.test.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
import { beforeEach, describe, expect, it, mock, spyOn } from "bun:test"
|
||||
import { AUTO_SLASH_COMMAND_TAG_OPEN } from "./constants"
|
||||
import type {
|
||||
AutoSlashCommandHookInput,
|
||||
AutoSlashCommandHookOutput,
|
||||
CommandExecuteBeforeInput,
|
||||
CommandExecuteBeforeOutput,
|
||||
} from "./types"
|
||||
import * as shared from "../../shared"
|
||||
|
||||
const executeSlashCommandMock = mock(
|
||||
async (parsed: { command: string; args: string; raw: string }) => ({
|
||||
success: true,
|
||||
replacementText: parsed.raw,
|
||||
})
|
||||
)
|
||||
|
||||
mock.module("./executor", () => ({
|
||||
executeSlashCommand: executeSlashCommandMock,
|
||||
}))
|
||||
|
||||
const logMock = spyOn(shared, "log").mockImplementation(() => {})
|
||||
|
||||
const { createAutoSlashCommandHook } = await import("./hook")
|
||||
|
||||
function createChatInput(sessionID: string, messageID: string): AutoSlashCommandHookInput {
|
||||
return {
|
||||
sessionID,
|
||||
messageID,
|
||||
}
|
||||
}
|
||||
|
||||
function createChatOutput(text: string): AutoSlashCommandHookOutput {
|
||||
return {
|
||||
message: {},
|
||||
parts: [{ type: "text", text }],
|
||||
}
|
||||
}
|
||||
|
||||
function createCommandInput(sessionID: string, command: string): CommandExecuteBeforeInput {
|
||||
return {
|
||||
sessionID,
|
||||
command,
|
||||
arguments: "",
|
||||
}
|
||||
}
|
||||
|
||||
function createCommandOutput(text: string): CommandExecuteBeforeOutput {
|
||||
return {
|
||||
parts: [{ type: "text", text }],
|
||||
}
|
||||
}
|
||||
|
||||
describe("createAutoSlashCommandHook leak prevention", () => {
|
||||
beforeEach(() => {
|
||||
executeSlashCommandMock.mockClear()
|
||||
logMock.mockClear()
|
||||
})
|
||||
|
||||
describe("#given hook with sessionProcessedCommandExecutions", () => {
|
||||
describe("#when same command executed twice for same session", () => {
|
||||
it("#then second execution is deduplicated", async () => {
|
||||
const hook = createAutoSlashCommandHook()
|
||||
const input = createCommandInput("session-dedup", "leak-test-command")
|
||||
const firstOutput = createCommandOutput("first")
|
||||
const secondOutput = createCommandOutput("second")
|
||||
|
||||
await hook["command.execute.before"](input, firstOutput)
|
||||
await hook["command.execute.before"](input, secondOutput)
|
||||
|
||||
expect(executeSlashCommandMock).toHaveBeenCalledTimes(1)
|
||||
expect(firstOutput.parts[0].text).toContain(AUTO_SLASH_COMMAND_TAG_OPEN)
|
||||
expect(secondOutput.parts[0].text).toBe("second")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given hook with entries from multiple sessions", () => {
|
||||
describe("#when dispose() is called", () => {
|
||||
it("#then both Sets are empty", async () => {
|
||||
const hook = createAutoSlashCommandHook()
|
||||
await hook["chat.message"](
|
||||
createChatInput("session-chat", "message-chat"),
|
||||
createChatOutput("/leak-chat")
|
||||
)
|
||||
await hook["command.execute.before"](
|
||||
createCommandInput("session-command", "leak-command"),
|
||||
createCommandOutput("before")
|
||||
)
|
||||
executeSlashCommandMock.mockClear()
|
||||
|
||||
hook.dispose()
|
||||
const chatOutputAfterDispose = createChatOutput("/leak-chat")
|
||||
const commandOutputAfterDispose = createCommandOutput("after")
|
||||
await hook["chat.message"](
|
||||
createChatInput("session-chat", "message-chat"),
|
||||
chatOutputAfterDispose
|
||||
)
|
||||
await hook["command.execute.before"](
|
||||
createCommandInput("session-command", "leak-command"),
|
||||
commandOutputAfterDispose
|
||||
)
|
||||
|
||||
expect(executeSlashCommandMock).toHaveBeenCalledTimes(2)
|
||||
expect(chatOutputAfterDispose.parts[0].text).toContain(AUTO_SLASH_COMMAND_TAG_OPEN)
|
||||
expect(commandOutputAfterDispose.parts[0].text).toContain(
|
||||
AUTO_SLASH_COMMAND_TAG_OPEN
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given Set with more than 10000 entries", () => {
|
||||
describe("#when new entry added", () => {
|
||||
it("#then Set size is reduced", async () => {
|
||||
const hook = createAutoSlashCommandHook()
|
||||
const oldestInput = createChatInput("session-oldest", "message-oldest")
|
||||
await hook["chat.message"](oldestInput, createChatOutput("/leak-oldest"))
|
||||
|
||||
for (let index = 0; index < 10000; index += 1) {
|
||||
await hook["chat.message"](
|
||||
createChatInput(`session-${index}`, `message-${index}`),
|
||||
createChatOutput(`/leak-${index}`)
|
||||
)
|
||||
}
|
||||
|
||||
const newestInput = createChatInput("session-newest", "message-newest")
|
||||
await hook["chat.message"](newestInput, createChatOutput("/leak-newest"))
|
||||
executeSlashCommandMock.mockClear()
|
||||
const oldestRetryOutput = createChatOutput("/leak-oldest")
|
||||
const newestRetryOutput = createChatOutput("/leak-newest")
|
||||
|
||||
await hook["chat.message"](oldestInput, oldestRetryOutput)
|
||||
await hook["chat.message"](newestInput, newestRetryOutput)
|
||||
|
||||
expect(executeSlashCommandMock).toHaveBeenCalledTimes(1)
|
||||
expect(oldestRetryOutput.parts[0].text).toContain(AUTO_SLASH_COMMAND_TAG_OPEN)
|
||||
expect(newestRetryOutput.parts[0].text).toBe("/leak-newest")
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
AUTO_SLASH_COMMAND_TAG_CLOSE,
|
||||
AUTO_SLASH_COMMAND_TAG_OPEN,
|
||||
} from "./constants"
|
||||
import { createProcessedCommandStore } from "./processed-command-store"
|
||||
import type {
|
||||
AutoSlashCommandHookInput,
|
||||
AutoSlashCommandHookOutput,
|
||||
@@ -17,8 +18,22 @@ import type {
|
||||
} from "./types"
|
||||
import type { LoadedSkill } from "../../features/opencode-skill-loader"
|
||||
|
||||
const sessionProcessedCommands = new Set<string>()
|
||||
const sessionProcessedCommandExecutions = new Set<string>()
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null
|
||||
}
|
||||
|
||||
function getDeletedSessionID(properties: unknown): string | null {
|
||||
if (!isRecord(properties)) {
|
||||
return null
|
||||
}
|
||||
|
||||
const info = properties.info
|
||||
if (!isRecord(info)) {
|
||||
return null
|
||||
}
|
||||
|
||||
return typeof info.id === "string" ? info.id : null
|
||||
}
|
||||
|
||||
export interface AutoSlashCommandHookOptions {
|
||||
skills?: LoadedSkill[]
|
||||
@@ -32,6 +47,13 @@ export function createAutoSlashCommandHook(options?: AutoSlashCommandHookOptions
|
||||
pluginsEnabled: options?.pluginsEnabled,
|
||||
enabledPluginsOverride: options?.enabledPluginsOverride,
|
||||
}
|
||||
const sessionProcessedCommands = createProcessedCommandStore()
|
||||
const sessionProcessedCommandExecutions = createProcessedCommandStore()
|
||||
|
||||
const dispose = (): void => {
|
||||
sessionProcessedCommands.clear()
|
||||
sessionProcessedCommandExecutions.clear()
|
||||
}
|
||||
|
||||
return {
|
||||
"chat.message": async (
|
||||
@@ -61,7 +83,9 @@ export function createAutoSlashCommandHook(options?: AutoSlashCommandHookOptions
|
||||
return
|
||||
}
|
||||
|
||||
const commandKey = `${input.sessionID}:${input.messageID}:${parsed.command}`
|
||||
const commandKey = input.messageID
|
||||
? `${input.sessionID}:${input.messageID}:${parsed.command}`
|
||||
: `${input.sessionID}:${parsed.command}`
|
||||
if (sessionProcessedCommands.has(commandKey)) {
|
||||
return
|
||||
}
|
||||
@@ -101,7 +125,7 @@ export function createAutoSlashCommandHook(options?: AutoSlashCommandHookOptions
|
||||
input: CommandExecuteBeforeInput,
|
||||
output: CommandExecuteBeforeOutput
|
||||
): Promise<void> => {
|
||||
const commandKey = `${input.sessionID}:${input.command}:${Date.now()}`
|
||||
const commandKey = `${input.sessionID}:${input.command.toLowerCase()}:${input.arguments || ""}`
|
||||
if (sessionProcessedCommandExecutions.has(commandKey)) {
|
||||
return
|
||||
}
|
||||
@@ -145,5 +169,23 @@ export function createAutoSlashCommandHook(options?: AutoSlashCommandHookOptions
|
||||
command: input.command,
|
||||
})
|
||||
},
|
||||
event: async ({
|
||||
event,
|
||||
}: {
|
||||
event: { type: string; properties?: unknown }
|
||||
}): Promise<void> => {
|
||||
if (event.type !== "session.deleted") {
|
||||
return
|
||||
}
|
||||
|
||||
const sessionID = getDeletedSessionID(event.properties)
|
||||
if (!sessionID) {
|
||||
return
|
||||
}
|
||||
|
||||
sessionProcessedCommands.cleanupSession(sessionID)
|
||||
sessionProcessedCommandExecutions.cleanupSession(sessionID)
|
||||
},
|
||||
dispose,
|
||||
}
|
||||
}
|
||||
|
||||
41
src/hooks/auto-slash-command/processed-command-store.ts
Normal file
41
src/hooks/auto-slash-command/processed-command-store.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
const MAX_PROCESSED_ENTRY_COUNT = 10_000
|
||||
|
||||
function trimProcessedEntries(entries: Set<string>): Set<string> {
|
||||
if (entries.size <= MAX_PROCESSED_ENTRY_COUNT) {
|
||||
return entries
|
||||
}
|
||||
|
||||
return new Set(Array.from(entries).slice(Math.floor(entries.size / 2)))
|
||||
}
|
||||
|
||||
function removeSessionEntries(entries: Set<string>, sessionID: string): Set<string> {
|
||||
const sessionPrefix = `${sessionID}:`
|
||||
return new Set(Array.from(entries).filter((entry) => !entry.startsWith(sessionPrefix)))
|
||||
}
|
||||
|
||||
export interface ProcessedCommandStore {
|
||||
has(commandKey: string): boolean
|
||||
add(commandKey: string): void
|
||||
cleanupSession(sessionID: string): void
|
||||
clear(): void
|
||||
}
|
||||
|
||||
export function createProcessedCommandStore(): ProcessedCommandStore {
|
||||
let entries = new Set<string>()
|
||||
|
||||
return {
|
||||
has(commandKey: string): boolean {
|
||||
return entries.has(commandKey)
|
||||
},
|
||||
add(commandKey: string): void {
|
||||
entries.add(commandKey)
|
||||
entries = trimProcessedEntries(entries)
|
||||
},
|
||||
cleanupSession(sessionID: string): void {
|
||||
entries = removeSessionEntries(entries, sessionID)
|
||||
},
|
||||
clear(): void {
|
||||
entries.clear()
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { HookDeps } from "./types"
|
||||
import type { HookDeps, RuntimeFallbackTimeout } from "./types"
|
||||
import { HOOK_NAME } from "./constants"
|
||||
import { log } from "../../shared/logger"
|
||||
import { normalizeAgentName, resolveAgentForSession } from "./agent-resolver"
|
||||
@@ -9,8 +9,8 @@ import { SessionCategoryRegistry } from "../../shared/session-category-registry"
|
||||
|
||||
const SESSION_TTL_MS = 30 * 60 * 1000
|
||||
|
||||
declare function setTimeout(callback: () => void | Promise<void>, delay?: number): ReturnType<typeof globalThis.setTimeout>
|
||||
declare function clearTimeout(timeout: ReturnType<typeof globalThis.setTimeout>): void
|
||||
declare function setTimeout(callback: () => void | Promise<void>, delay?: number): RuntimeFallbackTimeout
|
||||
declare function clearTimeout(timeout: RuntimeFallbackTimeout): void
|
||||
|
||||
export function createAutoRetryHelpers(deps: HookDeps) {
|
||||
const { ctx, config, options, sessionStates, sessionLastAccess, sessionRetryInFlight, sessionAwaitingFallbackResult, sessionFallbackTimeouts, pluginConfig } = deps
|
||||
|
||||
160
src/hooks/runtime-fallback/dispose.test.ts
Normal file
160
src/hooks/runtime-fallback/dispose.test.ts
Normal file
@@ -0,0 +1,160 @@
|
||||
import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"
|
||||
import type { HookDeps, RuntimeFallbackPluginInput } from "./types"
|
||||
|
||||
let capturedDeps: HookDeps | undefined
|
||||
|
||||
const mockCreateAutoRetryHelpers = mock((deps: HookDeps) => {
|
||||
capturedDeps = deps
|
||||
|
||||
return {
|
||||
abortSessionRequest: async () => {},
|
||||
clearSessionFallbackTimeout: () => {},
|
||||
scheduleSessionFallbackTimeout: () => {},
|
||||
autoRetryWithFallback: async () => {},
|
||||
resolveAgentForSessionFromContext: async () => undefined,
|
||||
cleanupStaleSessions: () => {},
|
||||
}
|
||||
})
|
||||
|
||||
const mockCreateEventHandler = mock(() => async () => {})
|
||||
const mockCreateMessageUpdateHandler = mock(() => async () => {})
|
||||
const mockCreateChatMessageHandler = mock(() => async () => {})
|
||||
|
||||
mock.module("./auto-retry", () => ({
|
||||
createAutoRetryHelpers: mockCreateAutoRetryHelpers,
|
||||
}))
|
||||
|
||||
mock.module("./event-handler", () => ({
|
||||
createEventHandler: mockCreateEventHandler,
|
||||
}))
|
||||
|
||||
mock.module("./message-update-handler", () => ({
|
||||
createMessageUpdateHandler: mockCreateMessageUpdateHandler,
|
||||
}))
|
||||
|
||||
mock.module("./chat-message-handler", () => ({
|
||||
createChatMessageHandler: mockCreateChatMessageHandler,
|
||||
}))
|
||||
|
||||
const { createRuntimeFallbackHook } = await import("./hook")
|
||||
|
||||
function createMockContext(): RuntimeFallbackPluginInput {
|
||||
return {
|
||||
client: {
|
||||
session: {
|
||||
abort: async () => ({}),
|
||||
messages: async () => ({}),
|
||||
promptAsync: async () => ({}),
|
||||
},
|
||||
tui: {
|
||||
showToast: async () => ({}),
|
||||
},
|
||||
},
|
||||
directory: "/test",
|
||||
}
|
||||
}
|
||||
|
||||
describe("createRuntimeFallbackHook dispose", () => {
|
||||
const originalSetInterval = globalThis.setInterval
|
||||
const originalClearInterval = globalThis.clearInterval
|
||||
const originalClearTimeout = globalThis.clearTimeout
|
||||
const createdIntervals: Array<ReturnType<typeof originalSetInterval>> = []
|
||||
const clearedIntervals: Array<Parameters<typeof originalClearInterval>[0]> = []
|
||||
const clearedTimeouts: Array<Parameters<typeof originalClearTimeout>[0]> = []
|
||||
const timeoutMapSizesDuringClear: number[] = []
|
||||
|
||||
beforeEach(() => {
|
||||
capturedDeps = undefined
|
||||
createdIntervals.length = 0
|
||||
clearedIntervals.length = 0
|
||||
clearedTimeouts.length = 0
|
||||
timeoutMapSizesDuringClear.length = 0
|
||||
|
||||
mockCreateAutoRetryHelpers.mockClear()
|
||||
mockCreateEventHandler.mockClear()
|
||||
mockCreateMessageUpdateHandler.mockClear()
|
||||
mockCreateChatMessageHandler.mockClear()
|
||||
|
||||
const wrappedSetInterval = ((handler: () => void, timeout?: number) => {
|
||||
const interval = originalSetInterval(handler, timeout)
|
||||
createdIntervals.push(interval)
|
||||
return interval
|
||||
}) as typeof globalThis.setInterval
|
||||
|
||||
const wrappedClearInterval = ((interval?: Parameters<typeof clearInterval>[0]) => {
|
||||
clearedIntervals.push(interval)
|
||||
return originalClearInterval(interval)
|
||||
}) as typeof globalThis.clearInterval
|
||||
|
||||
const wrappedClearTimeout = ((timeout?: Parameters<typeof clearTimeout>[0]) => {
|
||||
timeoutMapSizesDuringClear.push(capturedDeps?.sessionFallbackTimeouts.size ?? -1)
|
||||
clearedTimeouts.push(timeout)
|
||||
return originalClearTimeout(timeout)
|
||||
}) as typeof globalThis.clearTimeout
|
||||
|
||||
globalThis.setInterval = wrappedSetInterval
|
||||
globalThis.clearInterval = wrappedClearInterval
|
||||
globalThis.clearTimeout = wrappedClearTimeout
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.setInterval = originalSetInterval
|
||||
globalThis.clearInterval = originalClearInterval
|
||||
globalThis.clearTimeout = originalClearTimeout
|
||||
})
|
||||
|
||||
test("#given runtime-fallback hook created #when dispose() is called #then cleanup interval is cleared", () => {
|
||||
// given
|
||||
const hook = createRuntimeFallbackHook(createMockContext(), { pluginConfig: {} })
|
||||
|
||||
// when
|
||||
hook.dispose?.()
|
||||
|
||||
// then
|
||||
expect(createdIntervals).toHaveLength(1)
|
||||
expect(clearedIntervals).toEqual([createdIntervals[0]])
|
||||
})
|
||||
|
||||
test("#given hook with session state data #when dispose() is called #then all Maps and Sets are empty", () => {
|
||||
// given
|
||||
const hook = createRuntimeFallbackHook(createMockContext(), { pluginConfig: {} })
|
||||
const fallbackTimeout = setTimeout(() => {}, 60_000)
|
||||
|
||||
capturedDeps?.sessionStates.set("session-1", {
|
||||
originalModel: "anthropic/claude-opus-4-6",
|
||||
currentModel: "openai/gpt-5.4",
|
||||
fallbackIndex: 1,
|
||||
failedModels: new Map([["anthropic/claude-opus-4-6", 1]]),
|
||||
attemptCount: 1,
|
||||
})
|
||||
capturedDeps?.sessionLastAccess.set("session-1", Date.now())
|
||||
capturedDeps?.sessionRetryInFlight.add("session-1")
|
||||
capturedDeps?.sessionAwaitingFallbackResult.add("session-1")
|
||||
capturedDeps?.sessionFallbackTimeouts.set("session-1", fallbackTimeout)
|
||||
|
||||
// when
|
||||
hook.dispose?.()
|
||||
|
||||
// then
|
||||
expect(capturedDeps?.sessionStates.size).toBe(0)
|
||||
expect(capturedDeps?.sessionLastAccess.size).toBe(0)
|
||||
expect(capturedDeps?.sessionRetryInFlight.size).toBe(0)
|
||||
expect(capturedDeps?.sessionAwaitingFallbackResult.size).toBe(0)
|
||||
expect(capturedDeps?.sessionFallbackTimeouts.size).toBe(0)
|
||||
})
|
||||
|
||||
test("#given hook with pending fallback timeouts #when dispose() is called #then timeouts are cleared before Map is emptied", () => {
|
||||
// given
|
||||
const hook = createRuntimeFallbackHook(createMockContext(), { pluginConfig: {} })
|
||||
const fallbackTimeout = setTimeout(() => {}, 60_000)
|
||||
capturedDeps?.sessionFallbackTimeouts.set("session-1", fallbackTimeout)
|
||||
|
||||
// when
|
||||
hook.dispose?.()
|
||||
|
||||
// then
|
||||
expect(clearedTimeouts).toEqual([fallbackTimeout])
|
||||
expect(timeoutMapSizesDuringClear).toEqual([1])
|
||||
expect(capturedDeps?.sessionFallbackTimeouts.size).toBe(0)
|
||||
})
|
||||
})
|
||||
@@ -1,5 +1,4 @@
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import type { HookDeps, RuntimeFallbackHook, RuntimeFallbackOptions } from "./types"
|
||||
import type { HookDeps, RuntimeFallbackHook, RuntimeFallbackInterval, RuntimeFallbackOptions, RuntimeFallbackPluginInput, RuntimeFallbackTimeout } from "./types"
|
||||
import { DEFAULT_CONFIG, HOOK_NAME } from "./constants"
|
||||
import { log } from "../../shared/logger"
|
||||
import { loadPluginConfig } from "../../plugin-config"
|
||||
@@ -8,8 +7,12 @@ import { createEventHandler } from "./event-handler"
|
||||
import { createMessageUpdateHandler } from "./message-update-handler"
|
||||
import { createChatMessageHandler } from "./chat-message-handler"
|
||||
|
||||
declare function setInterval(callback: () => void, delay?: number): RuntimeFallbackInterval
|
||||
declare function clearInterval(interval: RuntimeFallbackInterval): void
|
||||
declare function clearTimeout(timeout: RuntimeFallbackTimeout): void
|
||||
|
||||
export function createRuntimeFallbackHook(
|
||||
ctx: PluginInput,
|
||||
ctx: RuntimeFallbackPluginInput,
|
||||
options?: RuntimeFallbackOptions
|
||||
): RuntimeFallbackHook {
|
||||
const config = {
|
||||
@@ -60,8 +63,23 @@ export function createRuntimeFallbackHook(
|
||||
await baseEventHandler({ event })
|
||||
}
|
||||
|
||||
const dispose = () => {
|
||||
clearInterval(cleanupInterval)
|
||||
|
||||
for (const fallbackTimeout of deps.sessionFallbackTimeouts.values()) {
|
||||
clearTimeout(fallbackTimeout)
|
||||
}
|
||||
|
||||
deps.sessionStates.clear()
|
||||
deps.sessionLastAccess.clear()
|
||||
deps.sessionRetryInFlight.clear()
|
||||
deps.sessionAwaitingFallbackResult.clear()
|
||||
deps.sessionFallbackTimeouts.clear()
|
||||
}
|
||||
|
||||
return {
|
||||
event: eventHandler,
|
||||
"chat.message": chatMessageHandler,
|
||||
dispose,
|
||||
} as RuntimeFallbackHook
|
||||
}
|
||||
|
||||
@@ -1,6 +1,40 @@
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import type { RuntimeFallbackConfig, OhMyOpenCodeConfig } from "../../config"
|
||||
|
||||
export interface RuntimeFallbackInterval {
|
||||
unref: () => void
|
||||
}
|
||||
|
||||
export type RuntimeFallbackTimeout = object | number
|
||||
|
||||
export interface RuntimeFallbackPluginInput {
|
||||
client: {
|
||||
session: {
|
||||
abort: (input: { path: { id: string } }) => Promise<unknown>
|
||||
messages: (input: { path: { id: string }; query: { directory: string } }) => Promise<unknown>
|
||||
promptAsync: (input: {
|
||||
path: { id: string }
|
||||
body: {
|
||||
agent?: string
|
||||
model: { providerID: string; modelID: string }
|
||||
parts: Array<{ type: "text"; text: string }>
|
||||
}
|
||||
query: { directory: string }
|
||||
}) => Promise<unknown>
|
||||
}
|
||||
tui: {
|
||||
showToast: (input: {
|
||||
body: {
|
||||
title: string
|
||||
message: string
|
||||
variant: "success" | "error" | "info" | "warning"
|
||||
duration: number
|
||||
}
|
||||
}) => Promise<unknown>
|
||||
}
|
||||
}
|
||||
directory: string
|
||||
}
|
||||
|
||||
export interface FallbackState {
|
||||
originalModel: string
|
||||
currentModel: string
|
||||
@@ -26,10 +60,11 @@ export interface RuntimeFallbackOptions {
|
||||
export interface RuntimeFallbackHook {
|
||||
event: (input: { event: { type: string; properties?: unknown } }) => Promise<void>
|
||||
"chat.message"?: (input: { sessionID: string; agent?: string; model?: { providerID: string; modelID: string } }, output: { message: { model?: { providerID: string; modelID: string } }; parts?: Array<{ type: string; text?: string }> }) => Promise<void>
|
||||
dispose?: () => void
|
||||
}
|
||||
|
||||
export interface HookDeps {
|
||||
ctx: PluginInput
|
||||
ctx: RuntimeFallbackPluginInput
|
||||
config: Required<RuntimeFallbackConfig>
|
||||
options: RuntimeFallbackOptions | undefined
|
||||
pluginConfig: OhMyOpenCodeConfig | undefined
|
||||
@@ -37,5 +72,5 @@ export interface HookDeps {
|
||||
sessionLastAccess: Map<string, number>
|
||||
sessionRetryInFlight: Set<string>
|
||||
sessionAwaitingFallbackResult: Set<string>
|
||||
sessionFallbackTimeouts: Map<string, ReturnType<typeof setTimeout>>
|
||||
sessionFallbackTimeouts: Map<string, RuntimeFallbackTimeout>
|
||||
}
|
||||
|
||||
101
src/hooks/todo-continuation-enforcer/dispose.test.ts
Normal file
101
src/hooks/todo-continuation-enforcer/dispose.test.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
declare module "bun:test" {
|
||||
export interface Matchers {
|
||||
toBeDefined(): void
|
||||
toBeUndefined(): void
|
||||
toHaveLength(expected: number): void
|
||||
}
|
||||
}
|
||||
|
||||
import { afterAll, afterEach, describe, expect, it, mock } from "bun:test"
|
||||
|
||||
import * as actualSessionStateModule from "./session-state"
|
||||
import type { SessionStateStore } from "./session-state"
|
||||
|
||||
let createdSessionStateStore: SessionStateStore | undefined
|
||||
const createActualSessionStateStore = actualSessionStateModule.createSessionStateStore
|
||||
|
||||
const mockModule = mock as typeof mock & {
|
||||
module: (specifier: string, factory: () => unknown) => void
|
||||
}
|
||||
|
||||
mockModule.module("./session-state", () => ({
|
||||
...actualSessionStateModule,
|
||||
createSessionStateStore: () => {
|
||||
const sessionStateStore = createActualSessionStateStore()
|
||||
createdSessionStateStore = sessionStateStore
|
||||
return sessionStateStore
|
||||
},
|
||||
}))
|
||||
|
||||
const { createTodoContinuationEnforcer } = await import(".")
|
||||
|
||||
type PluginInput = Parameters<typeof createTodoContinuationEnforcer>[0]
|
||||
|
||||
function createMockPluginInput(): PluginInput {
|
||||
return {
|
||||
directory: "/tmp/test",
|
||||
} as PluginInput
|
||||
}
|
||||
|
||||
function getCreatedSessionStateStore(): SessionStateStore {
|
||||
if (!createdSessionStateStore) {
|
||||
throw new Error("expected session state store to be created")
|
||||
}
|
||||
|
||||
return createdSessionStateStore
|
||||
}
|
||||
|
||||
describe("todo-continuation-enforcer dispose", () => {
|
||||
afterEach(() => {
|
||||
createdSessionStateStore?.shutdown()
|
||||
createdSessionStateStore = undefined
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
mockModule.module("./session-state", () => actualSessionStateModule)
|
||||
})
|
||||
|
||||
it("#given todo-continuation-enforcer created #when dispose exists on return value #then it is a function", () => {
|
||||
// given
|
||||
const enforcer = createTodoContinuationEnforcer(createMockPluginInput())
|
||||
|
||||
// when
|
||||
const { dispose } = enforcer
|
||||
|
||||
// then
|
||||
expect(typeof dispose).toBe("function")
|
||||
|
||||
enforcer.dispose()
|
||||
})
|
||||
|
||||
it("#given enforcer with active session states #when dispose is called #then internal session state store is shut down", () => {
|
||||
// given
|
||||
const originalClearInterval = globalThis.clearInterval
|
||||
const clearIntervalCalls: Array<Parameters<typeof clearInterval>[0]> = []
|
||||
globalThis.clearInterval = ((timer?: Parameters<typeof clearInterval>[0]) => {
|
||||
clearIntervalCalls.push(timer)
|
||||
return originalClearInterval(timer)
|
||||
}) as typeof clearInterval
|
||||
|
||||
try {
|
||||
const enforcer = createTodoContinuationEnforcer(createMockPluginInput())
|
||||
const sessionStateStore = getCreatedSessionStateStore()
|
||||
|
||||
enforcer.markRecovering("session-1")
|
||||
enforcer.markRecovering("session-2")
|
||||
|
||||
expect(sessionStateStore.getExistingState("session-1")).toBeDefined()
|
||||
expect(sessionStateStore.getExistingState("session-2")).toBeDefined()
|
||||
|
||||
// when
|
||||
enforcer.dispose()
|
||||
|
||||
// then
|
||||
expect(clearIntervalCalls).toHaveLength(1)
|
||||
expect(sessionStateStore.getExistingState("session-1")).toBeUndefined()
|
||||
expect(sessionStateStore.getExistingState("session-2")).toBeUndefined()
|
||||
} finally {
|
||||
globalThis.clearInterval = originalClearInterval
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -56,5 +56,6 @@ export function createTodoContinuationEnforcer(
|
||||
markRecovering,
|
||||
markRecoveryComplete,
|
||||
cancelAllCountdowns,
|
||||
dispose: () => sessionStateStore.shutdown(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ export interface TodoContinuationEnforcer {
|
||||
markRecovering: (sessionID: string) => void
|
||||
markRecoveryComplete: (sessionID: string) => void
|
||||
cancelAllCountdowns: () => void
|
||||
dispose: () => void
|
||||
}
|
||||
|
||||
export interface Todo {
|
||||
|
||||
12
src/index.ts
12
src/index.ts
@@ -7,6 +7,7 @@ import { createHooks } from "./create-hooks"
|
||||
import { createManagers } from "./create-managers"
|
||||
import { createTools } from "./create-tools"
|
||||
import { createPluginInterface } from "./plugin-interface"
|
||||
import { createPluginDispose, type PluginDispose } from "./plugin-dispose"
|
||||
|
||||
import { loadPluginConfig } from "./plugin-config"
|
||||
import { createModelCacheState } from "./plugin-state"
|
||||
@@ -14,6 +15,8 @@ import { createFirstMessageVariantGate } from "./shared/first-message-variant"
|
||||
import { injectServerAuthIntoClient, log } from "./shared"
|
||||
import { startTmuxCheck } from "./tools"
|
||||
|
||||
let activePluginDispose: PluginDispose | null = null
|
||||
|
||||
const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
||||
// Initialize config context for plugin runtime (prevents warnings from hooks)
|
||||
initConfigContext("opencode", null)
|
||||
@@ -23,6 +26,7 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
||||
|
||||
injectServerAuthIntoClient(ctx.client)
|
||||
startTmuxCheck()
|
||||
await activePluginDispose?.()
|
||||
|
||||
const pluginConfig = loadPluginConfig(ctx.directory, ctx)
|
||||
const disabledHooks = new Set(pluginConfig.disabled_hooks ?? [])
|
||||
@@ -67,6 +71,12 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
||||
availableSkills: toolsResult.availableSkills,
|
||||
})
|
||||
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager: managers.backgroundManager,
|
||||
skillMcpManager: managers.skillMcpManager,
|
||||
disposeHooks: hooks.disposeHooks,
|
||||
})
|
||||
|
||||
const pluginInterface = createPluginInterface({
|
||||
ctx,
|
||||
pluginConfig,
|
||||
@@ -76,6 +86,8 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
||||
tools: toolsResult.filteredTools,
|
||||
})
|
||||
|
||||
activePluginDispose = dispose
|
||||
|
||||
return {
|
||||
...pluginInterface,
|
||||
|
||||
|
||||
175
src/plugin-dispose.test.ts
Normal file
175
src/plugin-dispose.test.ts
Normal file
@@ -0,0 +1,175 @@
|
||||
import { describe, expect, spyOn, test } from "bun:test"
|
||||
|
||||
import { disposeCreatedHooks } from "./create-hooks"
|
||||
import { createPluginDispose } from "./plugin-dispose"
|
||||
|
||||
describe("createPluginDispose", () => {
|
||||
test("#given plugin with active managers and hooks #when dispose() is called #then backgroundManager.shutdown() is called", async () => {
|
||||
// given
|
||||
const backgroundManager = {
|
||||
shutdown: async (): Promise<void> => {},
|
||||
}
|
||||
const skillMcpManager = {
|
||||
disconnectAll: async (): Promise<void> => {},
|
||||
}
|
||||
const shutdownSpy = spyOn(backgroundManager, "shutdown")
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager,
|
||||
skillMcpManager,
|
||||
disposeHooks: (): void => {},
|
||||
})
|
||||
|
||||
// when
|
||||
await dispose()
|
||||
|
||||
// then
|
||||
expect(shutdownSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
test("#given plugin with active MCP connections #when dispose() is called #then skillMcpManager.disconnectAll() is called", async () => {
|
||||
// given
|
||||
const backgroundManager = {
|
||||
shutdown: async (): Promise<void> => {},
|
||||
}
|
||||
const skillMcpManager = {
|
||||
disconnectAll: async (): Promise<void> => {},
|
||||
}
|
||||
const disconnectAllSpy = spyOn(skillMcpManager, "disconnectAll")
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager,
|
||||
skillMcpManager,
|
||||
disposeHooks: (): void => {},
|
||||
})
|
||||
|
||||
// when
|
||||
await dispose()
|
||||
|
||||
// then
|
||||
expect(disconnectAllSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
test("#given plugin with hooks that have dispose #when dispose() is called #then each hook's dispose is called", async () => {
|
||||
// given
|
||||
const runtimeFallback = {
|
||||
dispose: (): void => {},
|
||||
}
|
||||
const todoContinuationEnforcer = {
|
||||
dispose: (): void => {},
|
||||
}
|
||||
const autoSlashCommand = {
|
||||
dispose: (): void => {},
|
||||
}
|
||||
const runtimeFallbackDisposeSpy = spyOn(runtimeFallback, "dispose")
|
||||
const todoContinuationEnforcerDisposeSpy = spyOn(todoContinuationEnforcer, "dispose")
|
||||
const autoSlashCommandDisposeSpy = spyOn(autoSlashCommand, "dispose")
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager: {
|
||||
shutdown: async (): Promise<void> => {},
|
||||
},
|
||||
skillMcpManager: {
|
||||
disconnectAll: async (): Promise<void> => {},
|
||||
},
|
||||
disposeHooks: (): void => {
|
||||
disposeCreatedHooks({
|
||||
runtimeFallback,
|
||||
todoContinuationEnforcer,
|
||||
autoSlashCommand,
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
// when
|
||||
await dispose()
|
||||
|
||||
// then
|
||||
expect(runtimeFallbackDisposeSpy).toHaveBeenCalledTimes(1)
|
||||
expect(todoContinuationEnforcerDisposeSpy).toHaveBeenCalledTimes(1)
|
||||
expect(autoSlashCommandDisposeSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
test("#given dispose already called #when dispose() called again #then no errors", async () => {
|
||||
// given
|
||||
const backgroundManager = {
|
||||
shutdown: async (): Promise<void> => {},
|
||||
}
|
||||
const skillMcpManager = {
|
||||
disconnectAll: async (): Promise<void> => {},
|
||||
}
|
||||
const disposeHooks = {
|
||||
run: (): void => {},
|
||||
}
|
||||
const shutdownSpy = spyOn(backgroundManager, "shutdown")
|
||||
const disconnectAllSpy = spyOn(skillMcpManager, "disconnectAll")
|
||||
const disposeHooksSpy = spyOn(disposeHooks, "run")
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager,
|
||||
skillMcpManager,
|
||||
disposeHooks: disposeHooks.run,
|
||||
})
|
||||
|
||||
// when
|
||||
await dispose()
|
||||
await dispose()
|
||||
|
||||
// then
|
||||
expect(shutdownSpy).toHaveBeenCalledTimes(1)
|
||||
expect(disconnectAllSpy).toHaveBeenCalledTimes(1)
|
||||
expect(disposeHooksSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
test("#given backgroundManager.shutdown() throws #when dispose() is called #then skillMcpManager.disconnectAll() and disposeHooks() are still called", async () => {
|
||||
// given
|
||||
const backgroundManager = {
|
||||
shutdown: async (): Promise<void> => {
|
||||
throw new Error("shutdown failed")
|
||||
},
|
||||
}
|
||||
const skillMcpManager = {
|
||||
disconnectAll: async (): Promise<void> => {},
|
||||
}
|
||||
const disposeHooksCalls: number[] = []
|
||||
const disconnectAllSpy = spyOn(skillMcpManager, "disconnectAll")
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager,
|
||||
skillMcpManager,
|
||||
disposeHooks: (): void => {
|
||||
disposeHooksCalls.push(1)
|
||||
},
|
||||
})
|
||||
|
||||
// when
|
||||
await dispose()
|
||||
|
||||
// then
|
||||
expect(disconnectAllSpy).toHaveBeenCalledTimes(1)
|
||||
expect(disposeHooksCalls).toHaveLength(1)
|
||||
})
|
||||
|
||||
test("#given skillMcpManager.disconnectAll() throws #when dispose() is called #then disposeHooks() is still called", async () => {
|
||||
// given
|
||||
const backgroundManager = {
|
||||
shutdown: async (): Promise<void> => {},
|
||||
}
|
||||
const skillMcpManager = {
|
||||
disconnectAll: async (): Promise<void> => {
|
||||
throw new Error("disconnectAll failed")
|
||||
},
|
||||
}
|
||||
const disposeHooksCalls: number[] = []
|
||||
const shutdownSpy = spyOn(backgroundManager, "shutdown")
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager,
|
||||
skillMcpManager,
|
||||
disposeHooks: (): void => {
|
||||
disposeHooksCalls.push(1)
|
||||
},
|
||||
})
|
||||
|
||||
// when
|
||||
await dispose()
|
||||
|
||||
// then
|
||||
expect(shutdownSpy).toHaveBeenCalledTimes(1)
|
||||
expect(disposeHooksCalls).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
43
src/plugin-dispose.ts
Normal file
43
src/plugin-dispose.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { log } from "./shared"
|
||||
|
||||
export type PluginDispose = () => Promise<void>
|
||||
|
||||
export function createPluginDispose(args: {
|
||||
backgroundManager: {
|
||||
shutdown: () => void | Promise<void>
|
||||
}
|
||||
skillMcpManager: {
|
||||
disconnectAll: () => Promise<void>
|
||||
}
|
||||
disposeHooks: () => void
|
||||
}): PluginDispose {
|
||||
const { backgroundManager, skillMcpManager, disposeHooks } = args
|
||||
let disposePromise: Promise<void> | null = null
|
||||
|
||||
return async (): Promise<void> => {
|
||||
if (disposePromise) {
|
||||
await disposePromise
|
||||
return
|
||||
}
|
||||
|
||||
disposePromise = (async (): Promise<void> => {
|
||||
try {
|
||||
await backgroundManager.shutdown()
|
||||
} catch (error) {
|
||||
log("[plugin-dispose] backgroundManager.shutdown() error:", error)
|
||||
}
|
||||
try {
|
||||
await skillMcpManager.disconnectAll()
|
||||
} catch (error) {
|
||||
log("[plugin-dispose] skillMcpManager.disconnectAll() error:", error)
|
||||
}
|
||||
try {
|
||||
disposeHooks()
|
||||
} catch (error) {
|
||||
log("[plugin-dispose] disposeHooks() error:", error)
|
||||
}
|
||||
})()
|
||||
|
||||
await disposePromise
|
||||
}
|
||||
}
|
||||
@@ -190,6 +190,7 @@ export function createEventHandler(args: {
|
||||
await Promise.resolve(hooks.compactionTodoPreserver?.event?.(input));
|
||||
await Promise.resolve(hooks.writeExistingFileGuard?.event?.(input));
|
||||
await Promise.resolve(hooks.atlasHook?.handler?.(input));
|
||||
await Promise.resolve(hooks.autoSlashCommand?.event?.(input));
|
||||
};
|
||||
|
||||
const recentSyntheticIdles = new Map<string, number>();
|
||||
|
||||
184
src/tools/call-omo-agent/sync-executor-leak.test.ts
Normal file
184
src/tools/call-omo-agent/sync-executor-leak.test.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"
|
||||
import {
|
||||
_resetForTesting,
|
||||
subagentSessions,
|
||||
syncSubagentSessions,
|
||||
} from "../../features/claude-code-session-state"
|
||||
import { executeSync } from "./sync-executor"
|
||||
|
||||
type ExecuteSyncArgs = Parameters<typeof executeSync>[0]
|
||||
type ExecuteSyncToolContext = Parameters<typeof executeSync>[1]
|
||||
type ExecuteSyncDeps = NonNullable<Parameters<typeof executeSync>[3]>
|
||||
|
||||
function createArgs(): ExecuteSyncArgs {
|
||||
return {
|
||||
subagent_type: "explore",
|
||||
description: "cleanup leak",
|
||||
prompt: "find something",
|
||||
run_in_background: false,
|
||||
}
|
||||
}
|
||||
|
||||
function createToolContext(): ExecuteSyncToolContext {
|
||||
return {
|
||||
sessionID: "parent-session",
|
||||
messageID: "msg-1",
|
||||
agent: "sisyphus",
|
||||
abort: new AbortController().signal,
|
||||
metadata: mock(async () => {}),
|
||||
}
|
||||
}
|
||||
|
||||
function createContext(promptAsync: ReturnType<typeof mock>) {
|
||||
return {
|
||||
client: {
|
||||
session: {
|
||||
promptAsync,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
function createDependencies(overrides?: Partial<ExecuteSyncDeps>): ExecuteSyncDeps {
|
||||
return {
|
||||
createOrGetSession: mock(async () => ({ sessionID: "ses-default", isNew: true })),
|
||||
waitForCompletion: mock(async () => {}),
|
||||
processMessages: mock(async () => "agent response"),
|
||||
setSessionFallbackChain: mock(() => {}),
|
||||
clearSessionFallbackChain: mock(() => {}),
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
describe("executeSync session cleanup", () => {
|
||||
beforeEach(() => {
|
||||
_resetForTesting()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
_resetForTesting()
|
||||
})
|
||||
|
||||
describe("#given executeSync creates a session", () => {
|
||||
test("#when execution completes successfully #then sessionID is removed from subagentSessions and syncSubagentSessions", async () => {
|
||||
// given
|
||||
const sessionID = "ses-cleanup-success"
|
||||
const args = createArgs()
|
||||
const toolContext = createToolContext()
|
||||
const promptAsync = mock(async () => ({ data: {} }))
|
||||
const deps = createDependencies({
|
||||
createOrGetSession: mock(async () => {
|
||||
subagentSessions.add(sessionID)
|
||||
syncSubagentSessions.add(sessionID)
|
||||
return { sessionID, isNew: true }
|
||||
}),
|
||||
waitForCompletion: mock(async (createdSessionID: string) => {
|
||||
expect(createdSessionID).toBe(sessionID)
|
||||
expect(subagentSessions.has(sessionID)).toBe(true)
|
||||
expect(syncSubagentSessions.has(sessionID)).toBe(true)
|
||||
}),
|
||||
})
|
||||
|
||||
expect(subagentSessions.has(sessionID)).toBe(false)
|
||||
expect(syncSubagentSessions.has(sessionID)).toBe(false)
|
||||
|
||||
// when
|
||||
const result = await executeSync(args, toolContext, createContext(promptAsync) as never, deps)
|
||||
|
||||
// then
|
||||
expect(result).toContain(`session_id: ${sessionID}`)
|
||||
expect(subagentSessions.has(sessionID)).toBe(false)
|
||||
expect(syncSubagentSessions.has(sessionID)).toBe(false)
|
||||
})
|
||||
|
||||
test("#when execution throws an error #then sessionID is still removed from both Sets", async () => {
|
||||
// given
|
||||
const sessionID = "ses-cleanup-error"
|
||||
const args = createArgs()
|
||||
const toolContext = createToolContext()
|
||||
const promptAsync = mock(async () => ({ data: {} }))
|
||||
const deps = createDependencies({
|
||||
createOrGetSession: mock(async () => {
|
||||
subagentSessions.add(sessionID)
|
||||
syncSubagentSessions.add(sessionID)
|
||||
return { sessionID, isNew: true }
|
||||
}),
|
||||
waitForCompletion: mock(async (createdSessionID: string) => {
|
||||
expect(createdSessionID).toBe(sessionID)
|
||||
expect(subagentSessions.has(sessionID)).toBe(true)
|
||||
expect(syncSubagentSessions.has(sessionID)).toBe(true)
|
||||
throw new Error("poll exploded")
|
||||
}),
|
||||
})
|
||||
|
||||
// when
|
||||
const resultPromise = executeSync(args, toolContext, createContext(promptAsync) as never, deps)
|
||||
|
||||
// then
|
||||
let thrownError: Error | undefined
|
||||
|
||||
try {
|
||||
await resultPromise
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
thrownError = error
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
expect(thrownError?.message).toBe("poll exploded")
|
||||
expect(subagentSessions.has(sessionID)).toBe(false)
|
||||
expect(syncSubagentSessions.has(sessionID)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given executeSync reuses an existing session", () => {
|
||||
test("#when execution completes successfully #then the reused session is tracked in both Sets", async () => {
|
||||
// given
|
||||
const sessionID = "ses-reused"
|
||||
const args = { ...createArgs(), session_id: sessionID }
|
||||
const toolContext = createToolContext()
|
||||
const promptAsync = mock(async () => ({ data: {} }))
|
||||
const deps = createDependencies({
|
||||
createOrGetSession: mock(async () => ({ sessionID, isNew: false })),
|
||||
waitForCompletion: mock(async (createdSessionID: string) => {
|
||||
expect(createdSessionID).toBe(sessionID)
|
||||
expect(subagentSessions.has(sessionID)).toBe(true)
|
||||
expect(syncSubagentSessions.has(sessionID)).toBe(true)
|
||||
}),
|
||||
})
|
||||
|
||||
expect(subagentSessions.has(sessionID)).toBe(false)
|
||||
expect(syncSubagentSessions.has(sessionID)).toBe(false)
|
||||
|
||||
// when
|
||||
const result = await executeSync(args, toolContext, createContext(promptAsync) as never, deps)
|
||||
|
||||
// then
|
||||
expect(result).toContain(`session_id: ${sessionID}`)
|
||||
expect(subagentSessions.has(sessionID)).toBe(true)
|
||||
expect(syncSubagentSessions.has(sessionID)).toBe(true)
|
||||
})
|
||||
|
||||
test("#when execution applies a fallback chain #then it clears that chain in finally", async () => {
|
||||
// given
|
||||
const sessionID = "ses-reused-fallback"
|
||||
const args = { ...createArgs(), session_id: sessionID }
|
||||
const toolContext = createToolContext()
|
||||
const promptAsync = mock(async () => ({ data: {} }))
|
||||
const clearSessionFallbackChain = mock(() => {})
|
||||
const deps = createDependencies({
|
||||
createOrGetSession: mock(async () => ({ sessionID, isNew: false })),
|
||||
clearSessionFallbackChain,
|
||||
})
|
||||
const fallbackChain = [{ providers: ["openai"], model: "gpt-5.4" }]
|
||||
|
||||
// when
|
||||
await executeSync(args, toolContext, createContext(promptAsync) as never, deps, fallbackChain)
|
||||
|
||||
// then
|
||||
expect(clearSessionFallbackChain).toHaveBeenCalledWith(sessionID)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -24,6 +24,7 @@ type Dependencies = {
|
||||
waitForCompletion: ReturnType<typeof mock>
|
||||
processMessages: ReturnType<typeof mock>
|
||||
setSessionFallbackChain: ReturnType<typeof mock>
|
||||
clearSessionFallbackChain: ReturnType<typeof mock>
|
||||
}
|
||||
|
||||
async function importExecuteSync(): Promise<ExecuteSync> {
|
||||
@@ -37,6 +38,7 @@ function createDependencies(overrides?: Partial<Dependencies>): Dependencies {
|
||||
waitForCompletion: mock(async () => {}),
|
||||
processMessages: mock(async () => "agent response"),
|
||||
setSessionFallbackChain: mock(() => {}),
|
||||
clearSessionFallbackChain: mock(() => {}),
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
@@ -259,6 +261,7 @@ describe("executeSync", () => {
|
||||
waitForCompletion: mock(async () => {}),
|
||||
processMessages: mock(async () => "agent response"),
|
||||
setSessionFallbackChain: mock(() => {}),
|
||||
clearSessionFallbackChain: mock(() => {}),
|
||||
}
|
||||
|
||||
const spawnReservation = {
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import type { CallOmoAgentArgs } from "./types"
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import { log } from "../../shared"
|
||||
import { subagentSessions, syncSubagentSessions } from "../../features/claude-code-session-state"
|
||||
import { clearSessionFallbackChain, setSessionFallbackChain } from "../../hooks/model-fallback/hook"
|
||||
import { getAgentToolRestrictions, log } from "../../shared"
|
||||
import type { FallbackEntry } from "../../shared/model-requirements"
|
||||
import { getAgentToolRestrictions } from "../../shared"
|
||||
import { setSessionFallbackChain } from "../../hooks/model-fallback/hook"
|
||||
import { createOrGetSession } from "./session-creator"
|
||||
import { waitForCompletion } from "./completion-poller"
|
||||
import { processMessages } from "./message-processor"
|
||||
import { createOrGetSession } from "./session-creator"
|
||||
|
||||
type SessionWithPromptAsync = {
|
||||
promptAsync: (opts: { path: { id: string }; body: Record<string, unknown> }) => Promise<unknown>
|
||||
@@ -17,6 +17,7 @@ type ExecuteSyncDeps = {
|
||||
waitForCompletion: typeof waitForCompletion
|
||||
processMessages: typeof processMessages
|
||||
setSessionFallbackChain: typeof setSessionFallbackChain
|
||||
clearSessionFallbackChain: typeof clearSessionFallbackChain
|
||||
}
|
||||
|
||||
type SpawnReservation = {
|
||||
@@ -29,6 +30,7 @@ const defaultDeps: ExecuteSyncDeps = {
|
||||
waitForCompletion,
|
||||
processMessages,
|
||||
setSessionFallbackChain,
|
||||
clearSessionFallbackChain,
|
||||
}
|
||||
|
||||
export async function executeSync(
|
||||
@@ -46,10 +48,15 @@ export async function executeSync(
|
||||
spawnReservation?: SpawnReservation,
|
||||
): Promise<string> {
|
||||
let sessionID: string | undefined
|
||||
let createdSessionForExecution = false
|
||||
let appliedFallbackChain = false
|
||||
|
||||
try {
|
||||
const session = await deps.createOrGetSession(args, toolContext, ctx)
|
||||
sessionID = session.sessionID
|
||||
createdSessionForExecution = session.isNew
|
||||
subagentSessions.add(sessionID)
|
||||
syncSubagentSessions.add(sessionID)
|
||||
|
||||
if (session.isNew) {
|
||||
spawnReservation?.commit()
|
||||
@@ -57,12 +64,15 @@ export async function executeSync(
|
||||
|
||||
if (fallbackChain && fallbackChain.length > 0) {
|
||||
deps.setSessionFallbackChain(sessionID, fallbackChain)
|
||||
appliedFallbackChain = true
|
||||
}
|
||||
|
||||
await toolContext.metadata?.({
|
||||
title: args.description,
|
||||
metadata: { sessionId: sessionID },
|
||||
})
|
||||
await Promise.resolve(
|
||||
toolContext.metadata?.({
|
||||
title: args.description,
|
||||
metadata: { sessionId: sessionID },
|
||||
})
|
||||
)
|
||||
|
||||
log(`[call_omo_agent] Sending prompt to session ${sessionID}`)
|
||||
log(`[call_omo_agent] Prompt text:`, args.prompt.substring(0, 100))
|
||||
@@ -93,12 +103,18 @@ export async function executeSync(
|
||||
|
||||
const responseText = await deps.processMessages(sessionID, ctx)
|
||||
|
||||
const output =
|
||||
responseText + "\n\n" + ["<task_metadata>", `session_id: ${sessionID}`, "</task_metadata>"].join("\n")
|
||||
|
||||
return output
|
||||
return responseText + "\n\n" + ["<task_metadata>", `session_id: ${sessionID}`, "</task_metadata>"].join("\n")
|
||||
} catch (error) {
|
||||
spawnReservation?.rollback()
|
||||
throw error
|
||||
} finally {
|
||||
if (sessionID && appliedFallbackChain) {
|
||||
deps.clearSessionFallbackChain(sessionID)
|
||||
}
|
||||
|
||||
if (sessionID && createdSessionForExecution) {
|
||||
subagentSessions.delete(sessionID)
|
||||
syncSubagentSessions.delete(sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user