Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
456d9cea65 | ||
|
|
30f893b766 | ||
|
|
c905e1cb7a | ||
|
|
d3e2b36e3d | ||
|
|
5f0b6d49f5 | ||
|
|
b45408dd9c | ||
|
|
6c8527f29b | ||
|
|
cd4da93bf2 | ||
|
|
71b2f1518a | ||
|
|
dcda8769cc | ||
|
|
a94fbadd57 | ||
|
|
23b49c4a5c | ||
|
|
b4973954e3 | ||
|
|
6d50fbe563 | ||
|
|
9850dd0f6e | ||
|
|
34aaef2219 | ||
|
|
faca80caa9 | ||
|
|
0c3fbd724b | ||
|
|
c7455708f8 | ||
|
|
bffa1ad43d | ||
|
|
6560dedd4c | ||
|
|
b7e32a99f2 | ||
|
|
a06e656565 | ||
|
|
30ed086c40 | ||
|
|
7c15b06da7 | ||
|
|
0e7ee2ac30 | ||
|
|
ca93d2f0fe | ||
|
|
3ab4529bc7 | ||
|
|
9d3e152b19 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -33,3 +33,4 @@ yarn.lock
|
||||
test-injection/
|
||||
notepad.md
|
||||
oauth-success.html
|
||||
.188e87dbff6e7fd9-00000000.bun-build
|
||||
|
||||
31
bun.lock
31
bun.lock
@@ -18,6 +18,7 @@
|
||||
"jsonc-parser": "^3.3.1",
|
||||
"picocolors": "^1.1.1",
|
||||
"picomatch": "^4.0.2",
|
||||
"vscode-jsonrpc": "^8.2.0",
|
||||
"zod": "^4.1.8",
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -27,13 +28,13 @@
|
||||
"typescript": "^5.7.3",
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"oh-my-opencode-darwin-arm64": "3.1.2",
|
||||
"oh-my-opencode-darwin-x64": "3.1.2",
|
||||
"oh-my-opencode-linux-arm64": "3.1.2",
|
||||
"oh-my-opencode-linux-arm64-musl": "3.1.2",
|
||||
"oh-my-opencode-linux-x64": "3.1.2",
|
||||
"oh-my-opencode-linux-x64-musl": "3.1.2",
|
||||
"oh-my-opencode-windows-x64": "3.1.2",
|
||||
"oh-my-opencode-darwin-arm64": "3.1.6",
|
||||
"oh-my-opencode-darwin-x64": "3.1.6",
|
||||
"oh-my-opencode-linux-arm64": "3.1.6",
|
||||
"oh-my-opencode-linux-arm64-musl": "3.1.6",
|
||||
"oh-my-opencode-linux-x64": "3.1.6",
|
||||
"oh-my-opencode-linux-x64-musl": "3.1.6",
|
||||
"oh-my-opencode-windows-x64": "3.1.6",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -225,6 +226,20 @@
|
||||
|
||||
"object-inspect": ["object-inspect@1.13.4", "", {}, "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew=="],
|
||||
|
||||
"oh-my-opencode-darwin-arm64": ["oh-my-opencode-darwin-arm64@3.1.6", "", { "os": "darwin", "cpu": "arm64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-KK+ptnkBigvDYbRtF/B5izEC4IoXDS8mAnRHWFBSCINhzQR2No6AtEcwijd6vKBPR+/r71ofq/8mTsIeb1PEVQ=="],
|
||||
|
||||
"oh-my-opencode-darwin-x64": ["oh-my-opencode-darwin-x64@3.1.6", "", { "os": "darwin", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-UkPI/RUi7INarFasBUZ4Rous6RUQXsU2nr0V8KFJp+70END43D/96dDUwX+zmPtpDhD+DfWkejuwzqfkZJ2ZDQ=="],
|
||||
|
||||
"oh-my-opencode-linux-arm64": ["oh-my-opencode-linux-arm64@3.1.6", "", { "os": "linux", "cpu": "arm64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-gvmvgh7WtTtcHiCbG7z43DOYfY/jrf2S6TX/jBMX2/e1AGkcLKwz30NjGhZxeK5SyzxRVypgfZZK1IuriRgbdA=="],
|
||||
|
||||
"oh-my-opencode-linux-arm64-musl": ["oh-my-opencode-linux-arm64-musl@3.1.6", "", { "os": "linux", "cpu": "arm64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-j3R76pmQ4HGVGFJUMMCeF/1lO3Jg7xFdpcBUKCeFh42N1jMgn1aeyxkAaJYB9RwCF/p6+P8B6gVDLCEDu2mxjA=="],
|
||||
|
||||
"oh-my-opencode-linux-x64": ["oh-my-opencode-linux-x64@3.1.6", "", { "os": "linux", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-VDdo0tHCOr5nm7ajd652u798nPNOLRSTcPOnVh6vIPddkZ+ujRke+enOKOw9Pd5e+4AkthqHBwFXNm2VFgnEKg=="],
|
||||
|
||||
"oh-my-opencode-linux-x64-musl": ["oh-my-opencode-linux-x64-musl@3.1.6", "", { "os": "linux", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode" } }, "sha512-hBG/dhsr8PZelUlYsPBruSLnelB9ocB7H92I+S9svTpDVo67rAmXOoR04twKQ9TeCO4ShOa6hhMhbQnuI8fgNw=="],
|
||||
|
||||
"oh-my-opencode-windows-x64": ["oh-my-opencode-windows-x64@3.1.6", "", { "os": "win32", "cpu": "x64", "bin": { "oh-my-opencode": "bin/oh-my-opencode.exe" } }, "sha512-c8Awp03p2DsbS0G589nzveRCeJPgJRJ0vQrha4ChRmmo31Qc5OSmJ5xuMaF8L4nM+/trbTgAQMFMtCMLgtC8IQ=="],
|
||||
|
||||
"on-finished": ["on-finished@2.4.1", "", { "dependencies": { "ee-first": "1.1.1" } }, "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg=="],
|
||||
|
||||
"once": ["once@1.4.0", "", { "dependencies": { "wrappy": "1" } }, "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w=="],
|
||||
@@ -289,6 +304,8 @@
|
||||
|
||||
"vary": ["vary@1.1.2", "", {}, "sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg=="],
|
||||
|
||||
"vscode-jsonrpc": ["vscode-jsonrpc@8.2.1", "", {}, "sha512-kdjOSJ2lLIn7r1rtrMbbNCHjyMPfRnowdKjBQ+mGq6NAW5QY2bEZC/khaC5OR8svbbjvLEaIXkOq45e2X9BIbQ=="],
|
||||
|
||||
"which": ["which@2.0.2", "", { "dependencies": { "isexe": "^2.0.0" }, "bin": { "node-which": "./bin/node-which" } }, "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA=="],
|
||||
|
||||
"wrappy": ["wrappy@1.0.2", "", {}, "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ=="],
|
||||
|
||||
@@ -134,7 +134,41 @@ bunx oh-my-opencode run [prompt]
|
||||
|
||||
---
|
||||
|
||||
## 6. `auth` - Authentication Management
|
||||
## 6. `mcp oauth` - MCP OAuth Management
|
||||
|
||||
Manages OAuth 2.1 authentication for remote MCP servers.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
# Login to an OAuth-protected MCP server
|
||||
bunx oh-my-opencode mcp oauth login <server-name> --server-url https://api.example.com
|
||||
|
||||
# Login with explicit client ID and scopes
|
||||
bunx oh-my-opencode mcp oauth login my-api --server-url https://api.example.com --client-id my-client --scopes "read,write"
|
||||
|
||||
# Remove stored OAuth tokens
|
||||
bunx oh-my-opencode mcp oauth logout <server-name>
|
||||
|
||||
# Check OAuth token status
|
||||
bunx oh-my-opencode mcp oauth status [server-name]
|
||||
```
|
||||
|
||||
### Options
|
||||
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| `--server-url <url>` | MCP server URL (required for login) |
|
||||
| `--client-id <id>` | OAuth client ID (optional if server supports Dynamic Client Registration) |
|
||||
| `--scopes <scopes>` | Comma-separated OAuth scopes |
|
||||
|
||||
### Token Storage
|
||||
|
||||
Tokens are stored in `~/.config/opencode/mcp-oauth.json` with `0600` permissions (owner read/write only). Key format: `{serverHost}/{resource}`.
|
||||
|
||||
---
|
||||
|
||||
## 7. `auth` - Authentication Management
|
||||
|
||||
Manages Google Antigravity OAuth authentication. Required for using Gemini models.
|
||||
|
||||
@@ -153,7 +187,7 @@ bunx oh-my-opencode auth status
|
||||
|
||||
---
|
||||
|
||||
## 7. Configuration Files
|
||||
## 8. Configuration Files
|
||||
|
||||
The CLI searches for configuration files in the following locations (in priority order):
|
||||
|
||||
@@ -183,7 +217,7 @@ Configuration files support **JSONC (JSON with Comments)** format. You can use c
|
||||
|
||||
---
|
||||
|
||||
## 8. Troubleshooting
|
||||
## 9. Troubleshooting
|
||||
|
||||
### "OpenCode version too old" Error
|
||||
|
||||
@@ -213,7 +247,7 @@ bunx oh-my-opencode doctor --category authentication
|
||||
|
||||
---
|
||||
|
||||
## 9. Non-Interactive Mode
|
||||
## 10. Non-Interactive Mode
|
||||
|
||||
Use the `--no-tui` option for CI/CD environments.
|
||||
|
||||
@@ -227,7 +261,7 @@ bunx oh-my-opencode doctor --json > doctor-report.json
|
||||
|
||||
---
|
||||
|
||||
## 10. Developer Information
|
||||
## 11. Developer Information
|
||||
|
||||
### CLI Structure
|
||||
|
||||
|
||||
@@ -163,7 +163,39 @@ Override built-in agent settings:
|
||||
}
|
||||
```
|
||||
|
||||
Each agent supports: `model`, `temperature`, `top_p`, `prompt`, `prompt_append`, `tools`, `disable`, `description`, `mode`, `color`, `permission`.
|
||||
Each agent supports: `model`, `temperature`, `top_p`, `prompt`, `prompt_append`, `tools`, `disable`, `description`, `mode`, `color`, `permission`, `category`, `variant`, `maxTokens`, `thinking`, `reasoningEffort`, `textVerbosity`, `providerOptions`.
|
||||
|
||||
### Additional Agent Options
|
||||
|
||||
| Option | Type | Description |
|
||||
| ------------------- | ------- | ----------------------------------------------------------------------------------------------- |
|
||||
| `category` | string | Category name to inherit model and other settings from category defaults |
|
||||
| `variant` | string | Model variant (e.g., `max`, `high`, `medium`, `low`, `xhigh`) |
|
||||
| `maxTokens` | number | Maximum tokens for response. Passed directly to OpenCode SDK. |
|
||||
| `thinking` | object | Extended thinking configuration for Anthropic models. See [Thinking Options](#thinking-options) below. |
|
||||
| `reasoningEffort` | string | OpenAI reasoning effort level. Values: `low`, `medium`, `high`, `xhigh`. |
|
||||
| `textVerbosity` | string | Text verbosity level. Values: `low`, `medium`, `high`. |
|
||||
| `providerOptions` | object | Provider-specific options passed directly to OpenCode SDK. |
|
||||
|
||||
#### Thinking Options (Anthropic)
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"oracle": {
|
||||
"thinking": {
|
||||
"type": "enabled",
|
||||
"budgetTokens": 200000
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
| ------------- | ------- | ------- | -------------------------------------------- |
|
||||
| `type` | string | - | `enabled` or `disabled` |
|
||||
| `budgetTokens`| number | - | Maximum budget tokens for extended thinking |
|
||||
|
||||
Use `prompt_append` to add extra instructions without replacing the default system prompt:
|
||||
|
||||
@@ -213,7 +245,7 @@ Or disable via `disabled_agents` in `~/.config/opencode/oh-my-opencode.json` or
|
||||
}
|
||||
```
|
||||
|
||||
Available agents: `oracle`, `librarian`, `explore`, `multimodal-looker`
|
||||
Available agents: `sisyphus`, `prometheus`, `oracle`, `librarian`, `explore`, `multimodal-looker`, `metis`, `momus`, `atlas`
|
||||
|
||||
## Built-in Skills
|
||||
|
||||
@@ -232,6 +264,105 @@ Disable built-in skills via `disabled_skills` in `~/.config/opencode/oh-my-openc
|
||||
|
||||
Available built-in skills: `playwright`, `agent-browser`, `git-master`
|
||||
|
||||
## Skills Configuration
|
||||
|
||||
Configure advanced skills settings including custom skill sources, enabling/disabling specific skills, and defining custom skills.
|
||||
|
||||
```json
|
||||
{
|
||||
"skills": {
|
||||
"sources": [
|
||||
{ "path": "./custom-skills", "recursive": true },
|
||||
"https://example.com/skill.yaml"
|
||||
],
|
||||
"enable": ["my-custom-skill"],
|
||||
"disable": ["other-skill"],
|
||||
"my-skill": {
|
||||
"description": "Custom skill description",
|
||||
"template": "Custom prompt template",
|
||||
"from": "source-file.ts",
|
||||
"model": "custom/model",
|
||||
"agent": "custom-agent",
|
||||
"subtask": true,
|
||||
"argument-hint": "usage hint",
|
||||
"license": "MIT",
|
||||
"compatibility": ">= 3.0.0",
|
||||
"metadata": {
|
||||
"author": "Your Name"
|
||||
},
|
||||
"allowed-tools": ["tool1", "tool2"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Sources
|
||||
|
||||
Load skills from local directories or remote URLs:
|
||||
|
||||
```json
|
||||
{
|
||||
"skills": {
|
||||
"sources": [
|
||||
{ "path": "./custom-skills", "recursive": true },
|
||||
{ "path": "./single-skill.yaml" },
|
||||
"https://example.com/skill.yaml",
|
||||
"https://raw.githubusercontent.com/user/repo/main/skills/*"
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
| ----------- | ------- | ---------------------------------------------- |
|
||||
| `path` | - | Local file/directory path or remote URL |
|
||||
| `recursive` | `false` | Recursively load from directory |
|
||||
| `glob` | - | Glob pattern for file selection |
|
||||
|
||||
### Enable/Disable Skills
|
||||
|
||||
```json
|
||||
{
|
||||
"skills": {
|
||||
"enable": ["skill-1", "skill-2"],
|
||||
"disable": ["disabled-skill"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Custom Skill Definition
|
||||
|
||||
Define custom skills directly in your config:
|
||||
|
||||
| Option | Default | Description |
|
||||
| ---------------- | ------- | ------------------------------------------------------------------------------------ |
|
||||
| `description` | - | Human-readable description of the skill |
|
||||
| `template` | - | Custom prompt template for the skill |
|
||||
| `from` | - | Source file to load template from |
|
||||
| `model` | - | Override model for this skill |
|
||||
| `agent` | - | Override agent for this skill |
|
||||
| `subtask` | `false` | Whether to run as a subtask |
|
||||
| `argument-hint` | - | Hint for how to use the skill |
|
||||
| `license` | - | Skill license |
|
||||
| `compatibility` | - | Required oh-my-opencode version compatibility |
|
||||
| `metadata` | - | Additional metadata as key-value pairs |
|
||||
| `allowed-tools` | - | Array of tools this skill is allowed to use |
|
||||
|
||||
**Example: Custom skill**
|
||||
|
||||
```json
|
||||
{
|
||||
"skills": {
|
||||
"data-analyst": {
|
||||
"description": "Specialized for data analysis tasks",
|
||||
"template": "You are a data analyst. Focus on statistical analysis, visualization, and data interpretation.",
|
||||
"model": "openai/gpt-5.2",
|
||||
"allowed-tools": ["read", "bash", "lsp_diagnostics"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Browser Automation
|
||||
|
||||
Choose between two browser automation providers:
|
||||
@@ -555,6 +686,7 @@ Configure concurrency limits for background agent tasks. This controls how many
|
||||
{
|
||||
"background_task": {
|
||||
"defaultConcurrency": 5,
|
||||
"staleTimeoutMs": 180000,
|
||||
"providerConcurrency": {
|
||||
"anthropic": 3,
|
||||
"openai": 5,
|
||||
@@ -571,6 +703,7 @@ Configure concurrency limits for background agent tasks. This controls how many
|
||||
| Option | Default | Description |
|
||||
| --------------------- | ------- | ----------------------------------------------------------------------------------------------------------------------- |
|
||||
| `defaultConcurrency` | - | Default maximum concurrent background tasks for all providers/models |
|
||||
| `staleTimeoutMs` | `180000` | Stale timeout in milliseconds - interrupt tasks with no activity for this duration (minimum: 60000 = 1 minute) |
|
||||
| `providerConcurrency` | - | Per-provider concurrency limits. Keys are provider names (e.g., `anthropic`, `openai`, `google`) |
|
||||
| `modelConcurrency` | - | Per-model concurrency limits. Keys are full model names (e.g., `anthropic/claude-opus-4-5`). Overrides provider limits. |
|
||||
|
||||
@@ -692,7 +825,14 @@ Add your own categories or override built-in ones:
|
||||
}
|
||||
```
|
||||
|
||||
Each category supports: `model`, `temperature`, `top_p`, `maxTokens`, `thinking`, `reasoningEffort`, `textVerbosity`, `tools`, `prompt_append`, `variant`.
|
||||
Each category supports: `model`, `temperature`, `top_p`, `maxTokens`, `thinking`, `reasoningEffort`, `textVerbosity`, `tools`, `prompt_append`, `variant`, `description`, `is_unstable_agent`.
|
||||
|
||||
### Additional Category Options
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
| ------------------ | ------- | ------- | --------------------------------------------------------------------------------------------------- |
|
||||
| `description` | string | - | Human-readable description of the category's purpose. Shown in delegate_task prompt. |
|
||||
| `is_unstable_agent`| boolean | `false` | Mark agent as unstable - forces background mode for monitoring. Auto-enabled for gemini models. |
|
||||
|
||||
## Model Resolution System
|
||||
|
||||
@@ -826,12 +966,93 @@ Disable specific built-in hooks via `disabled_hooks` in `~/.config/opencode/oh-m
|
||||
}
|
||||
```
|
||||
|
||||
Available hooks: `todo-continuation-enforcer`, `context-window-monitor`, `session-recovery`, `session-notification`, `comment-checker`, `grep-output-truncator`, `tool-output-truncator`, `directory-agents-injector`, `directory-readme-injector`, `empty-task-response-detector`, `think-mode`, `anthropic-context-window-limit-recovery`, `rules-injector`, `background-notification`, `auto-update-checker`, `startup-toast`, `keyword-detector`, `agent-usage-reminder`, `non-interactive-env`, `interactive-bash-session`, `compaction-context-injector`, `thinking-block-validator`, `claude-code-hooks`, `ralph-loop`, `preemptive-compaction`
|
||||
Available hooks: `todo-continuation-enforcer`, `context-window-monitor`, `session-recovery`, `session-notification`, `comment-checker`, `grep-output-truncator`, `tool-output-truncator`, `directory-agents-injector`, `directory-readme-injector`, `empty-task-response-detector`, `think-mode`, `anthropic-context-window-limit-recovery`, `rules-injector`, `background-notification`, `auto-update-checker`, `startup-toast`, `keyword-detector`, `agent-usage-reminder`, `non-interactive-env`, `interactive-bash-session`, `compaction-context-injector`, `thinking-block-validator`, `claude-code-hooks`, `ralph-loop`, `preemptive-compaction`, `auto-slash-command`, `sisyphus-junior-notepad`, `start-work`
|
||||
|
||||
**Note on `directory-agents-injector`**: This hook is **automatically disabled** when running on OpenCode 1.1.37+ because OpenCode now has native support for dynamically resolving AGENTS.md files from subdirectories (PR #10678). This prevents duplicate AGENTS.md injection. For older OpenCode versions, the hook remains active to provide the same functionality.
|
||||
|
||||
**Note on `auto-update-checker` and `startup-toast`**: The `startup-toast` hook is a sub-feature of `auto-update-checker`. To disable only the startup toast notification while keeping update checking enabled, add `"startup-toast"` to `disabled_hooks`. To disable all update checking features (including the toast), add `"auto-update-checker"` to `disabled_hooks`.
|
||||
|
||||
## Disabled Commands
|
||||
|
||||
Disable specific built-in commands via `disabled_commands` in `~/.config/opencode/oh-my-opencode.json` or `.opencode/oh-my-opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"disabled_commands": ["init-deep", "start-work"]
|
||||
}
|
||||
```
|
||||
|
||||
Available commands: `init-deep`, `start-work`
|
||||
|
||||
## Comment Checker
|
||||
|
||||
Configure comment-checker hook behavior. The comment checker warns when excessive comments are added to code.
|
||||
|
||||
```json
|
||||
{
|
||||
"comment_checker": {
|
||||
"custom_prompt": "Your custom warning message. Use {{comments}} placeholder for detected comments XML."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
| ------------- | ------- | -------------------------------------------------------------------------- |
|
||||
| `custom_prompt` | - | Custom warning message to replace the default. Use `{{comments}}` placeholder. |
|
||||
|
||||
## Notification
|
||||
|
||||
Configure notification behavior for background task completion.
|
||||
|
||||
```json
|
||||
{
|
||||
"notification": {
|
||||
"force_enable": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
| -------------- | ------- | ---------------------------------------------------------------------------------------------- |
|
||||
| `force_enable` | `false` | Force enable session-notification even if external notification plugins are detected. Default: `false`. |
|
||||
|
||||
## Sisyphus Tasks & Swarm
|
||||
|
||||
Configure Sisyphus Tasks and Swarm systems for advanced task management and multi-agent orchestration.
|
||||
|
||||
```json
|
||||
{
|
||||
"sisyphus": {
|
||||
"tasks": {
|
||||
"enabled": false,
|
||||
"storage_path": ".sisyphus/tasks",
|
||||
"claude_code_compat": false
|
||||
},
|
||||
"swarm": {
|
||||
"enabled": false,
|
||||
"storage_path": ".sisyphus/teams",
|
||||
"ui_mode": "toast"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Tasks Configuration
|
||||
|
||||
| Option | Default | Description |
|
||||
| -------------------- | ------------------ | ------------------------------------------------------------------------- |
|
||||
| `enabled` | `false` | Enable Sisyphus Tasks system |
|
||||
| `storage_path` | `.sisyphus/tasks` | Storage path for tasks (relative to project root) |
|
||||
| `claude_code_compat` | `false` | Enable Claude Code path compatibility mode |
|
||||
|
||||
### Swarm Configuration
|
||||
|
||||
| Option | Default | Description |
|
||||
| -------------- | ------------------ | -------------------------------------------------------------- |
|
||||
| `enabled` | `false` | Enable Sisyphus Swarm system for multi-agent orchestration |
|
||||
| `storage_path` | `.sisyphus/teams` | Storage path for teams (relative to project root) |
|
||||
| `ui_mode` | `toast` | UI mode: `toast` (notifications), `tmux` (panes), or `both` |
|
||||
|
||||
## MCPs
|
||||
|
||||
Exa, Context7 and grep.app MCP enabled by default.
|
||||
@@ -873,6 +1094,38 @@ Add LSP servers via the `lsp` option in `~/.config/opencode/oh-my-opencode.json`
|
||||
|
||||
Each server supports: `command`, `extensions`, `priority`, `env`, `initialization`, `disabled`.
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
| -------------- | -------- | ------- | ---------------------------------------------------------------------- |
|
||||
| `command` | array | - | Command to start the LSP server (executable + args) |
|
||||
| `extensions` | array | - | File extensions this server handles (e.g., `[".ts", ".tsx"]`) |
|
||||
| `priority` | number | - | Server priority when multiple servers match a file |
|
||||
| `env` | object | - | Environment variables for the LSP server (key-value pairs) |
|
||||
| `initialization`| object | - | Custom initialization options passed to the LSP server |
|
||||
| `disabled` | boolean | `false` | Whether to disable this LSP server |
|
||||
|
||||
**Example with advanced options:**
|
||||
|
||||
```json
|
||||
{
|
||||
"lsp": {
|
||||
"typescript-language-server": {
|
||||
"command": ["typescript-language-server", "--stdio"],
|
||||
"extensions": [".ts", ".tsx"],
|
||||
"priority": 10,
|
||||
"env": {
|
||||
"NODE_OPTIONS": "--max-old-space-size=4096"
|
||||
},
|
||||
"initialization": {
|
||||
"preferences": {
|
||||
"includeInlayParameterNameHints": "all",
|
||||
"includeInlayFunctionParameterTypeHints": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Experimental
|
||||
|
||||
Opt-in experimental features that may change or be removed in future versions. Use with caution.
|
||||
@@ -882,7 +1135,29 @@ Opt-in experimental features that may change or be removed in future versions. U
|
||||
"experimental": {
|
||||
"truncate_all_tool_outputs": true,
|
||||
"aggressive_truncation": true,
|
||||
"auto_resume": true
|
||||
"auto_resume": true,
|
||||
"dynamic_context_pruning": {
|
||||
"enabled": false,
|
||||
"notification": "detailed",
|
||||
"turn_protection": {
|
||||
"enabled": true,
|
||||
"turns": 3
|
||||
},
|
||||
"protected_tools": ["task", "todowrite", "lsp_rename"],
|
||||
"strategies": {
|
||||
"deduplication": {
|
||||
"enabled": true
|
||||
},
|
||||
"supersede_writes": {
|
||||
"enabled": true,
|
||||
"aggressive": false
|
||||
},
|
||||
"purge_errors": {
|
||||
"enabled": true,
|
||||
"turns": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -891,7 +1166,72 @@ Opt-in experimental features that may change or be removed in future versions. U
|
||||
| --------------------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `truncate_all_tool_outputs` | `false` | Truncates ALL tool outputs instead of just whitelisted tools (Grep, Glob, LSP, AST-grep). Tool output truncator is enabled by default - disable via `disabled_hooks`. |
|
||||
| `aggressive_truncation` | `false` | When token limit is exceeded, aggressively truncates tool outputs to fit within limits. More aggressive than the default truncation behavior. Falls back to summarize/revert if insufficient. |
|
||||
| `auto_resume` | `false` | Automatically resumes session after successful recovery from thinking block errors or thinking disabled violations. Extracts the last user message and continues. |
|
||||
| `auto_resume` | `false` | Automatically resumes session after successful recovery from thinking block errors or thinking disabled violations. Extracts last user message and continues. |
|
||||
| `dynamic_context_pruning` | See below | Dynamic context pruning configuration for managing context window usage automatically. See [Dynamic Context Pruning](#dynamic-context-pruning) below. |
|
||||
|
||||
### Dynamic Context Pruning
|
||||
|
||||
Dynamic context pruning automatically manages context window by intelligently pruning old tool outputs. This feature helps maintain performance in long sessions.
|
||||
|
||||
```json
|
||||
{
|
||||
"experimental": {
|
||||
"dynamic_context_pruning": {
|
||||
"enabled": false,
|
||||
"notification": "detailed",
|
||||
"turn_protection": {
|
||||
"enabled": true,
|
||||
"turns": 3
|
||||
},
|
||||
"protected_tools": ["task", "todowrite", "todoread", "lsp_rename", "session_read", "session_write", "session_search"],
|
||||
"strategies": {
|
||||
"deduplication": {
|
||||
"enabled": true
|
||||
},
|
||||
"supersede_writes": {
|
||||
"enabled": true,
|
||||
"aggressive": false
|
||||
},
|
||||
"purge_errors": {
|
||||
"enabled": true,
|
||||
"turns": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
| ----------------- | ------- | ----------------------------------------------------------------------------------------- |
|
||||
| `enabled` | `false` | Enable dynamic context pruning |
|
||||
| `notification` | `detailed` | Notification level: `off`, `minimal`, or `detailed` |
|
||||
| `turn_protection` | See below | Turn protection settings - prevent pruning recent tool outputs |
|
||||
|
||||
#### Turn Protection
|
||||
|
||||
| Option | Default | Description |
|
||||
| --------- | ------- | ------------------------------------------------------------ |
|
||||
| `enabled` | `true` | Enable turn protection |
|
||||
| `turns` | `3` | Number of recent turns to protect from pruning (1-10) |
|
||||
|
||||
#### Protected Tools
|
||||
|
||||
Tools that should never be pruned (default):
|
||||
|
||||
```json
|
||||
["task", "todowrite", "todoread", "lsp_rename", "session_read", "session_write", "session_search"]
|
||||
```
|
||||
|
||||
#### Pruning Strategies
|
||||
|
||||
| Strategy | Option | Default | Description |
|
||||
| ------------------- | ------------ | ------- | ---------------------------------------------------------------------------- |
|
||||
| **deduplication** | `enabled` | `true` | Remove duplicate tool calls (same tool + same args) |
|
||||
| **supersede_writes**| `enabled` | `true` | Prune write inputs when file subsequently read |
|
||||
| | `aggressive` | `false` | Aggressive mode: prune any write if ANY subsequent read |
|
||||
| **purge_errors** | `enabled` | `true` | Prune errored tool inputs after N turns |
|
||||
| | `turns` | `5` | Number of turns before pruning errors (1-20) |
|
||||
|
||||
**Warning**: These features are experimental and may cause unexpected behavior. Enable only if you understand the implications.
|
||||
|
||||
|
||||
@@ -521,6 +521,37 @@ mcp:
|
||||
|
||||
The `skill_mcp` tool invokes these operations with full schema discovery.
|
||||
|
||||
#### OAuth-Enabled MCPs
|
||||
|
||||
Skills can define OAuth-protected remote MCP servers. OAuth 2.1 with full RFC compliance (RFC 9728, 8414, 8707, 7591) is supported:
|
||||
|
||||
```yaml
|
||||
---
|
||||
description: My API skill
|
||||
mcp:
|
||||
my-api:
|
||||
url: https://api.example.com/mcp
|
||||
oauth:
|
||||
clientId: ${CLIENT_ID}
|
||||
scopes: ["read", "write"]
|
||||
---
|
||||
```
|
||||
|
||||
When a skill MCP has `oauth` configured:
|
||||
- **Auto-discovery**: Fetches `/.well-known/oauth-protected-resource` (RFC 9728), falls back to `/.well-known/oauth-authorization-server` (RFC 8414)
|
||||
- **Dynamic Client Registration**: Auto-registers with servers supporting RFC 7591 (clientId becomes optional)
|
||||
- **PKCE**: Mandatory for all flows
|
||||
- **Resource Indicators**: Auto-generated from MCP URL per RFC 8707
|
||||
- **Token Storage**: Persisted in `~/.config/opencode/mcp-oauth.json` (chmod 0600)
|
||||
- **Auto-refresh**: Tokens refresh on 401; step-up authorization on 403 with `WWW-Authenticate`
|
||||
- **Dynamic Port**: OAuth callback server uses an auto-discovered available port
|
||||
|
||||
Pre-authenticate via CLI:
|
||||
|
||||
```bash
|
||||
bunx oh-my-opencode mcp oauth login <server-name> --server-url https://api.example.com
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Context Injection
|
||||
|
||||
17
package.json
17
package.json
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "oh-my-opencode",
|
||||
"version": "3.1.5",
|
||||
"version": "3.1.8",
|
||||
"description": "The Best AI Agent Harness - Batteries-Included OpenCode Plugin with Multi-Model Orchestration, Parallel Background Agents, and Crafted LSP/AST Tools",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -64,6 +64,7 @@
|
||||
"jsonc-parser": "^3.3.1",
|
||||
"picocolors": "^1.1.1",
|
||||
"picomatch": "^4.0.2",
|
||||
"vscode-jsonrpc": "^8.2.0",
|
||||
"zod": "^4.1.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -73,13 +74,13 @@
|
||||
"typescript": "^5.7.3"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"oh-my-opencode-darwin-arm64": "3.1.5",
|
||||
"oh-my-opencode-darwin-x64": "3.1.5",
|
||||
"oh-my-opencode-linux-arm64": "3.1.5",
|
||||
"oh-my-opencode-linux-arm64-musl": "3.1.5",
|
||||
"oh-my-opencode-linux-x64": "3.1.5",
|
||||
"oh-my-opencode-linux-x64-musl": "3.1.5",
|
||||
"oh-my-opencode-windows-x64": "3.1.5"
|
||||
"oh-my-opencode-darwin-arm64": "3.1.8",
|
||||
"oh-my-opencode-darwin-x64": "3.1.8",
|
||||
"oh-my-opencode-linux-arm64": "3.1.8",
|
||||
"oh-my-opencode-linux-arm64-musl": "3.1.8",
|
||||
"oh-my-opencode-linux-x64": "3.1.8",
|
||||
"oh-my-opencode-linux-x64-musl": "3.1.8",
|
||||
"oh-my-opencode-windows-x64": "3.1.8"
|
||||
},
|
||||
"trustedDependencies": [
|
||||
"@ast-grep/cli",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "oh-my-opencode-darwin-arm64",
|
||||
"version": "3.1.5",
|
||||
"version": "3.1.8",
|
||||
"description": "Platform-specific binary for oh-my-opencode (darwin-arm64)",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "oh-my-opencode-darwin-x64",
|
||||
"version": "3.1.5",
|
||||
"version": "3.1.8",
|
||||
"description": "Platform-specific binary for oh-my-opencode (darwin-x64)",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "oh-my-opencode-linux-arm64-musl",
|
||||
"version": "3.1.5",
|
||||
"version": "3.1.8",
|
||||
"description": "Platform-specific binary for oh-my-opencode (linux-arm64-musl)",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "oh-my-opencode-linux-arm64",
|
||||
"version": "3.1.5",
|
||||
"version": "3.1.8",
|
||||
"description": "Platform-specific binary for oh-my-opencode (linux-arm64)",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "oh-my-opencode-linux-x64-musl",
|
||||
"version": "3.1.5",
|
||||
"version": "3.1.8",
|
||||
"description": "Platform-specific binary for oh-my-opencode (linux-x64-musl)",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "oh-my-opencode-linux-x64",
|
||||
"version": "3.1.5",
|
||||
"version": "3.1.8",
|
||||
"description": "Platform-specific binary for oh-my-opencode (linux-x64)",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "oh-my-opencode-windows-x64",
|
||||
"version": "3.1.5",
|
||||
"version": "3.1.8",
|
||||
"description": "Platform-specific binary for oh-my-opencode (windows-x64)",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
|
||||
@@ -943,6 +943,38 @@
|
||||
"created_at": "2026-01-28T13:04:16Z",
|
||||
"repoId": 1108837393,
|
||||
"pullRequestNo": 1203
|
||||
},
|
||||
{
|
||||
"name": "KennyDizi",
|
||||
"id": 16578966,
|
||||
"comment_id": 3811619818,
|
||||
"created_at": "2026-01-28T14:26:10Z",
|
||||
"repoId": 1108837393,
|
||||
"pullRequestNo": 1214
|
||||
},
|
||||
{
|
||||
"name": "mrdavidlaing",
|
||||
"id": 227505,
|
||||
"comment_id": 3813542625,
|
||||
"created_at": "2026-01-28T19:51:34Z",
|
||||
"repoId": 1108837393,
|
||||
"pullRequestNo": 1226
|
||||
},
|
||||
{
|
||||
"name": "Lynricsy",
|
||||
"id": 62173814,
|
||||
"comment_id": 3816370548,
|
||||
"created_at": "2026-01-29T09:00:28Z",
|
||||
"repoId": 1108837393,
|
||||
"pullRequestNo": 1241
|
||||
},
|
||||
{
|
||||
"name": "LeekJay",
|
||||
"id": 39609783,
|
||||
"comment_id": 3819009761,
|
||||
"created_at": "2026-01-29T17:03:24Z",
|
||||
"repoId": 1108837393,
|
||||
"pullRequestNo": 1254
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -529,7 +529,7 @@ export function createAtlasAgent(ctx: OrchestratorContext): AgentConfig {
|
||||
])
|
||||
return {
|
||||
description:
|
||||
"Orchestrates work via delegate_task() to complete ALL tasks in a todo list until fully done",
|
||||
"Orchestrates work via delegate_task() to complete ALL tasks in a todo list until fully done. (Atlas - OhMyOpenCode)",
|
||||
mode: "primary" as const,
|
||||
...(ctx.model ? { model: ctx.model } : {}),
|
||||
temperature: 0.1,
|
||||
|
||||
@@ -33,7 +33,7 @@ export function createExploreAgent(model: string): AgentConfig {
|
||||
|
||||
return {
|
||||
description:
|
||||
'Contextual grep for codebases. Answers "Where is X?", "Which file has Y?", "Find the code that does Z". Fire multiple in parallel for broad searches. Specify thoroughness: "quick" for basic, "medium" for moderate, "very thorough" for comprehensive analysis.',
|
||||
'Contextual grep for codebases. Answers "Where is X?", "Which file has Y?", "Find the code that does Z". Fire multiple in parallel for broad searches. Specify thoroughness: "quick" for basic, "medium" for moderate, "very thorough" for comprehensive analysis. (Explore - OhMyOpenCode)',
|
||||
mode: "subagent" as const,
|
||||
model,
|
||||
temperature: 0.1,
|
||||
|
||||
@@ -30,7 +30,7 @@ export function createLibrarianAgent(model: string): AgentConfig {
|
||||
|
||||
return {
|
||||
description:
|
||||
"Specialized codebase understanding agent for multi-repository analysis, searching remote codebases, retrieving official documentation, and finding implementation examples using GitHub CLI, Context7, and Web Search. MUST BE USED when users ask to look up code in remote repositories, explain library internals, or find usage examples in open source.",
|
||||
"Specialized codebase understanding agent for multi-repository analysis, searching remote codebases, retrieving official documentation, and finding implementation examples using GitHub CLI, Context7, and Web Search. MUST BE USED when users ask to look up code in remote repositories, explain library internals, or find usage examples in open source. (Librarian - OhMyOpenCode)",
|
||||
mode: "subagent" as const,
|
||||
model,
|
||||
temperature: 0.1,
|
||||
|
||||
@@ -310,7 +310,7 @@ const metisRestrictions = createAgentToolRestrictions([
|
||||
export function createMetisAgent(model: string): AgentConfig {
|
||||
return {
|
||||
description:
|
||||
"Pre-planning consultant that analyzes requests to identify hidden intentions, ambiguities, and AI failure points.",
|
||||
"Pre-planning consultant that analyzes requests to identify hidden intentions, ambiguities, and AI failure points. (Metis - OhMyOpenCode)",
|
||||
mode: "subagent" as const,
|
||||
model,
|
||||
temperature: 0.3,
|
||||
|
||||
@@ -399,7 +399,7 @@ export function createMomusAgent(model: string): AgentConfig {
|
||||
|
||||
const base = {
|
||||
description:
|
||||
"Expert reviewer for evaluating work plans against rigorous clarity, verifiability, and completeness standards.",
|
||||
"Expert reviewer for evaluating work plans against rigorous clarity, verifiability, and completeness standards. (Momus - OhMyOpenCode)",
|
||||
mode: "subagent" as const,
|
||||
model,
|
||||
temperature: 0.1,
|
||||
|
||||
@@ -14,7 +14,7 @@ export function createMultimodalLookerAgent(model: string): AgentConfig {
|
||||
|
||||
return {
|
||||
description:
|
||||
"Analyze media files (PDFs, images, diagrams) that require interpretation beyond raw text. Extracts specific information or summaries from documents, describes visual content. Use when you need analyzed/extracted data rather than literal file contents.",
|
||||
"Analyze media files (PDFs, images, diagrams) that require interpretation beyond raw text. Extracts specific information or summaries from documents, describes visual content. Use when you need analyzed/extracted data rather than literal file contents. (Multimodal-Looker - OhMyOpenCode)",
|
||||
mode: "subagent" as const,
|
||||
model,
|
||||
temperature: 0.1,
|
||||
|
||||
@@ -105,7 +105,7 @@ export function createOracleAgent(model: string): AgentConfig {
|
||||
|
||||
const base = {
|
||||
description:
|
||||
"Read-only consultation agent. High-IQ reasoning specialist for debugging hard problems and high-difficulty architecture design.",
|
||||
"Read-only consultation agent. High-IQ reasoning specialist for debugging hard problems and high-difficulty architecture design. (Oracle - OhMyOpenCode)",
|
||||
mode: "subagent" as const,
|
||||
model,
|
||||
temperature: 0.1,
|
||||
|
||||
@@ -84,7 +84,7 @@ export function createSisyphusJuniorAgentWithOverrides(
|
||||
|
||||
const base: AgentConfig = {
|
||||
description: override?.description ??
|
||||
"Sisyphus-Junior - Focused task executor. Same discipline, no delegation.",
|
||||
"Focused task executor. Same discipline, no delegation. (Sisyphus-Junior - OhMyOpenCode)",
|
||||
mode: "subagent" as const,
|
||||
model,
|
||||
temperature,
|
||||
|
||||
@@ -433,7 +433,7 @@ export function createSisyphusAgent(
|
||||
const permission = { question: "allow", call_omo_agent: "deny" } as AgentConfig["permission"]
|
||||
const base = {
|
||||
description:
|
||||
"Sisyphus - Powerful AI orchestrator from OhMyOpenCode. Plans obsessively with todos, assesses search complexity before exploration, delegates strategically via category+skills combinations. Uses explore for internal code (parallel-friendly), librarian for external docs.",
|
||||
"Powerful AI orchestrator. Plans obsessively with todos, assesses search complexity before exploration, delegates strategically via category+skills combinations. Uses explore for internal code (parallel-friendly), librarian for external docs. (Sisyphus - OhMyOpenCode)",
|
||||
mode: "primary" as const,
|
||||
model,
|
||||
maxTokens: 64000,
|
||||
|
||||
@@ -47,17 +47,16 @@ describe("createBuiltinAgents with model overrides", () => {
|
||||
expect(agents.sisyphus.reasoningEffort).toBeUndefined()
|
||||
})
|
||||
|
||||
test("Oracle uses connected provider when no availableModels but connected cache exists", async () => {
|
||||
// #given - connected providers cache exists with openai
|
||||
test("Oracle uses connected provider fallback when availableModels is empty and cache exists", async () => {
|
||||
// #given - connected providers cache has "openai", which matches oracle's first fallback entry
|
||||
const cacheSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"])
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], {}, undefined, TEST_DEFAULT_MODEL)
|
||||
|
||||
// #then - uses openai from connected cache
|
||||
// #then - oracle resolves via connected cache fallback to openai/gpt-5.2 (not system default)
|
||||
expect(agents.oracle.model).toBe("openai/gpt-5.2")
|
||||
expect(agents.oracle.reasoningEffort).toBe("medium")
|
||||
expect(agents.oracle.textVerbosity).toBe("high")
|
||||
expect(agents.oracle.thinking).toBeUndefined()
|
||||
cacheSpy.mockRestore()
|
||||
})
|
||||
@@ -123,39 +122,39 @@ describe("createBuiltinAgents with model overrides", () => {
|
||||
})
|
||||
|
||||
describe("createBuiltinAgents without systemDefaultModel", () => {
|
||||
test("creates agents with connected provider when cache exists", async () => {
|
||||
// #given - connected providers cache exists
|
||||
test("agents created via connected cache fallback even without systemDefaultModel", async () => {
|
||||
// #given - connected cache has "openai", which matches oracle's fallback chain
|
||||
const cacheSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai"])
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], {}, undefined, undefined)
|
||||
|
||||
// #then - agents should use connected provider from fallback chain
|
||||
// #then - connected cache enables model resolution despite no systemDefaultModel
|
||||
expect(agents.oracle).toBeDefined()
|
||||
expect(agents.oracle.model).toBe("openai/gpt-5.2")
|
||||
cacheSpy.mockRestore()
|
||||
})
|
||||
|
||||
test("agents NOT created when no cache and no systemDefaultModel (first run without defaults)", async () => {
|
||||
// #given - no cache and no system default
|
||||
// #given
|
||||
const cacheSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(null)
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], {}, undefined, undefined)
|
||||
|
||||
// #then - oracle should NOT be created (resolveModelWithFallback returns undefined)
|
||||
// #then
|
||||
expect(agents.oracle).toBeUndefined()
|
||||
cacheSpy.mockRestore()
|
||||
})
|
||||
|
||||
test("sisyphus uses connected provider when cache exists", async () => {
|
||||
// #given - connected providers cache exists with anthropic
|
||||
test("sisyphus created via connected cache fallback even without systemDefaultModel", async () => {
|
||||
// #given - connected cache has "anthropic", which matches sisyphus's first fallback entry
|
||||
const cacheSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["anthropic"])
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], {}, undefined, undefined)
|
||||
|
||||
// #then - sisyphus should use anthropic from connected cache
|
||||
// #then - connected cache enables model resolution despite no systemDefaultModel
|
||||
expect(agents.sisyphus).toBeDefined()
|
||||
expect(agents.sisyphus.model).toBe("anthropic/claude-opus-4-5")
|
||||
cacheSpy.mockRestore()
|
||||
@@ -408,3 +407,119 @@ describe("buildAgent with category and skills", () => {
|
||||
expect(agent.prompt).not.toContain("agent-browser open")
|
||||
})
|
||||
})
|
||||
|
||||
describe("override.category expansion in createBuiltinAgents", () => {
|
||||
test("standard agent override with category expands category properties", async () => {
|
||||
// #given
|
||||
const overrides = {
|
||||
oracle: { category: "ultrabrain" } as any,
|
||||
}
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], overrides, undefined, TEST_DEFAULT_MODEL)
|
||||
|
||||
// #then - ultrabrain category: model=openai/gpt-5.2-codex, variant=xhigh
|
||||
expect(agents.oracle).toBeDefined()
|
||||
expect(agents.oracle.model).toBe("openai/gpt-5.2-codex")
|
||||
expect(agents.oracle.variant).toBe("xhigh")
|
||||
})
|
||||
|
||||
test("standard agent override with category AND direct variant - direct wins", async () => {
|
||||
// #given - ultrabrain has variant=xhigh, but direct override says "max"
|
||||
const overrides = {
|
||||
oracle: { category: "ultrabrain", variant: "max" } as any,
|
||||
}
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], overrides, undefined, TEST_DEFAULT_MODEL)
|
||||
|
||||
// #then - direct variant overrides category variant
|
||||
expect(agents.oracle).toBeDefined()
|
||||
expect(agents.oracle.variant).toBe("max")
|
||||
})
|
||||
|
||||
test("standard agent override with category AND direct reasoningEffort - direct wins", async () => {
|
||||
// #given - custom category has reasoningEffort=xhigh, direct override says "low"
|
||||
const categories = {
|
||||
"test-cat": {
|
||||
model: "openai/gpt-5.2",
|
||||
reasoningEffort: "xhigh" as const,
|
||||
},
|
||||
}
|
||||
const overrides = {
|
||||
oracle: { category: "test-cat", reasoningEffort: "low" } as any,
|
||||
}
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], overrides, undefined, TEST_DEFAULT_MODEL, categories)
|
||||
|
||||
// #then - direct reasoningEffort wins over category
|
||||
expect(agents.oracle).toBeDefined()
|
||||
expect(agents.oracle.reasoningEffort).toBe("low")
|
||||
})
|
||||
|
||||
test("standard agent override with category applies reasoningEffort from category when no direct override", async () => {
|
||||
// #given - custom category has reasoningEffort, no direct reasoningEffort in override
|
||||
const categories = {
|
||||
"reasoning-cat": {
|
||||
model: "openai/gpt-5.2",
|
||||
reasoningEffort: "high" as const,
|
||||
},
|
||||
}
|
||||
const overrides = {
|
||||
oracle: { category: "reasoning-cat" } as any,
|
||||
}
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], overrides, undefined, TEST_DEFAULT_MODEL, categories)
|
||||
|
||||
// #then - category reasoningEffort is applied
|
||||
expect(agents.oracle).toBeDefined()
|
||||
expect(agents.oracle.reasoningEffort).toBe("high")
|
||||
})
|
||||
|
||||
test("sisyphus override with category expands category properties", async () => {
|
||||
// #given
|
||||
const overrides = {
|
||||
sisyphus: { category: "ultrabrain" } as any,
|
||||
}
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], overrides, undefined, TEST_DEFAULT_MODEL)
|
||||
|
||||
// #then - ultrabrain category: model=openai/gpt-5.2-codex, variant=xhigh
|
||||
expect(agents.sisyphus).toBeDefined()
|
||||
expect(agents.sisyphus.model).toBe("openai/gpt-5.2-codex")
|
||||
expect(agents.sisyphus.variant).toBe("xhigh")
|
||||
})
|
||||
|
||||
test("atlas override with category expands category properties", async () => {
|
||||
// #given
|
||||
const overrides = {
|
||||
atlas: { category: "ultrabrain" } as any,
|
||||
}
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], overrides, undefined, TEST_DEFAULT_MODEL)
|
||||
|
||||
// #then - ultrabrain category: model=openai/gpt-5.2-codex, variant=xhigh
|
||||
expect(agents.atlas).toBeDefined()
|
||||
expect(agents.atlas.model).toBe("openai/gpt-5.2-codex")
|
||||
expect(agents.atlas.variant).toBe("xhigh")
|
||||
})
|
||||
|
||||
test("override with non-existent category has no effect on config", async () => {
|
||||
// #given
|
||||
const overrides = {
|
||||
oracle: { category: "non-existent-category" } as any,
|
||||
}
|
||||
|
||||
// #when
|
||||
const agents = await createBuiltinAgents([], overrides, undefined, TEST_DEFAULT_MODEL)
|
||||
|
||||
// #then - no category-specific variant/reasoningEffort applied from non-existent category
|
||||
expect(agents.oracle).toBeDefined()
|
||||
const agentsWithoutOverride = await createBuiltinAgents([], {}, undefined, TEST_DEFAULT_MODEL)
|
||||
expect(agents.oracle.model).toBe(agentsWithoutOverride.oracle.model)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -120,6 +120,33 @@ export function createEnvContext(): string {
|
||||
</omo-env>`
|
||||
}
|
||||
|
||||
/**
|
||||
* Expands a category reference from an agent override into concrete config properties.
|
||||
* Category properties are applied unconditionally (overwriting factory defaults),
|
||||
* because the user's chosen category should take priority over factory base values.
|
||||
* Direct override properties applied later via mergeAgentConfig() will supersede these.
|
||||
*/
|
||||
function applyCategoryOverride(
|
||||
config: AgentConfig,
|
||||
categoryName: string,
|
||||
mergedCategories: Record<string, CategoryConfig>
|
||||
): AgentConfig {
|
||||
const categoryConfig = mergedCategories[categoryName]
|
||||
if (!categoryConfig) return config
|
||||
|
||||
const result = { ...config } as AgentConfig & Record<string, unknown>
|
||||
if (categoryConfig.model) result.model = categoryConfig.model
|
||||
if (categoryConfig.variant !== undefined) result.variant = categoryConfig.variant
|
||||
if (categoryConfig.temperature !== undefined) result.temperature = categoryConfig.temperature
|
||||
if (categoryConfig.reasoningEffort !== undefined) result.reasoningEffort = categoryConfig.reasoningEffort
|
||||
if (categoryConfig.textVerbosity !== undefined) result.textVerbosity = categoryConfig.textVerbosity
|
||||
if (categoryConfig.thinking !== undefined) result.thinking = categoryConfig.thinking
|
||||
if (categoryConfig.top_p !== undefined) result.top_p = categoryConfig.top_p
|
||||
if (categoryConfig.maxTokens !== undefined) result.maxTokens = categoryConfig.maxTokens
|
||||
|
||||
return result as AgentConfig
|
||||
}
|
||||
|
||||
function mergeAgentConfig(
|
||||
base: AgentConfig,
|
||||
override: AgentOverrideConfig
|
||||
@@ -149,7 +176,8 @@ export async function createBuiltinAgents(
|
||||
gitMasterConfig?: GitMasterConfig,
|
||||
discoveredSkills: LoadedSkill[] = [],
|
||||
client?: any,
|
||||
browserProvider?: BrowserAutomationProvider
|
||||
browserProvider?: BrowserAutomationProvider,
|
||||
uiSelectedModel?: string
|
||||
): Promise<Record<string, AgentConfig>> {
|
||||
const connectedProviders = readConnectedProvidersCache()
|
||||
const availableModels = client
|
||||
@@ -198,6 +226,7 @@ export async function createBuiltinAgents(
|
||||
const requirement = AGENT_MODEL_REQUIREMENTS[agentName]
|
||||
|
||||
const resolution = resolveModelWithFallback({
|
||||
uiSelectedModel,
|
||||
userModel: override?.model,
|
||||
fallbackChain: requirement?.fallbackChain,
|
||||
availableModels,
|
||||
@@ -208,18 +237,23 @@ export async function createBuiltinAgents(
|
||||
|
||||
let config = buildAgent(source, model, mergedCategories, gitMasterConfig, browserProvider)
|
||||
|
||||
// Apply variant from override or resolved fallback chain
|
||||
if (override?.variant) {
|
||||
config = { ...config, variant: override.variant }
|
||||
} else if (resolvedVariant) {
|
||||
// Apply resolved variant from model fallback chain
|
||||
if (resolvedVariant) {
|
||||
config = { ...config, variant: resolvedVariant }
|
||||
}
|
||||
|
||||
// Expand override.category into concrete properties (higher priority than factory/resolved)
|
||||
const overrideCategory = (override as Record<string, unknown> | undefined)?.category as string | undefined
|
||||
if (overrideCategory) {
|
||||
config = applyCategoryOverride(config, overrideCategory, mergedCategories)
|
||||
}
|
||||
|
||||
if (agentName === "librarian" && directory && config.prompt) {
|
||||
const envContext = createEnvContext()
|
||||
config = { ...config, prompt: config.prompt + envContext }
|
||||
}
|
||||
|
||||
// Direct override properties take highest priority
|
||||
if (override) {
|
||||
config = mergeAgentConfig(config, override)
|
||||
}
|
||||
@@ -241,6 +275,7 @@ export async function createBuiltinAgents(
|
||||
const sisyphusRequirement = AGENT_MODEL_REQUIREMENTS["sisyphus"]
|
||||
|
||||
const sisyphusResolution = resolveModelWithFallback({
|
||||
uiSelectedModel,
|
||||
userModel: sisyphusOverride?.model,
|
||||
fallbackChain: sisyphusRequirement?.fallbackChain,
|
||||
availableModels,
|
||||
@@ -258,12 +293,15 @@ export async function createBuiltinAgents(
|
||||
availableCategories
|
||||
)
|
||||
|
||||
if (sisyphusOverride?.variant) {
|
||||
sisyphusConfig = { ...sisyphusConfig, variant: sisyphusOverride.variant }
|
||||
} else if (sisyphusResolvedVariant) {
|
||||
if (sisyphusResolvedVariant) {
|
||||
sisyphusConfig = { ...sisyphusConfig, variant: sisyphusResolvedVariant }
|
||||
}
|
||||
|
||||
const sisOverrideCategory = (sisyphusOverride as Record<string, unknown> | undefined)?.category as string | undefined
|
||||
if (sisOverrideCategory) {
|
||||
sisyphusConfig = applyCategoryOverride(sisyphusConfig, sisOverrideCategory, mergedCategories)
|
||||
}
|
||||
|
||||
if (directory && sisyphusConfig.prompt) {
|
||||
const envContext = createEnvContext()
|
||||
sisyphusConfig = { ...sisyphusConfig, prompt: sisyphusConfig.prompt + envContext }
|
||||
@@ -282,6 +320,7 @@ export async function createBuiltinAgents(
|
||||
const atlasRequirement = AGENT_MODEL_REQUIREMENTS["atlas"]
|
||||
|
||||
const atlasResolution = resolveModelWithFallback({
|
||||
uiSelectedModel,
|
||||
userModel: orchestratorOverride?.model,
|
||||
fallbackChain: atlasRequirement?.fallbackChain,
|
||||
availableModels,
|
||||
@@ -298,12 +337,15 @@ export async function createBuiltinAgents(
|
||||
userCategories: categories,
|
||||
})
|
||||
|
||||
if (orchestratorOverride?.variant) {
|
||||
orchestratorConfig = { ...orchestratorConfig, variant: orchestratorOverride.variant }
|
||||
} else if (atlasResolvedVariant) {
|
||||
if (atlasResolvedVariant) {
|
||||
orchestratorConfig = { ...orchestratorConfig, variant: atlasResolvedVariant }
|
||||
}
|
||||
|
||||
const atlasOverrideCategory = (orchestratorOverride as Record<string, unknown> | undefined)?.category as string | undefined
|
||||
if (atlasOverrideCategory) {
|
||||
orchestratorConfig = applyCategoryOverride(orchestratorConfig, atlasOverrideCategory, mergedCategories)
|
||||
}
|
||||
|
||||
if (orchestratorOverride) {
|
||||
orchestratorConfig = mergeAgentConfig(orchestratorConfig, orchestratorOverride)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import { getDependencyCheckDefinitions } from "./dependencies"
|
||||
import { getGhCliCheckDefinition } from "./gh"
|
||||
import { getLspCheckDefinition } from "./lsp"
|
||||
import { getMcpCheckDefinitions } from "./mcp"
|
||||
import { getMcpOAuthCheckDefinition } from "./mcp-oauth"
|
||||
import { getVersionCheckDefinition } from "./version"
|
||||
|
||||
export * from "./opencode"
|
||||
@@ -19,6 +20,7 @@ export * from "./dependencies"
|
||||
export * from "./gh"
|
||||
export * from "./lsp"
|
||||
export * from "./mcp"
|
||||
export * from "./mcp-oauth"
|
||||
export * from "./version"
|
||||
|
||||
export function getAllCheckDefinitions(): CheckDefinition[] {
|
||||
@@ -32,6 +34,7 @@ export function getAllCheckDefinitions(): CheckDefinition[] {
|
||||
getGhCliCheckDefinition(),
|
||||
getLspCheckDefinition(),
|
||||
...getMcpCheckDefinitions(),
|
||||
getMcpOAuthCheckDefinition(),
|
||||
getVersionCheckDefinition(),
|
||||
]
|
||||
}
|
||||
|
||||
133
src/cli/doctor/checks/mcp-oauth.test.ts
Normal file
133
src/cli/doctor/checks/mcp-oauth.test.ts
Normal file
@@ -0,0 +1,133 @@
|
||||
import { describe, it, expect, spyOn, afterEach } from "bun:test"
|
||||
import * as mcpOauth from "./mcp-oauth"
|
||||
|
||||
describe("mcp-oauth check", () => {
|
||||
describe("getMcpOAuthCheckDefinition", () => {
|
||||
it("returns check definition with correct properties", () => {
|
||||
// #given
|
||||
// #when getting definition
|
||||
const def = mcpOauth.getMcpOAuthCheckDefinition()
|
||||
|
||||
// #then should have correct structure
|
||||
expect(def.id).toBe("mcp-oauth-tokens")
|
||||
expect(def.name).toBe("MCP OAuth Tokens")
|
||||
expect(def.category).toBe("tools")
|
||||
expect(def.critical).toBe(false)
|
||||
expect(typeof def.check).toBe("function")
|
||||
})
|
||||
})
|
||||
|
||||
describe("checkMcpOAuthTokens", () => {
|
||||
let readStoreSpy: ReturnType<typeof spyOn>
|
||||
|
||||
afterEach(() => {
|
||||
readStoreSpy?.mockRestore()
|
||||
})
|
||||
|
||||
it("returns skip when no tokens stored", async () => {
|
||||
// #given no OAuth tokens configured
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue(null)
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should skip
|
||||
expect(result.status).toBe("skip")
|
||||
expect(result.message).toContain("No OAuth")
|
||||
})
|
||||
|
||||
it("returns pass when all tokens valid", async () => {
|
||||
// #given valid tokens with future expiry (expiresAt is in epoch seconds)
|
||||
const futureTime = Math.floor(Date.now() / 1000) + 3600
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({
|
||||
"example.com/resource1": {
|
||||
accessToken: "token1",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
"example.com/resource2": {
|
||||
accessToken: "token2",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
})
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should pass
|
||||
expect(result.status).toBe("pass")
|
||||
expect(result.message).toContain("2")
|
||||
expect(result.message).toContain("valid")
|
||||
})
|
||||
|
||||
it("returns warn when some tokens expired", async () => {
|
||||
// #given mix of valid and expired tokens (expiresAt is in epoch seconds)
|
||||
const futureTime = Math.floor(Date.now() / 1000) + 3600
|
||||
const pastTime = Math.floor(Date.now() / 1000) - 3600
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({
|
||||
"example.com/resource1": {
|
||||
accessToken: "token1",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
"example.com/resource2": {
|
||||
accessToken: "token2",
|
||||
expiresAt: pastTime,
|
||||
},
|
||||
})
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should warn
|
||||
expect(result.status).toBe("warn")
|
||||
expect(result.message).toContain("1")
|
||||
expect(result.message).toContain("expired")
|
||||
expect(result.details?.some((d: string) => d.includes("Expired"))).toBe(
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it("returns pass when tokens have no expiry", async () => {
|
||||
// #given tokens without expiry info
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({
|
||||
"example.com/resource1": {
|
||||
accessToken: "token1",
|
||||
},
|
||||
})
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should pass (no expiry = assume valid)
|
||||
expect(result.status).toBe("pass")
|
||||
expect(result.message).toContain("1")
|
||||
})
|
||||
|
||||
it("includes token details in output", async () => {
|
||||
// #given multiple tokens
|
||||
const futureTime = Math.floor(Date.now() / 1000) + 3600
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({
|
||||
"api.example.com/v1": {
|
||||
accessToken: "token1",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
"auth.example.com/oauth": {
|
||||
accessToken: "token2",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
})
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should list tokens in details
|
||||
expect(result.details).toBeDefined()
|
||||
expect(result.details?.length).toBeGreaterThan(0)
|
||||
expect(
|
||||
result.details?.some((d: string) => d.includes("api.example.com"))
|
||||
).toBe(true)
|
||||
expect(
|
||||
result.details?.some((d: string) => d.includes("auth.example.com"))
|
||||
).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
80
src/cli/doctor/checks/mcp-oauth.ts
Normal file
80
src/cli/doctor/checks/mcp-oauth.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
import type { CheckResult, CheckDefinition } from "../types"
|
||||
import { CHECK_IDS, CHECK_NAMES } from "../constants"
|
||||
import { getMcpOauthStoragePath } from "../../../features/mcp-oauth/storage"
|
||||
import { existsSync, readFileSync } from "node:fs"
|
||||
|
||||
interface OAuthTokenData {
|
||||
accessToken: string
|
||||
refreshToken?: string
|
||||
expiresAt?: number
|
||||
clientInfo?: {
|
||||
clientId: string
|
||||
clientSecret?: string
|
||||
}
|
||||
}
|
||||
|
||||
type TokenStore = Record<string, OAuthTokenData>
|
||||
|
||||
export function readTokenStore(): TokenStore | null {
|
||||
const filePath = getMcpOauthStoragePath()
|
||||
if (!existsSync(filePath)) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(filePath, "utf-8")
|
||||
return JSON.parse(content) as TokenStore
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkMcpOAuthTokens(): Promise<CheckResult> {
|
||||
const store = readTokenStore()
|
||||
|
||||
if (!store || Object.keys(store).length === 0) {
|
||||
return {
|
||||
name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS],
|
||||
status: "skip",
|
||||
message: "No OAuth tokens configured",
|
||||
details: ["Optional: Configure OAuth tokens for MCP servers"],
|
||||
}
|
||||
}
|
||||
|
||||
const now = Math.floor(Date.now() / 1000)
|
||||
const tokens = Object.entries(store)
|
||||
const expiredTokens = tokens.filter(
|
||||
([, token]) => token.expiresAt && token.expiresAt < now
|
||||
)
|
||||
|
||||
if (expiredTokens.length > 0) {
|
||||
return {
|
||||
name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS],
|
||||
status: "warn",
|
||||
message: `${expiredTokens.length} of ${tokens.length} token(s) expired`,
|
||||
details: [
|
||||
...tokens
|
||||
.filter(([, token]) => !token.expiresAt || token.expiresAt >= now)
|
||||
.map(([key]) => `Valid: ${key}`),
|
||||
...expiredTokens.map(([key]) => `Expired: ${key}`),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS],
|
||||
status: "pass",
|
||||
message: `${tokens.length} OAuth token(s) valid`,
|
||||
details: tokens.map(([key]) => `Configured: ${key}`),
|
||||
}
|
||||
}
|
||||
|
||||
export function getMcpOAuthCheckDefinition(): CheckDefinition {
|
||||
return {
|
||||
id: CHECK_IDS.MCP_OAUTH_TOKENS,
|
||||
name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS],
|
||||
category: "tools",
|
||||
check: checkMcpOAuthTokens,
|
||||
critical: false,
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,7 @@ export const CHECK_IDS = {
|
||||
LSP_SERVERS: "lsp-servers",
|
||||
MCP_BUILTIN: "mcp-builtin",
|
||||
MCP_USER: "mcp-user",
|
||||
MCP_OAUTH_TOKENS: "mcp-oauth-tokens",
|
||||
VERSION_STATUS: "version-status",
|
||||
} as const
|
||||
|
||||
@@ -50,6 +51,7 @@ export const CHECK_NAMES: Record<string, string> = {
|
||||
[CHECK_IDS.LSP_SERVERS]: "LSP Servers",
|
||||
[CHECK_IDS.MCP_BUILTIN]: "Built-in MCP Servers",
|
||||
[CHECK_IDS.MCP_USER]: "User MCP Configuration",
|
||||
[CHECK_IDS.MCP_OAUTH_TOKENS]: "MCP OAuth Tokens",
|
||||
[CHECK_IDS.VERSION_STATUS]: "Version Status",
|
||||
} as const
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import { install } from "./install"
|
||||
import { run } from "./run"
|
||||
import { getLocalVersion } from "./get-local-version"
|
||||
import { doctor } from "./doctor"
|
||||
import { createMcpOAuthCommand } from "./mcp-oauth"
|
||||
import type { InstallArgs } from "./types"
|
||||
import type { RunOptions } from "./run"
|
||||
import type { GetLocalVersionOptions } from "./get-local-version/types"
|
||||
@@ -150,4 +151,6 @@ program
|
||||
console.log(`oh-my-opencode v${VERSION}`)
|
||||
})
|
||||
|
||||
program.addCommand(createMcpOAuthCommand())
|
||||
|
||||
program.parse()
|
||||
|
||||
123
src/cli/mcp-oauth/index.test.ts
Normal file
123
src/cli/mcp-oauth/index.test.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
import { describe, it, expect } from "bun:test"
|
||||
import { Command } from "commander"
|
||||
import { createMcpOAuthCommand } from "./index"
|
||||
|
||||
describe("mcp oauth command", () => {
|
||||
|
||||
describe("command structure", () => {
|
||||
it("creates mcp command group with oauth subcommand", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
|
||||
// when
|
||||
const subcommands = mcpCommand.commands.map((cmd: Command) => cmd.name())
|
||||
|
||||
// then
|
||||
expect(subcommands).toContain("oauth")
|
||||
})
|
||||
|
||||
it("oauth subcommand has login, logout, and status subcommands", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
|
||||
// when
|
||||
const subcommands = oauthCommand?.commands.map((cmd: Command) => cmd.name()) ?? []
|
||||
|
||||
// then
|
||||
expect(subcommands).toContain("login")
|
||||
expect(subcommands).toContain("logout")
|
||||
expect(subcommands).toContain("status")
|
||||
})
|
||||
})
|
||||
|
||||
describe("login subcommand", () => {
|
||||
it("exists and has description", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login")
|
||||
|
||||
// when
|
||||
const description = loginCommand?.description() ?? ""
|
||||
|
||||
// then
|
||||
expect(loginCommand).toBeDefined()
|
||||
expect(description).toContain("OAuth")
|
||||
})
|
||||
|
||||
it("accepts --server-url option", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login")
|
||||
|
||||
// when
|
||||
const options = loginCommand?.options ?? []
|
||||
const serverUrlOption = options.find((opt: { long?: string }) => opt.long === "--server-url")
|
||||
|
||||
// then
|
||||
expect(serverUrlOption).toBeDefined()
|
||||
})
|
||||
|
||||
it("accepts --client-id option", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login")
|
||||
|
||||
// when
|
||||
const options = loginCommand?.options ?? []
|
||||
const clientIdOption = options.find((opt: { long?: string }) => opt.long === "--client-id")
|
||||
|
||||
// then
|
||||
expect(clientIdOption).toBeDefined()
|
||||
})
|
||||
|
||||
it("accepts --scopes option", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login")
|
||||
|
||||
// when
|
||||
const options = loginCommand?.options ?? []
|
||||
const scopesOption = options.find((opt: { long?: string }) => opt.long === "--scopes")
|
||||
|
||||
// then
|
||||
expect(scopesOption).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe("logout subcommand", () => {
|
||||
it("exists and has description", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const logoutCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "logout")
|
||||
|
||||
// when
|
||||
const description = logoutCommand?.description() ?? ""
|
||||
|
||||
// then
|
||||
expect(logoutCommand).toBeDefined()
|
||||
expect(description).toContain("tokens")
|
||||
})
|
||||
})
|
||||
|
||||
describe("status subcommand", () => {
|
||||
it("exists and has description", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const statusCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "status")
|
||||
|
||||
// when
|
||||
const description = statusCommand?.description() ?? ""
|
||||
|
||||
// then
|
||||
expect(statusCommand).toBeDefined()
|
||||
expect(description).toContain("status")
|
||||
})
|
||||
})
|
||||
})
|
||||
43
src/cli/mcp-oauth/index.ts
Normal file
43
src/cli/mcp-oauth/index.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { Command } from "commander"
|
||||
import { login } from "./login"
|
||||
import { logout } from "./logout"
|
||||
import { status } from "./status"
|
||||
|
||||
export function createMcpOAuthCommand(): Command {
|
||||
const mcp = new Command("mcp").description("MCP server management")
|
||||
|
||||
const oauth = new Command("oauth").description("OAuth token management for MCP servers")
|
||||
|
||||
oauth
|
||||
.command("login <server-name>")
|
||||
.description("Authenticate with an MCP server using OAuth")
|
||||
.option("--server-url <url>", "OAuth server URL (required if not in config)")
|
||||
.option("--client-id <id>", "OAuth client ID (optional, uses DCR if not provided)")
|
||||
.option("--scopes <scopes...>", "OAuth scopes to request")
|
||||
.action(async (serverName: string, options) => {
|
||||
const exitCode = await login(serverName, options)
|
||||
process.exit(exitCode)
|
||||
})
|
||||
|
||||
oauth
|
||||
.command("logout <server-name>")
|
||||
.description("Remove stored OAuth tokens for an MCP server")
|
||||
.option("--server-url <url>", "OAuth server URL (use if server name differs from URL)")
|
||||
.action(async (serverName: string, options) => {
|
||||
const exitCode = await logout(serverName, options)
|
||||
process.exit(exitCode)
|
||||
})
|
||||
|
||||
oauth
|
||||
.command("status [server-name]")
|
||||
.description("Show OAuth token status for MCP servers")
|
||||
.action(async (serverName: string | undefined) => {
|
||||
const exitCode = await status(serverName)
|
||||
process.exit(exitCode)
|
||||
})
|
||||
|
||||
mcp.addCommand(oauth)
|
||||
return mcp
|
||||
}
|
||||
|
||||
export { login, logout, status }
|
||||
80
src/cli/mcp-oauth/login.test.ts
Normal file
80
src/cli/mcp-oauth/login.test.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
import { describe, it, expect, beforeEach, afterEach, mock } from "bun:test"
|
||||
|
||||
const mockLogin = mock(() => Promise.resolve({ accessToken: "test-token", expiresAt: 1710000000 }))
|
||||
|
||||
mock.module("../../features/mcp-oauth/provider", () => ({
|
||||
McpOAuthProvider: class MockMcpOAuthProvider {
|
||||
constructor(public options: { serverUrl: string; clientId?: string; scopes?: string[] }) {}
|
||||
async login() {
|
||||
return mockLogin()
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
const { login } = await import("./login")
|
||||
|
||||
describe("login command", () => {
|
||||
beforeEach(() => {
|
||||
mockLogin.mockClear()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
// cleanup
|
||||
})
|
||||
|
||||
it("returns error code when server-url is not provided", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
const options = {}
|
||||
|
||||
// when
|
||||
const exitCode = await login(serverName, options)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(1)
|
||||
})
|
||||
|
||||
it("returns success code when login succeeds", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
const options = {
|
||||
serverUrl: "https://oauth.example.com",
|
||||
}
|
||||
|
||||
// when
|
||||
const exitCode = await login(serverName, options)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(0)
|
||||
expect(mockLogin).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it("returns error code when login throws", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
const options = {
|
||||
serverUrl: "https://oauth.example.com",
|
||||
}
|
||||
mockLogin.mockRejectedValueOnce(new Error("Network error"))
|
||||
|
||||
// when
|
||||
const exitCode = await login(serverName, options)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(1)
|
||||
})
|
||||
|
||||
it("returns error code when server-url is missing", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
const options = {
|
||||
clientId: "test-client-id",
|
||||
}
|
||||
|
||||
// when
|
||||
const exitCode = await login(serverName, options)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(1)
|
||||
})
|
||||
})
|
||||
38
src/cli/mcp-oauth/login.ts
Normal file
38
src/cli/mcp-oauth/login.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { McpOAuthProvider } from "../../features/mcp-oauth/provider"
|
||||
|
||||
export interface LoginOptions {
|
||||
serverUrl?: string
|
||||
clientId?: string
|
||||
scopes?: string[]
|
||||
}
|
||||
|
||||
export async function login(serverName: string, options: LoginOptions): Promise<number> {
|
||||
try {
|
||||
const serverUrl = options.serverUrl
|
||||
if (!serverUrl) {
|
||||
console.error(`Error: --server-url is required for server "${serverName}"`)
|
||||
return 1
|
||||
}
|
||||
|
||||
const provider = new McpOAuthProvider({
|
||||
serverUrl,
|
||||
clientId: options.clientId,
|
||||
scopes: options.scopes,
|
||||
})
|
||||
|
||||
console.log(`Authenticating with ${serverName}...`)
|
||||
const tokenData = await provider.login()
|
||||
|
||||
console.log(`✓ Successfully authenticated with ${serverName}`)
|
||||
if (tokenData.expiresAt) {
|
||||
const expiryDate = new Date(tokenData.expiresAt * 1000)
|
||||
console.log(` Token expires at: ${expiryDate.toISOString()}`)
|
||||
}
|
||||
|
||||
return 0
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
console.error(`Error: Failed to authenticate with ${serverName}: ${message}`)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
65
src/cli/mcp-oauth/logout.test.ts
Normal file
65
src/cli/mcp-oauth/logout.test.ts
Normal file
@@ -0,0 +1,65 @@
|
||||
import { describe, it, expect, beforeEach, afterEach, mock } from "bun:test"
|
||||
import { existsSync, mkdirSync, rmSync } from "node:fs"
|
||||
import { join } from "node:path"
|
||||
import { tmpdir } from "node:os"
|
||||
import { saveToken } from "../../features/mcp-oauth/storage"
|
||||
|
||||
const { logout } = await import("./logout")
|
||||
|
||||
describe("logout command", () => {
|
||||
const TEST_CONFIG_DIR = join(tmpdir(), "mcp-oauth-logout-test-" + Date.now())
|
||||
let originalConfigDir: string | undefined
|
||||
|
||||
beforeEach(() => {
|
||||
originalConfigDir = process.env.OPENCODE_CONFIG_DIR
|
||||
process.env.OPENCODE_CONFIG_DIR = TEST_CONFIG_DIR
|
||||
if (!existsSync(TEST_CONFIG_DIR)) {
|
||||
mkdirSync(TEST_CONFIG_DIR, { recursive: true })
|
||||
}
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
if (originalConfigDir === undefined) {
|
||||
delete process.env.OPENCODE_CONFIG_DIR
|
||||
} else {
|
||||
process.env.OPENCODE_CONFIG_DIR = originalConfigDir
|
||||
}
|
||||
if (existsSync(TEST_CONFIG_DIR)) {
|
||||
rmSync(TEST_CONFIG_DIR, { recursive: true, force: true })
|
||||
}
|
||||
})
|
||||
|
||||
it("returns success code when logout succeeds", async () => {
|
||||
// given
|
||||
const serverUrl = "https://test-server.example.com"
|
||||
saveToken(serverUrl, serverUrl, { accessToken: "test-token" })
|
||||
|
||||
// when
|
||||
const exitCode = await logout("test-server", { serverUrl })
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
|
||||
it("handles non-existent server gracefully", async () => {
|
||||
// given
|
||||
const serverName = "non-existent-server"
|
||||
|
||||
// when
|
||||
const exitCode = await logout(serverName, { serverUrl: "https://nonexistent.example.com" })
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
|
||||
it("returns error when --server-url is not provided", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
|
||||
// when
|
||||
const exitCode = await logout(serverName)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(1)
|
||||
})
|
||||
})
|
||||
30
src/cli/mcp-oauth/logout.ts
Normal file
30
src/cli/mcp-oauth/logout.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import { deleteToken } from "../../features/mcp-oauth/storage"
|
||||
|
||||
export interface LogoutOptions {
|
||||
serverUrl?: string
|
||||
}
|
||||
|
||||
export async function logout(serverName: string, options?: LogoutOptions): Promise<number> {
|
||||
try {
|
||||
const serverUrl = options?.serverUrl
|
||||
if (!serverUrl) {
|
||||
console.error(`Error: --server-url is required for logout. Token storage uses server URLs, not names.`)
|
||||
console.error(` Usage: mcp oauth logout ${serverName} --server-url https://your-server.example.com`)
|
||||
return 1
|
||||
}
|
||||
|
||||
const success = deleteToken(serverUrl, serverUrl)
|
||||
|
||||
if (success) {
|
||||
console.log(`✓ Successfully removed tokens for ${serverName}`)
|
||||
return 0
|
||||
}
|
||||
|
||||
console.error(`Error: Failed to remove tokens for ${serverName}`)
|
||||
return 1
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
console.error(`Error: Failed to remove tokens for ${serverName}: ${message}`)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
48
src/cli/mcp-oauth/status.test.ts
Normal file
48
src/cli/mcp-oauth/status.test.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import { describe, it, expect, beforeEach, afterEach } from "bun:test"
|
||||
import { status } from "./status"
|
||||
|
||||
describe("status command", () => {
|
||||
beforeEach(() => {
|
||||
// setup
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
// cleanup
|
||||
})
|
||||
|
||||
it("returns success code when checking status for specific server", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
|
||||
// when
|
||||
const exitCode = await status(serverName)
|
||||
|
||||
// then
|
||||
expect(typeof exitCode).toBe("number")
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
|
||||
it("returns success code when checking status for all servers", async () => {
|
||||
// given
|
||||
const serverName = undefined
|
||||
|
||||
// when
|
||||
const exitCode = await status(serverName)
|
||||
|
||||
// then
|
||||
expect(typeof exitCode).toBe("number")
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
|
||||
it("handles non-existent server gracefully", async () => {
|
||||
// given
|
||||
const serverName = "non-existent-server"
|
||||
|
||||
// when
|
||||
const exitCode = await status(serverName)
|
||||
|
||||
// then
|
||||
expect(typeof exitCode).toBe("number")
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
})
|
||||
50
src/cli/mcp-oauth/status.ts
Normal file
50
src/cli/mcp-oauth/status.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
import { listAllTokens, listTokensByHost } from "../../features/mcp-oauth/storage"
|
||||
|
||||
export async function status(serverName: string | undefined): Promise<number> {
|
||||
try {
|
||||
if (serverName) {
|
||||
const tokens = listTokensByHost(serverName)
|
||||
|
||||
if (Object.keys(tokens).length === 0) {
|
||||
console.log(`No tokens found for ${serverName}`)
|
||||
return 0
|
||||
}
|
||||
|
||||
console.log(`OAuth Status for ${serverName}:`)
|
||||
for (const [key, token] of Object.entries(tokens)) {
|
||||
console.log(` ${key}:`)
|
||||
console.log(` Access Token: [REDACTED]`)
|
||||
if (token.refreshToken) {
|
||||
console.log(` Refresh Token: [REDACTED]`)
|
||||
}
|
||||
if (token.expiresAt) {
|
||||
const expiryDate = new Date(token.expiresAt * 1000)
|
||||
const now = Date.now() / 1000
|
||||
const isExpired = token.expiresAt < now
|
||||
const tokenStatus = isExpired ? "EXPIRED" : "VALID"
|
||||
console.log(` Expiry: ${expiryDate.toISOString()} (${tokenStatus})`)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const tokens = listAllTokens()
|
||||
if (Object.keys(tokens).length === 0) {
|
||||
console.log("No OAuth tokens stored")
|
||||
return 0
|
||||
}
|
||||
|
||||
console.log("Stored OAuth Tokens:")
|
||||
for (const [key, token] of Object.entries(tokens)) {
|
||||
const isExpired = token.expiresAt && token.expiresAt < Date.now() / 1000
|
||||
const tokenStatus = isExpired ? "EXPIRED" : "VALID"
|
||||
console.log(` ${key}: ${tokenStatus}`)
|
||||
}
|
||||
|
||||
return 0
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
console.error(`Error: Failed to get token status: ${message}`)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
@@ -82,6 +82,7 @@ describe("createEventState", () => {
|
||||
expect(state.lastOutput).toBe("")
|
||||
expect(state.lastPartText).toBe("")
|
||||
expect(state.currentTool).toBe(null)
|
||||
expect(state.hasReceivedMeaningfulWork).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -126,6 +127,119 @@ describe("event handling", () => {
|
||||
expect(state.mainSessionIdle).toBe(false)
|
||||
})
|
||||
|
||||
it("hasReceivedMeaningfulWork is false initially after session.idle", async () => {
|
||||
// #given - session goes idle without any assistant output (race condition scenario)
|
||||
const ctx = createMockContext("my-session")
|
||||
const state = createEventState()
|
||||
|
||||
const payload: EventPayload = {
|
||||
type: "session.idle",
|
||||
properties: { sessionID: "my-session" },
|
||||
}
|
||||
|
||||
const events = toAsyncIterable([payload])
|
||||
const { processEvents } = await import("./events")
|
||||
|
||||
// #when
|
||||
await processEvents(ctx, events, state)
|
||||
|
||||
// #then - idle but no meaningful work yet
|
||||
expect(state.mainSessionIdle).toBe(true)
|
||||
expect(state.hasReceivedMeaningfulWork).toBe(false)
|
||||
})
|
||||
|
||||
it("message.updated with assistant role sets hasReceivedMeaningfulWork", async () => {
|
||||
// #given
|
||||
const ctx = createMockContext("my-session")
|
||||
const state = createEventState()
|
||||
|
||||
const payload: EventPayload = {
|
||||
type: "message.updated",
|
||||
properties: {
|
||||
info: { sessionID: "my-session", role: "assistant" },
|
||||
},
|
||||
}
|
||||
|
||||
const events = toAsyncIterable([payload])
|
||||
const { processEvents } = await import("./events")
|
||||
|
||||
// #when
|
||||
await processEvents(ctx, events, state)
|
||||
|
||||
// #then
|
||||
expect(state.hasReceivedMeaningfulWork).toBe(true)
|
||||
})
|
||||
|
||||
it("message.updated with user role does not set hasReceivedMeaningfulWork", async () => {
|
||||
// #given - user message should not count as meaningful work
|
||||
const ctx = createMockContext("my-session")
|
||||
const state = createEventState()
|
||||
|
||||
const payload: EventPayload = {
|
||||
type: "message.updated",
|
||||
properties: {
|
||||
info: { sessionID: "my-session", role: "user" },
|
||||
},
|
||||
}
|
||||
|
||||
const events = toAsyncIterable([payload])
|
||||
const { processEvents } = await import("./events")
|
||||
|
||||
// #when
|
||||
await processEvents(ctx, events, state)
|
||||
|
||||
// #then - user role should not count as meaningful work
|
||||
expect(state.hasReceivedMeaningfulWork).toBe(false)
|
||||
})
|
||||
|
||||
it("tool.execute sets hasReceivedMeaningfulWork", async () => {
|
||||
// #given
|
||||
const ctx = createMockContext("my-session")
|
||||
const state = createEventState()
|
||||
|
||||
const payload: EventPayload = {
|
||||
type: "tool.execute",
|
||||
properties: {
|
||||
sessionID: "my-session",
|
||||
name: "read_file",
|
||||
input: { filePath: "/src/index.ts" },
|
||||
},
|
||||
}
|
||||
|
||||
const events = toAsyncIterable([payload])
|
||||
const { processEvents } = await import("./events")
|
||||
|
||||
// #when
|
||||
await processEvents(ctx, events, state)
|
||||
|
||||
// #then
|
||||
expect(state.hasReceivedMeaningfulWork).toBe(true)
|
||||
})
|
||||
|
||||
it("tool.execute from different session does not set hasReceivedMeaningfulWork", async () => {
|
||||
// #given
|
||||
const ctx = createMockContext("my-session")
|
||||
const state = createEventState()
|
||||
|
||||
const payload: EventPayload = {
|
||||
type: "tool.execute",
|
||||
properties: {
|
||||
sessionID: "other-session",
|
||||
name: "read_file",
|
||||
input: { filePath: "/src/index.ts" },
|
||||
},
|
||||
}
|
||||
|
||||
const events = toAsyncIterable([payload])
|
||||
const { processEvents } = await import("./events")
|
||||
|
||||
// #when
|
||||
await processEvents(ctx, events, state)
|
||||
|
||||
// #then - different session's tool call shouldn't count
|
||||
expect(state.hasReceivedMeaningfulWork).toBe(false)
|
||||
})
|
||||
|
||||
it("session.status with busy type sets mainSessionIdle to false", async () => {
|
||||
// #given
|
||||
const ctx = createMockContext("my-session")
|
||||
@@ -136,6 +250,7 @@ describe("event handling", () => {
|
||||
lastOutput: "",
|
||||
lastPartText: "",
|
||||
currentTool: null,
|
||||
hasReceivedMeaningfulWork: false,
|
||||
}
|
||||
|
||||
const payload: EventPayload = {
|
||||
|
||||
@@ -63,6 +63,8 @@ export interface EventState {
|
||||
lastOutput: string
|
||||
lastPartText: string
|
||||
currentTool: string | null
|
||||
/** Set to true when the main session has produced meaningful work (text, tool call, or tool result) */
|
||||
hasReceivedMeaningfulWork: boolean
|
||||
}
|
||||
|
||||
export function createEventState(): EventState {
|
||||
@@ -73,6 +75,7 @@ export function createEventState(): EventState {
|
||||
lastOutput: "",
|
||||
lastPartText: "",
|
||||
currentTool: null,
|
||||
hasReceivedMeaningfulWork: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +116,9 @@ function logEventVerbose(ctx: RunContext, payload: EventPayload): void {
|
||||
const isMainSession = sessionID === ctx.sessionID
|
||||
const sessionTag = isMainSession
|
||||
? pc.green("[MAIN]")
|
||||
: pc.yellow(`[${String(sessionID).slice(0, 8)}]`)
|
||||
: sessionID
|
||||
? pc.yellow(`[${String(sessionID).slice(0, 8)}]`)
|
||||
: pc.dim("[system]")
|
||||
|
||||
switch (payload.type) {
|
||||
case "session.idle":
|
||||
@@ -124,8 +129,6 @@ function logEventVerbose(ctx: RunContext, payload: EventPayload): void {
|
||||
}
|
||||
|
||||
case "message.part.updated": {
|
||||
// Skip verbose logging for partial message updates
|
||||
// Only log tool invocation state changes, not text streaming
|
||||
const partProps = props as MessagePartUpdatedProps | undefined
|
||||
const part = partProps?.part
|
||||
if (part?.type === "tool-invocation") {
|
||||
@@ -133,6 +136,11 @@ function logEventVerbose(ctx: RunContext, payload: EventPayload): void {
|
||||
console.error(
|
||||
pc.dim(`${sessionTag} message.part (tool): ${toolPart.toolName} [${toolPart.state}]`)
|
||||
)
|
||||
} else if (part?.type === "text" && part.text) {
|
||||
const preview = part.text.slice(0, 80).replace(/\n/g, "\\n")
|
||||
console.error(
|
||||
pc.dim(`${sessionTag} message.part (text): "${preview}${part.text.length > 80 ? "..." : ""}"`)
|
||||
)
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -140,11 +148,10 @@ function logEventVerbose(ctx: RunContext, payload: EventPayload): void {
|
||||
case "message.updated": {
|
||||
const msgProps = props as MessageUpdatedProps | undefined
|
||||
const role = msgProps?.info?.role ?? "unknown"
|
||||
const content = msgProps?.content ?? ""
|
||||
const preview = content.slice(0, 100).replace(/\n/g, "\\n")
|
||||
console.error(
|
||||
pc.dim(`${sessionTag} message.updated (${role}): "${preview}${content.length > 100 ? "..." : ""}"`)
|
||||
)
|
||||
const model = msgProps?.info?.modelID
|
||||
const agent = msgProps?.info?.agent
|
||||
const details = [role, agent, model].filter(Boolean).join(", ")
|
||||
console.error(pc.dim(`${sessionTag} message.updated (${details})`))
|
||||
break
|
||||
}
|
||||
|
||||
@@ -241,6 +248,7 @@ function handleMessagePartUpdated(
|
||||
const newText = part.text.slice(state.lastPartText.length)
|
||||
if (newText) {
|
||||
process.stdout.write(newText)
|
||||
state.hasReceivedMeaningfulWork = true
|
||||
}
|
||||
state.lastPartText = part.text
|
||||
}
|
||||
@@ -257,16 +265,7 @@ function handleMessageUpdated(
|
||||
if (props?.info?.sessionID !== ctx.sessionID) return
|
||||
if (props?.info?.role !== "assistant") return
|
||||
|
||||
const content = props.content
|
||||
if (!content || content === state.lastOutput) return
|
||||
|
||||
if (state.lastPartText.length === 0) {
|
||||
const newContent = content.slice(state.lastOutput.length)
|
||||
if (newContent) {
|
||||
process.stdout.write(newContent)
|
||||
}
|
||||
}
|
||||
state.lastOutput = content
|
||||
state.hasReceivedMeaningfulWork = true
|
||||
}
|
||||
|
||||
function handleToolExecute(
|
||||
@@ -296,6 +295,7 @@ function handleToolExecute(
|
||||
}
|
||||
}
|
||||
|
||||
state.hasReceivedMeaningfulWork = true
|
||||
process.stdout.write(`\n${pc.cyan(">")} ${pc.bold(toolName)}${inputPreview}\n`)
|
||||
}
|
||||
|
||||
|
||||
@@ -143,6 +143,14 @@ export async function run(options: RunOptions): Promise<number> {
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
// Guard against premature completion: don't check completion until the
|
||||
// session has produced meaningful work (text output, tool call, or tool result).
|
||||
// Without this, a session that goes busy->idle before the LLM responds
|
||||
// would exit immediately because 0 todos + 0 children = "complete".
|
||||
if (!eventState.hasReceivedMeaningfulWork) {
|
||||
continue
|
||||
}
|
||||
|
||||
const shouldExit = await checkCompletionConditions(ctx)
|
||||
if (shouldExit) {
|
||||
console.log(pc.green("\n\nAll tasks completed."))
|
||||
|
||||
@@ -44,8 +44,13 @@ export interface SessionStatusProps {
|
||||
}
|
||||
|
||||
export interface MessageUpdatedProps {
|
||||
info?: { sessionID?: string; role?: string }
|
||||
content?: string
|
||||
info?: {
|
||||
sessionID?: string
|
||||
role?: string
|
||||
modelID?: string
|
||||
providerID?: string
|
||||
agent?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface MessagePartUpdatedProps {
|
||||
|
||||
@@ -170,6 +170,7 @@ function createBackgroundManager(): BackgroundManager {
|
||||
const client = {
|
||||
session: {
|
||||
prompt: async () => ({}),
|
||||
abort: async () => ({}),
|
||||
},
|
||||
}
|
||||
return new BackgroundManager({ client, directory: tmpdir() } as unknown as PluginInput)
|
||||
@@ -1053,6 +1054,7 @@ describe("BackgroundManager.resume model persistence", () => {
|
||||
promptCalls.push(args)
|
||||
return {}
|
||||
},
|
||||
abort: async () => ({}),
|
||||
},
|
||||
}
|
||||
manager = new BackgroundManager({ client, directory: tmpdir() } as unknown as PluginInput)
|
||||
@@ -1926,3 +1928,162 @@ describe("BackgroundManager.checkAndInterruptStaleTasks", () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe("BackgroundManager.shutdown session abort", () => {
|
||||
test("should call session.abort for all running tasks during shutdown", () => {
|
||||
// #given
|
||||
const abortedSessionIDs: string[] = []
|
||||
const client = {
|
||||
session: {
|
||||
prompt: async () => ({}),
|
||||
abort: async (args: { path: { id: string } }) => {
|
||||
abortedSessionIDs.push(args.path.id)
|
||||
return {}
|
||||
},
|
||||
},
|
||||
}
|
||||
const manager = new BackgroundManager({ client, directory: tmpdir() } as unknown as PluginInput)
|
||||
|
||||
const task1: BackgroundTask = {
|
||||
id: "task-1",
|
||||
sessionID: "session-1",
|
||||
parentSessionID: "parent-1",
|
||||
parentMessageID: "msg-1",
|
||||
description: "Running task 1",
|
||||
prompt: "Test",
|
||||
agent: "test-agent",
|
||||
status: "running",
|
||||
startedAt: new Date(),
|
||||
}
|
||||
const task2: BackgroundTask = {
|
||||
id: "task-2",
|
||||
sessionID: "session-2",
|
||||
parentSessionID: "parent-2",
|
||||
parentMessageID: "msg-2",
|
||||
description: "Running task 2",
|
||||
prompt: "Test",
|
||||
agent: "test-agent",
|
||||
status: "running",
|
||||
startedAt: new Date(),
|
||||
}
|
||||
|
||||
getTaskMap(manager).set(task1.id, task1)
|
||||
getTaskMap(manager).set(task2.id, task2)
|
||||
|
||||
// #when
|
||||
manager.shutdown()
|
||||
|
||||
// #then
|
||||
expect(abortedSessionIDs).toContain("session-1")
|
||||
expect(abortedSessionIDs).toContain("session-2")
|
||||
expect(abortedSessionIDs).toHaveLength(2)
|
||||
})
|
||||
|
||||
test("should not call session.abort for completed or cancelled tasks", () => {
|
||||
// #given
|
||||
const abortedSessionIDs: string[] = []
|
||||
const client = {
|
||||
session: {
|
||||
prompt: async () => ({}),
|
||||
abort: async (args: { path: { id: string } }) => {
|
||||
abortedSessionIDs.push(args.path.id)
|
||||
return {}
|
||||
},
|
||||
},
|
||||
}
|
||||
const manager = new BackgroundManager({ client, directory: tmpdir() } as unknown as PluginInput)
|
||||
|
||||
const completedTask: BackgroundTask = {
|
||||
id: "task-completed",
|
||||
sessionID: "session-completed",
|
||||
parentSessionID: "parent-1",
|
||||
parentMessageID: "msg-1",
|
||||
description: "Completed task",
|
||||
prompt: "Test",
|
||||
agent: "test-agent",
|
||||
status: "completed",
|
||||
startedAt: new Date(),
|
||||
completedAt: new Date(),
|
||||
}
|
||||
const cancelledTask: BackgroundTask = {
|
||||
id: "task-cancelled",
|
||||
sessionID: "session-cancelled",
|
||||
parentSessionID: "parent-2",
|
||||
parentMessageID: "msg-2",
|
||||
description: "Cancelled task",
|
||||
prompt: "Test",
|
||||
agent: "test-agent",
|
||||
status: "cancelled",
|
||||
startedAt: new Date(),
|
||||
completedAt: new Date(),
|
||||
}
|
||||
const pendingTask: BackgroundTask = {
|
||||
id: "task-pending",
|
||||
parentSessionID: "parent-3",
|
||||
parentMessageID: "msg-3",
|
||||
description: "Pending task",
|
||||
prompt: "Test",
|
||||
agent: "test-agent",
|
||||
status: "pending",
|
||||
queuedAt: new Date(),
|
||||
}
|
||||
|
||||
getTaskMap(manager).set(completedTask.id, completedTask)
|
||||
getTaskMap(manager).set(cancelledTask.id, cancelledTask)
|
||||
getTaskMap(manager).set(pendingTask.id, pendingTask)
|
||||
|
||||
// #when
|
||||
manager.shutdown()
|
||||
|
||||
// #then
|
||||
expect(abortedSessionIDs).toHaveLength(0)
|
||||
})
|
||||
|
||||
test("should call onShutdown callback during shutdown", () => {
|
||||
// #given
|
||||
let shutdownCalled = false
|
||||
const client = {
|
||||
session: {
|
||||
prompt: async () => ({}),
|
||||
abort: async () => ({}),
|
||||
},
|
||||
}
|
||||
const manager = new BackgroundManager(
|
||||
{ client, directory: tmpdir() } as unknown as PluginInput,
|
||||
undefined,
|
||||
{
|
||||
onShutdown: () => {
|
||||
shutdownCalled = true
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// #when
|
||||
manager.shutdown()
|
||||
|
||||
// #then
|
||||
expect(shutdownCalled).toBe(true)
|
||||
})
|
||||
|
||||
test("should not throw when onShutdown callback throws", () => {
|
||||
// #given
|
||||
const client = {
|
||||
session: {
|
||||
prompt: async () => ({}),
|
||||
abort: async () => ({}),
|
||||
},
|
||||
}
|
||||
const manager = new BackgroundManager(
|
||||
{ client, directory: tmpdir() } as unknown as PluginInput,
|
||||
undefined,
|
||||
{
|
||||
onShutdown: () => {
|
||||
throw new Error("cleanup failed")
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// #when / #then
|
||||
expect(() => manager.shutdown()).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ export class BackgroundManager {
|
||||
private config?: BackgroundTaskConfig
|
||||
private tmuxEnabled: boolean
|
||||
private onSubagentSessionCreated?: OnSubagentSessionCreated
|
||||
private onShutdown?: () => void
|
||||
|
||||
private queuesByKey: Map<string, QueueItem[]> = new Map()
|
||||
private processingKeys: Set<string> = new Set()
|
||||
@@ -89,6 +90,7 @@ export class BackgroundManager {
|
||||
options?: {
|
||||
tmuxConfig?: TmuxConfig
|
||||
onSubagentSessionCreated?: OnSubagentSessionCreated
|
||||
onShutdown?: () => void
|
||||
}
|
||||
) {
|
||||
this.tasks = new Map()
|
||||
@@ -100,6 +102,7 @@ export class BackgroundManager {
|
||||
this.config = config
|
||||
this.tmuxEnabled = options?.tmuxConfig?.enabled ?? false
|
||||
this.onSubagentSessionCreated = options?.onSubagentSessionCreated
|
||||
this.onShutdown = options?.onShutdown
|
||||
this.registerProcessCleanup()
|
||||
}
|
||||
|
||||
@@ -1346,7 +1349,25 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
log("[background-agent] Shutting down BackgroundManager")
|
||||
this.stopPolling()
|
||||
|
||||
// Release concurrency for all running tasks first
|
||||
// Abort all running sessions to prevent zombie processes (#1240)
|
||||
for (const task of this.tasks.values()) {
|
||||
if (task.status === "running" && task.sessionID) {
|
||||
this.client.session.abort({
|
||||
path: { id: task.sessionID },
|
||||
}).catch(() => {})
|
||||
}
|
||||
}
|
||||
|
||||
// Notify shutdown listeners (e.g., tmux cleanup)
|
||||
if (this.onShutdown) {
|
||||
try {
|
||||
this.onShutdown()
|
||||
} catch (error) {
|
||||
log("[background-agent] Error in onShutdown callback:", error)
|
||||
}
|
||||
}
|
||||
|
||||
// Release concurrency for all running tasks
|
||||
for (const task of this.tasks.values()) {
|
||||
if (task.concurrencyKey) {
|
||||
this.concurrencyManager.release(task.concurrencyKey)
|
||||
|
||||
@@ -55,7 +55,6 @@ ${REFACTOR_TEMPLATE}
|
||||
},
|
||||
"start-work": {
|
||||
description: "(builtin) Start Sisyphus work session from Prometheus plan",
|
||||
agent: "atlas",
|
||||
template: `<command-instruction>
|
||||
${START_WORK_TEMPLATE}
|
||||
</command-instruction>
|
||||
|
||||
@@ -7,6 +7,10 @@ export interface ClaudeCodeMcpServer {
|
||||
args?: string[]
|
||||
env?: Record<string, string>
|
||||
headers?: Record<string, string>
|
||||
oauth?: {
|
||||
clientId?: string
|
||||
scopes?: string[]
|
||||
}
|
||||
disabled?: boolean
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { MessageMeta, OriginalMessageContext, TextPart, ToolPermission } fr
|
||||
|
||||
export interface StoredMessage {
|
||||
agent?: string
|
||||
model?: { providerID?: string; modelID?: string }
|
||||
model?: { providerID?: string; modelID?: string; variant?: string }
|
||||
tools?: Record<string, ToolPermission>
|
||||
}
|
||||
|
||||
@@ -141,9 +141,17 @@ export function injectHookMessage(
|
||||
const resolvedAgent = originalMessage.agent ?? fallback?.agent ?? "general"
|
||||
const resolvedModel =
|
||||
originalMessage.model?.providerID && originalMessage.model?.modelID
|
||||
? { providerID: originalMessage.model.providerID, modelID: originalMessage.model.modelID }
|
||||
? {
|
||||
providerID: originalMessage.model.providerID,
|
||||
modelID: originalMessage.model.modelID,
|
||||
...(originalMessage.model.variant ? { variant: originalMessage.model.variant } : {})
|
||||
}
|
||||
: fallback?.model?.providerID && fallback?.model?.modelID
|
||||
? { providerID: fallback.model.providerID, modelID: fallback.model.modelID }
|
||||
? {
|
||||
providerID: fallback.model.providerID,
|
||||
modelID: fallback.model.modelID,
|
||||
...(fallback.model.variant ? { variant: fallback.model.variant } : {})
|
||||
}
|
||||
: undefined
|
||||
const resolvedTools = originalMessage.tools ?? fallback?.tools
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ export interface MessageMeta {
|
||||
model?: {
|
||||
providerID: string
|
||||
modelID: string
|
||||
variant?: string
|
||||
}
|
||||
path?: {
|
||||
cwd: string
|
||||
@@ -25,6 +26,7 @@ export interface OriginalMessageContext {
|
||||
model?: {
|
||||
providerID?: string
|
||||
modelID?: string
|
||||
variant?: string
|
||||
}
|
||||
path?: {
|
||||
cwd?: string
|
||||
|
||||
129
src/features/mcp-oauth/callback-server.test.ts
Normal file
129
src/features/mcp-oauth/callback-server.test.ts
Normal file
@@ -0,0 +1,129 @@
|
||||
import { afterEach, describe, expect, it } from "bun:test"
|
||||
import { findAvailablePort, startCallbackServer, type CallbackServer } from "./callback-server"
|
||||
|
||||
describe("findAvailablePort", () => {
|
||||
it("returns the start port when it is available", async () => {
|
||||
//#given
|
||||
const startPort = 19877
|
||||
|
||||
//#when
|
||||
const port = await findAvailablePort(startPort)
|
||||
|
||||
//#then
|
||||
expect(port).toBeGreaterThanOrEqual(startPort)
|
||||
expect(port).toBeLessThan(startPort + 20)
|
||||
})
|
||||
|
||||
it("skips busy ports and returns next available", async () => {
|
||||
//#given
|
||||
const blocker = Bun.serve({
|
||||
port: 19877,
|
||||
hostname: "127.0.0.1",
|
||||
fetch: () => new Response(),
|
||||
})
|
||||
|
||||
//#when
|
||||
const port = await findAvailablePort(19877)
|
||||
|
||||
//#then
|
||||
expect(port).toBeGreaterThan(19877)
|
||||
blocker.stop(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe("startCallbackServer", () => {
|
||||
let server: CallbackServer | null = null
|
||||
|
||||
afterEach(() => {
|
||||
server?.close()
|
||||
server = null
|
||||
})
|
||||
|
||||
it("starts server and returns port", async () => {
|
||||
//#given - no preconditions
|
||||
|
||||
//#when
|
||||
server = await startCallbackServer()
|
||||
|
||||
//#then
|
||||
expect(server.port).toBeGreaterThanOrEqual(19877)
|
||||
expect(typeof server.waitForCallback).toBe("function")
|
||||
expect(typeof server.close).toBe("function")
|
||||
})
|
||||
|
||||
it("resolves callback with code and state from query params", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
const callbackUrl = `http://127.0.0.1:${server.port}/oauth/callback?code=test-code&state=test-state`
|
||||
|
||||
//#when
|
||||
const fetchPromise = fetch(callbackUrl)
|
||||
const result = await server.waitForCallback()
|
||||
const response = await fetchPromise
|
||||
|
||||
//#then
|
||||
expect(result).toEqual({ code: "test-code", state: "test-state" })
|
||||
expect(response.status).toBe(200)
|
||||
const html = await response.text()
|
||||
expect(html).toContain("Authorization successful")
|
||||
})
|
||||
|
||||
it("returns 404 for non-callback routes", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
|
||||
//#when
|
||||
const response = await fetch(`http://127.0.0.1:${server.port}/other`)
|
||||
|
||||
//#then
|
||||
expect(response.status).toBe(404)
|
||||
})
|
||||
|
||||
it("returns 400 and rejects when code is missing", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
const callbackRejection = server.waitForCallback().catch((e: Error) => e)
|
||||
|
||||
//#when
|
||||
const response = await fetch(`http://127.0.0.1:${server.port}/oauth/callback?state=s`)
|
||||
|
||||
//#then
|
||||
expect(response.status).toBe(400)
|
||||
const error = await callbackRejection
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
expect((error as Error).message).toContain("missing code or state")
|
||||
})
|
||||
|
||||
it("returns 400 and rejects when state is missing", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
const callbackRejection = server.waitForCallback().catch((e: Error) => e)
|
||||
|
||||
//#when
|
||||
const response = await fetch(`http://127.0.0.1:${server.port}/oauth/callback?code=c`)
|
||||
|
||||
//#then
|
||||
expect(response.status).toBe(400)
|
||||
const error = await callbackRejection
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
expect((error as Error).message).toContain("missing code or state")
|
||||
})
|
||||
|
||||
it("close stops the server immediately", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
const port = server.port
|
||||
|
||||
//#when
|
||||
server.close()
|
||||
server = null
|
||||
|
||||
//#then
|
||||
try {
|
||||
await fetch(`http://127.0.0.1:${port}/oauth/callback?code=c&state=s`)
|
||||
expect(true).toBe(false)
|
||||
} catch (error) {
|
||||
expect(error).toBeDefined()
|
||||
}
|
||||
})
|
||||
})
|
||||
124
src/features/mcp-oauth/callback-server.ts
Normal file
124
src/features/mcp-oauth/callback-server.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
const DEFAULT_PORT = 19877
|
||||
const MAX_PORT_ATTEMPTS = 20
|
||||
const TIMEOUT_MS = 5 * 60 * 1000
|
||||
|
||||
export type OAuthCallbackResult = {
|
||||
code: string
|
||||
state: string
|
||||
}
|
||||
|
||||
export type CallbackServer = {
|
||||
port: number
|
||||
waitForCallback: () => Promise<OAuthCallbackResult>
|
||||
close: () => void
|
||||
}
|
||||
|
||||
const SUCCESS_HTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>OAuth Authorized</title>
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #0a0a0a; color: #fafafa; }
|
||||
.container { text-align: center; }
|
||||
h1 { font-size: 1.5rem; margin-bottom: 0.5rem; }
|
||||
p { color: #888; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Authorization successful</h1>
|
||||
<p>You can close this window and return to your terminal.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
async function isPortAvailable(port: number): Promise<boolean> {
|
||||
try {
|
||||
const server = Bun.serve({
|
||||
port,
|
||||
hostname: "127.0.0.1",
|
||||
fetch: () => new Response(),
|
||||
})
|
||||
server.stop(true)
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export async function findAvailablePort(startPort: number = DEFAULT_PORT): Promise<number> {
|
||||
for (let attempt = 0; attempt < MAX_PORT_ATTEMPTS; attempt++) {
|
||||
const port = startPort + attempt
|
||||
if (await isPortAvailable(port)) {
|
||||
return port
|
||||
}
|
||||
}
|
||||
throw new Error(`No available port found in range ${startPort}-${startPort + MAX_PORT_ATTEMPTS - 1}`)
|
||||
}
|
||||
|
||||
export async function startCallbackServer(startPort: number = DEFAULT_PORT): Promise<CallbackServer> {
|
||||
const port = await findAvailablePort(startPort)
|
||||
|
||||
let resolveCallback: ((result: OAuthCallbackResult) => void) | null = null
|
||||
let rejectCallback: ((error: Error) => void) | null = null
|
||||
|
||||
const callbackPromise = new Promise<OAuthCallbackResult>((resolve, reject) => {
|
||||
resolveCallback = resolve
|
||||
rejectCallback = reject
|
||||
})
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
rejectCallback?.(new Error("OAuth callback timed out after 5 minutes"))
|
||||
server.stop(true)
|
||||
}, TIMEOUT_MS)
|
||||
|
||||
const server = Bun.serve({
|
||||
port,
|
||||
hostname: "127.0.0.1",
|
||||
fetch(request: Request): Response {
|
||||
const url = new URL(request.url)
|
||||
|
||||
if (url.pathname !== "/oauth/callback") {
|
||||
return new Response("Not Found", { status: 404 })
|
||||
}
|
||||
|
||||
const oauthError = url.searchParams.get("error")
|
||||
if (oauthError) {
|
||||
const description = url.searchParams.get("error_description") ?? oauthError
|
||||
clearTimeout(timeoutId)
|
||||
rejectCallback?.(new Error(`OAuth authorization failed: ${description}`))
|
||||
setTimeout(() => server.stop(true), 100)
|
||||
return new Response(`Authorization failed: ${description}`, { status: 400 })
|
||||
}
|
||||
|
||||
const code = url.searchParams.get("code")
|
||||
const state = url.searchParams.get("state")
|
||||
|
||||
if (!code || !state) {
|
||||
clearTimeout(timeoutId)
|
||||
rejectCallback?.(new Error("OAuth callback missing code or state parameter"))
|
||||
setTimeout(() => server.stop(true), 100)
|
||||
return new Response("Missing code or state parameter", { status: 400 })
|
||||
}
|
||||
|
||||
resolveCallback?.({ code, state })
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
setTimeout(() => server.stop(true), 100)
|
||||
|
||||
return new Response(SUCCESS_HTML, {
|
||||
headers: { "content-type": "text/html; charset=utf-8" },
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
port,
|
||||
waitForCallback: () => callbackPromise,
|
||||
close: () => {
|
||||
clearTimeout(timeoutId)
|
||||
server.stop(true)
|
||||
},
|
||||
}
|
||||
}
|
||||
164
src/features/mcp-oauth/dcr.test.ts
Normal file
164
src/features/mcp-oauth/dcr.test.ts
Normal file
@@ -0,0 +1,164 @@
|
||||
import { describe, expect, it } from "bun:test"
|
||||
import {
|
||||
getOrRegisterClient,
|
||||
type ClientCredentials,
|
||||
type ClientRegistrationStorage,
|
||||
type DcrFetch,
|
||||
} from "./dcr"
|
||||
|
||||
function createStorage(initial: ClientCredentials | null):
|
||||
& ClientRegistrationStorage
|
||||
& { getLastKey: () => string | null; getLastSet: () => ClientCredentials | null } {
|
||||
let stored = initial
|
||||
let lastKey: string | null = null
|
||||
let lastSet: ClientCredentials | null = null
|
||||
|
||||
return {
|
||||
getClientRegistration: () => stored,
|
||||
setClientRegistration: (serverIdentifier: string, credentials: ClientCredentials) => {
|
||||
lastKey = serverIdentifier
|
||||
lastSet = credentials
|
||||
stored = credentials
|
||||
},
|
||||
getLastKey: () => lastKey,
|
||||
getLastSet: () => lastSet,
|
||||
}
|
||||
}
|
||||
|
||||
describe("getOrRegisterClient", () => {
|
||||
it("returns cached registration when available", async () => {
|
||||
// #given
|
||||
const storage = createStorage({
|
||||
clientId: "cached-client",
|
||||
clientSecret: "cached-secret",
|
||||
})
|
||||
const fetchMock: DcrFetch = async () => {
|
||||
throw new Error("fetch should not be called")
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = await getOrRegisterClient({
|
||||
registrationEndpoint: "https://server.example.com/register",
|
||||
serverIdentifier: "server-1",
|
||||
clientName: "Test Client",
|
||||
redirectUris: ["https://app.example.com/callback"],
|
||||
tokenEndpointAuthMethod: "client_secret_post",
|
||||
storage,
|
||||
fetch: fetchMock,
|
||||
})
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
clientId: "cached-client",
|
||||
clientSecret: "cached-secret",
|
||||
})
|
||||
})
|
||||
|
||||
it("registers client and stores credentials when endpoint available", async () => {
|
||||
// #given
|
||||
const storage = createStorage(null)
|
||||
let fetchCalled = false
|
||||
const fetchMock: DcrFetch = async (
|
||||
input: string,
|
||||
init?: { method?: string; headers?: Record<string, string>; body?: string }
|
||||
) => {
|
||||
fetchCalled = true
|
||||
expect(input).toBe("https://server.example.com/register")
|
||||
if (typeof init?.body !== "string") {
|
||||
throw new Error("Expected request body string")
|
||||
}
|
||||
const payload = JSON.parse(init.body)
|
||||
expect(payload).toEqual({
|
||||
redirect_uris: ["https://app.example.com/callback"],
|
||||
client_name: "Test Client",
|
||||
grant_types: ["authorization_code", "refresh_token"],
|
||||
response_types: ["code"],
|
||||
token_endpoint_auth_method: "client_secret_post",
|
||||
})
|
||||
|
||||
return {
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
client_id: "registered-client",
|
||||
client_secret: "registered-secret",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = await getOrRegisterClient({
|
||||
registrationEndpoint: "https://server.example.com/register",
|
||||
serverIdentifier: "server-2",
|
||||
clientName: "Test Client",
|
||||
redirectUris: ["https://app.example.com/callback"],
|
||||
tokenEndpointAuthMethod: "client_secret_post",
|
||||
storage,
|
||||
fetch: fetchMock,
|
||||
})
|
||||
|
||||
// #then
|
||||
expect(fetchCalled).toBe(true)
|
||||
expect(result).toEqual({
|
||||
clientId: "registered-client",
|
||||
clientSecret: "registered-secret",
|
||||
})
|
||||
expect(storage.getLastKey()).toBe("server-2")
|
||||
expect(storage.getLastSet()).toEqual({
|
||||
clientId: "registered-client",
|
||||
clientSecret: "registered-secret",
|
||||
})
|
||||
})
|
||||
|
||||
it("uses config client id when registration endpoint missing", async () => {
|
||||
// #given
|
||||
const storage = createStorage(null)
|
||||
let fetchCalled = false
|
||||
const fetchMock: DcrFetch = async () => {
|
||||
fetchCalled = true
|
||||
return {
|
||||
ok: false,
|
||||
json: async () => ({}),
|
||||
}
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = await getOrRegisterClient({
|
||||
registrationEndpoint: undefined,
|
||||
serverIdentifier: "server-3",
|
||||
clientName: "Test Client",
|
||||
redirectUris: ["https://app.example.com/callback"],
|
||||
tokenEndpointAuthMethod: "client_secret_post",
|
||||
clientId: "config-client",
|
||||
storage,
|
||||
fetch: fetchMock,
|
||||
})
|
||||
|
||||
// #then
|
||||
expect(fetchCalled).toBe(false)
|
||||
expect(result).toEqual({ clientId: "config-client" })
|
||||
})
|
||||
|
||||
it("falls back to config client id when registration fails", async () => {
|
||||
// #given
|
||||
const storage = createStorage(null)
|
||||
const fetchMock: DcrFetch = async () => {
|
||||
throw new Error("network error")
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = await getOrRegisterClient({
|
||||
registrationEndpoint: "https://server.example.com/register",
|
||||
serverIdentifier: "server-4",
|
||||
clientName: "Test Client",
|
||||
redirectUris: ["https://app.example.com/callback"],
|
||||
tokenEndpointAuthMethod: "client_secret_post",
|
||||
clientId: "fallback-client",
|
||||
storage,
|
||||
fetch: fetchMock,
|
||||
})
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ clientId: "fallback-client" })
|
||||
expect(storage.getLastSet()).toBeNull()
|
||||
})
|
||||
})
|
||||
98
src/features/mcp-oauth/dcr.ts
Normal file
98
src/features/mcp-oauth/dcr.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
export type ClientRegistrationRequest = {
|
||||
redirect_uris: string[]
|
||||
client_name: string
|
||||
grant_types: ["authorization_code", "refresh_token"]
|
||||
response_types: ["code"]
|
||||
token_endpoint_auth_method: "none" | "client_secret_post"
|
||||
}
|
||||
|
||||
export type ClientCredentials = {
|
||||
clientId: string
|
||||
clientSecret?: string
|
||||
}
|
||||
|
||||
export type ClientRegistrationStorage = {
|
||||
getClientRegistration: (serverIdentifier: string) => ClientCredentials | null
|
||||
setClientRegistration: (
|
||||
serverIdentifier: string,
|
||||
credentials: ClientCredentials
|
||||
) => void
|
||||
}
|
||||
|
||||
export type DynamicClientRegistrationOptions = {
|
||||
registrationEndpoint?: string | null
|
||||
serverIdentifier?: string
|
||||
clientName: string
|
||||
redirectUris: string[]
|
||||
tokenEndpointAuthMethod: "none" | "client_secret_post"
|
||||
clientId?: string | null
|
||||
storage: ClientRegistrationStorage
|
||||
fetch?: DcrFetch
|
||||
}
|
||||
|
||||
export type DcrFetch = (
|
||||
input: string,
|
||||
init?: { method?: string; headers?: Record<string, string>; body?: string }
|
||||
) => Promise<{ ok: boolean; json: () => Promise<unknown> }>
|
||||
|
||||
export async function getOrRegisterClient(
|
||||
options: DynamicClientRegistrationOptions
|
||||
): Promise<ClientCredentials | null> {
|
||||
const serverIdentifier =
|
||||
options.serverIdentifier ?? options.registrationEndpoint ?? "default"
|
||||
const existing = options.storage.getClientRegistration(serverIdentifier)
|
||||
if (existing) return existing
|
||||
|
||||
if (!options.registrationEndpoint) {
|
||||
return options.clientId ? { clientId: options.clientId } : null
|
||||
}
|
||||
|
||||
const fetchImpl = options.fetch ?? globalThis.fetch
|
||||
const request: ClientRegistrationRequest = {
|
||||
redirect_uris: options.redirectUris,
|
||||
client_name: options.clientName,
|
||||
grant_types: ["authorization_code", "refresh_token"],
|
||||
response_types: ["code"],
|
||||
token_endpoint_auth_method: options.tokenEndpointAuthMethod,
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetchImpl(options.registrationEndpoint, {
|
||||
method: "POST",
|
||||
headers: { "content-type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
return options.clientId ? { clientId: options.clientId } : null
|
||||
}
|
||||
|
||||
const data: unknown = await response.json()
|
||||
const parsed = parseRegistrationResponse(data)
|
||||
if (!parsed) {
|
||||
return options.clientId ? { clientId: options.clientId } : null
|
||||
}
|
||||
|
||||
options.storage.setClientRegistration(serverIdentifier, parsed)
|
||||
return parsed
|
||||
} catch {
|
||||
return options.clientId ? { clientId: options.clientId } : null
|
||||
}
|
||||
}
|
||||
|
||||
function parseRegistrationResponse(data: unknown): ClientCredentials | null {
|
||||
if (!isRecord(data)) return null
|
||||
const clientId = data.client_id
|
||||
if (typeof clientId !== "string" || clientId.length === 0) return null
|
||||
|
||||
const clientSecret = data.client_secret
|
||||
if (typeof clientSecret === "string" && clientSecret.length > 0) {
|
||||
return { clientId, clientSecret }
|
||||
}
|
||||
|
||||
return { clientId }
|
||||
}
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null
|
||||
}
|
||||
175
src/features/mcp-oauth/discovery.test.ts
Normal file
175
src/features/mcp-oauth/discovery.test.ts
Normal file
@@ -0,0 +1,175 @@
|
||||
import { describe, test, expect, beforeEach, afterEach } from "bun:test"
|
||||
import { discoverOAuthServerMetadata, resetDiscoveryCache } from "./discovery"
|
||||
|
||||
describe("discoverOAuthServerMetadata", () => {
|
||||
const originalFetch = globalThis.fetch
|
||||
|
||||
beforeEach(() => {
|
||||
resetDiscoveryCache()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
Object.defineProperty(globalThis, "fetch", { value: originalFetch, configurable: true })
|
||||
})
|
||||
|
||||
test("returns endpoints from PRM + AS discovery", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const authServer = "https://auth.example.com"
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", authServer).toString()
|
||||
const calls: string[] = []
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
calls.push(url)
|
||||
if (url === prmUrl) {
|
||||
return new Response(JSON.stringify({ authorization_servers: [authServer] }), { status: 200 })
|
||||
}
|
||||
if (url === asUrl) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
authorization_endpoint: "https://auth.example.com/authorize",
|
||||
token_endpoint: "https://auth.example.com/token",
|
||||
registration_endpoint: "https://auth.example.com/register",
|
||||
}),
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
return discoverOAuthServerMetadata(resource).then((result) => {
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
authorizationEndpoint: "https://auth.example.com/authorize",
|
||||
tokenEndpoint: "https://auth.example.com/token",
|
||||
registrationEndpoint: "https://auth.example.com/register",
|
||||
resource,
|
||||
})
|
||||
expect(calls).toEqual([prmUrl, asUrl])
|
||||
})
|
||||
})
|
||||
|
||||
test("falls back to RFC 8414 when PRM returns 404", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", resource).toString()
|
||||
const calls: string[] = []
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
calls.push(url)
|
||||
if (url === prmUrl) {
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
if (url === asUrl) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
authorization_endpoint: "https://mcp.example.com/authorize",
|
||||
token_endpoint: "https://mcp.example.com/token",
|
||||
}),
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
return discoverOAuthServerMetadata(resource).then((result) => {
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
authorizationEndpoint: "https://mcp.example.com/authorize",
|
||||
tokenEndpoint: "https://mcp.example.com/token",
|
||||
registrationEndpoint: undefined,
|
||||
resource,
|
||||
})
|
||||
expect(calls).toEqual([prmUrl, asUrl])
|
||||
})
|
||||
})
|
||||
|
||||
test("throws when both PRM and AS discovery return 404", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", resource).toString()
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
if (url === prmUrl || url === asUrl) {
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
const result = discoverOAuthServerMetadata(resource)
|
||||
|
||||
// #then
|
||||
return expect(result).rejects.toThrow("OAuth authorization server metadata not found")
|
||||
})
|
||||
|
||||
test("throws when AS metadata is malformed", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const authServer = "https://auth.example.com"
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", authServer).toString()
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
if (url === prmUrl) {
|
||||
return new Response(JSON.stringify({ authorization_servers: [authServer] }), { status: 200 })
|
||||
}
|
||||
if (url === asUrl) {
|
||||
return new Response(JSON.stringify({ authorization_endpoint: "https://auth.example.com/authorize" }), {
|
||||
status: 200,
|
||||
})
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
const result = discoverOAuthServerMetadata(resource)
|
||||
|
||||
// #then
|
||||
return expect(result).rejects.toThrow("token_endpoint")
|
||||
})
|
||||
|
||||
test("caches discovery results per resource URL", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const authServer = "https://auth.example.com"
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", authServer).toString()
|
||||
const calls: string[] = []
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
calls.push(url)
|
||||
if (url === prmUrl) {
|
||||
return new Response(JSON.stringify({ authorization_servers: [authServer] }), { status: 200 })
|
||||
}
|
||||
if (url === asUrl) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
authorization_endpoint: "https://auth.example.com/authorize",
|
||||
token_endpoint: "https://auth.example.com/token",
|
||||
}),
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
return discoverOAuthServerMetadata(resource)
|
||||
.then(() => discoverOAuthServerMetadata(resource))
|
||||
.then(() => {
|
||||
// #then
|
||||
expect(calls).toEqual([prmUrl, asUrl])
|
||||
})
|
||||
})
|
||||
})
|
||||
123
src/features/mcp-oauth/discovery.ts
Normal file
123
src/features/mcp-oauth/discovery.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
export interface OAuthServerMetadata {
|
||||
authorizationEndpoint: string
|
||||
tokenEndpoint: string
|
||||
registrationEndpoint?: string
|
||||
resource: string
|
||||
}
|
||||
|
||||
const discoveryCache = new Map<string, OAuthServerMetadata>()
|
||||
const pendingDiscovery = new Map<string, Promise<OAuthServerMetadata>>()
|
||||
|
||||
function parseHttpsUrl(value: string, label: string): URL {
|
||||
const parsed = new URL(value)
|
||||
if (parsed.protocol !== "https:") {
|
||||
throw new Error(`${label} must use https`)
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
function readStringField(source: Record<string, unknown>, field: string): string {
|
||||
const value = source[field]
|
||||
if (typeof value !== "string" || value.length === 0) {
|
||||
throw new Error(`OAuth metadata missing ${field}`)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
async function fetchMetadata(url: string): Promise<{ ok: true; json: Record<string, unknown> } | { ok: false; status: number }> {
|
||||
const response = await fetch(url, { headers: { accept: "application/json" } })
|
||||
if (!response.ok) {
|
||||
return { ok: false, status: response.status }
|
||||
}
|
||||
const json = (await response.json().catch(() => null)) as Record<string, unknown> | null
|
||||
if (!json || typeof json !== "object") {
|
||||
throw new Error("OAuth metadata response is not valid JSON")
|
||||
}
|
||||
return { ok: true, json }
|
||||
}
|
||||
|
||||
async function fetchAuthorizationServerMetadata(issuer: string, resource: string): Promise<OAuthServerMetadata> {
|
||||
const issuerUrl = parseHttpsUrl(issuer, "Authorization server URL")
|
||||
const issuerPath = issuerUrl.pathname.replace(/\/+$/, "")
|
||||
const metadataUrl = new URL(`/.well-known/oauth-authorization-server${issuerPath}`, issuerUrl).toString()
|
||||
const metadata = await fetchMetadata(metadataUrl)
|
||||
|
||||
if (!metadata.ok) {
|
||||
if (metadata.status === 404) {
|
||||
throw new Error("OAuth authorization server metadata not found")
|
||||
}
|
||||
throw new Error(`OAuth authorization server metadata fetch failed (${metadata.status})`)
|
||||
}
|
||||
|
||||
const authorizationEndpoint = parseHttpsUrl(
|
||||
readStringField(metadata.json, "authorization_endpoint"),
|
||||
"authorization_endpoint"
|
||||
).toString()
|
||||
const tokenEndpoint = parseHttpsUrl(
|
||||
readStringField(metadata.json, "token_endpoint"),
|
||||
"token_endpoint"
|
||||
).toString()
|
||||
const registrationEndpointValue = metadata.json.registration_endpoint
|
||||
const registrationEndpoint =
|
||||
typeof registrationEndpointValue === "string" && registrationEndpointValue.length > 0
|
||||
? parseHttpsUrl(registrationEndpointValue, "registration_endpoint").toString()
|
||||
: undefined
|
||||
|
||||
return {
|
||||
authorizationEndpoint,
|
||||
tokenEndpoint,
|
||||
registrationEndpoint,
|
||||
resource,
|
||||
}
|
||||
}
|
||||
|
||||
function parseAuthorizationServers(metadata: Record<string, unknown>): string[] {
|
||||
const servers = metadata.authorization_servers
|
||||
if (!Array.isArray(servers)) return []
|
||||
return servers.filter((server): server is string => typeof server === "string" && server.length > 0)
|
||||
}
|
||||
|
||||
export async function discoverOAuthServerMetadata(resource: string): Promise<OAuthServerMetadata> {
|
||||
const resourceUrl = parseHttpsUrl(resource, "Resource server URL")
|
||||
const resourceKey = resourceUrl.toString()
|
||||
|
||||
const cached = discoveryCache.get(resourceKey)
|
||||
if (cached) return cached
|
||||
|
||||
const pending = pendingDiscovery.get(resourceKey)
|
||||
if (pending) return pending
|
||||
|
||||
const discoveryPromise = (async () => {
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resourceUrl).toString()
|
||||
const prmResponse = await fetchMetadata(prmUrl)
|
||||
|
||||
if (prmResponse.ok) {
|
||||
const authServers = parseAuthorizationServers(prmResponse.json)
|
||||
if (authServers.length === 0) {
|
||||
throw new Error("OAuth protected resource metadata missing authorization_servers")
|
||||
}
|
||||
return fetchAuthorizationServerMetadata(authServers[0], resource)
|
||||
}
|
||||
|
||||
if (prmResponse.status !== 404) {
|
||||
throw new Error(`OAuth protected resource metadata fetch failed (${prmResponse.status})`)
|
||||
}
|
||||
|
||||
return fetchAuthorizationServerMetadata(resourceKey, resource)
|
||||
})()
|
||||
|
||||
pendingDiscovery.set(resourceKey, discoveryPromise)
|
||||
|
||||
try {
|
||||
const result = await discoveryPromise
|
||||
discoveryCache.set(resourceKey, result)
|
||||
return result
|
||||
} finally {
|
||||
pendingDiscovery.delete(resourceKey)
|
||||
}
|
||||
}
|
||||
|
||||
export function resetDiscoveryCache(): void {
|
||||
discoveryCache.clear()
|
||||
pendingDiscovery.clear()
|
||||
}
|
||||
1
src/features/mcp-oauth/index.ts
Normal file
1
src/features/mcp-oauth/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export * from "./schema"
|
||||
223
src/features/mcp-oauth/provider.test.ts
Normal file
223
src/features/mcp-oauth/provider.test.ts
Normal file
@@ -0,0 +1,223 @@
|
||||
import { describe, expect, it, beforeEach, afterEach, mock } from "bun:test"
|
||||
import { createHash, randomBytes } from "node:crypto"
|
||||
import { McpOAuthProvider, generateCodeVerifier, generateCodeChallenge, buildAuthorizationUrl } from "./provider"
|
||||
import type { OAuthTokenData } from "./storage"
|
||||
|
||||
describe("McpOAuthProvider", () => {
|
||||
describe("generateCodeVerifier", () => {
|
||||
it("returns a base64url-encoded 32-byte random string", () => {
|
||||
//#given
|
||||
const verifier = generateCodeVerifier()
|
||||
|
||||
//#when
|
||||
const decoded = Buffer.from(verifier, "base64url")
|
||||
|
||||
//#then
|
||||
expect(decoded.length).toBe(32)
|
||||
expect(verifier).toMatch(/^[A-Za-z0-9_-]+$/)
|
||||
})
|
||||
|
||||
it("produces unique values on each call", () => {
|
||||
//#given
|
||||
const first = generateCodeVerifier()
|
||||
|
||||
//#when
|
||||
const second = generateCodeVerifier()
|
||||
|
||||
//#then
|
||||
expect(first).not.toBe(second)
|
||||
})
|
||||
})
|
||||
|
||||
describe("generateCodeChallenge", () => {
|
||||
it("returns SHA256 base64url digest of the verifier", () => {
|
||||
//#given
|
||||
const verifier = "test-verifier-value"
|
||||
const expected = createHash("sha256").update(verifier).digest("base64url")
|
||||
|
||||
//#when
|
||||
const challenge = generateCodeChallenge(verifier)
|
||||
|
||||
//#then
|
||||
expect(challenge).toBe(expected)
|
||||
})
|
||||
})
|
||||
|
||||
describe("buildAuthorizationUrl", () => {
|
||||
it("builds URL with all required PKCE parameters", () => {
|
||||
//#given
|
||||
const endpoint = "https://auth.example.com/authorize"
|
||||
|
||||
//#when
|
||||
const url = buildAuthorizationUrl(endpoint, {
|
||||
clientId: "my-client",
|
||||
redirectUri: "http://127.0.0.1:8912/callback",
|
||||
codeChallenge: "challenge-value",
|
||||
state: "state-value",
|
||||
scopes: ["openid", "profile"],
|
||||
resource: "https://mcp.example.com",
|
||||
})
|
||||
|
||||
//#then
|
||||
const parsed = new URL(url)
|
||||
expect(parsed.origin + parsed.pathname).toBe("https://auth.example.com/authorize")
|
||||
expect(parsed.searchParams.get("response_type")).toBe("code")
|
||||
expect(parsed.searchParams.get("client_id")).toBe("my-client")
|
||||
expect(parsed.searchParams.get("redirect_uri")).toBe("http://127.0.0.1:8912/callback")
|
||||
expect(parsed.searchParams.get("code_challenge")).toBe("challenge-value")
|
||||
expect(parsed.searchParams.get("code_challenge_method")).toBe("S256")
|
||||
expect(parsed.searchParams.get("state")).toBe("state-value")
|
||||
expect(parsed.searchParams.get("scope")).toBe("openid profile")
|
||||
expect(parsed.searchParams.get("resource")).toBe("https://mcp.example.com")
|
||||
})
|
||||
|
||||
it("omits scope when empty", () => {
|
||||
//#given
|
||||
const endpoint = "https://auth.example.com/authorize"
|
||||
|
||||
//#when
|
||||
const url = buildAuthorizationUrl(endpoint, {
|
||||
clientId: "my-client",
|
||||
redirectUri: "http://127.0.0.1:8912/callback",
|
||||
codeChallenge: "challenge-value",
|
||||
state: "state-value",
|
||||
scopes: [],
|
||||
})
|
||||
|
||||
//#then
|
||||
const parsed = new URL(url)
|
||||
expect(parsed.searchParams.has("scope")).toBe(false)
|
||||
})
|
||||
|
||||
it("omits resource when undefined", () => {
|
||||
//#given
|
||||
const endpoint = "https://auth.example.com/authorize"
|
||||
|
||||
//#when
|
||||
const url = buildAuthorizationUrl(endpoint, {
|
||||
clientId: "my-client",
|
||||
redirectUri: "http://127.0.0.1:8912/callback",
|
||||
codeChallenge: "challenge-value",
|
||||
state: "state-value",
|
||||
})
|
||||
|
||||
//#then
|
||||
const parsed = new URL(url)
|
||||
expect(parsed.searchParams.has("resource")).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("constructor and basic methods", () => {
|
||||
it("stores serverUrl and optional clientId and scopes", () => {
|
||||
//#given
|
||||
const options = {
|
||||
serverUrl: "https://mcp.example.com",
|
||||
clientId: "my-client",
|
||||
scopes: ["openid"],
|
||||
}
|
||||
|
||||
//#when
|
||||
const provider = new McpOAuthProvider(options)
|
||||
|
||||
//#then
|
||||
expect(provider.tokens()).toBeNull()
|
||||
expect(provider.clientInformation()).toBeNull()
|
||||
expect(provider.codeVerifier()).toBeNull()
|
||||
})
|
||||
|
||||
it("defaults scopes to empty array", () => {
|
||||
//#given
|
||||
const options = { serverUrl: "https://mcp.example.com" }
|
||||
|
||||
//#when
|
||||
const provider = new McpOAuthProvider(options)
|
||||
|
||||
//#then
|
||||
expect(provider.redirectUrl()).toBe("http://127.0.0.1:19877/callback")
|
||||
})
|
||||
})
|
||||
|
||||
describe("saveCodeVerifier / codeVerifier", () => {
|
||||
it("stores and retrieves code verifier", () => {
|
||||
//#given
|
||||
const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" })
|
||||
|
||||
//#when
|
||||
provider.saveCodeVerifier("my-verifier")
|
||||
|
||||
//#then
|
||||
expect(provider.codeVerifier()).toBe("my-verifier")
|
||||
})
|
||||
})
|
||||
|
||||
describe("saveTokens / tokens", () => {
|
||||
let originalEnv: string | undefined
|
||||
|
||||
beforeEach(() => {
|
||||
originalEnv = process.env.OPENCODE_CONFIG_DIR
|
||||
const { mkdirSync } = require("node:fs")
|
||||
const { tmpdir } = require("node:os")
|
||||
const { join } = require("node:path")
|
||||
const testDir = join(tmpdir(), "mcp-oauth-provider-test-" + Date.now())
|
||||
mkdirSync(testDir, { recursive: true })
|
||||
process.env.OPENCODE_CONFIG_DIR = testDir
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
if (originalEnv === undefined) {
|
||||
delete process.env.OPENCODE_CONFIG_DIR
|
||||
} else {
|
||||
process.env.OPENCODE_CONFIG_DIR = originalEnv
|
||||
}
|
||||
})
|
||||
|
||||
it("persists and loads token data via storage", () => {
|
||||
//#given
|
||||
const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" })
|
||||
const tokenData: OAuthTokenData = {
|
||||
accessToken: "access-token-123",
|
||||
refreshToken: "refresh-token-456",
|
||||
expiresAt: 1710000000,
|
||||
}
|
||||
|
||||
//#when
|
||||
const saved = provider.saveTokens(tokenData)
|
||||
const loaded = provider.tokens()
|
||||
|
||||
//#then
|
||||
expect(saved).toBe(true)
|
||||
expect(loaded).toEqual(tokenData)
|
||||
})
|
||||
})
|
||||
|
||||
describe("redirectToAuthorization", () => {
|
||||
it("throws when no client information is set", async () => {
|
||||
//#given
|
||||
const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" })
|
||||
const metadata = {
|
||||
authorizationEndpoint: "https://auth.example.com/authorize",
|
||||
tokenEndpoint: "https://auth.example.com/token",
|
||||
resource: "https://mcp.example.com",
|
||||
}
|
||||
|
||||
//#when
|
||||
const result = provider.redirectToAuthorization(metadata)
|
||||
|
||||
//#then
|
||||
await expect(result).rejects.toThrow("No client information available")
|
||||
})
|
||||
})
|
||||
|
||||
describe("redirectUrl", () => {
|
||||
it("returns localhost callback URL with default port", () => {
|
||||
//#given
|
||||
const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" })
|
||||
|
||||
//#when
|
||||
const url = provider.redirectUrl()
|
||||
|
||||
//#then
|
||||
expect(url).toBe("http://127.0.0.1:19877/callback")
|
||||
})
|
||||
})
|
||||
})
|
||||
295
src/features/mcp-oauth/provider.ts
Normal file
295
src/features/mcp-oauth/provider.ts
Normal file
@@ -0,0 +1,295 @@
|
||||
import { createHash, randomBytes } from "node:crypto"
|
||||
import { createServer } from "node:http"
|
||||
import { spawn } from "node:child_process"
|
||||
import type { OAuthTokenData } from "./storage"
|
||||
import { loadToken, saveToken } from "./storage"
|
||||
import { discoverOAuthServerMetadata } from "./discovery"
|
||||
import type { OAuthServerMetadata } from "./discovery"
|
||||
import { getOrRegisterClient } from "./dcr"
|
||||
import type { ClientCredentials, ClientRegistrationStorage } from "./dcr"
|
||||
import { findAvailablePort } from "./callback-server"
|
||||
|
||||
export type McpOAuthProviderOptions = {
|
||||
serverUrl: string
|
||||
clientId?: string
|
||||
scopes?: string[]
|
||||
}
|
||||
|
||||
type CallbackResult = {
|
||||
code: string
|
||||
state: string
|
||||
}
|
||||
|
||||
function generateCodeVerifier(): string {
|
||||
return randomBytes(32).toString("base64url")
|
||||
}
|
||||
|
||||
function generateCodeChallenge(verifier: string): string {
|
||||
return createHash("sha256").update(verifier).digest("base64url")
|
||||
}
|
||||
|
||||
function buildAuthorizationUrl(
|
||||
authorizationEndpoint: string,
|
||||
options: {
|
||||
clientId: string
|
||||
redirectUri: string
|
||||
codeChallenge: string
|
||||
state: string
|
||||
scopes?: string[]
|
||||
resource?: string
|
||||
}
|
||||
): string {
|
||||
const url = new URL(authorizationEndpoint)
|
||||
url.searchParams.set("response_type", "code")
|
||||
url.searchParams.set("client_id", options.clientId)
|
||||
url.searchParams.set("redirect_uri", options.redirectUri)
|
||||
url.searchParams.set("code_challenge", options.codeChallenge)
|
||||
url.searchParams.set("code_challenge_method", "S256")
|
||||
url.searchParams.set("state", options.state)
|
||||
if (options.scopes && options.scopes.length > 0) {
|
||||
url.searchParams.set("scope", options.scopes.join(" "))
|
||||
}
|
||||
if (options.resource) {
|
||||
url.searchParams.set("resource", options.resource)
|
||||
}
|
||||
return url.toString()
|
||||
}
|
||||
|
||||
const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000
|
||||
|
||||
function startCallbackServer(port: number): Promise<CallbackResult> {
|
||||
return new Promise((resolve, reject) => {
|
||||
let timeoutId: ReturnType<typeof setTimeout>
|
||||
|
||||
const server = createServer((request, response) => {
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
const requestUrl = new URL(request.url ?? "/", `http://localhost:${port}`)
|
||||
const code = requestUrl.searchParams.get("code")
|
||||
const state = requestUrl.searchParams.get("state")
|
||||
const error = requestUrl.searchParams.get("error")
|
||||
|
||||
if (error) {
|
||||
const errorDescription = requestUrl.searchParams.get("error_description") ?? error
|
||||
response.writeHead(400, { "content-type": "text/html" })
|
||||
response.end("<html><body><h1>Authorization failed</h1></body></html>")
|
||||
server.close()
|
||||
reject(new Error(`OAuth authorization error: ${errorDescription}`))
|
||||
return
|
||||
}
|
||||
|
||||
if (!code || !state) {
|
||||
response.writeHead(400, { "content-type": "text/html" })
|
||||
response.end("<html><body><h1>Missing code or state</h1></body></html>")
|
||||
server.close()
|
||||
reject(new Error("OAuth callback missing code or state parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
response.writeHead(200, { "content-type": "text/html" })
|
||||
response.end("<html><body><h1>Authorization successful. You can close this tab.</h1></body></html>")
|
||||
server.close()
|
||||
resolve({ code, state })
|
||||
})
|
||||
|
||||
timeoutId = setTimeout(() => {
|
||||
server.close()
|
||||
reject(new Error("OAuth callback timed out after 5 minutes"))
|
||||
}, CALLBACK_TIMEOUT_MS)
|
||||
|
||||
server.listen(port, "127.0.0.1")
|
||||
server.on("error", (err) => {
|
||||
clearTimeout(timeoutId)
|
||||
reject(err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
function openBrowser(url: string): void {
|
||||
const platform = process.platform
|
||||
let cmd: string
|
||||
let args: string[]
|
||||
|
||||
if (platform === "darwin") {
|
||||
cmd = "open"
|
||||
args = [url]
|
||||
} else if (platform === "win32") {
|
||||
cmd = "explorer"
|
||||
args = [url]
|
||||
} else {
|
||||
cmd = "xdg-open"
|
||||
args = [url]
|
||||
}
|
||||
|
||||
try {
|
||||
const child = spawn(cmd, args, { stdio: "ignore", detached: true })
|
||||
child.on("error", () => {})
|
||||
child.unref()
|
||||
} catch {
|
||||
// Browser open failed — user must navigate manually
|
||||
}
|
||||
}
|
||||
|
||||
export class McpOAuthProvider {
|
||||
private readonly serverUrl: string
|
||||
private readonly configClientId: string | undefined
|
||||
private readonly scopes: string[]
|
||||
private storedCodeVerifier: string | null = null
|
||||
private storedClientInfo: ClientCredentials | null = null
|
||||
private callbackPort: number | null = null
|
||||
|
||||
constructor(options: McpOAuthProviderOptions) {
|
||||
this.serverUrl = options.serverUrl
|
||||
this.configClientId = options.clientId
|
||||
this.scopes = options.scopes ?? []
|
||||
}
|
||||
|
||||
tokens(): OAuthTokenData | null {
|
||||
return loadToken(this.serverUrl, this.serverUrl)
|
||||
}
|
||||
|
||||
saveTokens(tokenData: OAuthTokenData): boolean {
|
||||
return saveToken(this.serverUrl, this.serverUrl, tokenData)
|
||||
}
|
||||
|
||||
clientInformation(): ClientCredentials | null {
|
||||
if (this.storedClientInfo) return this.storedClientInfo
|
||||
const tokenData = this.tokens()
|
||||
if (tokenData?.clientInfo) {
|
||||
this.storedClientInfo = tokenData.clientInfo
|
||||
return this.storedClientInfo
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
redirectUrl(): string {
|
||||
return `http://127.0.0.1:${this.callbackPort ?? 19877}/callback`
|
||||
}
|
||||
|
||||
saveCodeVerifier(verifier: string): void {
|
||||
this.storedCodeVerifier = verifier
|
||||
}
|
||||
|
||||
codeVerifier(): string | null {
|
||||
return this.storedCodeVerifier
|
||||
}
|
||||
|
||||
async redirectToAuthorization(metadata: OAuthServerMetadata): Promise<CallbackResult> {
|
||||
const verifier = generateCodeVerifier()
|
||||
this.saveCodeVerifier(verifier)
|
||||
const challenge = generateCodeChallenge(verifier)
|
||||
const state = randomBytes(16).toString("hex")
|
||||
|
||||
const clientInfo = this.clientInformation()
|
||||
if (!clientInfo) {
|
||||
throw new Error("No client information available. Run login() or register a client first.")
|
||||
}
|
||||
|
||||
if (this.callbackPort === null) {
|
||||
this.callbackPort = await findAvailablePort()
|
||||
}
|
||||
|
||||
const authUrl = buildAuthorizationUrl(metadata.authorizationEndpoint, {
|
||||
clientId: clientInfo.clientId,
|
||||
redirectUri: this.redirectUrl(),
|
||||
codeChallenge: challenge,
|
||||
state,
|
||||
scopes: this.scopes,
|
||||
resource: metadata.resource,
|
||||
})
|
||||
|
||||
const callbackPromise = startCallbackServer(this.callbackPort)
|
||||
openBrowser(authUrl)
|
||||
|
||||
const result = await callbackPromise
|
||||
if (result.state !== state) {
|
||||
throw new Error("OAuth state mismatch")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
async login(): Promise<OAuthTokenData> {
|
||||
const metadata = await discoverOAuthServerMetadata(this.serverUrl)
|
||||
|
||||
const clientRegistrationStorage: ClientRegistrationStorage = {
|
||||
getClientRegistration: () => this.storedClientInfo,
|
||||
setClientRegistration: (_serverIdentifier: string, credentials: ClientCredentials) => {
|
||||
this.storedClientInfo = credentials
|
||||
},
|
||||
}
|
||||
|
||||
const clientInfo = await getOrRegisterClient({
|
||||
registrationEndpoint: metadata.registrationEndpoint,
|
||||
serverIdentifier: this.serverUrl,
|
||||
clientName: "oh-my-opencode",
|
||||
redirectUris: [this.redirectUrl()],
|
||||
tokenEndpointAuthMethod: "none",
|
||||
clientId: this.configClientId,
|
||||
storage: clientRegistrationStorage,
|
||||
})
|
||||
|
||||
if (!clientInfo) {
|
||||
throw new Error("Failed to obtain client credentials. Provide a clientId or ensure the server supports DCR.")
|
||||
}
|
||||
|
||||
this.storedClientInfo = clientInfo
|
||||
|
||||
const { code } = await this.redirectToAuthorization(metadata)
|
||||
const verifier = this.codeVerifier()
|
||||
if (!verifier) {
|
||||
throw new Error("Code verifier not found")
|
||||
}
|
||||
|
||||
const tokenResponse = await fetch(metadata.tokenEndpoint, {
|
||||
method: "POST",
|
||||
headers: { "content-type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
grant_type: "authorization_code",
|
||||
code,
|
||||
redirect_uri: this.redirectUrl(),
|
||||
client_id: clientInfo.clientId,
|
||||
code_verifier: verifier,
|
||||
...(metadata.resource ? { resource: metadata.resource } : {}),
|
||||
}).toString(),
|
||||
})
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
let errorDetail = `${tokenResponse.status}`
|
||||
try {
|
||||
const body = (await tokenResponse.json()) as Record<string, unknown>
|
||||
if (body.error) {
|
||||
errorDetail = `${tokenResponse.status} ${body.error}`
|
||||
if (body.error_description) {
|
||||
errorDetail += `: ${body.error_description}`
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Response body not JSON
|
||||
}
|
||||
throw new Error(`Token exchange failed: ${errorDetail}`)
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as Record<string, unknown>
|
||||
const accessToken = tokenData.access_token
|
||||
if (typeof accessToken !== "string") {
|
||||
throw new Error("Token response missing access_token")
|
||||
}
|
||||
|
||||
const oauthTokenData: OAuthTokenData = {
|
||||
accessToken,
|
||||
refreshToken: typeof tokenData.refresh_token === "string" ? tokenData.refresh_token : undefined,
|
||||
expiresAt:
|
||||
typeof tokenData.expires_in === "number" ? Math.floor(Date.now() / 1000) + tokenData.expires_in : undefined,
|
||||
clientInfo: {
|
||||
clientId: clientInfo.clientId,
|
||||
clientSecret: clientInfo.clientSecret,
|
||||
},
|
||||
}
|
||||
|
||||
this.saveTokens(oauthTokenData)
|
||||
return oauthTokenData
|
||||
}
|
||||
}
|
||||
|
||||
export { generateCodeVerifier, generateCodeChallenge, buildAuthorizationUrl, startCallbackServer }
|
||||
121
src/features/mcp-oauth/resource-indicator.test.ts
Normal file
121
src/features/mcp-oauth/resource-indicator.test.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import { describe, expect, it } from "bun:test"
|
||||
import { addResourceToParams, getResourceIndicator } from "./resource-indicator"
|
||||
|
||||
describe("getResourceIndicator", () => {
|
||||
it("returns URL unchanged when already normalized", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com")
|
||||
})
|
||||
|
||||
it("strips trailing slash", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com")
|
||||
})
|
||||
|
||||
it("strips query parameters", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/v1?token=abc&debug=true"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com/v1")
|
||||
})
|
||||
|
||||
it("strips fragment", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/v1#section"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com/v1")
|
||||
})
|
||||
|
||||
it("strips query and trailing slash together", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/api/?key=val"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com/api")
|
||||
})
|
||||
|
||||
it("preserves path segments", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/org/project/v2"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com/org/project/v2")
|
||||
})
|
||||
|
||||
it("preserves port number", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com:8443/api/"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com:8443/api")
|
||||
})
|
||||
})
|
||||
|
||||
describe("addResourceToParams", () => {
|
||||
it("sets resource parameter on empty params", () => {
|
||||
// #given
|
||||
const params = new URLSearchParams()
|
||||
const resource = "https://mcp.example.com"
|
||||
|
||||
// #when
|
||||
addResourceToParams(params, resource)
|
||||
|
||||
// #then
|
||||
expect(params.get("resource")).toBe("https://mcp.example.com")
|
||||
})
|
||||
|
||||
it("adds resource alongside existing parameters", () => {
|
||||
// #given
|
||||
const params = new URLSearchParams({ grant_type: "authorization_code" })
|
||||
const resource = "https://mcp.example.com/v1"
|
||||
|
||||
// #when
|
||||
addResourceToParams(params, resource)
|
||||
|
||||
// #then
|
||||
expect(params.get("grant_type")).toBe("authorization_code")
|
||||
expect(params.get("resource")).toBe("https://mcp.example.com/v1")
|
||||
})
|
||||
|
||||
it("overwrites existing resource parameter", () => {
|
||||
// #given
|
||||
const params = new URLSearchParams({ resource: "https://old.example.com" })
|
||||
const resource = "https://new.example.com"
|
||||
|
||||
// #when
|
||||
addResourceToParams(params, resource)
|
||||
|
||||
// #then
|
||||
expect(params.get("resource")).toBe("https://new.example.com")
|
||||
expect(params.getAll("resource")).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
16
src/features/mcp-oauth/resource-indicator.ts
Normal file
16
src/features/mcp-oauth/resource-indicator.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
export function getResourceIndicator(url: string): string {
|
||||
const parsed = new URL(url)
|
||||
parsed.search = ""
|
||||
parsed.hash = ""
|
||||
|
||||
let normalized = parsed.toString()
|
||||
if (normalized.endsWith("/")) {
|
||||
normalized = normalized.slice(0, -1)
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
export function addResourceToParams(params: URLSearchParams, resource: string): void {
|
||||
params.set("resource", resource)
|
||||
}
|
||||
60
src/features/mcp-oauth/schema.test.ts
Normal file
60
src/features/mcp-oauth/schema.test.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
/// <reference types="bun-types" />
|
||||
import { describe, expect, test } from "bun:test"
|
||||
import { McpOauthSchema } from "./schema"
|
||||
|
||||
describe("McpOauthSchema", () => {
|
||||
test("parses empty oauth config", () => {
|
||||
//#given
|
||||
const input = {}
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.parse(input)
|
||||
|
||||
//#then
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
test("parses oauth config with clientId", () => {
|
||||
//#given
|
||||
const input = { clientId: "client-123" }
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.parse(input)
|
||||
|
||||
//#then
|
||||
expect(result).toEqual({ clientId: "client-123" })
|
||||
})
|
||||
|
||||
test("parses oauth config with scopes", () => {
|
||||
//#given
|
||||
const input = { scopes: ["openid", "profile"] }
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.parse(input)
|
||||
|
||||
//#then
|
||||
expect(result).toEqual({ scopes: ["openid", "profile"] })
|
||||
})
|
||||
|
||||
test("rejects non-string clientId", () => {
|
||||
//#given
|
||||
const input = { clientId: 123 }
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.safeParse(input)
|
||||
|
||||
//#then
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
test("rejects non-string scopes", () => {
|
||||
//#given
|
||||
const input = { scopes: ["openid", 42] }
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.safeParse(input)
|
||||
|
||||
//#then
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
})
|
||||
8
src/features/mcp-oauth/schema.ts
Normal file
8
src/features/mcp-oauth/schema.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import { z } from "zod"
|
||||
|
||||
export const McpOauthSchema = z.object({
|
||||
clientId: z.string().optional(),
|
||||
scopes: z.array(z.string()).optional(),
|
||||
})
|
||||
|
||||
export type McpOauth = z.infer<typeof McpOauthSchema>
|
||||
223
src/features/mcp-oauth/step-up.test.ts
Normal file
223
src/features/mcp-oauth/step-up.test.ts
Normal file
@@ -0,0 +1,223 @@
|
||||
import { describe, expect, it } from "bun:test"
|
||||
import { isStepUpRequired, mergeScopes, parseWwwAuthenticate } from "./step-up"
|
||||
|
||||
describe("parseWwwAuthenticate", () => {
|
||||
it("parses scope from simple Bearer header", () => {
|
||||
// #given
|
||||
const header = 'Bearer scope="read write"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["read", "write"] })
|
||||
})
|
||||
|
||||
it("parses scope with error fields", () => {
|
||||
// #given
|
||||
const header = 'Bearer error="insufficient_scope", scope="admin"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
requiredScopes: ["admin"],
|
||||
error: "insufficient_scope",
|
||||
})
|
||||
})
|
||||
|
||||
it("parses all fields including error_description", () => {
|
||||
// #given
|
||||
const header =
|
||||
'Bearer realm="example", error="insufficient_scope", error_description="Need admin access", scope="admin write"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
requiredScopes: ["admin", "write"],
|
||||
error: "insufficient_scope",
|
||||
errorDescription: "Need admin access",
|
||||
})
|
||||
})
|
||||
|
||||
it("returns null for non-Bearer scheme", () => {
|
||||
// #given
|
||||
const header = 'Basic realm="example"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("returns null when no scope parameter present", () => {
|
||||
// #given
|
||||
const header = 'Bearer error="invalid_token"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("returns null for empty scope value", () => {
|
||||
// #given
|
||||
const header = 'Bearer scope=""'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("returns null for bare Bearer with no params", () => {
|
||||
// #given
|
||||
const header = "Bearer"
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("handles case-insensitive Bearer prefix", () => {
|
||||
// #given
|
||||
const header = 'bearer scope="read"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["read"] })
|
||||
})
|
||||
|
||||
it("parses single scope value", () => {
|
||||
// #given
|
||||
const header = 'Bearer scope="admin"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["admin"] })
|
||||
})
|
||||
})
|
||||
|
||||
describe("mergeScopes", () => {
|
||||
it("merges new scopes into existing", () => {
|
||||
// #given
|
||||
const existing = ["read", "write"]
|
||||
const required = ["admin", "write"]
|
||||
|
||||
// #when
|
||||
const result = mergeScopes(existing, required)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual(["read", "write", "admin"])
|
||||
})
|
||||
|
||||
it("returns required when existing is empty", () => {
|
||||
// #given
|
||||
const existing: string[] = []
|
||||
const required = ["read", "write"]
|
||||
|
||||
// #when
|
||||
const result = mergeScopes(existing, required)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual(["read", "write"])
|
||||
})
|
||||
|
||||
it("returns existing when required is empty", () => {
|
||||
// #given
|
||||
const existing = ["read"]
|
||||
const required: string[] = []
|
||||
|
||||
// #when
|
||||
const result = mergeScopes(existing, required)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual(["read"])
|
||||
})
|
||||
|
||||
it("deduplicates identical scopes", () => {
|
||||
// #given
|
||||
const existing = ["read", "write"]
|
||||
const required = ["read", "write"]
|
||||
|
||||
// #when
|
||||
const result = mergeScopes(existing, required)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual(["read", "write"])
|
||||
})
|
||||
})
|
||||
|
||||
describe("isStepUpRequired", () => {
|
||||
it("returns step-up info for 403 with WWW-Authenticate", () => {
|
||||
// #given
|
||||
const statusCode = 403
|
||||
const headers = { "www-authenticate": 'Bearer scope="admin"' }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["admin"] })
|
||||
})
|
||||
|
||||
it("returns null for non-403 status", () => {
|
||||
// #given
|
||||
const statusCode = 401
|
||||
const headers = { "www-authenticate": 'Bearer scope="admin"' }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("returns null when no WWW-Authenticate header", () => {
|
||||
// #given
|
||||
const statusCode = 403
|
||||
const headers = { "content-type": "application/json" }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("handles capitalized WWW-Authenticate header", () => {
|
||||
// #given
|
||||
const statusCode = 403
|
||||
const headers = { "WWW-Authenticate": 'Bearer scope="read write"' }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["read", "write"] })
|
||||
})
|
||||
|
||||
it("returns null for 403 with unparseable WWW-Authenticate", () => {
|
||||
// #given
|
||||
const statusCode = 403
|
||||
const headers = { "www-authenticate": 'Basic realm="example"' }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
79
src/features/mcp-oauth/step-up.ts
Normal file
79
src/features/mcp-oauth/step-up.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
export interface StepUpInfo {
|
||||
requiredScopes: string[]
|
||||
error?: string
|
||||
errorDescription?: string
|
||||
}
|
||||
|
||||
export function parseWwwAuthenticate(header: string): StepUpInfo | null {
|
||||
const trimmed = header.trim()
|
||||
const lowerHeader = trimmed.toLowerCase()
|
||||
const bearerIndex = lowerHeader.indexOf("bearer")
|
||||
if (bearerIndex === -1) {
|
||||
return null
|
||||
}
|
||||
|
||||
const params = trimmed.slice(bearerIndex + "bearer".length).trim()
|
||||
if (params.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
const scope = extractParam(params, "scope")
|
||||
if (scope === null) {
|
||||
return null
|
||||
}
|
||||
|
||||
const requiredScopes = scope
|
||||
.split(/\s+/)
|
||||
.filter((s) => s.length > 0)
|
||||
|
||||
if (requiredScopes.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
const info: StepUpInfo = { requiredScopes }
|
||||
|
||||
const error = extractParam(params, "error")
|
||||
if (error !== null) {
|
||||
info.error = error
|
||||
}
|
||||
|
||||
const errorDescription = extractParam(params, "error_description")
|
||||
if (errorDescription !== null) {
|
||||
info.errorDescription = errorDescription
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
function extractParam(params: string, name: string): string | null {
|
||||
const quotedPattern = new RegExp(`${name}="([^"]*)"`)
|
||||
const quotedMatch = quotedPattern.exec(params)
|
||||
if (quotedMatch) {
|
||||
return quotedMatch[1]
|
||||
}
|
||||
|
||||
const unquotedPattern = new RegExp(`${name}=([^\\s,]+)`)
|
||||
const unquotedMatch = unquotedPattern.exec(params)
|
||||
return unquotedMatch?.[1] ?? null
|
||||
}
|
||||
|
||||
export function mergeScopes(existing: string[], required: string[]): string[] {
|
||||
const set = new Set(existing)
|
||||
for (const scope of required) {
|
||||
set.add(scope)
|
||||
}
|
||||
return [...set]
|
||||
}
|
||||
|
||||
export function isStepUpRequired(statusCode: number, headers: Record<string, string>): StepUpInfo | null {
|
||||
if (statusCode !== 403) {
|
||||
return null
|
||||
}
|
||||
|
||||
const wwwAuth = headers["www-authenticate"] ?? headers["WWW-Authenticate"]
|
||||
if (!wwwAuth) {
|
||||
return null
|
||||
}
|
||||
|
||||
return parseWwwAuthenticate(wwwAuth)
|
||||
}
|
||||
136
src/features/mcp-oauth/storage.test.ts
Normal file
136
src/features/mcp-oauth/storage.test.ts
Normal file
@@ -0,0 +1,136 @@
|
||||
import { describe, expect, test, beforeEach, afterEach } from "bun:test"
|
||||
import { existsSync, mkdirSync, rmSync, readFileSync, statSync, writeFileSync } from "node:fs"
|
||||
import { join } from "node:path"
|
||||
import { tmpdir } from "node:os"
|
||||
import {
|
||||
deleteToken,
|
||||
getMcpOauthStoragePath,
|
||||
listAllTokens,
|
||||
listTokensByHost,
|
||||
loadToken,
|
||||
saveToken,
|
||||
} from "./storage"
|
||||
import type { OAuthTokenData } from "./storage"
|
||||
|
||||
describe("mcp-oauth storage", () => {
|
||||
const TEST_CONFIG_DIR = join(tmpdir(), "mcp-oauth-test-" + Date.now())
|
||||
let originalConfigDir: string | undefined
|
||||
|
||||
beforeEach(() => {
|
||||
originalConfigDir = process.env.OPENCODE_CONFIG_DIR
|
||||
process.env.OPENCODE_CONFIG_DIR = TEST_CONFIG_DIR
|
||||
if (!existsSync(TEST_CONFIG_DIR)) {
|
||||
mkdirSync(TEST_CONFIG_DIR, { recursive: true })
|
||||
}
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
if (originalConfigDir === undefined) {
|
||||
delete process.env.OPENCODE_CONFIG_DIR
|
||||
} else {
|
||||
process.env.OPENCODE_CONFIG_DIR = originalConfigDir
|
||||
}
|
||||
if (existsSync(TEST_CONFIG_DIR)) {
|
||||
rmSync(TEST_CONFIG_DIR, { recursive: true, force: true })
|
||||
}
|
||||
})
|
||||
|
||||
test("should save tokens with {host}/{resource} key and set 0600 permissions", () => {
|
||||
// #given
|
||||
const token: OAuthTokenData = {
|
||||
accessToken: "access-1",
|
||||
refreshToken: "refresh-1",
|
||||
expiresAt: 1710000000,
|
||||
clientInfo: { clientId: "client-1", clientSecret: "secret-1" },
|
||||
}
|
||||
|
||||
// #when
|
||||
const success = saveToken("https://example.com:443", "mcp/v1", token)
|
||||
const storagePath = getMcpOauthStoragePath()
|
||||
const parsed = JSON.parse(readFileSync(storagePath, "utf-8")) as Record<string, OAuthTokenData>
|
||||
const mode = statSync(storagePath).mode & 0o777
|
||||
|
||||
// #then
|
||||
expect(success).toBe(true)
|
||||
expect(Object.keys(parsed)).toEqual(["example.com/mcp/v1"])
|
||||
expect(parsed["example.com/mcp/v1"].accessToken).toBe("access-1")
|
||||
expect(mode).toBe(0o600)
|
||||
})
|
||||
|
||||
test("should load a saved token", () => {
|
||||
// #given
|
||||
const token: OAuthTokenData = { accessToken: "access-2", refreshToken: "refresh-2" }
|
||||
saveToken("api.example.com", "resource-a", token)
|
||||
|
||||
// #when
|
||||
const loaded = loadToken("api.example.com:8443", "resource-a")
|
||||
|
||||
// #then
|
||||
expect(loaded).toEqual(token)
|
||||
})
|
||||
|
||||
test("should delete a token", () => {
|
||||
// #given
|
||||
const token: OAuthTokenData = { accessToken: "access-3" }
|
||||
saveToken("api.example.com", "resource-b", token)
|
||||
|
||||
// #when
|
||||
const success = deleteToken("api.example.com", "resource-b")
|
||||
const loaded = loadToken("api.example.com", "resource-b")
|
||||
|
||||
// #then
|
||||
expect(success).toBe(true)
|
||||
expect(loaded).toBeNull()
|
||||
})
|
||||
|
||||
test("should list tokens by host", () => {
|
||||
// #given
|
||||
saveToken("api.example.com", "resource-a", { accessToken: "access-a" })
|
||||
saveToken("api.example.com", "resource-b", { accessToken: "access-b" })
|
||||
saveToken("other.example.com", "resource-c", { accessToken: "access-c" })
|
||||
|
||||
// #when
|
||||
const entries = listTokensByHost("api.example.com:5555")
|
||||
|
||||
// #then
|
||||
expect(Object.keys(entries).sort()).toEqual([
|
||||
"api.example.com/resource-a",
|
||||
"api.example.com/resource-b",
|
||||
])
|
||||
expect(entries["api.example.com/resource-a"].accessToken).toBe("access-a")
|
||||
})
|
||||
|
||||
test("should handle missing storage file", () => {
|
||||
// #given
|
||||
const storagePath = getMcpOauthStoragePath()
|
||||
if (existsSync(storagePath)) {
|
||||
rmSync(storagePath, { force: true })
|
||||
}
|
||||
|
||||
// #when
|
||||
const loaded = loadToken("api.example.com", "resource-a")
|
||||
const entries = listTokensByHost("api.example.com")
|
||||
|
||||
// #then
|
||||
expect(loaded).toBeNull()
|
||||
expect(entries).toEqual({})
|
||||
})
|
||||
|
||||
test("should handle invalid JSON", () => {
|
||||
// #given
|
||||
const storagePath = getMcpOauthStoragePath()
|
||||
const dir = join(storagePath, "..")
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true })
|
||||
}
|
||||
writeFileSync(storagePath, "{not-valid-json", "utf-8")
|
||||
|
||||
// #when
|
||||
const loaded = loadToken("api.example.com", "resource-a")
|
||||
const entries = listTokensByHost("api.example.com")
|
||||
|
||||
// #then
|
||||
expect(loaded).toBeNull()
|
||||
expect(entries).toEqual({})
|
||||
})
|
||||
})
|
||||
153
src/features/mcp-oauth/storage.ts
Normal file
153
src/features/mcp-oauth/storage.ts
Normal file
@@ -0,0 +1,153 @@
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, unlinkSync, writeFileSync } from "node:fs"
|
||||
import { dirname, join } from "node:path"
|
||||
import { getOpenCodeConfigDir } from "../../shared"
|
||||
|
||||
export interface OAuthTokenData {
|
||||
accessToken: string
|
||||
refreshToken?: string
|
||||
expiresAt?: number
|
||||
clientInfo?: {
|
||||
clientId: string
|
||||
clientSecret?: string
|
||||
}
|
||||
}
|
||||
|
||||
type TokenStore = Record<string, OAuthTokenData>
|
||||
|
||||
const STORAGE_FILE_NAME = "mcp-oauth.json"
|
||||
|
||||
export function getMcpOauthStoragePath(): string {
|
||||
return join(getOpenCodeConfigDir({ binary: "opencode" }), STORAGE_FILE_NAME)
|
||||
}
|
||||
|
||||
function normalizeHost(serverHost: string): string {
|
||||
let host = serverHost.trim()
|
||||
if (!host) return host
|
||||
|
||||
if (host.includes("://")) {
|
||||
try {
|
||||
host = new URL(host).hostname
|
||||
} catch {
|
||||
host = host.split("/")[0]
|
||||
}
|
||||
} else {
|
||||
host = host.split("/")[0]
|
||||
}
|
||||
|
||||
if (host.startsWith("[")) {
|
||||
const closing = host.indexOf("]")
|
||||
if (closing !== -1) {
|
||||
host = host.slice(0, closing + 1)
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
if (host.includes(":")) {
|
||||
host = host.split(":")[0]
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
function normalizeResource(resource: string): string {
|
||||
return resource.replace(/^\/+/, "")
|
||||
}
|
||||
|
||||
function buildKey(serverHost: string, resource: string): string {
|
||||
const host = normalizeHost(serverHost)
|
||||
const normalizedResource = normalizeResource(resource)
|
||||
return `${host}/${normalizedResource}`
|
||||
}
|
||||
|
||||
function readStore(): TokenStore | null {
|
||||
const filePath = getMcpOauthStoragePath()
|
||||
if (!existsSync(filePath)) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(filePath, "utf-8")
|
||||
return JSON.parse(content) as TokenStore
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function writeStore(store: TokenStore): boolean {
|
||||
const filePath = getMcpOauthStoragePath()
|
||||
|
||||
try {
|
||||
const dir = dirname(filePath)
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true })
|
||||
}
|
||||
|
||||
writeFileSync(filePath, JSON.stringify(store, null, 2), { encoding: "utf-8", mode: 0o600 })
|
||||
chmodSync(filePath, 0o600)
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export function loadToken(serverHost: string, resource: string): OAuthTokenData | null {
|
||||
const store = readStore()
|
||||
if (!store) return null
|
||||
|
||||
const key = buildKey(serverHost, resource)
|
||||
return store[key] ?? null
|
||||
}
|
||||
|
||||
export function saveToken(serverHost: string, resource: string, token: OAuthTokenData): boolean {
|
||||
const store = readStore() ?? {}
|
||||
const key = buildKey(serverHost, resource)
|
||||
store[key] = token
|
||||
return writeStore(store)
|
||||
}
|
||||
|
||||
export function deleteToken(serverHost: string, resource: string): boolean {
|
||||
const store = readStore()
|
||||
if (!store) return true
|
||||
|
||||
const key = buildKey(serverHost, resource)
|
||||
if (!(key in store)) {
|
||||
return true
|
||||
}
|
||||
|
||||
delete store[key]
|
||||
|
||||
if (Object.keys(store).length === 0) {
|
||||
try {
|
||||
const filePath = getMcpOauthStoragePath()
|
||||
if (existsSync(filePath)) {
|
||||
unlinkSync(filePath)
|
||||
}
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return writeStore(store)
|
||||
}
|
||||
|
||||
export function listTokensByHost(serverHost: string): TokenStore {
|
||||
const store = readStore()
|
||||
if (!store) return {}
|
||||
|
||||
const host = normalizeHost(serverHost)
|
||||
const prefix = `${host}/`
|
||||
const result: TokenStore = {}
|
||||
|
||||
for (const [key, value] of Object.entries(store)) {
|
||||
if (key.startsWith(prefix)) {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
export function listAllTokens(): TokenStore {
|
||||
return readStore() ?? {}
|
||||
}
|
||||
@@ -3,8 +3,6 @@ import { SkillMcpManager } from "./manager"
|
||||
import type { SkillMcpClientInfo, SkillMcpServerContext } from "./types"
|
||||
import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types"
|
||||
|
||||
|
||||
|
||||
// Mock the MCP SDK transports to avoid network calls
|
||||
const mockHttpConnect = mock(() => Promise.reject(new Error("Mocked HTTP connection failure")))
|
||||
const mockHttpClose = mock(() => Promise.resolve())
|
||||
@@ -24,6 +22,21 @@ mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
const mockTokens = mock(() => null as { accessToken: string; refreshToken?: string; expiresAt?: number } | null)
|
||||
const mockLogin = mock(() => Promise.resolve({ accessToken: "new-token" }))
|
||||
|
||||
mock.module("../mcp-oauth/provider", () => ({
|
||||
McpOAuthProvider: class MockMcpOAuthProvider {
|
||||
constructor(public options: { serverUrl: string; clientId?: string; scopes?: string[] }) {}
|
||||
tokens() {
|
||||
return mockTokens()
|
||||
}
|
||||
async login() {
|
||||
return mockLogin()
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -518,7 +531,6 @@ describe("SkillMcpManager", () => {
|
||||
skillName: "retry-skill",
|
||||
}
|
||||
|
||||
// Mock client that fails first time with "Not connected", then succeeds
|
||||
let callCount = 0
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
@@ -531,7 +543,6 @@ describe("SkillMcpManager", () => {
|
||||
close: mock(() => Promise.resolve()),
|
||||
}
|
||||
|
||||
// Spy on getOrCreateClientWithRetry to inject mock client
|
||||
const getOrCreateSpy = spyOn(manager as any, "getOrCreateClientWithRetry")
|
||||
getOrCreateSpy.mockResolvedValue(mockClient)
|
||||
|
||||
@@ -539,9 +550,9 @@ describe("SkillMcpManager", () => {
|
||||
const result = await manager.callTool(info, context, "test-tool", {})
|
||||
|
||||
// #then
|
||||
expect(callCount).toBe(2) // First call fails, second succeeds
|
||||
expect(callCount).toBe(2)
|
||||
expect(result).toEqual([{ type: "text", text: "success" }])
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(2) // Called twice due to retry
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it("should fail after 3 retry attempts", async () => {
|
||||
@@ -558,7 +569,6 @@ describe("SkillMcpManager", () => {
|
||||
skillName: "fail-skill",
|
||||
}
|
||||
|
||||
// Mock client that always fails with "Not connected"
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
throw new Error("Not connected")
|
||||
@@ -573,7 +583,7 @@ describe("SkillMcpManager", () => {
|
||||
await expect(manager.callTool(info, context, "test-tool", {})).rejects.toThrow(
|
||||
/Failed after 3 reconnection attempts/
|
||||
)
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(3) // Initial + 2 retries
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it("should not retry on non-connection errors", async () => {
|
||||
@@ -590,7 +600,6 @@ describe("SkillMcpManager", () => {
|
||||
skillName: "error-skill",
|
||||
}
|
||||
|
||||
// Mock client that fails with non-connection error
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
throw new Error("Tool not found")
|
||||
@@ -605,7 +614,194 @@ describe("SkillMcpManager", () => {
|
||||
await expect(manager.callTool(info, context, "test-tool", {})).rejects.toThrow(
|
||||
"Tool not found"
|
||||
)
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(1) // No retry
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe("OAuth integration", () => {
|
||||
beforeEach(() => {
|
||||
mockTokens.mockClear()
|
||||
mockLogin.mockClear()
|
||||
})
|
||||
|
||||
it("injects Authorization header when oauth config has stored tokens", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "oauth-server",
|
||||
skillName: "oauth-skill",
|
||||
sessionID: "session-oauth-1",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
oauth: {
|
||||
clientId: "my-client",
|
||||
scopes: ["read", "write"],
|
||||
},
|
||||
}
|
||||
mockTokens.mockReturnValue({ accessToken: "stored-access-token" })
|
||||
|
||||
// #when
|
||||
try {
|
||||
await manager.getOrCreateClient(info, config)
|
||||
} catch { /* connection fails in test */ }
|
||||
|
||||
// #then
|
||||
const headers = lastTransportInstance.options?.requestInit?.headers as Record<string, string> | undefined
|
||||
expect(headers?.Authorization).toBe("Bearer stored-access-token")
|
||||
})
|
||||
|
||||
it("does not inject Authorization header when no stored tokens exist and login fails", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "oauth-no-token",
|
||||
skillName: "oauth-skill",
|
||||
sessionID: "session-oauth-2",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
oauth: {
|
||||
clientId: "my-client",
|
||||
},
|
||||
}
|
||||
mockTokens.mockReturnValue(null)
|
||||
mockLogin.mockRejectedValue(new Error("Login failed"))
|
||||
|
||||
// #when
|
||||
try {
|
||||
await manager.getOrCreateClient(info, config)
|
||||
} catch { /* connection fails in test */ }
|
||||
|
||||
// #then
|
||||
const headers = lastTransportInstance.options?.requestInit?.headers as Record<string, string> | undefined
|
||||
expect(headers?.Authorization).toBeUndefined()
|
||||
})
|
||||
|
||||
it("preserves existing static headers alongside OAuth token", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "oauth-with-headers",
|
||||
skillName: "oauth-skill",
|
||||
sessionID: "session-oauth-3",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
headers: {
|
||||
"X-Custom": "custom-value",
|
||||
},
|
||||
oauth: {
|
||||
clientId: "my-client",
|
||||
},
|
||||
}
|
||||
mockTokens.mockReturnValue({ accessToken: "oauth-token" })
|
||||
|
||||
// #when
|
||||
try {
|
||||
await manager.getOrCreateClient(info, config)
|
||||
} catch { /* connection fails in test */ }
|
||||
|
||||
// #then
|
||||
const headers = lastTransportInstance.options?.requestInit?.headers as Record<string, string> | undefined
|
||||
expect(headers?.["X-Custom"]).toBe("custom-value")
|
||||
expect(headers?.Authorization).toBe("Bearer oauth-token")
|
||||
})
|
||||
|
||||
it("does not create auth provider when oauth config is absent", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "no-oauth-server",
|
||||
skillName: "test-skill",
|
||||
sessionID: "session-no-oauth",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
headers: {
|
||||
Authorization: "Bearer static-token",
|
||||
},
|
||||
}
|
||||
|
||||
// #when
|
||||
try {
|
||||
await manager.getOrCreateClient(info, config)
|
||||
} catch { /* connection fails in test */ }
|
||||
|
||||
// #then
|
||||
const headers = lastTransportInstance.options?.requestInit?.headers as Record<string, string> | undefined
|
||||
expect(headers?.Authorization).toBe("Bearer static-token")
|
||||
expect(mockTokens).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it("handles step-up auth by triggering re-login on 403 with scope", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "stepup-server",
|
||||
skillName: "stepup-skill",
|
||||
sessionID: "session-stepup-1",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
oauth: {
|
||||
clientId: "my-client",
|
||||
scopes: ["read"],
|
||||
},
|
||||
}
|
||||
const context: SkillMcpServerContext = {
|
||||
config,
|
||||
skillName: "stepup-skill",
|
||||
}
|
||||
|
||||
mockTokens.mockReturnValue({ accessToken: "initial-token" })
|
||||
mockLogin.mockResolvedValue({ accessToken: "upgraded-token" })
|
||||
|
||||
let callCount = 0
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
throw new Error('403 WWW-Authenticate: Bearer scope="admin write"')
|
||||
}
|
||||
return { content: [{ type: "text", text: "success" }] }
|
||||
}),
|
||||
close: mock(() => Promise.resolve()),
|
||||
}
|
||||
|
||||
const getOrCreateSpy = spyOn(manager as any, "getOrCreateClientWithRetry")
|
||||
getOrCreateSpy.mockResolvedValue(mockClient)
|
||||
|
||||
// #when
|
||||
const result = await manager.callTool(info, context, "test-tool", {})
|
||||
|
||||
// #then
|
||||
expect(result).toEqual([{ type: "text", text: "success" }])
|
||||
expect(mockLogin).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it("does not attempt step-up when oauth config is absent", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "no-stepup-server",
|
||||
skillName: "no-stepup-skill",
|
||||
sessionID: "session-no-stepup",
|
||||
}
|
||||
const context: SkillMcpServerContext = {
|
||||
config: {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
},
|
||||
skillName: "no-stepup-skill",
|
||||
}
|
||||
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
throw new Error('403 WWW-Authenticate: Bearer scope="admin"')
|
||||
}),
|
||||
close: mock(() => Promise.resolve()),
|
||||
}
|
||||
|
||||
const getOrCreateSpy = spyOn(manager as any, "getOrCreateClientWithRetry")
|
||||
getOrCreateSpy.mockResolvedValue(mockClient)
|
||||
|
||||
// #when / #then
|
||||
await expect(manager.callTool(info, context, "test-tool", {})).rejects.toThrow(/403/)
|
||||
expect(mockLogin).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,6 +4,8 @@ import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/
|
||||
import type { Tool, Resource, Prompt } from "@modelcontextprotocol/sdk/types.js"
|
||||
import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types"
|
||||
import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander"
|
||||
import { McpOAuthProvider } from "../mcp-oauth/provider"
|
||||
import { isStepUpRequired, mergeScopes } from "../mcp-oauth/step-up"
|
||||
import { createCleanMcpEnvironment } from "./env-cleaner"
|
||||
import type { SkillMcpClientInfo, SkillMcpServerContext } from "./types"
|
||||
|
||||
@@ -60,6 +62,7 @@ function getConnectionType(config: ClaudeCodeMcpServer): ConnectionType | null {
|
||||
export class SkillMcpManager {
|
||||
private clients: Map<string, ManagedClient> = new Map()
|
||||
private pendingConnections: Map<string, Promise<Client>> = new Map()
|
||||
private authProviders: Map<string, McpOAuthProvider> = new Map()
|
||||
private cleanupRegistered = false
|
||||
private cleanupInterval: ReturnType<typeof setInterval> | null = null
|
||||
private readonly IDLE_TIMEOUT = 5 * 60 * 1000
|
||||
@@ -68,6 +71,28 @@ export class SkillMcpManager {
|
||||
return `${info.sessionID}:${info.skillName}:${info.serverName}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create an McpOAuthProvider for a given server URL + oauth config.
|
||||
* Providers are cached by server URL to reuse tokens across reconnections.
|
||||
*/
|
||||
private getOrCreateAuthProvider(
|
||||
serverUrl: string,
|
||||
oauth: NonNullable<ClaudeCodeMcpServer["oauth"]>
|
||||
): McpOAuthProvider {
|
||||
const existing = this.authProviders.get(serverUrl)
|
||||
if (existing) {
|
||||
return existing
|
||||
}
|
||||
|
||||
const provider = new McpOAuthProvider({
|
||||
serverUrl,
|
||||
clientId: oauth.clientId,
|
||||
scopes: oauth.scopes,
|
||||
})
|
||||
this.authProviders.set(serverUrl, provider)
|
||||
return provider
|
||||
}
|
||||
|
||||
private registerProcessCleanup(): void {
|
||||
if (this.cleanupRegistered) return
|
||||
this.cleanupRegistered = true
|
||||
@@ -204,7 +229,30 @@ export class SkillMcpManager {
|
||||
// Build request init with headers if provided
|
||||
const requestInit: RequestInit = {}
|
||||
if (config.headers && Object.keys(config.headers).length > 0) {
|
||||
requestInit.headers = config.headers
|
||||
requestInit.headers = { ...config.headers }
|
||||
}
|
||||
|
||||
let authProvider: McpOAuthProvider | undefined
|
||||
if (config.oauth) {
|
||||
authProvider = this.getOrCreateAuthProvider(config.url, config.oauth)
|
||||
let tokenData = authProvider.tokens()
|
||||
|
||||
const isExpired = tokenData?.expiresAt != null && tokenData.expiresAt < Math.floor(Date.now() / 1000)
|
||||
if (!tokenData || isExpired) {
|
||||
try {
|
||||
tokenData = await authProvider.login()
|
||||
} catch {
|
||||
// Login failed — proceed without auth header
|
||||
}
|
||||
}
|
||||
|
||||
if (tokenData) {
|
||||
const existingHeaders = (requestInit.headers ?? {}) as Record<string, string>
|
||||
requestInit.headers = {
|
||||
...existingHeaders,
|
||||
Authorization: `Bearer ${tokenData.accessToken}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const transport = new StreamableHTTPClientTransport(url, {
|
||||
@@ -460,6 +508,12 @@ export class SkillMcpManager {
|
||||
lastError = error instanceof Error ? error : new Error(String(error))
|
||||
const errorMessage = lastError.message.toLowerCase()
|
||||
|
||||
const stepUpHandled = await this.handleStepUpIfNeeded(lastError, config)
|
||||
if (stepUpHandled) {
|
||||
await this.forceReconnect(info)
|
||||
continue
|
||||
}
|
||||
|
||||
if (!errorMessage.includes("not connected")) {
|
||||
throw lastError
|
||||
}
|
||||
@@ -470,23 +524,66 @@ export class SkillMcpManager {
|
||||
)
|
||||
}
|
||||
|
||||
const key = this.getClientKey(info)
|
||||
const existing = this.clients.get(key)
|
||||
if (existing) {
|
||||
this.clients.delete(key)
|
||||
try {
|
||||
await existing.client.close()
|
||||
} catch { /* process may already be terminated */ }
|
||||
try {
|
||||
await existing.transport.close()
|
||||
} catch { /* transport may already be terminated */ }
|
||||
}
|
||||
await this.forceReconnect(info)
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError || new Error("Operation failed with unknown error")
|
||||
}
|
||||
|
||||
private async handleStepUpIfNeeded(
|
||||
error: Error,
|
||||
config: ClaudeCodeMcpServer
|
||||
): Promise<boolean> {
|
||||
if (!config.oauth || !config.url) {
|
||||
return false
|
||||
}
|
||||
|
||||
const statusMatch = /\b403\b/.exec(error.message)
|
||||
if (!statusMatch) {
|
||||
return false
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {}
|
||||
const wwwAuthMatch = /WWW-Authenticate:\s*(.+)/i.exec(error.message)
|
||||
if (wwwAuthMatch?.[1]) {
|
||||
headers["www-authenticate"] = wwwAuthMatch[1]
|
||||
}
|
||||
|
||||
const stepUp = isStepUpRequired(403, headers)
|
||||
if (!stepUp) {
|
||||
return false
|
||||
}
|
||||
|
||||
const currentScopes = config.oauth.scopes ?? []
|
||||
const merged = mergeScopes(currentScopes, stepUp.requiredScopes)
|
||||
config.oauth.scopes = merged
|
||||
|
||||
this.authProviders.delete(config.url)
|
||||
const provider = this.getOrCreateAuthProvider(config.url, config.oauth)
|
||||
|
||||
try {
|
||||
await provider.login()
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private async forceReconnect(info: SkillMcpClientInfo): Promise<void> {
|
||||
const key = this.getClientKey(info)
|
||||
const existing = this.clients.get(key)
|
||||
if (existing) {
|
||||
this.clients.delete(key)
|
||||
try {
|
||||
await existing.client.close()
|
||||
} catch { /* process may already be terminated */ }
|
||||
try {
|
||||
await existing.transport.close()
|
||||
} catch { /* transport may already be terminated */ }
|
||||
}
|
||||
}
|
||||
|
||||
private async getOrCreateClientWithRetry(
|
||||
info: SkillMcpClientInfo,
|
||||
config: ClaudeCodeMcpServer
|
||||
|
||||
@@ -2,6 +2,7 @@ import { describe, test, expect, mock, beforeEach } from 'bun:test'
|
||||
import type { TmuxConfig } from '../../config/schema'
|
||||
import type { WindowState, PaneAction } from './types'
|
||||
import type { ActionResult, ExecuteContext } from './action-executor'
|
||||
import type { TmuxUtilDeps } from './manager'
|
||||
|
||||
type ExecuteActionsResult = {
|
||||
success: boolean
|
||||
@@ -33,6 +34,11 @@ const mockExecuteAction = mock<(
|
||||
const mockIsInsideTmux = mock<() => boolean>(() => true)
|
||||
const mockGetCurrentPaneId = mock<() => string | undefined>(() => '%0')
|
||||
|
||||
const mockTmuxDeps: TmuxUtilDeps = {
|
||||
isInsideTmux: mockIsInsideTmux,
|
||||
getCurrentPaneId: mockGetCurrentPaneId,
|
||||
}
|
||||
|
||||
mock.module('./pane-state-querier', () => ({
|
||||
queryWindowState: mockQueryWindowState,
|
||||
paneExists: mockPaneExists,
|
||||
@@ -51,15 +57,19 @@ mock.module('./action-executor', () => ({
|
||||
executeAction: mockExecuteAction,
|
||||
}))
|
||||
|
||||
mock.module('../../shared/tmux', () => ({
|
||||
isInsideTmux: mockIsInsideTmux,
|
||||
getCurrentPaneId: mockGetCurrentPaneId,
|
||||
POLL_INTERVAL_BACKGROUND_MS: 2000,
|
||||
SESSION_TIMEOUT_MS: 600000,
|
||||
SESSION_MISSING_GRACE_MS: 6000,
|
||||
SESSION_READY_POLL_INTERVAL_MS: 100,
|
||||
SESSION_READY_TIMEOUT_MS: 500,
|
||||
}))
|
||||
mock.module('../../shared/tmux', () => {
|
||||
const { isInsideTmux, getCurrentPaneId } = require('../../shared/tmux/tmux-utils')
|
||||
const { POLL_INTERVAL_BACKGROUND_MS, SESSION_TIMEOUT_MS, SESSION_MISSING_GRACE_MS } = require('../../shared/tmux/constants')
|
||||
return {
|
||||
isInsideTmux,
|
||||
getCurrentPaneId,
|
||||
POLL_INTERVAL_BACKGROUND_MS,
|
||||
SESSION_TIMEOUT_MS,
|
||||
SESSION_MISSING_GRACE_MS,
|
||||
SESSION_READY_POLL_INTERVAL_MS: 100,
|
||||
SESSION_READY_TIMEOUT_MS: 500,
|
||||
}
|
||||
})
|
||||
|
||||
const trackedSessions = new Set<string>()
|
||||
|
||||
@@ -148,7 +158,7 @@ describe('TmuxSessionManager', () => {
|
||||
}
|
||||
|
||||
//#when
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
|
||||
//#then
|
||||
expect(manager).toBeDefined()
|
||||
@@ -168,7 +178,7 @@ describe('TmuxSessionManager', () => {
|
||||
}
|
||||
|
||||
//#when
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
|
||||
//#then
|
||||
expect(manager).toBeDefined()
|
||||
@@ -188,7 +198,7 @@ describe('TmuxSessionManager', () => {
|
||||
}
|
||||
|
||||
//#when
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
|
||||
//#then
|
||||
expect(manager).toBeDefined()
|
||||
@@ -210,7 +220,7 @@ describe('TmuxSessionManager', () => {
|
||||
main_pane_min_width: 80,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
const event = createSessionCreatedEvent(
|
||||
'ses_child',
|
||||
'ses_parent',
|
||||
@@ -271,7 +281,7 @@ describe('TmuxSessionManager', () => {
|
||||
main_pane_min_width: 80,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
|
||||
//#when - first agent
|
||||
await manager.onSessionCreated(
|
||||
@@ -305,7 +315,7 @@ describe('TmuxSessionManager', () => {
|
||||
main_pane_min_width: 80,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
const event = createSessionCreatedEvent('ses_root', undefined, 'Root Session')
|
||||
|
||||
//#when
|
||||
@@ -327,7 +337,7 @@ describe('TmuxSessionManager', () => {
|
||||
main_pane_min_width: 80,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
const event = createSessionCreatedEvent(
|
||||
'ses_child',
|
||||
'ses_parent',
|
||||
@@ -353,7 +363,7 @@ describe('TmuxSessionManager', () => {
|
||||
main_pane_min_width: 80,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
const event = {
|
||||
type: 'session.deleted',
|
||||
properties: {
|
||||
@@ -398,7 +408,7 @@ describe('TmuxSessionManager', () => {
|
||||
main_pane_min_width: 120,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
|
||||
//#when
|
||||
await manager.onSessionCreated(
|
||||
@@ -450,7 +460,7 @@ describe('TmuxSessionManager', () => {
|
||||
main_pane_min_width: 80,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
|
||||
await manager.onSessionCreated(
|
||||
createSessionCreatedEvent(
|
||||
@@ -487,7 +497,7 @@ describe('TmuxSessionManager', () => {
|
||||
main_pane_min_width: 80,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
|
||||
//#when
|
||||
await manager.onSessionDeleted({ sessionID: 'ses_unknown' })
|
||||
@@ -521,7 +531,7 @@ describe('TmuxSessionManager', () => {
|
||||
main_pane_min_width: 80,
|
||||
agent_pane_min_width: 40,
|
||||
}
|
||||
const manager = new TmuxSessionManager(ctx, config)
|
||||
const manager = new TmuxSessionManager(ctx, config, mockTmuxDeps)
|
||||
|
||||
await manager.onSessionCreated(
|
||||
createSessionCreatedEvent('ses_1', 'ses_parent', 'Task 1')
|
||||
|
||||
@@ -2,8 +2,8 @@ import type { PluginInput } from "@opencode-ai/plugin"
|
||||
import type { TmuxConfig } from "../../config/schema"
|
||||
import type { TrackedSession, CapacityConfig } from "./types"
|
||||
import {
|
||||
isInsideTmux,
|
||||
getCurrentPaneId,
|
||||
isInsideTmux as defaultIsInsideTmux,
|
||||
getCurrentPaneId as defaultGetCurrentPaneId,
|
||||
POLL_INTERVAL_BACKGROUND_MS,
|
||||
SESSION_MISSING_GRACE_MS,
|
||||
SESSION_READY_POLL_INTERVAL_MS,
|
||||
@@ -21,6 +21,16 @@ interface SessionCreatedEvent {
|
||||
properties?: { info?: { id?: string; parentID?: string; title?: string } }
|
||||
}
|
||||
|
||||
export interface TmuxUtilDeps {
|
||||
isInsideTmux: () => boolean
|
||||
getCurrentPaneId: () => string | undefined
|
||||
}
|
||||
|
||||
const defaultTmuxDeps: TmuxUtilDeps = {
|
||||
isInsideTmux: defaultIsInsideTmux,
|
||||
getCurrentPaneId: defaultGetCurrentPaneId,
|
||||
}
|
||||
|
||||
const SESSION_TIMEOUT_MS = 10 * 60 * 1000
|
||||
|
||||
/**
|
||||
@@ -43,13 +53,15 @@ export class TmuxSessionManager {
|
||||
private sessions = new Map<string, TrackedSession>()
|
||||
private pendingSessions = new Set<string>()
|
||||
private pollInterval?: ReturnType<typeof setInterval>
|
||||
private deps: TmuxUtilDeps
|
||||
|
||||
constructor(ctx: PluginInput, tmuxConfig: TmuxConfig) {
|
||||
constructor(ctx: PluginInput, tmuxConfig: TmuxConfig, deps: TmuxUtilDeps = defaultTmuxDeps) {
|
||||
this.client = ctx.client
|
||||
this.tmuxConfig = tmuxConfig
|
||||
this.deps = deps
|
||||
const defaultPort = process.env.OPENCODE_PORT ?? "4096"
|
||||
this.serverUrl = ctx.serverUrl?.toString() ?? `http://localhost:${defaultPort}`
|
||||
this.sourcePaneId = getCurrentPaneId()
|
||||
this.sourcePaneId = deps.getCurrentPaneId()
|
||||
|
||||
log("[tmux-session-manager] initialized", {
|
||||
configEnabled: this.tmuxConfig.enabled,
|
||||
@@ -60,7 +72,7 @@ export class TmuxSessionManager {
|
||||
}
|
||||
|
||||
private isEnabled(): boolean {
|
||||
return this.tmuxConfig.enabled && isInsideTmux()
|
||||
return this.tmuxConfig.enabled && this.deps.isInsideTmux()
|
||||
}
|
||||
|
||||
private getCapacityConfig(): CapacityConfig {
|
||||
@@ -113,7 +125,7 @@ export class TmuxSessionManager {
|
||||
log("[tmux-session-manager] onSessionCreated called", {
|
||||
enabled,
|
||||
tmuxConfigEnabled: this.tmuxConfig.enabled,
|
||||
isInsideTmux: isInsideTmux(),
|
||||
isInsideTmux: this.deps.isInsideTmux(),
|
||||
eventType: event.type,
|
||||
infoId: event.properties?.info?.id,
|
||||
infoParentID: event.properties?.info?.parentID,
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
} from "./storage";
|
||||
import { OMO_SESSION_PREFIX, buildSessionReminderMessage } from "./constants";
|
||||
import type { InteractiveBashSessionState } from "./types";
|
||||
import { subagentSessions } from "../../features/claude-code-session-state";
|
||||
|
||||
interface ToolExecuteInput {
|
||||
tool: string;
|
||||
@@ -146,7 +147,7 @@ function findSubcommand(tokens: string[]): string {
|
||||
return ""
|
||||
}
|
||||
|
||||
export function createInteractiveBashSessionHook(_ctx: PluginInput) {
|
||||
export function createInteractiveBashSessionHook(ctx: PluginInput) {
|
||||
const sessionStates = new Map<string, InteractiveBashSessionState>();
|
||||
|
||||
function getOrCreateState(sessionID: string): InteractiveBashSessionState {
|
||||
@@ -178,6 +179,10 @@ export function createInteractiveBashSessionHook(_ctx: PluginInput) {
|
||||
await proc.exited;
|
||||
} catch {}
|
||||
}
|
||||
|
||||
for (const sessionId of subagentSessions) {
|
||||
ctx.client.session.abort({ path: { id: sessionId } }).catch(() => {})
|
||||
}
|
||||
}
|
||||
|
||||
const toolExecuteAfter = async (
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
clearBoulderState,
|
||||
} from "../../features/boulder-state"
|
||||
import { log } from "../../shared/logger"
|
||||
import { updateSessionAgent } from "../../features/claude-code-session-state"
|
||||
import { getSessionAgent, updateSessionAgent } from "../../features/claude-code-session-state"
|
||||
|
||||
export const HOOK_NAME = "start-work"
|
||||
|
||||
@@ -71,7 +71,10 @@ export function createStartWorkHook(ctx: PluginInput) {
|
||||
sessionID: input.sessionID,
|
||||
})
|
||||
|
||||
updateSessionAgent(input.sessionID, "atlas")
|
||||
const currentAgent = getSessionAgent(input.sessionID)
|
||||
if (!currentAgent) {
|
||||
updateSessionAgent(input.sessionID, "atlas")
|
||||
}
|
||||
|
||||
const existingState = readBoulderState(ctx.directory)
|
||||
const sessionId = input.sessionID
|
||||
|
||||
@@ -205,7 +205,11 @@ export function createTodoContinuationEnforcer(
|
||||
const prevMessage = messageDir ? findNearestMessageWithFields(messageDir) : null
|
||||
agentName = agentName ?? prevMessage?.agent
|
||||
model = model ?? (prevMessage?.model?.providerID && prevMessage?.model?.modelID
|
||||
? { providerID: prevMessage.model.providerID, modelID: prevMessage.model.modelID }
|
||||
? {
|
||||
providerID: prevMessage.model.providerID,
|
||||
modelID: prevMessage.model.modelID,
|
||||
...(prevMessage.model.variant ? { variant: prevMessage.model.variant } : {})
|
||||
}
|
||||
: undefined)
|
||||
tools = tools ?? prevMessage?.tools
|
||||
}
|
||||
|
||||
@@ -144,10 +144,6 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
||||
isOpenCodeVersionAtLeast(OPENCODE_NATIVE_AGENTS_INJECTION_VERSION);
|
||||
|
||||
if (hasNativeSupport) {
|
||||
console.warn(
|
||||
`[oh-my-opencode] directory-agents-injector hook auto-disabled: ` +
|
||||
`OpenCode ${currentVersion} has native AGENTS.md support (>= ${OPENCODE_NATIVE_AGENTS_INJECTION_VERSION})`
|
||||
);
|
||||
log("directory-agents-injector auto-disabled due to native OpenCode support", {
|
||||
currentVersion,
|
||||
nativeVersion: OPENCODE_NATIVE_AGENTS_INJECTION_VERSION,
|
||||
@@ -268,6 +264,11 @@ const OhMyOpenCodePlugin: Plugin = async (ctx) => {
|
||||
});
|
||||
log("[index] onSubagentSessionCreated callback completed");
|
||||
},
|
||||
onShutdown: () => {
|
||||
tmuxSessionManager.cleanup().catch((error) => {
|
||||
log("[index] tmux cleanup error during shutdown:", error)
|
||||
})
|
||||
},
|
||||
});
|
||||
|
||||
const atlasHook = isHookEnabled("atlas")
|
||||
|
||||
@@ -1,56 +1,56 @@
|
||||
import { describe, test, expect, mock, beforeEach } from "bun:test"
|
||||
import { describe, test, expect, spyOn, beforeEach, afterEach } from "bun:test"
|
||||
import { resolveCategoryConfig, createConfigHandler } from "./config-handler"
|
||||
import type { CategoryConfig } from "../config/schema"
|
||||
import type { OhMyOpenCodeConfig } from "../config"
|
||||
|
||||
mock.module("../agents", () => ({
|
||||
createBuiltinAgents: async () => ({
|
||||
import * as agents from "../agents"
|
||||
import * as sisyphusJunior from "../agents/sisyphus-junior"
|
||||
import * as commandLoader from "../features/claude-code-command-loader"
|
||||
import * as builtinCommands from "../features/builtin-commands"
|
||||
import * as skillLoader from "../features/opencode-skill-loader"
|
||||
import * as agentLoader from "../features/claude-code-agent-loader"
|
||||
import * as mcpLoader from "../features/claude-code-mcp-loader"
|
||||
import * as pluginLoader from "../features/claude-code-plugin-loader"
|
||||
import * as mcpModule from "../mcp"
|
||||
import * as shared from "../shared"
|
||||
import * as configDir from "../shared/opencode-config-dir"
|
||||
import * as permissionCompat from "../shared/permission-compat"
|
||||
import * as modelResolver from "../shared/model-resolver"
|
||||
|
||||
beforeEach(() => {
|
||||
spyOn(agents, "createBuiltinAgents" as any).mockResolvedValue({
|
||||
sisyphus: { name: "sisyphus", prompt: "test", mode: "primary" },
|
||||
oracle: { name: "oracle", prompt: "test", mode: "subagent" },
|
||||
}),
|
||||
}))
|
||||
})
|
||||
|
||||
mock.module("../agents/sisyphus-junior", () => ({
|
||||
createSisyphusJuniorAgentWithOverrides: () => ({
|
||||
spyOn(sisyphusJunior, "createSisyphusJuniorAgentWithOverrides" as any).mockReturnValue({
|
||||
name: "sisyphus-junior",
|
||||
prompt: "test",
|
||||
mode: "subagent",
|
||||
}),
|
||||
}))
|
||||
})
|
||||
|
||||
mock.module("../features/claude-code-command-loader", () => ({
|
||||
loadUserCommands: async () => ({}),
|
||||
loadProjectCommands: async () => ({}),
|
||||
loadOpencodeGlobalCommands: async () => ({}),
|
||||
loadOpencodeProjectCommands: async () => ({}),
|
||||
}))
|
||||
spyOn(commandLoader, "loadUserCommands" as any).mockResolvedValue({})
|
||||
spyOn(commandLoader, "loadProjectCommands" as any).mockResolvedValue({})
|
||||
spyOn(commandLoader, "loadOpencodeGlobalCommands" as any).mockResolvedValue({})
|
||||
spyOn(commandLoader, "loadOpencodeProjectCommands" as any).mockResolvedValue({})
|
||||
|
||||
mock.module("../features/builtin-commands", () => ({
|
||||
loadBuiltinCommands: () => ({}),
|
||||
}))
|
||||
spyOn(builtinCommands, "loadBuiltinCommands" as any).mockReturnValue({})
|
||||
|
||||
mock.module("../features/opencode-skill-loader", () => ({
|
||||
loadUserSkills: async () => ({}),
|
||||
loadProjectSkills: async () => ({}),
|
||||
loadOpencodeGlobalSkills: async () => ({}),
|
||||
loadOpencodeProjectSkills: async () => ({}),
|
||||
discoverUserClaudeSkills: async () => [],
|
||||
discoverProjectClaudeSkills: async () => [],
|
||||
discoverOpencodeGlobalSkills: async () => [],
|
||||
discoverOpencodeProjectSkills: async () => [],
|
||||
}))
|
||||
spyOn(skillLoader, "loadUserSkills" as any).mockResolvedValue({})
|
||||
spyOn(skillLoader, "loadProjectSkills" as any).mockResolvedValue({})
|
||||
spyOn(skillLoader, "loadOpencodeGlobalSkills" as any).mockResolvedValue({})
|
||||
spyOn(skillLoader, "loadOpencodeProjectSkills" as any).mockResolvedValue({})
|
||||
spyOn(skillLoader, "discoverUserClaudeSkills" as any).mockResolvedValue([])
|
||||
spyOn(skillLoader, "discoverProjectClaudeSkills" as any).mockResolvedValue([])
|
||||
spyOn(skillLoader, "discoverOpencodeGlobalSkills" as any).mockResolvedValue([])
|
||||
spyOn(skillLoader, "discoverOpencodeProjectSkills" as any).mockResolvedValue([])
|
||||
|
||||
mock.module("../features/claude-code-agent-loader", () => ({
|
||||
loadUserAgents: () => ({}),
|
||||
loadProjectAgents: () => ({}),
|
||||
}))
|
||||
spyOn(agentLoader, "loadUserAgents" as any).mockReturnValue({})
|
||||
spyOn(agentLoader, "loadProjectAgents" as any).mockReturnValue({})
|
||||
|
||||
mock.module("../features/claude-code-mcp-loader", () => ({
|
||||
loadMcpConfigs: async () => ({ servers: {} }),
|
||||
}))
|
||||
spyOn(mcpLoader, "loadMcpConfigs" as any).mockResolvedValue({ servers: {} })
|
||||
|
||||
mock.module("../features/claude-code-plugin-loader", () => ({
|
||||
loadAllPluginComponents: async () => ({
|
||||
spyOn(pluginLoader, "loadAllPluginComponents" as any).mockResolvedValue({
|
||||
commands: {},
|
||||
skills: {},
|
||||
agents: {},
|
||||
@@ -58,60 +58,52 @@ mock.module("../features/claude-code-plugin-loader", () => ({
|
||||
hooksConfigs: [],
|
||||
plugins: [],
|
||||
errors: [],
|
||||
}),
|
||||
}))
|
||||
})
|
||||
|
||||
mock.module("../mcp", () => ({
|
||||
createBuiltinMcps: () => ({}),
|
||||
}))
|
||||
spyOn(mcpModule, "createBuiltinMcps" as any).mockReturnValue({})
|
||||
|
||||
mock.module("../shared", () => ({
|
||||
log: () => {},
|
||||
fetchAvailableModels: async () => new Set(["anthropic/claude-opus-4-5"]),
|
||||
readConnectedProvidersCache: () => null,
|
||||
}))
|
||||
spyOn(shared, "log" as any).mockImplementation(() => {})
|
||||
spyOn(shared, "fetchAvailableModels" as any).mockResolvedValue(new Set(["anthropic/claude-opus-4-5"]))
|
||||
spyOn(shared, "readConnectedProvidersCache" as any).mockReturnValue(null)
|
||||
|
||||
mock.module("../shared/opencode-config-dir", () => ({
|
||||
getOpenCodeConfigPaths: () => ({
|
||||
spyOn(configDir, "getOpenCodeConfigPaths" as any).mockReturnValue({
|
||||
global: "/tmp/.config/opencode",
|
||||
project: "/tmp/.opencode",
|
||||
}),
|
||||
}))
|
||||
})
|
||||
|
||||
mock.module("../shared/permission-compat", () => ({
|
||||
migrateAgentConfig: (config: Record<string, unknown>) => config,
|
||||
}))
|
||||
spyOn(permissionCompat, "migrateAgentConfig" as any).mockImplementation((config: Record<string, unknown>) => config)
|
||||
|
||||
mock.module("../shared/migration", () => ({
|
||||
AGENT_NAME_MAP: {},
|
||||
}))
|
||||
spyOn(modelResolver, "resolveModelWithFallback" as any).mockReturnValue({ model: "anthropic/claude-opus-4-5" })
|
||||
})
|
||||
|
||||
mock.module("../shared/model-resolver", () => ({
|
||||
resolveModelWithFallback: () => ({ model: "anthropic/claude-opus-4-5" }),
|
||||
}))
|
||||
|
||||
mock.module("../shared/model-requirements", () => ({
|
||||
AGENT_MODEL_REQUIREMENTS: {
|
||||
sisyphus: { fallbackChain: [{ providers: ["anthropic", "github-copilot", "opencode"], model: "claude-opus-4-5" }] },
|
||||
oracle: { fallbackChain: [{ providers: ["openai", "github-copilot", "opencode"], model: "gpt-5.2" }] },
|
||||
librarian: { fallbackChain: [{ providers: ["anthropic", "github-copilot", "opencode"], model: "claude-sonnet-4-5" }] },
|
||||
explore: { fallbackChain: [{ providers: ["anthropic", "opencode"], model: "claude-haiku-4-5" }] },
|
||||
"multimodal-looker": { fallbackChain: [{ providers: ["google", "github-copilot", "opencode"], model: "gemini-3-flash" }] },
|
||||
prometheus: { fallbackChain: [{ providers: ["anthropic", "github-copilot", "opencode"], model: "claude-opus-4-5" }] },
|
||||
metis: { fallbackChain: [{ providers: ["anthropic", "github-copilot", "opencode"], model: "claude-opus-4-5" }] },
|
||||
momus: { fallbackChain: [{ providers: ["openai", "github-copilot", "opencode"], model: "gpt-5.2" }] },
|
||||
atlas: { fallbackChain: [{ providers: ["anthropic", "github-copilot", "opencode"], model: "claude-sonnet-4-5" }] },
|
||||
},
|
||||
CATEGORY_MODEL_REQUIREMENTS: {
|
||||
"visual-engineering": { fallbackChain: [{ providers: ["google", "github-copilot", "opencode"], model: "gemini-3-pro" }] },
|
||||
ultrabrain: { fallbackChain: [{ providers: ["openai", "github-copilot", "opencode"], model: "gpt-5.2-codex" }] },
|
||||
artistry: { fallbackChain: [{ providers: ["google", "github-copilot", "opencode"], model: "gemini-3-pro" }] },
|
||||
quick: { fallbackChain: [{ providers: ["anthropic", "github-copilot", "opencode"], model: "claude-haiku-4-5" }] },
|
||||
"unspecified-low": { fallbackChain: [{ providers: ["anthropic", "github-copilot", "opencode"], model: "claude-sonnet-4-5" }] },
|
||||
"unspecified-high": { fallbackChain: [{ providers: ["anthropic", "github-copilot", "opencode"], model: "claude-opus-4-5" }] },
|
||||
writing: { fallbackChain: [{ providers: ["google", "github-copilot", "opencode"], model: "gemini-3-flash" }] },
|
||||
},
|
||||
}))
|
||||
afterEach(() => {
|
||||
(agents.createBuiltinAgents as any)?.mockRestore?.()
|
||||
;(sisyphusJunior.createSisyphusJuniorAgentWithOverrides as any)?.mockRestore?.()
|
||||
;(commandLoader.loadUserCommands as any)?.mockRestore?.()
|
||||
;(commandLoader.loadProjectCommands as any)?.mockRestore?.()
|
||||
;(commandLoader.loadOpencodeGlobalCommands as any)?.mockRestore?.()
|
||||
;(commandLoader.loadOpencodeProjectCommands as any)?.mockRestore?.()
|
||||
;(builtinCommands.loadBuiltinCommands as any)?.mockRestore?.()
|
||||
;(skillLoader.loadUserSkills as any)?.mockRestore?.()
|
||||
;(skillLoader.loadProjectSkills as any)?.mockRestore?.()
|
||||
;(skillLoader.loadOpencodeGlobalSkills as any)?.mockRestore?.()
|
||||
;(skillLoader.loadOpencodeProjectSkills as any)?.mockRestore?.()
|
||||
;(skillLoader.discoverUserClaudeSkills as any)?.mockRestore?.()
|
||||
;(skillLoader.discoverProjectClaudeSkills as any)?.mockRestore?.()
|
||||
;(skillLoader.discoverOpencodeGlobalSkills as any)?.mockRestore?.()
|
||||
;(skillLoader.discoverOpencodeProjectSkills as any)?.mockRestore?.()
|
||||
;(agentLoader.loadUserAgents as any)?.mockRestore?.()
|
||||
;(agentLoader.loadProjectAgents as any)?.mockRestore?.()
|
||||
;(mcpLoader.loadMcpConfigs as any)?.mockRestore?.()
|
||||
;(pluginLoader.loadAllPluginComponents as any)?.mockRestore?.()
|
||||
;(mcpModule.createBuiltinMcps as any)?.mockRestore?.()
|
||||
;(shared.log as any)?.mockRestore?.()
|
||||
;(shared.fetchAvailableModels as any)?.mockRestore?.()
|
||||
;(shared.readConnectedProvidersCache as any)?.mockRestore?.()
|
||||
;(configDir.getOpenCodeConfigPaths as any)?.mockRestore?.()
|
||||
;(permissionCompat.migrateAgentConfig as any)?.mockRestore?.()
|
||||
;(modelResolver.resolveModelWithFallback as any)?.mockRestore?.()
|
||||
})
|
||||
|
||||
describe("Plan agent demote behavior", () => {
|
||||
test("plan agent should be demoted to subagent mode when replacePlan is true", async () => {
|
||||
@@ -280,3 +272,127 @@ describe("Prometheus category config resolution", () => {
|
||||
expect(config?.tools).toEqual({ tool1: true, tool2: false })
|
||||
})
|
||||
})
|
||||
|
||||
describe("Prometheus direct override priority over category", () => {
|
||||
test("direct reasoningEffort takes priority over category reasoningEffort", async () => {
|
||||
// #given - category has reasoningEffort=xhigh, direct override says "low"
|
||||
const pluginConfig: OhMyOpenCodeConfig = {
|
||||
sisyphus_agent: {
|
||||
planner_enabled: true,
|
||||
},
|
||||
categories: {
|
||||
"test-planning": {
|
||||
model: "openai/gpt-5.2",
|
||||
reasoningEffort: "xhigh",
|
||||
},
|
||||
},
|
||||
agents: {
|
||||
prometheus: {
|
||||
category: "test-planning",
|
||||
reasoningEffort: "low",
|
||||
},
|
||||
},
|
||||
}
|
||||
const config: Record<string, unknown> = {
|
||||
model: "anthropic/claude-opus-4-5",
|
||||
agent: {},
|
||||
}
|
||||
const handler = createConfigHandler({
|
||||
ctx: { directory: "/tmp" },
|
||||
pluginConfig,
|
||||
modelCacheState: {
|
||||
anthropicContext1MEnabled: false,
|
||||
modelContextLimitsCache: new Map(),
|
||||
},
|
||||
})
|
||||
|
||||
// #when
|
||||
await handler(config)
|
||||
|
||||
// #then - direct override's reasoningEffort wins
|
||||
const agents = config.agent as Record<string, { reasoningEffort?: string }>
|
||||
expect(agents.prometheus).toBeDefined()
|
||||
expect(agents.prometheus.reasoningEffort).toBe("low")
|
||||
})
|
||||
|
||||
test("category reasoningEffort applied when no direct override", async () => {
|
||||
// #given - category has reasoningEffort but no direct override
|
||||
const pluginConfig: OhMyOpenCodeConfig = {
|
||||
sisyphus_agent: {
|
||||
planner_enabled: true,
|
||||
},
|
||||
categories: {
|
||||
"reasoning-cat": {
|
||||
model: "openai/gpt-5.2",
|
||||
reasoningEffort: "high",
|
||||
},
|
||||
},
|
||||
agents: {
|
||||
prometheus: {
|
||||
category: "reasoning-cat",
|
||||
},
|
||||
},
|
||||
}
|
||||
const config: Record<string, unknown> = {
|
||||
model: "anthropic/claude-opus-4-5",
|
||||
agent: {},
|
||||
}
|
||||
const handler = createConfigHandler({
|
||||
ctx: { directory: "/tmp" },
|
||||
pluginConfig,
|
||||
modelCacheState: {
|
||||
anthropicContext1MEnabled: false,
|
||||
modelContextLimitsCache: new Map(),
|
||||
},
|
||||
})
|
||||
|
||||
// #when
|
||||
await handler(config)
|
||||
|
||||
// #then - category's reasoningEffort is applied
|
||||
const agents = config.agent as Record<string, { reasoningEffort?: string }>
|
||||
expect(agents.prometheus).toBeDefined()
|
||||
expect(agents.prometheus.reasoningEffort).toBe("high")
|
||||
})
|
||||
|
||||
test("direct temperature takes priority over category temperature", async () => {
|
||||
// #given
|
||||
const pluginConfig: OhMyOpenCodeConfig = {
|
||||
sisyphus_agent: {
|
||||
planner_enabled: true,
|
||||
},
|
||||
categories: {
|
||||
"temp-cat": {
|
||||
model: "openai/gpt-5.2",
|
||||
temperature: 0.8,
|
||||
},
|
||||
},
|
||||
agents: {
|
||||
prometheus: {
|
||||
category: "temp-cat",
|
||||
temperature: 0.1,
|
||||
},
|
||||
},
|
||||
}
|
||||
const config: Record<string, unknown> = {
|
||||
model: "anthropic/claude-opus-4-5",
|
||||
agent: {},
|
||||
}
|
||||
const handler = createConfigHandler({
|
||||
ctx: { directory: "/tmp" },
|
||||
pluginConfig,
|
||||
modelCacheState: {
|
||||
anthropicContext1MEnabled: false,
|
||||
modelContextLimitsCache: new Map(),
|
||||
},
|
||||
})
|
||||
|
||||
// #when
|
||||
await handler(config)
|
||||
|
||||
// #then - direct temperature wins over category
|
||||
const agents = config.agent as Record<string, { temperature?: number }>
|
||||
expect(agents.prometheus).toBeDefined()
|
||||
expect(agents.prometheus.temperature).toBe(0.1)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -133,16 +133,20 @@ export function createConfigHandler(deps: ConfigHandlerDeps) {
|
||||
];
|
||||
|
||||
const browserProvider = pluginConfig.browser_automation_engine?.provider ?? "playwright";
|
||||
// config.model represents the currently active model in OpenCode (including UI selection)
|
||||
// Pass it as uiSelectedModel so it takes highest priority in model resolution
|
||||
const currentModel = config.model as string | undefined;
|
||||
const builtinAgents = await createBuiltinAgents(
|
||||
migratedDisabledAgents,
|
||||
pluginConfig.agents,
|
||||
ctx.directory,
|
||||
config.model as string | undefined,
|
||||
undefined, // systemDefaultModel - let fallback chain handle this
|
||||
pluginConfig.categories,
|
||||
pluginConfig.git_master,
|
||||
allDiscoveredSkills,
|
||||
ctx.client,
|
||||
browserProvider
|
||||
browserProvider,
|
||||
currentModel // uiSelectedModel - takes highest priority
|
||||
);
|
||||
|
||||
// Claude Code agents: Do NOT apply permission migration
|
||||
@@ -223,9 +227,18 @@ export function createConfigHandler(deps: ConfigHandlerDeps) {
|
||||
);
|
||||
const prometheusOverride =
|
||||
pluginConfig.agents?.["prometheus"] as
|
||||
| (Record<string, unknown> & { category?: string; model?: string; variant?: string })
|
||||
| (Record<string, unknown> & {
|
||||
category?: string
|
||||
model?: string
|
||||
variant?: string
|
||||
reasoningEffort?: string
|
||||
textVerbosity?: string
|
||||
thinking?: { type: string; budgetTokens?: number }
|
||||
temperature?: number
|
||||
top_p?: number
|
||||
maxTokens?: number
|
||||
})
|
||||
| undefined;
|
||||
const defaultModel = config.model as string | undefined;
|
||||
|
||||
const categoryConfig = prometheusOverride?.category
|
||||
? resolveCategoryConfig(
|
||||
@@ -241,15 +254,22 @@ export function createConfigHandler(deps: ConfigHandlerDeps) {
|
||||
: new Set<string>();
|
||||
|
||||
const modelResolution = resolveModelWithFallback({
|
||||
uiSelectedModel: currentModel,
|
||||
userModel: prometheusOverride?.model ?? categoryConfig?.model,
|
||||
fallbackChain: prometheusRequirement?.fallbackChain,
|
||||
availableModels,
|
||||
systemDefaultModel: defaultModel ?? "",
|
||||
systemDefaultModel: undefined,
|
||||
});
|
||||
const resolvedModel = modelResolution?.model;
|
||||
const resolvedVariant = modelResolution?.variant;
|
||||
|
||||
const variantToUse = prometheusOverride?.variant ?? resolvedVariant;
|
||||
const reasoningEffortToUse = prometheusOverride?.reasoningEffort ?? categoryConfig?.reasoningEffort;
|
||||
const textVerbosityToUse = prometheusOverride?.textVerbosity ?? categoryConfig?.textVerbosity;
|
||||
const thinkingToUse = prometheusOverride?.thinking ?? categoryConfig?.thinking;
|
||||
const temperatureToUse = prometheusOverride?.temperature ?? categoryConfig?.temperature;
|
||||
const topPToUse = prometheusOverride?.top_p ?? categoryConfig?.top_p;
|
||||
const maxTokensToUse = prometheusOverride?.maxTokens ?? categoryConfig?.maxTokens;
|
||||
const prometheusBase = {
|
||||
name: "prometheus",
|
||||
...(resolvedModel ? { model: resolvedModel } : {}),
|
||||
@@ -259,22 +279,16 @@ export function createConfigHandler(deps: ConfigHandlerDeps) {
|
||||
permission: PROMETHEUS_PERMISSION,
|
||||
description: `${configAgent?.plan?.description ?? "Plan agent"} (Prometheus - OhMyOpenCode)`,
|
||||
color: (configAgent?.plan?.color as string) ?? "#FF6347",
|
||||
...(categoryConfig?.temperature !== undefined
|
||||
? { temperature: categoryConfig.temperature }
|
||||
: {}),
|
||||
...(categoryConfig?.top_p !== undefined
|
||||
? { top_p: categoryConfig.top_p }
|
||||
: {}),
|
||||
...(categoryConfig?.maxTokens !== undefined
|
||||
? { maxTokens: categoryConfig.maxTokens }
|
||||
: {}),
|
||||
...(temperatureToUse !== undefined ? { temperature: temperatureToUse } : {}),
|
||||
...(topPToUse !== undefined ? { top_p: topPToUse } : {}),
|
||||
...(maxTokensToUse !== undefined ? { maxTokens: maxTokensToUse } : {}),
|
||||
...(categoryConfig?.tools ? { tools: categoryConfig.tools } : {}),
|
||||
...(categoryConfig?.thinking ? { thinking: categoryConfig.thinking } : {}),
|
||||
...(categoryConfig?.reasoningEffort !== undefined
|
||||
? { reasoningEffort: categoryConfig.reasoningEffort }
|
||||
...(thinkingToUse ? { thinking: thinkingToUse } : {}),
|
||||
...(reasoningEffortToUse !== undefined
|
||||
? { reasoningEffort: reasoningEffortToUse }
|
||||
: {}),
|
||||
...(categoryConfig?.textVerbosity !== undefined
|
||||
? { textVerbosity: categoryConfig.textVerbosity }
|
||||
...(textVerbosityToUse !== undefined
|
||||
? { textVerbosity: textVerbosityToUse }
|
||||
: {}),
|
||||
};
|
||||
|
||||
|
||||
@@ -113,7 +113,80 @@ describe("resolveModelWithFallback", () => {
|
||||
logSpy.mockRestore()
|
||||
})
|
||||
|
||||
describe("Step 1: Override", () => {
|
||||
describe("Step 1: UI Selection (highest priority)", () => {
|
||||
test("returns uiSelectedModel with override source when provided", () => {
|
||||
// #given
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
uiSelectedModel: "opencode/glm-4.7-free",
|
||||
userModel: "anthropic/claude-opus-4-5",
|
||||
fallbackChain: [
|
||||
{ providers: ["anthropic", "github-copilot"], model: "claude-opus-4-5" },
|
||||
],
|
||||
availableModels: new Set(["anthropic/claude-opus-4-5", "github-copilot/claude-opus-4-5-preview"]),
|
||||
systemDefaultModel: "google/gemini-3-pro",
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = resolveModelWithFallback(input)
|
||||
|
||||
// #then
|
||||
expect(result!.model).toBe("opencode/glm-4.7-free")
|
||||
expect(result!.source).toBe("override")
|
||||
expect(logSpy).toHaveBeenCalledWith("Model resolved via UI selection", { model: "opencode/glm-4.7-free" })
|
||||
})
|
||||
|
||||
test("UI selection takes priority over config override", () => {
|
||||
// #given
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
uiSelectedModel: "opencode/glm-4.7-free",
|
||||
userModel: "anthropic/claude-opus-4-5",
|
||||
availableModels: new Set(["anthropic/claude-opus-4-5"]),
|
||||
systemDefaultModel: "google/gemini-3-pro",
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = resolveModelWithFallback(input)
|
||||
|
||||
// #then
|
||||
expect(result!.model).toBe("opencode/glm-4.7-free")
|
||||
expect(result!.source).toBe("override")
|
||||
})
|
||||
|
||||
test("whitespace-only uiSelectedModel is treated as not provided", () => {
|
||||
// #given
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
uiSelectedModel: " ",
|
||||
userModel: "anthropic/claude-opus-4-5",
|
||||
availableModels: new Set(["anthropic/claude-opus-4-5"]),
|
||||
systemDefaultModel: "google/gemini-3-pro",
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = resolveModelWithFallback(input)
|
||||
|
||||
// #then
|
||||
expect(result!.model).toBe("anthropic/claude-opus-4-5")
|
||||
expect(logSpy).toHaveBeenCalledWith("Model resolved via config override", { model: "anthropic/claude-opus-4-5" })
|
||||
})
|
||||
|
||||
test("empty string uiSelectedModel falls through to config override", () => {
|
||||
// #given
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
uiSelectedModel: "",
|
||||
userModel: "anthropic/claude-opus-4-5",
|
||||
availableModels: new Set(["anthropic/claude-opus-4-5"]),
|
||||
systemDefaultModel: "google/gemini-3-pro",
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = resolveModelWithFallback(input)
|
||||
|
||||
// #then
|
||||
expect(result!.model).toBe("anthropic/claude-opus-4-5")
|
||||
})
|
||||
})
|
||||
|
||||
describe("Step 2: Config Override", () => {
|
||||
test("returns userModel with override source when userModel is provided", () => {
|
||||
// #given
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
@@ -131,7 +204,7 @@ describe("resolveModelWithFallback", () => {
|
||||
// #then
|
||||
expect(result!.model).toBe("anthropic/claude-opus-4-5")
|
||||
expect(result!.source).toBe("override")
|
||||
expect(logSpy).toHaveBeenCalledWith("Model resolved via override", { model: "anthropic/claude-opus-4-5" })
|
||||
expect(logSpy).toHaveBeenCalledWith("Model resolved via config override", { model: "anthropic/claude-opus-4-5" })
|
||||
})
|
||||
|
||||
test("override takes priority even if model not in availableModels", () => {
|
||||
@@ -190,7 +263,7 @@ describe("resolveModelWithFallback", () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe("Step 2: Provider fallback chain", () => {
|
||||
describe("Step 3: Provider fallback chain", () => {
|
||||
test("tries providers in order within entry and returns first match", () => {
|
||||
// #given
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
@@ -317,7 +390,7 @@ describe("resolveModelWithFallback", () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe("Step 3: System default fallback (no availability match)", () => {
|
||||
describe("Step 4: System default fallback (no availability match)", () => {
|
||||
test("returns system default when no availability match found in fallback chain", () => {
|
||||
// #given
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
@@ -356,7 +429,7 @@ describe("resolveModelWithFallback", () => {
|
||||
cacheSpy.mockRestore()
|
||||
})
|
||||
|
||||
test("uses connected provider when availableModels empty but connected providers cache exists", () => {
|
||||
test("uses connected provider from fallback when availableModels empty but cache exists", () => {
|
||||
// #given - model cache missing but connected-providers cache exists
|
||||
const cacheSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["openai", "google"])
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
@@ -370,12 +443,52 @@ describe("resolveModelWithFallback", () => {
|
||||
// #when
|
||||
const result = resolveModelWithFallback(input)
|
||||
|
||||
// #then - should use openai (second provider) since anthropic not in connected cache
|
||||
// #then - should use connected provider (openai) from fallback chain
|
||||
expect(result!.model).toBe("openai/claude-opus-4-5")
|
||||
expect(result!.source).toBe("provider-fallback")
|
||||
cacheSpy.mockRestore()
|
||||
})
|
||||
|
||||
test("uses github-copilot when google not connected (visual-engineering scenario)", () => {
|
||||
// #given - user has github-copilot but not google connected
|
||||
const cacheSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["github-copilot"])
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
fallbackChain: [
|
||||
{ providers: ["google", "github-copilot", "opencode"], model: "gemini-3-pro" },
|
||||
],
|
||||
availableModels: new Set(),
|
||||
systemDefaultModel: "anthropic/claude-sonnet-4-5",
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = resolveModelWithFallback(input)
|
||||
|
||||
// #then - should use github-copilot (second provider) since google not connected
|
||||
expect(result!.model).toBe("github-copilot/gemini-3-pro")
|
||||
expect(result!.source).toBe("provider-fallback")
|
||||
cacheSpy.mockRestore()
|
||||
})
|
||||
|
||||
test("falls through to system default when no provider in fallback is connected", () => {
|
||||
// #given - user only has quotio connected, but fallback chain has anthropic/opencode
|
||||
const cacheSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(["quotio"])
|
||||
const input: ExtendedModelResolutionInput = {
|
||||
fallbackChain: [
|
||||
{ providers: ["anthropic", "opencode"], model: "claude-haiku-4-5" },
|
||||
],
|
||||
availableModels: new Set(),
|
||||
systemDefaultModel: "quotio/claude-opus-4-5-20251101",
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = resolveModelWithFallback(input)
|
||||
|
||||
// #then - no provider in fallback is connected, fall through to system default
|
||||
expect(result!.model).toBe("quotio/claude-opus-4-5-20251101")
|
||||
expect(result!.source).toBe("system-default")
|
||||
cacheSpy.mockRestore()
|
||||
})
|
||||
|
||||
test("falls through to system default when no cache and systemDefaultModel is provided", () => {
|
||||
// #given - no cache but system default is configured
|
||||
const cacheSpy = spyOn(connectedProvidersCache, "readConnectedProvidersCache").mockReturnValue(null)
|
||||
|
||||
@@ -21,6 +21,7 @@ export type ModelResolutionResult = {
|
||||
}
|
||||
|
||||
export type ExtendedModelResolutionInput = {
|
||||
uiSelectedModel?: string
|
||||
userModel?: string
|
||||
fallbackChain?: FallbackEntry[]
|
||||
availableModels: Set<string>
|
||||
@@ -43,25 +44,30 @@ export function resolveModel(input: ModelResolutionInput): string | undefined {
|
||||
export function resolveModelWithFallback(
|
||||
input: ExtendedModelResolutionInput,
|
||||
): ModelResolutionResult | undefined {
|
||||
const { userModel, fallbackChain, availableModels, systemDefaultModel } = input
|
||||
const { uiSelectedModel, userModel, fallbackChain, availableModels, systemDefaultModel } = input
|
||||
|
||||
// Step 1: Override
|
||||
// Step 1: UI Selection (highest priority - respects user's model choice in OpenCode UI)
|
||||
const normalizedUiModel = normalizeModel(uiSelectedModel)
|
||||
if (normalizedUiModel) {
|
||||
log("Model resolved via UI selection", { model: normalizedUiModel })
|
||||
return { model: normalizedUiModel, source: "override" }
|
||||
}
|
||||
|
||||
// Step 2: Config Override (from oh-my-opencode.json)
|
||||
const normalizedUserModel = normalizeModel(userModel)
|
||||
if (normalizedUserModel) {
|
||||
log("Model resolved via override", { model: normalizedUserModel })
|
||||
log("Model resolved via config override", { model: normalizedUserModel })
|
||||
return { model: normalizedUserModel, source: "override" }
|
||||
}
|
||||
|
||||
// Step 2: Provider fallback chain (with availability check)
|
||||
// Step 3: Provider fallback chain (with availability check)
|
||||
if (fallbackChain && fallbackChain.length > 0) {
|
||||
if (availableModels.size === 0) {
|
||||
const connectedProviders = readConnectedProvidersCache()
|
||||
const connectedSet = connectedProviders ? new Set(connectedProviders) : null
|
||||
|
||||
// When no cache exists at all, skip fallback chain and fall through to system default
|
||||
// This allows OpenCode to use Provider.defaultModel() as the final fallback
|
||||
if (connectedSet === null) {
|
||||
log("No cache available, skipping fallback chain to use system default")
|
||||
log("Model fallback chain skipped (no connected providers cache) - falling through to system default")
|
||||
} else {
|
||||
for (const entry of fallbackChain) {
|
||||
for (const provider of entry.providers) {
|
||||
@@ -76,7 +82,7 @@ export function resolveModelWithFallback(
|
||||
}
|
||||
}
|
||||
}
|
||||
log("No matching provider in connected cache, falling through to system default")
|
||||
log("No connected provider found in fallback chain, falling through to system default")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +99,7 @@ export function resolveModelWithFallback(
|
||||
log("No available model found in fallback chain, falling through to system default")
|
||||
}
|
||||
|
||||
// Step 3: System default (if provided)
|
||||
// Step 4: System default (if provided)
|
||||
if (systemDefaultModel === undefined) {
|
||||
log("No model resolved - systemDefaultModel not configured")
|
||||
return undefined
|
||||
|
||||
@@ -80,7 +80,11 @@ export function createBackgroundTask(manager: BackgroundManager): ToolDefinition
|
||||
})
|
||||
|
||||
const parentModel = prevMessage?.model?.providerID && prevMessage?.model?.modelID
|
||||
? { providerID: prevMessage.model.providerID, modelID: prevMessage.model.modelID }
|
||||
? {
|
||||
providerID: prevMessage.model.providerID,
|
||||
modelID: prevMessage.model.modelID,
|
||||
...(prevMessage.model.variant ? { variant: prevMessage.model.variant } : {})
|
||||
}
|
||||
: undefined
|
||||
|
||||
const task = await manager.launch({
|
||||
|
||||
@@ -1474,6 +1474,73 @@ describe("sisyphus-task", () => {
|
||||
}, { timeout: 20000 })
|
||||
})
|
||||
|
||||
describe("category model resolution fallback", () => {
|
||||
test("category uses resolved.model when connectedProvidersCache is null and availableModels is empty", async () => {
|
||||
// #given - connectedProvidersCache returns null (simulates missing cache file)
|
||||
// This is a regression test for PR #1227 which removed resolved.model from userModel chain
|
||||
cacheSpy.mockReturnValue(null)
|
||||
|
||||
const { createDelegateTask } = require("./tools")
|
||||
let launchInput: any
|
||||
|
||||
const mockManager = {
|
||||
launch: async (input: any) => {
|
||||
launchInput = input
|
||||
return {
|
||||
id: "task-fallback",
|
||||
sessionID: "ses_fallback_test",
|
||||
description: "Fallback test task",
|
||||
agent: "sisyphus-junior",
|
||||
status: "running",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockClient = {
|
||||
app: { agents: async () => ({ data: [] }) },
|
||||
config: { get: async () => ({ data: { model: SYSTEM_DEFAULT_MODEL } }) },
|
||||
model: { list: async () => [] },
|
||||
session: {
|
||||
create: async () => ({ data: { id: "test-session" } }),
|
||||
prompt: async () => ({ data: {} }),
|
||||
messages: async () => ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
// NO userCategories override, NO sisyphusJuniorModel
|
||||
const tool = createDelegateTask({
|
||||
manager: mockManager,
|
||||
client: mockClient,
|
||||
// userCategories: undefined - use DEFAULT_CATEGORIES only
|
||||
// sisyphusJuniorModel: undefined
|
||||
})
|
||||
|
||||
const toolContext = {
|
||||
sessionID: "parent-session",
|
||||
messageID: "parent-message",
|
||||
agent: "sisyphus",
|
||||
abort: new AbortController().signal,
|
||||
}
|
||||
|
||||
// #when - using "quick" category which should use "anthropic/claude-haiku-4-5"
|
||||
await tool.execute(
|
||||
{
|
||||
description: "Test category fallback",
|
||||
prompt: "Do something quick",
|
||||
category: "quick",
|
||||
run_in_background: true,
|
||||
load_skills: [],
|
||||
},
|
||||
toolContext
|
||||
)
|
||||
|
||||
// #then - model should be anthropic/claude-haiku-4-5 from DEFAULT_CATEGORIES
|
||||
// NOT anthropic/claude-sonnet-4-5 (system default)
|
||||
expect(launchInput.model.providerID).toBe("anthropic")
|
||||
expect(launchInput.model.modelID).toBe("claude-haiku-4-5")
|
||||
})
|
||||
})
|
||||
|
||||
describe("browserProvider propagation", () => {
|
||||
test("should resolve agent-browser skill when browserProvider is passed", async () => {
|
||||
// #given - delegate_task configured with browserProvider: "agent-browser"
|
||||
@@ -2035,6 +2102,192 @@ describe("sisyphus-task", () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe("subagent_type model extraction (issue #1225)", () => {
|
||||
test("background mode passes matched agent model to manager.launch", async () => {
|
||||
// #given - agent with model registered, using subagent_type with run_in_background=true
|
||||
const { createDelegateTask } = require("./tools")
|
||||
let launchInput: any
|
||||
|
||||
const mockManager = {
|
||||
launch: async (input: any) => {
|
||||
launchInput = input
|
||||
return {
|
||||
id: "task-explore",
|
||||
sessionID: "ses_explore_model",
|
||||
description: "Explore task",
|
||||
agent: "explore",
|
||||
status: "running",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
const mockClient = {
|
||||
app: {
|
||||
agents: async () => ({
|
||||
data: [
|
||||
{ name: "explore", mode: "subagent", model: { providerID: "anthropic", modelID: "claude-haiku-4-5" } },
|
||||
],
|
||||
}),
|
||||
},
|
||||
config: { get: async () => ({ data: { model: SYSTEM_DEFAULT_MODEL } }) },
|
||||
session: {
|
||||
create: async () => ({ data: { id: "ses_explore_model" } }),
|
||||
prompt: async () => ({ data: {} }),
|
||||
messages: async () => ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
const tool = createDelegateTask({
|
||||
manager: mockManager,
|
||||
client: mockClient,
|
||||
})
|
||||
|
||||
const toolContext = {
|
||||
sessionID: "parent-session",
|
||||
messageID: "parent-message",
|
||||
agent: "sisyphus",
|
||||
abort: new AbortController().signal,
|
||||
}
|
||||
|
||||
// #when - delegating to explore agent via subagent_type
|
||||
await tool.execute(
|
||||
{
|
||||
description: "Explore codebase",
|
||||
prompt: "Find auth patterns",
|
||||
subagent_type: "explore",
|
||||
run_in_background: true,
|
||||
load_skills: [],
|
||||
},
|
||||
toolContext
|
||||
)
|
||||
|
||||
// #then - matched agent's model should be passed to manager.launch
|
||||
expect(launchInput.model).toEqual({
|
||||
providerID: "anthropic",
|
||||
modelID: "claude-haiku-4-5",
|
||||
})
|
||||
})
|
||||
|
||||
test("sync mode passes matched agent model to session.prompt", async () => {
|
||||
// #given - agent with model registered, using subagent_type with run_in_background=false
|
||||
const { createDelegateTask } = require("./tools")
|
||||
let promptBody: any
|
||||
|
||||
const mockManager = { launch: async () => ({}) }
|
||||
|
||||
const mockClient = {
|
||||
app: {
|
||||
agents: async () => ({
|
||||
data: [
|
||||
{ name: "oracle", mode: "subagent", model: { providerID: "anthropic", modelID: "claude-opus-4-5" } },
|
||||
],
|
||||
}),
|
||||
},
|
||||
config: { get: async () => ({ data: { model: SYSTEM_DEFAULT_MODEL } }) },
|
||||
session: {
|
||||
get: async () => ({ data: { directory: "/project" } }),
|
||||
create: async () => ({ data: { id: "ses_oracle_model" } }),
|
||||
prompt: async (input: any) => {
|
||||
promptBody = input.body
|
||||
return { data: {} }
|
||||
},
|
||||
messages: async () => ({
|
||||
data: [{ info: { role: "assistant" }, parts: [{ type: "text", text: "Consultation done" }] }],
|
||||
}),
|
||||
status: async () => ({ data: { "ses_oracle_model": { type: "idle" } } }),
|
||||
},
|
||||
}
|
||||
|
||||
const tool = createDelegateTask({
|
||||
manager: mockManager,
|
||||
client: mockClient,
|
||||
})
|
||||
|
||||
const toolContext = {
|
||||
sessionID: "parent-session",
|
||||
messageID: "parent-message",
|
||||
agent: "sisyphus",
|
||||
abort: new AbortController().signal,
|
||||
}
|
||||
|
||||
// #when - delegating to oracle agent via subagent_type in sync mode
|
||||
await tool.execute(
|
||||
{
|
||||
description: "Consult oracle",
|
||||
prompt: "Review architecture",
|
||||
subagent_type: "oracle",
|
||||
run_in_background: false,
|
||||
load_skills: [],
|
||||
},
|
||||
toolContext
|
||||
)
|
||||
|
||||
// #then - matched agent's model should be passed to session.prompt
|
||||
expect(promptBody.model).toEqual({
|
||||
providerID: "anthropic",
|
||||
modelID: "claude-opus-4-5",
|
||||
})
|
||||
}, { timeout: 20000 })
|
||||
|
||||
test("agent without model does not override categoryModel", async () => {
|
||||
// #given - agent registered without model field
|
||||
const { createDelegateTask } = require("./tools")
|
||||
let promptBody: any
|
||||
|
||||
const mockManager = { launch: async () => ({}) }
|
||||
|
||||
const mockClient = {
|
||||
app: {
|
||||
agents: async () => ({
|
||||
data: [
|
||||
{ name: "explore", mode: "subagent" }, // no model field
|
||||
],
|
||||
}),
|
||||
},
|
||||
config: { get: async () => ({ data: { model: SYSTEM_DEFAULT_MODEL } }) },
|
||||
session: {
|
||||
get: async () => ({ data: { directory: "/project" } }),
|
||||
create: async () => ({ data: { id: "ses_no_model_agent" } }),
|
||||
prompt: async (input: any) => {
|
||||
promptBody = input.body
|
||||
return { data: {} }
|
||||
},
|
||||
messages: async () => ({
|
||||
data: [{ info: { role: "assistant" }, parts: [{ type: "text", text: "Done" }] }],
|
||||
}),
|
||||
status: async () => ({ data: { "ses_no_model_agent": { type: "idle" } } }),
|
||||
},
|
||||
}
|
||||
|
||||
const tool = createDelegateTask({
|
||||
manager: mockManager,
|
||||
client: mockClient,
|
||||
})
|
||||
|
||||
const toolContext = {
|
||||
sessionID: "parent-session",
|
||||
messageID: "parent-message",
|
||||
agent: "sisyphus",
|
||||
abort: new AbortController().signal,
|
||||
}
|
||||
|
||||
// #when - delegating to agent without model
|
||||
await tool.execute(
|
||||
{
|
||||
description: "Explore without model",
|
||||
prompt: "Find something",
|
||||
subagent_type: "explore",
|
||||
run_in_background: false,
|
||||
load_skills: [],
|
||||
},
|
||||
toolContext
|
||||
)
|
||||
|
||||
// #then - no model should be passed to session.prompt
|
||||
expect(promptBody.model).toBeUndefined()
|
||||
}, { timeout: 20000 })
|
||||
})
|
||||
|
||||
describe("prometheus subagent delegate_task permission", () => {
|
||||
test("prometheus subagent should have delegate_task permission enabled", async () => {
|
||||
// #given - sisyphus delegates to prometheus
|
||||
|
||||
@@ -287,7 +287,11 @@ Prompts MUST be in English.`
|
||||
resolvedParentAgent: parentAgent,
|
||||
})
|
||||
const parentModel = prevMessage?.model?.providerID && prevMessage?.model?.modelID
|
||||
? { providerID: prevMessage.model.providerID, modelID: prevMessage.model.modelID }
|
||||
? {
|
||||
providerID: prevMessage.model.providerID,
|
||||
modelID: prevMessage.model.modelID,
|
||||
...(prevMessage.model.variant ? { variant: prevMessage.model.variant } : {})
|
||||
}
|
||||
: undefined
|
||||
|
||||
if (args.session_id) {
|
||||
@@ -537,7 +541,7 @@ To continue this session: session_id="${args.session_id}"`
|
||||
}
|
||||
} else {
|
||||
const resolution = resolveModelWithFallback({
|
||||
userModel: userCategories?.[args.category]?.model ?? sisyphusJuniorModel,
|
||||
userModel: userCategories?.[args.category]?.model ?? resolved.model ?? sisyphusJuniorModel,
|
||||
fallbackChain: requirement.fallbackChain,
|
||||
availableModels,
|
||||
systemDefaultModel,
|
||||
@@ -567,7 +571,7 @@ To continue this session: session_id="${args.session_id}"`
|
||||
modelInfo = { model: actualModel, type, source }
|
||||
|
||||
const parsedModel = parseModelString(actualModel)
|
||||
const variantToUse = userCategories?.[args.category]?.variant ?? resolvedVariant
|
||||
const variantToUse = userCategories?.[args.category]?.variant ?? resolvedVariant ?? resolved.config.variant
|
||||
categoryModel = parsedModel
|
||||
? (variantToUse ? { ...parsedModel, variant: variantToUse } : parsedModel)
|
||||
: undefined
|
||||
@@ -780,7 +784,7 @@ Create the work plan directly - that's your job as the planning agent.`
|
||||
// Uses case-insensitive matching to allow "Oracle", "oracle", "ORACLE" etc.
|
||||
try {
|
||||
const agentsResult = await client.app.agents()
|
||||
type AgentInfo = { name: string; mode?: "subagent" | "primary" | "all" }
|
||||
type AgentInfo = { name: string; mode?: "subagent" | "primary" | "all"; model?: { providerID: string; modelID: string } }
|
||||
const agents = (agentsResult as { data?: AgentInfo[] }).data ?? agentsResult as unknown as AgentInfo[]
|
||||
|
||||
const callableAgents = agents.filter((a) => a.mode !== "primary")
|
||||
@@ -803,9 +807,24 @@ Create the work plan directly - that's your job as the planning agent.`
|
||||
}
|
||||
// Use the canonical agent name from registration
|
||||
agentToUse = matchedAgent.name
|
||||
|
||||
// Extract registered agent's model to pass explicitly to session.prompt.
|
||||
// This ensures the model is always in the correct object format ({providerID, modelID})
|
||||
// regardless of how OpenCode handles string→object conversion for plugin-registered agents.
|
||||
// See: https://github.com/code-yeongyu/oh-my-opencode/issues/1225
|
||||
if (matchedAgent.model) {
|
||||
categoryModel = matchedAgent.model
|
||||
}
|
||||
} catch {
|
||||
// If we can't fetch agents, proceed anyway - the session.prompt will fail with a clearer error
|
||||
}
|
||||
|
||||
// When using subagent_type directly, inherit parent model so agents don't default
|
||||
// to their hardcoded models (like grok-code) which may not be available
|
||||
if (parentModel) {
|
||||
categoryModel = parentModel
|
||||
modelInfo = { model: `${parentModel.providerID}/${parentModel.modelID}`, type: "inherited" }
|
||||
}
|
||||
}
|
||||
|
||||
const systemContent = buildSystemContent({ skillContent, categoryPromptAppend, agentName: agentToUse })
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { describe, expect, test } from "bun:test"
|
||||
import { normalizeArgs, validateArgs } from "./tools"
|
||||
import { normalizeArgs, validateArgs, createLookAt } from "./tools"
|
||||
|
||||
describe("look-at tool", () => {
|
||||
describe("normalizeArgs", () => {
|
||||
@@ -70,4 +70,80 @@ describe("look-at tool", () => {
|
||||
expect(error).toContain("file_path")
|
||||
})
|
||||
})
|
||||
|
||||
describe("createLookAt error handling", () => {
|
||||
// #given session.prompt에서 JSON parse 에러 발생
|
||||
// #when LookAt 도구 실행
|
||||
// #then 사용자 친화적 에러 메시지 반환
|
||||
test("handles JSON parse error from session.prompt gracefully", async () => {
|
||||
const mockClient = {
|
||||
session: {
|
||||
get: async () => ({ data: { directory: "/project" } }),
|
||||
create: async () => ({ data: { id: "ses_test_json_error" } }),
|
||||
prompt: async () => {
|
||||
throw new Error("JSON Parse error: Unexpected EOF")
|
||||
},
|
||||
messages: async () => ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
const tool = createLookAt({
|
||||
client: mockClient,
|
||||
directory: "/project",
|
||||
} as any)
|
||||
|
||||
const toolContext = {
|
||||
sessionID: "parent-session",
|
||||
messageID: "parent-message",
|
||||
agent: "sisyphus",
|
||||
abort: new AbortController().signal,
|
||||
}
|
||||
|
||||
const result = await tool.execute(
|
||||
{ file_path: "/test/file.png", goal: "analyze image" },
|
||||
toolContext
|
||||
)
|
||||
|
||||
expect(result).toContain("Error: Failed to analyze file")
|
||||
expect(result).toContain("malformed response")
|
||||
expect(result).toContain("multimodal-looker")
|
||||
expect(result).toContain("image/png")
|
||||
})
|
||||
|
||||
// #given session.prompt에서 일반 에러 발생
|
||||
// #when LookAt 도구 실행
|
||||
// #then 원본 에러 메시지 포함한 에러 반환
|
||||
test("handles generic prompt error gracefully", async () => {
|
||||
const mockClient = {
|
||||
session: {
|
||||
get: async () => ({ data: { directory: "/project" } }),
|
||||
create: async () => ({ data: { id: "ses_test_generic_error" } }),
|
||||
prompt: async () => {
|
||||
throw new Error("Network connection failed")
|
||||
},
|
||||
messages: async () => ({ data: [] }),
|
||||
},
|
||||
}
|
||||
|
||||
const tool = createLookAt({
|
||||
client: mockClient,
|
||||
directory: "/project",
|
||||
} as any)
|
||||
|
||||
const toolContext = {
|
||||
sessionID: "parent-session",
|
||||
messageID: "parent-message",
|
||||
agent: "sisyphus",
|
||||
abort: new AbortController().signal,
|
||||
}
|
||||
|
||||
const result = await tool.execute(
|
||||
{ file_path: "/test/file.pdf", goal: "extract text" },
|
||||
toolContext
|
||||
)
|
||||
|
||||
expect(result).toContain("Error: Failed to send prompt")
|
||||
expect(result).toContain("Network connection failed")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -131,22 +131,49 @@ Original error: ${createResult.error}`
|
||||
log(`[look_at] Created session: ${sessionID}`)
|
||||
|
||||
log(`[look_at] Sending prompt with file passthrough to session ${sessionID}`)
|
||||
await ctx.client.session.prompt({
|
||||
path: { id: sessionID },
|
||||
body: {
|
||||
agent: MULTIMODAL_LOOKER_AGENT,
|
||||
tools: {
|
||||
task: false,
|
||||
call_omo_agent: false,
|
||||
look_at: false,
|
||||
read: false,
|
||||
try {
|
||||
await ctx.client.session.prompt({
|
||||
path: { id: sessionID },
|
||||
body: {
|
||||
agent: MULTIMODAL_LOOKER_AGENT,
|
||||
tools: {
|
||||
task: false,
|
||||
call_omo_agent: false,
|
||||
look_at: false,
|
||||
read: false,
|
||||
},
|
||||
parts: [
|
||||
{ type: "text", text: prompt },
|
||||
{ type: "file", mime: mimeType, url: pathToFileURL(args.file_path).href, filename },
|
||||
],
|
||||
},
|
||||
parts: [
|
||||
{ type: "text", text: prompt },
|
||||
{ type: "file", mime: mimeType, url: pathToFileURL(args.file_path).href, filename },
|
||||
],
|
||||
},
|
||||
})
|
||||
})
|
||||
} catch (promptError) {
|
||||
const errorMessage = promptError instanceof Error ? promptError.message : String(promptError)
|
||||
log(`[look_at] Prompt error:`, promptError)
|
||||
|
||||
const isJsonParseError = errorMessage.includes("JSON") && (errorMessage.includes("EOF") || errorMessage.includes("parse"))
|
||||
if (isJsonParseError) {
|
||||
return `Error: Failed to analyze file - received malformed response from multimodal-looker agent.
|
||||
|
||||
This typically occurs when:
|
||||
1. The multimodal-looker model is not available or not connected
|
||||
2. The model does not support this file type (${mimeType})
|
||||
3. The API returned an empty or truncated response
|
||||
|
||||
File: ${args.file_path}
|
||||
MIME type: ${mimeType}
|
||||
|
||||
Try:
|
||||
- Ensure a vision-capable model (e.g., gemini-3-flash, gpt-5.2) is available
|
||||
- Check provider connections in opencode settings
|
||||
- For text files like .md, .txt, use the Read tool instead
|
||||
|
||||
Original error: ${errorMessage}`
|
||||
}
|
||||
|
||||
return `Error: Failed to send prompt to multimodal-looker agent: ${errorMessage}`
|
||||
}
|
||||
|
||||
log(`[look_at] Prompt sent, fetching messages...`)
|
||||
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
import { spawn, type Subprocess } from "bun"
|
||||
import { Readable, Writable } from "node:stream"
|
||||
import { readFileSync } from "fs"
|
||||
import { extname, resolve } from "path"
|
||||
import { pathToFileURL } from "node:url"
|
||||
import {
|
||||
createMessageConnection,
|
||||
StreamMessageReader,
|
||||
StreamMessageWriter,
|
||||
type MessageConnection,
|
||||
} from "vscode-jsonrpc/node"
|
||||
import { getLanguageId } from "./config"
|
||||
import type { Diagnostic, ResolvedServer } from "./types"
|
||||
|
||||
@@ -38,22 +45,18 @@ class LSPServerManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Works on all platforms
|
||||
process.on("exit", cleanup)
|
||||
|
||||
// Ctrl+C - works on all platforms
|
||||
process.on("SIGINT", () => {
|
||||
cleanup()
|
||||
process.exit(0)
|
||||
})
|
||||
|
||||
// Kill signal - Unix/macOS
|
||||
process.on("SIGTERM", () => {
|
||||
cleanup()
|
||||
process.exit(0)
|
||||
})
|
||||
|
||||
// Ctrl+Break - Windows specific
|
||||
if (process.platform === "win32") {
|
||||
process.on("SIGBREAK", () => {
|
||||
cleanup()
|
||||
@@ -209,13 +212,12 @@ export const lspManager = LSPServerManager.getInstance()
|
||||
|
||||
export class LSPClient {
|
||||
private proc: Subprocess<"pipe", "pipe", "pipe"> | null = null
|
||||
private buffer: Uint8Array = new Uint8Array(0)
|
||||
private pending = new Map<number, { resolve: (value: unknown) => void; reject: (error: Error) => void }>()
|
||||
private requestIdCounter = 0
|
||||
private connection: MessageConnection | null = null
|
||||
private openedFiles = new Set<string>()
|
||||
private stderrBuffer: string[] = []
|
||||
private processExited = false
|
||||
private diagnosticsStore = new Map<string, Diagnostic[]>()
|
||||
private readonly REQUEST_TIMEOUT = 15000
|
||||
|
||||
constructor(
|
||||
private root: string,
|
||||
@@ -238,7 +240,6 @@ export class LSPClient {
|
||||
throw new Error(`Failed to spawn LSP server: ${this.server.command.join(" ")}`)
|
||||
}
|
||||
|
||||
this.startReading()
|
||||
this.startStderrReading()
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
@@ -249,33 +250,66 @@ export class LSPClient {
|
||||
`LSP server exited immediately with code ${this.proc.exitCode}` + (stderr ? `\nstderr: ${stderr}` : "")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private startReading(): void {
|
||||
if (!this.proc) return
|
||||
|
||||
const reader = this.proc.stdout.getReader()
|
||||
const read = async () => {
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
const stdoutReader = this.proc.stdout.getReader()
|
||||
const nodeReadable = new Readable({
|
||||
async read() {
|
||||
try {
|
||||
const { done, value } = await stdoutReader.read()
|
||||
if (done) {
|
||||
this.processExited = true
|
||||
this.rejectAllPending("LSP server stdout closed")
|
||||
break
|
||||
this.push(null)
|
||||
} else {
|
||||
this.push(Buffer.from(value))
|
||||
}
|
||||
const newBuf = new Uint8Array(this.buffer.length + value.length)
|
||||
newBuf.set(this.buffer)
|
||||
newBuf.set(value, this.buffer.length)
|
||||
this.buffer = newBuf
|
||||
this.processBuffer()
|
||||
} catch {
|
||||
this.push(null)
|
||||
}
|
||||
} catch (err) {
|
||||
this.processExited = true
|
||||
this.rejectAllPending(`LSP stdout read error: ${err}`)
|
||||
},
|
||||
})
|
||||
|
||||
const stdin = this.proc.stdin
|
||||
const nodeWritable = new Writable({
|
||||
write(chunk, _encoding, callback) {
|
||||
try {
|
||||
stdin.write(chunk)
|
||||
callback()
|
||||
} catch (err) {
|
||||
callback(err as Error)
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
this.connection = createMessageConnection(
|
||||
new StreamMessageReader(nodeReadable),
|
||||
new StreamMessageWriter(nodeWritable)
|
||||
)
|
||||
|
||||
this.connection.onNotification("textDocument/publishDiagnostics", (params: { uri?: string; diagnostics?: Diagnostic[] }) => {
|
||||
if (params.uri) {
|
||||
this.diagnosticsStore.set(params.uri, params.diagnostics ?? [])
|
||||
}
|
||||
}
|
||||
read()
|
||||
})
|
||||
|
||||
this.connection.onRequest("workspace/configuration", (params: { items?: Array<{ section?: string }> }) => {
|
||||
const items = params?.items ?? []
|
||||
return items.map((item) => {
|
||||
if (item.section === "json") return { validate: { enable: true } }
|
||||
return {}
|
||||
})
|
||||
})
|
||||
|
||||
this.connection.onRequest("client/registerCapability", () => null)
|
||||
this.connection.onRequest("window/workDoneProgress/create", () => null)
|
||||
|
||||
this.connection.onClose(() => {
|
||||
this.processExited = true
|
||||
})
|
||||
|
||||
this.connection.onError((error) => {
|
||||
console.error("LSP connection error:", error)
|
||||
})
|
||||
|
||||
this.connection.listen()
|
||||
}
|
||||
|
||||
private startStderrReading(): void {
|
||||
@@ -294,142 +328,49 @@ export class LSPClient {
|
||||
this.stderrBuffer.shift()
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
read()
|
||||
}
|
||||
|
||||
private rejectAllPending(reason: string): void {
|
||||
for (const [id, handler] of this.pending) {
|
||||
handler.reject(new Error(reason))
|
||||
this.pending.delete(id)
|
||||
}
|
||||
}
|
||||
private async sendRequest<T>(method: string, params?: unknown): Promise<T> {
|
||||
if (!this.connection) throw new Error("LSP client not started")
|
||||
|
||||
private findSequence(haystack: Uint8Array, needle: number[]): number {
|
||||
outer: for (let i = 0; i <= haystack.length - needle.length; i++) {
|
||||
for (let j = 0; j < needle.length; j++) {
|
||||
if (haystack[i + j] !== needle[j]) continue outer
|
||||
}
|
||||
return i
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
private processBuffer(): void {
|
||||
const decoder = new TextDecoder()
|
||||
const CONTENT_LENGTH = [67, 111, 110, 116, 101, 110, 116, 45, 76, 101, 110, 103, 116, 104, 58]
|
||||
const CRLF_CRLF = [13, 10, 13, 10]
|
||||
const LF_LF = [10, 10]
|
||||
|
||||
while (true) {
|
||||
const headerStart = this.findSequence(this.buffer, CONTENT_LENGTH)
|
||||
if (headerStart === -1) break
|
||||
if (headerStart > 0) this.buffer = this.buffer.slice(headerStart)
|
||||
|
||||
let headerEnd = this.findSequence(this.buffer, CRLF_CRLF)
|
||||
let sepLen = 4
|
||||
if (headerEnd === -1) {
|
||||
headerEnd = this.findSequence(this.buffer, LF_LF)
|
||||
sepLen = 2
|
||||
}
|
||||
if (headerEnd === -1) break
|
||||
|
||||
const header = decoder.decode(this.buffer.slice(0, headerEnd))
|
||||
const match = header.match(/Content-Length:\s*(\d+)/i)
|
||||
if (!match) break
|
||||
|
||||
const len = parseInt(match[1], 10)
|
||||
const start = headerEnd + sepLen
|
||||
const end = start + len
|
||||
if (this.buffer.length < end) break
|
||||
|
||||
const content = decoder.decode(this.buffer.slice(start, end))
|
||||
this.buffer = this.buffer.slice(end)
|
||||
|
||||
try {
|
||||
const msg = JSON.parse(content)
|
||||
|
||||
if ("method" in msg && !("id" in msg)) {
|
||||
if (msg.method === "textDocument/publishDiagnostics" && msg.params?.uri) {
|
||||
this.diagnosticsStore.set(msg.params.uri, msg.params.diagnostics ?? [])
|
||||
}
|
||||
} else if ("id" in msg && "method" in msg) {
|
||||
this.handleServerRequest(msg.id, msg.method, msg.params)
|
||||
} else if ("id" in msg && this.pending.has(msg.id)) {
|
||||
const handler = this.pending.get(msg.id)!
|
||||
this.pending.delete(msg.id)
|
||||
if ("error" in msg) {
|
||||
handler.reject(new Error(msg.error.message))
|
||||
} else {
|
||||
handler.resolve(msg.result)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private send(method: string, params?: unknown): Promise<unknown> {
|
||||
if (!this.proc) throw new Error("LSP client not started")
|
||||
|
||||
if (this.processExited || this.proc.exitCode !== null) {
|
||||
if (this.processExited || (this.proc && this.proc.exitCode !== null)) {
|
||||
const stderr = this.stderrBuffer.slice(-10).join("\n")
|
||||
throw new Error(`LSP server already exited (code: ${this.proc.exitCode})` + (stderr ? `\nstderr: ${stderr}` : ""))
|
||||
throw new Error(`LSP server already exited (code: ${this.proc?.exitCode})` + (stderr ? `\nstderr: ${stderr}` : ""))
|
||||
}
|
||||
|
||||
const id = ++this.requestIdCounter
|
||||
const msg = JSON.stringify({ jsonrpc: "2.0", id, method, params })
|
||||
const header = `Content-Length: ${Buffer.byteLength(msg)}\r\n\r\n`
|
||||
this.proc.stdin.write(header + msg)
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
this.pending.set(id, { resolve, reject })
|
||||
setTimeout(() => {
|
||||
if (this.pending.has(id)) {
|
||||
this.pending.delete(id)
|
||||
const stderr = this.stderrBuffer.slice(-5).join("\n")
|
||||
reject(new Error(`LSP request timeout (method: ${method})` + (stderr ? `\nrecent stderr: ${stderr}` : "")))
|
||||
}
|
||||
}, 15000)
|
||||
let timeoutId: ReturnType<typeof setTimeout>
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
timeoutId = setTimeout(() => {
|
||||
const stderr = this.stderrBuffer.slice(-5).join("\n")
|
||||
reject(new Error(`LSP request timeout (method: ${method})` + (stderr ? `\nrecent stderr: ${stderr}` : "")))
|
||||
}, this.REQUEST_TIMEOUT)
|
||||
})
|
||||
}
|
||||
|
||||
private notify(method: string, params?: unknown): void {
|
||||
if (!this.proc) return
|
||||
if (this.processExited || this.proc.exitCode !== null) return
|
||||
const requestPromise = this.connection.sendRequest(method, params) as Promise<T>
|
||||
|
||||
const msg = JSON.stringify({ jsonrpc: "2.0", method, params })
|
||||
this.proc.stdin.write(`Content-Length: ${Buffer.byteLength(msg)}\r\n\r\n${msg}`)
|
||||
}
|
||||
|
||||
private respond(id: number | string, result: unknown): void {
|
||||
if (!this.proc) return
|
||||
if (this.processExited || this.proc.exitCode !== null) return
|
||||
|
||||
const msg = JSON.stringify({ jsonrpc: "2.0", id, result })
|
||||
this.proc.stdin.write(`Content-Length: ${Buffer.byteLength(msg)}\r\n\r\n${msg}`)
|
||||
}
|
||||
|
||||
private handleServerRequest(id: number | string, method: string, params?: unknown): void {
|
||||
if (method === "workspace/configuration") {
|
||||
const items = (params as { items?: Array<{ section?: string }> })?.items ?? []
|
||||
const result = items.map((item) => {
|
||||
if (item.section === "json") return { validate: { enable: true } }
|
||||
return {}
|
||||
})
|
||||
this.respond(id, result)
|
||||
} else if (method === "client/registerCapability") {
|
||||
this.respond(id, null)
|
||||
} else if (method === "window/workDoneProgress/create") {
|
||||
this.respond(id, null)
|
||||
try {
|
||||
const result = await Promise.race([requestPromise, timeoutPromise])
|
||||
clearTimeout(timeoutId!)
|
||||
return result
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId!)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private sendNotification(method: string, params?: unknown): void {
|
||||
if (!this.connection) return
|
||||
if (this.processExited || (this.proc && this.proc.exitCode !== null)) return
|
||||
|
||||
this.connection.sendNotification(method, params)
|
||||
}
|
||||
|
||||
async initialize(): Promise<void> {
|
||||
const rootUri = pathToFileURL(this.root).href
|
||||
await this.send("initialize", {
|
||||
await this.sendRequest("initialize", {
|
||||
processId: process.pid,
|
||||
rootUri,
|
||||
rootPath: this.root,
|
||||
@@ -481,8 +422,8 @@ export class LSPClient {
|
||||
},
|
||||
...this.server.initialization,
|
||||
})
|
||||
this.notify("initialized")
|
||||
this.notify("workspace/didChangeConfiguration", {
|
||||
this.sendNotification("initialized")
|
||||
this.sendNotification("workspace/didChangeConfiguration", {
|
||||
settings: { json: { validate: { enable: true } } },
|
||||
})
|
||||
await new Promise((r) => setTimeout(r, 300))
|
||||
@@ -496,7 +437,7 @@ export class LSPClient {
|
||||
const ext = extname(absPath)
|
||||
const languageId = getLanguageId(ext)
|
||||
|
||||
this.notify("textDocument/didOpen", {
|
||||
this.sendNotification("textDocument/didOpen", {
|
||||
textDocument: {
|
||||
uri: pathToFileURL(absPath).href,
|
||||
languageId,
|
||||
@@ -512,7 +453,7 @@ export class LSPClient {
|
||||
async definition(filePath: string, line: number, character: number): Promise<unknown> {
|
||||
const absPath = resolve(filePath)
|
||||
await this.openFile(absPath)
|
||||
return this.send("textDocument/definition", {
|
||||
return this.sendRequest("textDocument/definition", {
|
||||
textDocument: { uri: pathToFileURL(absPath).href },
|
||||
position: { line: line - 1, character },
|
||||
})
|
||||
@@ -521,7 +462,7 @@ export class LSPClient {
|
||||
async references(filePath: string, line: number, character: number, includeDeclaration = true): Promise<unknown> {
|
||||
const absPath = resolve(filePath)
|
||||
await this.openFile(absPath)
|
||||
return this.send("textDocument/references", {
|
||||
return this.sendRequest("textDocument/references", {
|
||||
textDocument: { uri: pathToFileURL(absPath).href },
|
||||
position: { line: line - 1, character },
|
||||
context: { includeDeclaration },
|
||||
@@ -531,13 +472,13 @@ export class LSPClient {
|
||||
async documentSymbols(filePath: string): Promise<unknown> {
|
||||
const absPath = resolve(filePath)
|
||||
await this.openFile(absPath)
|
||||
return this.send("textDocument/documentSymbol", {
|
||||
return this.sendRequest("textDocument/documentSymbol", {
|
||||
textDocument: { uri: pathToFileURL(absPath).href },
|
||||
})
|
||||
}
|
||||
|
||||
async workspaceSymbols(query: string): Promise<unknown> {
|
||||
return this.send("workspace/symbol", { query })
|
||||
return this.sendRequest("workspace/symbol", { query })
|
||||
}
|
||||
|
||||
async diagnostics(filePath: string): Promise<{ items: Diagnostic[] }> {
|
||||
@@ -547,14 +488,13 @@ export class LSPClient {
|
||||
await new Promise((r) => setTimeout(r, 500))
|
||||
|
||||
try {
|
||||
const result = await this.send("textDocument/diagnostic", {
|
||||
const result = await this.sendRequest<{ items?: Diagnostic[] }>("textDocument/diagnostic", {
|
||||
textDocument: { uri },
|
||||
})
|
||||
if (result && typeof result === "object" && "items" in result) {
|
||||
return result as { items: Diagnostic[] }
|
||||
}
|
||||
} catch {
|
||||
}
|
||||
} catch {}
|
||||
|
||||
return { items: this.diagnosticsStore.get(uri) ?? [] }
|
||||
}
|
||||
@@ -562,7 +502,7 @@ export class LSPClient {
|
||||
async prepareRename(filePath: string, line: number, character: number): Promise<unknown> {
|
||||
const absPath = resolve(filePath)
|
||||
await this.openFile(absPath)
|
||||
return this.send("textDocument/prepareRename", {
|
||||
return this.sendRequest("textDocument/prepareRename", {
|
||||
textDocument: { uri: pathToFileURL(absPath).href },
|
||||
position: { line: line - 1, character },
|
||||
})
|
||||
@@ -571,7 +511,7 @@ export class LSPClient {
|
||||
async rename(filePath: string, line: number, character: number, newName: string): Promise<unknown> {
|
||||
const absPath = resolve(filePath)
|
||||
await this.openFile(absPath)
|
||||
return this.send("textDocument/rename", {
|
||||
return this.sendRequest("textDocument/rename", {
|
||||
textDocument: { uri: pathToFileURL(absPath).href },
|
||||
position: { line: line - 1, character },
|
||||
newName,
|
||||
@@ -583,10 +523,13 @@ export class LSPClient {
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
try {
|
||||
this.notify("shutdown", {})
|
||||
this.notify("exit")
|
||||
} catch {
|
||||
if (this.connection) {
|
||||
try {
|
||||
this.sendNotification("shutdown", {})
|
||||
this.sendNotification("exit")
|
||||
} catch {}
|
||||
this.connection.dispose()
|
||||
this.connection = null
|
||||
}
|
||||
this.proc?.kill()
|
||||
this.proc = null
|
||||
|
||||
5
src/types/request-info.d.ts
vendored
Normal file
5
src/types/request-info.d.ts
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
declare global {
|
||||
type RequestInfo = string | URL
|
||||
}
|
||||
|
||||
export {}
|
||||
Reference in New Issue
Block a user