fix(run): inherit main-session tool permissions for continuation prompts
This commit is contained in:
@@ -6,7 +6,14 @@ import type {
|
||||
ResumeInput,
|
||||
} from "./types"
|
||||
import { TaskHistory } from "./task-history"
|
||||
import { log, getAgentToolRestrictions, normalizeSDKResponse, promptWithModelSuggestionRetry } from "../../shared"
|
||||
import {
|
||||
log,
|
||||
getAgentToolRestrictions,
|
||||
normalizePromptTools,
|
||||
normalizeSDKResponse,
|
||||
promptWithModelSuggestionRetry,
|
||||
resolveInheritedPromptTools,
|
||||
} from "../../shared"
|
||||
import { setSessionTools } from "../../shared/session-tools-store"
|
||||
import { ConcurrencyManager } from "./concurrency"
|
||||
import type { BackgroundTaskConfig, TmuxConfig } from "../../config/schema"
|
||||
@@ -1246,12 +1253,19 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
|
||||
let agent: string | undefined = task.parentAgent
|
||||
let model: { providerID: string; modelID: string } | undefined
|
||||
let tools: Record<string, boolean> | undefined = task.parentTools
|
||||
|
||||
if (this.enableParentSessionNotifications) {
|
||||
try {
|
||||
const messagesResp = await this.client.session.messages({ path: { id: task.parentSessionID } })
|
||||
const messages = normalizeSDKResponse(messagesResp, [] as Array<{
|
||||
info?: { agent?: string; model?: { providerID: string; modelID: string }; modelID?: string; providerID?: string }
|
||||
info?: {
|
||||
agent?: string
|
||||
model?: { providerID: string; modelID: string }
|
||||
modelID?: string
|
||||
providerID?: string
|
||||
tools?: Record<string, boolean | "allow" | "deny" | "ask">
|
||||
}
|
||||
}>)
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const info = messages[i].info
|
||||
@@ -1261,6 +1275,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
if (info?.agent || info?.model || (info?.modelID && info?.providerID)) {
|
||||
agent = info.agent ?? task.parentAgent
|
||||
model = info.model ?? (info.providerID && info.modelID ? { providerID: info.providerID, modelID: info.modelID } : undefined)
|
||||
tools = normalizePromptTools(info.tools) ?? tools
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -1277,8 +1292,11 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
model = currentMessage?.model?.providerID && currentMessage?.model?.modelID
|
||||
? { providerID: currentMessage.model.providerID, modelID: currentMessage.model.modelID }
|
||||
: undefined
|
||||
tools = normalizePromptTools(currentMessage?.tools) ?? tools
|
||||
}
|
||||
|
||||
tools = resolveInheritedPromptTools(task.parentSessionID, tools)
|
||||
|
||||
log("[background-agent] notifyParentSession context:", {
|
||||
taskId: task.id,
|
||||
resolvedAgent: agent,
|
||||
@@ -1292,7 +1310,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
noReply: !allComplete,
|
||||
...(agent !== undefined ? { agent } : {}),
|
||||
...(model !== undefined ? { model } : {}),
|
||||
...(task.parentTools ? { tools: task.parentTools } : {}),
|
||||
...(tools ? { tools } : {}),
|
||||
parts: [{ type: "text", text: notification }],
|
||||
},
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { OpencodeClient } from "./constants"
|
||||
import type { BackgroundTask } from "./types"
|
||||
import { findNearestMessageWithFields } from "../hook-message-injector"
|
||||
import { getMessageDir } from "../../shared"
|
||||
import { normalizePromptTools, resolveInheritedPromptTools } from "../../shared"
|
||||
|
||||
type AgentModel = { providerID: string; modelID: string }
|
||||
|
||||
@@ -12,6 +13,7 @@ function isObject(value: unknown): value is Record<string, unknown> {
|
||||
function extractAgentAndModelFromMessage(message: unknown): {
|
||||
agent?: string
|
||||
model?: AgentModel
|
||||
tools?: Record<string, boolean>
|
||||
} {
|
||||
if (!isObject(message)) return {}
|
||||
const info = message["info"]
|
||||
@@ -19,31 +21,33 @@ function extractAgentAndModelFromMessage(message: unknown): {
|
||||
|
||||
const agent = typeof info["agent"] === "string" ? info["agent"] : undefined
|
||||
const modelObj = info["model"]
|
||||
const tools = normalizePromptTools(isObject(info["tools"]) ? info["tools"] as Record<string, unknown> as Record<string, boolean | "allow" | "deny" | "ask"> : undefined)
|
||||
if (isObject(modelObj)) {
|
||||
const providerID = modelObj["providerID"]
|
||||
const modelID = modelObj["modelID"]
|
||||
if (typeof providerID === "string" && typeof modelID === "string") {
|
||||
return { agent, model: { providerID, modelID } }
|
||||
return { agent, model: { providerID, modelID }, tools }
|
||||
}
|
||||
}
|
||||
|
||||
const providerID = info["providerID"]
|
||||
const modelID = info["modelID"]
|
||||
if (typeof providerID === "string" && typeof modelID === "string") {
|
||||
return { agent, model: { providerID, modelID } }
|
||||
return { agent, model: { providerID, modelID }, tools }
|
||||
}
|
||||
|
||||
return { agent }
|
||||
return { agent, tools }
|
||||
}
|
||||
|
||||
export async function resolveParentSessionAgentAndModel(input: {
|
||||
client: OpencodeClient
|
||||
task: BackgroundTask
|
||||
}): Promise<{ agent?: string; model?: AgentModel }> {
|
||||
}): Promise<{ agent?: string; model?: AgentModel; tools?: Record<string, boolean> }> {
|
||||
const { client, task } = input
|
||||
|
||||
let agent: string | undefined = task.parentAgent
|
||||
let model: AgentModel | undefined
|
||||
let tools: Record<string, boolean> | undefined = task.parentTools
|
||||
|
||||
try {
|
||||
const messagesResp = await client.session.messages({
|
||||
@@ -55,9 +59,10 @@ export async function resolveParentSessionAgentAndModel(input: {
|
||||
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const extracted = extractAgentAndModelFromMessage(messages[i])
|
||||
if (extracted.agent || extracted.model) {
|
||||
if (extracted.agent || extracted.model || extracted.tools) {
|
||||
agent = extracted.agent ?? task.parentAgent
|
||||
model = extracted.model
|
||||
tools = extracted.tools ?? tools
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -69,7 +74,8 @@ export async function resolveParentSessionAgentAndModel(input: {
|
||||
currentMessage?.model?.providerID && currentMessage?.model?.modelID
|
||||
? { providerID: currentMessage.model.providerID, modelID: currentMessage.model.modelID }
|
||||
: undefined
|
||||
tools = normalizePromptTools(currentMessage?.tools) ?? tools
|
||||
}
|
||||
|
||||
return { agent, model }
|
||||
return { agent, model, tools: resolveInheritedPromptTools(task.parentSessionID, tools) }
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ export async function notifyParentSession(
|
||||
completedTasks,
|
||||
})
|
||||
|
||||
const { agent, model } = await resolveParentSessionAgentAndModel({ client, task })
|
||||
const { agent, model, tools } = await resolveParentSessionAgentAndModel({ client, task })
|
||||
|
||||
log("[background-agent] notifyParentSession context:", {
|
||||
taskId: task.id,
|
||||
@@ -71,7 +71,7 @@ export async function notifyParentSession(
|
||||
noReply: !allComplete,
|
||||
...(agent !== undefined ? { agent } : {}),
|
||||
...(model !== undefined ? { model } : {}),
|
||||
...(task.parentTools ? { tools: task.parentTools } : {}),
|
||||
...(tools ? { tools } : {}),
|
||||
parts: [{ type: "text", text: notification }],
|
||||
},
|
||||
})
|
||||
|
||||
@@ -5,6 +5,7 @@ import type { Client } from "./client"
|
||||
import { clearSessionState } from "./state"
|
||||
import { formatBytes } from "./message-builder"
|
||||
import { log } from "../../shared/logger"
|
||||
import { resolveInheritedPromptTools } from "../../shared"
|
||||
|
||||
export async function runAggressiveTruncationStrategy(params: {
|
||||
sessionID: string
|
||||
@@ -61,9 +62,13 @@ export async function runAggressiveTruncationStrategy(params: {
|
||||
clearSessionState(params.autoCompactState, params.sessionID)
|
||||
setTimeout(async () => {
|
||||
try {
|
||||
const inheritedTools = resolveInheritedPromptTools(params.sessionID)
|
||||
await params.client.session.promptAsync({
|
||||
path: { id: params.sessionID },
|
||||
body: { auto: true } as never,
|
||||
body: {
|
||||
auto: true,
|
||||
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||
} as never,
|
||||
query: { directory: params.directory },
|
||||
})
|
||||
} catch {}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import type { BackgroundManager } from "../../features/background-agent"
|
||||
import { log } from "../../shared/logger"
|
||||
import { resolveInheritedPromptTools } from "../../shared"
|
||||
import { HOOK_NAME } from "./hook-name"
|
||||
import { BOULDER_CONTINUATION_PROMPT } from "./system-reminder-templates"
|
||||
import { resolveRecentModelForSession } from "./recent-model-resolver"
|
||||
import { resolveRecentPromptContextForSession } from "./recent-model-resolver"
|
||||
import type { SessionState } from "./types"
|
||||
|
||||
export async function injectBoulderContinuation(input: {
|
||||
@@ -43,13 +44,15 @@ export async function injectBoulderContinuation(input: {
|
||||
try {
|
||||
log(`[${HOOK_NAME}] Injecting boulder continuation`, { sessionID, planName, remaining })
|
||||
|
||||
const model = await resolveRecentModelForSession(ctx, sessionID)
|
||||
const promptContext = await resolveRecentPromptContextForSession(ctx, sessionID)
|
||||
const inheritedTools = resolveInheritedPromptTools(sessionID, promptContext.tools)
|
||||
|
||||
await ctx.client.session.promptAsync({
|
||||
path: { id: sessionID },
|
||||
body: {
|
||||
agent: agent ?? "atlas",
|
||||
...(model !== undefined ? { model } : {}),
|
||||
...(promptContext.model !== undefined ? { model: promptContext.model } : {}),
|
||||
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||
parts: [{ type: "text", text: prompt }],
|
||||
},
|
||||
query: { directory: ctx.directory },
|
||||
|
||||
@@ -3,28 +3,39 @@ import {
|
||||
findNearestMessageWithFields,
|
||||
findNearestMessageWithFieldsFromSDK,
|
||||
} from "../../features/hook-message-injector"
|
||||
import { getMessageDir, isSqliteBackend, normalizeSDKResponse } from "../../shared"
|
||||
import { getMessageDir, isSqliteBackend, normalizePromptTools, normalizeSDKResponse } from "../../shared"
|
||||
import type { ModelInfo } from "./types"
|
||||
|
||||
export async function resolveRecentModelForSession(
|
||||
type PromptContext = {
|
||||
model?: ModelInfo
|
||||
tools?: Record<string, boolean>
|
||||
}
|
||||
|
||||
export async function resolveRecentPromptContextForSession(
|
||||
ctx: PluginInput,
|
||||
sessionID: string
|
||||
): Promise<ModelInfo | undefined> {
|
||||
): Promise<PromptContext> {
|
||||
try {
|
||||
const messagesResp = await ctx.client.session.messages({ path: { id: sessionID } })
|
||||
const messages = normalizeSDKResponse(messagesResp, [] as Array<{
|
||||
info?: { model?: ModelInfo; modelID?: string; providerID?: string }
|
||||
info?: {
|
||||
model?: ModelInfo
|
||||
modelID?: string
|
||||
providerID?: string
|
||||
tools?: Record<string, boolean | "allow" | "deny" | "ask">
|
||||
}
|
||||
}>)
|
||||
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const info = messages[i].info
|
||||
const model = info?.model
|
||||
const tools = normalizePromptTools(info?.tools)
|
||||
if (model?.providerID && model?.modelID) {
|
||||
return { providerID: model.providerID, modelID: model.modelID }
|
||||
return { model: { providerID: model.providerID, modelID: model.modelID }, tools }
|
||||
}
|
||||
|
||||
if (info?.providerID && info?.modelID) {
|
||||
return { providerID: info.providerID, modelID: info.modelID }
|
||||
return { model: { providerID: info.providerID, modelID: info.modelID }, tools }
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
@@ -39,8 +50,17 @@ export async function resolveRecentModelForSession(
|
||||
currentMessage = messageDir ? findNearestMessageWithFields(messageDir) : null
|
||||
}
|
||||
const model = currentMessage?.model
|
||||
const tools = normalizePromptTools(currentMessage?.tools)
|
||||
if (!model?.providerID || !model?.modelID) {
|
||||
return undefined
|
||||
return { tools }
|
||||
}
|
||||
return { providerID: model.providerID, modelID: model.modelID }
|
||||
return { model: { providerID: model.providerID, modelID: model.modelID }, tools }
|
||||
}
|
||||
|
||||
export async function resolveRecentModelForSession(
|
||||
ctx: PluginInput,
|
||||
sessionID: string
|
||||
): Promise<ModelInfo | undefined> {
|
||||
const context = await resolveRecentPromptContextForSession(ctx, sessionID)
|
||||
return context.model
|
||||
}
|
||||
|
||||
@@ -3,13 +3,14 @@ import { log } from "../../shared/logger"
|
||||
import { findNearestMessageWithFields } from "../../features/hook-message-injector"
|
||||
import { getMessageDir } from "./message-storage-directory"
|
||||
import { withTimeout } from "./with-timeout"
|
||||
import { normalizeSDKResponse } from "../../shared"
|
||||
import { normalizeSDKResponse, resolveInheritedPromptTools } from "../../shared"
|
||||
|
||||
type MessageInfo = {
|
||||
agent?: string
|
||||
model?: { providerID: string; modelID: string }
|
||||
modelID?: string
|
||||
providerID?: string
|
||||
tools?: Record<string, boolean | "allow" | "deny" | "ask">
|
||||
}
|
||||
|
||||
export async function injectContinuationPrompt(
|
||||
@@ -18,6 +19,7 @@ export async function injectContinuationPrompt(
|
||||
): Promise<void> {
|
||||
let agent: string | undefined
|
||||
let model: { providerID: string; modelID: string } | undefined
|
||||
let tools: Record<string, boolean | "allow" | "deny" | "ask"> | undefined
|
||||
|
||||
try {
|
||||
const messagesResp = await withTimeout(
|
||||
@@ -36,6 +38,7 @@ export async function injectContinuationPrompt(
|
||||
(info.providerID && info.modelID
|
||||
? { providerID: info.providerID, modelID: info.modelID }
|
||||
: undefined)
|
||||
tools = info.tools
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -50,13 +53,17 @@ export async function injectContinuationPrompt(
|
||||
modelID: currentMessage.model.modelID,
|
||||
}
|
||||
: undefined
|
||||
tools = currentMessage?.tools
|
||||
}
|
||||
|
||||
const inheritedTools = resolveInheritedPromptTools(options.sessionID, tools)
|
||||
|
||||
await ctx.client.session.promptAsync({
|
||||
path: { id: options.sessionID },
|
||||
body: {
|
||||
...(agent !== undefined ? { agent } : {}),
|
||||
...(model !== undefined ? { model } : {}),
|
||||
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||
parts: [{ type: "text", text: options.prompt }],
|
||||
},
|
||||
query: { directory: options.directory },
|
||||
|
||||
48
src/hooks/session-recovery/resume.test.ts
Normal file
48
src/hooks/session-recovery/resume.test.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
declare const require: (name: string) => any
|
||||
const { describe, expect, test } = require("bun:test")
|
||||
import { extractResumeConfig, resumeSession } from "./resume"
|
||||
import type { MessageData } from "./types"
|
||||
|
||||
describe("session-recovery resume", () => {
|
||||
test("extractResumeConfig carries tools from last user message", () => {
|
||||
// given
|
||||
const userMessage: MessageData = {
|
||||
info: {
|
||||
agent: "Hephaestus",
|
||||
model: { providerID: "openai", modelID: "gpt-5.3-codex" },
|
||||
tools: { question: false, bash: true },
|
||||
},
|
||||
}
|
||||
|
||||
// when
|
||||
const config = extractResumeConfig(userMessage, "ses_resume_tools")
|
||||
|
||||
// then
|
||||
expect(config.tools).toEqual({ question: false, bash: true })
|
||||
})
|
||||
|
||||
test("resumeSession sends inherited tools with continuation prompt", async () => {
|
||||
// given
|
||||
let promptBody: Record<string, unknown> | undefined
|
||||
const client = {
|
||||
session: {
|
||||
promptAsync: async (input: { body: Record<string, unknown> }) => {
|
||||
promptBody = input.body
|
||||
return {}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// when
|
||||
const ok = await resumeSession(client as never, {
|
||||
sessionID: "ses_resume_prompt",
|
||||
agent: "Hephaestus",
|
||||
model: { providerID: "openai", modelID: "gpt-5.3-codex" },
|
||||
tools: { question: false, bash: true },
|
||||
})
|
||||
|
||||
// then
|
||||
expect(ok).toBe(true)
|
||||
expect(promptBody?.tools).toEqual({ question: false, bash: true })
|
||||
})
|
||||
})
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { createOpencodeClient } from "@opencode-ai/sdk"
|
||||
import type { MessageData, ResumeConfig } from "./types"
|
||||
import { resolveInheritedPromptTools } from "../../shared"
|
||||
|
||||
const RECOVERY_RESUME_TEXT = "[session recovered - continuing previous task]"
|
||||
|
||||
@@ -19,17 +20,20 @@ export function extractResumeConfig(userMessage: MessageData | undefined, sessio
|
||||
sessionID,
|
||||
agent: userMessage?.info?.agent,
|
||||
model: userMessage?.info?.model,
|
||||
tools: userMessage?.info?.tools,
|
||||
}
|
||||
}
|
||||
|
||||
export async function resumeSession(client: Client, config: ResumeConfig): Promise<boolean> {
|
||||
try {
|
||||
const inheritedTools = resolveInheritedPromptTools(config.sessionID, config.tools)
|
||||
await client.session.promptAsync({
|
||||
path: { id: config.sessionID },
|
||||
body: {
|
||||
parts: [{ type: "text", text: RECOVERY_RESUME_TEXT }],
|
||||
agent: config.agent,
|
||||
model: config.model,
|
||||
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||
},
|
||||
})
|
||||
return true
|
||||
|
||||
@@ -95,4 +95,5 @@ export interface ResumeConfig {
|
||||
providerID: string
|
||||
modelID: string
|
||||
}
|
||||
tools?: Record<string, boolean>
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
declare const require: (name: string) => any
|
||||
const { describe, expect, test } = require("bun:test")
|
||||
|
||||
import { injectContinuation } from "./continuation-injection"
|
||||
|
||||
describe("injectContinuation", () => {
|
||||
test("inherits tools from resolved message info when reinjecting", async () => {
|
||||
// given
|
||||
let capturedTools: Record<string, boolean> | undefined
|
||||
const ctx = {
|
||||
directory: "/tmp/test",
|
||||
client: {
|
||||
session: {
|
||||
todo: async () => ({ data: [{ id: "1", content: "todo", status: "pending", priority: "high" }] }),
|
||||
promptAsync: async (input: { body: { tools?: Record<string, boolean> } }) => {
|
||||
capturedTools = input.body.tools
|
||||
return {}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
const sessionStateStore = {
|
||||
getExistingState: () => ({ inFlight: false, lastInjectedAt: 0, consecutiveFailures: 0 }),
|
||||
}
|
||||
|
||||
// when
|
||||
await injectContinuation({
|
||||
ctx: ctx as never,
|
||||
sessionID: "ses_continuation_tools",
|
||||
resolvedInfo: {
|
||||
agent: "Hephaestus",
|
||||
model: { providerID: "openai", modelID: "gpt-5.3-codex" },
|
||||
tools: { question: "deny", bash: "allow" },
|
||||
},
|
||||
sessionStateStore: sessionStateStore as never,
|
||||
})
|
||||
|
||||
// then
|
||||
expect(capturedTools).toEqual({ question: false, bash: true })
|
||||
})
|
||||
})
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
|
||||
import type { BackgroundManager } from "../../features/background-agent"
|
||||
import { normalizeSDKResponse } from "../../shared"
|
||||
import { normalizeSDKResponse, resolveInheritedPromptTools } from "../../shared"
|
||||
import {
|
||||
findNearestMessageWithFields,
|
||||
findNearestMessageWithFieldsFromSDK,
|
||||
@@ -136,11 +136,14 @@ ${todoList}`
|
||||
incompleteCount: freshIncompleteCount,
|
||||
})
|
||||
|
||||
const inheritedTools = resolveInheritedPromptTools(sessionID, tools)
|
||||
|
||||
await ctx.client.session.promptAsync({
|
||||
path: { id: sessionID },
|
||||
body: {
|
||||
agent: agentName,
|
||||
...(model !== undefined ? { model } : {}),
|
||||
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||
parts: [{ type: "text", text: prompt }],
|
||||
},
|
||||
query: { directory: ctx.directory },
|
||||
|
||||
@@ -8,6 +8,7 @@ type MessageInfo = {
|
||||
model?: { providerID: string; modelID: string }
|
||||
providerID?: string
|
||||
modelID?: string
|
||||
tools?: Record<string, boolean | "allow" | "deny" | "ask">
|
||||
}
|
||||
|
||||
type MessagePart = {
|
||||
@@ -40,6 +41,20 @@ export function getMessageInfo(value: unknown): MessageInfo | undefined {
|
||||
model,
|
||||
providerID: typeof info.providerID === "string" ? info.providerID : undefined,
|
||||
modelID: typeof info.modelID === "string" ? info.modelID : undefined,
|
||||
tools: isRecord(info.tools)
|
||||
? Object.entries(info.tools).reduce<Record<string, boolean | "allow" | "deny" | "ask">>((acc, [key, value]) => {
|
||||
if (
|
||||
value === true ||
|
||||
value === false ||
|
||||
value === "allow" ||
|
||||
value === "deny" ||
|
||||
value === "ask"
|
||||
) {
|
||||
acc[key] = value
|
||||
}
|
||||
return acc
|
||||
}, {})
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { BackgroundManager } from "../../features/background-agent"
|
||||
import { getMainSessionID, getSessionAgent } from "../../features/claude-code-session-state"
|
||||
import { log } from "../../shared/logger"
|
||||
import { resolveInheritedPromptTools } from "../../shared"
|
||||
import {
|
||||
buildReminder,
|
||||
extractMessages,
|
||||
@@ -29,6 +30,7 @@ type BabysitterContext = {
|
||||
parts: Array<{ type: "text"; text: string }>
|
||||
agent?: string
|
||||
model?: { providerID: string; modelID: string }
|
||||
tools?: Record<string, boolean>
|
||||
}
|
||||
query?: { directory?: string }
|
||||
}) => Promise<unknown>
|
||||
@@ -38,6 +40,7 @@ type BabysitterContext = {
|
||||
parts: Array<{ type: "text"; text: string }>
|
||||
agent?: string
|
||||
model?: { providerID: string; modelID: string }
|
||||
tools?: Record<string, boolean>
|
||||
}
|
||||
query?: { directory?: string }
|
||||
}) => Promise<unknown>
|
||||
@@ -54,9 +57,10 @@ type BabysitterOptions = {
|
||||
async function resolveMainSessionTarget(
|
||||
ctx: BabysitterContext,
|
||||
sessionID: string
|
||||
): Promise<{ agent?: string; model?: { providerID: string; modelID: string } }> {
|
||||
): Promise<{ agent?: string; model?: { providerID: string; modelID: string }; tools?: Record<string, boolean> }> {
|
||||
let agent = getSessionAgent(sessionID)
|
||||
let model: { providerID: string; modelID: string } | undefined
|
||||
let tools: Record<string, boolean> | undefined
|
||||
|
||||
try {
|
||||
const messagesResp = await ctx.client.session.messages({
|
||||
@@ -68,6 +72,7 @@ async function resolveMainSessionTarget(
|
||||
if (info?.agent || info?.model || (info?.providerID && info?.modelID)) {
|
||||
agent = agent ?? info?.agent
|
||||
model = info?.model ?? (info?.providerID && info?.modelID ? { providerID: info.providerID, modelID: info.modelID } : undefined)
|
||||
tools = resolveInheritedPromptTools(sessionID, info?.tools) ?? tools
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -75,7 +80,7 @@ async function resolveMainSessionTarget(
|
||||
log(`[${HOOK_NAME}] Failed to resolve main session agent`, { sessionID, error: String(error) })
|
||||
}
|
||||
|
||||
return { agent, model }
|
||||
return { agent, model, tools: resolveInheritedPromptTools(sessionID, tools) }
|
||||
}
|
||||
|
||||
async function getThinkingSummary(ctx: BabysitterContext, sessionID: string): Promise<string | null> {
|
||||
@@ -144,7 +149,7 @@ export function createUnstableAgentBabysitterHook(ctx: BabysitterContext, option
|
||||
|
||||
const summary = task.sessionID ? await getThinkingSummary(ctx, task.sessionID) : null
|
||||
const reminder = buildReminder(task, summary, idleMs)
|
||||
const { agent, model } = await resolveMainSessionTarget(ctx, mainSessionID)
|
||||
const { agent, model, tools } = await resolveMainSessionTarget(ctx, mainSessionID)
|
||||
|
||||
try {
|
||||
await ctx.client.session.promptAsync({
|
||||
@@ -152,6 +157,7 @@ export function createUnstableAgentBabysitterHook(ctx: BabysitterContext, option
|
||||
body: {
|
||||
...(agent ? { agent } : {}),
|
||||
...(model ? { model } : {}),
|
||||
...(tools ? { tools } : {}),
|
||||
parts: [{ type: "text", text: reminder }],
|
||||
},
|
||||
query: { directory: ctx.directory },
|
||||
|
||||
@@ -56,3 +56,4 @@ export * from "./opencode-storage-paths"
|
||||
export * from "./opencode-message-dir"
|
||||
export * from "./normalize-sdk-response"
|
||||
export * from "./session-directory-resolver"
|
||||
export * from "./prompt-tools"
|
||||
|
||||
56
src/shared/prompt-tools.test.ts
Normal file
56
src/shared/prompt-tools.test.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
declare const require: (name: string) => any
|
||||
const { afterEach, describe, expect, test } = require("bun:test")
|
||||
import { clearSessionTools, setSessionTools } from "./session-tools-store"
|
||||
import { normalizePromptTools, resolveInheritedPromptTools } from "./prompt-tools"
|
||||
|
||||
describe("prompt-tools", () => {
|
||||
afterEach(() => {
|
||||
clearSessionTools()
|
||||
})
|
||||
|
||||
test("normalizes allow/deny style permissions to boolean tools", () => {
|
||||
// given
|
||||
const tools = {
|
||||
question: "deny",
|
||||
bash: "allow",
|
||||
task: "ask",
|
||||
read: true,
|
||||
edit: false,
|
||||
} as const
|
||||
|
||||
// when
|
||||
const normalized = normalizePromptTools(tools)
|
||||
|
||||
// then
|
||||
expect(normalized).toEqual({
|
||||
question: false,
|
||||
bash: true,
|
||||
task: true,
|
||||
read: true,
|
||||
edit: false,
|
||||
})
|
||||
})
|
||||
|
||||
test("prefers per-session stored tools over fallback tools", () => {
|
||||
// given
|
||||
const sessionID = "ses_prompt_tools"
|
||||
setSessionTools(sessionID, { question: false, bash: true })
|
||||
|
||||
// when
|
||||
const resolved = resolveInheritedPromptTools(sessionID, { question: true, bash: false })
|
||||
|
||||
// then
|
||||
expect(resolved).toEqual({ question: false, bash: true })
|
||||
})
|
||||
|
||||
test("uses fallback tools when no per-session tools exist", () => {
|
||||
// given
|
||||
const sessionID = "ses_fallback_only"
|
||||
|
||||
// when
|
||||
const resolved = resolveInheritedPromptTools(sessionID, { question: "deny", write: "allow" })
|
||||
|
||||
// then
|
||||
expect(resolved).toEqual({ question: false, write: true })
|
||||
})
|
||||
})
|
||||
35
src/shared/prompt-tools.ts
Normal file
35
src/shared/prompt-tools.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { getSessionTools } from "./session-tools-store"
|
||||
|
||||
export type PromptToolPermission = boolean | "allow" | "deny" | "ask"
|
||||
|
||||
export function normalizePromptTools(
|
||||
tools: Record<string, PromptToolPermission> | undefined
|
||||
): Record<string, boolean> | undefined {
|
||||
if (!tools) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const normalized: Record<string, boolean> = {}
|
||||
for (const [toolName, permission] of Object.entries(tools)) {
|
||||
if (permission === false || permission === "deny") {
|
||||
normalized[toolName] = false
|
||||
continue
|
||||
}
|
||||
if (permission === true || permission === "allow" || permission === "ask") {
|
||||
normalized[toolName] = true
|
||||
}
|
||||
}
|
||||
|
||||
return Object.keys(normalized).length > 0 ? normalized : undefined
|
||||
}
|
||||
|
||||
export function resolveInheritedPromptTools(
|
||||
sessionID: string,
|
||||
fallbackTools?: Record<string, PromptToolPermission>
|
||||
): Record<string, boolean> | undefined {
|
||||
const sessionTools = getSessionTools(sessionID)
|
||||
if (sessionTools && Object.keys(sessionTools).length > 0) {
|
||||
return { ...sessionTools }
|
||||
}
|
||||
return normalizePromptTools(fallbackTools)
|
||||
}
|
||||
Reference in New Issue
Block a user