fix: address review-work round 2 findings
- MCP teardown race: add shutdownGeneration counter to prevent in-flight connections from resurrecting after disconnectAll - MCP multi-key disconnect race: replace disconnectedSessions Set with generation-based Map to track per-session disconnect events - MCP clients: check shutdownGeneration in stdio/http client creators before inserting into state.clients - BackgroundManager: call clearTaskHistoryWhenParentTasksGone after timer-based task removal in scheduleTaskRemoval and notifyParentSession - BackgroundManager: clean completedTaskSummaries when parent has no remaining tasks - Plugin dispose: remove duplicate tmuxSessionManager.cleanup call since BackgroundManager.shutdown already handles it via onShutdown
This commit is contained in:
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1149,6 +1149,7 @@ export class BackgroundManager {
|
||||
if (!parentSessionID) return
|
||||
if (this.getTasksByParentSession(parentSessionID).length > 0) return
|
||||
this.taskHistory.clearSession(parentSessionID)
|
||||
this.completedTaskSummaries.delete(parentSessionID)
|
||||
}
|
||||
|
||||
private scheduleTaskRemoval(taskId: string): void {
|
||||
@@ -1170,6 +1171,7 @@ export class BackgroundManager {
|
||||
SessionCategoryRegistry.remove(task.sessionID)
|
||||
}
|
||||
log("[background-agent] Removed completed task from memory:", taskId)
|
||||
this.clearTaskHistoryWhenParentTasksGone(task?.parentSessionID)
|
||||
}
|
||||
}, TASK_CLEANUP_DELAY_MS)
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ export async function disconnectSession(state: SkillMcpManagerState, sessionID:
|
||||
}
|
||||
}
|
||||
if (hasPendingForSession) {
|
||||
state.disconnectedSessions.add(sessionID)
|
||||
state.disconnectedSessions.set(sessionID, (state.disconnectedSessions.get(sessionID) ?? 0) + 1)
|
||||
}
|
||||
const keysToRemove: string[] = []
|
||||
|
||||
@@ -125,6 +125,7 @@ export async function disconnectSession(state: SkillMcpManagerState, sessionID:
|
||||
}
|
||||
|
||||
export async function disconnectAll(state: SkillMcpManagerState): Promise<void> {
|
||||
state.shutdownGeneration++
|
||||
stopCleanupTimer(state)
|
||||
unregisterProcessCleanup(state)
|
||||
|
||||
|
||||
@@ -73,12 +73,13 @@ function createState(): SkillMcpManagerState {
|
||||
const state: SkillMcpManagerState = {
|
||||
clients: new Map(),
|
||||
pendingConnections: new Map(),
|
||||
disconnectedSessions: new Set(),
|
||||
disconnectedSessions: new Map(),
|
||||
authProviders: new Map(),
|
||||
cleanupRegistered: false,
|
||||
cleanupInterval: null,
|
||||
cleanupHandlers: [],
|
||||
idleTimeoutMs: 5 * 60 * 1000,
|
||||
shutdownGeneration: 0,
|
||||
}
|
||||
|
||||
trackedStates.push(state)
|
||||
@@ -145,11 +146,11 @@ describe("getOrCreateClient disconnect race", () => {
|
||||
const state = createState()
|
||||
const info = createClientInfo("session-a")
|
||||
const clientKey = createClientKey(info)
|
||||
state.disconnectedSessions.add(info.sessionID)
|
||||
state.disconnectedSessions.set(info.sessionID, 1)
|
||||
|
||||
const client = await getOrCreateClient({ state, clientKey, info, config: stdioConfig })
|
||||
|
||||
expect(state.disconnectedSessions.has(info.sessionID)).toBe(false)
|
||||
expect(state.disconnectedSessions.has(info.sessionID)).toBe(true)
|
||||
expect(state.clients.get(clientKey)?.client).toBe(client)
|
||||
expect(createdClients[0]?.close).not.toHaveBeenCalled()
|
||||
})
|
||||
@@ -163,3 +164,51 @@ describe("getOrCreateClient disconnect race", () => {
|
||||
expect(state.clients.size).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe("getOrCreateClient disconnectAll race", () => {
|
||||
it("#given pending connection #when disconnectAll() is called before connection completes #then client is not added to state.clients", async () => {
|
||||
const state = createState()
|
||||
const info = createClientInfo("session-a")
|
||||
const clientKey = createClientKey(info)
|
||||
const pendingConnect = createDeferred<void>()
|
||||
pendingConnects.push(pendingConnect)
|
||||
|
||||
const clientPromise = getOrCreateClient({ state, clientKey, info, config: stdioConfig })
|
||||
expect(state.pendingConnections.has(clientKey)).toBe(true)
|
||||
|
||||
await disconnectAll(state)
|
||||
pendingConnect.resolve(undefined)
|
||||
|
||||
await expect(clientPromise).rejects.toThrow(/connection completed after shutdown/)
|
||||
expect(state.clients.has(clientKey)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("getOrCreateClient multi-key disconnect race", () => {
|
||||
it("#given 2 pending connections for session A #when disconnectSession(A) before both complete #then both old connections are rejected", async () => {
|
||||
const state = createState()
|
||||
const infoKey1 = createClientInfo("session-a")
|
||||
const infoKey2 = { ...createClientInfo("session-a"), serverName: "server-2" }
|
||||
const clientKey1 = createClientKey(infoKey1)
|
||||
const clientKey2 = `${infoKey2.sessionID}:${infoKey2.skillName}:${infoKey2.serverName}`
|
||||
const pendingConnect1 = createDeferred<void>()
|
||||
const pendingConnect2 = createDeferred<void>()
|
||||
pendingConnects.push(pendingConnect1)
|
||||
pendingConnects.push(pendingConnect2)
|
||||
|
||||
const promise1 = getOrCreateClient({ state, clientKey: clientKey1, info: infoKey1, config: stdioConfig })
|
||||
const promise2 = getOrCreateClient({ state, clientKey: clientKey2, info: infoKey2, config: stdioConfig })
|
||||
expect(state.pendingConnections.size).toBe(2)
|
||||
|
||||
await disconnectSession(state, "session-a")
|
||||
|
||||
pendingConnect1.resolve(undefined)
|
||||
await expect(promise1).rejects.toThrow(/disconnected during MCP connection setup/)
|
||||
|
||||
pendingConnect2.resolve(undefined)
|
||||
await expect(promise2).rejects.toThrow(/disconnected during MCP connection setup/)
|
||||
|
||||
expect(state.clients.has(clientKey1)).toBe(false)
|
||||
expect(state.clients.has(clientKey2)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -14,7 +14,6 @@ export async function getOrCreateClient(params: {
|
||||
config: ClaudeCodeMcpServer
|
||||
}): Promise<Client> {
|
||||
const { state, clientKey, info, config } = params
|
||||
state.disconnectedSessions.delete(info.sessionID)
|
||||
|
||||
const existing = state.clients.get(clientKey)
|
||||
if (existing) {
|
||||
@@ -31,6 +30,9 @@ export async function getOrCreateClient(params: {
|
||||
const expandedConfig = expandEnvVarsInObject(config)
|
||||
let currentConnectionPromise!: Promise<Client>
|
||||
currentConnectionPromise = (async () => {
|
||||
const disconnectGenAtStart = state.disconnectedSessions.get(info.sessionID) ?? 0
|
||||
const shutdownGenAtStart = state.shutdownGeneration
|
||||
|
||||
const client = await createClient({ state, clientKey, info, config: expandedConfig })
|
||||
|
||||
const isStale = state.pendingConnections.has(clientKey) && state.pendingConnections.get(clientKey) !== currentConnectionPromise
|
||||
@@ -39,7 +41,13 @@ export async function getOrCreateClient(params: {
|
||||
throw new Error(`Connection for "${info.sessionID}" was superseded by a newer connection attempt.`)
|
||||
}
|
||||
|
||||
if (state.disconnectedSessions.has(info.sessionID)) {
|
||||
if (state.shutdownGeneration !== shutdownGenAtStart) {
|
||||
try { await client.close() } catch {}
|
||||
throw new Error(`Shutdown occurred during MCP connection for "${info.sessionID}"`)
|
||||
}
|
||||
|
||||
const currentDisconnectGen = state.disconnectedSessions.get(info.sessionID) ?? 0
|
||||
if (currentDisconnectGen > disconnectGenAtStart) {
|
||||
await forceReconnect(state, clientKey)
|
||||
throw new Error(`Session "${info.sessionID}" disconnected during MCP connection setup.`)
|
||||
}
|
||||
|
||||
@@ -18,12 +18,13 @@ function createState(): SkillMcpManagerState {
|
||||
const state: SkillMcpManagerState = {
|
||||
clients: new Map(),
|
||||
pendingConnections: new Map(),
|
||||
disconnectedSessions: new Set(),
|
||||
disconnectedSessions: new Map(),
|
||||
authProviders: new Map(),
|
||||
cleanupRegistered: false,
|
||||
cleanupInterval: null,
|
||||
cleanupHandlers: [],
|
||||
idleTimeoutMs: 5 * 60 * 1000,
|
||||
shutdownGeneration: 0,
|
||||
}
|
||||
|
||||
trackedStates.push(state)
|
||||
|
||||
@@ -24,6 +24,7 @@ function redactUrl(urlStr: string): string {
|
||||
|
||||
export async function createHttpClient(params: SkillMcpClientConnectionParams): Promise<Client> {
|
||||
const { state, clientKey, info, config } = params
|
||||
const shutdownGenAtStart = state.shutdownGeneration
|
||||
|
||||
if (!config.url) {
|
||||
throw new Error(`MCP server "${info.serverName}" is configured for HTTP but missing 'url' field.`)
|
||||
@@ -72,6 +73,12 @@ export async function createHttpClient(params: SkillMcpClientConnectionParams):
|
||||
)
|
||||
}
|
||||
|
||||
if (state.shutdownGeneration !== shutdownGenAtStart) {
|
||||
try { await client.close() } catch {}
|
||||
try { await transport.close() } catch {}
|
||||
throw new Error(`MCP server "${info.serverName}" connection completed after shutdown`)
|
||||
}
|
||||
|
||||
const managedClient = {
|
||||
client,
|
||||
transport,
|
||||
|
||||
@@ -10,12 +10,13 @@ export class SkillMcpManager {
|
||||
private readonly state: SkillMcpManagerState = {
|
||||
clients: new Map(),
|
||||
pendingConnections: new Map(),
|
||||
disconnectedSessions: new Set(),
|
||||
disconnectedSessions: new Map(),
|
||||
authProviders: new Map(),
|
||||
cleanupRegistered: false,
|
||||
cleanupInterval: null,
|
||||
cleanupHandlers: [],
|
||||
idleTimeoutMs: 5 * 60 * 1000,
|
||||
shutdownGeneration: 0,
|
||||
}
|
||||
|
||||
private getClientKey(info: SkillMcpClientInfo): string {
|
||||
|
||||
@@ -14,6 +14,7 @@ function getStdioCommand(config: ClaudeCodeMcpServer, serverName: string): strin
|
||||
|
||||
export async function createStdioClient(params: SkillMcpClientConnectionParams): Promise<Client> {
|
||||
const { state, clientKey, info, config } = params
|
||||
const shutdownGenAtStart = state.shutdownGeneration
|
||||
|
||||
const command = getStdioCommand(config, info.serverName)
|
||||
const args = config.args ?? []
|
||||
@@ -55,6 +56,12 @@ export async function createStdioClient(params: SkillMcpClientConnectionParams):
|
||||
)
|
||||
}
|
||||
|
||||
if (state.shutdownGeneration !== shutdownGenAtStart) {
|
||||
try { await client.close() } catch {}
|
||||
try { await transport.close() } catch {}
|
||||
throw new Error(`MCP server "${info.serverName}" connection completed after shutdown`)
|
||||
}
|
||||
|
||||
const managedClient = {
|
||||
client,
|
||||
transport,
|
||||
|
||||
@@ -51,12 +51,13 @@ export interface ProcessCleanupHandler {
|
||||
export interface SkillMcpManagerState {
|
||||
clients: Map<string, ManagedClient>
|
||||
pendingConnections: Map<string, Promise<Client>>
|
||||
disconnectedSessions: Set<string>
|
||||
disconnectedSessions: Map<string, number>
|
||||
authProviders: Map<string, McpOAuthProvider>
|
||||
cleanupRegistered: boolean
|
||||
cleanupInterval: ReturnType<typeof setInterval> | null
|
||||
cleanupHandlers: ProcessCleanupHandler[]
|
||||
idleTimeoutMs: number
|
||||
shutdownGeneration: number
|
||||
}
|
||||
|
||||
export interface SkillMcpClientConnectionParams {
|
||||
|
||||
@@ -74,7 +74,6 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager: managers.backgroundManager,
|
||||
skillMcpManager: managers.skillMcpManager,
|
||||
tmuxSessionManager: managers.tmuxSessionManager,
|
||||
disposeHooks: hooks.disposeHooks,
|
||||
})
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ describe("createPluginDispose", () => {
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager,
|
||||
skillMcpManager,
|
||||
tmuxSessionManager: { cleanup: async (): Promise<void> => {} },
|
||||
disposeHooks: (): void => {},
|
||||
})
|
||||
|
||||
@@ -39,7 +38,6 @@ describe("createPluginDispose", () => {
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager,
|
||||
skillMcpManager,
|
||||
tmuxSessionManager: { cleanup: async (): Promise<void> => {} },
|
||||
disposeHooks: (): void => {},
|
||||
})
|
||||
|
||||
@@ -71,7 +69,6 @@ describe("createPluginDispose", () => {
|
||||
skillMcpManager: {
|
||||
disconnectAll: async (): Promise<void> => {},
|
||||
},
|
||||
tmuxSessionManager: { cleanup: async (): Promise<void> => {} },
|
||||
disposeHooks: (): void => {
|
||||
disposeCreatedHooks({
|
||||
runtimeFallback,
|
||||
@@ -107,7 +104,6 @@ describe("createPluginDispose", () => {
|
||||
const dispose = createPluginDispose({
|
||||
backgroundManager,
|
||||
skillMcpManager,
|
||||
tmuxSessionManager: { cleanup: async (): Promise<void> => {} },
|
||||
disposeHooks: disposeHooks.run,
|
||||
})
|
||||
|
||||
|
||||
@@ -7,12 +7,9 @@ export function createPluginDispose(args: {
|
||||
skillMcpManager: {
|
||||
disconnectAll: () => Promise<void>
|
||||
}
|
||||
tmuxSessionManager: {
|
||||
cleanup: () => Promise<void>
|
||||
}
|
||||
disposeHooks: () => void
|
||||
}): PluginDispose {
|
||||
const { backgroundManager, skillMcpManager, tmuxSessionManager, disposeHooks } = args
|
||||
const { backgroundManager, skillMcpManager, disposeHooks } = args
|
||||
let disposePromise: Promise<void> | null = null
|
||||
|
||||
return async (): Promise<void> => {
|
||||
@@ -23,10 +20,7 @@ export function createPluginDispose(args: {
|
||||
|
||||
disposePromise = (async (): Promise<void> => {
|
||||
backgroundManager.shutdown()
|
||||
await Promise.all([
|
||||
skillMcpManager.disconnectAll(),
|
||||
tmuxSessionManager.cleanup(),
|
||||
])
|
||||
await skillMcpManager.disconnectAll()
|
||||
disposeHooks()
|
||||
})()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user