mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge upstream main
This commit is contained in:
@@ -24,6 +24,25 @@ jobs:
|
||||
with:
|
||||
version: v2.10.1
|
||||
|
||||
vuln_check:
|
||||
name: Security Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Run Govulncheck
|
||||
uses: golang/govulncheck-action@v1
|
||||
with:
|
||||
go-package: ./...
|
||||
|
||||
test:
|
||||
name: Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -100,3 +100,11 @@ jobs:
|
||||
gh release edit "${{ inputs.tag }}" \
|
||||
--draft=${{ inputs.draft }} \
|
||||
--prerelease=${{ inputs.prerelease }}
|
||||
|
||||
upload-tos:
|
||||
name: Upload to TOS
|
||||
needs: release
|
||||
uses: ./.github/workflows/upload-tos.yml
|
||||
with:
|
||||
tag: ${{ inputs.tag }}
|
||||
secrets: inherit
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
name: Upload to Volcengine TOS
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Release tag to download and upload (e.g. v0.2.0)"
|
||||
required: true
|
||||
type: string
|
||||
workflow_call:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Release tag to download and upload"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
upload-tos:
|
||||
name: Upload to Volcengine TOS
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Download release assets
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
mkdir -p artifacts
|
||||
gh release download "${{ inputs.tag }}" \
|
||||
--repo "${{ github.repository }}" \
|
||||
--dir artifacts \
|
||||
--pattern "*.tar.gz" \
|
||||
--pattern "*.zip" \
|
||||
--pattern "*.rpm" \
|
||||
--pattern "*.deb"
|
||||
|
||||
- name: Upload to Volcengine TOS
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.VOLC_TOS_ACCESS_KEY }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.VOLC_TOS_SECRET_KEY }}
|
||||
AWS_DEFAULT_REGION: cn-beijing
|
||||
run: |
|
||||
aws configure set default.s3.addressing_style virtual
|
||||
TOS_ENDPOINT="https://tos-s3-cn-beijing.volces.com"
|
||||
# Upload to versioned directory
|
||||
aws s3 sync artifacts/ "s3://picoclaw-downloads/${{ inputs.tag }}/" \
|
||||
--endpoint-url "$TOS_ENDPOINT"
|
||||
# Upload to latest (overwrite)
|
||||
aws s3 sync artifacts/ "s3://picoclaw-downloads/latest/" \
|
||||
--endpoint-url "$TOS_ENDPOINT" \
|
||||
--delete
|
||||
@@ -19,7 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
---
|
||||
|
||||
PicoClaw is heavily inspired by and based on [nanobot](https://github.com/HKUDS/nanobot) by HKUDS.
|
||||
|
||||
@@ -216,7 +216,7 @@ docker compose -f docker/docker-compose.yml --profile gateway up -d
|
||||
> [!TIP]
|
||||
> Set your API key in `~/.picoclaw/config.json`.
|
||||
> Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
|
||||
> Web Search is **optional** - get free [Tavily API](https://tavily.com) (1000 free queries/month) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback.
|
||||
> Web Search is **optional** - get free [Tavily API](https://tavily.com) (1000 free queries/month), [SearXNG](https://github.com/searxng/searxng) (free, self-hosted) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback.
|
||||
|
||||
**1. Initialize**
|
||||
|
||||
@@ -265,6 +265,16 @@ picoclaw onboard
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
},
|
||||
"perplexity": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_PERPLEXITY_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"searxng": {
|
||||
"enabled": false,
|
||||
"base_url": "http://your-searxng-instance:8888",
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -277,7 +287,12 @@ picoclaw onboard
|
||||
**3. Get API Keys**
|
||||
|
||||
* **LLM Provider**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
|
||||
* **Web Search** (optional): [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month) · [Brave Search](https://brave.com/search/api) - Free tier available (2000 requests/month)
|
||||
* **Web Search** (optional):
|
||||
* [Brave Search](https://brave.com/search/api) - Paid ($5/1000 queries, ~$5-6/month)
|
||||
* [Perplexity](https://www.perplexity.ai) - AI-powered search with chat interface
|
||||
* [SearXNG](https://github.com/searxng/searxng) - Self-hosted metasearch engine (free, no API key needed)
|
||||
* [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month)
|
||||
* DuckDuckGo - Built-in fallback (no API key required)
|
||||
|
||||
> **Note**: See `config.example.json` for a complete configuration template.
|
||||
|
||||
@@ -1243,6 +1258,16 @@ picoclaw agent -m "Hello"
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
},
|
||||
"perplexity": {
|
||||
"enabled": false,
|
||||
"api_key": "",
|
||||
"max_results": 5
|
||||
},
|
||||
"searxng": {
|
||||
"enabled": false,
|
||||
"base_url": "http://localhost:8888",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
@@ -1300,10 +1325,69 @@ discord: <https://discord.gg/V4sAZ9XWpN>
|
||||
|
||||
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
|
||||
|
||||
To enable web search:
|
||||
#### Search Provider Priority
|
||||
|
||||
1. **Option 1 (Recommended)**: Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) for the best results.
|
||||
2. **Option 2 (No Credit Card)**: If you don't have a key, we automatically fall back to **DuckDuckGo** (no key required).
|
||||
PicoClaw automatically selects the best available search provider in this order:
|
||||
1. **Perplexity** (if enabled and API key configured) - AI-powered search with citations
|
||||
2. **Brave Search** (if enabled and API key configured) - Privacy-focused paid API ($5/1000 queries)
|
||||
3. **SearXNG** (if enabled and base_url configured) - Self-hosted metasearch aggregating 70+ engines (free)
|
||||
4. **DuckDuckGo** (if enabled, default fallback) - No API key required (free)
|
||||
|
||||
#### Web Search Configuration Options
|
||||
|
||||
**Option 1 (Best Results)**: Perplexity AI Search
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"perplexity": {
|
||||
"enabled": true,
|
||||
"api_key": "YOUR_PERPLEXITY_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Option 2 (Paid API)**: Get an API key at [https://brave.com/search/api](https://brave.com/search/api) ($5/1000 queries, ~$5-6/month)
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"brave": {
|
||||
"enabled": true,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Option 3 (Self-Hosted)**: Deploy your own [SearXNG](https://github.com/searxng/searxng) instance
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"searxng": {
|
||||
"enabled": true,
|
||||
"base_url": "http://your-server:8888",
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Benefits of SearXNG:
|
||||
- **Zero cost**: No API fees or rate limits
|
||||
- **Privacy-focused**: Self-hosted, no tracking
|
||||
- **Aggregate results**: Queries 70+ search engines simultaneously
|
||||
- **Perfect for cloud VMs**: Solves datacenter IP blocking issues (Oracle Cloud, GCP, AWS, Azure)
|
||||
- **No API key needed**: Just deploy and configure the base URL
|
||||
|
||||
**Option 4 (No Setup Required)**: DuckDuckGo is enabled by default as fallback (no API key needed)
|
||||
|
||||
Add the key to `~/.picoclaw/config.json` if using Brave:
|
||||
|
||||
@@ -1319,6 +1403,16 @@ Add the key to `~/.picoclaw/config.json` if using Brave:
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
},
|
||||
"perplexity": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_PERPLEXITY_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"searxng": {
|
||||
"enabled": false,
|
||||
"base_url": "http://your-searxng-instance:8888",
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1337,10 +1431,11 @@ This happens when another instance of the bot is running. Make sure only one `pi
|
||||
|
||||
## 📝 API Key Comparison
|
||||
|
||||
| Service | Free Tier | Use Case |
|
||||
| ---------------- | ------------------- | ------------------------------------- |
|
||||
| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) |
|
||||
| **Zhipu** | 200K tokens/month | Best for Chinese users |
|
||||
| **Brave Search** | 2000 queries/month | Web search functionality |
|
||||
| **Groq** | Free tier available | Fast inference (Llama, Mixtral) |
|
||||
| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) |
|
||||
| Service | Free Tier | Use Case |
|
||||
| ---------------- | ------------------------ | ------------------------------------- |
|
||||
| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) |
|
||||
| **Zhipu** | 200K tokens/month | Best for Chinese users |
|
||||
| **Brave Search** | Paid ($5/1000 queries) | Web search functionality |
|
||||
| **SearXNG** | Unlimited (self-hosted) | Privacy-focused metasearch (70+ engines) |
|
||||
| **Groq** | Free tier available | Fast inference (Llama, Mixtral) |
|
||||
| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) |
|
||||
|
||||
@@ -335,7 +335,11 @@ func (s *appState) testModel(model *picoclawconfig.ModelConfig) {
|
||||
s.showMessage("Test OK", resp.Status)
|
||||
return
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
if err != nil {
|
||||
s.showMessage("Test failed", fmt.Sprintf("failed to read response: %v", err))
|
||||
return
|
||||
}
|
||||
s.showMessage(
|
||||
"Test failed",
|
||||
fmt.Sprintf("%s: %s", resp.Status, strings.TrimSpace(string(body))),
|
||||
|
||||
@@ -297,7 +297,10 @@ func fetchGoogleUserEmail(accessToken string) (string, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading userinfo response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("userinfo request failed: %s", string(body))
|
||||
}
|
||||
|
||||
@@ -177,7 +177,10 @@ func fetchGoogleUserEmail(accessToken string) (string, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading userinfo response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("userinfo request failed: %s", string(body))
|
||||
}
|
||||
|
||||
@@ -230,19 +230,25 @@ func setupCronTool(
|
||||
// Create cron service
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
// Create and register CronTool
|
||||
cronTool, err := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error during CronTool initialization: %v", err)
|
||||
// Create and register CronTool if enabled
|
||||
var cronTool *tools.CronTool
|
||||
if cfg.Tools.IsToolEnabled("cron") {
|
||||
var err error
|
||||
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error during CronTool initialization: %v", err)
|
||||
}
|
||||
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
}
|
||||
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
|
||||
// Set the onJob handler
|
||||
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
|
||||
result := cronTool.ExecuteJob(context.Background(), job)
|
||||
return result, nil
|
||||
})
|
||||
// Set onJob handler
|
||||
if cronTool != nil {
|
||||
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
|
||||
result := cronTool.ExecuteJob(context.Background(), job)
|
||||
return result, nil
|
||||
})
|
||||
}
|
||||
|
||||
return cronService
|
||||
}
|
||||
|
||||
@@ -18,12 +18,21 @@ var (
|
||||
goVersion string
|
||||
)
|
||||
|
||||
// GetPicoclawHome returns the picoclaw home directory.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
func GetPicoclawHome() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
return home
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".picoclaw")
|
||||
}
|
||||
|
||||
func GetConfigPath() string {
|
||||
if configPath := os.Getenv("PICOCLAW_CONFIG"); configPath != "" {
|
||||
return configPath
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".picoclaw", "config.json")
|
||||
return filepath.Join(GetPicoclawHome(), "config.json")
|
||||
}
|
||||
|
||||
func LoadConfig() (*config.Config, error) {
|
||||
|
||||
@@ -19,6 +19,27 @@ func TestGetConfigPath(t *testing.T) {
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestGetConfigPath_WithPICOCLAW_HOME(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw")
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
|
||||
got := GetConfigPath()
|
||||
want := filepath.Join("/custom/picoclaw", "config.json")
|
||||
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestGetConfigPath_WithPICOCLAW_CONFIG(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_CONFIG", "/custom/config.json")
|
||||
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw")
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
|
||||
got := GetConfigPath()
|
||||
want := "/custom/config.json"
|
||||
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestFormatVersion_NoGitCommit(t *testing.T) {
|
||||
oldVersion, oldGit := version, gitCommit
|
||||
t.Cleanup(func() { version, gitCommit = oldVersion, oldGit })
|
||||
|
||||
@@ -21,8 +21,8 @@ picoclaw skills install --registry clawhub github
|
||||
`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if registry != "" {
|
||||
if len(args) != 2 {
|
||||
return fmt.Errorf("when --registry is set, exactly 2 arguments are required: <name> <slug>")
|
||||
if len(args) != 1 {
|
||||
return fmt.Errorf("when --registry is set, exactly 1 argument is required: <slug>")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -45,7 +45,7 @@ picoclaw skills install --registry clawhub github
|
||||
return err
|
||||
}
|
||||
|
||||
return skillsInstallFromRegistry(cfg, args[0], args[1])
|
||||
return skillsInstallFromRegistry(cfg, registry, args[0])
|
||||
}
|
||||
|
||||
return skillsInstallCmd(installer, args[0])
|
||||
|
||||
@@ -26,3 +26,72 @@ func TestNewInstallSubcommand(t *testing.T) {
|
||||
|
||||
assert.Len(t, cmd.Aliases, 0)
|
||||
}
|
||||
|
||||
func TestInstallCommandArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
registry string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "no registry, one arg",
|
||||
args: []string{"sipeed/picoclaw-skills/weather"},
|
||||
registry: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "no registry, no args",
|
||||
args: []string{},
|
||||
registry: "",
|
||||
expectError: true,
|
||||
errorMsg: "exactly 1 argument is required: <github>",
|
||||
},
|
||||
{
|
||||
name: "no registry, too many args",
|
||||
args: []string{"arg1", "arg2"},
|
||||
registry: "",
|
||||
expectError: true,
|
||||
errorMsg: "exactly 1 argument is required: <github>",
|
||||
},
|
||||
{
|
||||
name: "with registry, one arg",
|
||||
args: []string{"weather-skill"},
|
||||
registry: "clawhub",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "with registry, no args",
|
||||
args: []string{},
|
||||
registry: "clawhub",
|
||||
expectError: true,
|
||||
errorMsg: "when --registry is set, exactly 1 argument is required: <slug>",
|
||||
},
|
||||
{
|
||||
name: "with registry, too many args",
|
||||
args: []string{"arg1", "arg2"},
|
||||
registry: "clawhub",
|
||||
expectError: true,
|
||||
errorMsg: "when --registry is set, exactly 1 argument is required: <slug>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := newInstallCommand(nil)
|
||||
|
||||
if tt.registry != "" {
|
||||
require.NoError(t, cmd.Flags().Set("registry", tt.registry))
|
||||
}
|
||||
|
||||
err := cmd.Args(cmd, tt.args)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.errorMsg, err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,8 @@
|
||||
"model_name": "claude-sonnet-4.6",
|
||||
"model": "anthropic/claude-sonnet-4.6",
|
||||
"api_key": "sk-ant-your-key",
|
||||
"api_base": "https://api.anthropic.com/v1"
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"thinking_level": "high"
|
||||
},
|
||||
{
|
||||
"model_name": "gemini",
|
||||
@@ -224,27 +225,53 @@
|
||||
"mistral": {
|
||||
"api_key": "",
|
||||
"api_base": "https://api.mistral.ai/v1"
|
||||
},
|
||||
"avian": {
|
||||
"api_key": "",
|
||||
"api_base": "https://api.avian.io/v1"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"allow_read_paths": null,
|
||||
"allow_write_paths": null,
|
||||
"web": {
|
||||
"enabled": true,
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"tavily": {
|
||||
"enabled": false,
|
||||
"api_key": "",
|
||||
"base_url": "",
|
||||
"max_results": 0
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
},
|
||||
"perplexity": {
|
||||
"enabled": false,
|
||||
"api_key": "pplx-xxx",
|
||||
"api_key": "",
|
||||
"max_results": 5
|
||||
},
|
||||
"proxy": ""
|
||||
"searxng": {
|
||||
"enabled": false,
|
||||
"base_url": "http://localhost:8888",
|
||||
"max_results": 5
|
||||
},
|
||||
"glm_search": {
|
||||
"enabled": false,
|
||||
"api_key": "",
|
||||
"base_url": "https://open.bigmodel.cn/api/paas/v4/web_search",
|
||||
"search_engine": "search_std",
|
||||
"max_results": 5
|
||||
},
|
||||
"fetch_limit_bytes": 10485760
|
||||
},
|
||||
"cron": {
|
||||
"enabled": true,
|
||||
"exec_timeout_minutes": 5
|
||||
},
|
||||
"mcp": {
|
||||
@@ -313,19 +340,75 @@
|
||||
}
|
||||
},
|
||||
"exec": {
|
||||
"enable_deny_patterns": false,
|
||||
"custom_deny_patterns": []
|
||||
"enabled": true,
|
||||
"enable_deny_patterns": true,
|
||||
"custom_deny_patterns": null,
|
||||
"custom_allow_patterns": null
|
||||
},
|
||||
"skills": {
|
||||
"enabled": true,
|
||||
"registries": {
|
||||
"clawhub": {
|
||||
"enabled": true,
|
||||
"base_url": "https://clawhub.ai",
|
||||
"search_path": "/api/v1/search",
|
||||
"skills_path": "/api/v1/skills",
|
||||
"download_path": "/api/v1/download"
|
||||
"auth_token": "",
|
||||
"search_path": "",
|
||||
"skills_path": "",
|
||||
"download_path": "",
|
||||
"timeout": 0,
|
||||
"max_zip_size": 0,
|
||||
"max_response_size": 0
|
||||
}
|
||||
},
|
||||
"max_concurrent_searches": 2,
|
||||
"search_cache": {
|
||||
"max_size": 50,
|
||||
"ttl_seconds": 300
|
||||
}
|
||||
},
|
||||
"media_cleanup": {
|
||||
"enabled": true,
|
||||
"max_age_minutes": 30,
|
||||
"interval_minutes": 5
|
||||
},
|
||||
"append_file": {
|
||||
"enabled": true
|
||||
},
|
||||
"edit_file": {
|
||||
"enabled": true
|
||||
},
|
||||
"find_skills": {
|
||||
"enabled": true
|
||||
},
|
||||
"i2c": {
|
||||
"enabled": false
|
||||
},
|
||||
"install_skill": {
|
||||
"enabled": true
|
||||
},
|
||||
"list_dir": {
|
||||
"enabled": true
|
||||
},
|
||||
"message": {
|
||||
"enabled": true
|
||||
},
|
||||
"read_file": {
|
||||
"enabled": true
|
||||
},
|
||||
"spawn": {
|
||||
"enabled": true
|
||||
},
|
||||
"spi": {
|
||||
"enabled": false
|
||||
},
|
||||
"subagent": {
|
||||
"enabled": true
|
||||
},
|
||||
"web_fetch": {
|
||||
"enabled": true
|
||||
},
|
||||
"write_file": {
|
||||
"enabled": true
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
|
||||
@@ -180,6 +180,7 @@ The skills tool configures skill discovery and installation via registries like
|
||||
| ---------------------------------- | ------ | -------------------- | ----------------------- |
|
||||
| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
|
||||
| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
|
||||
| `registries.clawhub.auth_token` | string | `""` | Optional Bearer token for higher rate limits |
|
||||
| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
|
||||
| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
|
||||
| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
|
||||
@@ -194,6 +195,7 @@ The skills tool configures skill discovery and installation via registries like
|
||||
"clawhub": {
|
||||
"enabled": true,
|
||||
"base_url": "https://clawhub.ai",
|
||||
"auth_token": "",
|
||||
"search_path": "/api/v1/search",
|
||||
"skills_path": "/api/v1/skills",
|
||||
"download_path": "/api/v1/download"
|
||||
|
||||
@@ -37,7 +37,6 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/gdamore/tcell/v2 v2.13.8 // indirect
|
||||
github.com/h2non/filetype v1.1.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
|
||||
@@ -42,6 +42,9 @@ type ContextBuilder struct {
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
return home
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
|
||||
+63
-12
@@ -26,6 +26,7 @@ type AgentInstance struct {
|
||||
MaxIterations int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
ThinkingLevel ThinkingLevel
|
||||
ContextWindow int
|
||||
SummarizeMessageThreshold int
|
||||
SummarizeTokenPercent int
|
||||
@@ -36,6 +37,14 @@ type AgentInstance struct {
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
|
||||
// Router is non-nil when model routing is configured and the light model
|
||||
// was successfully resolved. It scores each incoming message and decides
|
||||
// whether to route to LightCandidates or stay with Candidates.
|
||||
Router *routing.Router
|
||||
// LightCandidates holds the resolved provider candidates for the light model.
|
||||
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
|
||||
LightCandidates []providers.FallbackCandidate
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -59,17 +68,30 @@ func NewAgentInstance(
|
||||
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
|
||||
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
toolsRegistry.Register(execTool)
|
||||
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
|
||||
if cfg.Tools.IsToolEnabled("read_file") {
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("write_file") {
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("list_dir") {
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("exec") {
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
toolsRegistry.Register(execTool)
|
||||
}
|
||||
|
||||
if cfg.Tools.IsToolEnabled("edit_file") {
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("append_file") {
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
|
||||
}
|
||||
|
||||
sessionsDir := filepath.Join(workspace, "sessions")
|
||||
sessionsManager := session.NewSessionManager(sessionsDir)
|
||||
@@ -103,6 +125,12 @@ func NewAgentInstance(
|
||||
temperature = *defaults.Temperature
|
||||
}
|
||||
|
||||
var thinkingLevelStr string
|
||||
if mc, err := cfg.GetModelConfig(model); err == nil {
|
||||
thinkingLevelStr = mc.ThinkingLevel
|
||||
}
|
||||
thinkingLevel := parseThinkingLevel(thinkingLevelStr)
|
||||
|
||||
summarizeMessageThreshold := defaults.SummarizeMessageThreshold
|
||||
if summarizeMessageThreshold == 0 {
|
||||
summarizeMessageThreshold = 20
|
||||
@@ -160,6 +188,25 @@ func NewAgentInstance(
|
||||
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
var lightCandidates []providers.FallbackCandidate
|
||||
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
|
||||
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
|
||||
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
|
||||
if len(resolved) > 0 {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
Threshold: rc.Threshold,
|
||||
})
|
||||
lightCandidates = resolved
|
||||
} else {
|
||||
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
|
||||
rc.LightModel, agentID)
|
||||
}
|
||||
}
|
||||
|
||||
return &AgentInstance{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
@@ -169,6 +216,7 @@ func NewAgentInstance(
|
||||
MaxIterations: maxIter,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ThinkingLevel: thinkingLevel,
|
||||
ContextWindow: maxTokens,
|
||||
SummarizeMessageThreshold: summarizeMessageThreshold,
|
||||
SummarizeTokenPercent: summarizeTokenPercent,
|
||||
@@ -179,6 +227,8 @@ func NewAgentInstance(
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
Router: router,
|
||||
LightCandidates: lightCandidates,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,12 +237,13 @@ func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentD
|
||||
if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" {
|
||||
return expandHome(strings.TrimSpace(agentCfg.Workspace))
|
||||
}
|
||||
// Use the configured default workspace (respects PICOCLAW_HOME)
|
||||
if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" {
|
||||
return expandHome(defaults.Workspace)
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
// For named agents without explicit workspace, use default workspace with agent ID suffix
|
||||
id := routing.NormalizeAgentID(agentCfg.ID)
|
||||
return filepath.Join(home, ".picoclaw", "workspace-"+id)
|
||||
return filepath.Join(expandHome(defaults.Workspace), "..", "workspace-"+id)
|
||||
}
|
||||
|
||||
// resolveAgentModel resolves the primary model for an agent.
|
||||
|
||||
+164
-110
@@ -108,76 +108,106 @@ func registerSharedTools(
|
||||
}
|
||||
|
||||
// Web tools
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
|
||||
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
|
||||
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
|
||||
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
|
||||
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
|
||||
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
|
||||
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
|
||||
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
|
||||
} else if searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
if cfg.Tools.IsToolEnabled("web") {
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
|
||||
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
|
||||
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
|
||||
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
|
||||
SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults,
|
||||
SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled,
|
||||
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
|
||||
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
|
||||
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
|
||||
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
|
||||
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
|
||||
} else if searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
}
|
||||
}
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
agent.Tools.Register(fetchTool)
|
||||
if cfg.Tools.IsToolEnabled("web_fetch") {
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
agent.Tools.Register(fetchTool)
|
||||
}
|
||||
}
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
agent.Tools.Register(tools.NewI2CTool())
|
||||
agent.Tools.Register(tools.NewSPITool())
|
||||
if cfg.Tools.IsToolEnabled("i2c") {
|
||||
agent.Tools.Register(tools.NewI2CTool())
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("spi") {
|
||||
agent.Tools.Register(tools.NewSPITool())
|
||||
}
|
||||
|
||||
// Message tool
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
if cfg.Tools.IsToolEnabled("message") {
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
})
|
||||
})
|
||||
})
|
||||
agent.Tools.Register(messageTool)
|
||||
agent.Tools.Register(messageTool)
|
||||
}
|
||||
|
||||
// Skill discovery and installation tools
|
||||
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
|
||||
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
|
||||
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
|
||||
})
|
||||
searchCache := skills.NewSearchCache(
|
||||
cfg.Tools.Skills.SearchCache.MaxSize,
|
||||
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
|
||||
)
|
||||
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
|
||||
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
|
||||
skills_enabled := cfg.Tools.IsToolEnabled("skills")
|
||||
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
|
||||
install_skills_enable := cfg.Tools.IsToolEnabled("install_skill")
|
||||
if skills_enabled && (find_skills_enable || install_skills_enable) {
|
||||
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
|
||||
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
|
||||
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
|
||||
})
|
||||
|
||||
if find_skills_enable {
|
||||
searchCache := skills.NewSearchCache(
|
||||
cfg.Tools.Skills.SearchCache.MaxSize,
|
||||
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
|
||||
)
|
||||
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
|
||||
}
|
||||
|
||||
if install_skills_enable {
|
||||
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn tool with allowlist checker
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
|
||||
})
|
||||
agent.Tools.Register(spawnTool)
|
||||
if cfg.Tools.IsToolEnabled("spawn") {
|
||||
if cfg.Tools.IsToolEnabled("subagent") {
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
|
||||
})
|
||||
agent.Tools.Register(spawnTool)
|
||||
} else {
|
||||
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,7 +215,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
// Initialize MCP servers for all agents
|
||||
if al.cfg.Tools.MCP.Enabled {
|
||||
if al.cfg.Tools.IsToolEnabled("mcp") {
|
||||
mcpManager := mcp.NewManager()
|
||||
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
|
||||
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
|
||||
@@ -227,6 +257,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
agent.Tools.Register(mcpTool)
|
||||
totalRegistrations++
|
||||
@@ -543,8 +574,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
|
||||
// Reset message-tool state for this round so we don't skip publishing due to a previous round.
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(tools.ContextualTool); ok {
|
||||
mt.SetContext(msg.Channel, msg.ChatID)
|
||||
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
|
||||
resetter.ResetSentInRound()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -659,10 +690,7 @@ func (al *AgentLoop) runAgentLoop(
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Update tool contexts
|
||||
al.updateToolContexts(agent, opts.Channel, opts.ChatID)
|
||||
|
||||
// 2. Build messages (skip history for heartbeat)
|
||||
// 1. Build messages (skip history for heartbeat)
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if !opts.NoHistory {
|
||||
@@ -682,10 +710,10 @@ func (al *AgentLoop) runAgentLoop(
|
||||
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
// 3. Save user message to session
|
||||
// 2. Save user message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
|
||||
// 4. Run LLM iteration loop
|
||||
// 3. Run LLM iteration loop
|
||||
finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -694,21 +722,21 @@ func (al *AgentLoop) runAgentLoop(
|
||||
// If last tool had ForUser content and we already sent it, we might not need to send final response
|
||||
// This is controlled by the tool's Silent flag and ForUser content
|
||||
|
||||
// 5. Handle empty response
|
||||
// 4. Handle empty response
|
||||
if finalContent == "" {
|
||||
finalContent = opts.DefaultResponse
|
||||
}
|
||||
|
||||
// 6. Save final assistant message to session
|
||||
// 5. Save final assistant message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
|
||||
agent.Sessions.Save(opts.SessionKey)
|
||||
|
||||
// 7. Optional: summarization
|
||||
// 6. Optional: summarization
|
||||
if opts.EnableSummary {
|
||||
al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
|
||||
}
|
||||
|
||||
// 8. Optional: send response via bus
|
||||
// 7. Optional: send response via bus
|
||||
if opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
@@ -717,7 +745,7 @@ func (al *AgentLoop) runAgentLoop(
|
||||
})
|
||||
}
|
||||
|
||||
// 9. Log response
|
||||
// 8. Log response
|
||||
responsePreview := utils.Truncate(finalContent, 120)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
|
||||
map[string]any{
|
||||
@@ -796,6 +824,12 @@ func (al *AgentLoop) runLLMIteration(
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
|
||||
// Determine effective model tier for this conversation turn.
|
||||
// selectCandidates evaluates routing once and the decision is sticky for
|
||||
// all tool-follow-up iterations within the same turn so that a multi-step
|
||||
// tool chain doesn't switch models mid-way through.
|
||||
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
|
||||
|
||||
for iteration < agent.MaxIterations {
|
||||
iteration++
|
||||
|
||||
@@ -814,7 +848,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": agent.Model,
|
||||
"model": activeModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": agent.MaxTokens,
|
||||
@@ -830,27 +864,33 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
// Call LLM with fallback chain if candidates are configured.
|
||||
// Call LLM with fallback chain if multiple candidates are configured.
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
llmOpts := map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
}
|
||||
// parseThinkingLevel guarantees ThinkingOff for empty/unknown values,
|
||||
// so checking != ThinkingOff is sufficient.
|
||||
if agent.ThinkingLevel != ThinkingOff {
|
||||
if tc, ok := agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
|
||||
llmOpts["thinking_level"] = string(agent.ThinkingLevel)
|
||||
} else {
|
||||
logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring",
|
||||
map[string]any{"agent_id": agent.ID, "thinking_level": string(agent.ThinkingLevel)})
|
||||
}
|
||||
}
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
agent.Candidates,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
messages,
|
||||
providerToolDefs,
|
||||
model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
@@ -866,11 +906,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
})
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
|
||||
}
|
||||
|
||||
// Retry loop for context/token errors
|
||||
@@ -1057,7 +1093,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Create async callback for tools that implement AsyncTool
|
||||
// Create async callback for tools that implement AsyncExecutor
|
||||
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
|
||||
@@ -1139,24 +1175,42 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return finalContent, iteration, nil
|
||||
}
|
||||
|
||||
// updateToolContexts updates the context for tools that need channel/chatID info.
|
||||
func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) {
|
||||
// Use ContextualTool interface instead of type assertions
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(tools.ContextualTool); ok {
|
||||
mt.SetContext(channel, chatID)
|
||||
}
|
||||
// selectCandidates returns the model candidates and resolved model name to use
|
||||
// for a conversation turn. When model routing is configured and the incoming
|
||||
// message scores below the complexity threshold, it returns the light model
|
||||
// candidates instead of the primary ones.
|
||||
//
|
||||
// The returned (candidates, model) pair is used for all LLM calls within one
|
||||
// turn — tool follow-up iterations use the same tier as the initial call so
|
||||
// that a multi-step tool chain doesn't switch models mid-way.
|
||||
func (al *AgentLoop) selectCandidates(
|
||||
agent *AgentInstance,
|
||||
userMsg string,
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, agent.Model
|
||||
}
|
||||
if tool, ok := agent.Tools.Get("spawn"); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := agent.Tools.Get("subagent"); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
if !usedLight {
|
||||
logger.DebugCF("agent", "Model routing: primary model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, agent.Model
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"light_model": agent.Router.LightModel(),
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, agent.Router.LightModel()
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
|
||||
+16
-65
@@ -164,35 +164,21 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolContext_Updates verifies tool context is updated with channel/chatID
|
||||
// TestToolContext_Updates verifies tool context helpers work correctly
|
||||
func TestToolContext_Updates(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
ctx := tools.WithToolContext(context.Background(), "telegram", "chat-42")
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
if got := tools.ToolChannel(ctx); got != "telegram" {
|
||||
t.Errorf("expected channel 'telegram', got %q", got)
|
||||
}
|
||||
if got := tools.ToolChatID(ctx); got != "chat-42" {
|
||||
t.Errorf("expected chatID 'chat-42', got %q", got)
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "OK"}
|
||||
_ = NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Verify that ContextualTool interface is defined and can be implemented
|
||||
// This test validates the interface contract exists
|
||||
ctxTool := &mockContextualTool{}
|
||||
|
||||
// Verify the tool implements the interface correctly
|
||||
var _ tools.ContextualTool = ctxTool
|
||||
// Empty context returns empty strings
|
||||
if got := tools.ToolChannel(context.Background()); got != "" {
|
||||
t.Errorf("expected empty channel from bare context, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
|
||||
@@ -241,16 +227,11 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) {
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = tmpDir
|
||||
cfg.Agents.Defaults.Model = "test-model"
|
||||
cfg.Agents.Defaults.MaxTokens = 4096
|
||||
cfg.Agents.Defaults.MaxToolIterations = 10
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
@@ -359,36 +340,6 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
|
||||
return tools.SilentResult("Custom tool executed")
|
||||
}
|
||||
|
||||
// mockContextualTool tracks context updates
|
||||
type mockContextualTool struct {
|
||||
lastChannel string
|
||||
lastChatID string
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Name() string {
|
||||
return "mock_contextual"
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Description() string {
|
||||
return "Mock contextual tool"
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
return tools.SilentResult("Contextual tool executed")
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) SetContext(channel, chatID string) {
|
||||
m.lastChannel = channel
|
||||
m.lastChatID = chatID
|
||||
}
|
||||
|
||||
// testHelper executes a message and returns the response
|
||||
type testHelper struct {
|
||||
al *AgentLoop
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package agent
|
||||
|
||||
import "strings"
|
||||
|
||||
// ThinkingLevel controls how the provider sends thinking parameters.
|
||||
//
|
||||
// - "adaptive": sends {thinking: {type: "adaptive"}} + output_config.effort (Claude 4.6+)
|
||||
// - "low"/"medium"/"high"/"xhigh": sends {thinking: {type: "enabled", budget_tokens: N}} (all models)
|
||||
// - "off": disables thinking
|
||||
type ThinkingLevel string
|
||||
|
||||
const (
|
||||
ThinkingOff ThinkingLevel = "off"
|
||||
ThinkingLow ThinkingLevel = "low"
|
||||
ThinkingMedium ThinkingLevel = "medium"
|
||||
ThinkingHigh ThinkingLevel = "high"
|
||||
ThinkingXHigh ThinkingLevel = "xhigh"
|
||||
ThinkingAdaptive ThinkingLevel = "adaptive"
|
||||
)
|
||||
|
||||
// parseThinkingLevel normalizes a config string to a ThinkingLevel.
|
||||
// Case-insensitive and whitespace-tolerant for user-facing config values.
|
||||
// Returns ThinkingOff for unknown or empty values.
|
||||
func parseThinkingLevel(level string) ThinkingLevel {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "adaptive":
|
||||
return ThinkingAdaptive
|
||||
case "low":
|
||||
return ThinkingLow
|
||||
case "medium":
|
||||
return ThinkingMedium
|
||||
case "high":
|
||||
return ThinkingHigh
|
||||
case "xhigh":
|
||||
return ThinkingXHigh
|
||||
default:
|
||||
return ThinkingOff
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseThinkingLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want ThinkingLevel
|
||||
}{
|
||||
{"off", "off", ThinkingOff},
|
||||
{"empty", "", ThinkingOff},
|
||||
{"low", "low", ThinkingLow},
|
||||
{"medium", "medium", ThinkingMedium},
|
||||
{"high", "high", ThinkingHigh},
|
||||
{"xhigh", "xhigh", ThinkingXHigh},
|
||||
{"adaptive", "adaptive", ThinkingAdaptive},
|
||||
{"unknown", "unknown", ThinkingOff},
|
||||
// Case-insensitive and whitespace-tolerant
|
||||
{"upper_Medium", "Medium", ThinkingMedium},
|
||||
{"upper_HIGH", "HIGH", ThinkingHigh},
|
||||
{"mixed_Adaptive", "Adaptive", ThinkingAdaptive},
|
||||
{"leading_space", " high", ThinkingHigh},
|
||||
{"trailing_space", "low ", ThinkingLow},
|
||||
{"both_spaces", " medium ", ThinkingMedium},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := parseThinkingLevel(tt.input); got != tt.want {
|
||||
t.Errorf("parseThinkingLevel(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+20
-5
@@ -212,7 +212,10 @@ func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading device code response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
||||
}
|
||||
@@ -300,7 +303,10 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading device code response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
||||
}
|
||||
@@ -360,7 +366,10 @@ func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*Au
|
||||
return nil, fmt.Errorf("pending")
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading device token response: %w", err)
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AuthorizationCode string `json:"authorization_code"`
|
||||
@@ -401,7 +410,10 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading token refresh response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
@@ -494,7 +506,10 @@ func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading token exchange response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||
}
|
||||
|
||||
@@ -39,6 +39,9 @@ func (c *AuthCredential) NeedsRefresh() bool {
|
||||
}
|
||||
|
||||
func authFilePath() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
return filepath.Join(home, "auth.json")
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".picoclaw", "auth.json")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -26,6 +27,12 @@ const (
|
||||
sendTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// Pre-compiled regexes for resolveDiscordRefs (avoid re-compiling per call)
|
||||
channelRefRe = regexp.MustCompile(`<#(\d+)>`)
|
||||
msgLinkRe = regexp.MustCompile(`https://(?:discord\.com|discordapp\.com)/channels/(\d+)/(\d+)/(\d+)`)
|
||||
)
|
||||
|
||||
type DiscordChannel struct {
|
||||
*channels.BaseChannel
|
||||
session *discordgo.Session
|
||||
@@ -338,6 +345,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
content = c.stripBotMention(content)
|
||||
}
|
||||
|
||||
// Resolve Discord refs in main content before concatenation to avoid
|
||||
// double-expanding links that appear in the referenced message.
|
||||
content = c.resolveDiscordRefs(s, content, m.GuildID)
|
||||
|
||||
// Prepend referenced (quoted) message content if this is a reply
|
||||
if m.MessageReference != nil && m.ReferencedMessage != nil {
|
||||
refContent := m.ReferencedMessage.Content
|
||||
if refContent != "" {
|
||||
refAuthor := "unknown"
|
||||
if m.ReferencedMessage.Author != nil {
|
||||
refAuthor = m.ReferencedMessage.Author.Username
|
||||
}
|
||||
refContent = c.resolveDiscordRefs(s, refContent, m.GuildID)
|
||||
content = fmt.Sprintf("[quoted message from %s]: %s\n\n%s",
|
||||
refAuthor, refContent, content)
|
||||
}
|
||||
}
|
||||
|
||||
senderID := m.Author.ID
|
||||
|
||||
mediaPaths := make([]string, 0, len(m.Attachments))
|
||||
@@ -508,6 +533,51 @@ func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveDiscordRefs resolves channel references (<#id> → #channel-name) and
|
||||
// expands Discord message links to show the linked message content.
|
||||
// Only links pointing to the same guild are expanded to prevent cross-guild leakage.
|
||||
func (c *DiscordChannel) resolveDiscordRefs(s *discordgo.Session, text string, guildID string) string {
|
||||
// 1. Resolve channel references: <#id> → #channel-name
|
||||
text = channelRefRe.ReplaceAllStringFunc(text, func(match string) string {
|
||||
parts := channelRefRe.FindStringSubmatch(match)
|
||||
if len(parts) < 2 {
|
||||
return match
|
||||
}
|
||||
// Prefer session state cache to avoid API calls
|
||||
if ch, err := s.State.Channel(parts[1]); err == nil {
|
||||
return "#" + ch.Name
|
||||
}
|
||||
if ch, err := s.Channel(parts[1]); err == nil {
|
||||
return "#" + ch.Name
|
||||
}
|
||||
return match
|
||||
})
|
||||
|
||||
// 2. Expand Discord message links (max 3, same guild only)
|
||||
matches := msgLinkRe.FindAllStringSubmatch(text, 3)
|
||||
for _, m := range matches {
|
||||
if len(m) < 4 {
|
||||
continue
|
||||
}
|
||||
linkGuildID, channelID, messageID := m[1], m[2], m[3]
|
||||
// Security: only expand links from the same guild
|
||||
if linkGuildID != guildID {
|
||||
continue
|
||||
}
|
||||
msg, err := s.ChannelMessage(channelID, messageID)
|
||||
if err != nil || msg == nil || msg.Content == "" {
|
||||
continue
|
||||
}
|
||||
author := "unknown"
|
||||
if msg.Author != nil {
|
||||
author = msg.Author.Username
|
||||
}
|
||||
text += fmt.Sprintf("\n[linked message from %s]: %s", author, msg.Content)
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
// stripBotMention removes the bot mention from the message content.
|
||||
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
|
||||
func (c *DiscordChannel) stripBotMention(text string) string {
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChannelRefRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantID string
|
||||
wantOK bool
|
||||
}{
|
||||
{"basic channel ref", "<#123456789>", "123456789", true},
|
||||
{"long id", "<#9876543210123456>", "9876543210123456", true},
|
||||
{"no match plain text", "hello world", "", false},
|
||||
{"no match partial", "<#>", "", false},
|
||||
{"no match letters", "<#abc>", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := channelRefRe.FindStringSubmatch(tt.input)
|
||||
if tt.wantOK {
|
||||
if len(matches) < 2 || matches[1] != tt.wantID {
|
||||
t.Errorf("channelRefRe(%q) = %v, want ID %q", tt.input, matches, tt.wantID)
|
||||
}
|
||||
} else {
|
||||
if len(matches) >= 2 {
|
||||
t.Errorf("channelRefRe(%q) should not match, got %v", tt.input, matches)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgLinkRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantGuild string
|
||||
wantChan string
|
||||
wantMsg string
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
"discord.com link",
|
||||
"https://discord.com/channels/111/222/333",
|
||||
"111", "222", "333", true,
|
||||
},
|
||||
{
|
||||
"discordapp.com link",
|
||||
"https://discordapp.com/channels/111/222/333",
|
||||
"111", "222", "333", true,
|
||||
},
|
||||
{
|
||||
"real world ids",
|
||||
"check this https://discord.com/channels/9000000000000001/9000000000000002/9000000000000003 please",
|
||||
"9000000000000001", "9000000000000002", "9000000000000003", true,
|
||||
},
|
||||
{"no match http", "http://discord.com/channels/1/2/3", "", "", "", false},
|
||||
{"no match missing segment", "https://discord.com/channels/1/2", "", "", "", false},
|
||||
{"no match plain text", "hello world", "", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := msgLinkRe.FindStringSubmatch(tt.input)
|
||||
if tt.wantOK {
|
||||
if len(matches) < 4 {
|
||||
t.Fatalf("msgLinkRe(%q) didn't match, want guild=%s chan=%s msg=%s",
|
||||
tt.input, tt.wantGuild, tt.wantChan, tt.wantMsg)
|
||||
}
|
||||
if matches[1] != tt.wantGuild || matches[2] != tt.wantChan || matches[3] != tt.wantMsg {
|
||||
t.Errorf("msgLinkRe(%q) = guild=%s chan=%s msg=%s, want %s/%s/%s",
|
||||
tt.input, matches[1], matches[2], matches[3],
|
||||
tt.wantGuild, tt.wantChan, tt.wantMsg)
|
||||
}
|
||||
} else {
|
||||
if len(matches) >= 4 {
|
||||
t.Errorf("msgLinkRe(%q) should not match, got %v", tt.input, matches)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgLinkRegex_MultipleMatches(t *testing.T) {
|
||||
input := "see https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6 and https://discord.com/channels/7/8/9 and https://discord.com/channels/10/11/12"
|
||||
matches := msgLinkRe.FindAllStringSubmatch(input, 3)
|
||||
if len(matches) != 3 {
|
||||
t.Fatalf("expected 3 matches (capped), got %d", len(matches))
|
||||
}
|
||||
// Verify the 3rd match is 7/8/9 (not 10/11/12)
|
||||
if matches[2][1] != "7" || matches[2][2] != "8" || matches[2][3] != "9" {
|
||||
t.Errorf("3rd match = %v, want guild=7 chan=8 msg=9", matches[2])
|
||||
}
|
||||
}
|
||||
@@ -654,7 +654,10 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("reading LINE API error response: %w", err))
|
||||
}
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody)))
|
||||
}
|
||||
|
||||
|
||||
@@ -793,7 +793,10 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err)
|
||||
}
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusTooManyRequests:
|
||||
return fmt.Errorf("response_url rate limited (%d): %s: %w",
|
||||
|
||||
@@ -321,8 +321,17 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody)))
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return "", channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading wecom upload error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return "", channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("wecom upload error: %s", string(respBody)),
|
||||
)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
@@ -371,8 +380,17 @@ func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken stri
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody)))
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading wecom_app error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("wecom_app API error: %s", string(respBody)),
|
||||
)
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -453,8 +453,17 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body)))
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading webhook error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("webhook API error: %s", string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
||||
+125
-36
@@ -167,22 +167,35 @@ type SessionConfig struct {
|
||||
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
|
||||
}
|
||||
|
||||
// RoutingConfig controls the intelligent model routing feature.
|
||||
// When enabled, each incoming message is scored against structural features
|
||||
// (message length, code blocks, tool call history, conversation depth, attachments).
|
||||
// Messages scoring below Threshold are sent to LightModel; all others use the
|
||||
// agent's primary model. This reduces cost and latency for simple tasks without
|
||||
// requiring any keyword matching — all scoring is language-agnostic.
|
||||
type RoutingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks
|
||||
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
@@ -432,6 +445,7 @@ type ProvidersConfig struct {
|
||||
Antigravity ProviderConfig `json:"antigravity"`
|
||||
Qwen ProviderConfig `json:"qwen"`
|
||||
Mistral ProviderConfig `json:"mistral"`
|
||||
Avian ProviderConfig `json:"avian"`
|
||||
}
|
||||
|
||||
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
|
||||
@@ -456,7 +470,8 @@ func (p ProvidersConfig) IsEmpty() bool {
|
||||
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
|
||||
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
|
||||
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
|
||||
p.Mistral.APIKey == "" && p.Mistral.APIBase == ""
|
||||
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
|
||||
p.Avian.APIKey == "" && p.Avian.APIBase == ""
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
|
||||
@@ -507,6 +522,7 @@ type ModelConfig struct {
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
}
|
||||
|
||||
// Validate checks if the ModelConfig has all required fields.
|
||||
@@ -525,6 +541,10 @@ type GatewayConfig struct {
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
Enabled bool `json:"enabled" env:"ENABLED"`
|
||||
}
|
||||
|
||||
type BraveConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
|
||||
@@ -549,6 +569,12 @@ type PerplexityConfig struct {
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type SearXNGConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_SEARXNG_ENABLED"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_SEARXNG_BASE_URL"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARXNG_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type GLMSearchConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
|
||||
@@ -560,11 +586,13 @@ type GLMSearchConfig struct {
|
||||
}
|
||||
|
||||
type WebToolsConfig struct {
|
||||
Brave BraveConfig `json:"brave"`
|
||||
Tavily TavilyConfig `json:"tavily"`
|
||||
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig `json:"perplexity"`
|
||||
GLMSearch GLMSearchConfig `json:"glm_search"`
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"`
|
||||
Brave BraveConfig ` json:"brave"`
|
||||
Tavily TavilyConfig ` json:"tavily"`
|
||||
DuckDuckGo DuckDuckGoConfig ` json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig ` json:"perplexity"`
|
||||
SearXNG SearXNGConfig ` json:"searxng"`
|
||||
GLMSearch GLMSearchConfig ` json:"glm_search"`
|
||||
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
|
||||
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
@@ -572,19 +600,29 @@ type WebToolsConfig struct {
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
|
||||
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
|
||||
}
|
||||
|
||||
type ExecConfig struct {
|
||||
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
|
||||
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
|
||||
CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"`
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"`
|
||||
EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"`
|
||||
CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"`
|
||||
CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"`
|
||||
TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s)
|
||||
}
|
||||
|
||||
type SkillsToolsConfig struct {
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"`
|
||||
Registries SkillsRegistriesConfig ` json:"registries"`
|
||||
MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"`
|
||||
SearchCache SearchCacheConfig ` json:"search_cache"`
|
||||
}
|
||||
|
||||
type MediaCleanupConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"`
|
||||
MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"`
|
||||
Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"`
|
||||
ToolConfig ` envPrefix:"PICOCLAW_MEDIA_CLEANUP_"`
|
||||
MaxAge int ` env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE" json:"max_age_minutes"`
|
||||
Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"`
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
@@ -596,12 +634,19 @@ type ToolsConfig struct {
|
||||
Skills SkillsToolsConfig `json:"skills"`
|
||||
MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
|
||||
MCP MCPConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
type SkillsToolsConfig struct {
|
||||
Registries SkillsRegistriesConfig `json:"registries"`
|
||||
MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"`
|
||||
SearchCache SearchCacheConfig `json:"search_cache"`
|
||||
AppendFile ToolConfig `json:"append_file" envPrefix:"PICOCLAW_TOOLS_APPEND_FILE_"`
|
||||
EditFile ToolConfig `json:"edit_file" envPrefix:"PICOCLAW_TOOLS_EDIT_FILE_"`
|
||||
FindSkills ToolConfig `json:"find_skills" envPrefix:"PICOCLAW_TOOLS_FIND_SKILLS_"`
|
||||
I2C ToolConfig `json:"i2c" envPrefix:"PICOCLAW_TOOLS_I2C_"`
|
||||
InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"`
|
||||
ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
|
||||
Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
|
||||
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
|
||||
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
|
||||
WriteFile ToolConfig `json:"write_file" envPrefix:"PICOCLAW_TOOLS_WRITE_FILE_"`
|
||||
}
|
||||
|
||||
type SearchCacheConfig struct {
|
||||
@@ -647,8 +692,7 @@ type MCPServerConfig struct {
|
||||
|
||||
// MCPConfig defines configuration for all MCP servers
|
||||
type MCPConfig struct {
|
||||
// Enabled globally enables/disables MCP integration
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
|
||||
ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"`
|
||||
// Servers is a map of server name to server configuration
|
||||
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
|
||||
}
|
||||
@@ -854,3 +898,48 @@ func (c *Config) ValidateModelList() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
switch name {
|
||||
case "web":
|
||||
return t.Web.Enabled
|
||||
case "cron":
|
||||
return t.Cron.Enabled
|
||||
case "exec":
|
||||
return t.Exec.Enabled
|
||||
case "skills":
|
||||
return t.Skills.Enabled
|
||||
case "media_cleanup":
|
||||
return t.MediaCleanup.Enabled
|
||||
case "append_file":
|
||||
return t.AppendFile.Enabled
|
||||
case "edit_file":
|
||||
return t.EditFile.Enabled
|
||||
case "find_skills":
|
||||
return t.FindSkills.Enabled
|
||||
case "i2c":
|
||||
return t.I2C.Enabled
|
||||
case "install_skill":
|
||||
return t.InstallSkill.Enabled
|
||||
case "list_dir":
|
||||
return t.ListDir.Enabled
|
||||
case "message":
|
||||
return t.Message.Enabled
|
||||
case "read_file":
|
||||
return t.ReadFile.Enabled
|
||||
case "spawn":
|
||||
return t.Spawn.Enabled
|
||||
case "spi":
|
||||
return t.SPI.Enabled
|
||||
case "subagent":
|
||||
return t.Subagent.Enabled
|
||||
case "web_fetch":
|
||||
return t.WebFetch.Enabled
|
||||
case "write_file":
|
||||
return t.WriteFile.Enabled
|
||||
case "mcp":
|
||||
return t.MCP.Enabled
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
+77
-2
@@ -316,6 +316,20 @@ func DefaultConfig() *Config {
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// Avian - https://avian.io
|
||||
{
|
||||
ModelName: "deepseek-v3.2",
|
||||
Model: "avian/deepseek/deepseek-v3.2",
|
||||
APIBase: "https://api.avian.io/v1",
|
||||
APIKey: "",
|
||||
},
|
||||
{
|
||||
ModelName: "kimi-k2.5",
|
||||
Model: "avian/moonshotai/kimi-k2.5",
|
||||
APIBase: "https://api.avian.io/v1",
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// VLLM (local) - http://localhost:8000
|
||||
{
|
||||
ModelName: "local-model",
|
||||
@@ -330,11 +344,16 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
MediaCleanup: MediaCleanupConfig{
|
||||
Enabled: true,
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
MaxAge: 30,
|
||||
Interval: 5,
|
||||
},
|
||||
Web: WebToolsConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Proxy: "",
|
||||
FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
|
||||
Brave: BraveConfig{
|
||||
@@ -351,6 +370,11 @@ func DefaultConfig() *Config {
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
SearXNG: SearXNGConfig{
|
||||
Enabled: false,
|
||||
BaseURL: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
GLMSearch: GLMSearchConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
@@ -360,12 +384,22 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
ExecTimeoutMinutes: 5,
|
||||
},
|
||||
Exec: ExecConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
EnableDenyPatterns: true,
|
||||
TimeoutSeconds: 60,
|
||||
},
|
||||
Skills: SkillsToolsConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Registries: SkillsRegistriesConfig{
|
||||
ClawHub: ClawHubRegistryConfig{
|
||||
Enabled: true,
|
||||
@@ -379,9 +413,50 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
Enabled: false,
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
AppendFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
EditFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
FindSkills: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
I2C: ToolConfig{
|
||||
Enabled: false, // Hardware tool - Linux only
|
||||
},
|
||||
InstallSkill: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
ListDir: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Message: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
ReadFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Spawn: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
SPI: ToolConfig{
|
||||
Enabled: false, // Hardware tool - Linux only
|
||||
},
|
||||
Subagent: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
WebFetch: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
WriteFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
|
||||
@@ -390,6 +390,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"avian"},
|
||||
protocol: "avian",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.Avian.APIKey == "" && p.Avian.APIBase == "" {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "avian",
|
||||
Model: "avian/deepseek/deepseek-v3.2",
|
||||
APIKey: p.Avian.APIKey,
|
||||
APIBase: p.Avian.APIBase,
|
||||
Proxy: p.Avian.Proxy,
|
||||
RequestTimeout: p.Avian.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Process each provider migration
|
||||
|
||||
@@ -159,16 +159,17 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
VolcEngine: ProviderConfig{APIKey: "key15"},
|
||||
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
|
||||
Antigravity: ProviderConfig{AuthMethod: "oauth"},
|
||||
Qwen: ProviderConfig{APIKey: "key18"},
|
||||
Mistral: ProviderConfig{APIKey: "key19"},
|
||||
Qwen: ProviderConfig{APIKey: "key17"},
|
||||
Mistral: ProviderConfig{APIKey: "key18"},
|
||||
Avian: ProviderConfig{APIKey: "key19"},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
// All 19 providers should be converted
|
||||
if len(result) != 19 {
|
||||
t.Errorf("len(result) = %d, want 19", len(result))
|
||||
// All 21 providers should be converted
|
||||
if len(result) != 21 {
|
||||
t.Errorf("len(result) = %d, want 21", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+13
-3
@@ -194,7 +194,9 @@ func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
mcpCfg := config.MCPConfig{
|
||||
Enabled: true,
|
||||
ToolConfig: config.ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Servers: map[string]config.MCPServerConfig{
|
||||
"test-server": {
|
||||
Enabled: true,
|
||||
@@ -228,12 +230,20 @@ func TestNewManager_InitialState(t *testing.T) {
|
||||
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
|
||||
err := mgr.LoadFromMCPConfig(
|
||||
context.Background(),
|
||||
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}},
|
||||
"/tmp",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
|
||||
}
|
||||
|
||||
err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
|
||||
err = mgr.LoadFromMCPConfig(
|
||||
context.Background(),
|
||||
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}},
|
||||
"/tmp",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error when no servers configured, got: %v", err)
|
||||
}
|
||||
|
||||
@@ -31,6 +31,9 @@ type Provider struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// SupportsThinking implements providers.ThinkingCapable.
|
||||
func (p *Provider) SupportsThinking() bool { return true }
|
||||
|
||||
func NewProvider(token string) *Provider {
|
||||
return NewProviderWithBaseURL(token, "")
|
||||
}
|
||||
@@ -182,9 +185,80 @@ func buildParams(
|
||||
params.Tools = translateTools(tools)
|
||||
}
|
||||
|
||||
// Extended Thinking / Adaptive Thinking
|
||||
// The thinking_level value directly determines the API parameter format:
|
||||
// "adaptive" → {thinking: {type: "adaptive"}} + output_config.effort
|
||||
// "low/medium/high/xhigh" → {thinking: {type: "enabled", budget_tokens: N}}
|
||||
if level, ok := options["thinking_level"].(string); ok && level != "" && level != "off" {
|
||||
applyThinkingConfig(¶ms, level)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// applyThinkingConfig sets thinking parameters based on the level value.
|
||||
// "adaptive" uses the adaptive thinking API (Claude 4.6+).
|
||||
// All other levels use budget_tokens which is universally supported.
|
||||
//
|
||||
// Anthropic API constraint: temperature must not be set when thinking is enabled.
|
||||
// budget_tokens must be strictly less than max_tokens.
|
||||
func applyThinkingConfig(params *anthropic.MessageNewParams, level string) {
|
||||
// Anthropic API rejects requests with temperature set alongside thinking.
|
||||
// Reset to zero value (omitted from JSON serialization).
|
||||
if params.Temperature.Valid() {
|
||||
log.Printf("anthropic: temperature cleared because thinking is enabled (level=%s)", level)
|
||||
}
|
||||
params.Temperature = anthropic.MessageNewParams{}.Temperature
|
||||
|
||||
if level == "adaptive" {
|
||||
adaptive := anthropic.NewThinkingConfigAdaptiveParam()
|
||||
params.Thinking = anthropic.ThinkingConfigParamUnion{OfAdaptive: &adaptive}
|
||||
params.OutputConfig = anthropic.OutputConfigParam{
|
||||
Effort: anthropic.OutputConfigEffortHigh,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
budget := int64(levelToBudget(level))
|
||||
if budget <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// budget_tokens must be < max_tokens; clamp to respect user's max_tokens setting.
|
||||
if budget >= params.MaxTokens {
|
||||
log.Printf("anthropic: budget_tokens (%d) clamped to %d (max_tokens-1)", budget, params.MaxTokens-1)
|
||||
budget = params.MaxTokens - 1
|
||||
} else if budget > params.MaxTokens*80/100 {
|
||||
log.Printf("anthropic: thinking budget (%d) exceeds 80%% of max_tokens (%d), output may be truncated",
|
||||
budget, params.MaxTokens)
|
||||
}
|
||||
params.Thinking = anthropic.ThinkingConfigParamOfEnabled(budget)
|
||||
}
|
||||
|
||||
// levelToBudget maps a thinking level to budget_tokens.
|
||||
// Values are based on Anthropic's recommendations and community best practices:
|
||||
//
|
||||
// low = 4,096 — simple reasoning, quick debugging (Claude Code "think")
|
||||
// medium = 16,384 — Anthropic recommended sweet spot for most tasks
|
||||
// high = 32,000 — complex architecture, deep analysis (diminishing returns above this)
|
||||
// xhigh = 64,000 — extreme reasoning, research problems, benchmarks
|
||||
//
|
||||
// Note: For Claude 4.6+, prefer adaptive thinking over manual budget_tokens.
|
||||
func levelToBudget(level string) int {
|
||||
switch level {
|
||||
case "low":
|
||||
return 4096
|
||||
case "medium":
|
||||
return 16384
|
||||
case "high":
|
||||
return 32000
|
||||
case "xhigh":
|
||||
return 64000
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||
result := make([]anthropic.ToolUnionParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
@@ -213,10 +287,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||
|
||||
func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
var content strings.Builder
|
||||
var reasoning strings.Builder
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "thinking":
|
||||
tb := block.AsThinking()
|
||||
reasoning.WriteString(tb.Thinking)
|
||||
case "text":
|
||||
tb := block.AsText()
|
||||
content.WriteString(tb.Text)
|
||||
@@ -247,6 +325,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content.String(),
|
||||
Reasoning: reasoning.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: &UsageInfo{
|
||||
|
||||
@@ -0,0 +1,212 @@
|
||||
package anthropicprovider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
)
|
||||
|
||||
func TestApplyThinkingConfig_Adaptive(t *testing.T) {
|
||||
params := anthropic.MessageNewParams{
|
||||
MaxTokens: 16000,
|
||||
Temperature: anthropic.Float(0.7),
|
||||
}
|
||||
applyThinkingConfig(¶ms, "adaptive")
|
||||
|
||||
if params.Thinking.OfAdaptive == nil {
|
||||
t.Fatal("expected adaptive thinking")
|
||||
}
|
||||
if params.Thinking.OfEnabled != nil {
|
||||
t.Error("should not set enabled thinking in adaptive mode")
|
||||
}
|
||||
if params.OutputConfig.Effort != anthropic.OutputConfigEffortHigh {
|
||||
t.Errorf("effort = %q, want %q", params.OutputConfig.Effort, anthropic.OutputConfigEffortHigh)
|
||||
}
|
||||
if params.Temperature.Valid() {
|
||||
t.Error("temperature should be cleared when thinking is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyThinkingConfig_BudgetLevels(t *testing.T) {
|
||||
tests := []struct {
|
||||
level string
|
||||
wantBudget int64
|
||||
}{
|
||||
{"low", 4096},
|
||||
{"medium", 16384},
|
||||
{"high", 32000},
|
||||
{"xhigh", 64000},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.level, func(t *testing.T) {
|
||||
params := anthropic.MessageNewParams{
|
||||
MaxTokens: 200000,
|
||||
Temperature: anthropic.Float(0.5),
|
||||
}
|
||||
applyThinkingConfig(¶ms, tt.level)
|
||||
|
||||
if params.Thinking.OfEnabled == nil {
|
||||
t.Fatal("expected enabled thinking")
|
||||
}
|
||||
if params.Thinking.OfAdaptive != nil {
|
||||
t.Error("should not set adaptive thinking")
|
||||
}
|
||||
if params.Thinking.OfEnabled.BudgetTokens != tt.wantBudget {
|
||||
t.Errorf("budget_tokens = %d, want %d", params.Thinking.OfEnabled.BudgetTokens, tt.wantBudget)
|
||||
}
|
||||
if params.OutputConfig.Effort != "" {
|
||||
t.Errorf("effort = %q, want empty", params.OutputConfig.Effort)
|
||||
}
|
||||
if params.Temperature.Valid() {
|
||||
t.Error("temperature should be cleared when thinking is enabled")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyThinkingConfig_BudgetClamp(t *testing.T) {
|
||||
// budget_tokens must be < max_tokens; clamp budget down to respect user's max_tokens.
|
||||
params := anthropic.MessageNewParams{MaxTokens: 4096}
|
||||
applyThinkingConfig(¶ms, "high") // budget=32000 > maxTokens=4096
|
||||
|
||||
if params.Thinking.OfEnabled == nil {
|
||||
t.Fatal("expected enabled thinking")
|
||||
}
|
||||
if params.Thinking.OfEnabled.BudgetTokens != 4095 {
|
||||
t.Errorf("budget_tokens = %d, want 4095 (maxTokens-1)", params.Thinking.OfEnabled.BudgetTokens)
|
||||
}
|
||||
if params.MaxTokens != 4096 {
|
||||
t.Errorf("max_tokens should not be modified, got %d", params.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyThinkingConfig_UnknownLevel(t *testing.T) {
|
||||
params := anthropic.MessageNewParams{MaxTokens: 16000}
|
||||
applyThinkingConfig(¶ms, "unknown")
|
||||
|
||||
if params.Thinking.OfEnabled != nil {
|
||||
t.Error("should not set enabled thinking for unknown level")
|
||||
}
|
||||
if params.Thinking.OfAdaptive != nil {
|
||||
t.Error("should not set adaptive thinking for unknown level")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLevelToBudget(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
level string
|
||||
want int
|
||||
}{
|
||||
{"low", "low", 4096},
|
||||
{"medium", "medium", 16384},
|
||||
{"high", "high", 32000},
|
||||
{"xhigh", "xhigh", 64000},
|
||||
{"off", "off", 0},
|
||||
{"empty", "", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := levelToBudget(tt.level); got != tt.want {
|
||||
t.Errorf("levelToBudget(%q) = %d, want %d", tt.level, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParams_ThinkingClearsTemperature(t *testing.T) {
|
||||
msgs := []Message{{Role: "user", Content: "hello"}}
|
||||
opts := map[string]any{
|
||||
"max_tokens": 200000,
|
||||
"temperature": 0.8,
|
||||
"thinking_level": "medium",
|
||||
}
|
||||
|
||||
params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if params.Temperature.Valid() {
|
||||
t.Error("temperature should be cleared when thinking_level is set")
|
||||
}
|
||||
if params.Thinking.OfEnabled == nil {
|
||||
t.Fatal("expected enabled thinking")
|
||||
}
|
||||
if params.Thinking.OfEnabled.BudgetTokens != 16384 {
|
||||
t.Errorf("budget_tokens = %d, want 16384", params.Thinking.OfEnabled.BudgetTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// unmarshalBlocks constructs []ContentBlockUnion via JSON round-trip so that
|
||||
// the internal JSON.raw field is populated (required by AsText/AsThinking).
|
||||
func unmarshalBlocks(t *testing.T, jsonStr string) []anthropic.ContentBlockUnion {
|
||||
t.Helper()
|
||||
var blocks []anthropic.ContentBlockUnion
|
||||
if err := json.Unmarshal([]byte(jsonStr), &blocks); err != nil {
|
||||
t.Fatalf("unmarshalBlocks: %v", err)
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
func TestParseResponse_ThinkingBlock(t *testing.T) {
|
||||
resp := &anthropic.Message{
|
||||
Content: unmarshalBlocks(t, `[
|
||||
{"type":"thinking","thinking":"Let me reason step by step...","signature":"sig"},
|
||||
{"type":"text","text":"The answer is 42."}
|
||||
]`),
|
||||
StopReason: anthropic.StopReasonEndTurn,
|
||||
}
|
||||
|
||||
result := parseResponse(resp)
|
||||
|
||||
if result.Reasoning != "Let me reason step by step..." {
|
||||
t.Errorf("Reasoning = %q, want thinking content", result.Reasoning)
|
||||
}
|
||||
if result.Content != "The answer is 42." {
|
||||
t.Errorf("Content = %q, want text content", result.Content)
|
||||
}
|
||||
if result.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", result.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_NoThinkingBlock(t *testing.T) {
|
||||
resp := &anthropic.Message{
|
||||
Content: unmarshalBlocks(t, `[
|
||||
{"type":"text","text":"Just a normal response."}
|
||||
]`),
|
||||
StopReason: anthropic.StopReasonEndTurn,
|
||||
}
|
||||
|
||||
result := parseResponse(resp)
|
||||
|
||||
if result.Reasoning != "" {
|
||||
t.Errorf("Reasoning = %q, want empty", result.Reasoning)
|
||||
}
|
||||
if result.Content != "Just a normal response." {
|
||||
t.Errorf("Content = %q, want text content", result.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParams_NoThinkingKeepsTemperature(t *testing.T) {
|
||||
msgs := []Message{{Role: "user", Content: "hello"}}
|
||||
opts := map[string]any{
|
||||
"temperature": 0.8,
|
||||
}
|
||||
|
||||
params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !params.Temperature.Valid() {
|
||||
t.Error("temperature should be preserved when thinking is not set")
|
||||
}
|
||||
if params.Temperature.Value != 0.8 {
|
||||
t.Errorf("temperature = %f, want 0.8", params.Temperature.Value)
|
||||
}
|
||||
}
|
||||
@@ -640,7 +640,10 @@ func FetchAntigravityProjectID(accessToken string) (string, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading loadCodeAssist response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("loadCodeAssist failed: %s", string(body))
|
||||
}
|
||||
@@ -681,7 +684,10 @@ func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelIn
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading fetchAvailableModels response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf(
|
||||
"fetchAvailableModels failed (HTTP %d): %s",
|
||||
|
||||
@@ -190,6 +190,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "avian":
|
||||
if cfg.Providers.Avian.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Avian.APIKey
|
||||
sel.apiBase = cfg.Providers.Avian.APIBase
|
||||
sel.proxy = cfg.Providers.Avian.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.avian.io/v1"
|
||||
}
|
||||
}
|
||||
case "mistral":
|
||||
if cfg.Providers.Mistral.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Mistral.APIKey
|
||||
@@ -316,6 +325,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.mistral.ai/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Avian.APIKey
|
||||
sel.apiBase = cfg.Providers.Avian.APIBase
|
||||
sel.proxy = cfg.Providers.Avian.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.avian.io/v1"
|
||||
}
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
sel.apiKey = cfg.Providers.VLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.VLLM.APIBase
|
||||
|
||||
@@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral":
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
@@ -210,6 +210,8 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "http://localhost:8000/v1"
|
||||
case "mistral":
|
||||
return "https://api.mistral.ai/v1"
|
||||
case "avian":
|
||||
return "https://api.avian.io/v1"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -323,12 +323,14 @@ func serializeMessages(messages []Message) []any {
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
|
||||
@@ -37,6 +37,13 @@ type StatefulProvider interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
// ThinkingCapable is an optional interface for providers that support
|
||||
// extended thinking (e.g. Anthropic). Used by the agent loop to warn
|
||||
// when thinking_level is configured but the active provider cannot use it.
|
||||
type ThinkingCapable interface {
|
||||
SupportsThinking() bool
|
||||
}
|
||||
|
||||
// FailoverReason classifies why an LLM request failed for fallback decisions.
|
||||
type FailoverReason string
|
||||
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
package routing
|
||||
|
||||
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
|
||||
// A higher score indicates a more complex task that benefits from a heavy model.
|
||||
// The score is compared against the configured threshold: score >= threshold selects
|
||||
// the primary (heavy) model; score < threshold selects the light model.
|
||||
//
|
||||
// Classifier is an interface so that future implementations (ML-based, embedding-based,
|
||||
// or any other approach) can be swapped in without changing routing infrastructure.
|
||||
type Classifier interface {
|
||||
Score(f Features) float64
|
||||
}
|
||||
|
||||
// RuleClassifier is the v1 implementation.
|
||||
// It uses a weighted sum of structural signals with no external dependencies,
|
||||
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
|
||||
// that the returned score always falls within the [0, 1] contract.
|
||||
//
|
||||
// Individual weights (multiple signals can fire simultaneously):
|
||||
//
|
||||
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
|
||||
// token 50-200: 0.15 — medium length; may or may not be complex
|
||||
// code block present: 0.40 — coding tasks need the heavy model
|
||||
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
|
||||
// tool calls 1-3 (recent): 0.10 — some tool activity
|
||||
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
|
||||
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
|
||||
//
|
||||
// Default threshold is 0.35, so:
|
||||
// - Pure greetings / trivial Q&A: 0.00 → light ✓
|
||||
// - Medium prose message (50–200 tokens): 0.15 → light ✓
|
||||
// - Message with code block: 0.40 → heavy ✓
|
||||
// - Long message (>200 tokens): 0.35 → heavy ✓
|
||||
// - Active tool session + medium message: 0.25 → light (acceptable)
|
||||
// - Any message with an image/audio attachment: 1.00 → heavy ✓
|
||||
type RuleClassifier struct{}
|
||||
|
||||
// Score computes the complexity score for the given feature set.
|
||||
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
|
||||
func (c *RuleClassifier) Score(f Features) float64 {
|
||||
// Hard gate: multi-modal inputs always require the heavy model.
|
||||
if f.HasAttachments {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
var score float64
|
||||
|
||||
// Token estimate — primary verbosity signal
|
||||
switch {
|
||||
case f.TokenEstimate > 200:
|
||||
score += 0.35
|
||||
case f.TokenEstimate > 50:
|
||||
score += 0.15
|
||||
}
|
||||
|
||||
// Fenced code blocks — strongest indicator of a coding/technical task
|
||||
if f.CodeBlockCount > 0 {
|
||||
score += 0.40
|
||||
}
|
||||
|
||||
// Recent tool call density — indicates an ongoing agentic workflow
|
||||
switch {
|
||||
case f.RecentToolCalls > 3:
|
||||
score += 0.25
|
||||
case f.RecentToolCalls > 0:
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Conversation depth — accumulated context implies compound task
|
||||
if f.ConversationDepth > 10 {
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire
|
||||
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
|
||||
if score > 1.0 {
|
||||
score = 1.0
|
||||
}
|
||||
return score
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// lookbackWindow is the number of recent history entries scanned for tool calls.
|
||||
// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant).
|
||||
const lookbackWindow = 6
|
||||
|
||||
// Features holds the structural signals extracted from a message and its session context.
|
||||
// Every dimension is language-agnostic by construction — no keyword or pattern matching
|
||||
// against natural-language content. This ensures consistent routing for all locales.
|
||||
type Features struct {
|
||||
// TokenEstimate is a proxy for token count.
|
||||
// CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each.
|
||||
// This avoids API calls while giving accurate estimates for all scripts.
|
||||
TokenEstimate int
|
||||
|
||||
// CodeBlockCount is the number of fenced code blocks (``` pairs) in the message.
|
||||
// Coding tasks almost always require the heavy model.
|
||||
CodeBlockCount int
|
||||
|
||||
// RecentToolCalls is the count of tool_call messages in the last lookbackWindow
|
||||
// history entries. A high density indicates an active agentic workflow.
|
||||
RecentToolCalls int
|
||||
|
||||
// ConversationDepth is the total number of messages in the session history.
|
||||
// Deep sessions tend to carry implicit complexity built up over many turns.
|
||||
ConversationDepth int
|
||||
|
||||
// HasAttachments is true when the message appears to contain media (images,
|
||||
// audio, video). Multi-modal inputs require vision-capable heavy models.
|
||||
HasAttachments bool
|
||||
}
|
||||
|
||||
// ExtractFeatures computes the structural feature vector for a message.
|
||||
// It is a pure function with no side effects and zero allocations beyond
|
||||
// the returned struct.
|
||||
func ExtractFeatures(msg string, history []providers.Message) Features {
|
||||
return Features{
|
||||
TokenEstimate: estimateTokens(msg),
|
||||
CodeBlockCount: countCodeBlocks(msg),
|
||||
RecentToolCalls: countRecentToolCalls(history),
|
||||
ConversationDepth: len(history),
|
||||
HasAttachments: hasAttachments(msg),
|
||||
}
|
||||
}
|
||||
|
||||
// estimateTokens returns a token count proxy that handles both CJK and Latin text.
|
||||
// CJK runes (U+2E80–U+9FFF, U+F900–U+FAFF, U+AC00–U+D7AF) map to roughly one
|
||||
// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token
|
||||
// for English). Splitting the count this way avoids the 3x underestimation that a
|
||||
// flat rune_count/3 would produce for Chinese, Japanese, and Korean text.
|
||||
func estimateTokens(msg string) int {
|
||||
total := utf8.RuneCountInString(msg)
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
cjk := 0
|
||||
for _, r := range msg {
|
||||
if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF {
|
||||
cjk++
|
||||
}
|
||||
}
|
||||
return cjk + (total-cjk)/4
|
||||
}
|
||||
|
||||
// countCodeBlocks counts the number of complete fenced code blocks.
|
||||
// Each ``` delimiter increments a counter; pairs of delimiters form one block.
|
||||
// An unclosed opening fence (odd count) is treated as zero complete blocks
|
||||
// since it may just be an inline code span or a typo.
|
||||
func countCodeBlocks(msg string) int {
|
||||
n := strings.Count(msg, "```")
|
||||
return n / 2
|
||||
}
|
||||
|
||||
// countRecentToolCalls counts messages with tool calls in the last lookbackWindow
|
||||
// entries of history. It examines the ToolCalls field rather than parsing
|
||||
// the content string, so it is robust to any message format.
|
||||
func countRecentToolCalls(history []providers.Message) int {
|
||||
start := len(history) - lookbackWindow
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, msg := range history[start:] {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
count += len(msg.ToolCalls)
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// hasAttachments returns true when the message content contains embedded media.
|
||||
// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and
|
||||
// common image/audio URL extensions. This is intentionally conservative —
|
||||
// false negatives (missing an attachment) just mean the routing falls back to
|
||||
// the primary model anyway.
|
||||
func hasAttachments(msg string) bool {
|
||||
lower := strings.ToLower(msg)
|
||||
|
||||
// Base64 data URIs embedded directly in the message
|
||||
if strings.Contains(lower, "data:image/") ||
|
||||
strings.Contains(lower, "data:audio/") ||
|
||||
strings.Contains(lower, "data:video/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Common image/audio extensions in URLs or file references
|
||||
mediaExts := []string{
|
||||
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp",
|
||||
".mp3", ".wav", ".ogg", ".m4a", ".flac",
|
||||
".mp4", ".avi", ".mov", ".webm",
|
||||
}
|
||||
for _, ext := range mediaExts {
|
||||
if strings.Contains(lower, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// defaultThreshold is used when the config threshold is zero or negative.
|
||||
// At 0.35 a message needs at least one strong signal (code block, long text,
|
||||
// or an attachment) before the heavy model is chosen.
|
||||
const defaultThreshold = 0.35
|
||||
|
||||
// RouterConfig holds the validated model routing settings.
|
||||
// It mirrors config.RoutingConfig but lives in pkg/routing to keep the
|
||||
// dependency graph simple: pkg/agent resolves config → routing, not the reverse.
|
||||
type RouterConfig struct {
|
||||
// LightModel is the model_name (from model_list) used for simple tasks.
|
||||
LightModel string
|
||||
|
||||
// Threshold is the complexity score cutoff in [0, 1].
|
||||
// score >= Threshold → primary (heavy) model.
|
||||
// score < Threshold → light model.
|
||||
Threshold float64
|
||||
}
|
||||
|
||||
// Router selects the appropriate model tier for each incoming message.
|
||||
// It is safe for concurrent use from multiple goroutines.
|
||||
type Router struct {
|
||||
cfg RouterConfig
|
||||
classifier Classifier
|
||||
}
|
||||
|
||||
// New creates a Router with the given config and the default RuleClassifier.
|
||||
// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used.
|
||||
func New(cfg RouterConfig) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{
|
||||
cfg: cfg,
|
||||
classifier: &RuleClassifier{},
|
||||
}
|
||||
}
|
||||
|
||||
// newWithClassifier creates a Router with a custom Classifier.
|
||||
// Intended for unit tests that need to inject a deterministic scorer.
|
||||
func newWithClassifier(cfg RouterConfig, c Classifier) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{cfg: cfg, classifier: c}
|
||||
}
|
||||
|
||||
// SelectModel returns the model to use for this conversation turn along with
|
||||
// the computed complexity score (for logging and debugging).
|
||||
//
|
||||
// - If score < cfg.Threshold: returns (cfg.LightModel, true, score)
|
||||
// - Otherwise: returns (primaryModel, false, score)
|
||||
//
|
||||
// The caller is responsible for resolving the returned model name into
|
||||
// provider candidates (see AgentInstance.LightCandidates).
|
||||
func (r *Router) SelectModel(
|
||||
msg string,
|
||||
history []providers.Message,
|
||||
primaryModel string,
|
||||
) (model string, usedLight bool, score float64) {
|
||||
features := ExtractFeatures(msg, history)
|
||||
score = r.classifier.Score(features)
|
||||
if score < r.cfg.Threshold {
|
||||
return r.cfg.LightModel, true, score
|
||||
}
|
||||
return primaryModel, false, score
|
||||
}
|
||||
|
||||
// LightModel returns the configured light model name.
|
||||
func (r *Router) LightModel() string {
|
||||
return r.cfg.LightModel
|
||||
}
|
||||
|
||||
// Threshold returns the complexity threshold in use.
|
||||
func (r *Router) Threshold() float64 {
|
||||
return r.cfg.Threshold
|
||||
}
|
||||
@@ -0,0 +1,414 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ── ExtractFeatures ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractFeatures_EmptyMessage(t *testing.T) {
|
||||
f := ExtractFeatures("", nil)
|
||||
if f.TokenEstimate != 0 {
|
||||
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
|
||||
}
|
||||
if f.CodeBlockCount != 0 {
|
||||
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
|
||||
}
|
||||
if f.RecentToolCalls != 0 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
|
||||
}
|
||||
if f.ConversationDepth != 0 {
|
||||
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
|
||||
}
|
||||
if f.HasAttachments {
|
||||
t.Error("HasAttachments: got true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate(t *testing.T) {
|
||||
// 30 ASCII runes: 0 CJK + 30/4 = 7 tokens
|
||||
msg := strings.Repeat("a", 30)
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 7 {
|
||||
t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
|
||||
// 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token).
|
||||
// Using a rune slice literal avoids CJK string literals in source.
|
||||
msg := string([]rune{
|
||||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||||
0x4F60,
|
||||
})
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 9 {
|
||||
t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) {
|
||||
// Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens.
|
||||
msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok"
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 6 {
|
||||
t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_CodeBlocks(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{"no code here", 0},
|
||||
{"```go\nfmt.Println()\n```", 1},
|
||||
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
|
||||
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.CodeBlockCount != tc.want {
|
||||
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
|
||||
// History longer than lookbackWindow — only last lookbackWindow entries count.
|
||||
history := make([]providers.Message, 10)
|
||||
// Put 2 tool calls at positions 8 and 9 (within the last 6)
|
||||
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
|
||||
history[9] = providers.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
|
||||
}
|
||||
// Position 3 is outside the lookback window and must NOT be counted
|
||||
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
|
||||
|
||||
f := ExtractFeatures("test", history)
|
||||
// 1 (position 8) + 2 (position 9) = 3
|
||||
if f.RecentToolCalls != 3 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_ConversationDepth(t *testing.T) {
|
||||
history := make([]providers.Message, 7)
|
||||
f := ExtractFeatures("msg", history)
|
||||
if f.ConversationDepth != 7 {
|
||||
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"plain text", false},
|
||||
{"here is an image: data:image/png;base64,abc123", true},
|
||||
{"audio: data:audio/mp3;base64,xyz", true},
|
||||
{"video: data:video/mp4;base64,xyz", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"check out photo.jpg", true},
|
||||
{"see screenshot.png", true},
|
||||
{"listen to audio.mp3", true},
|
||||
{"watch clip.mp4", true},
|
||||
{"just a .go file", false},
|
||||
{"document.pdf", false}, // pdf is not in the media list
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── RuleClassifier ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{})
|
||||
if score != 0.0 {
|
||||
t.Errorf("zero features: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{HasAttachments: true})
|
||||
if score != 1.0 {
|
||||
t.Errorf("attachments: got %f, want 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Code block alone = 0.40, above default threshold 0.35
|
||||
score := c.Score(Features{CodeBlockCount: 1})
|
||||
if score < 0.35 {
|
||||
t.Errorf("code block: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_LongMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// >200 tokens = 0.35, exactly at default threshold → heavy
|
||||
score := c.Score(Features{TokenEstimate: 250})
|
||||
if score < 0.35 {
|
||||
t.Errorf("long message: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_MediumMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// 50-200 tokens = 0.15, below threshold → light
|
||||
score := c.Score(Features{TokenEstimate: 100})
|
||||
if score >= 0.35 {
|
||||
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ShortMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// <50 tokens, no other signals = 0.0 → light
|
||||
score := c.Score(Features{TokenEstimate: 10})
|
||||
if score != 0.0 {
|
||||
t.Errorf("short message: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
|
||||
scoreNone := c.Score(Features{RecentToolCalls: 0})
|
||||
scoreLow := c.Score(Features{RecentToolCalls: 2})
|
||||
scoreHigh := c.Score(Features{RecentToolCalls: 5})
|
||||
|
||||
if scoreNone != 0.0 {
|
||||
t.Errorf("no tools: got %f, want 0.0", scoreNone)
|
||||
}
|
||||
if scoreLow <= scoreNone {
|
||||
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
|
||||
}
|
||||
if scoreHigh <= scoreLow {
|
||||
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_DeepConversation(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
shallow := c.Score(Features{ConversationDepth: 5})
|
||||
deep := c.Score(Features{ConversationDepth: 15})
|
||||
if deep <= shallow {
|
||||
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Max all signals simultaneously
|
||||
f := Features{
|
||||
TokenEstimate: 500,
|
||||
CodeBlockCount: 3,
|
||||
RecentToolCalls: 10,
|
||||
ConversationDepth: 20,
|
||||
}
|
||||
score := c.Score(f)
|
||||
if score > 1.0 {
|
||||
t.Errorf("score %f exceeds 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Router ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestRouter_DefaultThreshold(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash"})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "hi"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("simple message: expected light model to be selected")
|
||||
}
|
||||
if model != "gemini-flash" {
|
||||
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "```go\nfmt.Println(\"hello\")\n```"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("code block: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "can you analyze this? data:image/png;base64,abc123"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("attachment: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
// >200 token estimate: 210 * 3 = 630 chars
|
||||
msg := strings.Repeat("word ", 210)
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("long message: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
|
||||
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
|
||||
// Routing is conservative: only promote to heavy when the signal is unambiguous.
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
|
||||
}
|
||||
msg := "ok"
|
||||
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
|
||||
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{
|
||||
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
|
||||
}},
|
||||
}
|
||||
// ~55 tokens * 3 = 165 chars
|
||||
msg := strings.Repeat("word ", 55)
|
||||
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
|
||||
// Very low threshold: even a short message triggers heavy model
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
|
||||
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
|
||||
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("low threshold: medium message should use primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
|
||||
// Very high threshold: even code blocks route to light
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
|
||||
msg := "```go\nfmt.Println()\n```"
|
||||
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("very high threshold: code block (0.40) should route to light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_LightModel(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
|
||||
if r.LightModel() != "my-fast-model" {
|
||||
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
|
||||
}
|
||||
}
|
||||
|
||||
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
|
||||
|
||||
type fixedScoreClassifier struct{ score float64 }
|
||||
|
||||
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
|
||||
|
||||
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.2},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if !usedLight {
|
||||
t.Error("low score with custom classifier: expected light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.8},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("high score with custom classifier: expected primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
|
||||
// score == threshold → primary (uses >= comparison)
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.5},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("score == threshold: expected primary model (>= threshold → primary)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_ReturnsScore(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.42},
|
||||
)
|
||||
_, _, score := r.SelectModel("anything", nil, "heavy")
|
||||
if score != 0.42 {
|
||||
t.Errorf("score: got %f, want 0.42", score)
|
||||
}
|
||||
}
|
||||
@@ -259,15 +259,7 @@ func (c *ClawHubRegistry) DownloadAndInstall(
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
if c.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.authToken)
|
||||
}
|
||||
|
||||
tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize))
|
||||
tmpPath, err := c.downloadToTempFileWithRetry(ctx, u.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
@@ -284,17 +276,12 @@ func (c *ClawHubRegistry) DownloadAndInstall(
|
||||
// --- HTTP helper ---
|
||||
|
||||
func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
req, err := c.newGetRequest(ctx, urlStr, "application/json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if c.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.authToken)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
resp, err := utils.DoRequestWithRetry(c.client, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -312,3 +299,64 @@ func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, err
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (c *ClawHubRegistry) newGetRequest(ctx context.Context, urlStr, accept string) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", accept)
|
||||
if c.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.authToken)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlStr string) (string, error) {
|
||||
req, err := c.newGetRequest(ctx, urlStr, "application/zip")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := utils.DoRequestWithRetry(c.client, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
errBody := make([]byte, 512)
|
||||
n, _ := io.ReadFull(resp.Body, errBody)
|
||||
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n]))
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "picoclaw-dl-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
cleanup := func() {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
|
||||
src := io.LimitReader(resp.Body, int64(c.maxZipSize)+1)
|
||||
written, err := io.Copy(tmpFile, src)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", fmt.Errorf("download write failed: %w", err)
|
||||
}
|
||||
|
||||
if written > int64(c.maxZipSize) {
|
||||
cleanup()
|
||||
return "", fmt.Errorf("download too large: %d bytes (max %d)", written, c.maxZipSize)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
@@ -54,6 +54,39 @@ func TestClawHubRegistrySearch(t *testing.T) {
|
||||
assert.Equal(t, "clawhub", results[0].RegistryName)
|
||||
}
|
||||
|
||||
func TestClawHubRegistrySearchRetries429(t *testing.T) {
|
||||
attempts := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
w.Header().Set("Retry-After", "0")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte("rate limited"))
|
||||
return
|
||||
}
|
||||
|
||||
slug := "github"
|
||||
name := "GitHub Integration"
|
||||
summary := "Interact with GitHub repos"
|
||||
version := "1.0.0"
|
||||
|
||||
json.NewEncoder(w).Encode(clawhubSearchResponse{
|
||||
Results: []clawhubSearchResult{
|
||||
{Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg := newTestRegistry(srv.URL, "")
|
||||
results, err := reg.Search(context.Background(), "github", 5)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
assert.Equal(t, 2, attempts)
|
||||
assert.Equal(t, "github", results[0].Slug)
|
||||
}
|
||||
|
||||
func TestClawHubRegistryGetSkillMeta(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/v1/skills/github", r.URL.Path)
|
||||
@@ -137,6 +170,54 @@ func TestClawHubRegistryDownloadAndInstall(t *testing.T) {
|
||||
assert.Contains(t, string(readmeContent), "# Test Skill")
|
||||
}
|
||||
|
||||
func TestClawHubRegistryDownloadAndInstallRetries429(t *testing.T) {
|
||||
zipBuf := createTestZip(t, map[string]string{
|
||||
"SKILL.md": "---\nname: retry-skill\ndescription: A test\n---\nHello skill",
|
||||
})
|
||||
|
||||
downloadAttempts := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/skills/retry-skill":
|
||||
json.NewEncoder(w).Encode(clawhubSkillResponse{
|
||||
Slug: "retry-skill",
|
||||
DisplayName: "Retry Skill",
|
||||
Summary: "A retry test skill",
|
||||
LatestVersion: &clawhubVersionInfo{Version: "1.0.0"},
|
||||
})
|
||||
case "/api/v1/download":
|
||||
downloadAttempts++
|
||||
if downloadAttempts == 1 {
|
||||
w.Header().Set("Retry-After", "0")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte("rate limited"))
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "retry-skill", r.URL.Query().Get("slug"))
|
||||
w.Header().Set("Content-Type", "application/zip")
|
||||
w.Write(zipBuf)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
targetDir := filepath.Join(tmpDir, "retry-skill")
|
||||
|
||||
reg := newTestRegistry(srv.URL, "")
|
||||
result, err := reg.DownloadAndInstall(context.Background(), "retry-skill", "", targetDir)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "1.0.0", result.Version)
|
||||
assert.Equal(t, 2, downloadAttempts)
|
||||
|
||||
skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md"))
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(skillContent), "Hello skill")
|
||||
}
|
||||
|
||||
func TestClawHubRegistryAuthToken(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
|
||||
+50
-38
@@ -10,11 +10,38 @@ type Tool interface {
|
||||
Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
}
|
||||
|
||||
// ContextualTool is an optional interface that tools can implement
|
||||
// to receive the current message context (channel, chatID)
|
||||
type ContextualTool interface {
|
||||
Tool
|
||||
SetContext(channel, chatID string)
|
||||
// --- Request-scoped tool context (channel / chatID) ---
|
||||
//
|
||||
// Carried via context.Value so that concurrent tool calls each receive
|
||||
// their own immutable copy — no mutable state on singleton tool instances.
|
||||
//
|
||||
// Keys are unexported pointer-typed vars — guaranteed collision-free,
|
||||
// and only accessible through the helper functions below.
|
||||
|
||||
type toolCtxKey struct{ name string }
|
||||
|
||||
var (
|
||||
ctxKeyChannel = &toolCtxKey{"channel"}
|
||||
ctxKeyChatID = &toolCtxKey{"chatID"}
|
||||
)
|
||||
|
||||
// WithToolContext returns a child context carrying channel and chatID.
|
||||
func WithToolContext(ctx context.Context, channel, chatID string) context.Context {
|
||||
ctx = context.WithValue(ctx, ctxKeyChannel, channel)
|
||||
ctx = context.WithValue(ctx, ctxKeyChatID, chatID)
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ToolChannel extracts the channel from ctx, or "" if unset.
|
||||
func ToolChannel(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyChannel).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// ToolChatID extracts the chatID from ctx, or "" if unset.
|
||||
func ToolChatID(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ctxKeyChatID).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// AsyncCallback is a function type that async tools use to notify completion.
|
||||
@@ -22,51 +49,36 @@ type ContextualTool interface {
|
||||
//
|
||||
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
|
||||
// The result parameter contains the tool's execution result.
|
||||
//
|
||||
// Example usage in an async tool:
|
||||
//
|
||||
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// // Start async work in background
|
||||
// go func() {
|
||||
// result := doAsyncWork()
|
||||
// if t.callback != nil {
|
||||
// t.callback(ctx, result)
|
||||
// }
|
||||
// }()
|
||||
// return AsyncResult("Async task started")
|
||||
// }
|
||||
type AsyncCallback func(ctx context.Context, result *ToolResult)
|
||||
|
||||
// AsyncTool is an optional interface that tools can implement to support
|
||||
// AsyncExecutor is an optional interface that tools can implement to support
|
||||
// asynchronous execution with completion callbacks.
|
||||
//
|
||||
// Async tools return immediately with an AsyncResult, then notify completion
|
||||
// via the callback set by SetCallback.
|
||||
// Unlike the old AsyncTool pattern (SetCallback + Execute), AsyncExecutor
|
||||
// receives the callback as a parameter of ExecuteAsync. This eliminates the
|
||||
// data race where concurrent calls could overwrite each other's callbacks
|
||||
// on a shared tool instance.
|
||||
//
|
||||
// This is useful for:
|
||||
// - Long-running operations that shouldn't block the agent loop
|
||||
// - Subagent spawns that complete independently
|
||||
// - Background tasks that need to report results later
|
||||
// - Long-running operations that shouldn't block the agent loop
|
||||
// - Subagent spawns that complete independently
|
||||
// - Background tasks that need to report results later
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type SpawnTool struct {
|
||||
// callback AsyncCallback
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||
// t.callback = cb
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// go t.runSubagent(ctx, args)
|
||||
// func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
// go func() {
|
||||
// result := t.runSubagent(ctx, args)
|
||||
// if cb != nil { cb(ctx, result) }
|
||||
// }()
|
||||
// return AsyncResult("Subagent spawned, will report back")
|
||||
// }
|
||||
type AsyncTool interface {
|
||||
type AsyncExecutor interface {
|
||||
Tool
|
||||
// SetCallback registers a callback function to be invoked when the async operation completes.
|
||||
// The callback will be called from a goroutine and should handle thread-safety if needed.
|
||||
SetCallback(cb AsyncCallback)
|
||||
// ExecuteAsync runs the tool asynchronously. The callback cb will be
|
||||
// invoked (possibly from another goroutine) when the async operation
|
||||
// completes. cb is guaranteed to be non-nil by the caller (registry).
|
||||
ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult
|
||||
}
|
||||
|
||||
func ToolToSchema(tool Tool) map[string]any {
|
||||
|
||||
+4
-18
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -24,9 +23,6 @@ type CronTool struct {
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
channel string
|
||||
chatID string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
@@ -102,14 +98,6 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
// SetContext sets the current session context for job creation
|
||||
func (t *CronTool) SetContext(channel, chatID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.channel = channel
|
||||
t.chatID = chatID
|
||||
}
|
||||
|
||||
// Execute runs the tool with the given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
@@ -119,7 +107,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
|
||||
switch action {
|
||||
case "add":
|
||||
return t.addJob(args)
|
||||
return t.addJob(ctx, args)
|
||||
case "list":
|
||||
return t.listJobs()
|
||||
case "remove":
|
||||
@@ -133,11 +121,9 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CronTool) addJob(args map[string]any) *ToolResult {
|
||||
t.mu.RLock()
|
||||
channel := t.channel
|
||||
chatID := t.chatID
|
||||
t.mu.RUnlock()
|
||||
func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult {
|
||||
channel := ToolChannel(ctx)
|
||||
chatID := ToolChatID(ctx)
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
||||
|
||||
+8
-10
@@ -9,10 +9,8 @@ import (
|
||||
type SendCallback func(channel, chatID, content string) error
|
||||
|
||||
type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
|
||||
sendCallback SendCallback
|
||||
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
@@ -48,10 +46,10 @@ func (t *MessageTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetContext(channel, chatID string) {
|
||||
t.defaultChannel = channel
|
||||
t.defaultChatID = chatID
|
||||
t.sentInRound.Store(false) // Reset send tracking for new processing round
|
||||
// ResetSentInRound resets the per-round send tracker.
|
||||
// Called by the agent loop at the start of each inbound message processing round.
|
||||
func (t *MessageTool) ResetSentInRound() {
|
||||
t.sentInRound.Store(false)
|
||||
}
|
||||
|
||||
// HasSentInRound returns true if the message tool sent a message during the current round.
|
||||
@@ -73,10 +71,10 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
|
||||
chatID, _ := args["chat_id"].(string)
|
||||
|
||||
if channel == "" {
|
||||
channel = t.defaultChannel
|
||||
channel = ToolChannel(ctx)
|
||||
}
|
||||
if chatID == "" {
|
||||
chatID = t.defaultChatID
|
||||
chatID = ToolChatID(ctx)
|
||||
}
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
@@ -18,7 +17,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Hello, world!",
|
||||
}
|
||||
@@ -60,7 +59,6 @@ func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("default-channel", "default-chat-id")
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
@@ -69,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "default-channel", "default-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
"channel": "custom-channel",
|
||||
@@ -96,14 +94,13 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
@@ -133,9 +130,8 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{} // content missing
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
@@ -151,7 +147,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No SetContext called, so defaultChannel and defaultChatID are empty
|
||||
// No WithToolContext — channel/chatID are empty
|
||||
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return nil
|
||||
@@ -175,10 +171,9 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
|
||||
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
// No SetSendCallback called
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
|
||||
args := map[string]any{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
+15
-13
@@ -45,8 +45,9 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string
|
||||
}
|
||||
|
||||
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
|
||||
// If the tool implements AsyncTool and a non-nil callback is provided,
|
||||
// the callback will be set on the tool before execution.
|
||||
// If the tool implements AsyncExecutor and a non-nil callback is provided,
|
||||
// ExecuteAsync is called instead of Execute — the callback is a parameter,
|
||||
// never stored as mutable state on the tool.
|
||||
func (r *ToolRegistry) ExecuteWithContext(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
@@ -69,22 +70,23 @@ func (r *ToolRegistry) ExecuteWithContext(
|
||||
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
|
||||
}
|
||||
|
||||
// If tool implements ContextualTool, set context
|
||||
if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" {
|
||||
contextualTool.SetContext(channel, chatID)
|
||||
}
|
||||
// Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx).
|
||||
// Always inject — tools validate what they require.
|
||||
ctx = WithToolContext(ctx, channel, chatID)
|
||||
|
||||
// If tool implements AsyncTool and callback is provided, set callback
|
||||
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
|
||||
asyncTool.SetCallback(asyncCallback)
|
||||
logger.DebugCF("tool", "Async callback injected",
|
||||
// If tool implements AsyncExecutor and callback is provided, use ExecuteAsync.
|
||||
// The callback is a call parameter, not mutable state on the tool instance.
|
||||
var result *ToolResult
|
||||
start := time.Now()
|
||||
if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil {
|
||||
logger.DebugCF("tool", "Executing async tool via ExecuteAsync",
|
||||
map[string]any{
|
||||
"tool": name,
|
||||
})
|
||||
result = asyncExec.ExecuteAsync(ctx, args, asyncCallback)
|
||||
} else {
|
||||
result = tool.Execute(ctx, args)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result := tool.Execute(ctx, args)
|
||||
duration := time.Since(start)
|
||||
|
||||
// Log based on result type
|
||||
|
||||
+32
-22
@@ -25,24 +25,24 @@ func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolRes
|
||||
return m.result
|
||||
}
|
||||
|
||||
type mockCtxTool struct {
|
||||
type mockContextAwareTool struct {
|
||||
mockRegistryTool
|
||||
channel string
|
||||
chatID string
|
||||
lastCtx context.Context
|
||||
}
|
||||
|
||||
func (m *mockCtxTool) SetContext(channel, chatID string) {
|
||||
m.channel = channel
|
||||
m.chatID = chatID
|
||||
func (m *mockContextAwareTool) Execute(ctx context.Context, _ map[string]any) *ToolResult {
|
||||
m.lastCtx = ctx
|
||||
return m.result
|
||||
}
|
||||
|
||||
type mockAsyncRegistryTool struct {
|
||||
mockRegistryTool
|
||||
cb AsyncCallback
|
||||
lastCB AsyncCallback
|
||||
}
|
||||
|
||||
func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
|
||||
m.cb = cb
|
||||
func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
m.lastCB = cb
|
||||
return m.result
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
@@ -136,34 +136,44 @@ func TestToolRegistry_Execute_NotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
|
||||
func TestToolRegistry_ExecuteWithContext_InjectsToolContext(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
ct := &mockCtxTool{
|
||||
ct := &mockContextAwareTool{
|
||||
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
|
||||
}
|
||||
r.Register(ct)
|
||||
|
||||
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
|
||||
|
||||
if ct.channel != "telegram" {
|
||||
t.Errorf("expected channel 'telegram', got %q", ct.channel)
|
||||
if ct.lastCtx == nil {
|
||||
t.Fatal("expected Execute to be called")
|
||||
}
|
||||
if ct.chatID != "chat-42" {
|
||||
t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
|
||||
if got := ToolChannel(ct.lastCtx); got != "telegram" {
|
||||
t.Errorf("expected channel 'telegram', got %q", got)
|
||||
}
|
||||
if got := ToolChatID(ct.lastCtx); got != "chat-42" {
|
||||
t.Errorf("expected chatID 'chat-42', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
|
||||
func TestToolRegistry_ExecuteWithContext_EmptyContext(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
ct := &mockCtxTool{
|
||||
ct := &mockContextAwareTool{
|
||||
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
|
||||
}
|
||||
r.Register(ct)
|
||||
|
||||
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
|
||||
|
||||
if ct.channel != "" || ct.chatID != "" {
|
||||
t.Error("SetContext should not be called with empty channel/chatID")
|
||||
if ct.lastCtx == nil {
|
||||
t.Fatal("expected Execute to be called")
|
||||
}
|
||||
// Empty values are still injected; tools decide what to do with them.
|
||||
if got := ToolChannel(ct.lastCtx); got != "" {
|
||||
t.Errorf("expected empty channel, got %q", got)
|
||||
}
|
||||
if got := ToolChatID(ct.lastCtx); got != "" {
|
||||
t.Errorf("expected empty chatID, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,14 +189,14 @@ func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
|
||||
cb := func(_ context.Context, _ *ToolResult) { called = true }
|
||||
|
||||
result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
|
||||
if at.cb == nil {
|
||||
t.Error("expected SetCallback to have been called")
|
||||
if at.lastCB == nil {
|
||||
t.Error("expected ExecuteAsync to have received a callback")
|
||||
}
|
||||
if !result.Async {
|
||||
t.Error("expected async result")
|
||||
}
|
||||
|
||||
at.cb(context.Background(), SilentResult("done"))
|
||||
at.lastCB(context.Background(), SilentResult("done"))
|
||||
if !called {
|
||||
t.Error("expected callback to be invoked")
|
||||
}
|
||||
|
||||
+6
-1
@@ -131,9 +131,14 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
}
|
||||
|
||||
timeout := 60 * time.Second
|
||||
if config != nil && config.Tools.Exec.TimeoutSeconds > 0 {
|
||||
timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second
|
||||
}
|
||||
|
||||
return &ExecTool{
|
||||
workingDir: workingDir,
|
||||
timeout: 60 * time.Second,
|
||||
timeout: timeout,
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
|
||||
+27
-17
@@ -8,25 +8,18 @@ import (
|
||||
|
||||
type SpawnTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
allowlistCheck func(targetAgentID string) bool
|
||||
callback AsyncCallback // For async completion notification
|
||||
}
|
||||
|
||||
// Compile-time check: SpawnTool implements AsyncExecutor.
|
||||
var _ AsyncExecutor = (*SpawnTool)(nil)
|
||||
|
||||
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
return &SpawnTool{
|
||||
manager: manager,
|
||||
originChannel: "cli",
|
||||
originChatID: "direct",
|
||||
manager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCallback implements AsyncTool interface for async completion notification
|
||||
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||
t.callback = cb
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Name() string {
|
||||
return "spawn"
|
||||
}
|
||||
@@ -56,16 +49,21 @@ func (t *SpawnTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SpawnTool) SetContext(channel, chatID string) {
|
||||
t.originChannel = channel
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
|
||||
t.allowlistCheck = check
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
return t.execute(ctx, args, nil)
|
||||
}
|
||||
|
||||
// ExecuteAsync implements AsyncExecutor. The callback is passed through to the
|
||||
// subagent manager as a call parameter — never stored on the SpawnTool instance.
|
||||
func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
return t.execute(ctx, args, cb)
|
||||
}
|
||||
|
||||
func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok || strings.TrimSpace(task) == "" {
|
||||
return ErrorResult("task is required and must be a non-empty string")
|
||||
@@ -85,8 +83,20 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
// Read channel/chatID from context (injected by registry).
|
||||
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
|
||||
// to preserve the same defaults as the original NewSpawnTool constructor.
|
||||
channel := ToolChannel(ctx)
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
chatID := ToolChatID(ctx)
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
}
|
||||
|
||||
// Pass callback to manager for async completion notification
|
||||
result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback)
|
||||
result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
||||
}
|
||||
|
||||
+14
-12
@@ -252,16 +252,12 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
|
||||
// and returns the result directly in the ToolResult.
|
||||
type SubagentTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
manager *SubagentManager
|
||||
}
|
||||
|
||||
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
|
||||
return &SubagentTool{
|
||||
manager: manager,
|
||||
originChannel: "cli",
|
||||
originChatID: "direct",
|
||||
manager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,11 +286,6 @@ func (t *SubagentTool) Parameters() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SubagentTool) SetContext(channel, chatID string) {
|
||||
t.originChannel = channel
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
@@ -341,13 +332,24 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
|
||||
// to preserve the same defaults as the original NewSubagentTool constructor.
|
||||
channel := ToolChannel(ctx)
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
chatID := ToolChatID(ctx)
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
}
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: llmOptions,
|
||||
}, messages, t.originChannel, t.originChatID)
|
||||
}, messages, channel, chatID)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
@@ -50,9 +50,8 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager.SetLLMOptions(2048, 0.6)
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetContext("cli", "direct")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
args := map[string]any{"task": "Do something"}
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
@@ -147,28 +146,14 @@ func TestSubagentTool_Parameters(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_SetContext verifies context setting
|
||||
func TestSubagentTool_SetContext(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
tool.SetContext("test-channel", "test-chat")
|
||||
|
||||
// Verify context is set (we can't directly access private fields,
|
||||
// but we can verify it doesn't crash)
|
||||
// The actual context usage is tested in Execute tests
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_Success tests successful execution
|
||||
func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetContext("telegram", "chat-123")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
|
||||
args := map[string]any{
|
||||
"task": "Write a haiku about coding",
|
||||
"label": "haiku-task",
|
||||
@@ -297,12 +282,9 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
// Set context
|
||||
channel := "test-channel"
|
||||
chatID := "test-chat"
|
||||
tool.SetContext(channel, chatID)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx := WithToolContext(context.Background(), channel, chatID)
|
||||
args := map[string]any{
|
||||
"task": "Test context passing",
|
||||
}
|
||||
|
||||
+71
-1
@@ -395,6 +395,68 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
type SearXNGSearchProvider struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (p *SearXNGSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := fmt.Sprintf("%s/search?q=%s&format=json&categories=general",
|
||||
strings.TrimSuffix(p.baseURL, "/"),
|
||||
url.QueryEscape(query))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("SearXNG returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
Engine string `json:"engine"`
|
||||
Score float64 `json:"score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
// Limit results to requested count
|
||||
if len(result.Results) > count {
|
||||
result.Results = result.Results[:count]
|
||||
}
|
||||
|
||||
// Format results in standard PicoClaw format
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("Results for: %s (via SearXNG)\n", query))
|
||||
for i, r := range result.Results {
|
||||
b.WriteString(fmt.Sprintf("%d. %s\n", i+1, r.Title))
|
||||
b.WriteString(fmt.Sprintf(" %s\n", r.URL))
|
||||
if r.Content != "" {
|
||||
b.WriteString(fmt.Sprintf(" %s\n", r.Content))
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
type GLMSearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -495,6 +557,9 @@ type WebSearchToolOptions struct {
|
||||
PerplexityAPIKey string
|
||||
PerplexityMaxResults int
|
||||
PerplexityEnabled bool
|
||||
SearXNGBaseURL string
|
||||
SearXNGMaxResults int
|
||||
SearXNGEnabled bool
|
||||
GLMSearchAPIKey string
|
||||
GLMSearchBaseURL string
|
||||
GLMSearchEngine string
|
||||
@@ -507,7 +572,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
|
||||
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
@@ -526,6 +591,11 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
} else if opts.SearXNGEnabled && opts.SearXNGBaseURL != "" {
|
||||
provider = &SearXNGSearchProvider{baseURL: opts.SearXNGBaseURL}
|
||||
if opts.SearXNGMaxResults > 0 {
|
||||
maxResults = opts.SearXNGMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user