diff --git a/src/hooks/start-work/index.ts b/src/hooks/start-work/index.ts index ee270861a..56470022e 100644 --- a/src/hooks/start-work/index.ts +++ b/src/hooks/start-work/index.ts @@ -1,4 +1,4 @@ export { HOOK_NAME, createStartWorkHook } from "./start-work-hook" -export { detectWorktreePath } from "./worktree-detector" +export { detectWorktreePath, listWorktrees, parseWorktreeListPorcelain } from "./worktree-detector" export type { ParsedUserRequest } from "./parse-user-request" export { parseUserRequest } from "./parse-user-request" diff --git a/src/hooks/start-work/worktree-detector.test.ts b/src/hooks/start-work/worktree-detector.test.ts index b02d5af1b..d68e99682 100644 --- a/src/hooks/start-work/worktree-detector.test.ts +++ b/src/hooks/start-work/worktree-detector.test.ts @@ -2,7 +2,7 @@ import { describe, expect, test, spyOn, beforeEach, afterEach } from "bun:test" import * as childProcess from "node:child_process" -import { detectWorktreePath } from "./worktree-detector" +import { detectWorktreePath, parseWorktreeListPorcelain, listWorktrees } from "./worktree-detector" describe("detectWorktreePath", () => { let execFileSyncSpy: ReturnType @@ -77,3 +77,113 @@ describe("detectWorktreePath", () => { }) }) }) + +describe("parseWorktreeListPorcelain", () => { + test("#given porcelain output with multiple worktrees #when parsing #then returns all entries", () => { + // given + const output = [ + "worktree /home/user/main-repo", + "HEAD abc1234", + "branch refs/heads/main", + "", + "worktree /home/user/worktrees/feature-a", + "HEAD def5678", + "branch refs/heads/feature-a", + "", + ].join("\n") + + // when + const result = parseWorktreeListPorcelain(output) + + // then + expect(result).toEqual([ + { path: "/home/user/main-repo", branch: "main", bare: false }, + { path: "/home/user/worktrees/feature-a", branch: "feature-a", bare: false }, + ]) + }) + + test("#given bare worktree #when parsing #then marks bare flag", () => { + // given + const output = [ + "worktree /home/user/bare-repo", + "HEAD abc1234", + "bare", + "", + ].join("\n") + + // when + const result = parseWorktreeListPorcelain(output) + + // then + expect(result).toEqual([ + { path: "/home/user/bare-repo", branch: undefined, bare: true }, + ]) + }) + + test("#given empty output #when parsing #then returns empty array", () => { + expect(parseWorktreeListPorcelain("")).toEqual([]) + }) + + test("#given output without trailing newline #when parsing #then still captures last entry", () => { + // given + const output = [ + "worktree /repo", + "HEAD abc1234", + "branch refs/heads/dev", + ].join("\n") + + // when + const result = parseWorktreeListPorcelain(output) + + // then + expect(result).toEqual([ + { path: "/repo", branch: "dev", bare: false }, + ]) + }) +}) + +describe("listWorktrees", () => { + let execFileSyncSpy: ReturnType + + beforeEach(() => { + execFileSyncSpy = spyOn(childProcess, "execFileSync").mockImplementation( + ((_file: string, _args: string[]) => "") as typeof childProcess.execFileSync, + ) + }) + + afterEach(() => { + execFileSyncSpy.mockRestore() + }) + + test("#given valid git repo #when listing #then returns parsed worktree entries", () => { + // given + execFileSyncSpy.mockImplementation( + ((_file: string, _args: string[]) => + "worktree /repo\nHEAD abc\nbranch refs/heads/main\n\n") as typeof childProcess.execFileSync, + ) + + // when + const result = listWorktrees("/repo") + + // then + expect(result).toEqual([{ path: "/repo", branch: "main", bare: false }]) + expect(execFileSyncSpy).toHaveBeenCalledWith( + "git", + ["worktree", "list", "--porcelain"], + expect.objectContaining({ cwd: "/repo" }), + ) + }) + + test("#given non-git directory #when listing #then returns empty array", () => { + // given + execFileSyncSpy.mockImplementation((_file: string, _args: string[]) => { + throw new Error("not a git repository") + }) + + // when + const result = listWorktrees("/tmp/not-a-repo") + + // then + expect(result).toEqual([]) + }) +}) diff --git a/src/hooks/start-work/worktree-detector.ts b/src/hooks/start-work/worktree-detector.ts index 74c919593..fe9567b7a 100644 --- a/src/hooks/start-work/worktree-detector.ts +++ b/src/hooks/start-work/worktree-detector.ts @@ -1,5 +1,68 @@ import { execFileSync } from "node:child_process" +export type WorktreeEntry = { + path: string + branch: string | undefined + bare: boolean +} + +export function parseWorktreeListPorcelain(output: string): WorktreeEntry[] { + const lines = output.split("\n").map((line) => line.trim()) + const entries: WorktreeEntry[] = [] + let current: Partial | undefined + + for (const line of lines) { + if (!line) { + if (current?.path) { + entries.push({ + path: current.path, + branch: current.branch, + bare: current.bare ?? false, + }) + } + current = undefined + continue + } + + if (line.startsWith("worktree ")) { + current = { path: line.slice("worktree ".length).trim() } + continue + } + + if (!current) continue + + if (line.startsWith("branch ")) { + current.branch = line.slice("branch ".length).trim().replace(/^refs\/heads\//, "") + } else if (line === "bare") { + current.bare = true + } + } + + if (current?.path) { + entries.push({ + path: current.path, + branch: current.branch, + bare: current.bare ?? false, + }) + } + + return entries +} + +export function listWorktrees(directory: string): WorktreeEntry[] { + try { + const output = execFileSync("git", ["worktree", "list", "--porcelain"], { + cwd: directory, + encoding: "utf-8", + timeout: 5000, + stdio: ["pipe", "pipe", "pipe"], + }) + return parseWorktreeListPorcelain(output) + } catch { + return [] + } +} + export function detectWorktreePath(directory: string): string | null { try { return execFileSync("git", ["rev-parse", "--show-toplevel"], { diff --git a/src/plugin/ultrawork-model-override.test.ts b/src/plugin/ultrawork-model-override.test.ts index 26dae2415..9ceb42aae 100644 --- a/src/plugin/ultrawork-model-override.test.ts +++ b/src/plugin/ultrawork-model-override.test.ts @@ -262,24 +262,24 @@ describe("applyUltraworkModelOverrideOnMessage", () => { } as unknown as Parameters[0] } - test("should schedule deferred DB override when message ID present", () => { + test("should schedule deferred DB override without variant when SDK unavailable", () => { //#given const config = createConfig("sisyphus", { model: "anthropic/claude-opus-4-6", variant: "max" }) const output = createOutput("ultrawork do something", { messageId: "msg_123" }) const tui = createMockTui() - //#when + //#when - no client passed, SDK validation unavailable applyUltraworkModelOverrideOnMessage(config, "sisyphus", output, tui) - //#then + //#then - variant should NOT be applied without SDK validation expect(dbOverrideSpy).toHaveBeenCalledWith( "msg_123", { providerID: "anthropic", modelID: "claude-opus-4-6" }, - "max", + undefined, ) }) - test("should override keyword-detector variant with configured ultrawork variant on deferred path", () => { + test("should NOT override variant when SDK unavailable even if config specifies variant", () => { //#given const config = createConfig("sisyphus", { model: "anthropic/claude-opus-4-6", @@ -290,17 +290,17 @@ describe("applyUltraworkModelOverrideOnMessage", () => { output.message["thinking"] = "max" const tui = createMockTui() - //#when + //#when - no client, SDK unavailable applyUltraworkModelOverrideOnMessage(config, "sisyphus", output, tui) - //#then + //#then - existing variant preserved, not overridden to "extended" expect(dbOverrideSpy).toHaveBeenCalledWith( "msg_123", { providerID: "anthropic", modelID: "claude-opus-4-6" }, - "extended", + undefined, ) - expect(output.message["variant"]).toBe("extended") - expect(output.message["thinking"]).toBe("extended") + expect(output.message["variant"]).toBe("max") + expect(output.message["thinking"]).toBe("max") }) test("should NOT mutate output.message.model when message ID present", () => { @@ -320,7 +320,7 @@ describe("applyUltraworkModelOverrideOnMessage", () => { expect(output.message.model).toEqual(sonnetModel) }) - test("should fall back to direct mutation when no message ID", () => { + test("should fall back to direct model mutation without variant when no message ID and no SDK", () => { //#given const config = createConfig("sisyphus", { model: "anthropic/claude-opus-4-6", variant: "max" }) const output = createOutput("ultrawork do something") @@ -329,24 +329,24 @@ describe("applyUltraworkModelOverrideOnMessage", () => { //#when applyUltraworkModelOverrideOnMessage(config, "sisyphus", output, tui) - //#then + //#then - model is set but variant is NOT applied without SDK validation expect(output.message.model).toEqual({ providerID: "anthropic", modelID: "claude-opus-4-6" }) - expect(output.message["variant"]).toBe("max") + expect(output.message["variant"]).toBeUndefined() expect(dbOverrideSpy).not.toHaveBeenCalled() }) - test("should apply variant-only override when no message ID", () => { + test("should not apply variant-only override when no SDK available", () => { //#given const config = createConfig("sisyphus", { variant: "high" }) const output = createOutput("ultrawork do something") const tui = createMockTui() - //#when + //#when - variant-only override, no SDK = no-op applyUltraworkModelOverrideOnMessage(config, "sisyphus", output, tui) - //#then + //#then - nothing applied since no model and variant requires SDK expect(output.message.model).toBeUndefined() - expect(output.message["variant"]).toBe("high") + expect(output.message["variant"]).toBeUndefined() expect(dbOverrideSpy).not.toHaveBeenCalled() }) @@ -414,7 +414,7 @@ describe("applyUltraworkModelOverrideOnMessage", () => { expect(dbOverrideSpy).toHaveBeenCalledWith( "msg_123", { providerID: "anthropic", modelID: "claude-opus-4-6" }, - "max", + undefined, ) }) @@ -439,4 +439,48 @@ describe("applyUltraworkModelOverrideOnMessage", () => { expect(dbOverrideSpy).not.toHaveBeenCalled() expect(toastCalled).toBe(false) }) + + test("should apply validated variant when SDK confirms model supports it", async () => { + //#given + const config = createConfig("sisyphus", { model: "anthropic/claude-opus-4-6", variant: "max" }) + const output = createOutput("ultrawork do something", { messageId: "msg_123" }) + const tui = createMockTui() + const mockClient = { + provider: { + list: async () => ({ + data: { all: [{ id: "anthropic", models: { "claude-opus-4-6": { variants: { max: {} } } } }] }, + }), + }, + } + + //#when + await applyUltraworkModelOverrideOnMessage(config, "sisyphus", output, tui, undefined, mockClient) + + //#then - SDK confirmed max exists, so variant is applied + expect(dbOverrideSpy).toHaveBeenCalledWith( + "msg_123", + { providerID: "anthropic", modelID: "claude-opus-4-6" }, + "max", + ) + }) + + test("should NOT apply variant when SDK confirms model does NOT have it", async () => { + //#given + const config = createConfig("sisyphus", { model: "anthropic/claude-haiku-4-5", variant: "max" }) + const output = createOutput("ultrawork do something", { messageId: "msg_123" }) + const tui = createMockTui() + const mockClient = { + provider: { + list: async () => ({ + data: { all: [{ id: "anthropic", models: { "claude-haiku-4-5": { variants: { high: {} } } } }] }, + }), + }, + } + + //#when + await applyUltraworkModelOverrideOnMessage(config, "sisyphus", output, tui, undefined, mockClient) + + //#then - SDK says haiku has no max variant, so variant is NOT applied + expect(output.message["variant"]).toBeUndefined() + }) }) diff --git a/src/plugin/ultrawork-model-override.ts b/src/plugin/ultrawork-model-override.ts index 980de1752..55d90c066 100644 --- a/src/plugin/ultrawork-model-override.ts +++ b/src/plugin/ultrawork-model-override.ts @@ -161,7 +161,10 @@ export function applyUltraworkModelOverrideOnMessage( : currentModel if (!client || typeof (client as { provider?: { list?: unknown } }).provider?.list !== "function") { - applyResolvedUltraworkOverride({ override, validatedVariant: override.variant, output, inputAgentName, tui }) + log("[ultrawork-model-override] SDK validation unavailable, skipping variant override", { + variant: override.variant, + }) + applyResolvedUltraworkOverride({ override, validatedVariant: undefined, output, inputAgentName, tui }) return } diff --git a/src/shared/connected-providers-cache.test.ts b/src/shared/connected-providers-cache.test.ts index a170c90b9..10774c9b3 100644 --- a/src/shared/connected-providers-cache.test.ts +++ b/src/shared/connected-providers-cache.test.ts @@ -2,7 +2,7 @@ import { beforeAll, beforeEach, afterEach, describe, expect, mock, test } from "bun:test" -import { existsSync, mkdtempSync, rmSync } from "node:fs" +import { existsSync, mkdirSync, mkdtempSync, readFileSync, rmSync, writeFileSync } from "node:fs" import { tmpdir } from "node:os" import { join } from "node:path" import * as dataPath from "./data-path" @@ -15,6 +15,16 @@ const getOmoOpenCodeCacheDirMock = mock(() => testCacheDir) let updateConnectedProvidersCache: typeof import("./connected-providers-cache").updateConnectedProvidersCache let readProviderModelsCache: typeof import("./connected-providers-cache").readProviderModelsCache +async function prepareConnectedProvidersCacheTestModule(): Promise { + testCacheDir = mkdtempSync(join(tmpdir(), "connected-providers-cache-test-")) + getOmoOpenCodeCacheDirMock.mockClear() + mock.module("./data-path", () => ({ + getOmoOpenCodeCacheDir: getOmoOpenCodeCacheDirMock, + })) + moduleImportCounter += 1 + ;({ updateConnectedProvidersCache, readProviderModelsCache } = await import(`./connected-providers-cache?test=${moduleImportCounter}`)) +} + describe("updateConnectedProvidersCache", () => { beforeAll(() => { mock.restore() @@ -22,18 +32,7 @@ describe("updateConnectedProvidersCache", () => { beforeEach(async () => { mock.restore() - const realCacheDir = join(dataPath.getCacheDir(), "oh-my-opencode") - if (existsSync(realCacheDir)) { - rmSync(realCacheDir, { recursive: true, force: true }) - } - - testCacheDir = mkdtempSync(join(tmpdir(), "connected-providers-cache-test-")) - getOmoOpenCodeCacheDirMock.mockClear() - mock.module("./data-path", () => ({ - getOmoOpenCodeCacheDir: getOmoOpenCodeCacheDirMock, - })) - moduleImportCounter += 1 - ;({ updateConnectedProvidersCache, readProviderModelsCache } = await import(`./connected-providers-cache?test=${moduleImportCounter}`)) + await prepareConnectedProvidersCacheTestModule() }) afterEach(() => { @@ -150,4 +149,25 @@ describe("updateConnectedProvidersCache", () => { const cache = readProviderModelsCache() expect(cache).toBeNull() }) + + test("does not remove the user's real cache directory during test setup", async () => { + //#given + const realCacheDir = join(dataPath.getCacheDir(), "oh-my-opencode") + const sentinelPath = join(realCacheDir, "connected-providers-cache.test-sentinel.json") + mkdirSync(realCacheDir, { recursive: true }) + writeFileSync(sentinelPath, JSON.stringify({ keep: true })) + + try { + //#when + await prepareConnectedProvidersCacheTestModule() + + //#then + expect(existsSync(sentinelPath)).toBe(true) + expect(readFileSync(sentinelPath, "utf-8")).toBe(JSON.stringify({ keep: true })) + } finally { + if (existsSync(sentinelPath)) { + rmSync(sentinelPath, { force: true }) + } + } + }) })