Merge upstream/main into feat/searxng

Resolved conflicts in 3 files:
- config/config.example.json: keep both searxng and glm_search configs
- pkg/agent/loop.go: adopt upstream's IsToolEnabled guard + keep searxng fields
- pkg/config/config.go: adopt upstream's ToolConfig embed + keep SearXNG field

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Truong Vinh Tran
2026-03-05 21:36:05 +01:00
35 changed files with 1260 additions and 430 deletions
+31 -10
View File
@@ -26,6 +26,7 @@ type AgentInstance struct {
MaxIterations int
MaxTokens int
Temperature float64
ThinkingLevel ThinkingLevel
ContextWindow int
SummarizeMessageThreshold int
SummarizeTokenPercent int
@@ -59,17 +60,30 @@ func NewAgentInstance(
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
toolsRegistry.Register(execTool)
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
if cfg.Tools.IsToolEnabled("read_file") {
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("write_file") {
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
}
if cfg.Tools.IsToolEnabled("list_dir") {
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("exec") {
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
toolsRegistry.Register(execTool)
}
if cfg.Tools.IsToolEnabled("edit_file") {
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
}
if cfg.Tools.IsToolEnabled("append_file") {
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
}
sessionsDir := filepath.Join(workspace, "sessions")
sessionsManager := session.NewSessionManager(sessionsDir)
@@ -103,6 +117,12 @@ func NewAgentInstance(
temperature = *defaults.Temperature
}
var thinkingLevelStr string
if mc, err := cfg.GetModelConfig(model); err == nil {
thinkingLevelStr = mc.ThinkingLevel
}
thinkingLevel := parseThinkingLevel(thinkingLevelStr)
summarizeMessageThreshold := defaults.SummarizeMessageThreshold
if summarizeMessageThreshold == 0 {
summarizeMessageThreshold = 20
@@ -169,6 +189,7 @@ func NewAgentInstance(
MaxIterations: maxIter,
MaxTokens: maxTokens,
Temperature: temperature,
ThinkingLevel: thinkingLevel,
ContextWindow: maxTokens,
SummarizeMessageThreshold: summarizeMessageThreshold,
SummarizeTokenPercent: summarizeTokenPercent,
+119 -113
View File
@@ -108,79 +108,105 @@ func registerSharedTools(
}
// Web tools
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults,
SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled,
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
if cfg.Tools.IsToolEnabled("web") {
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults,
SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled,
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
}
}
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
if cfg.Tools.IsToolEnabled("web_fetch") {
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
}
}
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
agent.Tools.Register(tools.NewSPITool())
if cfg.Tools.IsToolEnabled("i2c") {
agent.Tools.Register(tools.NewI2CTool())
}
if cfg.Tools.IsToolEnabled("spi") {
agent.Tools.Register(tools.NewSPITool())
}
// Message tool
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
if cfg.Tools.IsToolEnabled("message") {
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
})
})
})
agent.Tools.Register(messageTool)
agent.Tools.Register(messageTool)
}
// Skill discovery and installation tools
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
})
searchCache := skills.NewSearchCache(
cfg.Tools.Skills.SearchCache.MaxSize,
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
)
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
install_skills_enable := cfg.Tools.IsToolEnabled("install_skill")
if find_skills_enable || install_skills_enable {
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
})
if find_skills_enable {
searchCache := skills.NewSearchCache(
cfg.Tools.Skills.SearchCache.MaxSize,
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
)
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
}
if install_skills_enable {
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
}
}
// Spawn tool with allowlist checker
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
if cfg.Tools.IsToolEnabled("spawn") {
if cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
} else {
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
}
}
}
}
@@ -188,7 +214,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
// Initialize MCP servers for all agents
if al.cfg.Tools.MCP.Enabled {
if al.cfg.Tools.IsToolEnabled("mcp") {
mcpManager := mcp.NewManager()
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
@@ -230,6 +256,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
if !ok {
continue
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
agent.Tools.Register(mcpTool)
totalRegistrations++
@@ -546,8 +573,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
// Reset message-tool state for this round so we don't skip publishing due to a previous round.
if tool, ok := agent.Tools.Get("message"); ok {
if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(msg.Channel, msg.ChatID)
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
resetter.ResetSentInRound()
}
}
@@ -662,10 +689,7 @@ func (al *AgentLoop) runAgentLoop(
}
}
// 1. Update tool contexts
al.updateToolContexts(agent, opts.Channel, opts.ChatID)
// 2. Build messages (skip history for heartbeat)
// 1. Build messages (skip history for heartbeat)
var history []providers.Message
var summary string
if !opts.NoHistory {
@@ -685,10 +709,10 @@ func (al *AgentLoop) runAgentLoop(
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
// 3. Save user message to session
// 2. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
// 4. Run LLM iteration loop
// 3. Run LLM iteration loop
finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts)
if err != nil {
return "", err
@@ -697,21 +721,21 @@ func (al *AgentLoop) runAgentLoop(
// If last tool had ForUser content and we already sent it, we might not need to send final response
// This is controlled by the tool's Silent flag and ForUser content
// 5. Handle empty response
// 4. Handle empty response
if finalContent == "" {
finalContent = opts.DefaultResponse
}
// 6. Save final assistant message to session
// 5. Save final assistant message to session
agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
agent.Sessions.Save(opts.SessionKey)
// 7. Optional: summarization
// 6. Optional: summarization
if opts.EnableSummary {
al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
}
// 8. Optional: send response via bus
// 7. Optional: send response via bus
if opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
@@ -720,7 +744,7 @@ func (al *AgentLoop) runAgentLoop(
})
}
// 9. Log response
// 8. Log response
responsePreview := utils.Truncate(finalContent, 120)
logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
map[string]any{
@@ -837,23 +861,29 @@ func (al *AgentLoop) runLLMIteration(
var response *providers.LLMResponse
var err error
llmOpts := map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
}
// parseThinkingLevel guarantees ThinkingOff for empty/unknown values,
// so checking != ThinkingOff is sufficient.
if agent.ThinkingLevel != ThinkingOff {
if tc, ok := agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
llmOpts["thinking_level"] = string(agent.ThinkingLevel)
} else {
logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring",
map[string]any{"agent_id": agent.ID, "thinking_level": string(agent.ThinkingLevel)})
}
}
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(
ctx,
agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(
ctx,
messages,
providerToolDefs,
model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
},
)
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
},
)
if fbErr != nil {
@@ -869,11 +899,7 @@ func (al *AgentLoop) runLLMIteration(
}
return fbResult.Response, nil
}
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
})
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts)
}
// Retry loop for context/token errors
@@ -1060,7 +1086,7 @@ func (al *AgentLoop) runLLMIteration(
"iteration": iteration,
})
// Create async callback for tools that implement AsyncTool
// Create async callback for tools that implement AsyncExecutor
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
@@ -1142,26 +1168,6 @@ func (al *AgentLoop) runLLMIteration(
return finalContent, iteration, nil
}
// updateToolContexts updates the context for tools that need channel/chatID info.
func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) {
// Use ContextualTool interface instead of type assertions
if tool, ok := agent.Tools.Get("message"); ok {
if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(channel, chatID)
}
}
if tool, ok := agent.Tools.Get("spawn"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
if tool, ok := agent.Tools.Get("subagent"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
+16 -65
View File
@@ -164,35 +164,21 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
}
}
// TestToolContext_Updates verifies tool context is updated with channel/chatID
// TestToolContext_Updates verifies tool context helpers work correctly
func TestToolContext_Updates(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
ctx := tools.WithToolContext(context.Background(), "telegram", "chat-42")
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
if got := tools.ToolChannel(ctx); got != "telegram" {
t.Errorf("expected channel 'telegram', got %q", got)
}
if got := tools.ToolChatID(ctx); got != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", got)
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "OK"}
_ = NewAgentLoop(cfg, msgBus, provider)
// Verify that ContextualTool interface is defined and can be implemented
// This test validates the interface contract exists
ctxTool := &mockContextualTool{}
// Verify the tool implements the interface correctly
var _ tools.ContextualTool = ctxTool
// Empty context returns empty strings
if got := tools.ToolChannel(context.Background()); got != "" {
t.Errorf("expected empty channel from bare context, got %q", got)
}
}
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
@@ -241,16 +227,11 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) {
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = tmpDir
cfg.Agents.Defaults.Model = "test-model"
cfg.Agents.Defaults.MaxTokens = 4096
cfg.Agents.Defaults.MaxToolIterations = 10
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
@@ -359,36 +340,6 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool
return tools.SilentResult("Custom tool executed")
}
// mockContextualTool tracks context updates
type mockContextualTool struct {
lastChannel string
lastChatID string
}
func (m *mockContextualTool) Name() string {
return "mock_contextual"
}
func (m *mockContextualTool) Description() string {
return "Mock contextual tool"
}
func (m *mockContextualTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
return tools.SilentResult("Contextual tool executed")
}
func (m *mockContextualTool) SetContext(channel, chatID string) {
m.lastChannel = channel
m.lastChatID = chatID
}
// testHelper executes a message and returns the response
type testHelper struct {
al *AgentLoop
+39
View File
@@ -0,0 +1,39 @@
package agent
import "strings"
// ThinkingLevel controls how the provider sends thinking parameters.
//
// - "adaptive": sends {thinking: {type: "adaptive"}} + output_config.effort (Claude 4.6+)
// - "low"/"medium"/"high"/"xhigh": sends {thinking: {type: "enabled", budget_tokens: N}} (all models)
// - "off": disables thinking
type ThinkingLevel string
const (
ThinkingOff ThinkingLevel = "off"
ThinkingLow ThinkingLevel = "low"
ThinkingMedium ThinkingLevel = "medium"
ThinkingHigh ThinkingLevel = "high"
ThinkingXHigh ThinkingLevel = "xhigh"
ThinkingAdaptive ThinkingLevel = "adaptive"
)
// parseThinkingLevel normalizes a config string to a ThinkingLevel.
// Case-insensitive and whitespace-tolerant for user-facing config values.
// Returns ThinkingOff for unknown or empty values.
func parseThinkingLevel(level string) ThinkingLevel {
switch strings.ToLower(strings.TrimSpace(level)) {
case "adaptive":
return ThinkingAdaptive
case "low":
return ThinkingLow
case "medium":
return ThinkingMedium
case "high":
return ThinkingHigh
case "xhigh":
return ThinkingXHigh
default:
return ThinkingOff
}
}
+35
View File
@@ -0,0 +1,35 @@
package agent
import "testing"
func TestParseThinkingLevel(t *testing.T) {
tests := []struct {
name string
input string
want ThinkingLevel
}{
{"off", "off", ThinkingOff},
{"empty", "", ThinkingOff},
{"low", "low", ThinkingLow},
{"medium", "medium", ThinkingMedium},
{"high", "high", ThinkingHigh},
{"xhigh", "xhigh", ThinkingXHigh},
{"adaptive", "adaptive", ThinkingAdaptive},
{"unknown", "unknown", ThinkingOff},
// Case-insensitive and whitespace-tolerant
{"upper_Medium", "Medium", ThinkingMedium},
{"upper_HIGH", "HIGH", ThinkingHigh},
{"mixed_Adaptive", "Adaptive", ThinkingAdaptive},
{"leading_space", " high", ThinkingHigh},
{"trailing_space", "low ", ThinkingLow},
{"both_spaces", " medium ", ThinkingMedium},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseThinkingLevel(tt.input); got != tt.want {
t.Errorf("parseThinkingLevel(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
+90 -22
View File
@@ -431,6 +431,7 @@ type ProvidersConfig struct {
Antigravity ProviderConfig `json:"antigravity"`
Qwen ProviderConfig `json:"qwen"`
Mistral ProviderConfig `json:"mistral"`
Avian ProviderConfig `json:"avian"`
}
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
@@ -454,7 +455,8 @@ func (p ProvidersConfig) IsEmpty() bool {
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
p.Mistral.APIKey == "" && p.Mistral.APIBase == ""
p.Mistral.APIKey == "" && p.Mistral.APIBase == "" &&
p.Avian.APIKey == "" && p.Avian.APIBase == ""
}
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
@@ -505,6 +507,7 @@ type ModelConfig struct {
RPM int `json:"rpm,omitempty"` // Requests per minute limit
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
}
// Validate checks if the ModelConfig has all required fields.
@@ -523,6 +526,10 @@ type GatewayConfig struct {
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
}
type ToolConfig struct {
Enabled bool `json:"enabled" env:"ENABLED"`
}
type BraveConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
@@ -564,12 +571,13 @@ type GLMSearchConfig struct {
}
type WebToolsConfig struct {
Brave BraveConfig `json:"brave"`
Tavily TavilyConfig `json:"tavily"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
Perplexity PerplexityConfig `json:"perplexity"`
SearXNG SearXNGConfig `json:"searxng"`
GLMSearch GLMSearchConfig `json:"glm_search"`
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"`
Brave BraveConfig ` json:"brave"`
Tavily TavilyConfig ` json:"tavily"`
DuckDuckGo DuckDuckGoConfig ` json:"duckduckgo"`
Perplexity PerplexityConfig ` json:"perplexity"`
SearXNG SearXNGConfig ` json:"searxng"`
GLMSearch GLMSearchConfig ` json:"glm_search"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
@@ -577,19 +585,28 @@ type WebToolsConfig struct {
}
type CronToolsConfig struct {
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
}
type ExecConfig struct {
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"`
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"`
EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"`
CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"`
CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"`
}
type SkillsToolsConfig struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"`
Registries SkillsRegistriesConfig ` json:"registries"`
MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"`
SearchCache SearchCacheConfig ` json:"search_cache"`
}
type MediaCleanupConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"`
MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"`
Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"`
ToolConfig ` envPrefix:"PICOCLAW_MEDIA_CLEANUP_"`
MaxAge int ` env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE" json:"max_age_minutes"`
Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"`
}
type ToolsConfig struct {
@@ -601,12 +618,19 @@ type ToolsConfig struct {
Skills SkillsToolsConfig `json:"skills"`
MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
MCP MCPConfig `json:"mcp"`
}
type SkillsToolsConfig struct {
Registries SkillsRegistriesConfig `json:"registries"`
MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"`
SearchCache SearchCacheConfig `json:"search_cache"`
AppendFile ToolConfig `json:"append_file" envPrefix:"PICOCLAW_TOOLS_APPEND_FILE_"`
EditFile ToolConfig `json:"edit_file" envPrefix:"PICOCLAW_TOOLS_EDIT_FILE_"`
FindSkills ToolConfig `json:"find_skills" envPrefix:"PICOCLAW_TOOLS_FIND_SKILLS_"`
I2C ToolConfig `json:"i2c" envPrefix:"PICOCLAW_TOOLS_I2C_"`
InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"`
ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
WriteFile ToolConfig `json:"write_file" envPrefix:"PICOCLAW_TOOLS_WRITE_FILE_"`
}
type SearchCacheConfig struct {
@@ -652,8 +676,7 @@ type MCPServerConfig struct {
// MCPConfig defines configuration for all MCP servers
type MCPConfig struct {
// Enabled globally enables/disables MCP integration
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"`
// Servers is a map of server name to server configuration
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
}
@@ -839,3 +862,48 @@ func (c *Config) ValidateModelList() error {
}
return nil
}
func (t *ToolsConfig) IsToolEnabled(name string) bool {
switch name {
case "web":
return t.Web.Enabled
case "cron":
return t.Cron.Enabled
case "exec":
return t.Exec.Enabled
case "skills":
return t.Skills.Enabled
case "media_cleanup":
return t.MediaCleanup.Enabled
case "append_file":
return t.AppendFile.Enabled
case "edit_file":
return t.EditFile.Enabled
case "find_skills":
return t.FindSkills.Enabled
case "i2c":
return t.I2C.Enabled
case "install_skill":
return t.InstallSkill.Enabled
case "list_dir":
return t.ListDir.Enabled
case "message":
return t.Message.Enabled
case "read_file":
return t.ReadFile.Enabled
case "spawn":
return t.Spawn.Enabled
case "spi":
return t.SPI.Enabled
case "subagent":
return t.Subagent.Enabled
case "web_fetch":
return t.WebFetch.Enabled
case "write_file":
return t.WriteFile.Enabled
case "mcp":
return t.MCP.Enabled
default:
return true
}
}
+71 -2
View File
@@ -308,6 +308,20 @@ func DefaultConfig() *Config {
APIKey: "",
},
// Avian - https://avian.io
{
ModelName: "deepseek-v3.2",
Model: "avian/deepseek/deepseek-v3.2",
APIBase: "https://api.avian.io/v1",
APIKey: "",
},
{
ModelName: "kimi-k2.5",
Model: "avian/moonshotai/kimi-k2.5",
APIBase: "https://api.avian.io/v1",
APIKey: "",
},
// VLLM (local) - http://localhost:8000
{
ModelName: "local-model",
@@ -322,11 +336,16 @@ func DefaultConfig() *Config {
},
Tools: ToolsConfig{
MediaCleanup: MediaCleanupConfig{
Enabled: true,
ToolConfig: ToolConfig{
Enabled: true,
},
MaxAge: 30,
Interval: 5,
},
Web: WebToolsConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
Proxy: "",
FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
Brave: BraveConfig{
@@ -357,12 +376,21 @@ func DefaultConfig() *Config {
},
},
Cron: CronToolsConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
ExecTimeoutMinutes: 5,
},
Exec: ExecConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
EnableDenyPatterns: true,
},
Skills: SkillsToolsConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
Registries: SkillsRegistriesConfig{
ClawHub: ClawHubRegistryConfig{
Enabled: true,
@@ -376,9 +404,50 @@ func DefaultConfig() *Config {
},
},
MCP: MCPConfig{
Enabled: false,
ToolConfig: ToolConfig{
Enabled: false,
},
Servers: map[string]MCPServerConfig{},
},
AppendFile: ToolConfig{
Enabled: true,
},
EditFile: ToolConfig{
Enabled: true,
},
FindSkills: ToolConfig{
Enabled: true,
},
I2C: ToolConfig{
Enabled: false, // Hardware tool - Linux only
},
InstallSkill: ToolConfig{
Enabled: true,
},
ListDir: ToolConfig{
Enabled: true,
},
Message: ToolConfig{
Enabled: true,
},
ReadFile: ToolConfig{
Enabled: true,
},
Spawn: ToolConfig{
Enabled: true,
},
SPI: ToolConfig{
Enabled: false, // Hardware tool - Linux only
},
Subagent: ToolConfig{
Enabled: true,
},
WebFetch: ToolConfig{
Enabled: true,
},
WriteFile: ToolConfig{
Enabled: true,
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
+17
View File
@@ -373,6 +373,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
}, true
},
},
{
providerNames: []string{"avian"},
protocol: "avian",
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
if p.Avian.APIKey == "" && p.Avian.APIBase == "" {
return ModelConfig{}, false
}
return ModelConfig{
ModelName: "avian",
Model: "avian/deepseek/deepseek-v3.2",
APIKey: p.Avian.APIKey,
APIBase: p.Avian.APIBase,
Proxy: p.Avian.Proxy,
RequestTimeout: p.Avian.RequestTimeout,
}, true
},
},
}
// Process each provider migration
+4 -3
View File
@@ -160,14 +160,15 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
Antigravity: ProviderConfig{AuthMethod: "oauth"},
Qwen: ProviderConfig{APIKey: "key17"},
Mistral: ProviderConfig{APIKey: "key18"},
Avian: ProviderConfig{APIKey: "key19"},
},
}
result := ConvertProvidersToModelList(cfg)
// All 19 providers should be converted
if len(result) != 19 {
t.Errorf("len(result) = %d, want 19", len(result))
// All 20 providers should be converted
if len(result) != 20 {
t.Errorf("len(result) = %d, want 20", len(result))
}
}
+13 -3
View File
@@ -194,7 +194,9 @@ func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
mgr := NewManager()
mcpCfg := config.MCPConfig{
Enabled: true,
ToolConfig: config.ToolConfig{
Enabled: true,
},
Servers: map[string]config.MCPServerConfig{
"test-server": {
Enabled: true,
@@ -228,12 +230,20 @@ func TestNewManager_InitialState(t *testing.T) {
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
mgr := NewManager()
err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
err := mgr.LoadFromMCPConfig(
context.Background(),
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}},
"/tmp",
)
if err != nil {
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
}
err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
err = mgr.LoadFromMCPConfig(
context.Background(),
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}},
"/tmp",
)
if err != nil {
t.Fatalf("expected nil error when no servers configured, got: %v", err)
}
+79
View File
@@ -31,6 +31,9 @@ type Provider struct {
baseURL string
}
// SupportsThinking implements providers.ThinkingCapable.
func (p *Provider) SupportsThinking() bool { return true }
func NewProvider(token string) *Provider {
return NewProviderWithBaseURL(token, "")
}
@@ -182,9 +185,80 @@ func buildParams(
params.Tools = translateTools(tools)
}
// Extended Thinking / Adaptive Thinking
// The thinking_level value directly determines the API parameter format:
// "adaptive" → {thinking: {type: "adaptive"}} + output_config.effort
// "low/medium/high/xhigh" → {thinking: {type: "enabled", budget_tokens: N}}
if level, ok := options["thinking_level"].(string); ok && level != "" && level != "off" {
applyThinkingConfig(&params, level)
}
return params, nil
}
// applyThinkingConfig sets thinking parameters based on the level value.
// "adaptive" uses the adaptive thinking API (Claude 4.6+).
// All other levels use budget_tokens which is universally supported.
//
// Anthropic API constraint: temperature must not be set when thinking is enabled.
// budget_tokens must be strictly less than max_tokens.
func applyThinkingConfig(params *anthropic.MessageNewParams, level string) {
// Anthropic API rejects requests with temperature set alongside thinking.
// Reset to zero value (omitted from JSON serialization).
if params.Temperature.Valid() {
log.Printf("anthropic: temperature cleared because thinking is enabled (level=%s)", level)
}
params.Temperature = anthropic.MessageNewParams{}.Temperature
if level == "adaptive" {
adaptive := anthropic.NewThinkingConfigAdaptiveParam()
params.Thinking = anthropic.ThinkingConfigParamUnion{OfAdaptive: &adaptive}
params.OutputConfig = anthropic.OutputConfigParam{
Effort: anthropic.OutputConfigEffortHigh,
}
return
}
budget := int64(levelToBudget(level))
if budget <= 0 {
return
}
// budget_tokens must be < max_tokens; clamp to respect user's max_tokens setting.
if budget >= params.MaxTokens {
log.Printf("anthropic: budget_tokens (%d) clamped to %d (max_tokens-1)", budget, params.MaxTokens-1)
budget = params.MaxTokens - 1
} else if budget > params.MaxTokens*80/100 {
log.Printf("anthropic: thinking budget (%d) exceeds 80%% of max_tokens (%d), output may be truncated",
budget, params.MaxTokens)
}
params.Thinking = anthropic.ThinkingConfigParamOfEnabled(budget)
}
// levelToBudget maps a thinking level to budget_tokens.
// Values are based on Anthropic's recommendations and community best practices:
//
// low = 4,096 — simple reasoning, quick debugging (Claude Code "think")
// medium = 16,384 — Anthropic recommended sweet spot for most tasks
// high = 32,000 — complex architecture, deep analysis (diminishing returns above this)
// xhigh = 64,000 — extreme reasoning, research problems, benchmarks
//
// Note: For Claude 4.6+, prefer adaptive thinking over manual budget_tokens.
func levelToBudget(level string) int {
switch level {
case "low":
return 4096
case "medium":
return 16384
case "high":
return 32000
case "xhigh":
return 64000
default:
return 0
}
}
func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
@@ -213,10 +287,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
func parseResponse(resp *anthropic.Message) *LLMResponse {
var content strings.Builder
var reasoning strings.Builder
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "thinking":
tb := block.AsThinking()
reasoning.WriteString(tb.Thinking)
case "text":
tb := block.AsText()
content.WriteString(tb.Text)
@@ -247,6 +325,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
return &LLMResponse{
Content: content.String(),
Reasoning: reasoning.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
+212
View File
@@ -0,0 +1,212 @@
package anthropicprovider
import (
"encoding/json"
"testing"
"github.com/anthropics/anthropic-sdk-go"
)
func TestApplyThinkingConfig_Adaptive(t *testing.T) {
params := anthropic.MessageNewParams{
MaxTokens: 16000,
Temperature: anthropic.Float(0.7),
}
applyThinkingConfig(&params, "adaptive")
if params.Thinking.OfAdaptive == nil {
t.Fatal("expected adaptive thinking")
}
if params.Thinking.OfEnabled != nil {
t.Error("should not set enabled thinking in adaptive mode")
}
if params.OutputConfig.Effort != anthropic.OutputConfigEffortHigh {
t.Errorf("effort = %q, want %q", params.OutputConfig.Effort, anthropic.OutputConfigEffortHigh)
}
if params.Temperature.Valid() {
t.Error("temperature should be cleared when thinking is enabled")
}
}
func TestApplyThinkingConfig_BudgetLevels(t *testing.T) {
tests := []struct {
level string
wantBudget int64
}{
{"low", 4096},
{"medium", 16384},
{"high", 32000},
{"xhigh", 64000},
}
for _, tt := range tests {
t.Run(tt.level, func(t *testing.T) {
params := anthropic.MessageNewParams{
MaxTokens: 200000,
Temperature: anthropic.Float(0.5),
}
applyThinkingConfig(&params, tt.level)
if params.Thinking.OfEnabled == nil {
t.Fatal("expected enabled thinking")
}
if params.Thinking.OfAdaptive != nil {
t.Error("should not set adaptive thinking")
}
if params.Thinking.OfEnabled.BudgetTokens != tt.wantBudget {
t.Errorf("budget_tokens = %d, want %d", params.Thinking.OfEnabled.BudgetTokens, tt.wantBudget)
}
if params.OutputConfig.Effort != "" {
t.Errorf("effort = %q, want empty", params.OutputConfig.Effort)
}
if params.Temperature.Valid() {
t.Error("temperature should be cleared when thinking is enabled")
}
})
}
}
func TestApplyThinkingConfig_BudgetClamp(t *testing.T) {
// budget_tokens must be < max_tokens; clamp budget down to respect user's max_tokens.
params := anthropic.MessageNewParams{MaxTokens: 4096}
applyThinkingConfig(&params, "high") // budget=32000 > maxTokens=4096
if params.Thinking.OfEnabled == nil {
t.Fatal("expected enabled thinking")
}
if params.Thinking.OfEnabled.BudgetTokens != 4095 {
t.Errorf("budget_tokens = %d, want 4095 (maxTokens-1)", params.Thinking.OfEnabled.BudgetTokens)
}
if params.MaxTokens != 4096 {
t.Errorf("max_tokens should not be modified, got %d", params.MaxTokens)
}
}
func TestApplyThinkingConfig_UnknownLevel(t *testing.T) {
params := anthropic.MessageNewParams{MaxTokens: 16000}
applyThinkingConfig(&params, "unknown")
if params.Thinking.OfEnabled != nil {
t.Error("should not set enabled thinking for unknown level")
}
if params.Thinking.OfAdaptive != nil {
t.Error("should not set adaptive thinking for unknown level")
}
}
func TestLevelToBudget(t *testing.T) {
tests := []struct {
name string
level string
want int
}{
{"low", "low", 4096},
{"medium", "medium", 16384},
{"high", "high", 32000},
{"xhigh", "xhigh", 64000},
{"off", "off", 0},
{"empty", "", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := levelToBudget(tt.level); got != tt.want {
t.Errorf("levelToBudget(%q) = %d, want %d", tt.level, got, tt.want)
}
})
}
}
func TestBuildParams_ThinkingClearsTemperature(t *testing.T) {
msgs := []Message{{Role: "user", Content: "hello"}}
opts := map[string]any{
"max_tokens": 200000,
"temperature": 0.8,
"thinking_level": "medium",
}
params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts)
if err != nil {
t.Fatal(err)
}
if params.Temperature.Valid() {
t.Error("temperature should be cleared when thinking_level is set")
}
if params.Thinking.OfEnabled == nil {
t.Fatal("expected enabled thinking")
}
if params.Thinking.OfEnabled.BudgetTokens != 16384 {
t.Errorf("budget_tokens = %d, want 16384", params.Thinking.OfEnabled.BudgetTokens)
}
}
// unmarshalBlocks constructs []ContentBlockUnion via JSON round-trip so that
// the internal JSON.raw field is populated (required by AsText/AsThinking).
func unmarshalBlocks(t *testing.T, jsonStr string) []anthropic.ContentBlockUnion {
t.Helper()
var blocks []anthropic.ContentBlockUnion
if err := json.Unmarshal([]byte(jsonStr), &blocks); err != nil {
t.Fatalf("unmarshalBlocks: %v", err)
}
return blocks
}
func TestParseResponse_ThinkingBlock(t *testing.T) {
resp := &anthropic.Message{
Content: unmarshalBlocks(t, `[
{"type":"thinking","thinking":"Let me reason step by step...","signature":"sig"},
{"type":"text","text":"The answer is 42."}
]`),
StopReason: anthropic.StopReasonEndTurn,
}
result := parseResponse(resp)
if result.Reasoning != "Let me reason step by step..." {
t.Errorf("Reasoning = %q, want thinking content", result.Reasoning)
}
if result.Content != "The answer is 42." {
t.Errorf("Content = %q, want text content", result.Content)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", result.FinishReason)
}
}
func TestParseResponse_NoThinkingBlock(t *testing.T) {
resp := &anthropic.Message{
Content: unmarshalBlocks(t, `[
{"type":"text","text":"Just a normal response."}
]`),
StopReason: anthropic.StopReasonEndTurn,
}
result := parseResponse(resp)
if result.Reasoning != "" {
t.Errorf("Reasoning = %q, want empty", result.Reasoning)
}
if result.Content != "Just a normal response." {
t.Errorf("Content = %q, want text content", result.Content)
}
}
func TestBuildParams_NoThinkingKeepsTemperature(t *testing.T) {
msgs := []Message{{Role: "user", Content: "hello"}}
opts := map[string]any{
"temperature": 0.8,
}
params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts)
if err != nil {
t.Fatal(err)
}
if !params.Temperature.Valid() {
t.Error("temperature should be preserved when thinking is not set")
}
if params.Temperature.Value != 0.8 {
t.Errorf("temperature = %f, want 0.8", params.Temperature.Value)
}
}
+16
View File
@@ -181,6 +181,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.model = "deepseek-chat"
}
}
case "avian":
if cfg.Providers.Avian.APIKey != "" {
sel.apiKey = cfg.Providers.Avian.APIKey
sel.apiBase = cfg.Providers.Avian.APIBase
sel.proxy = cfg.Providers.Avian.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.avian.io/v1"
}
}
case "mistral":
if cfg.Providers.Mistral.APIKey != "" {
sel.apiKey = cfg.Providers.Mistral.APIKey
@@ -300,6 +309,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if sel.apiBase == "" {
sel.apiBase = "https://api.mistral.ai/v1"
}
case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "":
sel.apiKey = cfg.Providers.Avian.APIKey
sel.apiBase = cfg.Providers.Avian.APIBase
sel.proxy = cfg.Providers.Avian.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.avian.io/v1"
}
case cfg.Providers.VLLM.APIBase != "":
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
+3 -1
View File
@@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"volcengine", "vllm", "qwen", "mistral":
"volcengine", "vllm", "qwen", "mistral", "avian":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
@@ -208,6 +208,8 @@ func getDefaultAPIBase(protocol string) string {
return "http://localhost:8000/v1"
case "mistral":
return "https://api.mistral.ai/v1"
case "avian":
return "https://api.avian.io/v1"
default:
return ""
}
+8 -6
View File
@@ -323,12 +323,14 @@ func serializeMessages(messages []Message) []any {
})
}
for _, mediaURL := range m.Media {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
if strings.HasPrefix(mediaURL, "data:image/") {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
}
}
msg := map[string]any{
+7
View File
@@ -37,6 +37,13 @@ type StatefulProvider interface {
Close()
}
// ThinkingCapable is an optional interface for providers that support
// extended thinking (e.g. Anthropic). Used by the agent loop to warn
// when thinking_level is configured but the active provider cannot use it.
type ThinkingCapable interface {
SupportsThinking() bool
}
// FailoverReason classifies why an LLM request failed for fallback decisions.
type FailoverReason string
+64 -16
View File
@@ -259,15 +259,7 @@ func (c *ClawHubRegistry) DownloadAndInstall(
}
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize))
tmpPath, err := c.downloadToTempFileWithRetry(ctx, u.String())
if err != nil {
return nil, fmt.Errorf("download failed: %w", err)
}
@@ -284,17 +276,12 @@ func (c *ClawHubRegistry) DownloadAndInstall(
// --- HTTP helper ---
func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
req, err := c.newGetRequest(ctx, urlStr, "application/json")
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
resp, err := c.client.Do(req)
resp, err := utils.DoRequestWithRetry(c.client, req)
if err != nil {
return nil, err
}
@@ -312,3 +299,64 @@ func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, err
return body, nil
}
func (c *ClawHubRegistry) newGetRequest(ctx context.Context, urlStr, accept string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", accept)
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
return req, nil
}
func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlStr string) (string, error) {
req, err := c.newGetRequest(ctx, urlStr, "application/zip")
if err != nil {
return "", err
}
resp, err := utils.DoRequestWithRetry(c.client, req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
errBody := make([]byte, 512)
n, _ := io.ReadFull(resp.Body, errBody)
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n]))
}
tmpFile, err := os.CreateTemp("", "picoclaw-dl-*")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
cleanup := func() {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
src := io.LimitReader(resp.Body, int64(c.maxZipSize)+1)
written, err := io.Copy(tmpFile, src)
if err != nil {
cleanup()
return "", fmt.Errorf("download write failed: %w", err)
}
if written > int64(c.maxZipSize) {
cleanup()
return "", fmt.Errorf("download too large: %d bytes (max %d)", written, c.maxZipSize)
}
if err := tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Errorf("failed to close temp file: %w", err)
}
return tmpPath, nil
}
+81
View File
@@ -54,6 +54,39 @@ func TestClawHubRegistrySearch(t *testing.T) {
assert.Equal(t, "clawhub", results[0].RegistryName)
}
func TestClawHubRegistrySearchRetries429(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", "0")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
slug := "github"
name := "GitHub Integration"
summary := "Interact with GitHub repos"
version := "1.0.0"
json.NewEncoder(w).Encode(clawhubSearchResponse{
Results: []clawhubSearchResult{
{Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version},
},
})
}))
defer srv.Close()
reg := newTestRegistry(srv.URL, "")
results, err := reg.Search(context.Background(), "github", 5)
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, 2, attempts)
assert.Equal(t, "github", results[0].Slug)
}
func TestClawHubRegistryGetSkillMeta(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/v1/skills/github", r.URL.Path)
@@ -137,6 +170,54 @@ func TestClawHubRegistryDownloadAndInstall(t *testing.T) {
assert.Contains(t, string(readmeContent), "# Test Skill")
}
func TestClawHubRegistryDownloadAndInstallRetries429(t *testing.T) {
zipBuf := createTestZip(t, map[string]string{
"SKILL.md": "---\nname: retry-skill\ndescription: A test\n---\nHello skill",
})
downloadAttempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/skills/retry-skill":
json.NewEncoder(w).Encode(clawhubSkillResponse{
Slug: "retry-skill",
DisplayName: "Retry Skill",
Summary: "A retry test skill",
LatestVersion: &clawhubVersionInfo{Version: "1.0.0"},
})
case "/api/v1/download":
downloadAttempts++
if downloadAttempts == 1 {
w.Header().Set("Retry-After", "0")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
assert.Equal(t, "retry-skill", r.URL.Query().Get("slug"))
w.Header().Set("Content-Type", "application/zip")
w.Write(zipBuf)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
tmpDir := t.TempDir()
targetDir := filepath.Join(tmpDir, "retry-skill")
reg := newTestRegistry(srv.URL, "")
result, err := reg.DownloadAndInstall(context.Background(), "retry-skill", "", targetDir)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, "1.0.0", result.Version)
assert.Equal(t, 2, downloadAttempts)
skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md"))
require.NoError(t, err)
assert.Contains(t, string(skillContent), "Hello skill")
}
func TestClawHubRegistryAuthToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
+50 -38
View File
@@ -10,11 +10,38 @@ type Tool interface {
Execute(ctx context.Context, args map[string]any) *ToolResult
}
// ContextualTool is an optional interface that tools can implement
// to receive the current message context (channel, chatID)
type ContextualTool interface {
Tool
SetContext(channel, chatID string)
// --- Request-scoped tool context (channel / chatID) ---
//
// Carried via context.Value so that concurrent tool calls each receive
// their own immutable copy — no mutable state on singleton tool instances.
//
// Keys are unexported pointer-typed vars — guaranteed collision-free,
// and only accessible through the helper functions below.
type toolCtxKey struct{ name string }
var (
ctxKeyChannel = &toolCtxKey{"channel"}
ctxKeyChatID = &toolCtxKey{"chatID"}
)
// WithToolContext returns a child context carrying channel and chatID.
func WithToolContext(ctx context.Context, channel, chatID string) context.Context {
ctx = context.WithValue(ctx, ctxKeyChannel, channel)
ctx = context.WithValue(ctx, ctxKeyChatID, chatID)
return ctx
}
// ToolChannel extracts the channel from ctx, or "" if unset.
func ToolChannel(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyChannel).(string)
return v
}
// ToolChatID extracts the chatID from ctx, or "" if unset.
func ToolChatID(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyChatID).(string)
return v
}
// AsyncCallback is a function type that async tools use to notify completion.
@@ -22,51 +49,36 @@ type ContextualTool interface {
//
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
// The result parameter contains the tool's execution result.
//
// Example usage in an async tool:
//
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// // Start async work in background
// go func() {
// result := doAsyncWork()
// if t.callback != nil {
// t.callback(ctx, result)
// }
// }()
// return AsyncResult("Async task started")
// }
type AsyncCallback func(ctx context.Context, result *ToolResult)
// AsyncTool is an optional interface that tools can implement to support
// AsyncExecutor is an optional interface that tools can implement to support
// asynchronous execution with completion callbacks.
//
// Async tools return immediately with an AsyncResult, then notify completion
// via the callback set by SetCallback.
// Unlike the old AsyncTool pattern (SetCallback + Execute), AsyncExecutor
// receives the callback as a parameter of ExecuteAsync. This eliminates the
// data race where concurrent calls could overwrite each other's callbacks
// on a shared tool instance.
//
// This is useful for:
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
//
// Example:
//
// type SpawnTool struct {
// callback AsyncCallback
// }
//
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
// t.callback = cb
// }
//
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// go t.runSubagent(ctx, args)
// func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
// go func() {
// result := t.runSubagent(ctx, args)
// if cb != nil { cb(ctx, result) }
// }()
// return AsyncResult("Subagent spawned, will report back")
// }
type AsyncTool interface {
type AsyncExecutor interface {
Tool
// SetCallback registers a callback function to be invoked when the async operation completes.
// The callback will be called from a goroutine and should handle thread-safety if needed.
SetCallback(cb AsyncCallback)
// ExecuteAsync runs the tool asynchronously. The callback cb will be
// invoked (possibly from another goroutine) when the async operation
// completes. cb is guaranteed to be non-nil by the caller (registry).
ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult
}
func ToolToSchema(tool Tool) map[string]any {
+4 -18
View File
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -24,9 +23,6 @@ type CronTool struct {
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
channel string
chatID string
mu sync.RWMutex
}
// NewCronTool creates a new CronTool
@@ -102,14 +98,6 @@ func (t *CronTool) Parameters() map[string]any {
}
}
// SetContext sets the current session context for job creation
func (t *CronTool) SetContext(channel, chatID string) {
t.mu.Lock()
defer t.mu.Unlock()
t.channel = channel
t.chatID = chatID
}
// Execute runs the tool with the given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
action, ok := args["action"].(string)
@@ -119,7 +107,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
switch action {
case "add":
return t.addJob(args)
return t.addJob(ctx, args)
case "list":
return t.listJobs()
case "remove":
@@ -133,11 +121,9 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult
}
}
func (t *CronTool) addJob(args map[string]any) *ToolResult {
t.mu.RLock()
channel := t.channel
chatID := t.chatID
t.mu.RUnlock()
func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult {
channel := ToolChannel(ctx)
chatID := ToolChatID(ctx)
if channel == "" || chatID == "" {
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
+8 -10
View File
@@ -9,10 +9,8 @@ import (
type SendCallback func(channel, chatID, content string) error
type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
sendCallback SendCallback
sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -48,10 +46,10 @@ func (t *MessageTool) Parameters() map[string]any {
}
}
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
t.sentInRound.Store(false) // Reset send tracking for new processing round
// ResetSentInRound resets the per-round send tracker.
// Called by the agent loop at the start of each inbound message processing round.
func (t *MessageTool) ResetSentInRound() {
t.sentInRound.Store(false)
}
// HasSentInRound returns true if the message tool sent a message during the current round.
@@ -73,10 +71,10 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
chatID, _ := args["chat_id"].(string)
if channel == "" {
channel = t.defaultChannel
channel = ToolChannel(ctx)
}
if chatID == "" {
chatID = t.defaultChatID
chatID = ToolChatID(ctx)
}
if channel == "" || chatID == "" {
+6 -11
View File
@@ -8,7 +8,6 @@ import (
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content string) error {
@@ -18,7 +17,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
return nil
})
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Hello, world!",
}
@@ -60,7 +59,6 @@ func TestMessageTool_Execute_Success(t *testing.T) {
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("default-channel", "default-chat-id")
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content string) error {
@@ -69,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
return nil
})
ctx := context.Background()
ctx := WithToolContext(context.Background(), "default-channel", "default-chat-id")
args := map[string]any{
"content": "Test message",
"channel": "custom-channel",
@@ -96,14 +94,13 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content string) error {
return sendErr
})
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Test message",
}
@@ -133,9 +130,8 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{} // content missing
result := tool.Execute(ctx, args)
@@ -151,7 +147,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No SetContext called, so defaultChannel and defaultChatID are empty
// No WithToolContext — channel/chatID are empty
tool.SetSendCallback(func(channel, chatID, content string) error {
return nil
@@ -175,10 +171,9 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
// No SetSendCallback called
ctx := context.Background()
ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id")
args := map[string]any{
"content": "Test message",
}
+15 -13
View File
@@ -45,8 +45,9 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string
}
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
// If the tool implements AsyncExecutor and a non-nil callback is provided,
// ExecuteAsync is called instead of Execute — the callback is a parameter,
// never stored as mutable state on the tool.
func (r *ToolRegistry) ExecuteWithContext(
ctx context.Context,
name string,
@@ -69,22 +70,23 @@ func (r *ToolRegistry) ExecuteWithContext(
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
}
// If tool implements ContextualTool, set context
if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" {
contextualTool.SetContext(channel, chatID)
}
// Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx).
// Always inject — tools validate what they require.
ctx = WithToolContext(ctx, channel, chatID)
// If tool implements AsyncTool and callback is provided, set callback
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
// If tool implements AsyncExecutor and callback is provided, use ExecuteAsync.
// The callback is a call parameter, not mutable state on the tool instance.
var result *ToolResult
start := time.Now()
if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil {
logger.DebugCF("tool", "Executing async tool via ExecuteAsync",
map[string]any{
"tool": name,
})
result = asyncExec.ExecuteAsync(ctx, args, asyncCallback)
} else {
result = tool.Execute(ctx, args)
}
start := time.Now()
result := tool.Execute(ctx, args)
duration := time.Since(start)
// Log based on result type
+32 -22
View File
@@ -25,24 +25,24 @@ func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolRes
return m.result
}
type mockCtxTool struct {
type mockContextAwareTool struct {
mockRegistryTool
channel string
chatID string
lastCtx context.Context
}
func (m *mockCtxTool) SetContext(channel, chatID string) {
m.channel = channel
m.chatID = chatID
func (m *mockContextAwareTool) Execute(ctx context.Context, _ map[string]any) *ToolResult {
m.lastCtx = ctx
return m.result
}
type mockAsyncRegistryTool struct {
mockRegistryTool
cb AsyncCallback
lastCB AsyncCallback
}
func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
m.cb = cb
func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
m.lastCB = cb
return m.result
}
// --- helpers ---
@@ -136,34 +136,44 @@ func TestToolRegistry_Execute_NotFound(t *testing.T) {
}
}
func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
func TestToolRegistry_ExecuteWithContext_InjectsToolContext(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
ct := &mockContextAwareTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
if ct.channel != "telegram" {
t.Errorf("expected channel 'telegram', got %q", ct.channel)
if ct.lastCtx == nil {
t.Fatal("expected Execute to be called")
}
if ct.chatID != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
if got := ToolChannel(ct.lastCtx); got != "telegram" {
t.Errorf("expected channel 'telegram', got %q", got)
}
if got := ToolChatID(ct.lastCtx); got != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", got)
}
}
func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
func TestToolRegistry_ExecuteWithContext_EmptyContext(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
ct := &mockContextAwareTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
if ct.channel != "" || ct.chatID != "" {
t.Error("SetContext should not be called with empty channel/chatID")
if ct.lastCtx == nil {
t.Fatal("expected Execute to be called")
}
// Empty values are still injected; tools decide what to do with them.
if got := ToolChannel(ct.lastCtx); got != "" {
t.Errorf("expected empty channel, got %q", got)
}
if got := ToolChatID(ct.lastCtx); got != "" {
t.Errorf("expected empty chatID, got %q", got)
}
}
@@ -179,14 +189,14 @@ func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
cb := func(_ context.Context, _ *ToolResult) { called = true }
result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
if at.cb == nil {
t.Error("expected SetCallback to have been called")
if at.lastCB == nil {
t.Error("expected ExecuteAsync to have received a callback")
}
if !result.Async {
t.Error("expected async result")
}
at.cb(context.Background(), SilentResult("done"))
at.lastCB(context.Background(), SilentResult("done"))
if !called {
t.Error("expected callback to be invoked")
}
+27 -17
View File
@@ -8,25 +8,18 @@ import (
type SpawnTool struct {
manager *SubagentManager
originChannel string
originChatID string
allowlistCheck func(targetAgentID string) bool
callback AsyncCallback // For async completion notification
}
// Compile-time check: SpawnTool implements AsyncExecutor.
var _ AsyncExecutor = (*SpawnTool)(nil)
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
return &SpawnTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
manager: manager,
}
}
// SetCallback implements AsyncTool interface for async completion notification
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
t.callback = cb
}
func (t *SpawnTool) Name() string {
return "spawn"
}
@@ -56,16 +49,21 @@ func (t *SpawnTool) Parameters() map[string]any {
}
}
func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
t.allowlistCheck = check
}
func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
return t.execute(ctx, args, nil)
}
// ExecuteAsync implements AsyncExecutor. The callback is passed through to the
// subagent manager as a call parameter — never stored on the SpawnTool instance.
func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
return t.execute(ctx, args, cb)
}
func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
task, ok := args["task"].(string)
if !ok || strings.TrimSpace(task) == "" {
return ErrorResult("task is required and must be a non-empty string")
@@ -85,8 +83,20 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul
return ErrorResult("Subagent manager not configured")
}
// Read channel/chatID from context (injected by registry).
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
// to preserve the same defaults as the original NewSpawnTool constructor.
channel := ToolChannel(ctx)
if channel == "" {
channel = "cli"
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = "direct"
}
// Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback)
result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
}
+14 -12
View File
@@ -252,16 +252,12 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
type SubagentTool struct {
manager *SubagentManager
originChannel string
originChatID string
manager *SubagentManager
}
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
return &SubagentTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
manager: manager,
}
}
@@ -290,11 +286,6 @@ func (t *SubagentTool) Parameters() map[string]any {
}
}
func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
task, ok := args["task"].(string)
if !ok {
@@ -341,13 +332,24 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe
}
}
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
// to preserve the same defaults as the original NewSubagentTool constructor.
channel := ToolChannel(ctx)
if channel == "" {
channel = "cli"
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = "direct"
}
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: llmOptions,
}, messages, t.originChannel, t.originChatID)
}, messages, channel, chatID)
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
+3 -21
View File
@@ -50,9 +50,8 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager.SetLLMOptions(2048, 0.6)
tool := NewSubagentTool(manager)
tool.SetContext("cli", "direct")
ctx := context.Background()
ctx := WithToolContext(context.Background(), "cli", "direct")
args := map[string]any{"task": "Do something"}
result := tool.Execute(ctx, args)
@@ -147,28 +146,14 @@ func TestSubagentTool_Parameters(t *testing.T) {
}
}
// TestSubagentTool_SetContext verifies context setting
func TestSubagentTool_SetContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
tool.SetContext("test-channel", "test-chat")
// Verify context is set (we can't directly access private fields,
// but we can verify it doesn't crash)
// The actual context usage is tested in Execute tests
}
// TestSubagentTool_Execute_Success tests successful execution
func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
args := map[string]any{
"task": "Write a haiku about coding",
"label": "haiku-task",
@@ -297,12 +282,9 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
// Set context
channel := "test-channel"
chatID := "test-chat"
tool.SetContext(channel, chatID)
ctx := context.Background()
ctx := WithToolContext(context.Background(), channel, chatID)
args := map[string]any{
"task": "Test context passing",
}