Merge upstream main

This commit is contained in:
zihan987
2026-03-05 23:58:59 -08:00
59 changed files with 2617 additions and 479 deletions
+19
View File
@@ -24,6 +24,25 @@ jobs:
with:
version: v2.10.1
vuln_check:
name: Security Check
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Run Govulncheck
uses: golang/govulncheck-action@v1
with:
go-package: ./...
test:
name: Tests
runs-on: ubuntu-latest
+8
View File
@@ -100,3 +100,11 @@ jobs:
gh release edit "${{ inputs.tag }}" \
--draft=${{ inputs.draft }} \
--prerelease=${{ inputs.prerelease }}
upload-tos:
name: Upload to TOS
needs: release
uses: ./.github/workflows/upload-tos.yml
with:
tag: ${{ inputs.tag }}
secrets: inherit
+49
View File
@@ -0,0 +1,49 @@
name: Upload to Volcengine TOS
on:
workflow_dispatch:
inputs:
tag:
description: "Release tag to download and upload (e.g. v0.2.0)"
required: true
type: string
workflow_call:
inputs:
tag:
description: "Release tag to download and upload"
required: true
type: string
jobs:
upload-tos:
name: Upload to Volcengine TOS
runs-on: ubuntu-latest
steps:
- name: Download release assets
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
mkdir -p artifacts
gh release download "${{ inputs.tag }}" \
--repo "${{ github.repository }}" \
--dir artifacts \
--pattern "*.tar.gz" \
--pattern "*.zip" \
--pattern "*.rpm" \
--pattern "*.deb"
- name: Upload to Volcengine TOS
env:
AWS_ACCESS_KEY_ID: ${{ secrets.VOLC_TOS_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.VOLC_TOS_SECRET_KEY }}
AWS_DEFAULT_REGION: cn-beijing
run: |
aws configure set default.s3.addressing_style virtual
TOS_ENDPOINT="https://tos-s3-cn-beijing.volces.com"
# Upload to versioned directory
aws s3 sync artifacts/ "s3://picoclaw-downloads/${{ inputs.tag }}/" \
--endpoint-url "$TOS_ENDPOINT"
# Upload to latest (overwrite)
aws s3 sync artifacts/ "s3://picoclaw-downloads/latest/" \
--endpoint-url "$TOS_ENDPOINT" \
--delete
-4
View File
@@ -19,7 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---
PicoClaw is heavily inspired by and based on [nanobot](https://github.com/HKUDS/nanobot) by HKUDS.
+107 -12
View File
@@ -216,7 +216,7 @@ docker compose -f docker/docker-compose.yml --profile gateway up -d
> [!TIP]
> Set your API key in `~/.picoclaw/config.json`.
> Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
> Web Search is **optional** - get free [Tavily API](https://tavily.com) (1000 free queries/month) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback.
> Web Search is **optional** - get free [Tavily API](https://tavily.com) (1000 free queries/month), [SearXNG](https://github.com/searxng/searxng) (free, self-hosted) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback.
**1. Initialize**
@@ -265,6 +265,16 @@ picoclaw onboard
"duckduckgo": {
"enabled": true,
"max_results": 5
},
"perplexity": {
"enabled": false,
"api_key": "YOUR_PERPLEXITY_API_KEY",
"max_results": 5
},
"searxng": {
"enabled": false,
"base_url": "http://your-searxng-instance:8888",
"max_results": 5
}
}
}
@@ -277,7 +287,12 @@ picoclaw onboard
**3. Get API Keys**
* **LLM Provider**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
* **Web Search** (optional): [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month) · [Brave Search](https://brave.com/search/api) - Free tier available (2000 requests/month)
* **Web Search** (optional):
* [Brave Search](https://brave.com/search/api) - Paid ($5/1000 queries, ~$5-6/month)
* [Perplexity](https://www.perplexity.ai) - AI-powered search with chat interface
* [SearXNG](https://github.com/searxng/searxng) - Self-hosted metasearch engine (free, no API key needed)
* [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month)
* DuckDuckGo - Built-in fallback (no API key required)
> **Note**: See `config.example.json` for a complete configuration template.
@@ -1243,6 +1258,16 @@ picoclaw agent -m "Hello"
"duckduckgo": {
"enabled": true,
"max_results": 5
},
"perplexity": {
"enabled": false,
"api_key": "",
"max_results": 5
},
"searxng": {
"enabled": false,
"base_url": "http://localhost:8888",
"max_results": 5
}
},
"cron": {
@@ -1300,10 +1325,69 @@ discord: <https://discord.gg/V4sAZ9XWpN>
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
To enable web search:
#### Search Provider Priority
1. **Option 1 (Recommended)**: Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) for the best results.
2. **Option 2 (No Credit Card)**: If you don't have a key, we automatically fall back to **DuckDuckGo** (no key required).
PicoClaw automatically selects the best available search provider in this order:
1. **Perplexity** (if enabled and API key configured) - AI-powered search with citations
2. **Brave Search** (if enabled and API key configured) - Privacy-focused paid API ($5/1000 queries)
3. **SearXNG** (if enabled and base_url configured) - Self-hosted metasearch aggregating 70+ engines (free)
4. **DuckDuckGo** (if enabled, default fallback) - No API key required (free)
#### Web Search Configuration Options
**Option 1 (Best Results)**: Perplexity AI Search
```json
{
"tools": {
"web": {
"perplexity": {
"enabled": true,
"api_key": "YOUR_PERPLEXITY_API_KEY",
"max_results": 5
}
}
}
}
```
**Option 2 (Paid API)**: Get an API key at [https://brave.com/search/api](https://brave.com/search/api) ($5/1000 queries, ~$5-6/month)
```json
{
"tools": {
"web": {
"brave": {
"enabled": true,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
}
}
}
```
**Option 3 (Self-Hosted)**: Deploy your own [SearXNG](https://github.com/searxng/searxng) instance
```json
{
"tools": {
"web": {
"searxng": {
"enabled": true,
"base_url": "http://your-server:8888",
"max_results": 5
}
}
}
}
```
Benefits of SearXNG:
- **Zero cost**: No API fees or rate limits
- **Privacy-focused**: Self-hosted, no tracking
- **Aggregate results**: Queries 70+ search engines simultaneously
- **Perfect for cloud VMs**: Solves datacenter IP blocking issues (Oracle Cloud, GCP, AWS, Azure)
- **No API key needed**: Just deploy and configure the base URL
**Option 4 (No Setup Required)**: DuckDuckGo is enabled by default as fallback (no API key needed)
Add the key to `~/.picoclaw/config.json` if using Brave:
@@ -1319,6 +1403,16 @@ Add the key to `~/.picoclaw/config.json` if using Brave:
"duckduckgo": {
"enabled": true,
"max_results": 5
},
"perplexity": {
"enabled": false,
"api_key": "YOUR_PERPLEXITY_API_KEY",
"max_results": 5
},
"searxng": {
"enabled": false,
"base_url": "http://your-searxng-instance:8888",
"max_results": 5
}
}
}
@@ -1337,10 +1431,11 @@ This happens when another instance of the bot is running. Make sure only one `pi
## 📝 API Key Comparison
| Service | Free Tier | Use Case |
| ---------------- | ------------------- | ------------------------------------- |
| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) |
| **Zhipu** | 200K tokens/month | Best for Chinese users |
| **Brave Search** | 2000 queries/month | Web search functionality |
| **Groq** | Free tier available | Fast inference (Llama, Mixtral) |
| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) |
| Service | Free Tier | Use Case |
| ---------------- | ------------------------ | ------------------------------------- |
| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) |
| **Zhipu** | 200K tokens/month | Best for Chinese users |
| **Brave Search** | Paid ($5/1000 queries) | Web search functionality |
| **SearXNG** | Unlimited (self-hosted) | Privacy-focused metasearch (70+ engines) |
| **Groq** | Free tier available | Fast inference (Llama, Mixtral) |
| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) |
@@ -335,7 +335,11 @@ func (s *appState) testModel(model *picoclawconfig.ModelConfig) {
s.showMessage("Test OK", resp.Status)
return
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
body, err := io.ReadAll(io.LimitReader(resp.Body, 2048))
if err != nil {
s.showMessage("Test failed", fmt.Sprintf("failed to read response: %v", err))
return
}
s.showMessage(
"Test failed",
fmt.Sprintf("%s: %s", resp.Status, strings.TrimSpace(string(body))),
@@ -297,7 +297,10 @@ func fetchGoogleUserEmail(accessToken string) (string, error) {
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading userinfo response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("userinfo request failed: %s", string(body))
}
+4 -1
View File
@@ -177,7 +177,10 @@ func fetchGoogleUserEmail(accessToken string) (string, error) {
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading userinfo response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("userinfo request failed: %s", string(body))
}
+17 -11
View File
@@ -230,19 +230,25 @@ func setupCronTool(
// Create cron service
cronService := cron.NewCronService(cronStorePath, nil)
// Create and register CronTool
cronTool, err := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
if err != nil {
log.Fatalf("Critical error during CronTool initialization: %v", err)
// Create and register CronTool if enabled
var cronTool *tools.CronTool
if cfg.Tools.IsToolEnabled("cron") {
var err error
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
if err != nil {
log.Fatalf("Critical error during CronTool initialization: %v", err)
}
agentLoop.RegisterTool(cronTool)
}
agentLoop.RegisterTool(cronTool)
// Set the onJob handler
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
result := cronTool.ExecuteJob(context.Background(), job)
return result, nil
})
// Set onJob handler
if cronTool != nil {
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
result := cronTool.ExecuteJob(context.Background(), job)
return result, nil
})
}
return cronService
}
+11 -2
View File
@@ -18,12 +18,21 @@ var (
goVersion string
)
// GetPicoclawHome returns the picoclaw home directory.
// Priority: $PICOCLAW_HOME > ~/.picoclaw
func GetPicoclawHome() string {
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
return home
}
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw")
}
func GetConfigPath() string {
if configPath := os.Getenv("PICOCLAW_CONFIG"); configPath != "" {
return configPath
}
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "config.json")
return filepath.Join(GetPicoclawHome(), "config.json")
}
func LoadConfig() (*config.Config, error) {
+21
View File
@@ -19,6 +19,27 @@ func TestGetConfigPath(t *testing.T) {
assert.Equal(t, want, got)
}
func TestGetConfigPath_WithPICOCLAW_HOME(t *testing.T) {
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw")
t.Setenv("HOME", "/tmp/home")
got := GetConfigPath()
want := filepath.Join("/custom/picoclaw", "config.json")
assert.Equal(t, want, got)
}
func TestGetConfigPath_WithPICOCLAW_CONFIG(t *testing.T) {
t.Setenv("PICOCLAW_CONFIG", "/custom/config.json")
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw")
t.Setenv("HOME", "/tmp/home")
got := GetConfigPath()
want := "/custom/config.json"
assert.Equal(t, want, got)
}
func TestFormatVersion_NoGitCommit(t *testing.T) {
oldVersion, oldGit := version, gitCommit
t.Cleanup(func() { version, gitCommit = oldVersion, oldGit })
+3 -3
View File
@@ -21,8 +21,8 @@ picoclaw skills install --registry clawhub github
`,
Args: func(cmd *cobra.Command, args []string) error {
if registry != "" {
if len(args) != 2 {
return fmt.Errorf("when --registry is set, exactly 2 arguments are required: <name> <slug>")
if len(args) != 1 {
return fmt.Errorf("when --registry is set, exactly 1 argument is required: <slug>")
}
return nil
}
@@ -45,7 +45,7 @@ picoclaw skills install --registry clawhub github
return err
}
return skillsInstallFromRegistry(cfg, args[0], args[1])
return skillsInstallFromRegistry(cfg, registry, args[0])
}
return skillsInstallCmd(installer, args[0])
@@ -26,3 +26,72 @@ func TestNewInstallSubcommand(t *testing.T) {
assert.Len(t, cmd.Aliases, 0)
}
func TestInstallCommandArgs(t *testing.T) {
tests := []struct {
name string
args []string
registry string
expectError bool
errorMsg string
}{
{
name: "no registry, one arg",
args: []string{"sipeed/picoclaw-skills/weather"},
registry: "",
expectError: false,
},
{
name: "no registry, no args",
args: []string{},
registry: "",
expectError: true,
errorMsg: "exactly 1 argument is required: <github>",
},
{
name: "no registry, too many args",
args: []string{"arg1", "arg2"},
registry: "",
expectError: true,
errorMsg: "exactly 1 argument is required: <github>",
},
{
name: "with registry, one arg",
args: []string{"weather-skill"},
registry: "clawhub",
expectError: false,
},
{
name: "with registry, no args",
args: []string{},
registry: "clawhub",
expectError: true,
errorMsg: "when --registry is set, exactly 1 argument is required: <slug>",
},
{
name: "with registry, too many args",
args: []string{"arg1", "arg2"},
registry: "clawhub",
expectError: true,
errorMsg: "when --registry is set, exactly 1 argument is required: <slug>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := newInstallCommand(nil)
if tt.registry != "" {
require.NoError(t, cmd.Flags().Set("registry", tt.registry))
}
err := cmd.Args(cmd, tt.args)
if tt.expectError {
require.Error(t, err)
assert.Equal(t, tt.errorMsg, err.Error())
} else {
require.NoError(t, err)
}
})
}
}
+91 -8
View File
@@ -22,7 +22,8 @@
"model_name": "claude-sonnet-4.6",
"model": "anthropic/claude-sonnet-4.6",
"api_key": "sk-ant-your-key",
"api_base": "https://api.anthropic.com/v1"
"api_base": "https://api.anthropic.com/v1",
"thinking_level": "high"
},
{
"model_name": "gemini",
@@ -224,27 +225,53 @@
"mistral": {
"api_key": "",
"api_base": "https://api.mistral.ai/v1"
},
"avian": {
"api_key": "",
"api_base": "https://api.avian.io/v1"
}
},
"tools": {
"allow_read_paths": null,
"allow_write_paths": null,
"web": {
"enabled": true,
"brave": {
"enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
"tavily": {
"enabled": false,
"api_key": "",
"base_url": "",
"max_results": 0
},
"duckduckgo": {
"enabled": true,
"max_results": 5
},
"perplexity": {
"enabled": false,
"api_key": "pplx-xxx",
"api_key": "",
"max_results": 5
},
"proxy": ""
"searxng": {
"enabled": false,
"base_url": "http://localhost:8888",
"max_results": 5
},
"glm_search": {
"enabled": false,
"api_key": "",
"base_url": "https://open.bigmodel.cn/api/paas/v4/web_search",
"search_engine": "search_std",
"max_results": 5
},
"fetch_limit_bytes": 10485760
},
"cron": {
"enabled": true,
"exec_timeout_minutes": 5
},
"mcp": {
@@ -313,19 +340,75 @@
}
},
"exec": {
"enable_deny_patterns": false,
"custom_deny_patterns": []
"enabled": true,
"enable_deny_patterns": true,
"custom_deny_patterns": null,
"custom_allow_patterns": null
},
"skills": {
"enabled": true,
"registries": {
"clawhub": {
"enabled": true,
"base_url": "https://clawhub.ai",
"search_path": "/api/v1/search",
"skills_path": "/api/v1/skills",
"download_path": "/api/v1/download"
"auth_token": "",
"search_path": "",
"skills_path": "",
"download_path": "",
"timeout": 0,
"max_zip_size": 0,
"max_response_size": 0
}
},
"max_concurrent_searches": 2,
"search_cache": {
"max_size": 50,
"ttl_seconds": 300
}
},
"media_cleanup": {
"enabled": true,
"max_age_minutes": 30,
"interval_minutes": 5
},
"append_file": {
"enabled": true
},
"edit_file": {
"enabled": true
},
"find_skills": {
"enabled": true
},
"i2c": {
"enabled": false
},
"install_skill": {
"enabled": true
},
"list_dir": {
"enabled": true
},
"message": {
"enabled": true
},
"read_file": {
"enabled": true
},
"spawn": {
"enabled": true
},
"spi": {
"enabled": false
},
"subagent": {
"enabled": true
},
"web_fetch": {
"enabled": true
},
"write_file": {
"enabled": true
}
},
"heartbeat": {
+2
View File
@@ -180,6 +180,7 @@ The skills tool configures skill discovery and installation via registries like
| ---------------------------------- | ------ | -------------------- | ----------------------- |
| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
| `registries.clawhub.auth_token` | string | `""` | Optional Bearer token for higher rate limits |
| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
@@ -194,6 +195,7 @@ The skills tool configures skill discovery and installation via registries like
"clawhub": {
"enabled": true,
"base_url": "https://clawhub.ai",
"auth_token": "",
"search_path": "/api/v1/search",
"skills_path": "/api/v1/skills",
"download_path": "/api/v1/download"
-1
View File
@@ -37,7 +37,6 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
github.com/gdamore/encoding v1.0.1 // indirect
github.com/gdamore/tcell/v2 v2.13.8 // indirect
github.com/h2non/filetype v1.1.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
+3
View File
@@ -42,6 +42,9 @@ type ContextBuilder struct {
}
func getGlobalConfigDir() string {
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
return home
}
home, err := os.UserHomeDir()
if err != nil {
return ""
+63 -12
View File
@@ -26,6 +26,7 @@ type AgentInstance struct {
MaxIterations int
MaxTokens int
Temperature float64
ThinkingLevel ThinkingLevel
ContextWindow int
SummarizeMessageThreshold int
SummarizeTokenPercent int
@@ -36,6 +37,14 @@ type AgentInstance struct {
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
// Router is non-nil when model routing is configured and the light model
// was successfully resolved. It scores each incoming message and decides
// whether to route to LightCandidates or stay with Candidates.
Router *routing.Router
// LightCandidates holds the resolved provider candidates for the light model.
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
LightCandidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
@@ -59,17 +68,30 @@ func NewAgentInstance(
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
toolsRegistry.Register(execTool)
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
if cfg.Tools.IsToolEnabled("read_file") {
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("write_file") {
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
}
if cfg.Tools.IsToolEnabled("list_dir") {
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("exec") {
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
toolsRegistry.Register(execTool)
}
if cfg.Tools.IsToolEnabled("edit_file") {
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
}
if cfg.Tools.IsToolEnabled("append_file") {
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
}
sessionsDir := filepath.Join(workspace, "sessions")
sessionsManager := session.NewSessionManager(sessionsDir)
@@ -103,6 +125,12 @@ func NewAgentInstance(
temperature = *defaults.Temperature
}
var thinkingLevelStr string
if mc, err := cfg.GetModelConfig(model); err == nil {
thinkingLevelStr = mc.ThinkingLevel
}
thinkingLevel := parseThinkingLevel(thinkingLevelStr)
summarizeMessageThreshold := defaults.SummarizeMessageThreshold
if summarizeMessageThreshold == 0 {
summarizeMessageThreshold = 20
@@ -160,6 +188,25 @@ func NewAgentInstance(
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
var lightCandidates []providers.FallbackCandidate
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
if len(resolved) > 0 {
router = routing.New(routing.RouterConfig{
LightModel: rc.LightModel,
Threshold: rc.Threshold,
})
lightCandidates = resolved
} else {
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
rc.LightModel, agentID)
}
}
return &AgentInstance{
ID: agentID,
Name: agentName,
@@ -169,6 +216,7 @@ func NewAgentInstance(
MaxIterations: maxIter,
MaxTokens: maxTokens,
Temperature: temperature,
ThinkingLevel: thinkingLevel,
ContextWindow: maxTokens,
SummarizeMessageThreshold: summarizeMessageThreshold,
SummarizeTokenPercent: summarizeTokenPercent,
@@ -179,6 +227,8 @@ func NewAgentInstance(
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
Router: router,
LightCandidates: lightCandidates,
}
}
@@ -187,12 +237,13 @@ func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentD
if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" {
return expandHome(strings.TrimSpace(agentCfg.Workspace))
}
// Use the configured default workspace (respects PICOCLAW_HOME)
if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" {
return expandHome(defaults.Workspace)
}
home, _ := os.UserHomeDir()
// For named agents without explicit workspace, use default workspace with agent ID suffix
id := routing.NormalizeAgentID(agentCfg.ID)
return filepath.Join(home, ".picoclaw", "workspace-"+id)
return filepath.Join(expandHome(defaults.Workspace), "..", "workspace-"+id)
}
// resolveAgentModel resolves the primary model for an agent.
+164 -110
View File
@@ -108,76 +108,106 @@ func registerSharedTools(
}
// Web tools
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
if cfg.Tools.IsToolEnabled("web") {
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults,
SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled,
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
}
}
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
if cfg.Tools.IsToolEnabled("web_fetch") {
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
}
}
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
agent.Tools.Register(tools.NewSPITool())
if cfg.Tools.IsToolEnabled("i2c") {
agent.Tools.Register(tools.NewI2CTool())
}
if cfg.Tools.IsToolEnabled("spi") {
agent.Tools.Register(tools.NewSPITool())
}
// Message tool
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
if cfg.Tools.IsToolEnabled("message") {
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
})
})
})
agent.Tools.Register(messageTool)
agent.Tools.Register(messageTool)
}
// Skill discovery and installation tools
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
})
searchCache := skills.NewSearchCache(
cfg.Tools.Skills.SearchCache.MaxSize,
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
)
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
skills_enabled := cfg.Tools.IsToolEnabled("skills")
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
install_skills_enable := cfg.Tools.IsToolEnabled("install_skill")
if skills_enabled && (find_skills_enable || install_skills_enable) {
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
})
if find_skills_enable {
searchCache := skills.NewSearchCache(
cfg.Tools.Skills.SearchCache.MaxSize,
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
)
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
}
if install_skills_enable {
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
}
}
// Spawn tool with allowlist checker
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
if cfg.Tools.IsToolEnabled("spawn") {
if cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
} else {
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
}
}
}
}
@@ -185,7 +215,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
// Initialize MCP servers for all agents
if al.cfg.Tools.MCP.Enabled {
if al.cfg.Tools.IsToolEnabled("mcp") {
mcpManager := mcp.NewManager()
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
@@ -227,6 +257,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
if !ok {
continue
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
agent.Tools.Register(mcpTool)
totalRegistrations++
@@ -543,8 +574,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
// Reset message-tool state for this round so we don't skip publishing due to a previous round.
if tool, ok := agent.Tools.Get("message"); ok {
if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(msg.Channel, msg.ChatID)
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
resetter.ResetSentInRound()
}
}
@@ -659,10 +690,7 @@ func (al *AgentLoop) runAgentLoop(
}
}
// 1. Update tool contexts
al.updateToolContexts(agent, opts.Channel, opts.ChatID)
// 2. Build messages (skip history for heartbeat)
// 1. Build messages (skip history for heartbeat)
var history []providers.Message
var summary string
if !opts.NoHistory {
@@ -682,10 +710,10 @@ func (al *AgentLoop) runAgentLoop(
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
// 3. Save user message to session
// 2. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
// 4. Run LLM iteration loop
// 3. Run LLM iteration loop
finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts)
if err != nil {
return "", err
@@ -694,21 +722,21 @@ func (al *AgentLoop) runAgentLoop(
// If last tool had ForUser content and we already sent it, we might not need to send final response
// This is controlled by the tool's Silent flag and ForUser content
// 5. Handle empty response
// 4. Handle empty response
if finalContent == "" {
finalContent = opts.DefaultResponse
}
// 6. Save final assistant message to session
// 5. Save final assistant message to session
agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
agent.Sessions.Save(opts.SessionKey)
// 7. Optional: summarization
// 6. Optional: summarization
if opts.EnableSummary {
al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
}
// 8. Optional: send response via bus
// 7. Optional: send response via bus
if opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
@@ -717,7 +745,7 @@ func (al *AgentLoop) runAgentLoop(
})
}
// 9. Log response
// 8. Log response
responsePreview := utils.Truncate(finalContent, 120)
logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
map[string]any{
@@ -796,6 +824,12 @@ func (al *AgentLoop) runLLMIteration(
iteration := 0
var finalContent string
// Determine effective model tier for this conversation turn.
// selectCandidates evaluates routing once and the decision is sticky for
// all tool-follow-up iterations within the same turn so that a multi-step
// tool chain doesn't switch models mid-way through.
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
for iteration < agent.MaxIterations {
iteration++
@@ -814,7 +848,7 @@ func (al *AgentLoop) runLLMIteration(
map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"model": agent.Model,
"model": activeModel,
"messages_count": len(messages),
"tools_count": len(providerToolDefs),
"max_tokens": agent.MaxTokens,
@@ -830,27 +864,33 @@ func (al *AgentLoop) runLLMIteration(
"tools_json": formatToolsForLog(providerToolDefs),
})
// Call LLM with fallback chain if candidates are configured.
// Call LLM with fallback chain if multiple candidates are configured.
var response *providers.LLMResponse
var err error
llmOpts := map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
}
// parseThinkingLevel guarantees ThinkingOff for empty/unknown values,
// so checking != ThinkingOff is sufficient.
if agent.ThinkingLevel != ThinkingOff {
if tc, ok := agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
llmOpts["thinking_level"] = string(agent.ThinkingLevel)
} else {
logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring",
map[string]any{"agent_id": agent.ID, "thinking_level": string(agent.ThinkingLevel)})
}
}
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(
ctx,
agent.Candidates,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(
ctx,
messages,
providerToolDefs,
model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
},
)
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
},
)
if fbErr != nil {
@@ -866,11 +906,7 @@ func (al *AgentLoop) runLLMIteration(
}
return fbResult.Response, nil
}
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
})
return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
}
// Retry loop for context/token errors
@@ -1057,7 +1093,7 @@ func (al *AgentLoop) runLLMIteration(
"iteration": iteration,
})
// Create async callback for tools that implement AsyncTool
// Create async callback for tools that implement AsyncExecutor
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
@@ -1139,24 +1175,42 @@ func (al *AgentLoop) runLLMIteration(
return finalContent, iteration, nil
}
// updateToolContexts updates the context for tools that need channel/chatID info.
func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) {
// Use ContextualTool interface instead of type assertions
if tool, ok := agent.Tools.Get("message"); ok {
if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(channel, chatID)
}
// selectCandidates returns the model candidates and resolved model name to use
// for a conversation turn. When model routing is configured and the incoming
// message scores below the complexity threshold, it returns the light model
// candidates instead of the primary ones.
//
// The returned (candidates, model) pair is used for all LLM calls within one
// turn — tool follow-up iterations use the same tier as the initial call so
// that a multi-step tool chain doesn't switch models mid-way.
func (al *AgentLoop) selectCandidates(
agent *AgentInstance,
userMsg string,
history []providers.Message,
) (candidates []providers.FallbackCandidate, model string) {
if agent.Router == nil || len(agent.LightCandidates) == 0 {
return agent.Candidates, agent.Model
}
if tool, ok := agent.Tools.Get("spawn"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
if tool, ok := agent.Tools.Get("subagent"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
if !usedLight {
logger.DebugCF("agent", "Model routing: primary model selected",
map[string]any{
"agent_id": agent.ID,
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.Candidates, agent.Model
}
logger.InfoCF("agent", "Model routing: light model selected",
map[string]any{
"agent_id": agent.ID,
"light_model": agent.Router.LightModel(),
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.LightCandidates, agent.Router.LightModel()
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
+16 -65
View File
@@ -164,35 +164,21 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
}
}
// TestToolContext_Updates verifies tool context is updated with channel/chatID
// TestToolContext_Updates verifies tool context helpers work correctly
func TestToolContext_Updates(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
ctx := tools.WithToolContext(context.Background(), "telegram", "chat-42")
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
if got := tools.ToolChannel(ctx); got != "telegram" {
t.Errorf("expected channel 'telegram', got %q", got)
}
if got := tools.ToolChatID(ctx); got != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", got)
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "OK"}
_ = NewAgentLoop(cfg, msgBus, provider)
// Verify that ContextualTool interface is defined and can be implemented
// This test validates the interface contract exists
ctxTool := &mockContextualTool{}
// Verify the tool implements the interface correctly
var _ tools.ContextualTool = ctxTool
// Empty context returns empty strings
if got := tools.ToolChannel(context.Background()); got != "" {
t.Errorf("expected empty channel from bare context, got %q", got)
}
}
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
@@ -241,16 +227,11 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) {
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = tmpDir
cfg.Agents.Defaults.Model = "test-model"
cfg.Agents.Defaults.MaxTokens = 4096
cfg.Agents.Defaults.MaxToolIterations = 10
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
@@ -359,36 +340,6 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
return tools.SilentResult("Custom tool executed")
}
// mockContextualTool tracks context updates
type mockContextualTool struct {
lastChannel string
lastChatID string
}
func (m *mockContextualTool) Name() string {
return "mock_contextual"
}
func (m *mockContextualTool) Description() string {
return "Mock contextual tool"
}
func (m *mockContextualTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
return tools.SilentResult("Contextual tool executed")
}
func (m *mockContextualTool) SetContext(channel, chatID string) {
m.lastChannel = channel
m.lastChatID = chatID
}
// testHelper executes a message and returns the response
type testHelper struct {
al *AgentLoop
+39
View File
@@ -0,0 +1,39 @@
package agent
import "strings"
// ThinkingLevel controls how the provider sends thinking parameters.
//
// - "adaptive": sends {thinking: {type: "adaptive"}} + output_config.effort (Claude 4.6+)
// - "low"/"medium"/"high"/"xhigh": sends {thinking: {type: "enabled", budget_tokens: N}} (all models)
// - "off": disables thinking
type ThinkingLevel string
const (
ThinkingOff ThinkingLevel = "off"
ThinkingLow ThinkingLevel = "low"
ThinkingMedium ThinkingLevel = "medium"
ThinkingHigh ThinkingLevel = "high"
ThinkingXHigh ThinkingLevel = "xhigh"
ThinkingAdaptive ThinkingLevel = "adaptive"
)
// parseThinkingLevel normalizes a config string to a ThinkingLevel.
// Case-insensitive and whitespace-tolerant for user-facing config values.
// Returns ThinkingOff for unknown or empty values.
func parseThinkingLevel(level string) ThinkingLevel {
switch strings.ToLower(strings.TrimSpace(level)) {
case "adaptive":
return ThinkingAdaptive
case "low":
return ThinkingLow
case "medium":
return ThinkingMedium
case "high":
return ThinkingHigh
case "xhigh":
return ThinkingXHigh
default:
return ThinkingOff
}
}
+35
View File
@@ -0,0 +1,35 @@
package agent
import "testing"
func TestParseThinkingLevel(t *testing.T) {
tests := []struct {
name string
input string
want ThinkingLevel
}{
{"off", "off", ThinkingOff},
{"empty", "", ThinkingOff},
{"low", "low", ThinkingLow},
{"medium", "medium", ThinkingMedium},
{"high", "high", ThinkingHigh},
{"xhigh", "xhigh", ThinkingXHigh},
{"adaptive", "adaptive", ThinkingAdaptive},
{"unknown", "unknown", ThinkingOff},
// Case-insensitive and whitespace-tolerant
{"upper_Medium", "Medium", ThinkingMedium},
{"upper_HIGH", "HIGH", ThinkingHigh},
{"mixed_Adaptive", "Adaptive", ThinkingAdaptive},
{"leading_space", " high", ThinkingHigh},
{"trailing_space", "low ", ThinkingLow},
{"both_spaces", " medium ", ThinkingMedium},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseThinkingLevel(tt.input); got != tt.want {
t.Errorf("parseThinkingLevel(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
+20 -5
View File
@@ -212,7 +212,10 @@ func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) {
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading device code response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
@@ -300,7 +303,10 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading device code response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
@@ -360,7 +366,10 @@ func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*Au
return nil, fmt.Errorf("pending")
}
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading device token response: %w", err)
}
var tokenResp struct {
AuthorizationCode string `json:"authorization_code"`
@@ -401,7 +410,10 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading token refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
@@ -494,7 +506,10 @@ func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading token exchange response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed: %s", string(body))
}
+3
View File
@@ -39,6 +39,9 @@ func (c *AuthCredential) NeedsRefresh() bool {
}
func authFilePath() string {
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
return filepath.Join(home, "auth.json")
}
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "auth.json")
}
+70
View File
@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"os"
"regexp"
"strings"
"sync"
"time"
@@ -26,6 +27,12 @@ const (
sendTimeout = 10 * time.Second
)
var (
// Pre-compiled regexes for resolveDiscordRefs (avoid re-compiling per call)
channelRefRe = regexp.MustCompile(`<#(\d+)>`)
msgLinkRe = regexp.MustCompile(`https://(?:discord\.com|discordapp\.com)/channels/(\d+)/(\d+)/(\d+)`)
)
type DiscordChannel struct {
*channels.BaseChannel
session *discordgo.Session
@@ -338,6 +345,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = c.stripBotMention(content)
}
// Resolve Discord refs in main content before concatenation to avoid
// double-expanding links that appear in the referenced message.
content = c.resolveDiscordRefs(s, content, m.GuildID)
// Prepend referenced (quoted) message content if this is a reply
if m.MessageReference != nil && m.ReferencedMessage != nil {
refContent := m.ReferencedMessage.Content
if refContent != "" {
refAuthor := "unknown"
if m.ReferencedMessage.Author != nil {
refAuthor = m.ReferencedMessage.Author.Username
}
refContent = c.resolveDiscordRefs(s, refContent, m.GuildID)
content = fmt.Sprintf("[quoted message from %s]: %s\n\n%s",
refAuthor, refContent, content)
}
}
senderID := m.Author.ID
mediaPaths := make([]string, 0, len(m.Attachments))
@@ -508,6 +533,51 @@ func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
return nil
}
// resolveDiscordRefs resolves channel references (<#id> → #channel-name) and
// expands Discord message links to show the linked message content.
// Only links pointing to the same guild are expanded to prevent cross-guild leakage.
func (c *DiscordChannel) resolveDiscordRefs(s *discordgo.Session, text string, guildID string) string {
// 1. Resolve channel references: <#id> → #channel-name
text = channelRefRe.ReplaceAllStringFunc(text, func(match string) string {
parts := channelRefRe.FindStringSubmatch(match)
if len(parts) < 2 {
return match
}
// Prefer session state cache to avoid API calls
if ch, err := s.State.Channel(parts[1]); err == nil {
return "#" + ch.Name
}
if ch, err := s.Channel(parts[1]); err == nil {
return "#" + ch.Name
}
return match
})
// 2. Expand Discord message links (max 3, same guild only)
matches := msgLinkRe.FindAllStringSubmatch(text, 3)
for _, m := range matches {
if len(m) < 4 {
continue
}
linkGuildID, channelID, messageID := m[1], m[2], m[3]
// Security: only expand links from the same guild
if linkGuildID != guildID {
continue
}
msg, err := s.ChannelMessage(channelID, messageID)
if err != nil || msg == nil || msg.Content == "" {
continue
}
author := "unknown"
if msg.Author != nil {
author = msg.Author.Username
}
text += fmt.Sprintf("\n[linked message from %s]: %s", author, msg.Content)
}
return text
}
// stripBotMention removes the bot mention from the message content.
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
func (c *DiscordChannel) stripBotMention(text string) string {
@@ -0,0 +1,98 @@
package discord
import (
"testing"
)
func TestChannelRefRegex(t *testing.T) {
tests := []struct {
name string
input string
wantID string
wantOK bool
}{
{"basic channel ref", "<#123456789>", "123456789", true},
{"long id", "<#9876543210123456>", "9876543210123456", true},
{"no match plain text", "hello world", "", false},
{"no match partial", "<#>", "", false},
{"no match letters", "<#abc>", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := channelRefRe.FindStringSubmatch(tt.input)
if tt.wantOK {
if len(matches) < 2 || matches[1] != tt.wantID {
t.Errorf("channelRefRe(%q) = %v, want ID %q", tt.input, matches, tt.wantID)
}
} else {
if len(matches) >= 2 {
t.Errorf("channelRefRe(%q) should not match, got %v", tt.input, matches)
}
}
})
}
}
func TestMsgLinkRegex(t *testing.T) {
tests := []struct {
name string
input string
wantGuild string
wantChan string
wantMsg string
wantOK bool
}{
{
"discord.com link",
"https://discord.com/channels/111/222/333",
"111", "222", "333", true,
},
{
"discordapp.com link",
"https://discordapp.com/channels/111/222/333",
"111", "222", "333", true,
},
{
"real world ids",
"check this https://discord.com/channels/9000000000000001/9000000000000002/9000000000000003 please",
"9000000000000001", "9000000000000002", "9000000000000003", true,
},
{"no match http", "http://discord.com/channels/1/2/3", "", "", "", false},
{"no match missing segment", "https://discord.com/channels/1/2", "", "", "", false},
{"no match plain text", "hello world", "", "", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := msgLinkRe.FindStringSubmatch(tt.input)
if tt.wantOK {
if len(matches) < 4 {
t.Fatalf("msgLinkRe(%q) didn't match, want guild=%s chan=%s msg=%s",
tt.input, tt.wantGuild, tt.wantChan, tt.wantMsg)
}
if matches[1] != tt.wantGuild || matches[2] != tt.wantChan || matches[3] != tt.wantMsg {
t.Errorf("msgLinkRe(%q) = guild=%s chan=%s msg=%s, want %s/%s/%s",
tt.input, matches[1], matches[2], matches[3],
tt.wantGuild, tt.wantChan, tt.wantMsg)
}
} else {
if len(matches) >= 4 {
t.Errorf("msgLinkRe(%q) should not match, got %v", tt.input, matches)
}
}
})
}
}
func TestMsgLinkRegex_MultipleMatches(t *testing.T) {
input := "see https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6 and https://discord.com/channels/7/8/9 and https://discord.com/channels/10/11/12"
matches := msgLinkRe.FindAllStringSubmatch(input, 3)
if len(matches) != 3 {
t.Fatalf("expected 3 matches (capped), got %d", len(matches))
}
// Verify the 3rd match is 7/8/9 (not 10/11/12)
if matches[2][1] != "7" || matches[2][2] != "8" || matches[2][3] != "9" {
t.Errorf("3rd match = %v, want guild=7 chan=8 msg=9", matches[2])
}
}
+4 -1
View File
@@ -654,7 +654,10 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("reading LINE API error response: %w", err))
}
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody)))
}
+4 -1
View File
@@ -793,7 +793,10 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro
return nil
}
respBody, _ := io.ReadAll(resp.Body)
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err)
}
switch {
case resp.StatusCode == http.StatusTooManyRequests:
return fmt.Errorf("response_url rate limited (%d): %s: %w",
+22 -4
View File
@@ -321,8 +321,17 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody)))
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return "", channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading wecom upload error response: %w", readErr),
)
}
return "", channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("wecom upload error: %s", string(respBody)),
)
}
var result struct {
@@ -371,8 +380,17 @@ func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken stri
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody)))
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading wecom_app error response: %w", readErr),
)
}
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("wecom_app API error: %s", string(respBody)),
)
}
respBody, err := io.ReadAll(resp.Body)
+11 -2
View File
@@ -453,8 +453,17 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body)))
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading webhook error response: %w", readErr),
)
}
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("webhook API error: %s", string(body)),
)
}
body, err := io.ReadAll(resp.Body)
+125 -36
View File
@@ -167,22 +167,35 @@ type SessionConfig struct {
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
}
// RoutingConfig controls the intelligent model routing feature.
// When enabled, each incoming message is scored against structural features
// (message length, code blocks, tool call history, conversation depth, attachments).
// Messages scoring below Threshold are sent to LightModel; all others use the
// agent's primary model. This reduces cost and latency for simple tasks without
// requiring any keyword matching — all scoring is language-agnostic.
type RoutingConfig struct {
Enabled bool `json:"enabled"`
LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
@@ -432,6 +445,7 @@ type ProvidersConfig struct {
Antigravity ProviderConfig `json:"antigravity"`
Qwen ProviderConfig `json:"qwen"`
Mistral ProviderConfig `json:"mistral"`
Avian ProviderConfig `json:"avian"`
}
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
@@ -456,7 +470,8 @@ func (p ProvidersConfig) IsEmpty() bool {
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
p.Mistral.APIKey == "" && p.Mistral.APIBase == ""
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
p.Avian.APIKey == "" && p.Avian.APIBase == ""
}
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
@@ -507,6 +522,7 @@ type ModelConfig struct {
RPM int `json:"rpm,omitempty"` // Requests per minute limit
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
}
// Validate checks if the ModelConfig has all required fields.
@@ -525,6 +541,10 @@ type GatewayConfig struct {
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
}
type ToolConfig struct {
Enabled bool `json:"enabled" env:"ENABLED"`
}
type BraveConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
@@ -549,6 +569,12 @@ type PerplexityConfig struct {
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
}
type SearXNGConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_SEARXNG_ENABLED"`
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_SEARXNG_BASE_URL"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARXNG_MAX_RESULTS"`
}
type GLMSearchConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
@@ -560,11 +586,13 @@ type GLMSearchConfig struct {
}
type WebToolsConfig struct {
Brave BraveConfig `json:"brave"`
Tavily TavilyConfig `json:"tavily"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
Perplexity PerplexityConfig `json:"perplexity"`
GLMSearch GLMSearchConfig `json:"glm_search"`
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"`
Brave BraveConfig ` json:"brave"`
Tavily TavilyConfig ` json:"tavily"`
DuckDuckGo DuckDuckGoConfig ` json:"duckduckgo"`
Perplexity PerplexityConfig ` json:"perplexity"`
SearXNG SearXNGConfig ` json:"searxng"`
GLMSearch GLMSearchConfig ` json:"glm_search"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
@@ -572,19 +600,29 @@ type WebToolsConfig struct {
}
type CronToolsConfig struct {
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
}
type ExecConfig struct {
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"`
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"`
EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"`
CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"`
CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"`
TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s)
}
type SkillsToolsConfig struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"`
Registries SkillsRegistriesConfig ` json:"registries"`
MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"`
SearchCache SearchCacheConfig ` json:"search_cache"`
}
type MediaCleanupConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"`
MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"`
Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"`
ToolConfig ` envPrefix:"PICOCLAW_MEDIA_CLEANUP_"`
MaxAge int ` env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE" json:"max_age_minutes"`
Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"`
}
type ToolsConfig struct {
@@ -596,12 +634,19 @@ type ToolsConfig struct {
Skills SkillsToolsConfig `json:"skills"`
MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
MCP MCPConfig `json:"mcp"`
}
type SkillsToolsConfig struct {
Registries SkillsRegistriesConfig `json:"registries"`
MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"`
SearchCache SearchCacheConfig `json:"search_cache"`
AppendFile ToolConfig `json:"append_file" envPrefix:"PICOCLAW_TOOLS_APPEND_FILE_"`
EditFile ToolConfig `json:"edit_file" envPrefix:"PICOCLAW_TOOLS_EDIT_FILE_"`
FindSkills ToolConfig `json:"find_skills" envPrefix:"PICOCLAW_TOOLS_FIND_SKILLS_"`
I2C ToolConfig `json:"i2c" envPrefix:"PICOCLAW_TOOLS_I2C_"`
InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"`
ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
WriteFile ToolConfig `json:"write_file" envPrefix:"PICOCLAW_TOOLS_WRITE_FILE_"`
}
type SearchCacheConfig struct {
@@ -647,8 +692,7 @@ type MCPServerConfig struct {
// MCPConfig defines configuration for all MCP servers
type MCPConfig struct {
// Enabled globally enables/disables MCP integration
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"`
// Servers is a map of server name to server configuration
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
}
@@ -854,3 +898,48 @@ func (c *Config) ValidateModelList() error {
}
return nil
}
func (t *ToolsConfig) IsToolEnabled(name string) bool {
switch name {
case "web":
return t.Web.Enabled
case "cron":
return t.Cron.Enabled
case "exec":
return t.Exec.Enabled
case "skills":
return t.Skills.Enabled
case "media_cleanup":
return t.MediaCleanup.Enabled
case "append_file":
return t.AppendFile.Enabled
case "edit_file":
return t.EditFile.Enabled
case "find_skills":
return t.FindSkills.Enabled
case "i2c":
return t.I2C.Enabled
case "install_skill":
return t.InstallSkill.Enabled
case "list_dir":
return t.ListDir.Enabled
case "message":
return t.Message.Enabled
case "read_file":
return t.ReadFile.Enabled
case "spawn":
return t.Spawn.Enabled
case "spi":
return t.SPI.Enabled
case "subagent":
return t.Subagent.Enabled
case "web_fetch":
return t.WebFetch.Enabled
case "write_file":
return t.WriteFile.Enabled
case "mcp":
return t.MCP.Enabled
default:
return true
}
}
+77 -2
View File
@@ -316,6 +316,20 @@ func DefaultConfig() *Config {
APIKey: "",
},
// Avian - https://avian.io
{
ModelName: "deepseek-v3.2",
Model: "avian/deepseek/deepseek-v3.2",
APIBase: "https://api.avian.io/v1",
APIKey: "",
},
{
ModelName: "kimi-k2.5",
Model: "avian/moonshotai/kimi-k2.5",
APIBase: "https://api.avian.io/v1",
APIKey: "",
},
// VLLM (local) - http://localhost:8000
{
ModelName: "local-model",
@@ -330,11 +344,16 @@ func DefaultConfig() *Config {
},
Tools: ToolsConfig{
MediaCleanup: MediaCleanupConfig{
Enabled: true,
ToolConfig: ToolConfig{
Enabled: true,
},
MaxAge: 30,
Interval: 5,
},
Web: WebToolsConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
Proxy: "",
FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
Brave: BraveConfig{
@@ -351,6 +370,11 @@ func DefaultConfig() *Config {
APIKey: "",
MaxResults: 5,
},
SearXNG: SearXNGConfig{
Enabled: false,
BaseURL: "",
MaxResults: 5,
},
GLMSearch: GLMSearchConfig{
Enabled: false,
APIKey: "",
@@ -360,12 +384,22 @@ func DefaultConfig() *Config {
},
},
Cron: CronToolsConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
ExecTimeoutMinutes: 5,
},
Exec: ExecConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
EnableDenyPatterns: true,
TimeoutSeconds: 60,
},
Skills: SkillsToolsConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
Registries: SkillsRegistriesConfig{
ClawHub: ClawHubRegistryConfig{
Enabled: true,
@@ -379,9 +413,50 @@ func DefaultConfig() *Config {
},
},
MCP: MCPConfig{
Enabled: false,
ToolConfig: ToolConfig{
Enabled: false,
},
Servers: map[string]MCPServerConfig{},
},
AppendFile: ToolConfig{
Enabled: true,
},
EditFile: ToolConfig{
Enabled: true,
},
FindSkills: ToolConfig{
Enabled: true,
},
I2C: ToolConfig{
Enabled: false, // Hardware tool - Linux only
},
InstallSkill: ToolConfig{
Enabled: true,
},
ListDir: ToolConfig{
Enabled: true,
},
Message: ToolConfig{
Enabled: true,
},
ReadFile: ToolConfig{
Enabled: true,
},
Spawn: ToolConfig{
Enabled: true,
},
SPI: ToolConfig{
Enabled: false, // Hardware tool - Linux only
},
Subagent: ToolConfig{
Enabled: true,
},
WebFetch: ToolConfig{
Enabled: true,
},
WriteFile: ToolConfig{
Enabled: true,
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
+17
View File
@@ -390,6 +390,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
}, true
},
},
{
providerNames: []string{"avian"},
protocol: "avian",
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
if p.Avian.APIKey == "" && p.Avian.APIBase == "" {
return ModelConfig{}, false
}
return ModelConfig{
ModelName: "avian",
Model: "avian/deepseek/deepseek-v3.2",
APIKey: p.Avian.APIKey,
APIBase: p.Avian.APIBase,
Proxy: p.Avian.Proxy,
RequestTimeout: p.Avian.RequestTimeout,
}, true
},
},
}
// Process each provider migration
+6 -5
View File
@@ -159,16 +159,17 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
VolcEngine: ProviderConfig{APIKey: "key15"},
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
Antigravity: ProviderConfig{AuthMethod: "oauth"},
Qwen: ProviderConfig{APIKey: "key18"},
Mistral: ProviderConfig{APIKey: "key19"},
Qwen: ProviderConfig{APIKey: "key17"},
Mistral: ProviderConfig{APIKey: "key18"},
Avian: ProviderConfig{APIKey: "key19"},
},
}
result := ConvertProvidersToModelList(cfg)
// All 19 providers should be converted
if len(result) != 19 {
t.Errorf("len(result) = %d, want 19", len(result))
// All 21 providers should be converted
if len(result) != 21 {
t.Errorf("len(result) = %d, want 21", len(result))
}
}
+13 -3
View File
@@ -194,7 +194,9 @@ func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
mgr := NewManager()
mcpCfg := config.MCPConfig{
Enabled: true,
ToolConfig: config.ToolConfig{
Enabled: true,
},
Servers: map[string]config.MCPServerConfig{
"test-server": {
Enabled: true,
@@ -228,12 +230,20 @@ func TestNewManager_InitialState(t *testing.T) {
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
mgr := NewManager()
err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
err := mgr.LoadFromMCPConfig(
context.Background(),
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}},
"/tmp",
)
if err != nil {
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
}
err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
err = mgr.LoadFromMCPConfig(
context.Background(),
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}},
"/tmp",
)
if err != nil {
t.Fatalf("expected nil error when no servers configured, got: %v", err)
}
+79
View File
@@ -31,6 +31,9 @@ type Provider struct {
baseURL string
}
// SupportsThinking implements providers.ThinkingCapable.
func (p *Provider) SupportsThinking() bool { return true }
func NewProvider(token string) *Provider {
return NewProviderWithBaseURL(token, "")
}
@@ -182,9 +185,80 @@ func buildParams(
params.Tools = translateTools(tools)
}
// Extended Thinking / Adaptive Thinking
// The thinking_level value directly determines the API parameter format:
// "adaptive" → {thinking: {type: "adaptive"}} + output_config.effort
// "low/medium/high/xhigh" → {thinking: {type: "enabled", budget_tokens: N}}
if level, ok := options["thinking_level"].(string); ok && level != "" && level != "off" {
applyThinkingConfig(&params, level)
}
return params, nil
}
// applyThinkingConfig sets thinking parameters based on the level value.
// "adaptive" uses the adaptive thinking API (Claude 4.6+).
// All other levels use budget_tokens which is universally supported.
//
// Anthropic API constraint: temperature must not be set when thinking is enabled.
// budget_tokens must be strictly less than max_tokens.
func applyThinkingConfig(params *anthropic.MessageNewParams, level string) {
// Anthropic API rejects requests with temperature set alongside thinking.
// Reset to zero value (omitted from JSON serialization).
if params.Temperature.Valid() {
log.Printf("anthropic: temperature cleared because thinking is enabled (level=%s)", level)
}
params.Temperature = anthropic.MessageNewParams{}.Temperature
if level == "adaptive" {
adaptive := anthropic.NewThinkingConfigAdaptiveParam()
params.Thinking = anthropic.ThinkingConfigParamUnion{OfAdaptive: &adaptive}
params.OutputConfig = anthropic.OutputConfigParam{
Effort: anthropic.OutputConfigEffortHigh,
}
return
}
budget := int64(levelToBudget(level))
if budget <= 0 {
return
}
// budget_tokens must be < max_tokens; clamp to respect user's max_tokens setting.
if budget >= params.MaxTokens {
log.Printf("anthropic: budget_tokens (%d) clamped to %d (max_tokens-1)", budget, params.MaxTokens-1)
budget = params.MaxTokens - 1
} else if budget > params.MaxTokens*80/100 {
log.Printf("anthropic: thinking budget (%d) exceeds 80%% of max_tokens (%d), output may be truncated",
budget, params.MaxTokens)
}
params.Thinking = anthropic.ThinkingConfigParamOfEnabled(budget)
}
// levelToBudget maps a thinking level to budget_tokens.
// Values are based on Anthropic's recommendations and community best practices:
//
// low = 4,096 — simple reasoning, quick debugging (Claude Code "think")
// medium = 16,384 — Anthropic recommended sweet spot for most tasks
// high = 32,000 — complex architecture, deep analysis (diminishing returns above this)
// xhigh = 64,000 — extreme reasoning, research problems, benchmarks
//
// Note: For Claude 4.6+, prefer adaptive thinking over manual budget_tokens.
func levelToBudget(level string) int {
switch level {
case "low":
return 4096
case "medium":
return 16384
case "high":
return 32000
case "xhigh":
return 64000
default:
return 0
}
}
func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
@@ -213,10 +287,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
func parseResponse(resp *anthropic.Message) *LLMResponse {
var content strings.Builder
var reasoning strings.Builder
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "thinking":
tb := block.AsThinking()
reasoning.WriteString(tb.Thinking)
case "text":
tb := block.AsText()
content.WriteString(tb.Text)
@@ -247,6 +325,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
return &LLMResponse{
Content: content.String(),
Reasoning: reasoning.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
+212
View File
@@ -0,0 +1,212 @@
package anthropicprovider
import (
"encoding/json"
"testing"
"github.com/anthropics/anthropic-sdk-go"
)
func TestApplyThinkingConfig_Adaptive(t *testing.T) {
params := anthropic.MessageNewParams{
MaxTokens: 16000,
Temperature: anthropic.Float(0.7),
}
applyThinkingConfig(&params, "adaptive")
if params.Thinking.OfAdaptive == nil {
t.Fatal("expected adaptive thinking")
}
if params.Thinking.OfEnabled != nil {
t.Error("should not set enabled thinking in adaptive mode")
}
if params.OutputConfig.Effort != anthropic.OutputConfigEffortHigh {
t.Errorf("effort = %q, want %q", params.OutputConfig.Effort, anthropic.OutputConfigEffortHigh)
}
if params.Temperature.Valid() {
t.Error("temperature should be cleared when thinking is enabled")
}
}
func TestApplyThinkingConfig_BudgetLevels(t *testing.T) {
tests := []struct {
level string
wantBudget int64
}{
{"low", 4096},
{"medium", 16384},
{"high", 32000},
{"xhigh", 64000},
}
for _, tt := range tests {
t.Run(tt.level, func(t *testing.T) {
params := anthropic.MessageNewParams{
MaxTokens: 200000,
Temperature: anthropic.Float(0.5),
}
applyThinkingConfig(&params, tt.level)
if params.Thinking.OfEnabled == nil {
t.Fatal("expected enabled thinking")
}
if params.Thinking.OfAdaptive != nil {
t.Error("should not set adaptive thinking")
}
if params.Thinking.OfEnabled.BudgetTokens != tt.wantBudget {
t.Errorf("budget_tokens = %d, want %d", params.Thinking.OfEnabled.BudgetTokens, tt.wantBudget)
}
if params.OutputConfig.Effort != "" {
t.Errorf("effort = %q, want empty", params.OutputConfig.Effort)
}
if params.Temperature.Valid() {
t.Error("temperature should be cleared when thinking is enabled")
}
})
}
}
func TestApplyThinkingConfig_BudgetClamp(t *testing.T) {
// budget_tokens must be < max_tokens; clamp budget down to respect user's max_tokens.
params := anthropic.MessageNewParams{MaxTokens: 4096}
applyThinkingConfig(&params, "high") // budget=32000 > maxTokens=4096
if params.Thinking.OfEnabled == nil {
t.Fatal("expected enabled thinking")
}
if params.Thinking.OfEnabled.BudgetTokens != 4095 {
t.Errorf("budget_tokens = %d, want 4095 (maxTokens-1)", params.Thinking.OfEnabled.BudgetTokens)
}
if params.MaxTokens != 4096 {
t.Errorf("max_tokens should not be modified, got %d", params.MaxTokens)
}
}
func TestApplyThinkingConfig_UnknownLevel(t *testing.T) {
params := anthropic.MessageNewParams{MaxTokens: 16000}
applyThinkingConfig(&params, "unknown")
if params.Thinking.OfEnabled != nil {
t.Error("should not set enabled thinking for unknown level")
}
if params.Thinking.OfAdaptive != nil {
t.Error("should not set adaptive thinking for unknown level")
}
}
func TestLevelToBudget(t *testing.T) {
tests := []struct {
name string
level string
want int
}{
{"low", "low", 4096},
{"medium", "medium", 16384},
{"high", "high", 32000},
{"xhigh", "xhigh", 64000},
{"off", "off", 0},
{"empty", "", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := levelToBudget(tt.level); got != tt.want {
t.Errorf("levelToBudget(%q) = %d, want %d", tt.level, got, tt.want)
}
})
}
}
func TestBuildParams_ThinkingClearsTemperature(t *testing.T) {
msgs := []Message{{Role: "user", Content: "hello"}}
opts := map[string]any{
"max_tokens": 200000,
"temperature": 0.8,
"thinking_level": "medium",
}
params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts)
if err != nil {
t.Fatal(err)
}
if params.Temperature.Valid() {
t.Error("temperature should be cleared when thinking_level is set")
}
if params.Thinking.OfEnabled == nil {
t.Fatal("expected enabled thinking")
}
if params.Thinking.OfEnabled.BudgetTokens != 16384 {
t.Errorf("budget_tokens = %d, want 16384", params.Thinking.OfEnabled.BudgetTokens)
}
}
// unmarshalBlocks constructs []ContentBlockUnion via JSON round-trip so that
// the internal JSON.raw field is populated (required by AsText/AsThinking).
func unmarshalBlocks(t *testing.T, jsonStr string) []anthropic.ContentBlockUnion {
t.Helper()
var blocks []anthropic.ContentBlockUnion
if err := json.Unmarshal([]byte(jsonStr), &blocks); err != nil {
t.Fatalf("unmarshalBlocks: %v", err)
}
return blocks
}
func TestParseResponse_ThinkingBlock(t *testing.T) {
resp := &anthropic.Message{
Content: unmarshalBlocks(t, `[
{"type":"thinking","thinking":"Let me reason step by step...","signature":"sig"},
{"type":"text","text":"The answer is 42."}
]`),
StopReason: anthropic.StopReasonEndTurn,
}
result := parseResponse(resp)
if result.Reasoning != "Let me reason step by step..." {
t.Errorf("Reasoning = %q, want thinking content", result.Reasoning)
}
if result.Content != "The answer is 42." {
t.Errorf("Content = %q, want text content", result.Content)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", result.FinishReason)
}
}
func TestParseResponse_NoThinkingBlock(t *testing.T) {
resp := &anthropic.Message{
Content: unmarshalBlocks(t, `[
{"type":"text","text":"Just a normal response."}
]`),
StopReason: anthropic.StopReasonEndTurn,
}
result := parseResponse(resp)
if result.Reasoning != "" {
t.Errorf("Reasoning = %q, want empty", result.Reasoning)
}
if result.Content != "Just a normal response." {
t.Errorf("Content = %q, want text content", result.Content)
}
}
func TestBuildParams_NoThinkingKeepsTemperature(t *testing.T) {
msgs := []Message{{Role: "user", Content: "hello"}}
opts := map[string]any{
"temperature": 0.8,
}
params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts)
if err != nil {
t.Fatal(err)
}
if !params.Temperature.Valid() {
t.Error("temperature should be preserved when thinking is not set")
}
if params.Temperature.Value != 0.8 {
t.Errorf("temperature = %f, want 0.8", params.Temperature.Value)
}
}
+8 -2
View File
@@ -640,7 +640,10 @@ func FetchAntigravityProjectID(accessToken string) (string, error) {
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading loadCodeAssist response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("loadCodeAssist failed: %s", string(body))
}
@@ -681,7 +684,10 @@ func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelIn
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading fetchAvailableModels response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(
"fetchAvailableModels failed (HTTP %d): %s",
+16
View File
@@ -190,6 +190,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.model = "deepseek-chat"
}
}
case "avian":
if cfg.Providers.Avian.APIKey != "" {
sel.apiKey = cfg.Providers.Avian.APIKey
sel.apiBase = cfg.Providers.Avian.APIBase
sel.proxy = cfg.Providers.Avian.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.avian.io/v1"
}
}
case "mistral":
if cfg.Providers.Mistral.APIKey != "" {
sel.apiKey = cfg.Providers.Mistral.APIKey
@@ -316,6 +325,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if sel.apiBase == "" {
sel.apiBase = "https://api.mistral.ai/v1"
}
case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "":
sel.apiKey = cfg.Providers.Avian.APIKey
sel.apiBase = cfg.Providers.Avian.APIBase
sel.proxy = cfg.Providers.Avian.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.avian.io/v1"
}
case cfg.Providers.VLLM.APIBase != "":
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
+3 -1
View File
@@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "mistral":
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
@@ -210,6 +210,8 @@ func getDefaultAPIBase(protocol string) string {
return "http://localhost:8000/v1"
case "mistral":
return "https://api.mistral.ai/v1"
case "avian":
return "https://api.avian.io/v1"
default:
return ""
}
+8 -6
View File
@@ -323,12 +323,14 @@ func serializeMessages(messages []Message) []any {
})
}
for _, mediaURL := range m.Media {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
if strings.HasPrefix(mediaURL, "data:image/") {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
}
}
msg := map[string]any{
+7
View File
@@ -37,6 +37,13 @@ type StatefulProvider interface {
Close()
}
// ThinkingCapable is an optional interface for providers that support
// extended thinking (e.g. Anthropic). Used by the agent loop to warn
// when thinking_level is configured but the active provider cannot use it.
type ThinkingCapable interface {
SupportsThinking() bool
}
// FailoverReason classifies why an LLM request failed for fallback decisions.
type FailoverReason string
+80
View File
@@ -0,0 +1,80 @@
package routing
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
// A higher score indicates a more complex task that benefits from a heavy model.
// The score is compared against the configured threshold: score >= threshold selects
// the primary (heavy) model; score < threshold selects the light model.
//
// Classifier is an interface so that future implementations (ML-based, embedding-based,
// or any other approach) can be swapped in without changing routing infrastructure.
type Classifier interface {
Score(f Features) float64
}
// RuleClassifier is the v1 implementation.
// It uses a weighted sum of structural signals with no external dependencies,
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
// that the returned score always falls within the [0, 1] contract.
//
// Individual weights (multiple signals can fire simultaneously):
//
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
// token 50-200: 0.15 — medium length; may or may not be complex
// code block present: 0.40 — coding tasks need the heavy model
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
// tool calls 1-3 (recent): 0.10 — some tool activity
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
//
// Default threshold is 0.35, so:
// - Pure greetings / trivial Q&A: 0.00 → light ✓
// - Medium prose message (50200 tokens): 0.15 → light ✓
// - Message with code block: 0.40 → heavy ✓
// - Long message (>200 tokens): 0.35 → heavy ✓
// - Active tool session + medium message: 0.25 → light (acceptable)
// - Any message with an image/audio attachment: 1.00 → heavy ✓
type RuleClassifier struct{}
// Score computes the complexity score for the given feature set.
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
func (c *RuleClassifier) Score(f Features) float64 {
// Hard gate: multi-modal inputs always require the heavy model.
if f.HasAttachments {
return 1.0
}
var score float64
// Token estimate — primary verbosity signal
switch {
case f.TokenEstimate > 200:
score += 0.35
case f.TokenEstimate > 50:
score += 0.15
}
// Fenced code blocks — strongest indicator of a coding/technical task
if f.CodeBlockCount > 0 {
score += 0.40
}
// Recent tool call density — indicates an ongoing agentic workflow
switch {
case f.RecentToolCalls > 3:
score += 0.25
case f.RecentToolCalls > 0:
score += 0.10
}
// Conversation depth — accumulated context implies compound task
if f.ConversationDepth > 10 {
score += 0.10
}
// Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
if score > 1.0 {
score = 1.0
}
return score
}
+127
View File
@@ -0,0 +1,127 @@
package routing
import (
"strings"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
)
// lookbackWindow is the number of recent history entries scanned for tool calls.
// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant).
const lookbackWindow = 6
// Features holds the structural signals extracted from a message and its session context.
// Every dimension is language-agnostic by construction — no keyword or pattern matching
// against natural-language content. This ensures consistent routing for all locales.
type Features struct {
// TokenEstimate is a proxy for token count.
// CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each.
// This avoids API calls while giving accurate estimates for all scripts.
TokenEstimate int
// CodeBlockCount is the number of fenced code blocks (``` pairs) in the message.
// Coding tasks almost always require the heavy model.
CodeBlockCount int
// RecentToolCalls is the count of tool_call messages in the last lookbackWindow
// history entries. A high density indicates an active agentic workflow.
RecentToolCalls int
// ConversationDepth is the total number of messages in the session history.
// Deep sessions tend to carry implicit complexity built up over many turns.
ConversationDepth int
// HasAttachments is true when the message appears to contain media (images,
// audio, video). Multi-modal inputs require vision-capable heavy models.
HasAttachments bool
}
// ExtractFeatures computes the structural feature vector for a message.
// It is a pure function with no side effects and zero allocations beyond
// the returned struct.
func ExtractFeatures(msg string, history []providers.Message) Features {
return Features{
TokenEstimate: estimateTokens(msg),
CodeBlockCount: countCodeBlocks(msg),
RecentToolCalls: countRecentToolCalls(history),
ConversationDepth: len(history),
HasAttachments: hasAttachments(msg),
}
}
// estimateTokens returns a token count proxy that handles both CJK and Latin text.
// CJK runes (U+2E80U+9FFF, U+F900U+FAFF, U+AC00U+D7AF) map to roughly one
// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token
// for English). Splitting the count this way avoids the 3x underestimation that a
// flat rune_count/3 would produce for Chinese, Japanese, and Korean text.
func estimateTokens(msg string) int {
total := utf8.RuneCountInString(msg)
if total == 0 {
return 0
}
cjk := 0
for _, r := range msg {
if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF {
cjk++
}
}
return cjk + (total-cjk)/4
}
// countCodeBlocks counts the number of complete fenced code blocks.
// Each ``` delimiter increments a counter; pairs of delimiters form one block.
// An unclosed opening fence (odd count) is treated as zero complete blocks
// since it may just be an inline code span or a typo.
func countCodeBlocks(msg string) int {
n := strings.Count(msg, "```")
return n / 2
}
// countRecentToolCalls counts messages with tool calls in the last lookbackWindow
// entries of history. It examines the ToolCalls field rather than parsing
// the content string, so it is robust to any message format.
func countRecentToolCalls(history []providers.Message) int {
start := len(history) - lookbackWindow
if start < 0 {
start = 0
}
count := 0
for _, msg := range history[start:] {
if len(msg.ToolCalls) > 0 {
count += len(msg.ToolCalls)
}
}
return count
}
// hasAttachments returns true when the message content contains embedded media.
// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and
// common image/audio URL extensions. This is intentionally conservative —
// false negatives (missing an attachment) just mean the routing falls back to
// the primary model anyway.
func hasAttachments(msg string) bool {
lower := strings.ToLower(msg)
// Base64 data URIs embedded directly in the message
if strings.Contains(lower, "data:image/") ||
strings.Contains(lower, "data:audio/") ||
strings.Contains(lower, "data:video/") {
return true
}
// Common image/audio extensions in URLs or file references
mediaExts := []string{
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp",
".mp3", ".wav", ".ogg", ".m4a", ".flac",
".mp4", ".avi", ".mov", ".webm",
}
for _, ext := range mediaExts {
if strings.Contains(lower, ext) {
return true
}
}
return false
}
+82
View File
@@ -0,0 +1,82 @@
package routing
import (
"github.com/sipeed/picoclaw/pkg/providers"
)
// defaultThreshold is used when the config threshold is zero or negative.
// At 0.35 a message needs at least one strong signal (code block, long text,
// or an attachment) before the heavy model is chosen.
const defaultThreshold = 0.35
// RouterConfig holds the validated model routing settings.
// It mirrors config.RoutingConfig but lives in pkg/routing to keep the
// dependency graph simple: pkg/agent resolves config → routing, not the reverse.
type RouterConfig struct {
// LightModel is the model_name (from model_list) used for simple tasks.
LightModel string
// Threshold is the complexity score cutoff in [0, 1].
// score >= Threshold → primary (heavy) model.
// score < Threshold → light model.
Threshold float64
}
// Router selects the appropriate model tier for each incoming message.
// It is safe for concurrent use from multiple goroutines.
type Router struct {
cfg RouterConfig
classifier Classifier
}
// New creates a Router with the given config and the default RuleClassifier.
// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used.
func New(cfg RouterConfig) *Router {
if cfg.Threshold <= 0 {
cfg.Threshold = defaultThreshold
}
return &Router{
cfg: cfg,
classifier: &RuleClassifier{},
}
}
// newWithClassifier creates a Router with a custom Classifier.
// Intended for unit tests that need to inject a deterministic scorer.
func newWithClassifier(cfg RouterConfig, c Classifier) *Router {
if cfg.Threshold <= 0 {
cfg.Threshold = defaultThreshold
}
return &Router{cfg: cfg, classifier: c}
}
// SelectModel returns the model to use for this conversation turn along with
// the computed complexity score (for logging and debugging).
//
// - If score < cfg.Threshold: returns (cfg.LightModel, true, score)
// - Otherwise: returns (primaryModel, false, score)
//
// The caller is responsible for resolving the returned model name into
// provider candidates (see AgentInstance.LightCandidates).
func (r *Router) SelectModel(
msg string,
history []providers.Message,
primaryModel string,
) (model string, usedLight bool, score float64) {
features := ExtractFeatures(msg, history)
score = r.classifier.Score(features)
if score < r.cfg.Threshold {
return r.cfg.LightModel, true, score
}
return primaryModel, false, score
}
// LightModel returns the configured light model name.
func (r *Router) LightModel() string {
return r.cfg.LightModel
}
// Threshold returns the complexity threshold in use.
func (r *Router) Threshold() float64 {
return r.cfg.Threshold
}
+414
View File
@@ -0,0 +1,414 @@
package routing
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
// ── ExtractFeatures ──────────────────────────────────────────────────────────
func TestExtractFeatures_EmptyMessage(t *testing.T) {
f := ExtractFeatures("", nil)
if f.TokenEstimate != 0 {
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
}
if f.CodeBlockCount != 0 {
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
}
if f.RecentToolCalls != 0 {
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
}
if f.ConversationDepth != 0 {
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
}
if f.HasAttachments {
t.Error("HasAttachments: got true, want false")
}
}
func TestExtractFeatures_TokenEstimate(t *testing.T) {
// 30 ASCII runes: 0 CJK + 30/4 = 7 tokens
msg := strings.Repeat("a", 30)
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 7 {
t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate)
}
}
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
// 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token).
// Using a rune slice literal avoids CJK string literals in source.
msg := string([]rune{
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60,
})
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 9 {
t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate)
}
}
func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) {
// Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens.
msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok"
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 6 {
t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate)
}
}
func TestExtractFeatures_CodeBlocks(t *testing.T) {
cases := []struct {
msg string
want int
}{
{"no code here", 0},
{"```go\nfmt.Println()\n```", 1},
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.CodeBlockCount != tc.want {
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
}
}
}
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
// History longer than lookbackWindow — only last lookbackWindow entries count.
history := make([]providers.Message, 10)
// Put 2 tool calls at positions 8 and 9 (within the last 6)
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
history[9] = providers.Message{
Role: "assistant",
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
}
// Position 3 is outside the lookback window and must NOT be counted
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
f := ExtractFeatures("test", history)
// 1 (position 8) + 2 (position 9) = 3
if f.RecentToolCalls != 3 {
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
}
}
func TestExtractFeatures_ConversationDepth(t *testing.T) {
history := make([]providers.Message, 7)
f := ExtractFeatures("msg", history)
if f.ConversationDepth != 7 {
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
}
}
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"plain text", false},
{"here is an image: data:image/png;base64,abc123", true},
{"audio: data:audio/mp3;base64,xyz", true},
{"video: data:video/mp4;base64,xyz", true},
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"check out photo.jpg", true},
{"see screenshot.png", true},
{"listen to audio.mp3", true},
{"watch clip.mp4", true},
{"just a .go file", false},
{"document.pdf", false}, // pdf is not in the media list
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
// ── RuleClassifier ───────────────────────────────────────────────────────────
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{})
if score != 0.0 {
t.Errorf("zero features: got %f, want 0.0", score)
}
}
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{HasAttachments: true})
if score != 1.0 {
t.Errorf("attachments: got %f, want 1.0", score)
}
}
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
c := &RuleClassifier{}
// Code block alone = 0.40, above default threshold 0.35
score := c.Score(Features{CodeBlockCount: 1})
if score < 0.35 {
t.Errorf("code block: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_LongMessage(t *testing.T) {
c := &RuleClassifier{}
// >200 tokens = 0.35, exactly at default threshold → heavy
score := c.Score(Features{TokenEstimate: 250})
if score < 0.35 {
t.Errorf("long message: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_MediumMessage(t *testing.T) {
c := &RuleClassifier{}
// 50-200 tokens = 0.15, below threshold → light
score := c.Score(Features{TokenEstimate: 100})
if score >= 0.35 {
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
}
}
func TestRuleClassifier_ShortMessage(t *testing.T) {
c := &RuleClassifier{}
// <50 tokens, no other signals = 0.0 → light
score := c.Score(Features{TokenEstimate: 10})
if score != 0.0 {
t.Errorf("short message: got %f, want 0.0", score)
}
}
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
c := &RuleClassifier{}
scoreNone := c.Score(Features{RecentToolCalls: 0})
scoreLow := c.Score(Features{RecentToolCalls: 2})
scoreHigh := c.Score(Features{RecentToolCalls: 5})
if scoreNone != 0.0 {
t.Errorf("no tools: got %f, want 0.0", scoreNone)
}
if scoreLow <= scoreNone {
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
}
if scoreHigh <= scoreLow {
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
}
}
func TestRuleClassifier_DeepConversation(t *testing.T) {
c := &RuleClassifier{}
shallow := c.Score(Features{ConversationDepth: 5})
deep := c.Score(Features{ConversationDepth: 15})
if deep <= shallow {
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
}
}
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
c := &RuleClassifier{}
// Max all signals simultaneously
f := Features{
TokenEstimate: 500,
CodeBlockCount: 3,
RecentToolCalls: 10,
ConversationDepth: 20,
}
score := c.Score(f)
if score > 1.0 {
t.Errorf("score %f exceeds 1.0", score)
}
}
// ── Router ───────────────────────────────────────────────────────────────────
func TestRouter_DefaultThreshold(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash"})
if r.Threshold() != defaultThreshold {
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
if r.Threshold() != defaultThreshold {
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "hi"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("simple message: expected light model to be selected")
}
if model != "gemini-flash" {
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
}
}
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "```go\nfmt.Println(\"hello\")\n```"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("code block: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "can you analyze this? data:image/png;base64,abc123"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("attachment: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
// >200 token estimate: 210 * 3 = 630 chars
msg := strings.Repeat("word ", 210)
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("long message: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
// Routing is conservative: only promote to heavy when the signal is unambiguous.
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
}
msg := "ok"
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
if !usedLight {
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
}
}
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
}},
}
// ~55 tokens * 3 = 165 chars
msg := strings.Repeat("word ", 55)
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
if usedLight {
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
}
}
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
// Very low threshold: even a short message triggers heavy model
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("low threshold: medium message should use primary model")
}
}
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
// Very high threshold: even code blocks route to light
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
msg := "```go\nfmt.Println()\n```"
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("very high threshold: code block (0.40) should route to light model")
}
}
func TestRouter_LightModel(t *testing.T) {
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
if r.LightModel() != "my-fast-model" {
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
}
}
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
type fixedScoreClassifier struct{ score float64 }
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.2},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if !usedLight {
t.Error("low score with custom classifier: expected light model")
}
}
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.8},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("high score with custom classifier: expected primary model")
}
}
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
// score == threshold → primary (uses >= comparison)
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.5},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("score == threshold: expected primary model (>= threshold → primary)")
}
}
func TestRouter_SelectModel_ReturnsScore(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.42},
)
_, _, score := r.SelectModel("anything", nil, "heavy")
if score != 0.42 {
t.Errorf("score: got %f, want 0.42", score)
}
}
+64 -16
View File
@@ -259,15 +259,7 @@ func (c *ClawHubRegistry) DownloadAndInstall(
}
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize))
tmpPath, err := c.downloadToTempFileWithRetry(ctx, u.String())
if err != nil {
return nil, fmt.Errorf("download failed: %w", err)
}
@@ -284,17 +276,12 @@ func (c *ClawHubRegistry) DownloadAndInstall(
// --- HTTP helper ---
func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
req, err := c.newGetRequest(ctx, urlStr, "application/json")
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
resp, err := c.client.Do(req)
resp, err := utils.DoRequestWithRetry(c.client, req)
if err != nil {
return nil, err
}
@@ -312,3 +299,64 @@ func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, err
return body, nil
}
func (c *ClawHubRegistry) newGetRequest(ctx context.Context, urlStr, accept string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", accept)
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
return req, nil
}
func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlStr string) (string, error) {
req, err := c.newGetRequest(ctx, urlStr, "application/zip")
if err != nil {
return "", err
}
resp, err := utils.DoRequestWithRetry(c.client, req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
errBody := make([]byte, 512)
n, _ := io.ReadFull(resp.Body, errBody)
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n]))
}
tmpFile, err := os.CreateTemp("", "picoclaw-dl-*")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
cleanup := func() {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
src := io.LimitReader(resp.Body, int64(c.maxZipSize)+1)
written, err := io.Copy(tmpFile, src)
if err != nil {
cleanup()
return "", fmt.Errorf("download write failed: %w", err)
}
if written > int64(c.maxZipSize) {
cleanup()
return "", fmt.Errorf("download too large: %d bytes (max %d)", written, c.maxZipSize)
}
if err := tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Errorf("failed to close temp file: %w", err)
}
return tmpPath, nil
}
+81
View File
@@ -54,6 +54,39 @@ func TestClawHubRegistrySearch(t *testing.T) {
assert.Equal(t, "clawhub", results[0].RegistryName)
}
func TestClawHubRegistrySearchRetries429(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", "0")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
slug := "github"
name := "GitHub Integration"
summary := "Interact with GitHub repos"
version := "1.0.0"
json.NewEncoder(w).Encode(clawhubSearchResponse{
Results: []clawhubSearchResult{
{Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version},
},
})
}))
defer srv.Close()
reg := newTestRegistry(srv.URL, "")
results, err := reg.Search(context.Background(), "github", 5)
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, 2, attempts)
assert.Equal(t, "github", results[0].Slug)
}
func TestClawHubRegistryGetSkillMeta(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/v1/skills/github", r.URL.Path)
@@ -137,6 +170,54 @@ func TestClawHubRegistryDownloadAndInstall(t *testing.T) {
assert.Contains(t, string(readmeContent), "# Test Skill")
}
func TestClawHubRegistryDownloadAndInstallRetries429(t *testing.T) {
zipBuf := createTestZip(t, map[string]string{
"SKILL.md": "---\nname: retry-skill\ndescription: A test\n---\nHello skill",
})
downloadAttempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/skills/retry-skill":
json.NewEncoder(w).Encode(clawhubSkillResponse{
Slug: "retry-skill",
DisplayName: "Retry Skill",
Summary: "A retry test skill",
LatestVersion: &clawhubVersionInfo{Version: "1.0.0"},
})
case "/api/v1/download":
downloadAttempts++
if downloadAttempts == 1 {
w.Header().Set("Retry-After", "0")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
assert.Equal(t, "retry-skill", r.URL.Query().Get("slug"))
w.Header().Set("Content-Type", "application/zip")
w.Write(zipBuf)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
tmpDir := t.TempDir()
targetDir := filepath.Join(tmpDir, "retry-skill")
reg := newTestRegistry(srv.URL, "")
result, err := reg.DownloadAndInstall(context.Background(), "retry-skill", "", targetDir)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, "1.0.0", result.Version)
assert.Equal(t, 2, downloadAttempts)
skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md"))
require.NoError(t, err)
assert.Contains(t, string(skillContent), "Hello skill")
}
func TestClawHubRegistryAuthToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
+50 -38
View File
@@ -10,11 +10,38 @@ type Tool interface {
Execute(ctx context.Context, args map[string]any) *ToolResult
}
// ContextualTool is an optional interface that tools can implement
// to receive the current message context (channel, chatID)
type ContextualTool interface {
Tool
SetContext(channel, chatID string)
// --- Request-scoped tool context (channel / chatID) ---
//
// Carried via context.Value so that concurrent tool calls each receive
// their own immutable copy — no mutable state on singleton tool instances.
//
// Keys are unexported pointer-typed vars — guaranteed collision-free,
// and only accessible through the helper functions below.
type toolCtxKey struct{ name string }
var (
ctxKeyChannel = &toolCtxKey{"channel"}
ctxKeyChatID = &toolCtxKey{"chatID"}
)
// WithToolContext returns a child context carrying channel and chatID.
func WithToolContext(ctx context.Context, channel, chatID string) context.Context {
ctx = context.WithValue(ctx, ctxKeyChannel, channel)
ctx = context.WithValue(ctx, ctxKeyChatID, chatID)
return ctx
}
// ToolChannel extracts the channel from ctx, or "" if unset.
func ToolChannel(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyChannel).(string)
return v
}
// ToolChatID extracts the chatID from ctx, or "" if unset.
func ToolChatID(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyChatID).(string)
return v
}
// AsyncCallback is a function type that async tools use to notify completion.
@@ -22,51 +49,36 @@ type ContextualTool interface {
//
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
// The result parameter contains the tool's execution result.
//
// Example usage in an async tool:
//
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// // Start async work in background
// go func() {
// result := doAsyncWork()
// if t.callback != nil {
// t.callback(ctx, result)
// }
// }()
// return AsyncResult("Async task started")
// }
type AsyncCallback func(ctx context.Context, result *ToolResult)
// AsyncTool is an optional interface that tools can implement to support
// AsyncExecutor is an optional interface that tools can implement to support
// asynchronous execution with completion callbacks.
//
// Async tools return immediately with an AsyncResult, then notify completion
// via the callback set by SetCallback.
// Unlike the old AsyncTool pattern (SetCallback + Execute), AsyncExecutor
// receives the callback as a parameter of ExecuteAsync. This eliminates the
// data race where concurrent calls could overwrite each other's callbacks
// on a shared tool instance.
//
// This is useful for:
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
//
// Example:
//
// type SpawnTool struct {
// callback AsyncCallback
// }
//
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
// t.callback = cb
// }
//
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// go t.runSubagent(ctx, args)
// func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
// go func() {
// result := t.runSubagent(ctx, args)
// if cb != nil { cb(ctx, result) }
// }()
// return AsyncResult("Subagent spawned, will report back")
// }
type AsyncTool interface {
type AsyncExecutor interface {
Tool
// SetCallback registers a callback function to be invoked when the async operation completes.
// The callback will be called from a goroutine and should handle thread-safety if needed.
SetCallback(cb AsyncCallback)
// ExecuteAsync runs the tool asynchronously. The callback cb will be
// invoked (possibly from another goroutine) when the async operation
// completes. cb is guaranteed to be non-nil by the caller (registry).
ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult
}
func ToolToSchema(tool Tool) map[string]any {
+4 -18
View File
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -24,9 +23,6 @@ type CronTool struct {
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
channel string
chatID string
mu sync.RWMutex
}
// NewCronTool creates a new CronTool
@@ -102,14 +98,6 @@ func (t *CronTool) Parameters() map[string]any {
}
}
// SetContext sets the current session context for job creation
func (t *CronTool) SetContext(channel, chatID string) {
t.mu.Lock()
defer t.mu.Unlock()
t.channel = channel
t.chatID = chatID
}
// Execute runs the tool with the given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
action, ok := args["action"].(string)
@@ -119,7 +107,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
switch action {
case "add":
return t.addJob(args)
return t.addJob(ctx, args)
case "list":
return t.listJobs()
case "remove":
@@ -133,11 +121,9 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
}
}
func (t *CronTool) addJob(args map[string]any) *ToolResult {
t.mu.RLock()
channel := t.channel
chatID := t.chatID
t.mu.RUnlock()
func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult {
channel := ToolChannel(ctx)
chatID := ToolChatID(ctx)
if channel == "" || chatID == "" {
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
+8 -10
View File
@@ -9,10 +9,8 @@ import (
type SendCallback func(channel, chatID, content string) error
type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
sendCallback SendCallback
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -48,10 +46,10 @@ func (t *MessageTool) Parameters() map[string]any {
}
}
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
t.sentInRound.Store(false) // Reset send tracking for new processing round
// ResetSentInRound resets the per-round send tracker.
// Called by the agent loop at the start of each inbound message processing round.
func (t *MessageTool) ResetSentInRound() {
t.sentInRound.Store(false)
}
// HasSentInRound returns true if the message tool sent a message during the current round.
@@ -73,10 +71,10 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
chatID, _ := args["chat_id"].(string)
if channel == "" {
channel = t.defaultChannel
channel = ToolChannel(ctx)
}
if chatID == "" {
chatID = t.defaultChatID
chatID = ToolChatID(ctx)
}
if channel == "" || chatID == "" {
+6 -11
View File
@@ -8,7 +8,6 @@ import (
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content string) error {
@@ -18,7 +17,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
return nil
})
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Hello, world!",
}
@@ -60,7 +59,6 @@ func TestMessageTool_Execute_Success(t *testing.T) {
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("default-channel", "default-chat-id")
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content string) error {
@@ -69,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
return nil
})
ctx := context.Background()
ctx := WithToolContext(context.Background(), "default-channel", "default-chat-id")
args := map[string]any{
"content": "Test message",
"channel": "custom-channel",
@@ -96,14 +94,13 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content string) error {
return sendErr
})
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Test message",
}
@@ -133,9 +130,8 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{} // content missing
result := tool.Execute(ctx, args)
@@ -151,7 +147,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No SetContext called, so defaultChannel and defaultChatID are empty
// No WithToolContext — channel/chatID are empty
tool.SetSendCallback(func(channel, chatID, content string) error {
return nil
@@ -175,10 +171,9 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
// No SetSendCallback called
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Test message",
}
+15 -13
View File
@@ -45,8 +45,9 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string
}
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
// If the tool implements AsyncExecutor and a non-nil callback is provided,
// ExecuteAsync is called instead of Execute — the callback is a parameter,
// never stored as mutable state on the tool.
func (r *ToolRegistry) ExecuteWithContext(
ctx context.Context,
name string,
@@ -69,22 +70,23 @@ func (r *ToolRegistry) ExecuteWithContext(
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
}
// If tool implements ContextualTool, set context
if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" {
contextualTool.SetContext(channel, chatID)
}
// Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx).
// Always inject — tools validate what they require.
ctx = WithToolContext(ctx, channel, chatID)
// If tool implements AsyncTool and callback is provided, set callback
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
// If tool implements AsyncExecutor and callback is provided, use ExecuteAsync.
// The callback is a call parameter, not mutable state on the tool instance.
var result *ToolResult
start := time.Now()
if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil {
logger.DebugCF("tool", "Executing async tool via ExecuteAsync",
map[string]any{
"tool": name,
})
result = asyncExec.ExecuteAsync(ctx, args, asyncCallback)
} else {
result = tool.Execute(ctx, args)
}
start := time.Now()
result := tool.Execute(ctx, args)
duration := time.Since(start)
// Log based on result type
+32 -22
View File
@@ -25,24 +25,24 @@ func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolRes
return m.result
}
type mockCtxTool struct {
type mockContextAwareTool struct {
mockRegistryTool
channel string
chatID string
lastCtx context.Context
}
func (m *mockCtxTool) SetContext(channel, chatID string) {
m.channel = channel
m.chatID = chatID
func (m *mockContextAwareTool) Execute(ctx context.Context, _ map[string]any) *ToolResult {
m.lastCtx = ctx
return m.result
}
type mockAsyncRegistryTool struct {
mockRegistryTool
cb AsyncCallback
lastCB AsyncCallback
}
func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
m.cb = cb
func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
m.lastCB = cb
return m.result
}
// --- helpers ---
@@ -136,34 +136,44 @@ func TestToolRegistry_Execute_NotFound(t *testing.T) {
}
}
func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
func TestToolRegistry_ExecuteWithContext_InjectsToolContext(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
ct := &mockContextAwareTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
if ct.channel != "telegram" {
t.Errorf("expected channel 'telegram', got %q", ct.channel)
if ct.lastCtx == nil {
t.Fatal("expected Execute to be called")
}
if ct.chatID != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
if got := ToolChannel(ct.lastCtx); got != "telegram" {
t.Errorf("expected channel 'telegram', got %q", got)
}
if got := ToolChatID(ct.lastCtx); got != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", got)
}
}
func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
func TestToolRegistry_ExecuteWithContext_EmptyContext(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
ct := &mockContextAwareTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
if ct.channel != "" || ct.chatID != "" {
t.Error("SetContext should not be called with empty channel/chatID")
if ct.lastCtx == nil {
t.Fatal("expected Execute to be called")
}
// Empty values are still injected; tools decide what to do with them.
if got := ToolChannel(ct.lastCtx); got != "" {
t.Errorf("expected empty channel, got %q", got)
}
if got := ToolChatID(ct.lastCtx); got != "" {
t.Errorf("expected empty chatID, got %q", got)
}
}
@@ -179,14 +189,14 @@ func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
cb := func(_ context.Context, _ *ToolResult) { called = true }
result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
if at.cb == nil {
t.Error("expected SetCallback to have been called")
if at.lastCB == nil {
t.Error("expected ExecuteAsync to have received a callback")
}
if !result.Async {
t.Error("expected async result")
}
at.cb(context.Background(), SilentResult("done"))
at.lastCB(context.Background(), SilentResult("done"))
if !called {
t.Error("expected callback to be invoked")
}
+6 -1
View File
@@ -131,9 +131,14 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
timeout := 60 * time.Second
if config != nil && config.Tools.Exec.TimeoutSeconds > 0 {
timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second
}
return &ExecTool{
workingDir: workingDir,
timeout: 60 * time.Second,
timeout: timeout,
denyPatterns: denyPatterns,
allowPatterns: nil,
customAllowPatterns: customAllowPatterns,
+27 -17
View File
@@ -8,25 +8,18 @@ import (
type SpawnTool struct {
manager *SubagentManager
originChannel string
originChatID string
allowlistCheck func(targetAgentID string) bool
callback AsyncCallback // For async completion notification
}
// Compile-time check: SpawnTool implements AsyncExecutor.
var _ AsyncExecutor = (*SpawnTool)(nil)
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
return &SpawnTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
manager: manager,
}
}
// SetCallback implements AsyncTool interface for async completion notification
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
t.callback = cb
}
func (t *SpawnTool) Name() string {
return "spawn"
}
@@ -56,16 +49,21 @@ func (t *SpawnTool) Parameters() map[string]any {
}
}
func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
t.allowlistCheck = check
}
func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
return t.execute(ctx, args, nil)
}
// ExecuteAsync implements AsyncExecutor. The callback is passed through to the
// subagent manager as a call parameter — never stored on the SpawnTool instance.
func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
return t.execute(ctx, args, cb)
}
func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
task, ok := args["task"].(string)
if !ok || strings.TrimSpace(task) == "" {
return ErrorResult("task is required and must be a non-empty string")
@@ -85,8 +83,20 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul
return ErrorResult("Subagent manager not configured")
}
// Read channel/chatID from context (injected by registry).
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
// to preserve the same defaults as the original NewSpawnTool constructor.
channel := ToolChannel(ctx)
if channel == "" {
channel = "cli"
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = "direct"
}
// Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback)
result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
}
+14 -12
View File
@@ -252,16 +252,12 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
type SubagentTool struct {
manager *SubagentManager
originChannel string
originChatID string
manager *SubagentManager
}
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
return &SubagentTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
manager: manager,
}
}
@@ -290,11 +286,6 @@ func (t *SubagentTool) Parameters() map[string]any {
}
}
func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
task, ok := args["task"].(string)
if !ok {
@@ -341,13 +332,24 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe
}
}
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
// to preserve the same defaults as the original NewSubagentTool constructor.
channel := ToolChannel(ctx)
if channel == "" {
channel = "cli"
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = "direct"
}
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: llmOptions,
}, messages, t.originChannel, t.originChatID)
}, messages, channel, chatID)
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
+3 -21
View File
@@ -50,9 +50,8 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager.SetLLMOptions(2048, 0.6)
tool := NewSubagentTool(manager)
tool.SetContext("cli", "direct")
ctx := context.Background()
ctx := WithToolContext(context.Background(), "cli", "direct")
args := map[string]any{"task": "Do something"}
result := tool.Execute(ctx, args)
@@ -147,28 +146,14 @@ func TestSubagentTool_Parameters(t *testing.T) {
}
}
// TestSubagentTool_SetContext verifies context setting
func TestSubagentTool_SetContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
tool.SetContext("test-channel", "test-chat")
// Verify context is set (we can't directly access private fields,
// but we can verify it doesn't crash)
// The actual context usage is tested in Execute tests
}
// TestSubagentTool_Execute_Success tests successful execution
func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
args := map[string]any{
"task": "Write a haiku about coding",
"label": "haiku-task",
@@ -297,12 +282,9 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
// Set context
channel := "test-channel"
chatID := "test-chat"
tool.SetContext(channel, chatID)
ctx := context.Background()
ctx := WithToolContext(context.Background(), channel, chatID)
args := map[string]any{
"task": "Test context passing",
}
+71 -1
View File
@@ -395,6 +395,68 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
}
type SearXNGSearchProvider struct {
baseURL string
}
func (p *SearXNGSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("%s/search?q=%s&format=json&categories=general",
strings.TrimSuffix(p.baseURL, "/"),
url.QueryEscape(query))
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("SearXNG returned status %d", resp.StatusCode)
}
var result struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
Engine string `json:"engine"`
Score float64 `json:"score"`
} `json:"results"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
// Limit results to requested count
if len(result.Results) > count {
result.Results = result.Results[:count]
}
// Format results in standard PicoClaw format
var b strings.Builder
b.WriteString(fmt.Sprintf("Results for: %s (via SearXNG)\n", query))
for i, r := range result.Results {
b.WriteString(fmt.Sprintf("%d. %s\n", i+1, r.Title))
b.WriteString(fmt.Sprintf(" %s\n", r.URL))
if r.Content != "" {
b.WriteString(fmt.Sprintf(" %s\n", r.Content))
}
}
return b.String(), nil
}
type GLMSearchProvider struct {
apiKey string
baseURL string
@@ -495,6 +557,9 @@ type WebSearchToolOptions struct {
PerplexityAPIKey string
PerplexityMaxResults int
PerplexityEnabled bool
SearXNGBaseURL string
SearXNGMaxResults int
SearXNGEnabled bool
GLMSearchAPIKey string
GLMSearchBaseURL string
GLMSearchEngine string
@@ -507,7 +572,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
@@ -526,6 +591,11 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
} else if opts.SearXNGEnabled && opts.SearXNGBaseURL != "" {
provider = &SearXNGSearchProvider{baseURL: opts.SearXNGBaseURL}
if opts.SearXNGMaxResults > 0 {
maxResults = opts.SearXNGMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {