From a4546ffb8f208f0b4b9f69262e20b3754c884279 Mon Sep 17 00:00:00 2001 From: Kyle D Date: Fri, 27 Feb 2026 03:02:07 +0000 Subject: [PATCH 01/11] feat: add Avian as a named LLM provider Add Avian (https://avian.io) as an OpenAI-compatible provider with API base https://api.avian.io/v1 and AVIAN_API_KEY env var support. Models: deepseek/deepseek-v3.2, moonshotai/kimi-k2.5, z-ai/glm-5, minimax/minimax-m2.5. Supports chat completions, streaming, and function calling. Changes: - Add Avian to ProvidersConfig struct, IsEmpty(), HasProvidersConfig() - Add avian protocol to factory provider and default API base - Add avian case to legacy provider selection (factory.go) - Add avian migration rule for old config format - Add default model entries to ModelList (deepseek-v3.2, kimi-k2.5) - Add avian to example config - Update AllProviders test count from 18 to 19 --- config/config.example.json | 4 ++++ pkg/config/config.go | 4 +++- pkg/config/defaults.go | 14 ++++++++++++++ pkg/config/migration.go | 17 +++++++++++++++++ pkg/config/migration_test.go | 1 + pkg/providers/factory.go | 16 ++++++++++++++++ pkg/providers/factory_provider.go | 4 +++- 7 files changed, 58 insertions(+), 2 deletions(-) diff --git a/config/config.example.json b/config/config.example.json index adae6f05c..1db97d0bb 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -222,6 +222,10 @@ "mistral": { "api_key": "", "api_base": "https://api.mistral.ai/v1" + }, + "avian": { + "api_key": "", + "api_base": "https://api.avian.io/v1" } }, "tools": { diff --git a/pkg/config/config.go b/pkg/config/config.go index cb2799bba..e50a5c3e8 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -429,6 +429,7 @@ type ProvidersConfig struct { Antigravity ProviderConfig `json:"antigravity"` Qwen ProviderConfig `json:"qwen"` Mistral ProviderConfig `json:"mistral"` + Avian ProviderConfig `json:"avian"` } // IsEmpty checks if all provider configs are empty (no API keys or API bases set) @@ -452,7 +453,8 @@ func (p ProvidersConfig) IsEmpty() bool { p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" && p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" && p.Qwen.APIKey == "" && p.Qwen.APIBase == "" && - p.Mistral.APIKey == "" && p.Mistral.APIBase == "" + p.Mistral.APIKey == "" && p.Mistral.APIBase == "" && + p.Avian.APIKey == "" && p.Avian.APIBase == "" } // MarshalJSON implements custom JSON marshaling for ProvidersConfig diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 9fc09c5f1..5472dd94a 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -306,6 +306,20 @@ func DefaultConfig() *Config { APIKey: "", }, + // Avian - https://avian.io + { + ModelName: "deepseek-v3.2", + Model: "avian/deepseek/deepseek-v3.2", + APIBase: "https://api.avian.io/v1", + APIKey: "", + }, + { + ModelName: "kimi-k2.5", + Model: "avian/moonshotai/kimi-k2.5", + APIBase: "https://api.avian.io/v1", + APIKey: "", + }, + // VLLM (local) - http://localhost:8000 { ModelName: "local-model", diff --git a/pkg/config/migration.go b/pkg/config/migration.go index 772f714fd..4a17dd6c9 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -373,6 +373,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { }, true }, }, + { + providerNames: []string{"avian"}, + protocol: "avian", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Avian.APIKey == "" && p.Avian.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "avian", + Model: "avian/deepseek/deepseek-v3.2", + APIKey: p.Avian.APIKey, + APIBase: p.Avian.APIBase, + Proxy: p.Avian.Proxy, + RequestTimeout: p.Avian.RequestTimeout, + }, true + }, + }, } // Process each provider migration diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index e24e9fa1d..dc86beb41 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -160,6 +160,7 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) { Antigravity: ProviderConfig{AuthMethod: "oauth"}, Qwen: ProviderConfig{APIKey: "key17"}, Mistral: ProviderConfig{APIKey: "key18"}, + Avian: ProviderConfig{APIKey: "key19"}, }, } diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index 5b3e42b9e..a0d09a835 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -181,6 +181,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.model = "deepseek-chat" } } + case "avian": + if cfg.Providers.Avian.APIKey != "" { + sel.apiKey = cfg.Providers.Avian.APIKey + sel.apiBase = cfg.Providers.Avian.APIBase + sel.proxy = cfg.Providers.Avian.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.avian.io/v1" + } + } case "mistral": if cfg.Providers.Mistral.APIKey != "" { sel.apiKey = cfg.Providers.Mistral.APIKey @@ -300,6 +309,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if sel.apiBase == "" { sel.apiBase = "https://api.mistral.ai/v1" } + case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "": + sel.apiKey = cfg.Providers.Avian.APIKey + sel.apiBase = cfg.Providers.Avian.APIBase + sel.proxy = cfg.Providers.Avian.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.avian.io/v1" + } case cfg.Providers.VLLM.APIBase != "": sel.apiKey = cfg.Providers.VLLM.APIKey sel.apiBase = cfg.Providers.VLLM.APIBase diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 155317a3b..c05fb0ad4 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", - "volcengine", "vllm", "qwen", "mistral": + "volcengine", "vllm", "qwen", "mistral", "avian": // All other OpenAI-compatible HTTP providers if cfg.APIKey == "" && cfg.APIBase == "" { return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol) @@ -208,6 +208,8 @@ func getDefaultAPIBase(protocol string) string { return "http://localhost:8000/v1" case "mistral": return "https://api.mistral.ai/v1" + case "avian": + return "https://api.avian.io/v1" default: return "" } From b9ee9b33f50dd460d0bb56d212d15c0c9fe04b03 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Wed, 4 Mar 2026 19:34:08 +0100 Subject: [PATCH 02/11] prevent audio as image url --- pkg/providers/openai_compat/provider.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index ff9109e96..1904ee153 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -323,12 +323,14 @@ func serializeMessages(messages []Message) []any { }) } for _, mediaURL := range m.Media { - parts = append(parts, map[string]any{ - "type": "image_url", - "image_url": map[string]any{ - "url": mediaURL, - }, - }) + if strings.HasPrefix(mediaURL, "data:image/") { + parts = append(parts, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": mediaURL, + }, + }) + } } msg := map[string]any{ From 0c97cb30d84dfeb6b273c64bf6fe15abeea7a57a Mon Sep 17 00:00:00 2001 From: Kyle D Date: Wed, 4 Mar 2026 20:14:43 +0000 Subject: [PATCH 03/11] fix: update provider count in migration test to include Avian The TestConvertProvidersToModelList_AllProviders test expected 19 providers but adding Avian brings the total to 20. Co-Authored-By: Claude Opus 4.6 --- pkg/config/migration_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index dc86beb41..67ad73db9 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -166,9 +166,9 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) { result := ConvertProvidersToModelList(cfg) - // All 19 providers should be converted - if len(result) != 19 { - t.Errorf("len(result) = %d, want 19", len(result)) + // All 20 providers should be converted + if len(result) != 20 { + t.Errorf("len(result) = %d, want 20", len(result)) } } From 204038ec6022bfb362ef83821bce3192cbca4638 Mon Sep 17 00:00:00 2001 From: Larry Koo Date: Thu, 5 Mar 2026 09:51:18 +0800 Subject: [PATCH 04/11] feat: add extended thinking support for Anthropic models (#1076) * feat: add extended thinking support for Anthropic models Support configurable thinking levels (off/low/medium/high/xhigh/adaptive) via `agents.defaults.thinking_level` config field. - "adaptive": uses Anthropic's adaptive thinking API (Claude 4.6+) - "low/medium/high/xhigh": uses budget_tokens (all thinking-capable models) - "off": disables thinking (default) API constraints handled: - Temperature cleared when thinking is enabled - budget_tokens clamped to max_tokens-1 - Thinking response blocks parsed into Reasoning field Relates to #645, #966 * fix: address PR review feedback for thinking support - Add ThinkingCapable interface for provider capability detection - Warn when thinking_level is set but provider doesn't support it - Warn when temperature is cleared due to thinking enabled - Adjust budget values per Anthropic best practices (medium=16K, xhigh=64K) - Add budget clamp warning and 80% threshold warning - Add parseResponse thinking block tests - Add thinking_level field to config.example.json * refactor: move ThinkingLevel from AgentDefaults to ModelConfig Thinking is a model-level capability, not a global agent property. Per-model config avoids silent ignoring on non-Anthropic providers and eliminates spurious warning logs in multi-provider setups. Addresses PR #1076 review feedback from @yinwm. --- config/config.example.json | 3 +- pkg/agent/instance.go | 8 + pkg/agent/loop.go | 34 ++-- pkg/agent/thinking.go | 39 +++++ pkg/agent/thinking_test.go | 35 ++++ pkg/config/config.go | 1 + pkg/providers/anthropic/provider.go | 79 +++++++++ pkg/providers/anthropic/thinking_test.go | 212 +++++++++++++++++++++++ pkg/providers/types.go | 7 + 9 files changed, 401 insertions(+), 17 deletions(-) create mode 100644 pkg/agent/thinking.go create mode 100644 pkg/agent/thinking_test.go create mode 100644 pkg/providers/anthropic/thinking_test.go diff --git a/config/config.example.json b/config/config.example.json index f6e7de12a..c59a39885 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -22,7 +22,8 @@ "model_name": "claude-sonnet-4.6", "model": "anthropic/claude-sonnet-4.6", "api_key": "sk-ant-your-key", - "api_base": "https://api.anthropic.com/v1" + "api_base": "https://api.anthropic.com/v1", + "thinking_level": "high" }, { "model_name": "gemini", diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index ed25f537f..1e18b6f64 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -26,6 +26,7 @@ type AgentInstance struct { MaxIterations int MaxTokens int Temperature float64 + ThinkingLevel ThinkingLevel ContextWindow int SummarizeMessageThreshold int SummarizeTokenPercent int @@ -103,6 +104,12 @@ func NewAgentInstance( temperature = *defaults.Temperature } + var thinkingLevelStr string + if mc, err := cfg.GetModelConfig(model); err == nil { + thinkingLevelStr = mc.ThinkingLevel + } + thinkingLevel := parseThinkingLevel(thinkingLevelStr) + summarizeMessageThreshold := defaults.SummarizeMessageThreshold if summarizeMessageThreshold == 0 { summarizeMessageThreshold = 20 @@ -169,6 +176,7 @@ func NewAgentInstance( MaxIterations: maxIter, MaxTokens: maxTokens, Temperature: temperature, + ThinkingLevel: thinkingLevel, ContextWindow: maxTokens, SummarizeMessageThreshold: summarizeMessageThreshold, SummarizeTokenPercent: summarizeTokenPercent, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 7ce2a37a6..509f61099 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -834,23 +834,29 @@ func (al *AgentLoop) runLLMIteration( var response *providers.LLMResponse var err error + llmOpts := map[string]any{ + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, + "prompt_cache_key": agent.ID, + } + // parseThinkingLevel guarantees ThinkingOff for empty/unknown values, + // so checking != ThinkingOff is sufficient. + if agent.ThinkingLevel != ThinkingOff { + if tc, ok := agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() { + llmOpts["thinking_level"] = string(agent.ThinkingLevel) + } else { + logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring", + map[string]any{"agent_id": agent.ID, "thinking_level": string(agent.ThinkingLevel)}) + } + } + callLLM := func() (*providers.LLMResponse, error) { if len(agent.Candidates) > 1 && al.fallback != nil { fbResult, fbErr := al.fallback.Execute( ctx, agent.Candidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { - return agent.Provider.Chat( - ctx, - messages, - providerToolDefs, - model, - map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "prompt_cache_key": agent.ID, - }, - ) + return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts) }, ) if fbErr != nil { @@ -866,11 +872,7 @@ func (al *AgentLoop) runLLMIteration( } return fbResult.Response, nil } - return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "prompt_cache_key": agent.ID, - }) + return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts) } // Retry loop for context/token errors diff --git a/pkg/agent/thinking.go b/pkg/agent/thinking.go new file mode 100644 index 000000000..015b69282 --- /dev/null +++ b/pkg/agent/thinking.go @@ -0,0 +1,39 @@ +package agent + +import "strings" + +// ThinkingLevel controls how the provider sends thinking parameters. +// +// - "adaptive": sends {thinking: {type: "adaptive"}} + output_config.effort (Claude 4.6+) +// - "low"/"medium"/"high"/"xhigh": sends {thinking: {type: "enabled", budget_tokens: N}} (all models) +// - "off": disables thinking +type ThinkingLevel string + +const ( + ThinkingOff ThinkingLevel = "off" + ThinkingLow ThinkingLevel = "low" + ThinkingMedium ThinkingLevel = "medium" + ThinkingHigh ThinkingLevel = "high" + ThinkingXHigh ThinkingLevel = "xhigh" + ThinkingAdaptive ThinkingLevel = "adaptive" +) + +// parseThinkingLevel normalizes a config string to a ThinkingLevel. +// Case-insensitive and whitespace-tolerant for user-facing config values. +// Returns ThinkingOff for unknown or empty values. +func parseThinkingLevel(level string) ThinkingLevel { + switch strings.ToLower(strings.TrimSpace(level)) { + case "adaptive": + return ThinkingAdaptive + case "low": + return ThinkingLow + case "medium": + return ThinkingMedium + case "high": + return ThinkingHigh + case "xhigh": + return ThinkingXHigh + default: + return ThinkingOff + } +} diff --git a/pkg/agent/thinking_test.go b/pkg/agent/thinking_test.go new file mode 100644 index 000000000..be3a68c33 --- /dev/null +++ b/pkg/agent/thinking_test.go @@ -0,0 +1,35 @@ +package agent + +import "testing" + +func TestParseThinkingLevel(t *testing.T) { + tests := []struct { + name string + input string + want ThinkingLevel + }{ + {"off", "off", ThinkingOff}, + {"empty", "", ThinkingOff}, + {"low", "low", ThinkingLow}, + {"medium", "medium", ThinkingMedium}, + {"high", "high", ThinkingHigh}, + {"xhigh", "xhigh", ThinkingXHigh}, + {"adaptive", "adaptive", ThinkingAdaptive}, + {"unknown", "unknown", ThinkingOff}, + // Case-insensitive and whitespace-tolerant + {"upper_Medium", "Medium", ThinkingMedium}, + {"upper_HIGH", "HIGH", ThinkingHigh}, + {"mixed_Adaptive", "Adaptive", ThinkingAdaptive}, + {"leading_space", " high", ThinkingHigh}, + {"trailing_space", "low ", ThinkingLow}, + {"both_spaces", " medium ", ThinkingMedium}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseThinkingLevel(tt.input); got != tt.want { + t.Errorf("parseThinkingLevel(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 7165df0d0..3cfebf5e8 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -507,6 +507,7 @@ type ModelConfig struct { RPM int `json:"rpm,omitempty"` // Requests per minute limit MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens") RequestTimeout int `json:"request_timeout,omitempty"` + ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive } // Validate checks if the ModelConfig has all required fields. diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index 1bb15f771..1b250b9b4 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -31,6 +31,9 @@ type Provider struct { baseURL string } +// SupportsThinking implements providers.ThinkingCapable. +func (p *Provider) SupportsThinking() bool { return true } + func NewProvider(token string) *Provider { return NewProviderWithBaseURL(token, "") } @@ -182,9 +185,80 @@ func buildParams( params.Tools = translateTools(tools) } + // Extended Thinking / Adaptive Thinking + // The thinking_level value directly determines the API parameter format: + // "adaptive" → {thinking: {type: "adaptive"}} + output_config.effort + // "low/medium/high/xhigh" → {thinking: {type: "enabled", budget_tokens: N}} + if level, ok := options["thinking_level"].(string); ok && level != "" && level != "off" { + applyThinkingConfig(¶ms, level) + } + return params, nil } +// applyThinkingConfig sets thinking parameters based on the level value. +// "adaptive" uses the adaptive thinking API (Claude 4.6+). +// All other levels use budget_tokens which is universally supported. +// +// Anthropic API constraint: temperature must not be set when thinking is enabled. +// budget_tokens must be strictly less than max_tokens. +func applyThinkingConfig(params *anthropic.MessageNewParams, level string) { + // Anthropic API rejects requests with temperature set alongside thinking. + // Reset to zero value (omitted from JSON serialization). + if params.Temperature.Valid() { + log.Printf("anthropic: temperature cleared because thinking is enabled (level=%s)", level) + } + params.Temperature = anthropic.MessageNewParams{}.Temperature + + if level == "adaptive" { + adaptive := anthropic.NewThinkingConfigAdaptiveParam() + params.Thinking = anthropic.ThinkingConfigParamUnion{OfAdaptive: &adaptive} + params.OutputConfig = anthropic.OutputConfigParam{ + Effort: anthropic.OutputConfigEffortHigh, + } + return + } + + budget := int64(levelToBudget(level)) + if budget <= 0 { + return + } + + // budget_tokens must be < max_tokens; clamp to respect user's max_tokens setting. + if budget >= params.MaxTokens { + log.Printf("anthropic: budget_tokens (%d) clamped to %d (max_tokens-1)", budget, params.MaxTokens-1) + budget = params.MaxTokens - 1 + } else if budget > params.MaxTokens*80/100 { + log.Printf("anthropic: thinking budget (%d) exceeds 80%% of max_tokens (%d), output may be truncated", + budget, params.MaxTokens) + } + params.Thinking = anthropic.ThinkingConfigParamOfEnabled(budget) +} + +// levelToBudget maps a thinking level to budget_tokens. +// Values are based on Anthropic's recommendations and community best practices: +// +// low = 4,096 — simple reasoning, quick debugging (Claude Code "think") +// medium = 16,384 — Anthropic recommended sweet spot for most tasks +// high = 32,000 — complex architecture, deep analysis (diminishing returns above this) +// xhigh = 64,000 — extreme reasoning, research problems, benchmarks +// +// Note: For Claude 4.6+, prefer adaptive thinking over manual budget_tokens. +func levelToBudget(level string) int { + switch level { + case "low": + return 4096 + case "medium": + return 16384 + case "high": + return 32000 + case "xhigh": + return 64000 + default: + return 0 + } +} + func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam { result := make([]anthropic.ToolUnionParam, 0, len(tools)) for _, t := range tools { @@ -213,10 +287,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam { func parseResponse(resp *anthropic.Message) *LLMResponse { var content strings.Builder + var reasoning strings.Builder var toolCalls []ToolCall for _, block := range resp.Content { switch block.Type { + case "thinking": + tb := block.AsThinking() + reasoning.WriteString(tb.Thinking) case "text": tb := block.AsText() content.WriteString(tb.Text) @@ -247,6 +325,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse { return &LLMResponse{ Content: content.String(), + Reasoning: reasoning.String(), ToolCalls: toolCalls, FinishReason: finishReason, Usage: &UsageInfo{ diff --git a/pkg/providers/anthropic/thinking_test.go b/pkg/providers/anthropic/thinking_test.go new file mode 100644 index 000000000..e69a3869e --- /dev/null +++ b/pkg/providers/anthropic/thinking_test.go @@ -0,0 +1,212 @@ +package anthropicprovider + +import ( + "encoding/json" + "testing" + + "github.com/anthropics/anthropic-sdk-go" +) + +func TestApplyThinkingConfig_Adaptive(t *testing.T) { + params := anthropic.MessageNewParams{ + MaxTokens: 16000, + Temperature: anthropic.Float(0.7), + } + applyThinkingConfig(¶ms, "adaptive") + + if params.Thinking.OfAdaptive == nil { + t.Fatal("expected adaptive thinking") + } + if params.Thinking.OfEnabled != nil { + t.Error("should not set enabled thinking in adaptive mode") + } + if params.OutputConfig.Effort != anthropic.OutputConfigEffortHigh { + t.Errorf("effort = %q, want %q", params.OutputConfig.Effort, anthropic.OutputConfigEffortHigh) + } + if params.Temperature.Valid() { + t.Error("temperature should be cleared when thinking is enabled") + } +} + +func TestApplyThinkingConfig_BudgetLevels(t *testing.T) { + tests := []struct { + level string + wantBudget int64 + }{ + {"low", 4096}, + {"medium", 16384}, + {"high", 32000}, + {"xhigh", 64000}, + } + + for _, tt := range tests { + t.Run(tt.level, func(t *testing.T) { + params := anthropic.MessageNewParams{ + MaxTokens: 200000, + Temperature: anthropic.Float(0.5), + } + applyThinkingConfig(¶ms, tt.level) + + if params.Thinking.OfEnabled == nil { + t.Fatal("expected enabled thinking") + } + if params.Thinking.OfAdaptive != nil { + t.Error("should not set adaptive thinking") + } + if params.Thinking.OfEnabled.BudgetTokens != tt.wantBudget { + t.Errorf("budget_tokens = %d, want %d", params.Thinking.OfEnabled.BudgetTokens, tt.wantBudget) + } + if params.OutputConfig.Effort != "" { + t.Errorf("effort = %q, want empty", params.OutputConfig.Effort) + } + if params.Temperature.Valid() { + t.Error("temperature should be cleared when thinking is enabled") + } + }) + } +} + +func TestApplyThinkingConfig_BudgetClamp(t *testing.T) { + // budget_tokens must be < max_tokens; clamp budget down to respect user's max_tokens. + params := anthropic.MessageNewParams{MaxTokens: 4096} + applyThinkingConfig(¶ms, "high") // budget=32000 > maxTokens=4096 + + if params.Thinking.OfEnabled == nil { + t.Fatal("expected enabled thinking") + } + if params.Thinking.OfEnabled.BudgetTokens != 4095 { + t.Errorf("budget_tokens = %d, want 4095 (maxTokens-1)", params.Thinking.OfEnabled.BudgetTokens) + } + if params.MaxTokens != 4096 { + t.Errorf("max_tokens should not be modified, got %d", params.MaxTokens) + } +} + +func TestApplyThinkingConfig_UnknownLevel(t *testing.T) { + params := anthropic.MessageNewParams{MaxTokens: 16000} + applyThinkingConfig(¶ms, "unknown") + + if params.Thinking.OfEnabled != nil { + t.Error("should not set enabled thinking for unknown level") + } + if params.Thinking.OfAdaptive != nil { + t.Error("should not set adaptive thinking for unknown level") + } +} + +func TestLevelToBudget(t *testing.T) { + tests := []struct { + name string + level string + want int + }{ + {"low", "low", 4096}, + {"medium", "medium", 16384}, + {"high", "high", 32000}, + {"xhigh", "xhigh", 64000}, + {"off", "off", 0}, + {"empty", "", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := levelToBudget(tt.level); got != tt.want { + t.Errorf("levelToBudget(%q) = %d, want %d", tt.level, got, tt.want) + } + }) + } +} + +func TestBuildParams_ThinkingClearsTemperature(t *testing.T) { + msgs := []Message{{Role: "user", Content: "hello"}} + opts := map[string]any{ + "max_tokens": 200000, + "temperature": 0.8, + "thinking_level": "medium", + } + + params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts) + if err != nil { + t.Fatal(err) + } + + if params.Temperature.Valid() { + t.Error("temperature should be cleared when thinking_level is set") + } + if params.Thinking.OfEnabled == nil { + t.Fatal("expected enabled thinking") + } + if params.Thinking.OfEnabled.BudgetTokens != 16384 { + t.Errorf("budget_tokens = %d, want 16384", params.Thinking.OfEnabled.BudgetTokens) + } +} + +// unmarshalBlocks constructs []ContentBlockUnion via JSON round-trip so that +// the internal JSON.raw field is populated (required by AsText/AsThinking). +func unmarshalBlocks(t *testing.T, jsonStr string) []anthropic.ContentBlockUnion { + t.Helper() + var blocks []anthropic.ContentBlockUnion + if err := json.Unmarshal([]byte(jsonStr), &blocks); err != nil { + t.Fatalf("unmarshalBlocks: %v", err) + } + return blocks +} + +func TestParseResponse_ThinkingBlock(t *testing.T) { + resp := &anthropic.Message{ + Content: unmarshalBlocks(t, `[ + {"type":"thinking","thinking":"Let me reason step by step...","signature":"sig"}, + {"type":"text","text":"The answer is 42."} + ]`), + StopReason: anthropic.StopReasonEndTurn, + } + + result := parseResponse(resp) + + if result.Reasoning != "Let me reason step by step..." { + t.Errorf("Reasoning = %q, want thinking content", result.Reasoning) + } + if result.Content != "The answer is 42." { + t.Errorf("Content = %q, want text content", result.Content) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", result.FinishReason) + } +} + +func TestParseResponse_NoThinkingBlock(t *testing.T) { + resp := &anthropic.Message{ + Content: unmarshalBlocks(t, `[ + {"type":"text","text":"Just a normal response."} + ]`), + StopReason: anthropic.StopReasonEndTurn, + } + + result := parseResponse(resp) + + if result.Reasoning != "" { + t.Errorf("Reasoning = %q, want empty", result.Reasoning) + } + if result.Content != "Just a normal response." { + t.Errorf("Content = %q, want text content", result.Content) + } +} + +func TestBuildParams_NoThinkingKeepsTemperature(t *testing.T) { + msgs := []Message{{Role: "user", Content: "hello"}} + opts := map[string]any{ + "temperature": 0.8, + } + + params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts) + if err != nil { + t.Fatal(err) + } + + if !params.Temperature.Valid() { + t.Error("temperature should be preserved when thinking is not set") + } + if params.Temperature.Value != 0.8 { + t.Errorf("temperature = %f, want 0.8", params.Temperature.Value) + } +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index f0c168bc6..68bbd1e65 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -37,6 +37,13 @@ type StatefulProvider interface { Close() } +// ThinkingCapable is an optional interface for providers that support +// extended thinking (e.g. Anthropic). Used by the agent loop to warn +// when thinking_level is configured but the active provider cannot use it. +type ThinkingCapable interface { + SupportsThinking() bool +} + // FailoverReason classifies why an LLM request failed for fallback decisions. type FailoverReason string From aef1e8e8c489f427558d2004b1aecae68808d77a Mon Sep 17 00:00:00 2001 From: Boris Bliznioukov Date: Thu, 5 Mar 2026 02:57:33 +0100 Subject: [PATCH 05/11] fix: eliminate data races on shared tool instances (#1080) * fix: eliminate data races on shared tool instances Signed-off-by: Boris Bliznioukov * fix: remove unused indirect dependency on github.com/gdamore/tcell/v2 Signed-off-by: Boris Bliznioukov * fix: reviewer comments improve context handling for tool execution and ensure defaults for non-conversation callers Signed-off-by: Boris Bliznioukov --------- Signed-off-by: Boris Bliznioukov --- go.mod | 1 - pkg/agent/loop.go | 45 +++++------------ pkg/agent/loop_test.go | 66 +++++-------------------- pkg/tools/base.go | 88 +++++++++++++++++++-------------- pkg/tools/cron.go | 22 ++------- pkg/tools/message.go | 18 +++---- pkg/tools/message_test.go | 17 +++---- pkg/tools/registry.go | 28 ++++++----- pkg/tools/registry_test.go | 54 +++++++++++--------- pkg/tools/spawn.go | 44 ++++++++++------- pkg/tools/subagent.go | 26 +++++----- pkg/tools/subagent_tool_test.go | 24 ++------- 12 files changed, 181 insertions(+), 252 deletions(-) diff --git a/go.mod b/go.mod index c1172937c..238bd405c 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect - github.com/gdamore/tcell/v2 v2.13.8 // indirect github.com/h2non/filetype v1.1.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 509f61099..263eeb4dd 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -543,8 +543,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) // Reset message-tool state for this round so we don't skip publishing due to a previous round. if tool, ok := agent.Tools.Get("message"); ok { - if mt, ok := tool.(tools.ContextualTool); ok { - mt.SetContext(msg.Channel, msg.ChatID) + if resetter, ok := tool.(interface{ ResetSentInRound() }); ok { + resetter.ResetSentInRound() } } @@ -659,10 +659,7 @@ func (al *AgentLoop) runAgentLoop( } } - // 1. Update tool contexts - al.updateToolContexts(agent, opts.Channel, opts.ChatID) - - // 2. Build messages (skip history for heartbeat) + // 1. Build messages (skip history for heartbeat) var history []providers.Message var summary string if !opts.NoHistory { @@ -682,10 +679,10 @@ func (al *AgentLoop) runAgentLoop( maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize() messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) - // 3. Save user message to session + // 2. Save user message to session agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) - // 4. Run LLM iteration loop + // 3. Run LLM iteration loop finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err @@ -694,21 +691,21 @@ func (al *AgentLoop) runAgentLoop( // If last tool had ForUser content and we already sent it, we might not need to send final response // This is controlled by the tool's Silent flag and ForUser content - // 5. Handle empty response + // 4. Handle empty response if finalContent == "" { finalContent = opts.DefaultResponse } - // 6. Save final assistant message to session + // 5. Save final assistant message to session agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent) agent.Sessions.Save(opts.SessionKey) - // 7. Optional: summarization + // 6. Optional: summarization if opts.EnableSummary { al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) } - // 8. Optional: send response via bus + // 7. Optional: send response via bus if opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, @@ -717,7 +714,7 @@ func (al *AgentLoop) runAgentLoop( }) } - // 9. Log response + // 8. Log response responsePreview := utils.Truncate(finalContent, 120) logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), map[string]any{ @@ -1059,7 +1056,7 @@ func (al *AgentLoop) runLLMIteration( "iteration": iteration, }) - // Create async callback for tools that implement AsyncTool + // Create async callback for tools that implement AsyncExecutor asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { if !result.Silent && result.ForUser != "" { logger.InfoCF("agent", "Async tool completed, agent will handle notification", @@ -1141,26 +1138,6 @@ func (al *AgentLoop) runLLMIteration( return finalContent, iteration, nil } -// updateToolContexts updates the context for tools that need channel/chatID info. -func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) { - // Use ContextualTool interface instead of type assertions - if tool, ok := agent.Tools.Get("message"); ok { - if mt, ok := tool.(tools.ContextualTool); ok { - mt.SetContext(channel, chatID) - } - } - if tool, ok := agent.Tools.Get("spawn"); ok { - if st, ok := tool.(tools.ContextualTool); ok { - st.SetContext(channel, chatID) - } - } - if tool, ok := agent.Tools.Get("subagent"); ok { - if st, ok := tool.(tools.ContextualTool); ok { - st.SetContext(channel, chatID) - } - } -} - // maybeSummarize triggers summarization if the session history exceeds thresholds. func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { newHistory := agent.Sessions.GetHistory(sessionKey) diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 023286f02..4ab6b4542 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -164,35 +164,21 @@ func TestToolRegistry_ToolRegistration(t *testing.T) { } } -// TestToolContext_Updates verifies tool context is updated with channel/chatID +// TestToolContext_Updates verifies tool context helpers work correctly func TestToolContext_Updates(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agent-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) + ctx := tools.WithToolContext(context.Background(), "telegram", "chat-42") - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "test-model", - MaxTokens: 4096, - MaxToolIterations: 10, - }, - }, + if got := tools.ToolChannel(ctx); got != "telegram" { + t.Errorf("expected channel 'telegram', got %q", got) + } + if got := tools.ToolChatID(ctx); got != "chat-42" { + t.Errorf("expected chatID 'chat-42', got %q", got) } - msgBus := bus.NewMessageBus() - provider := &simpleMockProvider{response: "OK"} - _ = NewAgentLoop(cfg, msgBus, provider) - - // Verify that ContextualTool interface is defined and can be implemented - // This test validates the interface contract exists - ctxTool := &mockContextualTool{} - - // Verify the tool implements the interface correctly - var _ tools.ContextualTool = ctxTool + // Empty context returns empty strings + if got := tools.ToolChannel(context.Background()); got != "" { + t.Errorf("expected empty channel from bare context, got %q", got) + } } // TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved @@ -359,36 +345,6 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool return tools.SilentResult("Custom tool executed") } -// mockContextualTool tracks context updates -type mockContextualTool struct { - lastChannel string - lastChatID string -} - -func (m *mockContextualTool) Name() string { - return "mock_contextual" -} - -func (m *mockContextualTool) Description() string { - return "Mock contextual tool" -} - -func (m *mockContextualTool) Parameters() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{}, - } -} - -func (m *mockContextualTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { - return tools.SilentResult("Contextual tool executed") -} - -func (m *mockContextualTool) SetContext(channel, chatID string) { - m.lastChannel = channel - m.lastChatID = chatID -} - // testHelper executes a message and returns the response type testHelper struct { al *AgentLoop diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 770d8cb04..ec743e164 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -10,11 +10,38 @@ type Tool interface { Execute(ctx context.Context, args map[string]any) *ToolResult } -// ContextualTool is an optional interface that tools can implement -// to receive the current message context (channel, chatID) -type ContextualTool interface { - Tool - SetContext(channel, chatID string) +// --- Request-scoped tool context (channel / chatID) --- +// +// Carried via context.Value so that concurrent tool calls each receive +// their own immutable copy — no mutable state on singleton tool instances. +// +// Keys are unexported pointer-typed vars — guaranteed collision-free, +// and only accessible through the helper functions below. + +type toolCtxKey struct{ name string } + +var ( + ctxKeyChannel = &toolCtxKey{"channel"} + ctxKeyChatID = &toolCtxKey{"chatID"} +) + +// WithToolContext returns a child context carrying channel and chatID. +func WithToolContext(ctx context.Context, channel, chatID string) context.Context { + ctx = context.WithValue(ctx, ctxKeyChannel, channel) + ctx = context.WithValue(ctx, ctxKeyChatID, chatID) + return ctx +} + +// ToolChannel extracts the channel from ctx, or "" if unset. +func ToolChannel(ctx context.Context) string { + v, _ := ctx.Value(ctxKeyChannel).(string) + return v +} + +// ToolChatID extracts the chatID from ctx, or "" if unset. +func ToolChatID(ctx context.Context) string { + v, _ := ctx.Value(ctxKeyChatID).(string) + return v } // AsyncCallback is a function type that async tools use to notify completion. @@ -22,51 +49,36 @@ type ContextualTool interface { // // The ctx parameter allows the callback to be canceled if the agent is shutting down. // The result parameter contains the tool's execution result. -// -// Example usage in an async tool: -// -// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { -// // Start async work in background -// go func() { -// result := doAsyncWork() -// if t.callback != nil { -// t.callback(ctx, result) -// } -// }() -// return AsyncResult("Async task started") -// } type AsyncCallback func(ctx context.Context, result *ToolResult) -// AsyncTool is an optional interface that tools can implement to support +// AsyncExecutor is an optional interface that tools can implement to support // asynchronous execution with completion callbacks. // -// Async tools return immediately with an AsyncResult, then notify completion -// via the callback set by SetCallback. +// Unlike the old AsyncTool pattern (SetCallback + Execute), AsyncExecutor +// receives the callback as a parameter of ExecuteAsync. This eliminates the +// data race where concurrent calls could overwrite each other's callbacks +// on a shared tool instance. // // This is useful for: -// - Long-running operations that shouldn't block the agent loop -// - Subagent spawns that complete independently -// - Background tasks that need to report results later +// - Long-running operations that shouldn't block the agent loop +// - Subagent spawns that complete independently +// - Background tasks that need to report results later // // Example: // -// type SpawnTool struct { -// callback AsyncCallback -// } -// -// func (t *SpawnTool) SetCallback(cb AsyncCallback) { -// t.callback = cb -// } -// -// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { -// go t.runSubagent(ctx, args) +// func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult { +// go func() { +// result := t.runSubagent(ctx, args) +// if cb != nil { cb(ctx, result) } +// }() // return AsyncResult("Subagent spawned, will report back") // } -type AsyncTool interface { +type AsyncExecutor interface { Tool - // SetCallback registers a callback function to be invoked when the async operation completes. - // The callback will be called from a goroutine and should handle thread-safety if needed. - SetCallback(cb AsyncCallback) + // ExecuteAsync runs the tool asynchronously. The callback cb will be + // invoked (possibly from another goroutine) when the async operation + // completes. cb is guaranteed to be non-nil by the caller (registry). + ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult } func ToolToSchema(tool Tool) map[string]any { diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 6888d1326..31ac9ab88 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "strings" - "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" @@ -24,9 +23,6 @@ type CronTool struct { executor JobExecutor msgBus *bus.MessageBus execTool *ExecTool - channel string - chatID string - mu sync.RWMutex } // NewCronTool creates a new CronTool @@ -102,14 +98,6 @@ func (t *CronTool) Parameters() map[string]any { } } -// SetContext sets the current session context for job creation -func (t *CronTool) SetContext(channel, chatID string) { - t.mu.Lock() - defer t.mu.Unlock() - t.channel = channel - t.chatID = chatID -} - // Execute runs the tool with the given arguments func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult { action, ok := args["action"].(string) @@ -119,7 +107,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult switch action { case "add": - return t.addJob(args) + return t.addJob(ctx, args) case "list": return t.listJobs() case "remove": @@ -133,11 +121,9 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult } } -func (t *CronTool) addJob(args map[string]any) *ToolResult { - t.mu.RLock() - channel := t.channel - chatID := t.chatID - t.mu.RUnlock() +func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult { + channel := ToolChannel(ctx) + chatID := ToolChatID(ctx) if channel == "" || chatID == "" { return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.") diff --git a/pkg/tools/message.go b/pkg/tools/message.go index d1e4a373e..438ceeddd 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -9,10 +9,8 @@ import ( type SendCallback func(channel, chatID, content string) error type MessageTool struct { - sendCallback SendCallback - defaultChannel string - defaultChatID string - sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round + sendCallback SendCallback + sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round } func NewMessageTool() *MessageTool { @@ -48,10 +46,10 @@ func (t *MessageTool) Parameters() map[string]any { } } -func (t *MessageTool) SetContext(channel, chatID string) { - t.defaultChannel = channel - t.defaultChatID = chatID - t.sentInRound.Store(false) // Reset send tracking for new processing round +// ResetSentInRound resets the per-round send tracker. +// Called by the agent loop at the start of each inbound message processing round. +func (t *MessageTool) ResetSentInRound() { + t.sentInRound.Store(false) } // HasSentInRound returns true if the message tool sent a message during the current round. @@ -73,10 +71,10 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes chatID, _ := args["chat_id"].(string) if channel == "" { - channel = t.defaultChannel + channel = ToolChannel(ctx) } if chatID == "" { - chatID = t.defaultChatID + chatID = ToolChatID(ctx) } if channel == "" || chatID == "" { diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go index 717c1117b..05630972e 100644 --- a/pkg/tools/message_test.go +++ b/pkg/tools/message_test.go @@ -8,7 +8,6 @@ import ( func TestMessageTool_Execute_Success(t *testing.T) { tool := NewMessageTool() - tool.SetContext("test-channel", "test-chat-id") var sentChannel, sentChatID, sentContent string tool.SetSendCallback(func(channel, chatID, content string) error { @@ -18,7 +17,7 @@ func TestMessageTool_Execute_Success(t *testing.T) { return nil }) - ctx := context.Background() + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") args := map[string]any{ "content": "Hello, world!", } @@ -60,7 +59,6 @@ func TestMessageTool_Execute_Success(t *testing.T) { func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { tool := NewMessageTool() - tool.SetContext("default-channel", "default-chat-id") var sentChannel, sentChatID string tool.SetSendCallback(func(channel, chatID, content string) error { @@ -69,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { return nil }) - ctx := context.Background() + ctx := WithToolContext(context.Background(), "default-channel", "default-chat-id") args := map[string]any{ "content": "Test message", "channel": "custom-channel", @@ -96,14 +94,13 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { func TestMessageTool_Execute_SendFailure(t *testing.T) { tool := NewMessageTool() - tool.SetContext("test-channel", "test-chat-id") sendErr := errors.New("network error") tool.SetSendCallback(func(channel, chatID, content string) error { return sendErr }) - ctx := context.Background() + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") args := map[string]any{ "content": "Test message", } @@ -133,9 +130,8 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) { func TestMessageTool_Execute_MissingContent(t *testing.T) { tool := NewMessageTool() - tool.SetContext("test-channel", "test-chat-id") - ctx := context.Background() + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") args := map[string]any{} // content missing result := tool.Execute(ctx, args) @@ -151,7 +147,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) { func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { tool := NewMessageTool() - // No SetContext called, so defaultChannel and defaultChatID are empty + // No WithToolContext — channel/chatID are empty tool.SetSendCallback(func(channel, chatID, content string) error { return nil @@ -175,10 +171,9 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { func TestMessageTool_Execute_NotConfigured(t *testing.T) { tool := NewMessageTool() - tool.SetContext("test-channel", "test-chat-id") // No SetSendCallback called - ctx := context.Background() + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") args := map[string]any{ "content": "Test message", } diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 0ba983e02..ca8436c67 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -45,8 +45,9 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string } // ExecuteWithContext executes a tool with channel/chatID context and optional async callback. -// If the tool implements AsyncTool and a non-nil callback is provided, -// the callback will be set on the tool before execution. +// If the tool implements AsyncExecutor and a non-nil callback is provided, +// ExecuteAsync is called instead of Execute — the callback is a parameter, +// never stored as mutable state on the tool. func (r *ToolRegistry) ExecuteWithContext( ctx context.Context, name string, @@ -69,22 +70,23 @@ func (r *ToolRegistry) ExecuteWithContext( return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) } - // If tool implements ContextualTool, set context - if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" { - contextualTool.SetContext(channel, chatID) - } + // Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx). + // Always inject — tools validate what they require. + ctx = WithToolContext(ctx, channel, chatID) - // If tool implements AsyncTool and callback is provided, set callback - if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil { - asyncTool.SetCallback(asyncCallback) - logger.DebugCF("tool", "Async callback injected", + // If tool implements AsyncExecutor and callback is provided, use ExecuteAsync. + // The callback is a call parameter, not mutable state on the tool instance. + var result *ToolResult + start := time.Now() + if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil { + logger.DebugCF("tool", "Executing async tool via ExecuteAsync", map[string]any{ "tool": name, }) + result = asyncExec.ExecuteAsync(ctx, args, asyncCallback) + } else { + result = tool.Execute(ctx, args) } - - start := time.Now() - result := tool.Execute(ctx, args) duration := time.Since(start) // Log based on result type diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 8fe88ca78..92d7d5abd 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -25,24 +25,24 @@ func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolRes return m.result } -type mockCtxTool struct { +type mockContextAwareTool struct { mockRegistryTool - channel string - chatID string + lastCtx context.Context } -func (m *mockCtxTool) SetContext(channel, chatID string) { - m.channel = channel - m.chatID = chatID +func (m *mockContextAwareTool) Execute(ctx context.Context, _ map[string]any) *ToolResult { + m.lastCtx = ctx + return m.result } type mockAsyncRegistryTool struct { mockRegistryTool - cb AsyncCallback + lastCB AsyncCallback } -func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) { - m.cb = cb +func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string]any, cb AsyncCallback) *ToolResult { + m.lastCB = cb + return m.result } // --- helpers --- @@ -136,34 +136,44 @@ func TestToolRegistry_Execute_NotFound(t *testing.T) { } } -func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) { +func TestToolRegistry_ExecuteWithContext_InjectsToolContext(t *testing.T) { r := NewToolRegistry() - ct := &mockCtxTool{ + ct := &mockContextAwareTool{ mockRegistryTool: *newMockTool("ctx_tool", "needs context"), } r.Register(ct) r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil) - if ct.channel != "telegram" { - t.Errorf("expected channel 'telegram', got %q", ct.channel) + if ct.lastCtx == nil { + t.Fatal("expected Execute to be called") } - if ct.chatID != "chat-42" { - t.Errorf("expected chatID 'chat-42', got %q", ct.chatID) + if got := ToolChannel(ct.lastCtx); got != "telegram" { + t.Errorf("expected channel 'telegram', got %q", got) + } + if got := ToolChatID(ct.lastCtx); got != "chat-42" { + t.Errorf("expected chatID 'chat-42', got %q", got) } } -func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) { +func TestToolRegistry_ExecuteWithContext_EmptyContext(t *testing.T) { r := NewToolRegistry() - ct := &mockCtxTool{ + ct := &mockContextAwareTool{ mockRegistryTool: *newMockTool("ctx_tool", "needs context"), } r.Register(ct) r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil) - if ct.channel != "" || ct.chatID != "" { - t.Error("SetContext should not be called with empty channel/chatID") + if ct.lastCtx == nil { + t.Fatal("expected Execute to be called") + } + // Empty values are still injected; tools decide what to do with them. + if got := ToolChannel(ct.lastCtx); got != "" { + t.Errorf("expected empty channel, got %q", got) + } + if got := ToolChatID(ct.lastCtx); got != "" { + t.Errorf("expected empty chatID, got %q", got) } } @@ -179,14 +189,14 @@ func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) { cb := func(_ context.Context, _ *ToolResult) { called = true } result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb) - if at.cb == nil { - t.Error("expected SetCallback to have been called") + if at.lastCB == nil { + t.Error("expected ExecuteAsync to have received a callback") } if !result.Async { t.Error("expected async result") } - at.cb(context.Background(), SilentResult("done")) + at.lastCB(context.Background(), SilentResult("done")) if !called { t.Error("expected callback to be invoked") } diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 8b166b41f..be40ffda2 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -8,25 +8,18 @@ import ( type SpawnTool struct { manager *SubagentManager - originChannel string - originChatID string allowlistCheck func(targetAgentID string) bool - callback AsyncCallback // For async completion notification } +// Compile-time check: SpawnTool implements AsyncExecutor. +var _ AsyncExecutor = (*SpawnTool)(nil) + func NewSpawnTool(manager *SubagentManager) *SpawnTool { return &SpawnTool{ - manager: manager, - originChannel: "cli", - originChatID: "direct", + manager: manager, } } -// SetCallback implements AsyncTool interface for async completion notification -func (t *SpawnTool) SetCallback(cb AsyncCallback) { - t.callback = cb -} - func (t *SpawnTool) Name() string { return "spawn" } @@ -56,16 +49,21 @@ func (t *SpawnTool) Parameters() map[string]any { } } -func (t *SpawnTool) SetContext(channel, chatID string) { - t.originChannel = channel - t.originChatID = chatID -} - func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) { t.allowlistCheck = check } func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + return t.execute(ctx, args, nil) +} + +// ExecuteAsync implements AsyncExecutor. The callback is passed through to the +// subagent manager as a call parameter — never stored on the SpawnTool instance. +func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult { + return t.execute(ctx, args, cb) +} + +func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult { task, ok := args["task"].(string) if !ok || strings.TrimSpace(task) == "" { return ErrorResult("task is required and must be a non-empty string") @@ -85,8 +83,20 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul return ErrorResult("Subagent manager not configured") } + // Read channel/chatID from context (injected by registry). + // Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests) + // to preserve the same defaults as the original NewSpawnTool constructor. + channel := ToolChannel(ctx) + if channel == "" { + channel = "cli" + } + chatID := ToolChatID(ctx) + if chatID == "" { + chatID = "direct" + } + // Pass callback to manager for async completion notification - result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback) + result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb) if err != nil { return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 69f1a49a2..429340047 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -252,16 +252,12 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask { // Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion // and returns the result directly in the ToolResult. type SubagentTool struct { - manager *SubagentManager - originChannel string - originChatID string + manager *SubagentManager } func NewSubagentTool(manager *SubagentManager) *SubagentTool { return &SubagentTool{ - manager: manager, - originChannel: "cli", - originChatID: "direct", + manager: manager, } } @@ -290,11 +286,6 @@ func (t *SubagentTool) Parameters() map[string]any { } } -func (t *SubagentTool) SetContext(channel, chatID string) { - t.originChannel = channel - t.originChatID = chatID -} - func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult { task, ok := args["task"].(string) if !ok { @@ -341,13 +332,24 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe } } + // Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests) + // to preserve the same defaults as the original NewSubagentTool constructor. + channel := ToolChannel(ctx) + if channel == "" { + channel = "cli" + } + chatID := ToolChatID(ctx) + if chatID == "" { + chatID = "direct" + } + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ Provider: sm.provider, Model: sm.defaultModel, Tools: tools, MaxIterations: maxIter, LLMOptions: llmOptions, - }, messages, t.originChannel, t.originChatID) + }, messages, channel, chatID) if err != nil { return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) } diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go index 59bfdffae..a1450410a 100644 --- a/pkg/tools/subagent_tool_test.go +++ b/pkg/tools/subagent_tool_test.go @@ -50,9 +50,8 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) { manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) manager.SetLLMOptions(2048, 0.6) tool := NewSubagentTool(manager) - tool.SetContext("cli", "direct") - ctx := context.Background() + ctx := WithToolContext(context.Background(), "cli", "direct") args := map[string]any{"task": "Do something"} result := tool.Execute(ctx, args) @@ -147,28 +146,14 @@ func TestSubagentTool_Parameters(t *testing.T) { } } -// TestSubagentTool_SetContext verifies context setting -func TestSubagentTool_SetContext(t *testing.T) { - provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) - tool := NewSubagentTool(manager) - - tool.SetContext("test-channel", "test-chat") - - // Verify context is set (we can't directly access private fields, - // but we can verify it doesn't crash) - // The actual context usage is tested in Execute tests -} - // TestSubagentTool_Execute_Success tests successful execution func TestSubagentTool_Execute_Success(t *testing.T) { provider := &MockLLMProvider{} msgBus := bus.NewMessageBus() manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) tool := NewSubagentTool(manager) - tool.SetContext("telegram", "chat-123") - ctx := context.Background() + ctx := WithToolContext(context.Background(), "telegram", "chat-123") args := map[string]any{ "task": "Write a haiku about coding", "label": "haiku-task", @@ -297,12 +282,9 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) { manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) tool := NewSubagentTool(manager) - // Set context channel := "test-channel" chatID := "test-chat" - tool.SetContext(channel, chatID) - - ctx := context.Background() + ctx := WithToolContext(context.Background(), channel, chatID) args := map[string]any{ "task": "Test context passing", } From 41bb78f5939686235ac2cb4bd4bb50aaa44f11c6 Mon Sep 17 00:00:00 2001 From: Mauro Date: Thu, 5 Mar 2026 04:13:11 +0100 Subject: [PATCH 06/11] feat(ci) govulncheck (#1086) * feat(ci) govulncheck * feat(ci) disable persist-credentials --- .github/workflows/pr.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index be1c10c52..1e9a7919a 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -24,6 +24,25 @@ jobs: with: version: v2.10.1 + vuln_check: + name: Security Check + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + persist-credentials: false + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Run Govulncheck + uses: golang/govulncheck-action@v1 + with: + go-package: ./... + test: name: Tests runs-on: ubuntu-latest From 10ad9e83f96c81bc7059abb85a3ac68384cfdead Mon Sep 17 00:00:00 2001 From: lxowalle <83055338+lxowalle@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:15:16 +0800 Subject: [PATCH 07/11] docs: update license (#1131) --- LICENSE | 4 ---- 1 file changed, 4 deletions(-) diff --git a/LICENSE b/LICENSE index 410acae26..b38d9340d 100644 --- a/LICENSE +++ b/LICENSE @@ -19,7 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ---- - -PicoClaw is heavily inspired by and based on [nanobot](https://github.com/HKUDS/nanobot) by HKUDS. From 6f5930624b1cfaea113279f637a45569eafa812b Mon Sep 17 00:00:00 2001 From: lxowalle <83055338+lxowalle@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:53:26 +0800 Subject: [PATCH 08/11] Feat/add tool enable or disable configuration (#1071) * Add tools enable or diable config --- cmd/picoclaw/internal/gateway/helpers.go | 28 +++-- config/config.example.json | 87 ++++++++++++-- pkg/agent/instance.go | 33 +++-- pkg/agent/loop.go | 147 ++++++++++++++--------- pkg/agent/loop_test.go | 15 +-- pkg/config/config.go | 105 +++++++++++++--- pkg/config/defaults.go | 59 ++++++++- pkg/mcp/manager_test.go | 16 ++- 8 files changed, 367 insertions(+), 123 deletions(-) diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 5225340c7..174f5db62 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -230,19 +230,25 @@ func setupCronTool( // Create cron service cronService := cron.NewCronService(cronStorePath, nil) - // Create and register CronTool - cronTool, err := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) - if err != nil { - log.Fatalf("Critical error during CronTool initialization: %v", err) + // Create and register CronTool if enabled + var cronTool *tools.CronTool + if cfg.Tools.IsToolEnabled("cron") { + var err error + cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) + if err != nil { + log.Fatalf("Critical error during CronTool initialization: %v", err) + } + + agentLoop.RegisterTool(cronTool) } - agentLoop.RegisterTool(cronTool) - - // Set the onJob handler - cronService.SetOnJob(func(job *cron.CronJob) (string, error) { - result := cronTool.ExecuteJob(context.Background(), job) - return result, nil - }) + // Set onJob handler + if cronTool != nil { + cronService.SetOnJob(func(job *cron.CronJob) (string, error) { + result := cronTool.ExecuteJob(context.Background(), job) + return result, nil + }) + } return cronService } diff --git a/config/config.example.json b/config/config.example.json index c59a39885..ef1bf3eda 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -232,24 +232,41 @@ } }, "tools": { + "allow_read_paths": null, + "allow_write_paths": null, "web": { + "enabled": true, "brave": { "enabled": false, "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 }, + "tavily": { + "enabled": false, + "api_key": "", + "base_url": "", + "max_results": 0 + }, "duckduckgo": { "enabled": true, "max_results": 5 }, "perplexity": { "enabled": false, - "api_key": "pplx-xxx", + "api_key": "", "max_results": 5 }, - "proxy": "" + "glm_search": { + "enabled": false, + "api_key": "", + "base_url": "https://open.bigmodel.cn/api/paas/v4/web_search", + "search_engine": "search_std", + "max_results": 5 + }, + "fetch_limit_bytes": 10485760 }, "cron": { + "enabled": true, "exec_timeout_minutes": 5 }, "mcp": { @@ -318,19 +335,75 @@ } }, "exec": { - "enable_deny_patterns": false, - "custom_deny_patterns": [] + "enabled": true, + "enable_deny_patterns": true, + "custom_deny_patterns": null, + "custom_allow_patterns": null }, "skills": { + "enabled": true, "registries": { "clawhub": { "enabled": true, "base_url": "https://clawhub.ai", - "search_path": "/api/v1/search", - "skills_path": "/api/v1/skills", - "download_path": "/api/v1/download" + "auth_token": "", + "search_path": "", + "skills_path": "", + "download_path": "", + "timeout": 0, + "max_zip_size": 0, + "max_response_size": 0 } + }, + "max_concurrent_searches": 2, + "search_cache": { + "max_size": 50, + "ttl_seconds": 300 } + }, + "media_cleanup": { + "enabled": true, + "max_age_minutes": 30, + "interval_minutes": 5 + }, + "append_file": { + "enabled": true + }, + "edit_file": { + "enabled": true + }, + "find_skills": { + "enabled": true + }, + "i2c": { + "enabled": false + }, + "install_skill": { + "enabled": true + }, + "list_dir": { + "enabled": true + }, + "message": { + "enabled": true + }, + "read_file": { + "enabled": true + }, + "spawn": { + "enabled": true + }, + "spi": { + "enabled": false + }, + "subagent": { + "enabled": true + }, + "web_fetch": { + "enabled": true + }, + "write_file": { + "enabled": true } }, "heartbeat": { diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 1e18b6f64..e14acf06d 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -60,17 +60,30 @@ func NewAgentInstance( allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths) toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths)) - toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths)) - toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths)) - execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg) - if err != nil { - log.Fatalf("Critical error: unable to initialize exec tool: %v", err) - } - toolsRegistry.Register(execTool) - toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths)) - toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths)) + if cfg.Tools.IsToolEnabled("read_file") { + toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths)) + } + if cfg.Tools.IsToolEnabled("write_file") { + toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths)) + } + if cfg.Tools.IsToolEnabled("list_dir") { + toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths)) + } + if cfg.Tools.IsToolEnabled("exec") { + execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg) + if err != nil { + log.Fatalf("Critical error: unable to initialize exec tool: %v", err) + } + toolsRegistry.Register(execTool) + } + + if cfg.Tools.IsToolEnabled("edit_file") { + toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths)) + } + if cfg.Tools.IsToolEnabled("append_file") { + toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths)) + } sessionsDir := filepath.Join(workspace, "sessions") sessionsManager := session.NewSessionManager(sessionsDir) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 263eeb4dd..1ab79f3ca 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -108,76 +108,102 @@ func registerSharedTools( } // Web tools - searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ - BraveAPIKey: cfg.Tools.Web.Brave.APIKey, - BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, - BraveEnabled: cfg.Tools.Web.Brave.Enabled, - TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey, - TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL, - TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults, - TavilyEnabled: cfg.Tools.Web.Tavily.Enabled, - DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, - DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, - PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, - PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, - PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, - GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey, - GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL, - GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine, - GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults, - GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled, - Proxy: cfg.Tools.Web.Proxy, - }) - if err != nil { - logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()}) - } else if searchTool != nil { - agent.Tools.Register(searchTool) + if cfg.Tools.IsToolEnabled("web") { + searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, + BraveEnabled: cfg.Tools.Web.Brave.Enabled, + TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey, + TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL, + TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults, + TavilyEnabled: cfg.Tools.Web.Tavily.Enabled, + DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, + DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, + PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, + PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, + PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, + GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey, + GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL, + GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine, + GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults, + GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled, + Proxy: cfg.Tools.Web.Proxy, + }) + if err != nil { + logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()}) + } else if searchTool != nil { + agent.Tools.Register(searchTool) + } } - fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes) - if err != nil { - logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) - } else { - agent.Tools.Register(fetchTool) + if cfg.Tools.IsToolEnabled("web_fetch") { + fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } else { + agent.Tools.Register(fetchTool) + } } // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms - agent.Tools.Register(tools.NewI2CTool()) - agent.Tools.Register(tools.NewSPITool()) + if cfg.Tools.IsToolEnabled("i2c") { + agent.Tools.Register(tools.NewI2CTool()) + } + if cfg.Tools.IsToolEnabled("spi") { + agent.Tools.Register(tools.NewSPITool()) + } // Message tool - messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { - pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer pubCancel() - return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, + if cfg.Tools.IsToolEnabled("message") { + messageTool := tools.NewMessageTool() + messageTool.SetSendCallback(func(channel, chatID, content string) error { + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) }) - }) - agent.Tools.Register(messageTool) + agent.Tools.Register(messageTool) + } // Skill discovery and installation tools - registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ - MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, - ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), - }) - searchCache := skills.NewSearchCache( - cfg.Tools.Skills.SearchCache.MaxSize, - time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second, - ) - agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache)) - agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace)) + find_skills_enable := cfg.Tools.IsToolEnabled("find_skills") + install_skills_enable := cfg.Tools.IsToolEnabled("install_skill") + if find_skills_enable || install_skills_enable { + registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ + MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, + ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), + }) + + if find_skills_enable { + searchCache := skills.NewSearchCache( + cfg.Tools.Skills.SearchCache.MaxSize, + time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second, + ) + agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache)) + } + + if install_skills_enable { + agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace)) + } + } // Spawn tool with allowlist checker - subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) - subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) - spawnTool := tools.NewSpawnTool(subagentManager) - currentAgentID := agentID - spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { - return registry.CanSpawnSubagent(currentAgentID, targetAgentID) - }) - agent.Tools.Register(spawnTool) + if cfg.Tools.IsToolEnabled("spawn") { + if cfg.Tools.IsToolEnabled("subagent") { + subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) + subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) + spawnTool := tools.NewSpawnTool(subagentManager) + currentAgentID := agentID + spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { + return registry.CanSpawnSubagent(currentAgentID, targetAgentID) + }) + agent.Tools.Register(spawnTool) + } else { + logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil) + } + } } } @@ -185,7 +211,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) // Initialize MCP servers for all agents - if al.cfg.Tools.MCP.Enabled { + if al.cfg.Tools.IsToolEnabled("mcp") { mcpManager := mcp.NewManager() // Ensure MCP connections are cleaned up on exit, regardless of initialization success // This fixes resource leak when LoadFromMCPConfig partially succeeds then fails @@ -227,6 +253,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { if !ok { continue } + mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) agent.Tools.Register(mcpTool) totalRegistrations++ diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 4ab6b4542..aa7d59b5a 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -227,16 +227,11 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) { } defer os.RemoveAll(tmpDir) - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "test-model", - MaxTokens: 4096, - MaxToolIterations: 10, - }, - }, - } + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 msgBus := bus.NewMessageBus() provider := &mockProvider{} diff --git a/pkg/config/config.go b/pkg/config/config.go index 3cfebf5e8..0ee3acfe0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -526,6 +526,10 @@ type GatewayConfig struct { Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` } +type ToolConfig struct { + Enabled bool `json:"enabled" env:"ENABLED"` +} + type BraveConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"` APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"` @@ -561,11 +565,12 @@ type GLMSearchConfig struct { } type WebToolsConfig struct { - Brave BraveConfig `json:"brave"` - Tavily TavilyConfig `json:"tavily"` - DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` - Perplexity PerplexityConfig `json:"perplexity"` - GLMSearch GLMSearchConfig `json:"glm_search"` + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"` + Brave BraveConfig ` json:"brave"` + Tavily TavilyConfig ` json:"tavily"` + DuckDuckGo DuckDuckGoConfig ` json:"duckduckgo"` + Perplexity PerplexityConfig ` json:"perplexity"` + GLMSearch GLMSearchConfig ` json:"glm_search"` // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` @@ -573,19 +578,28 @@ type WebToolsConfig struct { } type CronToolsConfig struct { - ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"` + ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout } type ExecConfig struct { - EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"` - CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"` - CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"` + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"` + EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"` + CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"` + CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"` +} + +type SkillsToolsConfig struct { + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"` + Registries SkillsRegistriesConfig ` json:"registries"` + MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"` + SearchCache SearchCacheConfig ` json:"search_cache"` } type MediaCleanupConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"` - MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"` - Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"` + ToolConfig ` envPrefix:"PICOCLAW_MEDIA_CLEANUP_"` + MaxAge int ` env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE" json:"max_age_minutes"` + Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"` } type ToolsConfig struct { @@ -597,12 +611,19 @@ type ToolsConfig struct { Skills SkillsToolsConfig `json:"skills"` MediaCleanup MediaCleanupConfig `json:"media_cleanup"` MCP MCPConfig `json:"mcp"` -} - -type SkillsToolsConfig struct { - Registries SkillsRegistriesConfig `json:"registries"` - MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"` - SearchCache SearchCacheConfig `json:"search_cache"` + AppendFile ToolConfig `json:"append_file" envPrefix:"PICOCLAW_TOOLS_APPEND_FILE_"` + EditFile ToolConfig `json:"edit_file" envPrefix:"PICOCLAW_TOOLS_EDIT_FILE_"` + FindSkills ToolConfig `json:"find_skills" envPrefix:"PICOCLAW_TOOLS_FIND_SKILLS_"` + I2C ToolConfig `json:"i2c" envPrefix:"PICOCLAW_TOOLS_I2C_"` + InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"` + ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"` + Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"` + ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"` + Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"` + SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"` + Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"` + WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"` + WriteFile ToolConfig `json:"write_file" envPrefix:"PICOCLAW_TOOLS_WRITE_FILE_"` } type SearchCacheConfig struct { @@ -648,8 +669,7 @@ type MCPServerConfig struct { // MCPConfig defines configuration for all MCP servers type MCPConfig struct { - // Enabled globally enables/disables MCP integration - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"` + ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"` // Servers is a map of server name to server configuration Servers map[string]MCPServerConfig `json:"servers,omitempty"` } @@ -835,3 +855,48 @@ func (c *Config) ValidateModelList() error { } return nil } + +func (t *ToolsConfig) IsToolEnabled(name string) bool { + switch name { + case "web": + return t.Web.Enabled + case "cron": + return t.Cron.Enabled + case "exec": + return t.Exec.Enabled + case "skills": + return t.Skills.Enabled + case "media_cleanup": + return t.MediaCleanup.Enabled + case "append_file": + return t.AppendFile.Enabled + case "edit_file": + return t.EditFile.Enabled + case "find_skills": + return t.FindSkills.Enabled + case "i2c": + return t.I2C.Enabled + case "install_skill": + return t.InstallSkill.Enabled + case "list_dir": + return t.ListDir.Enabled + case "message": + return t.Message.Enabled + case "read_file": + return t.ReadFile.Enabled + case "spawn": + return t.Spawn.Enabled + case "spi": + return t.SPI.Enabled + case "subagent": + return t.Subagent.Enabled + case "web_fetch": + return t.WebFetch.Enabled + case "write_file": + return t.WriteFile.Enabled + case "mcp": + return t.MCP.Enabled + default: + return true + } +} diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 84fc60435..488590e28 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -336,11 +336,16 @@ func DefaultConfig() *Config { }, Tools: ToolsConfig{ MediaCleanup: MediaCleanupConfig{ - Enabled: true, + ToolConfig: ToolConfig{ + Enabled: true, + }, MaxAge: 30, Interval: 5, }, Web: WebToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, Proxy: "", FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default Brave: BraveConfig{ @@ -366,12 +371,21 @@ func DefaultConfig() *Config { }, }, Cron: CronToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, ExecTimeoutMinutes: 5, }, Exec: ExecConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, EnableDenyPatterns: true, }, Skills: SkillsToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, Registries: SkillsRegistriesConfig{ ClawHub: ClawHubRegistryConfig{ Enabled: true, @@ -385,9 +399,50 @@ func DefaultConfig() *Config { }, }, MCP: MCPConfig{ - Enabled: false, + ToolConfig: ToolConfig{ + Enabled: false, + }, Servers: map[string]MCPServerConfig{}, }, + AppendFile: ToolConfig{ + Enabled: true, + }, + EditFile: ToolConfig{ + Enabled: true, + }, + FindSkills: ToolConfig{ + Enabled: true, + }, + I2C: ToolConfig{ + Enabled: false, // Hardware tool - Linux only + }, + InstallSkill: ToolConfig{ + Enabled: true, + }, + ListDir: ToolConfig{ + Enabled: true, + }, + Message: ToolConfig{ + Enabled: true, + }, + ReadFile: ToolConfig{ + Enabled: true, + }, + Spawn: ToolConfig{ + Enabled: true, + }, + SPI: ToolConfig{ + Enabled: false, // Hardware tool - Linux only + }, + Subagent: ToolConfig{ + Enabled: true, + }, + WebFetch: ToolConfig{ + Enabled: true, + }, + WriteFile: ToolConfig{ + Enabled: true, + }, }, Heartbeat: HeartbeatConfig{ Enabled: true, diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go index 8ce81d09e..f353942ab 100644 --- a/pkg/mcp/manager_test.go +++ b/pkg/mcp/manager_test.go @@ -194,7 +194,9 @@ func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) { mgr := NewManager() mcpCfg := config.MCPConfig{ - Enabled: true, + ToolConfig: config.ToolConfig{ + Enabled: true, + }, Servers: map[string]config.MCPServerConfig{ "test-server": { Enabled: true, @@ -228,12 +230,20 @@ func TestNewManager_InitialState(t *testing.T) { func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) { mgr := NewManager() - err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp") + err := mgr.LoadFromMCPConfig( + context.Background(), + config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}}, + "/tmp", + ) if err != nil { t.Fatalf("expected nil error when MCP disabled, got: %v", err) } - err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp") + err = mgr.LoadFromMCPConfig( + context.Background(), + config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}}, + "/tmp", + ) if err != nil { t.Fatalf("expected nil error when no servers configured, got: %v", err) } From 7a2fdc24dc202012db849722ba92df8238257b3e Mon Sep 17 00:00:00 2001 From: qs3c <2749950753@qq.com> Date: Thu, 5 Mar 2026 15:00:06 +0800 Subject: [PATCH 09/11] fix(skills): retry ClawHub requests on 429 --- docs/tools_configuration.md | 2 + pkg/skills/clawhub_registry.go | 162 ++++++++++++++++++++++++---- pkg/skills/clawhub_registry_test.go | 81 ++++++++++++++ 3 files changed, 225 insertions(+), 20 deletions(-) diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md index 6204fb0c8..e64a3a107 100644 --- a/docs/tools_configuration.md +++ b/docs/tools_configuration.md @@ -180,6 +180,7 @@ The skills tool configures skill discovery and installation via registries like | ---------------------------------- | ------ | -------------------- | ----------------------- | | `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | | `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | +| `registries.clawhub.auth_token` | string | `""` | Optional Bearer token for higher rate limits | | `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | | `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | | `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | @@ -194,6 +195,7 @@ The skills tool configures skill discovery and installation via registries like "clawhub": { "enabled": true, "base_url": "https://clawhub.ai", + "auth_token": "", "search_path": "/api/v1/search", "skills_path": "/api/v1/skills", "download_path": "/api/v1/download" diff --git a/pkg/skills/clawhub_registry.go b/pkg/skills/clawhub_registry.go index f78197bbe..b520f3260 100644 --- a/pkg/skills/clawhub_registry.go +++ b/pkg/skills/clawhub_registry.go @@ -8,6 +8,8 @@ import ( "net/http" "net/url" "os" + "strconv" + "strings" "time" "github.com/sipeed/picoclaw/pkg/utils" @@ -17,6 +19,7 @@ const ( defaultClawHubTimeout = 30 * time.Second defaultMaxZipSize = 50 * 1024 * 1024 // 50 MB defaultMaxResponseSize = 2 * 1024 * 1024 // 2 MB + defaultMaxRetries = 3 ) // ClawHubRegistry implements SkillRegistry for the ClawHub platform. @@ -259,15 +262,7 @@ func (c *ClawHubRegistry) DownloadAndInstall( } u.RawQuery = q.Encode() - req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - if c.authToken != "" { - req.Header.Set("Authorization", "Bearer "+c.authToken) - } - - tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize)) + tmpPath, err := c.downloadToTempFileWithRetry(ctx, u.String()) if err != nil { return nil, fmt.Errorf("download failed: %w", err) } @@ -284,17 +279,7 @@ func (c *ClawHubRegistry) DownloadAndInstall( // --- HTTP helper --- func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) { - req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) - if err != nil { - return nil, err - } - - req.Header.Set("Accept", "application/json") - if c.authToken != "" { - req.Header.Set("Authorization", "Bearer "+c.authToken) - } - - resp, err := c.client.Do(req) + resp, err := c.doGetWithRetry(ctx, urlStr, "application/json") if err != nil { return nil, err } @@ -312,3 +297,140 @@ func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, err return body, nil } + +func (c *ClawHubRegistry) doGetWithRetry(ctx context.Context, urlStr, accept string) (*http.Response, error) { + var lastErr error + for attempt := 0; attempt < defaultMaxRetries; attempt++ { + req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", accept) + if c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+c.authToken) + } + + resp, err := c.client.Do(req) + if err != nil { + lastErr = err + } else { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return resp, nil + } + + if !isRetryableStatus(resp.StatusCode) || attempt == defaultMaxRetries-1 { + return resp, nil + } + + delay := retryDelay(resp.Header.Get("Retry-After"), attempt) + resp.Body.Close() + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + + if attempt == defaultMaxRetries-1 { + return nil, lastErr + } + if err := sleepWithContext(ctx, retryDelay("", attempt)); err != nil { + return nil, err + } + } + return nil, lastErr +} + +func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlStr string) (string, error) { + resp, err := c.doGetWithRetry(ctx, urlStr, "application/zip") + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody := make([]byte, 512) + n, _ := io.ReadFull(resp.Body, errBody) + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n])) + } + + tmpFile, err := os.CreateTemp("", "picoclaw-dl-*") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + cleanup := func() { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + } + + src := io.LimitReader(resp.Body, int64(c.maxZipSize)+1) + written, err := io.Copy(tmpFile, src) + if err != nil { + cleanup() + return "", fmt.Errorf("download write failed: %w", err) + } + + if written > int64(c.maxZipSize) { + cleanup() + return "", fmt.Errorf("download too large: %d bytes (max %d)", written, c.maxZipSize) + } + + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("failed to close temp file: %w", err) + } + + return tmpPath, nil +} + +func isRetryableStatus(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || statusCode >= http.StatusInternalServerError +} + +func retryDelay(retryAfter string, attempt int) time.Duration { + if d, ok := parseRetryAfter(retryAfter); ok { + return d + } + return time.Duration(attempt+1) * time.Second +} + +func parseRetryAfter(headerValue string) (time.Duration, bool) { + headerValue = strings.TrimSpace(headerValue) + if headerValue == "" { + return 0, false + } + + if sec, err := strconv.Atoi(headerValue); err == nil { + if sec < 0 { + sec = 0 + } + return time.Duration(sec) * time.Second, true + } + + if resetAt, err := http.ParseTime(headerValue); err == nil { + d := time.Until(resetAt) + if d < 0 { + d = 0 + } + return d, true + } + + return 0, false +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/pkg/skills/clawhub_registry_test.go b/pkg/skills/clawhub_registry_test.go index 65ee638da..055da22dc 100644 --- a/pkg/skills/clawhub_registry_test.go +++ b/pkg/skills/clawhub_registry_test.go @@ -54,6 +54,39 @@ func TestClawHubRegistrySearch(t *testing.T) { assert.Equal(t, "clawhub", results[0].RegistryName) } +func TestClawHubRegistrySearchRetries429(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + + slug := "github" + name := "GitHub Integration" + summary := "Interact with GitHub repos" + version := "1.0.0" + + json.NewEncoder(w).Encode(clawhubSearchResponse{ + Results: []clawhubSearchResult{ + {Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version}, + }, + }) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + results, err := reg.Search(context.Background(), "github", 5) + + require.NoError(t, err) + require.Len(t, results, 1) + assert.Equal(t, 2, attempts) + assert.Equal(t, "github", results[0].Slug) +} + func TestClawHubRegistryGetSkillMeta(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/api/v1/skills/github", r.URL.Path) @@ -137,6 +170,54 @@ func TestClawHubRegistryDownloadAndInstall(t *testing.T) { assert.Contains(t, string(readmeContent), "# Test Skill") } +func TestClawHubRegistryDownloadAndInstallRetries429(t *testing.T) { + zipBuf := createTestZip(t, map[string]string{ + "SKILL.md": "---\nname: retry-skill\ndescription: A test\n---\nHello skill", + }) + + downloadAttempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/skills/retry-skill": + json.NewEncoder(w).Encode(clawhubSkillResponse{ + Slug: "retry-skill", + DisplayName: "Retry Skill", + Summary: "A retry test skill", + LatestVersion: &clawhubVersionInfo{Version: "1.0.0"}, + }) + case "/api/v1/download": + downloadAttempts++ + if downloadAttempts == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + assert.Equal(t, "retry-skill", r.URL.Query().Get("slug")) + w.Header().Set("Content-Type", "application/zip") + w.Write(zipBuf) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + tmpDir := t.TempDir() + targetDir := filepath.Join(tmpDir, "retry-skill") + + reg := newTestRegistry(srv.URL, "") + result, err := reg.DownloadAndInstall(context.Background(), "retry-skill", "", targetDir) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "1.0.0", result.Version) + assert.Equal(t, 2, downloadAttempts) + + skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md")) + require.NoError(t, err) + assert.Contains(t, string(skillContent), "Hello skill") +} + func TestClawHubRegistryAuthToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") From ab120af64906537409c297c9c56d1bf80492926e Mon Sep 17 00:00:00 2001 From: cornjosh Date: Thu, 5 Mar 2026 17:10:04 +0800 Subject: [PATCH 10/11] fix(skills): use --registry flag value as registry name The --registry flag value was previously ignored and only used as a switch. Now the flag value is properly used as the registry name. Fixes #1104 Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- cmd/picoclaw/internal/skills/install.go | 6 +- cmd/picoclaw/internal/skills/install_test.go | 69 ++++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/cmd/picoclaw/internal/skills/install.go b/cmd/picoclaw/internal/skills/install.go index a30f68632..78bc421db 100644 --- a/cmd/picoclaw/internal/skills/install.go +++ b/cmd/picoclaw/internal/skills/install.go @@ -21,8 +21,8 @@ picoclaw skills install --registry clawhub github `, Args: func(cmd *cobra.Command, args []string) error { if registry != "" { - if len(args) != 2 { - return fmt.Errorf("when --registry is set, exactly 2 arguments are required: ") + if len(args) != 1 { + return fmt.Errorf("when --registry is set, exactly 1 argument is required: ") } return nil } @@ -45,7 +45,7 @@ picoclaw skills install --registry clawhub github return err } - return skillsInstallFromRegistry(cfg, args[0], args[1]) + return skillsInstallFromRegistry(cfg, registry, args[0]) } return skillsInstallCmd(installer, args[0]) diff --git a/cmd/picoclaw/internal/skills/install_test.go b/cmd/picoclaw/internal/skills/install_test.go index 97787a986..6b362822d 100644 --- a/cmd/picoclaw/internal/skills/install_test.go +++ b/cmd/picoclaw/internal/skills/install_test.go @@ -26,3 +26,72 @@ func TestNewInstallSubcommand(t *testing.T) { assert.Len(t, cmd.Aliases, 0) } + +func TestInstallCommandArgs(t *testing.T) { + tests := []struct { + name string + args []string + registry string + expectError bool + errorMsg string + }{ + { + name: "no registry, one arg", + args: []string{"sipeed/picoclaw-skills/weather"}, + registry: "", + expectError: false, + }, + { + name: "no registry, no args", + args: []string{}, + registry: "", + expectError: true, + errorMsg: "exactly 1 argument is required: ", + }, + { + name: "no registry, too many args", + args: []string{"arg1", "arg2"}, + registry: "", + expectError: true, + errorMsg: "exactly 1 argument is required: ", + }, + { + name: "with registry, one arg", + args: []string{"weather-skill"}, + registry: "clawhub", + expectError: false, + }, + { + name: "with registry, no args", + args: []string{}, + registry: "clawhub", + expectError: true, + errorMsg: "when --registry is set, exactly 1 argument is required: ", + }, + { + name: "with registry, too many args", + args: []string{"arg1", "arg2"}, + registry: "clawhub", + expectError: true, + errorMsg: "when --registry is set, exactly 1 argument is required: ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := newInstallCommand(nil) + + if tt.registry != "" { + require.NoError(t, cmd.Flags().Set("registry", tt.registry)) + } + + err := cmd.Args(cmd, tt.args) + if tt.expectError { + require.Error(t, err) + assert.Equal(t, tt.errorMsg, err.Error()) + } else { + require.NoError(t, err) + } + }) + } +} From 536e9ac9de6aadf97ea23fbe21f7c6126589a625 Mon Sep 17 00:00:00 2001 From: qs3c <2749950753@qq.com> Date: Thu, 5 Mar 2026 19:10:36 +0800 Subject: [PATCH 11/11] refactor(skills): reuse shared HTTP retry helper --- pkg/skills/clawhub_registry.go | 116 ++++++--------------------------- 1 file changed, 21 insertions(+), 95 deletions(-) diff --git a/pkg/skills/clawhub_registry.go b/pkg/skills/clawhub_registry.go index b520f3260..bd4bed8fb 100644 --- a/pkg/skills/clawhub_registry.go +++ b/pkg/skills/clawhub_registry.go @@ -8,8 +8,6 @@ import ( "net/http" "net/url" "os" - "strconv" - "strings" "time" "github.com/sipeed/picoclaw/pkg/utils" @@ -19,7 +17,6 @@ const ( defaultClawHubTimeout = 30 * time.Second defaultMaxZipSize = 50 * 1024 * 1024 // 50 MB defaultMaxResponseSize = 2 * 1024 * 1024 // 2 MB - defaultMaxRetries = 3 ) // ClawHubRegistry implements SkillRegistry for the ClawHub platform. @@ -279,7 +276,12 @@ func (c *ClawHubRegistry) DownloadAndInstall( // --- HTTP helper --- func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) { - resp, err := c.doGetWithRetry(ctx, urlStr, "application/json") + req, err := c.newGetRequest(ctx, urlStr, "application/json") + if err != nil { + return nil, err + } + + resp, err := utils.DoRequestWithRetry(c.client, req) if err != nil { return nil, err } @@ -298,50 +300,25 @@ func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, err return body, nil } -func (c *ClawHubRegistry) doGetWithRetry(ctx context.Context, urlStr, accept string) (*http.Response, error) { - var lastErr error - for attempt := 0; attempt < defaultMaxRetries; attempt++ { - req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) - if err != nil { - return nil, err - } - req.Header.Set("Accept", accept) - if c.authToken != "" { - req.Header.Set("Authorization", "Bearer "+c.authToken) - } - - resp, err := c.client.Do(req) - if err != nil { - lastErr = err - } else { - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - return resp, nil - } - - if !isRetryableStatus(resp.StatusCode) || attempt == defaultMaxRetries-1 { - return resp, nil - } - - delay := retryDelay(resp.Header.Get("Retry-After"), attempt) - resp.Body.Close() - if err := sleepWithContext(ctx, delay); err != nil { - return nil, err - } - continue - } - - if attempt == defaultMaxRetries-1 { - return nil, lastErr - } - if err := sleepWithContext(ctx, retryDelay("", attempt)); err != nil { - return nil, err - } +func (c *ClawHubRegistry) newGetRequest(ctx context.Context, urlStr, accept string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil) + if err != nil { + return nil, err } - return nil, lastErr + req.Header.Set("Accept", accept) + if c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+c.authToken) + } + return req, nil } func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlStr string) (string, error) { - resp, err := c.doGetWithRetry(ctx, urlStr, "application/zip") + req, err := c.newGetRequest(ctx, urlStr, "application/zip") + if err != nil { + return "", err + } + + resp, err := utils.DoRequestWithRetry(c.client, req) if err != nil { return "", err } @@ -383,54 +360,3 @@ func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlSt return tmpPath, nil } - -func isRetryableStatus(statusCode int) bool { - return statusCode == http.StatusTooManyRequests || statusCode >= http.StatusInternalServerError -} - -func retryDelay(retryAfter string, attempt int) time.Duration { - if d, ok := parseRetryAfter(retryAfter); ok { - return d - } - return time.Duration(attempt+1) * time.Second -} - -func parseRetryAfter(headerValue string) (time.Duration, bool) { - headerValue = strings.TrimSpace(headerValue) - if headerValue == "" { - return 0, false - } - - if sec, err := strconv.Atoi(headerValue); err == nil { - if sec < 0 { - sec = 0 - } - return time.Duration(sec) * time.Second, true - } - - if resetAt, err := http.ParseTime(headerValue); err == nil { - d := time.Until(resetAt) - if d < 0 { - d = 0 - } - return d, true - } - - return 0, false -} - -func sleepWithContext(ctx context.Context, delay time.Duration) error { - if delay <= 0 { - return nil - } - - timer := time.NewTimer(delay) - defer timer.Stop() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } -}