Unify dynamic fallback chains for background subagents
This commit is contained in:
committed by
YeonGyu-Kim
parent
38c925697b
commit
86c6bc7716
@@ -21,7 +21,7 @@ import { resetMessageCursor } from "../shared";
|
||||
import { getAgentConfigKey } from "../shared/agent-display-names";
|
||||
import { log } from "../shared/logger";
|
||||
import { shouldRetryError } from "../shared/model-error-classifier";
|
||||
import type { FallbackEntry } from "../shared/model-requirements";
|
||||
import { buildFallbackChainFromModels } from "../shared/fallback-chain-from-models";
|
||||
import { clearSessionModel, setSessionModel } from "../shared/session-model-state";
|
||||
import { deleteSessionTools } from "../shared/session-tools-store";
|
||||
import { lspManager } from "../tools";
|
||||
@@ -123,30 +123,6 @@ function extractProviderModelFromErrorMessage(message: string): { providerID?: s
|
||||
|
||||
return {};
|
||||
}
|
||||
function parseFallbackModelEntry(
|
||||
model: string,
|
||||
defaultProviderID: string,
|
||||
): FallbackEntry | undefined {
|
||||
const trimmed = model.trim();
|
||||
if (!trimmed) return undefined;
|
||||
|
||||
const parts = trimmed.split("/");
|
||||
const providerID = parts.length >= 2 ? parts[0].trim() : defaultProviderID;
|
||||
const rawModelID = parts.length >= 2 ? parts.slice(1).join("/").trim() : trimmed;
|
||||
if (!providerID || !rawModelID) return undefined;
|
||||
|
||||
const variantMatch = rawModelID.match(/^(.*)\(([^()]+)\)\s*$/);
|
||||
if (variantMatch) {
|
||||
const parsedModelID = variantMatch[1]?.trim();
|
||||
const parsedVariant = variantMatch[2]?.trim();
|
||||
if (parsedModelID && parsedVariant) {
|
||||
return { providers: [providerID], model: parsedModelID, variant: parsedVariant };
|
||||
}
|
||||
}
|
||||
|
||||
return { providers: [providerID], model: rawModelID };
|
||||
}
|
||||
|
||||
function applyUserConfiguredFallbackChain(
|
||||
sessionID: string,
|
||||
agentName: string,
|
||||
@@ -157,11 +133,9 @@ function applyUserConfiguredFallbackChain(
|
||||
const configuredFallbackModels = getFallbackModelsForSession(sessionID, agentKey, pluginConfig);
|
||||
if (configuredFallbackModels.length === 0) return;
|
||||
|
||||
const fallbackChain = configuredFallbackModels
|
||||
.map((model) => parseFallbackModelEntry(model, currentProviderID))
|
||||
.filter((entry): entry is FallbackEntry => entry !== undefined);
|
||||
const fallbackChain = buildFallbackChainFromModels(configuredFallbackModels, currentProviderID);
|
||||
|
||||
if (fallbackChain.length > 0) {
|
||||
if (fallbackChain && fallbackChain.length > 0) {
|
||||
setSessionFallbackChain(sessionID, fallbackChain);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,13 @@ export function createToolRegistry(args: {
|
||||
const { ctx, pluginConfig, managers, skillContext, availableCategories } = args
|
||||
|
||||
const backgroundTools = createBackgroundTools(managers.backgroundManager, ctx.client)
|
||||
const callOmoAgent = createCallOmoAgent(ctx, managers.backgroundManager, pluginConfig.disabled_agents ?? [])
|
||||
const callOmoAgent = createCallOmoAgent(
|
||||
ctx,
|
||||
managers.backgroundManager,
|
||||
pluginConfig.disabled_agents ?? [],
|
||||
pluginConfig.agents,
|
||||
pluginConfig.categories,
|
||||
)
|
||||
|
||||
const isMultimodalLookerEnabled = !(pluginConfig.disabled_agents ?? []).some(
|
||||
(agent) => agent.toLowerCase() === "multimodal-looker",
|
||||
|
||||
48
src/shared/fallback-chain-from-models.test.ts
Normal file
48
src/shared/fallback-chain-from-models.test.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import { describe, test, expect } from "bun:test"
|
||||
import { buildFallbackChainFromModels, parseFallbackModelEntry } from "./fallback-chain-from-models"
|
||||
|
||||
describe("fallback-chain-from-models", () => {
|
||||
test("parses provider/model entry with parenthesized variant", () => {
|
||||
//#given
|
||||
const fallbackModel = "openai/gpt-5.2(high)"
|
||||
|
||||
//#when
|
||||
const parsed = parseFallbackModelEntry(fallbackModel, "quotio")
|
||||
|
||||
//#then
|
||||
expect(parsed).toEqual({
|
||||
providers: ["openai"],
|
||||
model: "gpt-5.2",
|
||||
variant: "high",
|
||||
})
|
||||
})
|
||||
|
||||
test("uses default provider when fallback model omits provider prefix", () => {
|
||||
//#given
|
||||
const fallbackModel = "glm-5"
|
||||
|
||||
//#when
|
||||
const parsed = parseFallbackModelEntry(fallbackModel, "quotio")
|
||||
|
||||
//#then
|
||||
expect(parsed).toEqual({
|
||||
providers: ["quotio"],
|
||||
model: "glm-5",
|
||||
variant: undefined,
|
||||
})
|
||||
})
|
||||
|
||||
test("builds fallback chain from normalized fallback_models input", () => {
|
||||
//#given
|
||||
const fallbackModels = ["quotio/kimi-k2.5", "gpt-5.2 medium"]
|
||||
|
||||
//#when
|
||||
const chain = buildFallbackChainFromModels(fallbackModels, "quotio")
|
||||
|
||||
//#then
|
||||
expect(chain).toEqual([
|
||||
{ providers: ["quotio"], model: "kimi-k2.5", variant: undefined },
|
||||
{ providers: ["quotio"], model: "gpt-5.2", variant: "medium" },
|
||||
])
|
||||
})
|
||||
})
|
||||
75
src/shared/fallback-chain-from-models.ts
Normal file
75
src/shared/fallback-chain-from-models.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import type { FallbackEntry } from "./model-requirements"
|
||||
import { normalizeFallbackModels } from "./model-resolver"
|
||||
|
||||
const KNOWN_VARIANTS = new Set([
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"xhigh",
|
||||
"max",
|
||||
"none",
|
||||
"auto",
|
||||
"thinking",
|
||||
])
|
||||
|
||||
function parseVariantFromModel(rawModel: string): { modelID: string; variant?: string } {
|
||||
const trimmedModel = rawModel.trim()
|
||||
if (!trimmedModel) {
|
||||
return { modelID: "" }
|
||||
}
|
||||
|
||||
const parenthesizedVariant = trimmedModel.match(/^(.*)\(([^()]+)\)\s*$/)
|
||||
if (parenthesizedVariant) {
|
||||
const modelID = parenthesizedVariant[1]?.trim() ?? ""
|
||||
const variant = parenthesizedVariant[2]?.trim()
|
||||
return variant ? { modelID, variant } : { modelID }
|
||||
}
|
||||
|
||||
const spaceVariant = trimmedModel.match(/^(.*\S)\s+([a-z][a-z0-9_-]*)$/i)
|
||||
if (spaceVariant) {
|
||||
const modelID = spaceVariant[1]?.trim() ?? ""
|
||||
const variant = spaceVariant[2]?.trim().toLowerCase()
|
||||
if (variant && KNOWN_VARIANTS.has(variant)) {
|
||||
return { modelID, variant }
|
||||
}
|
||||
}
|
||||
|
||||
return { modelID: trimmedModel }
|
||||
}
|
||||
|
||||
export function parseFallbackModelEntry(
|
||||
model: string,
|
||||
defaultProviderID: string,
|
||||
): FallbackEntry | undefined {
|
||||
const trimmed = model.trim()
|
||||
if (!trimmed) return undefined
|
||||
|
||||
const parts = trimmed.split("/")
|
||||
const providerID = parts.length >= 2 ? parts[0].trim() : defaultProviderID
|
||||
const rawModelID = parts.length >= 2 ? parts.slice(1).join("/").trim() : trimmed
|
||||
if (!providerID || !rawModelID) return undefined
|
||||
|
||||
const parsed = parseVariantFromModel(rawModelID)
|
||||
if (!parsed.modelID) return undefined
|
||||
|
||||
return {
|
||||
providers: [providerID],
|
||||
model: parsed.modelID,
|
||||
variant: parsed.variant,
|
||||
}
|
||||
}
|
||||
|
||||
export function buildFallbackChainFromModels(
|
||||
fallbackModels: string | string[] | undefined,
|
||||
defaultProviderID: string,
|
||||
): FallbackEntry[] | undefined {
|
||||
const normalized = normalizeFallbackModels(fallbackModels)
|
||||
if (!normalized || normalized.length === 0) return undefined
|
||||
|
||||
const parsed = normalized
|
||||
.map((model) => parseFallbackModelEntry(model, defaultProviderID))
|
||||
.filter((entry): entry is FallbackEntry => entry !== undefined)
|
||||
|
||||
if (parsed.length === 0) return undefined
|
||||
return parsed
|
||||
}
|
||||
@@ -64,4 +64,26 @@ describe("executeBackground", () => {
|
||||
expect(result).toContain("interrupt")
|
||||
expect(result).toContain("test-task-id")
|
||||
})
|
||||
|
||||
test("passes fallbackChain to background manager launch", async () => {
|
||||
//#given
|
||||
const fallbackChain = [
|
||||
{ providers: ["quotio"], model: "kimi-k2.5", variant: undefined },
|
||||
{ providers: ["openai"], model: "gpt-5.2", variant: "high" },
|
||||
]
|
||||
launchMock.mockResolvedValueOnce({
|
||||
id: "test-task-id",
|
||||
sessionID: "sub-session",
|
||||
description: "Test task",
|
||||
agent: "test-agent",
|
||||
status: "pending",
|
||||
})
|
||||
|
||||
//#when
|
||||
await executeBackground(testArgs, testContext, mockManager, mockClient, fallbackChain)
|
||||
|
||||
//#then
|
||||
const launchArgs = launchMock.mock.calls.at(-1)?.[0]
|
||||
expect(launchArgs.fallbackChain).toEqual(fallbackChain)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { CallOmoAgentArgs } from "./types"
|
||||
import type { BackgroundManager } from "../../features/background-agent"
|
||||
import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import { log } from "../../shared"
|
||||
import type { FallbackEntry } from "../../shared/model-requirements"
|
||||
import { resolveMessageContext } from "../../features/hook-message-injector"
|
||||
import { getSessionAgent } from "../../features/claude-code-session-state"
|
||||
import { getMessageDir } from "./message-dir"
|
||||
@@ -17,7 +18,8 @@ export async function executeBackground(
|
||||
metadata?: (input: { title?: string; metadata?: Record<string, unknown> }) => void
|
||||
},
|
||||
manager: BackgroundManager,
|
||||
client: PluginInput["client"]
|
||||
client: PluginInput["client"],
|
||||
fallbackChain?: FallbackEntry[],
|
||||
): Promise<string> {
|
||||
try {
|
||||
const messageDir = getMessageDir(toolContext.sessionID)
|
||||
@@ -48,6 +50,7 @@ export async function executeBackground(
|
||||
parentMessageID: toolContext.messageID,
|
||||
parentAgent,
|
||||
parentTools: getSessionTools(toolContext.sessionID),
|
||||
fallbackChain,
|
||||
})
|
||||
|
||||
const WAIT_FOR_SESSION_INTERVAL_MS = 50
|
||||
|
||||
@@ -99,4 +99,48 @@ describe("createCallOmoAgent", () => {
|
||||
//#then
|
||||
expect(result).not.toContain("disabled via disabled_agents")
|
||||
})
|
||||
|
||||
test("uses agent override fallback_models when launching background subagent", async () => {
|
||||
//#given
|
||||
const launch = mock(() => Promise.resolve({
|
||||
id: "task-fallback",
|
||||
sessionID: "sub-session",
|
||||
description: "Test task",
|
||||
agent: "explore",
|
||||
status: "pending",
|
||||
}))
|
||||
const managerWithLaunch = {
|
||||
launch,
|
||||
getTask: mock(() => undefined),
|
||||
} as unknown as BackgroundManager
|
||||
const toolDef = createCallOmoAgent(
|
||||
mockCtx,
|
||||
managerWithLaunch,
|
||||
[],
|
||||
{
|
||||
explore: {
|
||||
fallback_models: ["quotio/kimi-k2.5", "openai/gpt-5.2(high)"],
|
||||
},
|
||||
},
|
||||
)
|
||||
const executeFunc = toolDef.execute as Function
|
||||
|
||||
//#when
|
||||
await executeFunc(
|
||||
{
|
||||
description: "Test fallback",
|
||||
prompt: "Test prompt",
|
||||
subagent_type: "explore",
|
||||
run_in_background: true,
|
||||
},
|
||||
{ sessionID: "test", messageID: "msg", agent: "test", abort: new AbortController().signal }
|
||||
)
|
||||
|
||||
//#then
|
||||
const launchArgs = launch.mock.calls[0]?.[0]
|
||||
expect(launchArgs.fallbackChain).toEqual([
|
||||
{ providers: ["quotio"], model: "kimi-k2.5", variant: undefined },
|
||||
{ providers: ["openai"], model: "gpt-5.2", variant: "high" },
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,14 +2,46 @@ import { tool, type PluginInput, type ToolDefinition } from "@opencode-ai/plugin
|
||||
import { ALLOWED_AGENTS, CALL_OMO_AGENT_DESCRIPTION } from "./constants"
|
||||
import type { AllowedAgentType, CallOmoAgentArgs, ToolContextWithMetadata } from "./types"
|
||||
import type { BackgroundManager } from "../../features/background-agent"
|
||||
import type { CategoriesConfig, AgentOverrides } from "../../config/schema"
|
||||
import type { FallbackEntry } from "../../shared/model-requirements"
|
||||
import { AGENT_MODEL_REQUIREMENTS } from "../../shared/model-requirements"
|
||||
import { getAgentConfigKey } from "../../shared/agent-display-names"
|
||||
import { normalizeFallbackModels } from "../../shared/model-resolver"
|
||||
import { buildFallbackChainFromModels } from "../../shared/fallback-chain-from-models"
|
||||
import { log } from "../../shared"
|
||||
import { executeBackground } from "./background-executor"
|
||||
import { executeSync } from "./sync-executor"
|
||||
|
||||
function resolveFallbackChainForCallOmoAgent(args: {
|
||||
subagentType: string
|
||||
agentOverrides?: AgentOverrides
|
||||
userCategories?: CategoriesConfig
|
||||
}): FallbackEntry[] | undefined {
|
||||
const { subagentType, agentOverrides, userCategories } = args
|
||||
const agentConfigKey = getAgentConfigKey(subagentType)
|
||||
const agentRequirement = AGENT_MODEL_REQUIREMENTS[agentConfigKey]
|
||||
|
||||
const agentOverride = agentOverrides?.[agentConfigKey as keyof AgentOverrides]
|
||||
?? (agentOverrides
|
||||
? Object.entries(agentOverrides).find(([key]) => key.toLowerCase() === agentConfigKey)?.[1]
|
||||
: undefined)
|
||||
|
||||
const normalizedFallbackModels = normalizeFallbackModels(
|
||||
agentOverride?.fallback_models
|
||||
?? (agentOverride?.category ? userCategories?.[agentOverride.category]?.fallback_models : undefined)
|
||||
)
|
||||
const defaultProviderID = agentRequirement?.fallbackChain?.[0]?.providers?.[0] ?? "opencode"
|
||||
const configuredFallbackChain = buildFallbackChainFromModels(normalizedFallbackModels, defaultProviderID)
|
||||
|
||||
return configuredFallbackChain ?? agentRequirement?.fallbackChain
|
||||
}
|
||||
|
||||
export function createCallOmoAgent(
|
||||
ctx: PluginInput,
|
||||
backgroundManager: BackgroundManager,
|
||||
disabledAgents: string[] = []
|
||||
disabledAgents: string[] = [],
|
||||
agentOverrides?: AgentOverrides,
|
||||
userCategories?: CategoriesConfig,
|
||||
): ToolDefinition {
|
||||
const agentDescriptions = ALLOWED_AGENTS.map(
|
||||
(name) => `- ${name}: Specialized agent for ${name} tasks`
|
||||
@@ -54,7 +86,12 @@ export function createCallOmoAgent(
|
||||
if (args.session_id) {
|
||||
return `Error: session_id is not supported in background mode. Use run_in_background=false to continue an existing session.`
|
||||
}
|
||||
return await executeBackground(args, toolCtx, backgroundManager, ctx.client)
|
||||
const fallbackChain = resolveFallbackChainForCallOmoAgent({
|
||||
subagentType: args.subagent_type,
|
||||
agentOverrides,
|
||||
userCategories,
|
||||
})
|
||||
return await executeBackground(args, toolCtx, backgroundManager, ctx.client, fallbackChain)
|
||||
}
|
||||
|
||||
return await executeSync(args, toolCtx, ctx)
|
||||
|
||||
@@ -75,4 +75,34 @@ describe("resolveCategoryExecution", () => {
|
||||
expect(result.error).toContain("Unknown category")
|
||||
expect(result.error).toContain("definitely-not-a-real-category-xyz123")
|
||||
})
|
||||
|
||||
test("uses category fallback_models for background/runtime fallback chain", async () => {
|
||||
//#given
|
||||
const args = {
|
||||
category: "deep",
|
||||
prompt: "test prompt",
|
||||
description: "Test task",
|
||||
run_in_background: false,
|
||||
load_skills: [],
|
||||
blockedBy: undefined,
|
||||
enableSkillTools: false,
|
||||
}
|
||||
const executorCtx = createMockExecutorContext()
|
||||
executorCtx.userCategories = {
|
||||
deep: {
|
||||
model: "quotio/claude-opus-4-6",
|
||||
fallback_models: ["quotio/kimi-k2.5", "openai/gpt-5.2(high)"],
|
||||
},
|
||||
}
|
||||
|
||||
//#when
|
||||
const result = await resolveCategoryExecution(args, executorCtx, undefined, "anthropic/claude-sonnet-4-6")
|
||||
|
||||
//#then
|
||||
expect(result.error).toBeUndefined()
|
||||
expect(result.fallbackChain).toEqual([
|
||||
{ providers: ["quotio"], model: "kimi-k2.5", variant: undefined },
|
||||
{ providers: ["openai"], model: "gpt-5.2", variant: "high" },
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
@@ -7,6 +7,8 @@ import { SISYPHUS_JUNIOR_AGENT } from "./sisyphus-junior-agent"
|
||||
import { resolveCategoryConfig } from "./categories"
|
||||
import { parseModelString } from "./model-string-parser"
|
||||
import { CATEGORY_MODEL_REQUIREMENTS } from "../../shared/model-requirements"
|
||||
import { normalizeFallbackModels } from "../../shared/model-resolver"
|
||||
import { buildFallbackChainFromModels } from "../../shared/fallback-chain-from-models"
|
||||
import { getAvailableModelsForDelegateTask } from "./available-models"
|
||||
import { resolveModelForDelegateTask } from "./model-selection"
|
||||
|
||||
@@ -79,6 +81,7 @@ Available categories: ${allCategoryNames}`,
|
||||
}
|
||||
|
||||
const requirement = CATEGORY_MODEL_REQUIREMENTS[args.category!]
|
||||
const normalizedConfiguredFallbackModels = normalizeFallbackModels(resolved.config.fallback_models)
|
||||
let actualModel: string | undefined
|
||||
let modelInfo: ModelFallbackInfo | undefined
|
||||
let categoryModel: { providerID: string; modelID: string; variant?: string } | undefined
|
||||
@@ -99,6 +102,7 @@ Available categories: ${allCategoryNames}`,
|
||||
} else {
|
||||
const resolution = resolveModelForDelegateTask({
|
||||
userModel: explicitCategoryModel ?? overrideModel,
|
||||
userFallbackModels: normalizedConfiguredFallbackModels,
|
||||
categoryDefaultModel: resolved.model,
|
||||
fallbackChain: requirement.fallbackChain,
|
||||
availableModels,
|
||||
@@ -178,6 +182,14 @@ Available categories: ${categoryNames.join(", ")}`,
|
||||
const categoryConfigModel = resolved.config.model?.toLowerCase()
|
||||
const isUnstableAgent = resolved.config.is_unstable_agent === true || [unstableModel, categoryConfigModel].some(m => m ? m.includes("gemini") || m.includes("minimax") || m.includes("kimi") : false)
|
||||
|
||||
const defaultProviderID = categoryModel?.providerID
|
||||
?? parseModelString(actualModel ?? "")?.providerID
|
||||
?? "opencode"
|
||||
const configuredFallbackChain = buildFallbackChainFromModels(
|
||||
normalizedConfiguredFallbackModels,
|
||||
defaultProviderID,
|
||||
)
|
||||
|
||||
return {
|
||||
agentToUse: SISYPHUS_JUNIOR_AGENT,
|
||||
categoryModel,
|
||||
@@ -186,6 +198,6 @@ Available categories: ${categoryNames.join(", ")}`,
|
||||
modelInfo,
|
||||
actualModel,
|
||||
isUnstableAgent,
|
||||
fallbackChain: requirement?.fallbackChain,
|
||||
fallbackChain: configuredFallbackChain ?? requirement?.fallbackChain,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ function getExplicitHighBaseModel(model: string): string | null {
|
||||
|
||||
export function resolveModelForDelegateTask(input: {
|
||||
userModel?: string
|
||||
userFallbackModels?: string[]
|
||||
categoryDefaultModel?: string
|
||||
fallbackChain?: FallbackEntry[]
|
||||
availableModels: Set<string>
|
||||
@@ -44,6 +45,28 @@ export function resolveModelForDelegateTask(input: {
|
||||
}
|
||||
}
|
||||
|
||||
const userFallbackModels = input.userFallbackModels
|
||||
if (userFallbackModels && userFallbackModels.length > 0) {
|
||||
if (input.availableModels.size === 0) {
|
||||
const first = normalizeModel(userFallbackModels[0])
|
||||
if (first) {
|
||||
return { model: first }
|
||||
}
|
||||
} else {
|
||||
for (const fallbackModel of userFallbackModels) {
|
||||
const normalizedFallback = normalizeModel(fallbackModel)
|
||||
if (!normalizedFallback) continue
|
||||
|
||||
const parts = normalizedFallback.split("/")
|
||||
const providerHint = parts.length >= 2 ? [parts[0]] : undefined
|
||||
const match = fuzzyMatchModel(normalizedFallback, input.availableModels, providerHint)
|
||||
if (match) {
|
||||
return { model: match }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fallbackChain = input.fallbackChain
|
||||
if (fallbackChain && fallbackChain.length > 0) {
|
||||
if (input.availableModels.size === 0) {
|
||||
|
||||
@@ -17,7 +17,10 @@ function createBaseArgs(overrides?: Partial<DelegateTaskArgs>): DelegateTaskArgs
|
||||
}
|
||||
}
|
||||
|
||||
function createExecutorContext(agentsFn: () => Promise<unknown>): ExecutorContext {
|
||||
function createExecutorContext(
|
||||
agentsFn: () => Promise<unknown>,
|
||||
overrides?: Partial<ExecutorContext>,
|
||||
): ExecutorContext {
|
||||
const client = {
|
||||
app: {
|
||||
agents: agentsFn,
|
||||
@@ -28,6 +31,7 @@ function createExecutorContext(agentsFn: () => Promise<unknown>): ExecutorContex
|
||||
client,
|
||||
manager: {} as ExecutorContext["manager"],
|
||||
directory: "/tmp/test",
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,4 +105,62 @@ describe("resolveSubagentExecution", () => {
|
||||
expect(result.categoryModel).toEqual({ providerID: "openai", modelID: "gpt-5.3-codex" })
|
||||
cacheSpy.mockRestore()
|
||||
})
|
||||
|
||||
test("uses agent override fallback_models for subagent runtime fallback chain", async () => {
|
||||
//#given
|
||||
const args = createBaseArgs({ subagent_type: "explore" })
|
||||
const executorCtx = createExecutorContext(
|
||||
async () => ([
|
||||
{ name: "explore", mode: "subagent", model: "quotio/claude-haiku-4-5" },
|
||||
]),
|
||||
{
|
||||
agentOverrides: {
|
||||
explore: {
|
||||
fallback_models: ["quotio/gpt-5.2", "glm-5(max)"],
|
||||
},
|
||||
} as ExecutorContext["agentOverrides"],
|
||||
}
|
||||
)
|
||||
|
||||
//#when
|
||||
const result = await resolveSubagentExecution(args, executorCtx, "sisyphus", "deep")
|
||||
|
||||
//#then
|
||||
expect(result.error).toBeUndefined()
|
||||
expect(result.fallbackChain).toEqual([
|
||||
{ providers: ["quotio"], model: "gpt-5.2", variant: undefined },
|
||||
{ providers: ["quotio"], model: "glm-5", variant: "max" },
|
||||
])
|
||||
})
|
||||
|
||||
test("uses category fallback_models when agent override points at category", async () => {
|
||||
//#given
|
||||
const args = createBaseArgs({ subagent_type: "explore" })
|
||||
const executorCtx = createExecutorContext(
|
||||
async () => ([
|
||||
{ name: "explore", mode: "subagent", model: "quotio/claude-haiku-4-5" },
|
||||
]),
|
||||
{
|
||||
agentOverrides: {
|
||||
explore: {
|
||||
category: "research",
|
||||
},
|
||||
} as ExecutorContext["agentOverrides"],
|
||||
userCategories: {
|
||||
research: {
|
||||
fallback_models: ["anthropic/claude-haiku-4-5"],
|
||||
},
|
||||
} as ExecutorContext["userCategories"],
|
||||
}
|
||||
)
|
||||
|
||||
//#when
|
||||
const result = await resolveSubagentExecution(args, executorCtx, "sisyphus", "deep")
|
||||
|
||||
//#then
|
||||
expect(result.error).toBeUndefined()
|
||||
expect(result.fallbackChain).toEqual([
|
||||
{ providers: ["anthropic"], model: "claude-haiku-4-5", variant: undefined },
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,6 +4,8 @@ import { isPlanFamily } from "./constants"
|
||||
import { SISYPHUS_JUNIOR_AGENT } from "./sisyphus-junior-agent"
|
||||
import { normalizeModelFormat } from "../../shared/model-format-normalizer"
|
||||
import { AGENT_MODEL_REQUIREMENTS } from "../../shared/model-requirements"
|
||||
import { normalizeFallbackModels } from "../../shared/model-resolver"
|
||||
import { buildFallbackChainFromModels } from "../../shared/fallback-chain-from-models"
|
||||
import { getAgentDisplayName, getAgentConfigKey } from "../../shared/agent-display-names"
|
||||
import { normalizeSDKResponse } from "../../shared"
|
||||
import { log } from "../../shared/logger"
|
||||
@@ -17,7 +19,7 @@ export async function resolveSubagentExecution(
|
||||
parentAgent: string | undefined,
|
||||
categoryExamples: string
|
||||
): Promise<{ agentToUse: string; categoryModel: { providerID: string; modelID: string; variant?: string } | undefined; fallbackChain?: FallbackEntry[]; error?: string }> {
|
||||
const { client, agentOverrides } = executorCtx
|
||||
const { client, agentOverrides, userCategories } = executorCtx
|
||||
|
||||
if (!args.subagent_type?.trim()) {
|
||||
return { agentToUse: "", categoryModel: undefined, error: `Agent name cannot be empty.` }
|
||||
@@ -98,7 +100,10 @@ Create the work plan directly - that's your job as the planning agent.`,
|
||||
const agentOverride = agentOverrides?.[agentConfigKey as keyof typeof agentOverrides]
|
||||
?? (agentOverrides ? Object.entries(agentOverrides).find(([key]) => key.toLowerCase() === agentConfigKey)?.[1] : undefined)
|
||||
const agentRequirement = AGENT_MODEL_REQUIREMENTS[agentConfigKey]
|
||||
fallbackChain = agentRequirement?.fallbackChain
|
||||
const normalizedAgentFallbackModels = normalizeFallbackModels(
|
||||
agentOverride?.fallback_models
|
||||
?? (agentOverride?.category ? userCategories?.[agentOverride.category]?.fallback_models : undefined)
|
||||
)
|
||||
|
||||
if (agentOverride?.model || agentRequirement || matchedAgent.model) {
|
||||
const availableModels = await getAvailableModelsForDelegateTask(client)
|
||||
@@ -112,6 +117,7 @@ Create the work plan directly - that's your job as the planning agent.`,
|
||||
|
||||
const resolution = resolveModelForDelegateTask({
|
||||
userModel: agentOverride?.model,
|
||||
userFallbackModels: normalizedAgentFallbackModels,
|
||||
categoryDefaultModel: matchedAgentModelStr,
|
||||
fallbackChain: agentRequirement?.fallbackChain,
|
||||
availableModels,
|
||||
@@ -125,6 +131,15 @@ Create the work plan directly - that's your job as the planning agent.`,
|
||||
categoryModel = variantToUse ? { ...normalized, variant: variantToUse } : normalized
|
||||
}
|
||||
}
|
||||
|
||||
const defaultProviderID = categoryModel?.providerID
|
||||
?? normalizedMatchedModel?.providerID
|
||||
?? "opencode"
|
||||
const configuredFallbackChain = buildFallbackChainFromModels(
|
||||
normalizedAgentFallbackModels,
|
||||
defaultProviderID,
|
||||
)
|
||||
fallbackChain = configuredFallbackChain ?? agentRequirement?.fallbackChain
|
||||
}
|
||||
|
||||
if (!categoryModel && matchedAgent.model) {
|
||||
|
||||
Reference in New Issue
Block a user