diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go new file mode 100644 index 000000000..54a5396e7 --- /dev/null +++ b/pkg/agent/instance.go @@ -0,0 +1,145 @@ +package agent + +import ( + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// AgentInstance represents a fully configured agent with its own workspace, +// session manager, context builder, and tool registry. +type AgentInstance struct { + ID string + Name string + Model string + Fallbacks []string + Workspace string + MaxIterations int + ContextWindow int + Provider providers.LLMProvider + Sessions *session.SessionManager + ContextBuilder *ContextBuilder + Tools *tools.ToolRegistry + Subagents *config.SubagentsConfig + SkillsFilter []string + Candidates []providers.FallbackCandidate +} + +// NewAgentInstance creates an agent instance from config. +func NewAgentInstance( + agentCfg *config.AgentConfig, + defaults *config.AgentDefaults, + cfg *config.Config, + provider providers.LLMProvider, +) *AgentInstance { + workspace := resolveAgentWorkspace(agentCfg, defaults) + os.MkdirAll(workspace, 0755) + + model := resolveAgentModel(agentCfg, defaults) + fallbacks := resolveAgentFallbacks(agentCfg, defaults) + + restrict := defaults.RestrictToWorkspace + toolsRegistry := tools.NewToolRegistry() + toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) + toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg)) + toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) + + sessionsDir := filepath.Join(workspace, "sessions") + sessionsManager := session.NewSessionManager(sessionsDir) + + contextBuilder := NewContextBuilder(workspace) + contextBuilder.SetToolsRegistry(toolsRegistry) + + agentID := routing.DefaultAgentID + agentName := "" + var subagents *config.SubagentsConfig + var skillsFilter []string + + if agentCfg != nil { + agentID = routing.NormalizeAgentID(agentCfg.ID) + agentName = agentCfg.Name + subagents = agentCfg.Subagents + skillsFilter = agentCfg.Skills + } + + maxIter := defaults.MaxToolIterations + if maxIter == 0 { + maxIter = 20 + } + + // Resolve fallback candidates + modelCfg := providers.ModelConfig{ + Primary: model, + Fallbacks: fallbacks, + } + candidates := providers.ResolveCandidates(modelCfg, defaults.Provider) + + return &AgentInstance{ + ID: agentID, + Name: agentName, + Model: model, + Fallbacks: fallbacks, + Workspace: workspace, + MaxIterations: maxIter, + ContextWindow: defaults.MaxTokens, + Provider: provider, + Sessions: sessionsManager, + ContextBuilder: contextBuilder, + Tools: toolsRegistry, + Subagents: subagents, + SkillsFilter: skillsFilter, + Candidates: candidates, + } +} + +// resolveAgentWorkspace determines the workspace directory for an agent. +func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string { + if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" { + return expandHome(strings.TrimSpace(agentCfg.Workspace)) + } + if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" { + return expandHome(defaults.Workspace) + } + home, _ := os.UserHomeDir() + id := routing.NormalizeAgentID(agentCfg.ID) + return filepath.Join(home, ".picoclaw", "workspace-"+id) +} + +// resolveAgentModel resolves the primary model for an agent. +func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string { + if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" { + return strings.TrimSpace(agentCfg.Model.Primary) + } + return defaults.Model +} + +// resolveAgentFallbacks resolves the fallback models for an agent. +func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) []string { + if agentCfg != nil && agentCfg.Model != nil && agentCfg.Model.Fallbacks != nil { + return agentCfg.Model.Fallbacks + } + return defaults.ModelFallbacks +} + +func expandHome(path string) string { + if path == "" { + return path + } + if path[0] == '~' { + home, _ := os.UserHomeDir() + if len(path) > 1 && path[1] == '/' { + return home + path[1:] + } + return home + } + return path +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 8c6c58c96..ed69712ff 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -10,8 +10,6 @@ import ( "context" "encoding/json" "fmt" - "os" - "path/filepath" "strings" "sync" "sync/atomic" @@ -24,7 +22,7 @@ import ( "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" - "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" @@ -32,17 +30,12 @@ import ( type AgentLoop struct { bus *bus.MessageBus - provider providers.LLMProvider - workspace string - model string - contextWindow int // Maximum context window size in tokens - maxIterations int - sessions *session.SessionManager + cfg *config.Config + registry *AgentRegistry state *state.Manager - contextBuilder *ContextBuilder - tools *tools.ToolRegistry running atomic.Bool - summarizing sync.Map // Tracks which sessions are currently being summarized + summarizing sync.Map + fallback *providers.FallbackChain channelManager *channels.Manager } @@ -58,99 +51,83 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } -// createToolRegistry creates a tool registry with common tools. -// This is shared between main agent and subagents. -func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry { - registry := tools.NewToolRegistry() - - // File system tools - registry.Register(tools.NewReadFileTool(workspace, restrict)) - registry.Register(tools.NewWriteFileTool(workspace, restrict)) - registry.Register(tools.NewListDirTool(workspace, restrict)) - registry.Register(tools.NewEditFileTool(workspace, restrict)) - registry.Register(tools.NewAppendFileTool(workspace, restrict)) - - // Shell execution - registry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg)) - - if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ - BraveAPIKey: cfg.Tools.Web.Brave.APIKey, - BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, - BraveEnabled: cfg.Tools.Web.Brave.Enabled, - DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, - DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, - PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, - PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, - PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, - }); searchTool != nil { - registry.Register(searchTool) - } - registry.Register(tools.NewWebFetchTool(50000)) - - // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms - registry.Register(tools.NewI2CTool()) - registry.Register(tools.NewSPITool()) - - // Message tool - available to both agent and subagent - // Subagent uses it to communicate directly with user - messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, - }) - return nil - }) - registry.Register(messageTool) - - return registry -} - func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { - workspace := cfg.WorkspacePath() - os.MkdirAll(workspace, 0755) + registry := NewAgentRegistry(cfg, provider) - restrict := cfg.Agents.Defaults.RestrictToWorkspace + // Register shared tools to all agents + registerSharedTools(cfg, msgBus, registry, provider) - // Create tool registry for main agent - toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus) + // Set up shared fallback chain + cooldown := providers.NewCooldownTracker() + fallbackChain := providers.NewFallbackChain(cooldown) - // Create subagent manager with its own tool registry - subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus) - subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus) - // Subagent doesn't need spawn/subagent tools to avoid recursion - subagentManager.SetTools(subagentTools) - - // Register spawn tool (for main agent) - spawnTool := tools.NewSpawnTool(subagentManager) - toolsRegistry.Register(spawnTool) - - // Register subagent tool (synchronous execution) - subagentTool := tools.NewSubagentTool(subagentManager) - toolsRegistry.Register(subagentTool) - - sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) - - // Create state manager for atomic state persistence - stateManager := state.NewManager(workspace) - - // Create context builder and set tools registry - contextBuilder := NewContextBuilder(workspace) - contextBuilder.SetToolsRegistry(toolsRegistry) + // Create state manager using default agent's workspace for channel recording + defaultAgent := registry.GetDefaultAgent() + var stateManager *state.Manager + if defaultAgent != nil { + stateManager = state.NewManager(defaultAgent.Workspace) + } return &AgentLoop{ - bus: msgBus, - provider: provider, - workspace: workspace, - model: cfg.Agents.Defaults.Model, - contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization - maxIterations: cfg.Agents.Defaults.MaxToolIterations, - sessions: sessionsManager, - state: stateManager, - contextBuilder: contextBuilder, - tools: toolsRegistry, - summarizing: sync.Map{}, + bus: msgBus, + cfg: cfg, + registry: registry, + state: stateManager, + summarizing: sync.Map{}, + fallback: fallbackChain, + } +} + +// registerSharedTools registers tools that are shared across all agents (web, message, spawn). +func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *AgentRegistry, provider providers.LLMProvider) { + for _, agentID := range registry.ListAgentIDs() { + agent, ok := registry.GetAgent(agentID) + if !ok { + continue + } + + // Web tools + if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, + BraveEnabled: cfg.Tools.Web.Brave.Enabled, + DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, + DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, + PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, + PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, + PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, + }); searchTool != nil { + agent.Tools.Register(searchTool) + } + agent.Tools.Register(tools.NewWebFetchTool(50000)) + + // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms + agent.Tools.Register(tools.NewI2CTool()) + agent.Tools.Register(tools.NewSPITool()) + + // Message tool + messageTool := tools.NewMessageTool() + messageTool.SetSendCallback(func(channel, chatID, content string) error { + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) + return nil + }) + agent.Tools.Register(messageTool) + + // Spawn tool with allowlist checker + subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) + spawnTool := tools.NewSpawnTool(subagentManager) + currentAgentID := agentID + spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { + return registry.CanSpawnSubagent(currentAgentID, targetAgentID) + }) + agent.Tools.Register(spawnTool) + + // Update context builder with the complete tools registry + agent.ContextBuilder.SetToolsRegistry(agent.Tools) } } @@ -175,10 +152,14 @@ func (al *AgentLoop) Run(ctx context.Context) error { if response != "" { // Check if the message tool already sent a response during this round. // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). alreadySent := false - if tool, ok := al.tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } } } @@ -201,7 +182,11 @@ func (al *AgentLoop) Stop() { } func (al *AgentLoop) RegisterTool(tool tools.Tool) { - al.tools.Register(tool) + for _, agentID := range al.registry.ListAgentIDs() { + if agent, ok := al.registry.GetAgent(agentID); ok { + agent.Tools.Register(tool) + } + } } func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { @@ -211,12 +196,18 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { + if al.state == nil { + return nil + } return al.state.SetLastChannel(channel) } // RecordLastChatID records the last active chat ID for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChatID(chatID string) error { + if al.state == nil { + return nil + } return al.state.SetLastChatID(chatID) } @@ -239,7 +230,8 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess // ProcessHeartbeat processes a heartbeat request without session history. // Each heartbeat is independent and doesn't accumulate context. func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) { - return al.runAgentLoop(ctx, processOptions{ + agent := al.registry.GetDefaultAgent() + return al.runAgentLoop(ctx, agent, processOptions{ SessionKey: "heartbeat", Channel: channel, ChatID: chatID, @@ -277,9 +269,36 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return response, nil } - // Process as user message - return al.runAgentLoop(ctx, processOptions{ - SessionKey: msg.SessionKey, + // Route to determine agent and session key + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + AccountID: msg.Metadata["account_id"], + Peer: extractPeer(msg), + ParentPeer: extractParentPeer(msg), + GuildID: msg.Metadata["guild_id"], + TeamID: msg.Metadata["team_id"], + }) + + agent, ok := al.registry.GetAgent(route.AgentID) + if !ok { + agent = al.registry.GetDefaultAgent() + } + + // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) + sessionKey := route.SessionKey + if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") { + sessionKey = msg.SessionKey + } + + logger.InfoCF("agent", "Routed message", + map[string]interface{}{ + "agent_id": agent.ID, + "session_key": sessionKey, + "matched_by": route.MatchedBy, + }) + + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, @@ -290,7 +309,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - // Verify this is a system message if msg.Channel != "system" { return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel) } @@ -302,12 +320,13 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe }) // Parse origin channel from chat_id (format: "channel:chat_id") - var originChannel string + var originChannel, originChatID string if idx := strings.Index(msg.ChatID, ":"); idx > 0 { originChannel = msg.ChatID[:idx] + originChatID = msg.ChatID[idx+1:] } else { - // Fallback originChannel = "cli" + originChatID = msg.ChatID } // Extract subagent result from message content @@ -328,44 +347,47 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe return "", nil } - // Agent acts as dispatcher only - subagent handles user interaction via message tool - // Don't forward result here, subagent should use message tool to communicate with user - logger.InfoCF("agent", "Subagent completed", - map[string]interface{}{ - "sender_id": msg.SenderID, - "channel": originChannel, - "content_len": len(content), - }) + // Use default agent for system messages + agent := al.registry.GetDefaultAgent() - // Agent only logs, does not respond to user - return "", nil + // Use the origin session for context + sessionKey := routing.BuildAgentMainSessionKey(agent.ID) + + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, + Channel: originChannel, + ChatID: originChatID, + UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), + DefaultResponse: "Background task completed.", + EnableSummary: false, + SendResponse: true, + }) } // runAgentLoop is the core message processing logic. -// It handles context building, LLM calls, tool execution, and response handling. -func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { +func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) { // 0. Record last channel for heartbeat notifications (skip internal channels) if opts.Channel != "" && opts.ChatID != "" { // Don't record internal channels (cli, system, subagent) if !constants.IsInternalChannel(opts.Channel) { channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) if err := al.RecordLastChannel(channelKey); err != nil { - logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()}) + logger.WarnCF("agent", "Failed to record last channel", map[string]interface{}{"error": err.Error()}) } } } // 1. Update tool contexts - al.updateToolContexts(opts.Channel, opts.ChatID) + al.updateToolContexts(agent, opts.Channel, opts.ChatID) // 2. Build messages (skip history for heartbeat) var history []providers.Message var summary string if !opts.NoHistory { - history = al.sessions.GetHistory(opts.SessionKey) - summary = al.sessions.GetSummary(opts.SessionKey) + history = agent.Sessions.GetHistory(opts.SessionKey) + summary = agent.Sessions.GetSummary(opts.SessionKey) } - messages := al.contextBuilder.BuildMessages( + messages := agent.ContextBuilder.BuildMessages( history, summary, opts.UserMessage, @@ -375,10 +397,10 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str ) // 3. Save user message to session - al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) + agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) // 4. Run LLM iteration loop - finalContent, iteration, err := al.runLLMIteration(ctx, messages, opts) + finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err } @@ -392,12 +414,12 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // 6. Save final assistant message to session - al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) - al.sessions.Save(opts.SessionKey) + agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent) + agent.Sessions.Save(opts.SessionKey) // 7. Optional: summarization if opts.EnableSummary { - al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID) + al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) } // 8. Optional: send response via bus @@ -413,6 +435,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str responsePreview := utils.Truncate(finalContent, 120) logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), map[string]interface{}{ + "agent_id": agent.ID, "session_key": opts.SessionKey, "iterations": iteration, "final_length": len(finalContent), @@ -422,28 +445,29 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // runLLMIteration executes the LLM call loop with tool handling. -// Returns the final content, iteration count, and any error. -func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.Message, opts processOptions) (string, int, error) { +func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, messages []providers.Message, opts processOptions) (string, int, error) { iteration := 0 var finalContent string - for iteration < al.maxIterations { + for iteration < agent.MaxIterations { iteration++ logger.DebugCF("agent", "LLM iteration", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, - "max": al.maxIterations, + "max": agent.MaxIterations, }) // Build tool definitions - providerToolDefs := al.tools.ToProviderDefs() + providerToolDefs := agent.Tools.ToProviderDefs() // Log LLM request details logger.DebugCF("agent", "LLM request", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, - "model": al.model, + "model": agent.Model, "messages_count": len(messages), "tools_count": len(providerToolDefs), "max_tokens": 8192, @@ -459,23 +483,45 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "tools_json": formatToolsForLog(providerToolDefs), }) + // Call LLM with fallback chain if candidates are configured. var response *providers.LLMResponse var err error + callLLM := func() (*providers.LLMResponse, error) { + if len(agent.Candidates) > 1 && al.fallback != nil { + fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates, + func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { + return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + }, + ) + if fbErr != nil { + return nil, fbErr + } + if fbResult.Provider != "" && len(fbResult.Attempts) > 0 { + logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", + fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), + map[string]interface{}{"agent_id": agent.ID, "iteration": iteration}) + } + return fbResult.Response, nil + } + return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + } + // Retry loop for context/token errors maxRetries := 2 for retry := 0; retry <= maxRetries; retry++ { - response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, - }) - + response, err = callLLM() if err == nil { - break // Success + break } errMsg := strings.ToLower(err.Error()) - // Check for context window errors (provider specific, but usually contain "token" or "invalid") isContextError := strings.Contains(errMsg, "token") || strings.Contains(errMsg, "context") || strings.Contains(errMsg, "invalidparameter") || @@ -487,107 +533,30 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "retry": retry, }) - // Notify user on first retry only - if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse { + if retry == 0 && !constants.IsInternalChannel(opts.Channel) { al.bus.PublishOutbound(bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, - Content: "⚠️ Context window exceeded. Compressing history and retrying...", + Content: "Context window exceeded. Compressing history and retrying...", }) } - // Force compression - al.forceCompression(opts.SessionKey) - - // Rebuild messages with compressed history - // Note: We need to reload history from session manager because forceCompression changed it - newHistory := al.sessions.GetHistory(opts.SessionKey) - newSummary := al.sessions.GetSummary(opts.SessionKey) - - // Re-create messages for the next attempt - // We keep the current user message (opts.UserMessage) effectively - messages = al.contextBuilder.BuildMessages( - newHistory, - newSummary, - opts.UserMessage, - nil, - opts.Channel, - opts.ChatID, + al.forceCompression(agent, opts.SessionKey) + newHistory := agent.Sessions.GetHistory(opts.SessionKey) + newSummary := agent.Sessions.GetSummary(opts.SessionKey) + messages = agent.ContextBuilder.BuildMessages( + newHistory, newSummary, "", + nil, opts.Channel, opts.ChatID, ) - - // Important: If we are in the middle of a tool loop (iteration > 1), - // rebuilding messages from session history might duplicate the flow or miss context - // if intermediate steps weren't saved correctly. - // However, al.sessions.AddFullMessage is called after every tool execution, - // so GetHistory should reflect the current state including partial tool execution. - // But we need to ensure we don't duplicate the user message which is appended in BuildMessages. - // BuildMessages(history...) takes the stored history and appends the *current* user message. - // If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop. - // So if we pass opts.UserMessage again, we might duplicate it? - // Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) - // So GetHistory ALREADY contains the user message! - - // CORRECTION: - // BuildMessages combines: [System] + [History] + [CurrentMessage] - // But Step 3 added CurrentMessage to History. - // So if we use GetHistory now, it has the user message. - // If we pass opts.UserMessage to BuildMessages, it adds it AGAIN. - - // For retry in the middle of a loop, we should rely on what's in the session. - // BUT checking BuildMessages implementation: - // It appends history... then appends currentMessage. - - // Logic fix for retry: - // If iteration == 1, opts.UserMessage corresponds to the user input. - // If iteration > 1, we are processing tool results. The "messages" passed to Chat - // already accumulated tool outputs. - // Rebuilding from session history is safest because it persists state. - // Start fresh with rebuilt history. - - // Special case: standard BuildMessages appends "currentMessage". - // If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed. - // However, the "messages" argument passed to runLLMIteration is constructed by the caller. - // If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history. - - // In runAgentLoop: - // 3. sessions.AddMessage(userMsg) - // 4. runLLMIteration(..., UserMessage) - - // So History contains the user message. - // BuildMessages typically appends the user message as a *new* pending message. - // Wait, standard BuildMessages usage in runAgentLoop: - // messages := BuildMessages(history (has old), UserMessage) - // THEN AddMessage(UserMessage). - // So "history" passed to BuildMessages does NOT contain the current UserMessage yet. - - // But here, inside the loop, we have already saved it. - // So GetHistory() includes the current user message. - // If we call BuildMessages(GetHistory(), UserMessage), we get duplicates. - - // Hack/Fix: - // If we are retrying, we rebuild from Session History ONLY. - // We pass empty string as "currentMessage" to BuildMessages - // because the "current message" is already saved in history (step 3). - - messages = al.contextBuilder.BuildMessages( - newHistory, - newSummary, - "", // Empty because history already contains the relevant messages - nil, - opts.Channel, - opts.ChatID, - ) - continue } - - // Real error or success, break loop break } if err != nil { logger.ErrorCF("agent", "LLM call failed", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, "error": err.Error(), }) @@ -599,6 +568,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M finalContent = response.Content logger.InfoCF("agent", "LLM response without tool calls (direct answer)", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, "content_chars": len(finalContent), }) @@ -612,6 +582,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } logger.InfoCF("agent", "LLM requested tool calls", map[string]interface{}{ + "agent_id": agent.ID, "tools": toolNames, "count": len(response.ToolCalls), "iteration": iteration, @@ -636,15 +607,15 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M messages = append(messages, assistantMsg) // Save assistant message with tool calls to session - al.sessions.AddFullMessage(opts.SessionKey, assistantMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) // Execute tool calls for _, tc := range response.ToolCalls { - // Log tool call with arguments preview argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), map[string]interface{}{ + "agent_id": agent.ID, "tool": tc.Name, "iteration": iteration, }) @@ -665,7 +636,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } } - toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback) + toolResult := agent.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback) // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { @@ -695,7 +666,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M messages = append(messages, toolResultMsg) // Save tool result message to session - al.sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) } } @@ -703,19 +674,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } // updateToolContexts updates the context for tools that need channel/chatID info. -func (al *AgentLoop) updateToolContexts(channel, chatID string) { +func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) { // Use ContextualTool interface instead of type assertions - if tool, ok := al.tools.Get("message"); ok { + if tool, ok := agent.Tools.Get("message"); ok { if mt, ok := tool.(tools.ContextualTool); ok { mt.SetContext(channel, chatID) } } - if tool, ok := al.tools.Get("spawn"); ok { + if tool, ok := agent.Tools.Get("spawn"); ok { if st, ok := tool.(tools.ContextualTool); ok { st.SetContext(channel, chatID) } } - if tool, ok := al.tools.Get("subagent"); ok { + if tool, ok := agent.Tools.Get("subagent"); ok { if st, ok := tool.(tools.ContextualTool); ok { st.SetContext(channel, chatID) } @@ -723,24 +694,24 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) { } // maybeSummarize triggers summarization if the session history exceeds thresholds. -func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) { - newHistory := al.sessions.GetHistory(sessionKey) +func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { + newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) - threshold := al.contextWindow * 75 / 100 + threshold := agent.ContextWindow * 75 / 100 if len(newHistory) > 20 || tokenEstimate > threshold { - if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading { + summarizeKey := agent.ID + ":" + sessionKey + if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading { go func() { - defer al.summarizing.Delete(sessionKey) - // Notify user about optimization if not an internal channel + defer al.summarizing.Delete(summarizeKey) if !constants.IsInternalChannel(channel) { al.bus.PublishOutbound(bus.OutboundMessage{ Channel: channel, ChatID: chatID, - Content: "⚠️ Memory threshold reached. Optimizing conversation history...", + Content: "Memory threshold reached. Optimizing conversation history...", }) } - al.summarizeSession(sessionKey) + al.summarizeSession(agent, sessionKey) }() } } @@ -748,8 +719,8 @@ func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) { // forceCompression aggressively reduces context when the limit is hit. // It drops the oldest 50% of messages (keeping system prompt and last user message). -func (al *AgentLoop) forceCompression(sessionKey string) { - history := al.sessions.GetHistory(sessionKey) +func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { + history := agent.Sessions.GetHistory(sessionKey) if len(history) <= 4 { return } @@ -796,8 +767,8 @@ func (al *AgentLoop) forceCompression(sessionKey string) { newHistory = append(newHistory, history[len(history)-1]) // Last message // Update session - al.sessions.SetHistory(sessionKey, newHistory) - al.sessions.Save(sessionKey) + agent.Sessions.SetHistory(sessionKey, newHistory) + agent.Sessions.Save(sessionKey) logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{ "session_key": sessionKey, @@ -810,15 +781,26 @@ func (al *AgentLoop) forceCompression(sessionKey string) { func (al *AgentLoop) GetStartupInfo() map[string]interface{} { info := make(map[string]interface{}) + agent := al.registry.GetDefaultAgent() + if agent == nil { + return info + } + // Tools info - tools := al.tools.List() + toolsList := agent.Tools.List() info["tools"] = map[string]interface{}{ - "count": len(tools), - "names": tools, + "count": len(toolsList), + "names": toolsList, } // Skills info - info["skills"] = al.contextBuilder.GetSkillsInfo() + info["skills"] = agent.ContextBuilder.GetSkillsInfo() + + // Agents info + info["agents"] = map[string]interface{}{ + "count": len(al.registry.ListAgentIDs()), + "ids": al.registry.ListAgentIDs(), + } return info } @@ -875,12 +857,12 @@ func formatToolsForLog(tools []providers.ToolDefinition) string { } // summarizeSession summarizes the conversation history for a session. -func (al *AgentLoop) summarizeSession(sessionKey string) { +func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() - history := al.sessions.GetHistory(sessionKey) - summary := al.sessions.GetSummary(sessionKey) + history := agent.Sessions.GetHistory(sessionKey) + summary := agent.Sessions.GetSummary(sessionKey) // Keep last 4 messages for continuity if len(history) <= 4 { @@ -890,8 +872,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { toSummarize := history[:len(history)-4] // Oversized Message Guard - // Skip messages larger than 50% of context window to prevent summarizer overflow - maxMessageTokens := al.contextWindow / 2 + maxMessageTokens := agent.ContextWindow / 2 validMessages := make([]providers.Message, 0) omitted := false @@ -899,8 +880,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { if m.Role != "user" && m.Role != "assistant" { continue } - // Estimate tokens for this message - msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety) + msgTokens := len(m.Content) / 2 if msgTokens > maxMessageTokens { omitted = true continue @@ -913,19 +893,17 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { } // Multi-Part Summarization - // Split into two parts if history is significant var finalSummary string if len(validMessages) > 10 { mid := len(validMessages) / 2 part1 := validMessages[:mid] part2 := validMessages[mid:] - s1, _ := al.summarizeBatch(ctx, part1, "") - s2, _ := al.summarizeBatch(ctx, part2, "") + s1, _ := al.summarizeBatch(ctx, agent, part1, "") + s2, _ := al.summarizeBatch(ctx, agent, part2, "") - // Merge them mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2) - resp, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, al.model, map[string]interface{}{ + resp, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, agent.Model, map[string]interface{}{ "max_tokens": 1024, "temperature": 0.3, }) @@ -935,7 +913,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { finalSummary = s1 + " " + s2 } } else { - finalSummary, _ = al.summarizeBatch(ctx, validMessages, summary) + finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary) } if omitted && finalSummary != "" { @@ -943,14 +921,14 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { } if finalSummary != "" { - al.sessions.SetSummary(sessionKey, finalSummary) - al.sessions.TruncateHistory(sessionKey, 4) - al.sessions.Save(sessionKey) + agent.Sessions.SetSummary(sessionKey, finalSummary) + agent.Sessions.TruncateHistory(sessionKey, 4) + agent.Sessions.Save(sessionKey) } } // summarizeBatch summarizes a batch of messages. -func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Message, existingSummary string) (string, error) { +func (al *AgentLoop) summarizeBatch(ctx context.Context, agent *AgentInstance, batch []providers.Message, existingSummary string) (string, error) { prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n" if existingSummary != "" { prompt += "Existing context: " + existingSummary + "\n" @@ -960,7 +938,7 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content) } - response, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, al.model, map[string]interface{}{ + response, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, agent.Model, map[string]interface{}{ "max_tokens": 1024, "temperature": 0.3, }) @@ -999,25 +977,31 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) switch cmd { case "/show": if len(args) < 1 { - return "Usage: /show [model|channel]", true + return "Usage: /show [model|channel|agents]", true } switch args[0] { case "model": - return fmt.Sprintf("Current model: %s", al.model), true + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + return "No default agent configured", true + } + return fmt.Sprintf("Current model: %s", defaultAgent.Model), true case "channel": return fmt.Sprintf("Current channel: %s", msg.Channel), true + case "agents": + agentIDs := al.registry.ListAgentIDs() + return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true default: return fmt.Sprintf("Unknown show target: %s", args[0]), true } case "/list": if len(args) < 1 { - return "Usage: /list [models|channels]", true + return "Usage: /list [models|channels|agents]", true } switch args[0] { case "models": - // TODO: Fetch available models dynamically if possible - return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true + return "Available models: configured in config.json per agent", true case "channels": if al.channelManager == nil { return "Channel manager not initialized", true @@ -1027,6 +1011,9 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "No channels enabled", true } return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true + case "agents": + agentIDs := al.registry.ListAgentIDs() + return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true default: return fmt.Sprintf("Unknown list target: %s", args[0]), true } @@ -1040,23 +1027,21 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) switch target { case "model": - oldModel := al.model - al.model = value + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + return "No default agent configured", true + } + oldModel := defaultAgent.Model + defaultAgent.Model = value return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true case "channel": - // This changes the 'default' channel for some operations, or effectively redirects output? - // For now, let's just validate if the channel exists if al.channelManager == nil { return "Channel manager not initialized", true } if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" { return fmt.Sprintf("Channel '%s' not found or not enabled", value), true } - - // If message came from CLI, maybe we want to redirect CLI output to this channel? - // That would require state persistence about "redirected channel" - // For now, just acknowledged. - return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true + return fmt.Sprintf("Switched target channel to %s", value), true default: return fmt.Sprintf("Unknown switch target: %s", target), true } @@ -1064,3 +1049,30 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "", false } + +// extractPeer extracts the routing peer from inbound message metadata. +func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { + peerKind := msg.Metadata["peer_kind"] + if peerKind == "" { + return nil + } + peerID := msg.Metadata["peer_id"] + if peerID == "" { + if peerKind == "direct" { + peerID = msg.SenderID + } else { + peerID = msg.ChatID + } + } + return &routing.RoutePeer{Kind: peerKind, ID: peerID} +} + +// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. +func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { + parentKind := msg.Metadata["parent_peer_kind"] + parentID := msg.Metadata["parent_peer_id"] + if parentKind == "" || parentID == "" { + return nil + } + return &routing.RoutePeer{Kind: parentKind, ID: parentID} +} diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 0bd38abf4..f2257973c 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -594,7 +594,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { {Role: "assistant", Content: "Old response 2"}, {Role: "user", Content: "Trigger message"}, } - al.sessions.SetHistory(sessionKey, history) + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("No default agent found") + } + defaultAgent.Sessions.SetHistory(sessionKey, history) // Call ProcessDirectWithChannel // Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration @@ -614,7 +618,7 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { } // Check final history length - finalHistory := al.sessions.GetHistory(sessionKey) + finalHistory := defaultAgent.Sessions.GetHistory(sessionKey) // We verify that the history has been modified (compressed) // Original length: 6 // Expected behavior: compression drops ~50% of history (mid slice) diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go new file mode 100644 index 000000000..4cf5a6fca --- /dev/null +++ b/pkg/agent/registry.go @@ -0,0 +1,114 @@ +package agent + +import ( + "sync" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" +) + +// AgentRegistry manages multiple agent instances and routes messages to them. +type AgentRegistry struct { + agents map[string]*AgentInstance + resolver *routing.RouteResolver + mu sync.RWMutex +} + +// NewAgentRegistry creates a registry from config, instantiating all agents. +func NewAgentRegistry( + cfg *config.Config, + provider providers.LLMProvider, +) *AgentRegistry { + registry := &AgentRegistry{ + agents: make(map[string]*AgentInstance), + resolver: routing.NewRouteResolver(cfg), + } + + agentConfigs := cfg.Agents.List + if len(agentConfigs) == 0 { + implicitAgent := &config.AgentConfig{ + ID: "main", + Default: true, + } + instance := NewAgentInstance(implicitAgent, &cfg.Agents.Defaults, cfg, provider) + registry.agents["main"] = instance + logger.InfoCF("agent", "Created implicit main agent (no agents.list configured)", nil) + } else { + for i := range agentConfigs { + ac := &agentConfigs[i] + id := routing.NormalizeAgentID(ac.ID) + instance := NewAgentInstance(ac, &cfg.Agents.Defaults, cfg, provider) + registry.agents[id] = instance + logger.InfoCF("agent", "Registered agent", + map[string]interface{}{ + "agent_id": id, + "name": ac.Name, + "workspace": instance.Workspace, + "model": instance.Model, + }) + } + } + + return registry +} + +// GetAgent returns the agent instance for a given ID. +func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + id := routing.NormalizeAgentID(agentID) + agent, ok := r.agents[id] + return agent, ok +} + +// ResolveRoute determines which agent handles the message. +func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute { + return r.resolver.ResolveRoute(input) +} + +// ListAgentIDs returns all registered agent IDs. +func (r *AgentRegistry) ListAgentIDs() []string { + r.mu.RLock() + defer r.mu.RUnlock() + ids := make([]string, 0, len(r.agents)) + for id := range r.agents { + ids = append(ids, id) + } + return ids +} + +// CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID. +func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool { + parent, ok := r.GetAgent(parentAgentID) + if !ok { + return false + } + if parent.Subagents == nil || parent.Subagents.AllowAgents == nil { + return false + } + targetNorm := routing.NormalizeAgentID(targetAgentID) + for _, allowed := range parent.Subagents.AllowAgents { + if allowed == "*" { + return true + } + if routing.NormalizeAgentID(allowed) == targetNorm { + return true + } + } + return false +} + +// GetDefaultAgent returns the default agent instance. +func (r *AgentRegistry) GetDefaultAgent() *AgentInstance { + r.mu.RLock() + defer r.mu.RUnlock() + if agent, ok := r.agents["main"]; ok { + return agent + } + for _, agent := range r.agents { + return agent + } + return nil +} diff --git a/pkg/agent/registry_test.go b/pkg/agent/registry_test.go new file mode 100644 index 000000000..f196d7fb7 --- /dev/null +++ b/pkg/agent/registry_test.go @@ -0,0 +1,199 @@ +package agent + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +type mockRegistryProvider struct{} + +func (m *mockRegistryProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "mock", FinishReason: "stop"}, nil +} + +func (m *mockRegistryProvider) GetDefaultModel() string { + return "mock-model" +} + +func testCfg(agents []config.AgentConfig) *config.Config { + return &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: "/tmp/picoclaw-test-registry", + Model: "gpt-4", + MaxTokens: 8192, + MaxToolIterations: 10, + }, + List: agents, + }, + } +} + +func TestNewAgentRegistry_ImplicitMain(t *testing.T) { + cfg := testCfg(nil) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + ids := registry.ListAgentIDs() + if len(ids) != 1 || ids[0] != "main" { + t.Errorf("expected implicit main agent, got %v", ids) + } + + agent, ok := registry.GetAgent("main") + if !ok || agent == nil { + t.Fatal("expected to find 'main' agent") + } + if agent.ID != "main" { + t.Errorf("agent.ID = %q, want 'main'", agent.ID) + } +} + +func TestNewAgentRegistry_ExplicitAgents(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "sales", Default: true, Name: "Sales Bot"}, + {ID: "support", Name: "Support Bot"}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + ids := registry.ListAgentIDs() + if len(ids) != 2 { + t.Fatalf("expected 2 agents, got %d: %v", len(ids), ids) + } + + sales, ok := registry.GetAgent("sales") + if !ok || sales == nil { + t.Fatal("expected to find 'sales' agent") + } + if sales.Name != "Sales Bot" { + t.Errorf("sales.Name = %q, want 'Sales Bot'", sales.Name) + } + + support, ok := registry.GetAgent("support") + if !ok || support == nil { + t.Fatal("expected to find 'support' agent") + } +} + +func TestAgentRegistry_GetAgent_Normalize(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "my-agent", Default: true}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, ok := registry.GetAgent("My-Agent") + if !ok || agent == nil { + t.Fatal("expected to find agent with normalized ID") + } + if agent.ID != "my-agent" { + t.Errorf("agent.ID = %q, want 'my-agent'", agent.ID) + } +} + +func TestAgentRegistry_GetDefaultAgent(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta", Default: true}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + // GetDefaultAgent first checks for "main", then returns any + agent := registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected a default agent") + } +} + +func TestAgentRegistry_CanSpawnSubagent(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + { + ID: "parent", + Default: true, + Subagents: &config.SubagentsConfig{ + AllowAgents: []string{"child1", "child2"}, + }, + }, + {ID: "child1"}, + {ID: "child2"}, + {ID: "restricted"}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + if !registry.CanSpawnSubagent("parent", "child1") { + t.Error("expected parent to be allowed to spawn child1") + } + if !registry.CanSpawnSubagent("parent", "child2") { + t.Error("expected parent to be allowed to spawn child2") + } + if registry.CanSpawnSubagent("parent", "restricted") { + t.Error("expected parent to NOT be allowed to spawn restricted") + } + if registry.CanSpawnSubagent("child1", "child2") { + t.Error("expected child1 to NOT be allowed to spawn (no subagents config)") + } +} + +func TestAgentRegistry_CanSpawnSubagent_Wildcard(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + { + ID: "admin", + Default: true, + Subagents: &config.SubagentsConfig{ + AllowAgents: []string{"*"}, + }, + }, + {ID: "any-agent"}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + if !registry.CanSpawnSubagent("admin", "any-agent") { + t.Error("expected wildcard to allow spawning any agent") + } + if !registry.CanSpawnSubagent("admin", "nonexistent") { + t.Error("expected wildcard to allow spawning even nonexistent agents") + } +} + +func TestAgentInstance_Model(t *testing.T) { + model := &config.AgentModelConfig{Primary: "claude-opus"} + cfg := testCfg([]config.AgentConfig{ + {ID: "custom", Default: true, Model: model}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, _ := registry.GetAgent("custom") + if agent.Model != "claude-opus" { + t.Errorf("agent.Model = %q, want 'claude-opus'", agent.Model) + } +} + +func TestAgentInstance_FallbackInheritance(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "inherit", Default: true}, + }) + cfg.Agents.Defaults.ModelFallbacks = []string{"openai/gpt-4o-mini", "anthropic/haiku"} + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, _ := registry.GetAgent("inherit") + if len(agent.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks inherited from defaults, got %d", len(agent.Fallbacks)) + } +} + +func TestAgentInstance_FallbackExplicitEmpty(t *testing.T) { + model := &config.AgentModelConfig{ + Primary: "gpt-4", + Fallbacks: []string{}, // explicitly empty = disable + } + cfg := testCfg([]config.AgentConfig{ + {ID: "no-fallback", Default: true, Model: model}, + }) + cfg.Agents.Defaults.ModelFallbacks = []string{"should-not-inherit"} + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, _ := registry.GetAgent("no-fallback") + if len(agent.Fallbacks) != 0 { + t.Errorf("expected 0 fallbacks (explicit empty), got %d: %v", len(agent.Fallbacks), agent.Fallbacks) + } +} diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 8d2d9a65b..4925099a3 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -2,7 +2,6 @@ package channels import ( "context" - "fmt" "strings" "github.com/sipeed/picoclaw/pkg/bus" @@ -87,17 +86,13 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st return } - // Build session key: channel:chatID - sessionKey := fmt.Sprintf("%s:%s", c.name, chatID) - msg := bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - ChatID: chatID, - Content: content, - Media: media, - SessionKey: sessionKey, - Metadata: metadata, + Channel: c.name, + SenderID: senderID, + ChatID: chatID, + Content: content, + Media: media, + Metadata: metadata, } c.bus.PublishInbound(msg) diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index 00aa8ab4d..f360c75ef 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -376,6 +376,13 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag "preview": utils.Truncate(content, 50), }) + peerKind := "channel" + peerID := m.ChannelID + if m.GuildID == "" { + peerKind = "direct" + peerID = senderID + } + metadata := map[string]string{ "message_id": m.ID, "user_id": senderID, @@ -384,6 +391,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag "guild_id": m.GuildID, "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), + "peer_kind": peerKind, + "peer_id": peerID, } c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go index 5387e9213..0060972ed 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack.go @@ -25,6 +25,7 @@ type SlackChannel struct { api *slack.Client socketClient *socketmode.Client botUserID string + teamID string transcriber *voice.GroqTranscriber ctx context.Context cancel context.CancelFunc @@ -72,6 +73,7 @@ func (c *SlackChannel) Start(ctx context.Context) error { return fmt.Errorf("slack auth test failed: %w", err) } c.botUserID = authResp.UserID + c.teamID = authResp.TeamID logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ "bot_user_id": c.botUserID, @@ -274,11 +276,21 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { return } + peerKind := "channel" + peerID := channelID + if strings.HasPrefix(channelID, "D") { + peerKind = "direct" + peerID = senderID + } + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", + "peer_kind": peerKind, + "peer_id": peerID, + "team_id": c.teamID, } logger.DebugCF("slack", "Received message", map[string]interface{}{ @@ -331,12 +343,22 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { return } + mentionPeerKind := "channel" + mentionPeerID := channelID + if strings.HasPrefix(channelID, "D") { + mentionPeerKind = "direct" + mentionPeerID = senderID + } + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", "is_mention": "true", + "peer_kind": mentionPeerKind, + "peer_id": mentionPeerID, + "team_id": c.teamID, } c.HandleMessage(senderID, chatID, content, nil, metadata) @@ -373,6 +395,9 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "platform": "slack", "is_command": "true", "trigger_id": cmd.TriggerID, + "peer_kind": "channel", + "peer_id": channelID, + "team_id": c.teamID, } logger.DebugCF("slack", "Slash command received", map[string]interface{}{ diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 5601d508c..24b82b557 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -347,12 +347,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes c.placeholders.Store(chatIDStr, pID) } + peerKind := "direct" + peerID := fmt.Sprintf("%d", user.ID) + if message.Chat.Type != "private" { + peerKind = "group" + peerID = fmt.Sprintf("%d", chatID) + } + metadata := map[string]string{ "message_id": fmt.Sprintf("%d", message.MessageID), "user_id": fmt.Sprintf("%d", user.ID), "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), + "peer_kind": peerKind, + "peer_id": peerID, } c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) diff --git a/pkg/config/config.go b/pkg/config/config.go index a1cc978b6..682996bd6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -45,6 +45,8 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { type Config struct { Agents AgentsConfig `json:"agents"` + Bindings []AgentBinding `json:"bindings,omitempty"` + Session SessionConfig `json:"session,omitempty"` Channels ChannelsConfig `json:"channels"` Providers ProvidersConfig `json:"providers"` Gateway GatewayConfig `json:"gateway"` @@ -56,16 +58,97 @@ type Config struct { type AgentsConfig struct { Defaults AgentDefaults `json:"defaults"` + List []AgentConfig `json:"list,omitempty"` +} + +// AgentModelConfig supports both string and structured model config. +// String format: "gpt-4" (just primary, no fallbacks) +// Object format: {"primary": "gpt-4", "fallbacks": ["claude-haiku"]} +type AgentModelConfig struct { + Primary string `json:"primary,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` +} + +func (m *AgentModelConfig) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err == nil { + m.Primary = s + m.Fallbacks = nil + return nil + } + type raw struct { + Primary string `json:"primary"` + Fallbacks []string `json:"fallbacks"` + } + var r raw + if err := json.Unmarshal(data, &r); err != nil { + return err + } + m.Primary = r.Primary + m.Fallbacks = r.Fallbacks + return nil +} + +func (m AgentModelConfig) MarshalJSON() ([]byte, error) { + if len(m.Fallbacks) == 0 && m.Primary != "" { + return json.Marshal(m.Primary) + } + type raw struct { + Primary string `json:"primary,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` + } + return json.Marshal(raw{Primary: m.Primary, Fallbacks: m.Fallbacks}) +} + +type AgentConfig struct { + ID string `json:"id"` + Default bool `json:"default,omitempty"` + Name string `json:"name,omitempty"` + Workspace string `json:"workspace,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` + Skills []string `json:"skills,omitempty"` + Subagents *SubagentsConfig `json:"subagents,omitempty"` +} + +type SubagentsConfig struct { + AllowAgents []string `json:"allow_agents,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` +} + +type PeerMatch struct { + Kind string `json:"kind"` + ID string `json:"id"` +} + +type BindingMatch struct { + Channel string `json:"channel"` + AccountID string `json:"account_id,omitempty"` + Peer *PeerMatch `json:"peer,omitempty"` + GuildID string `json:"guild_id,omitempty"` + TeamID string `json:"team_id,omitempty"` +} + +type AgentBinding struct { + AgentID string `json:"agent_id"` + Match BindingMatch `json:"match"` +} + +type SessionConfig struct { + DMScope string `json:"dm_scope,omitempty"` + IdentityLinks map[string][]string `json:"identity_links,omitempty"` } type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` - Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` + ModelFallbacks []string `json:"model_fallbacks,omitempty"` + ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` + ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` } type ChannelsConfig struct { @@ -461,6 +544,32 @@ func (c *Config) GetAPIBase() string { return "" } +// ModelConfig holds primary model and fallback list. +type ModelConfig struct { + Primary string + Fallbacks []string +} + +// GetModelConfig returns the text model configuration with fallbacks. +func (c *Config) GetModelConfig() ModelConfig { + c.mu.RLock() + defer c.mu.RUnlock() + return ModelConfig{ + Primary: c.Agents.Defaults.Model, + Fallbacks: c.Agents.Defaults.ModelFallbacks, + } +} + +// GetImageModelConfig returns the image model configuration with fallbacks. +func (c *Config) GetImageModelConfig() ModelConfig { + c.mu.RLock() + defer c.mu.RUnlock() + return ModelConfig{ + Primary: c.Agents.Defaults.ImageModel, + Fallbacks: c.Agents.Defaults.ImageModelFallbacks, + } +} + func expandHome(path string) string { if path == "" { return path diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index a1f73f0b3..47916d155 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1,12 +1,193 @@ package config import ( + "encoding/json" "os" "path/filepath" "runtime" "testing" ) +func TestAgentModelConfig_UnmarshalString(t *testing.T) { + var m AgentModelConfig + if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil { + t.Fatalf("unmarshal string: %v", err) + } + if m.Primary != "gpt-4" { + t.Errorf("Primary = %q, want 'gpt-4'", m.Primary) + } + if m.Fallbacks != nil { + t.Errorf("Fallbacks = %v, want nil", m.Fallbacks) + } +} + +func TestAgentModelConfig_UnmarshalObject(t *testing.T) { + var m AgentModelConfig + data := `{"primary": "claude-opus", "fallbacks": ["gpt-4o-mini", "haiku"]}` + if err := json.Unmarshal([]byte(data), &m); err != nil { + t.Fatalf("unmarshal object: %v", err) + } + if m.Primary != "claude-opus" { + t.Errorf("Primary = %q, want 'claude-opus'", m.Primary) + } + if len(m.Fallbacks) != 2 { + t.Fatalf("Fallbacks len = %d, want 2", len(m.Fallbacks)) + } + if m.Fallbacks[0] != "gpt-4o-mini" || m.Fallbacks[1] != "haiku" { + t.Errorf("Fallbacks = %v", m.Fallbacks) + } +} + +func TestAgentModelConfig_MarshalString(t *testing.T) { + m := AgentModelConfig{Primary: "gpt-4"} + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(data) != `"gpt-4"` { + t.Errorf("marshal = %s, want '\"gpt-4\"'", string(data)) + } +} + +func TestAgentModelConfig_MarshalObject(t *testing.T) { + m := AgentModelConfig{Primary: "claude-opus", Fallbacks: []string{"haiku"}} + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var result map[string]interface{} + json.Unmarshal(data, &result) + if result["primary"] != "claude-opus" { + t.Errorf("primary = %v", result["primary"]) + } +} + +func TestAgentConfig_FullParse(t *testing.T) { + jsonData := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "max_tool_iterations": 20 + }, + "list": [ + { + "id": "sales", + "default": true, + "name": "Sales Bot", + "model": "gpt-4" + }, + { + "id": "support", + "name": "Support Bot", + "model": { + "primary": "claude-opus", + "fallbacks": ["haiku"] + }, + "subagents": { + "allow_agents": ["sales"] + } + } + ] + }, + "bindings": [ + { + "agent_id": "support", + "match": { + "channel": "telegram", + "account_id": "*", + "peer": {"kind": "direct", "id": "user123"} + } + } + ], + "session": { + "dm_scope": "per-peer", + "identity_links": { + "john": ["telegram:123", "discord:john#1234"] + } + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(cfg.Agents.List) != 2 { + t.Fatalf("agents.list len = %d, want 2", len(cfg.Agents.List)) + } + + sales := cfg.Agents.List[0] + if sales.ID != "sales" || !sales.Default || sales.Name != "Sales Bot" { + t.Errorf("sales = %+v", sales) + } + if sales.Model == nil || sales.Model.Primary != "gpt-4" { + t.Errorf("sales.Model = %+v", sales.Model) + } + + support := cfg.Agents.List[1] + if support.ID != "support" || support.Name != "Support Bot" { + t.Errorf("support = %+v", support) + } + if support.Model == nil || support.Model.Primary != "claude-opus" { + t.Errorf("support.Model = %+v", support.Model) + } + if len(support.Model.Fallbacks) != 1 || support.Model.Fallbacks[0] != "haiku" { + t.Errorf("support.Model.Fallbacks = %v", support.Model.Fallbacks) + } + if support.Subagents == nil || len(support.Subagents.AllowAgents) != 1 { + t.Errorf("support.Subagents = %+v", support.Subagents) + } + + if len(cfg.Bindings) != 1 { + t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings)) + } + binding := cfg.Bindings[0] + if binding.AgentID != "support" || binding.Match.Channel != "telegram" { + t.Errorf("binding = %+v", binding) + } + if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" { + t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer) + } + + if cfg.Session.DMScope != "per-peer" { + t.Errorf("Session.DMScope = %q", cfg.Session.DMScope) + } + if len(cfg.Session.IdentityLinks) != 1 { + t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks) + } + links := cfg.Session.IdentityLinks["john"] + if len(links) != 2 { + t.Errorf("john links = %v", links) + } +} + +func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) { + jsonData := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "max_tool_iterations": 20 + } + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(cfg.Agents.List) != 0 { + t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List)) + } + if len(cfg.Bindings) != 0 { + t.Errorf("bindings should be empty, got %d", len(cfg.Bindings)) + } +} + // TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { cfg := DefaultConfig() @@ -20,8 +201,6 @@ func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { func TestDefaultConfig_WorkspacePath(t *testing.T) { cfg := DefaultConfig() - // Just verify the workspace is set, don't compare exact paths - // since expandHome behavior may differ based on environment if cfg.Agents.Defaults.Workspace == "" { t.Error("Workspace should not be empty") } @@ -79,7 +258,6 @@ func TestDefaultConfig_Gateway(t *testing.T) { func TestDefaultConfig_Providers(t *testing.T) { cfg := DefaultConfig() - // Verify all providers are empty by default if cfg.Providers.Anthropic.APIKey != "" { t.Error("Anthropic API key should be empty by default") } @@ -89,46 +267,18 @@ func TestDefaultConfig_Providers(t *testing.T) { if cfg.Providers.OpenRouter.APIKey != "" { t.Error("OpenRouter API key should be empty by default") } - if cfg.Providers.Groq.APIKey != "" { - t.Error("Groq API key should be empty by default") - } - if cfg.Providers.Zhipu.APIKey != "" { - t.Error("Zhipu API key should be empty by default") - } - if cfg.Providers.VLLM.APIKey != "" { - t.Error("VLLM API key should be empty by default") - } - if cfg.Providers.Gemini.APIKey != "" { - t.Error("Gemini API key should be empty by default") - } } // TestDefaultConfig_Channels verifies channels are disabled by default func TestDefaultConfig_Channels(t *testing.T) { cfg := DefaultConfig() - // Verify all channels are disabled by default - if cfg.Channels.WhatsApp.Enabled { - t.Error("WhatsApp should be disabled by default") - } if cfg.Channels.Telegram.Enabled { t.Error("Telegram should be disabled by default") } - if cfg.Channels.Feishu.Enabled { - t.Error("Feishu should be disabled by default") - } if cfg.Channels.Discord.Enabled { t.Error("Discord should be disabled by default") } - if cfg.Channels.MaixCam.Enabled { - t.Error("MaixCam should be disabled by default") - } - if cfg.Channels.QQ.Enabled { - t.Error("QQ should be disabled by default") - } - if cfg.Channels.DingTalk.Enabled { - t.Error("DingTalk should be disabled by default") - } if cfg.Channels.Slack.Enabled { t.Error("Slack should be disabled by default") } @@ -178,7 +328,6 @@ func TestSaveConfig_FilePermissions(t *testing.T) { func TestConfig_Complete(t *testing.T) { cfg := DefaultConfig() - // Verify complete config structure if cfg.Agents.Defaults.Workspace == "" { t.Error("Workspace should not be empty") } diff --git a/pkg/providers/codex_cli_provider_integration_test.go b/pkg/providers/codex_cli_provider_integration_test.go new file mode 100644 index 000000000..0267c730f --- /dev/null +++ b/pkg/providers/codex_cli_provider_integration_test.go @@ -0,0 +1,119 @@ +//go:build integration + +package providers + +import ( + "context" + exec "os/exec" + "strings" + "testing" + "time" +) + +// TestIntegration_RealCodexCLI tests the CodexCliProvider with a real codex CLI. +// Run with: go test -tags=integration ./pkg/providers/... +func TestIntegration_RealCodexCLI(t *testing.T) { + path, err := exec.LookPath("codex") + if err != nil { + t.Skip("codex CLI not found in PATH, skipping integration test") + } + t.Logf("Using codex CLI at: %s", path) + + p := NewCodexCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() with real CLI error = %v", err) + } + + if resp.Content == "" { + t.Error("Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage != nil { + t.Logf("Usage: prompt=%d, completion=%d, total=%d", + resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens) + } + + t.Logf("Response content: %q", resp.Content) + + if !strings.Contains(strings.ToLower(resp.Content), "pong") { + t.Errorf("Content = %q, expected to contain 'pong'", resp.Content) + } +} + +func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) { + if _, err := exec.LookPath("codex"); err != nil { + t.Skip("codex CLI not found in PATH") + } + + p := NewCodexCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, + {Role: "user", Content: "What is 2+2?"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + t.Logf("Response: %q", resp.Content) + + if !strings.Contains(resp.Content, "4") { + t.Errorf("Content = %q, expected to contain '4'", resp.Content) + } +} + +func TestIntegration_RealCodexCLI_ParsesRealJSONL(t *testing.T) { + if _, err := exec.LookPath("codex"); err != nil { + t.Skip("codex CLI not found in PATH") + } + + // Run codex directly and verify our parser handles real output + cmd := exec.Command("codex", "exec", + "--json", + "--dangerously-bypass-approvals-and-sandbox", + "--skip-git-repo-check", + "--color", "never", + "-C", t.TempDir(), + "-") + cmd.Stdin = strings.NewReader("Say hi") + + output, err := cmd.Output() + if err != nil { + // codex may write diagnostic noise to stderr but still produce valid output + if len(output) == 0 { + t.Fatalf("codex CLI failed: %v", err) + } + } + + t.Logf("Raw CLI output (first 500 chars): %s", string(output[:min(len(output), 500)])) + + // Verify our parser can handle real output + p := NewCodexCliProvider("") + resp, err := p.parseJSONLEvents(string(output)) + if err != nil { + t.Fatalf("parseJSONLEvents() failed on real CLI output: %v", err) + } + + if resp.Content == "" { + t.Error("parsed Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", resp.FinishReason) + } + + t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage) +} diff --git a/pkg/providers/cooldown.go b/pkg/providers/cooldown.go new file mode 100644 index 000000000..b0d8608dc --- /dev/null +++ b/pkg/providers/cooldown.go @@ -0,0 +1,207 @@ +package providers + +import ( + "math" + "sync" + "time" +) + +const ( + defaultFailureWindow = 24 * time.Hour +) + +// CooldownTracker manages per-provider cooldown state for the fallback chain. +// Thread-safe via sync.RWMutex. In-memory only (resets on restart). +type CooldownTracker struct { + mu sync.RWMutex + entries map[string]*cooldownEntry + failureWindow time.Duration + nowFunc func() time.Time // for testing +} + +type cooldownEntry struct { + ErrorCount int + FailureCounts map[FailoverReason]int + CooldownEnd time.Time // standard cooldown expiry + DisabledUntil time.Time // billing-specific disable expiry + DisabledReason FailoverReason // reason for disable (billing) + LastFailure time.Time +} + +// NewCooldownTracker creates a tracker with default 24h failure window. +func NewCooldownTracker() *CooldownTracker { + return &CooldownTracker{ + entries: make(map[string]*cooldownEntry), + failureWindow: defaultFailureWindow, + nowFunc: time.Now, + } +} + +// MarkFailure records a failure for a provider and sets appropriate cooldown. +// Resets error counts if last failure was more than failureWindow ago. +func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) { + ct.mu.Lock() + defer ct.mu.Unlock() + + now := ct.nowFunc() + entry := ct.getOrCreate(provider) + + // 24h failure window reset: if no failure in failureWindow, reset counters. + if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow { + entry.ErrorCount = 0 + entry.FailureCounts = make(map[FailoverReason]int) + } + + entry.ErrorCount++ + entry.FailureCounts[reason]++ + entry.LastFailure = now + + if reason == FailoverBilling { + billingCount := entry.FailureCounts[FailoverBilling] + entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount)) + entry.DisabledReason = FailoverBilling + } else { + entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount)) + } +} + +// MarkSuccess resets all counters and cooldowns for a provider. +func (ct *CooldownTracker) MarkSuccess(provider string) { + ct.mu.Lock() + defer ct.mu.Unlock() + + entry := ct.entries[provider] + if entry == nil { + return + } + + entry.ErrorCount = 0 + entry.FailureCounts = make(map[FailoverReason]int) + entry.CooldownEnd = time.Time{} + entry.DisabledUntil = time.Time{} + entry.DisabledReason = "" +} + +// IsAvailable returns true if the provider is not in cooldown or disabled. +func (ct *CooldownTracker) IsAvailable(provider string) bool { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return true + } + + now := ct.nowFunc() + + // Billing disable takes precedence (longer cooldown). + if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) { + return false + } + + // Standard cooldown. + if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) { + return false + } + + return true +} + +// CooldownRemaining returns how long until the provider becomes available. +// Returns 0 if already available. +func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + + now := ct.nowFunc() + var remaining time.Duration + + if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) { + d := entry.DisabledUntil.Sub(now) + if d > remaining { + remaining = d + } + } + + if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) { + d := entry.CooldownEnd.Sub(now) + if d > remaining { + remaining = d + } + } + + return remaining +} + +// ErrorCount returns the current error count for a provider. +func (ct *CooldownTracker) ErrorCount(provider string) int { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + return entry.ErrorCount +} + +// FailureCount returns the failure count for a specific reason. +func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + return entry.FailureCounts[reason] +} + +func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry { + entry := ct.entries[provider] + if entry == nil { + entry = &cooldownEntry{ + FailureCounts: make(map[FailoverReason]int), + } + ct.entries[provider] = entry + } + return entry +} + +// calculateStandardCooldown computes standard exponential backoff. +// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3)) +// +// 1 error → 1 min +// 2 errors → 5 min +// 3 errors → 25 min +// 4+ errors → 1 hour (cap) +func calculateStandardCooldown(errorCount int) time.Duration { + n := max(1, errorCount) + exp := min(n-1, 3) + ms := 60_000 * int(math.Pow(5, float64(exp))) + ms = min(3_600_000, ms) // cap at 1 hour + return time.Duration(ms) * time.Millisecond +} + +// calculateBillingCooldown computes billing-specific exponential backoff. +// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10)) +// +// 1 error → 5 hours +// 2 errors → 10 hours +// 3 errors → 20 hours +// 4+ errors → 24 hours (cap) +func calculateBillingCooldown(billingErrorCount int) time.Duration { + const baseMs = 5 * 60 * 60 * 1000 // 5 hours + const maxMs = 24 * 60 * 60 * 1000 // 24 hours + + n := max(1, billingErrorCount) + exp := min(n-1, 10) + raw := float64(baseMs) * math.Pow(2, float64(exp)) + ms := int(math.Min(float64(maxMs), raw)) + return time.Duration(ms) * time.Millisecond +} diff --git a/pkg/providers/cooldown_test.go b/pkg/providers/cooldown_test.go new file mode 100644 index 000000000..47f43ad5c --- /dev/null +++ b/pkg/providers/cooldown_test.go @@ -0,0 +1,269 @@ +package providers + +import ( + "sync" + "testing" + "time" +) + +func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) { + current := now + ct := NewCooldownTracker() + ct.nowFunc = func() time.Time { return current } + return ct, ¤t +} + +func TestCooldown_InitiallyAvailable(t *testing.T) { + ct := NewCooldownTracker() + if !ct.IsAvailable("openai") { + t.Error("new provider should be available") + } + if ct.ErrorCount("openai") != 0 { + t.Error("new provider should have 0 errors") + } +} + +func TestCooldown_StandardEscalation(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 1st error → 1 min cooldown + ct.MarkFailure("openai", FailoverRateLimit) + if ct.IsAvailable("openai") { + t.Error("should be in cooldown after 1st error") + } + + // Advance 61 seconds → available + *current = now.Add(61 * time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after 1 min cooldown") + } + + // 2nd error → 5 min cooldown + ct.MarkFailure("openai", FailoverRateLimit) + *current = now.Add(61*time.Second + 4*time.Minute) + if ct.IsAvailable("openai") { + t.Error("should be in cooldown (5 min) after 2nd error") + } + *current = now.Add(61*time.Second + 6*time.Minute) + if !ct.IsAvailable("openai") { + t.Error("should be available after 5 min cooldown") + } +} + +func TestCooldown_StandardCap(t *testing.T) { + // Verify formula: 1m, 5m, 25m, 1h, 1h, 1h... + expected := []time.Duration{ + 1 * time.Minute, + 5 * time.Minute, + 25 * time.Minute, + 1 * time.Hour, + 1 * time.Hour, + } + + for i, want := range expected { + got := calculateStandardCooldown(i + 1) + if got != want { + t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want) + } + } +} + +func TestCooldown_BillingEscalation(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 1st billing error → 5h cooldown + ct.MarkFailure("openai", FailoverBilling) + if ct.IsAvailable("openai") { + t.Error("should be disabled after billing error") + } + + // Advance 4h → still disabled + *current = now.Add(4 * time.Hour) + if ct.IsAvailable("openai") { + t.Error("should still be disabled (5h cooldown)") + } + + // Advance 5h + 1s → available + *current = now.Add(5*time.Hour + 1*time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after 5h billing cooldown") + } +} + +func TestCooldown_BillingCap(t *testing.T) { + expected := []time.Duration{ + 5 * time.Hour, + 10 * time.Hour, + 20 * time.Hour, + 24 * time.Hour, + 24 * time.Hour, + } + + for i, want := range expected { + got := calculateBillingCooldown(i + 1) + if got != want { + t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want) + } + } +} + +func TestCooldown_SuccessReset(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverBilling) + if ct.ErrorCount("openai") != 2 { + t.Errorf("error count = %d, want 2", ct.ErrorCount("openai")) + } + + ct.MarkSuccess("openai") + if ct.ErrorCount("openai") != 0 { + t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai")) + } + if !ct.IsAvailable("openai") { + t.Error("should be available after success") + } + if ct.FailureCount("openai", FailoverRateLimit) != 0 { + t.Error("failure counts should be reset after success") + } + if ct.FailureCount("openai", FailoverBilling) != 0 { + t.Error("billing failure count should be reset after success") + } +} + +func TestCooldown_FailureWindowReset(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 4 errors → 1h cooldown + for i := 0; i < 4; i++ { + ct.MarkFailure("openai", FailoverRateLimit) + *current = current.Add(2 * time.Second) // small advance between errors + } + if ct.ErrorCount("openai") != 4 { + t.Errorf("error count = %d, want 4", ct.ErrorCount("openai")) + } + + // Advance 25 hours (past 24h failure window) + *current = now.Add(25 * time.Hour) + + // Next error should reset counters first, then increment to 1 + ct.MarkFailure("openai", FailoverRateLimit) + if ct.ErrorCount("openai") != 1 { + t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai")) + } +} + +func TestCooldown_PerReasonTracking(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverBilling) + ct.MarkFailure("openai", FailoverAuth) + + if ct.FailureCount("openai", FailoverRateLimit) != 2 { + t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit)) + } + if ct.FailureCount("openai", FailoverBilling) != 1 { + t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling)) + } + if ct.FailureCount("openai", FailoverAuth) != 1 { + t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth)) + } + if ct.ErrorCount("openai") != 4 { + t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai")) + } +} + +func TestCooldown_BillingTakesPrecedence(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // Standard cooldown (1 min) + billing disable (5h) + ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown + ct.MarkFailure("openai", FailoverBilling) // 5h disable + + // After 2 min: standard cooldown expired but billing still active + *current = now.Add(2 * time.Minute) + if ct.IsAvailable("openai") { + t.Error("billing disable should take precedence over standard cooldown") + } + + // After 5h + 1s: both expired + *current = now.Add(5*time.Hour + 1*time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after all cooldowns expire") + } +} + +func TestCooldown_CooldownRemaining(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // No failures → 0 remaining + if ct.CooldownRemaining("openai") != 0 { + t.Error("expected 0 remaining for new provider") + } + + ct.MarkFailure("openai", FailoverRateLimit) + + *current = now.Add(30 * time.Second) + remaining := ct.CooldownRemaining("openai") + if remaining <= 0 || remaining > 1*time.Minute { + t.Errorf("remaining = %v, expected ~30s", remaining) + } +} + +func TestCooldown_SuccessOnUnknownProvider(t *testing.T) { + ct := NewCooldownTracker() + // Should not panic + ct.MarkSuccess("nonexistent") + if !ct.IsAvailable("nonexistent") { + t.Error("nonexistent provider should be available") + } +} + +func TestCooldown_ConcurrentAccess(t *testing.T) { + ct := NewCooldownTracker() + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(3) + go func() { + defer wg.Done() + ct.MarkFailure("openai", FailoverRateLimit) + }() + go func() { + defer wg.Done() + ct.IsAvailable("openai") + }() + go func() { + defer wg.Done() + ct.MarkSuccess("openai") + }() + } + + wg.Wait() + // If we got here without panic, concurrent access is safe +} + +func TestCooldown_MultipleProviders(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("anthropic", FailoverBilling) + + if ct.IsAvailable("openai") { + t.Error("openai should be in cooldown") + } + if ct.IsAvailable("anthropic") { + t.Error("anthropic should be in cooldown") + } + // groq was never touched + if !ct.IsAvailable("groq") { + t.Error("groq should be available") + } +} diff --git a/pkg/providers/error_classifier.go b/pkg/providers/error_classifier.go new file mode 100644 index 000000000..a0f003006 --- /dev/null +++ b/pkg/providers/error_classifier.go @@ -0,0 +1,253 @@ +package providers + +import ( + "context" + "regexp" + "strings" +) + +// errorPattern defines a single pattern (string or regex) for error classification. +type errorPattern struct { + substring string + regex *regexp.Regexp +} + +func substr(s string) errorPattern { return errorPattern{substring: s} } +func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} } + +// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns). +var ( + rateLimitPatterns = []errorPattern{ + rxp(`rate[_ ]limit`), + substr("too many requests"), + substr("429"), + substr("exceeded your current quota"), + rxp(`exceeded.*quota`), + rxp(`resource has been exhausted`), + rxp(`resource.*exhausted`), + substr("resource_exhausted"), + substr("quota exceeded"), + substr("usage limit"), + } + + overloadedPatterns = []errorPattern{ + rxp(`overloaded_error`), + rxp(`"type"\s*:\s*"overloaded_error"`), + substr("overloaded"), + } + + timeoutPatterns = []errorPattern{ + substr("timeout"), + substr("timed out"), + substr("deadline exceeded"), + substr("context deadline exceeded"), + } + + billingPatterns = []errorPattern{ + rxp(`\b402\b`), + substr("payment required"), + substr("insufficient credits"), + substr("credit balance"), + substr("plans & billing"), + substr("insufficient balance"), + } + + authPatterns = []errorPattern{ + rxp(`invalid[_ ]?api[_ ]?key`), + substr("incorrect api key"), + substr("invalid token"), + substr("authentication"), + substr("re-authenticate"), + substr("oauth token refresh failed"), + substr("unauthorized"), + substr("forbidden"), + substr("access denied"), + substr("expired"), + substr("token has expired"), + rxp(`\b401\b`), + rxp(`\b403\b`), + substr("no credentials found"), + substr("no api key found"), + } + + formatPatterns = []errorPattern{ + substr("string should match pattern"), + substr("tool_use.id"), + substr("tool_use_id"), + substr("messages.1.content.1.tool_use.id"), + substr("invalid request format"), + } + + imageDimensionPatterns = []errorPattern{ + rxp(`image dimensions exceed max`), + } + + imageSizePatterns = []errorPattern{ + rxp(`image exceeds.*mb`), + } + + // Transient HTTP status codes that map to timeout (server-side failures). + transientStatusCodes = map[int]bool{ + 500: true, 502: true, 503: true, + 521: true, 522: true, 523: true, 524: true, + 529: true, + } +) + +// ClassifyError classifies an error into a FailoverError with reason. +// Returns nil if the error is not classifiable (unknown errors should not trigger fallback). +func ClassifyError(err error, provider, model string) *FailoverError { + if err == nil { + return nil + } + + // Context cancellation: user abort, never fallback. + if err == context.Canceled { + return nil + } + + // Context deadline exceeded: treat as timeout, always fallback. + if err == context.DeadlineExceeded { + return &FailoverError{ + Reason: FailoverTimeout, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + msg := strings.ToLower(err.Error()) + + // Image dimension/size errors: non-retriable, non-fallback. + if IsImageDimensionError(msg) || IsImageSizeError(msg) { + return &FailoverError{ + Reason: FailoverFormat, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + // Try HTTP status code extraction first. + if status := extractHTTPStatus(msg); status > 0 { + if reason := classifyByStatus(status); reason != "" { + return &FailoverError{ + Reason: reason, + Provider: provider, + Model: model, + Status: status, + Wrapped: err, + } + } + } + + // Message pattern matching (priority order from OpenClaw). + if reason := classifyByMessage(msg); reason != "" { + return &FailoverError{ + Reason: reason, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + return nil +} + +// classifyByStatus maps HTTP status codes to FailoverReason. +func classifyByStatus(status int) FailoverReason { + switch { + case status == 401 || status == 403: + return FailoverAuth + case status == 402: + return FailoverBilling + case status == 408: + return FailoverTimeout + case status == 429: + return FailoverRateLimit + case status == 400: + return FailoverFormat + case transientStatusCodes[status]: + return FailoverTimeout + } + return "" +} + +// classifyByMessage matches error messages against patterns. +// Priority order matters (from OpenClaw classifyFailoverReason). +func classifyByMessage(msg string) FailoverReason { + if matchesAny(msg, rateLimitPatterns) { + return FailoverRateLimit + } + if matchesAny(msg, overloadedPatterns) { + return FailoverRateLimit // Overloaded treated as rate_limit + } + if matchesAny(msg, billingPatterns) { + return FailoverBilling + } + if matchesAny(msg, timeoutPatterns) { + return FailoverTimeout + } + if matchesAny(msg, authPatterns) { + return FailoverAuth + } + if matchesAny(msg, formatPatterns) { + return FailoverFormat + } + return "" +} + +// extractHTTPStatus extracts an HTTP status code from an error message. +// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429". +func extractHTTPStatus(msg string) int { + // Common patterns in Go HTTP error messages + patterns := []*regexp.Regexp{ + regexp.MustCompile(`status[:\s]+(\d{3})`), + regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`), + } + + for _, p := range patterns { + if m := p.FindStringSubmatch(msg); len(m) > 1 { + return parseDigits(m[1]) + } + } + + return 0 +} + +// IsImageDimensionError returns true if the message indicates an image dimension error. +func IsImageDimensionError(msg string) bool { + return matchesAny(msg, imageDimensionPatterns) +} + +// IsImageSizeError returns true if the message indicates an image file size error. +func IsImageSizeError(msg string) bool { + return matchesAny(msg, imageSizePatterns) +} + +// matchesAny checks if msg matches any of the patterns. +func matchesAny(msg string, patterns []errorPattern) bool { + for _, p := range patterns { + if p.regex != nil { + if p.regex.MatchString(msg) { + return true + } + } else if p.substring != "" { + if strings.Contains(msg, p.substring) { + return true + } + } + } + return false +} + +// parseDigits converts a string of digits to an int. +func parseDigits(s string) int { + n := 0 + for _, c := range s { + if c >= '0' && c <= '9' { + n = n*10 + int(c-'0') + } + } + return n +} diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go new file mode 100644 index 000000000..865aea57a --- /dev/null +++ b/pkg/providers/error_classifier_test.go @@ -0,0 +1,337 @@ +package providers + +import ( + "context" + "errors" + "fmt" + "testing" +) + +func TestClassifyError_Nil(t *testing.T) { + result := ClassifyError(nil, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for nil error, got %+v", result) + } +} + +func TestClassifyError_ContextCanceled(t *testing.T) { + result := ClassifyError(context.Canceled, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for context.Canceled (user abort), got %+v", result) + } +} + +func TestClassifyError_ContextDeadlineExceeded(t *testing.T) { + result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4") + if result == nil { + t.Fatal("expected non-nil for deadline exceeded") + } + if result.Reason != FailoverTimeout { + t.Errorf("reason = %q, want timeout", result.Reason) + } +} + +func TestClassifyError_StatusCodes(t *testing.T) { + tests := []struct { + status int + reason FailoverReason + }{ + {401, FailoverAuth}, + {403, FailoverAuth}, + {402, FailoverBilling}, + {408, FailoverTimeout}, + {429, FailoverRateLimit}, + {400, FailoverFormat}, + {500, FailoverTimeout}, + {502, FailoverTimeout}, + {503, FailoverTimeout}, + {521, FailoverTimeout}, + {522, FailoverTimeout}, + {523, FailoverTimeout}, + {524, FailoverTimeout}, + {529, FailoverTimeout}, + } + + for _, tt := range tests { + err := fmt.Errorf("API error: status: %d something went wrong", tt.status) + result := ClassifyError(err, "test", "model") + if result == nil { + t.Errorf("status %d: expected non-nil", tt.status) + continue + } + if result.Reason != tt.reason { + t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason) + } + } +} + +func TestClassifyError_RateLimitPatterns(t *testing.T) { + patterns := []string{ + "rate limit exceeded", + "rate_limit reached", + "too many requests", + "exceeded your current quota", + "resource has been exhausted", + "resource_exhausted", + "quota exceeded", + "usage limit reached", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverRateLimit { + t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason) + } + } +} + +func TestClassifyError_OverloadedPatterns(t *testing.T) { + patterns := []string{ + "overloaded_error", + `{"type": "overloaded_error"}`, + "server is overloaded", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "anthropic", "claude") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + // Overloaded is treated as rate_limit + if result.Reason != FailoverRateLimit { + t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason) + } + } +} + +func TestClassifyError_BillingPatterns(t *testing.T) { + patterns := []string{ + "payment required", + "insufficient credits", + "credit balance too low", + "plans & billing page", + "insufficient balance", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverBilling { + t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason) + } + } +} + +func TestClassifyError_TimeoutPatterns(t *testing.T) { + patterns := []string{ + "request timeout", + "connection timed out", + "deadline exceeded", + "context deadline exceeded", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverTimeout { + t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason) + } + } +} + +func TestClassifyError_AuthPatterns(t *testing.T) { + patterns := []string{ + "invalid api key", + "invalid_api_key", + "incorrect api key", + "invalid token", + "authentication failed", + "re-authenticate", + "oauth token refresh failed", + "unauthorized access", + "forbidden", + "access denied", + "expired", + "token has expired", + "no credentials found", + "no api key found", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverAuth { + t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason) + } + } +} + +func TestClassifyError_FormatPatterns(t *testing.T) { + patterns := []string{ + "string should match pattern", + "tool_use.id is required", + "invalid tool_use_id", + "messages.1.content.1.tool_use.id must be valid", + "invalid request format", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "anthropic", "claude") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverFormat { + t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason) + } + } +} + +func TestClassifyError_ImageDimensionError(t *testing.T) { + err := errors.New("image dimensions exceed max allowed 2048x2048") + result := ClassifyError(err, "openai", "gpt-4o") + if result == nil { + t.Fatal("expected non-nil for image dimension error") + } + if result.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", result.Reason) + } + if result.IsRetriable() { + t.Error("image dimension error should not be retriable") + } +} + +func TestClassifyError_ImageSizeError(t *testing.T) { + err := errors.New("image exceeds 20 mb limit") + result := ClassifyError(err, "openai", "gpt-4o") + if result == nil { + t.Fatal("expected non-nil for image size error") + } + if result.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", result.Reason) + } +} + +func TestClassifyError_UnknownError(t *testing.T) { + err := errors.New("some completely random error") + result := ClassifyError(err, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for unknown error, got %+v", result) + } +} + +func TestClassifyError_ProviderModelPropagation(t *testing.T) { + err := errors.New("rate limit exceeded") + result := ClassifyError(err, "my-provider", "my-model") + if result == nil { + t.Fatal("expected non-nil") + } + if result.Provider != "my-provider" { + t.Errorf("provider = %q, want my-provider", result.Provider) + } + if result.Model != "my-model" { + t.Errorf("model = %q, want my-model", result.Model) + } +} + +func TestFailoverError_IsRetriable(t *testing.T) { + tests := []struct { + reason FailoverReason + retriable bool + }{ + {FailoverAuth, true}, + {FailoverRateLimit, true}, + {FailoverBilling, true}, + {FailoverTimeout, true}, + {FailoverOverloaded, true}, + {FailoverFormat, false}, + {FailoverUnknown, true}, + } + + for _, tt := range tests { + fe := &FailoverError{Reason: tt.reason} + if fe.IsRetriable() != tt.retriable { + t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable) + } + } +} + +func TestFailoverError_ErrorString(t *testing.T) { + fe := &FailoverError{ + Reason: FailoverRateLimit, + Provider: "openai", + Model: "gpt-4", + Status: 429, + Wrapped: errors.New("too many requests"), + } + s := fe.Error() + if s == "" { + t.Error("expected non-empty error string") + } +} + +func TestFailoverError_Unwrap(t *testing.T) { + inner := errors.New("inner error") + fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner} + if fe.Unwrap() != inner { + t.Error("Unwrap should return wrapped error") + } +} + +func TestExtractHTTPStatus(t *testing.T) { + tests := []struct { + msg string + want int + }{ + {"status: 429 rate limited", 429}, + {"status 401 unauthorized", 401}, + {"HTTP/1.1 502 Bad Gateway", 502}, + {"no status code here", 0}, + {"random number 12345", 0}, + } + + for _, tt := range tests { + got := extractHTTPStatus(tt.msg) + if got != tt.want { + t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want) + } + } +} + +func TestIsImageDimensionError(t *testing.T) { + if !IsImageDimensionError("image dimensions exceed max 4096x4096") { + t.Error("should match image dimensions exceed max") + } + if IsImageDimensionError("normal error message") { + t.Error("should not match normal error") + } +} + +func TestIsImageSizeError(t *testing.T) { + if !IsImageSizeError("image exceeds 20 mb") { + t.Error("should match image exceeds mb") + } + if IsImageSizeError("normal error message") { + t.Error("should not match normal error") + } +} diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go new file mode 100644 index 000000000..9b07f9153 --- /dev/null +++ b/pkg/providers/fallback.go @@ -0,0 +1,283 @@ +package providers + +import ( + "context" + "fmt" + "strings" + "time" +) + +// FallbackChain orchestrates model fallback across multiple candidates. +type FallbackChain struct { + cooldown *CooldownTracker +} + +// FallbackCandidate represents one model/provider to try. +type FallbackCandidate struct { + Provider string + Model string +} + +// FallbackResult contains the successful response and metadata about all attempts. +type FallbackResult struct { + Response *LLMResponse + Provider string + Model string + Attempts []FallbackAttempt +} + +// FallbackAttempt records one attempt in the fallback chain. +type FallbackAttempt struct { + Provider string + Model string + Error error + Reason FailoverReason + Duration time.Duration + Skipped bool // true if skipped due to cooldown +} + +// NewFallbackChain creates a new fallback chain with the given cooldown tracker. +func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain { + return &FallbackChain{cooldown: cooldown} +} + +// ResolveCandidates parses model config into a deduplicated candidate list. +func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate { + seen := make(map[string]bool) + var candidates []FallbackCandidate + + addCandidate := func(raw string) { + ref := ParseModelRef(raw, defaultProvider) + if ref == nil { + return + } + key := ModelKey(ref.Provider, ref.Model) + if seen[key] { + return + } + seen[key] = true + candidates = append(candidates, FallbackCandidate{ + Provider: ref.Provider, + Model: ref.Model, + }) + } + + // Primary first. + addCandidate(cfg.Primary) + + // Then fallbacks. + for _, fb := range cfg.Fallbacks { + addCandidate(fb) + } + + return candidates +} + +// Execute runs the fallback chain for text/chat requests. +// It tries each candidate in order, respecting cooldowns and error classification. +// +// Behavior: +// - Candidates in cooldown are skipped (logged as skipped attempt). +// - context.Canceled aborts immediately (user abort, no fallback). +// - Non-retriable errors (format) abort immediately. +// - Retriable errors trigger fallback to next candidate. +// - Success marks provider as good (resets cooldown). +// - If all fail, returns aggregate error with all attempts. +func (fc *FallbackChain) Execute( + ctx context.Context, + candidates []FallbackCandidate, + run func(ctx context.Context, provider, model string) (*LLMResponse, error), +) (*FallbackResult, error) { + if len(candidates) == 0 { + return nil, fmt.Errorf("fallback: no candidates configured") + } + + result := &FallbackResult{ + Attempts: make([]FallbackAttempt, 0, len(candidates)), + } + + for i, candidate := range candidates { + // Check context before each attempt. + if ctx.Err() == context.Canceled { + return nil, context.Canceled + } + + // Check cooldown. + if !fc.cooldown.IsAvailable(candidate.Provider) { + remaining := fc.cooldown.CooldownRemaining(candidate.Provider) + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Skipped: true, + Reason: FailoverRateLimit, + Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)), + }) + continue + } + + // Execute the run function. + start := time.Now() + resp, err := run(ctx, candidate.Provider, candidate.Model) + elapsed := time.Since(start) + + if err == nil { + // Success. + fc.cooldown.MarkSuccess(candidate.Provider) + result.Response = resp + result.Provider = candidate.Provider + result.Model = candidate.Model + return result, nil + } + + // Context cancellation: abort immediately, no fallback. + if ctx.Err() == context.Canceled { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, context.Canceled + } + + // Classify the error. + failErr := ClassifyError(err, candidate.Provider, candidate.Model) + + if failErr == nil { + // Unclassifiable error: do not fallback, return immediately. + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w", + candidate.Provider, candidate.Model, err) + } + + // Non-retriable error: abort immediately. + if !failErr.IsRetriable() { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: failErr, + Reason: failErr.Reason, + Duration: elapsed, + }) + return nil, failErr + } + + // Retriable error: mark failure and continue to next candidate. + fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason) + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: failErr, + Reason: failErr.Reason, + Duration: elapsed, + }) + + // If this was the last candidate, return aggregate error. + if i == len(candidates)-1 { + return nil, &FallbackExhaustedError{Attempts: result.Attempts} + } + } + + // All candidates were skipped (all in cooldown). + return nil, &FallbackExhaustedError{Attempts: result.Attempts} +} + +// ExecuteImage runs the fallback chain for image/vision requests. +// Simpler than Execute: no cooldown checks (image endpoints have different rate limits). +// Image dimension/size errors abort immediately (non-retriable). +func (fc *FallbackChain) ExecuteImage( + ctx context.Context, + candidates []FallbackCandidate, + run func(ctx context.Context, provider, model string) (*LLMResponse, error), +) (*FallbackResult, error) { + if len(candidates) == 0 { + return nil, fmt.Errorf("image fallback: no candidates configured") + } + + result := &FallbackResult{ + Attempts: make([]FallbackAttempt, 0, len(candidates)), + } + + for i, candidate := range candidates { + if ctx.Err() == context.Canceled { + return nil, context.Canceled + } + + start := time.Now() + resp, err := run(ctx, candidate.Provider, candidate.Model) + elapsed := time.Since(start) + + if err == nil { + result.Response = resp + result.Provider = candidate.Provider + result.Model = candidate.Model + return result, nil + } + + if ctx.Err() == context.Canceled { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, context.Canceled + } + + // Image dimension/size errors are non-retriable. + errMsg := strings.ToLower(err.Error()) + if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Reason: FailoverFormat, + Duration: elapsed, + }) + return nil, &FailoverError{ + Reason: FailoverFormat, + Provider: candidate.Provider, + Model: candidate.Model, + Wrapped: err, + } + } + + // Any other error: record and try next. + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + + if i == len(candidates)-1 { + return nil, &FallbackExhaustedError{Attempts: result.Attempts} + } + } + + return nil, &FallbackExhaustedError{Attempts: result.Attempts} +} + +// FallbackExhaustedError indicates all fallback candidates were tried and failed. +type FallbackExhaustedError struct { + Attempts []FallbackAttempt +} + +func (e *FallbackExhaustedError) Error() string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts))) + for i, a := range e.Attempts { + if a.Skipped { + sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model)) + } else { + sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)", + i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond))) + } + } + return sb.String() +} diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go new file mode 100644 index 000000000..ea81e0d48 --- /dev/null +++ b/pkg/providers/fallback_test.go @@ -0,0 +1,473 @@ +package providers + +import ( + "context" + "errors" + "testing" + "time" +) + +func makeCandidate(provider, model string) FallbackCandidate { + return FallbackCandidate{Provider: provider, Model: model} +} + +func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return &LLMResponse{Content: content, FinishReason: "stop"}, nil + } +} + +func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return nil, err + } +} + +func TestFallback_SingleCandidate_Success(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + result, err := fc.Execute(context.Background(), candidates, successRun("hello")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "hello" { + t.Errorf("content = %q, want hello", result.Response.Content) + } + if result.Provider != "openai" || result.Model != "gpt-4" { + t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model) + } +} + +func TestFallback_SecondCandidateSuccess(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude-opus"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, errors.New("rate limit exceeded") + } + return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } + if result.Response.Content != "from claude" { + t.Errorf("content = %q, want 'from claude'", result.Response.Content) + } + if len(result.Attempts) != 1 { + t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts)) + } +} + +func TestFallback_AllFail(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + makeCandidate("groq", "llama"), + } + + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return nil, errors.New("rate limit exceeded") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error when all candidates fail") + } + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err) + } + if len(exhausted.Attempts) != 3 { + t.Errorf("attempts = %d, want 3", len(exhausted.Attempts)) + } +} + +func TestFallback_ContextCanceled(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + ctx, cancel := context.WithCancel(context.Background()) + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + cancel() // cancel context + return nil, context.Canceled + } + t.Error("should not reach second candidate after cancel") + return nil, nil + } + + _, err := fc.Execute(ctx, candidates, run) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %v", err) + } +} + +func TestFallback_NonRetriableError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("string should match pattern") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for non-retriable") + } + var fe *FailoverError + if !errors.As(err, &fe) { + t.Fatalf("expected FailoverError, got %T", err) + } + if fe.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", fe.Reason) + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt) + } +} + +func TestFallback_CooldownSkip(t *testing.T) { + now := time.Now() + ct, _ := newTestTracker(now) + fc := NewFallbackChain(ct) + + // Put openai in cooldown + ct.MarkFailure("openai", FailoverRateLimit) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + if provider == "openai" { + t.Error("should not call openai (in cooldown)") + } + return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } + // Should have 1 skipped attempt + skipped := 0 + for _, a := range result.Attempts { + if a.Skipped { + skipped++ + } + } + if skipped != 1 { + t.Errorf("skipped = %d, want 1", skipped) + } +} + +func TestFallback_AllInCooldown(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + // Put all providers in cooldown + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("anthropic", FailoverBilling) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + _, err := fc.Execute(context.Background(), candidates, + func(ctx context.Context, provider, model string) (*LLMResponse, error) { + t.Error("should not call any provider (all in cooldown)") + return nil, nil + }) + + if err == nil { + t.Fatal("expected error when all in cooldown") + } + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Fatalf("expected FallbackExhaustedError, got %T", err) + } +} + +func TestFallback_NoCandidates(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + _, err := fc.Execute(context.Background(), nil, successRun("ok")) + if err == nil { + t.Error("expected error for empty candidates") + } +} + +func TestFallback_EmptyFallbacks(t *testing.T) { + // Single primary, no fallbacks: should work like direct call + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + result, err := fc.Execute(context.Background(), candidates, successRun("ok")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "ok" { + t.Error("expected success with single candidate") + } +} + +func TestFallback_UnclassifiedError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("completely unknown internal error") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for unclassified error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt) + } +} + +func TestFallback_SuccessResetsCooldown(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere + } + return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ct.IsAvailable("openai") { + t.Error("success should reset cooldown") + } +} + +// --- Image Fallback Tests --- + +func TestImageFallback_Success(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")} + result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "image result" { + t.Error("expected image result") + } +} + +func TestImageFallback_DimensionError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("image dimensions exceed max 4096x4096") + } + + _, err := fc.ExecuteImage(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for image dimension error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt) + } +} + +func TestImageFallback_SizeError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("image exceeds 20 mb") + } + + _, err := fc.ExecuteImage(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for image size error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt) + } +} + +func TestImageFallback_RetryOnOtherErrors(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude-sonnet"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, errors.New("rate limit exceeded") + } + return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil + } + + result, err := fc.ExecuteImage(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } +} + +func TestImageFallback_NoCandidates(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + _, err := fc.ExecuteImage(context.Background(), nil, successRun("ok")) + if err == nil { + t.Error("expected error for empty candidates") + } +} + +// --- ResolveCandidates Tests --- + +func TestResolveCandidates_Simple(t *testing.T) { + cfg := ModelConfig{ + Primary: "gpt-4", + Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"}, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 3 { + t.Fatalf("candidates = %d, want 3", len(candidates)) + } + + if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" { + t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model) + } + if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" { + t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model) + } + if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" { + t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model) + } +} + +func TestResolveCandidates_Deduplication(t *testing.T) { + cfg := ModelConfig{ + Primary: "openai/gpt-4", + Fallbacks: []string{"openai/gpt-4", "anthropic/claude"}, + } + + candidates := ResolveCandidates(cfg, "default") + if len(candidates) != 2 { + t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates)) + } +} + +func TestResolveCandidates_EmptyFallbacks(t *testing.T) { + cfg := ModelConfig{ + Primary: "gpt-4", + Fallbacks: nil, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 1 { + t.Errorf("candidates = %d, want 1", len(candidates)) + } +} + +func TestResolveCandidates_EmptyPrimary(t *testing.T) { + cfg := ModelConfig{ + Primary: "", + Fallbacks: []string{"anthropic/claude"}, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 1 { + t.Errorf("candidates = %d, want 1", len(candidates)) + } +} + +func TestFallbackExhaustedError_Message(t *testing.T) { + e := &FallbackExhaustedError{ + Attempts: []FallbackAttempt{ + {Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond}, + {Provider: "anthropic", Model: "claude", Skipped: true}, + }, + } + msg := e.Error() + if msg == "" { + t.Error("expected non-empty error message") + } +} diff --git a/pkg/providers/model_ref.go b/pkg/providers/model_ref.go new file mode 100644 index 000000000..0d1b02d16 --- /dev/null +++ b/pkg/providers/model_ref.go @@ -0,0 +1,64 @@ +package providers + +import "strings" + +// ModelRef represents a parsed model reference with provider and model name. +type ModelRef struct { + Provider string + Model string +} + +// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}. +// If no slash present, uses defaultProvider. +// Returns nil for empty input. +func ParseModelRef(raw string, defaultProvider string) *ModelRef { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + if idx := strings.Index(raw, "/"); idx > 0 { + provider := NormalizeProvider(raw[:idx]) + model := strings.TrimSpace(raw[idx+1:]) + if model == "" { + return nil + } + return &ModelRef{Provider: provider, Model: model} + } + + return &ModelRef{ + Provider: NormalizeProvider(defaultProvider), + Model: raw, + } +} + +// NormalizeProvider normalizes provider identifiers to canonical form. +func NormalizeProvider(provider string) string { + p := strings.ToLower(strings.TrimSpace(provider)) + + switch p { + case "z.ai", "z-ai": + return "zai" + case "opencode-zen": + return "opencode" + case "qwen": + return "qwen-portal" + case "kimi-code": + return "kimi-coding" + case "gpt": + return "openai" + case "claude": + return "anthropic" + case "glm": + return "zhipu" + case "google": + return "gemini" + } + + return p +} + +// ModelKey returns a canonical "provider/model" key for deduplication. +func ModelKey(provider, model string) string { + return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model)) +} diff --git a/pkg/providers/model_ref_test.go b/pkg/providers/model_ref_test.go new file mode 100644 index 000000000..6dd25167f --- /dev/null +++ b/pkg/providers/model_ref_test.go @@ -0,0 +1,125 @@ +package providers + +import "testing" + +func TestParseModelRef_WithSlash(t *testing.T) { + ref := ParseModelRef("anthropic/claude-opus", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", ref.Provider) + } + if ref.Model != "claude-opus" { + t.Errorf("model = %q, want claude-opus", ref.Model) + } +} + +func TestParseModelRef_WithoutSlash(t *testing.T) { + ref := ParseModelRef("gpt-4", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "openai" { + t.Errorf("provider = %q, want openai", ref.Provider) + } + if ref.Model != "gpt-4" { + t.Errorf("model = %q, want gpt-4", ref.Model) + } +} + +func TestParseModelRef_Empty(t *testing.T) { + ref := ParseModelRef("", "openai") + if ref != nil { + t.Errorf("expected nil for empty string, got %+v", ref) + } +} + +func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) { + ref := ParseModelRef("openai/", "default") + if ref != nil { + t.Errorf("expected nil for empty model, got %+v", ref) + } +} + +func TestParseModelRef_WhitespaceHandling(t *testing.T) { + ref := ParseModelRef(" anthropic / claude-opus ", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", ref.Provider) + } + if ref.Model != "claude-opus" { + t.Errorf("model = %q, want claude-opus", ref.Model) + } +} + +func TestNormalizeProvider(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"OpenAI", "openai"}, + {"ANTHROPIC", "anthropic"}, + {"z.ai", "zai"}, + {"z-ai", "zai"}, + {"Z.AI", "zai"}, + {"opencode-zen", "opencode"}, + {"qwen", "qwen-portal"}, + {"kimi-code", "kimi-coding"}, + {"gpt", "openai"}, + {"claude", "anthropic"}, + {"glm", "zhipu"}, + {"google", "gemini"}, + {"groq", "groq"}, + {"", ""}, + } + + for _, tt := range tests { + got := NormalizeProvider(tt.input) + if got != tt.want { + t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestModelKey(t *testing.T) { + tests := []struct { + provider string + model string + want string + }{ + {"openai", "gpt-4", "openai/gpt-4"}, + {"Anthropic", "Claude-Opus", "anthropic/claude-opus"}, + {"claude", "sonnet", "anthropic/sonnet"}, + {"z.ai", "Model-X", "zai/model-x"}, + } + + for _, tt := range tests { + got := ModelKey(tt.provider, tt.model) + if got != tt.want { + t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want) + } + } +} + +func TestParseModelRef_ProviderNormalization(t *testing.T) { + ref := ParseModelRef("Z.AI/model-x", "default") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "zai" { + t.Errorf("provider = %q, want zai", ref.Provider) + } +} + +func TestParseModelRef_DefaultProviderNormalization(t *testing.T) { + ref := ParseModelRef("gpt-4o", "GPT") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "openai" { + t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider) + } +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 221a842fa..c4a9de58a 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -2,6 +2,7 @@ package providers import ( "context" + "fmt" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -18,3 +19,46 @@ type LLMProvider interface { Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) GetDefaultModel() string } + +// FailoverReason classifies why an LLM request failed for fallback decisions. +type FailoverReason string + +const ( + FailoverAuth FailoverReason = "auth" + FailoverRateLimit FailoverReason = "rate_limit" + FailoverBilling FailoverReason = "billing" + FailoverTimeout FailoverReason = "timeout" + FailoverFormat FailoverReason = "format" + FailoverOverloaded FailoverReason = "overloaded" + FailoverUnknown FailoverReason = "unknown" +) + +// FailoverError wraps an LLM provider error with classification metadata. +type FailoverError struct { + Reason FailoverReason + Provider string + Model string + Status int + Wrapped error +} + +func (e *FailoverError) Error() string { + return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v", + e.Reason, e.Provider, e.Model, e.Status, e.Wrapped) +} + +func (e *FailoverError) Unwrap() error { + return e.Wrapped +} + +// IsRetriable returns true if this error should trigger fallback to next candidate. +// Non-retriable: Format errors (bad request structure, image dimension/size). +func (e *FailoverError) IsRetriable() bool { + return e.Reason != FailoverFormat +} + +// ModelConfig holds primary model and fallback list. +type ModelConfig struct { + Primary string + Fallbacks []string +} diff --git a/pkg/routing/agent_id.go b/pkg/routing/agent_id.go new file mode 100644 index 000000000..bcf2f0dc0 --- /dev/null +++ b/pkg/routing/agent_id.go @@ -0,0 +1,66 @@ +package routing + +import ( + "regexp" + "strings" +) + +const ( + DefaultAgentID = "main" + DefaultMainKey = "main" + DefaultAccountID = "default" + MaxAgentIDLength = 64 +) + +var ( + validIDRe = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{0,63}$`) + invalidCharsRe = regexp.MustCompile(`[^a-z0-9_-]+`) + leadingDashRe = regexp.MustCompile(`^-+`) + trailingDashRe = regexp.MustCompile(`-+$`) +) + +// NormalizeAgentID sanitizes an agent ID to [a-z0-9][a-z0-9_-]{0,63}. +// Invalid characters are collapsed to "-". Leading/trailing dashes stripped. +// Empty input returns DefaultAgentID ("main"). +func NormalizeAgentID(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return DefaultAgentID + } + lower := strings.ToLower(trimmed) + if validIDRe.MatchString(lower) { + return lower + } + result := invalidCharsRe.ReplaceAllString(lower, "-") + result = leadingDashRe.ReplaceAllString(result, "") + result = trailingDashRe.ReplaceAllString(result, "") + if len(result) > MaxAgentIDLength { + result = result[:MaxAgentIDLength] + } + if result == "" { + return DefaultAgentID + } + return result +} + +// NormalizeAccountID sanitizes an account ID. Empty returns DefaultAccountID. +func NormalizeAccountID(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return DefaultAccountID + } + lower := strings.ToLower(trimmed) + if validIDRe.MatchString(lower) { + return lower + } + result := invalidCharsRe.ReplaceAllString(lower, "-") + result = leadingDashRe.ReplaceAllString(result, "") + result = trailingDashRe.ReplaceAllString(result, "") + if len(result) > MaxAgentIDLength { + result = result[:MaxAgentIDLength] + } + if result == "" { + return DefaultAccountID + } + return result +} diff --git a/pkg/routing/agent_id_test.go b/pkg/routing/agent_id_test.go new file mode 100644 index 000000000..050fe0645 --- /dev/null +++ b/pkg/routing/agent_id_test.go @@ -0,0 +1,86 @@ +package routing + +import "testing" + +func TestNormalizeAgentID_Empty(t *testing.T) { + if got := NormalizeAgentID(""); got != DefaultAgentID { + t.Errorf("NormalizeAgentID('') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_Whitespace(t *testing.T) { + if got := NormalizeAgentID(" "); got != DefaultAgentID { + t.Errorf("NormalizeAgentID(' ') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_Valid(t *testing.T) { + tests := []struct { + input, want string + }{ + {"main", "main"}, + {"Main", "main"}, + {"SALES", "sales"}, + {"support-bot", "support-bot"}, + {"agent_1", "agent_1"}, + {"a", "a"}, + {"0test", "0test"}, + } + for _, tt := range tests { + if got := NormalizeAgentID(tt.input); got != tt.want { + t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestNormalizeAgentID_InvalidChars(t *testing.T) { + tests := []struct { + input, want string + }{ + {"Hello World", "hello-world"}, + {"agent@123", "agent-123"}, + {"foo.bar.baz", "foo-bar-baz"}, + {"--leading", "leading"}, + {"--both--", "both"}, + } + for _, tt := range tests { + if got := NormalizeAgentID(tt.input); got != tt.want { + t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestNormalizeAgentID_AllInvalid(t *testing.T) { + if got := NormalizeAgentID("@@@"); got != DefaultAgentID { + t.Errorf("NormalizeAgentID('@@@') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_TruncatesAt64(t *testing.T) { + long := "" + for i := 0; i < 100; i++ { + long += "a" + } + got := NormalizeAgentID(long) + if len(got) > MaxAgentIDLength { + t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength) + } +} + +func TestNormalizeAccountID_Empty(t *testing.T) { + if got := NormalizeAccountID(""); got != DefaultAccountID { + t.Errorf("NormalizeAccountID('') = %q, want %q", got, DefaultAccountID) + } +} + +func TestNormalizeAccountID_Valid(t *testing.T) { + if got := NormalizeAccountID("MyBot"); got != "mybot" { + t.Errorf("NormalizeAccountID('MyBot') = %q, want 'mybot'", got) + } +} + +func TestNormalizeAccountID_InvalidChars(t *testing.T) { + if got := NormalizeAccountID("bot@home"); got != "bot-home" { + t.Errorf("NormalizeAccountID('bot@home') = %q, want 'bot-home'", got) + } +} diff --git a/pkg/routing/route.go b/pkg/routing/route.go new file mode 100644 index 000000000..9eb060c53 --- /dev/null +++ b/pkg/routing/route.go @@ -0,0 +1,252 @@ +package routing + +import ( + "strings" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// RouteInput contains the routing context from an inbound message. +type RouteInput struct { + Channel string + AccountID string + Peer *RoutePeer + ParentPeer *RoutePeer + GuildID string + TeamID string +} + +// ResolvedRoute is the result of agent routing. +type ResolvedRoute struct { + AgentID string + Channel string + AccountID string + SessionKey string + MainSessionKey string + MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default" +} + +// RouteResolver determines which agent handles a message based on config bindings. +type RouteResolver struct { + cfg *config.Config +} + +// NewRouteResolver creates a new route resolver. +func NewRouteResolver(cfg *config.Config) *RouteResolver { + return &RouteResolver{cfg: cfg} +} + +// ResolveRoute determines which agent handles the message and constructs session keys. +// Implements the 7-level priority cascade: +// peer > parent_peer > guild > team > account > channel_wildcard > default +func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { + channel := strings.ToLower(strings.TrimSpace(input.Channel)) + accountID := NormalizeAccountID(input.AccountID) + peer := input.Peer + + dmScope := DMScope(r.cfg.Session.DMScope) + if dmScope == "" { + dmScope = DMScopeMain + } + identityLinks := r.cfg.Session.IdentityLinks + + bindings := r.filterBindings(channel, accountID) + + choose := func(agentID string, matchedBy string) ResolvedRoute { + resolvedAgentID := r.pickAgentID(agentID) + sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: resolvedAgentID, + Channel: channel, + AccountID: accountID, + Peer: peer, + DMScope: dmScope, + IdentityLinks: identityLinks, + })) + mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID)) + return ResolvedRoute{ + AgentID: resolvedAgentID, + Channel: channel, + AccountID: accountID, + SessionKey: sessionKey, + MainSessionKey: mainSessionKey, + MatchedBy: matchedBy, + } + } + + // Priority 1: Peer binding + if peer != nil && strings.TrimSpace(peer.ID) != "" { + if match := r.findPeerMatch(bindings, peer); match != nil { + return choose(match.AgentID, "binding.peer") + } + } + + // Priority 2: Parent peer binding + parentPeer := input.ParentPeer + if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" { + if match := r.findPeerMatch(bindings, parentPeer); match != nil { + return choose(match.AgentID, "binding.peer.parent") + } + } + + // Priority 3: Guild binding + guildID := strings.TrimSpace(input.GuildID) + if guildID != "" { + if match := r.findGuildMatch(bindings, guildID); match != nil { + return choose(match.AgentID, "binding.guild") + } + } + + // Priority 4: Team binding + teamID := strings.TrimSpace(input.TeamID) + if teamID != "" { + if match := r.findTeamMatch(bindings, teamID); match != nil { + return choose(match.AgentID, "binding.team") + } + } + + // Priority 5: Account binding + if match := r.findAccountMatch(bindings); match != nil { + return choose(match.AgentID, "binding.account") + } + + // Priority 6: Channel wildcard binding + if match := r.findChannelWildcardMatch(bindings); match != nil { + return choose(match.AgentID, "binding.channel") + } + + // Priority 7: Default agent + return choose(r.resolveDefaultAgentID(), "default") +} + +func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding { + var filtered []config.AgentBinding + for _, b := range r.cfg.Bindings { + matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel)) + if matchChannel == "" || matchChannel != channel { + continue + } + if !matchesAccountID(b.Match.AccountID, accountID) { + continue + } + filtered = append(filtered, b) + } + return filtered +} + +func matchesAccountID(matchAccountID, actual string) bool { + trimmed := strings.TrimSpace(matchAccountID) + if trimmed == "" { + return actual == DefaultAccountID + } + if trimmed == "*" { + return true + } + return strings.ToLower(trimmed) == strings.ToLower(actual) +} + +func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + if b.Match.Peer == nil { + continue + } + peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind)) + peerID := strings.TrimSpace(b.Match.Peer.ID) + if peerKind == "" || peerID == "" { + continue + } + if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID { + return b + } + } + return nil +} + +func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + matchGuild := strings.TrimSpace(b.Match.GuildID) + if matchGuild != "" && matchGuild == guildID { + return &bindings[i] + } + } + return nil +} + +func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + matchTeam := strings.TrimSpace(b.Match.TeamID) + if matchTeam != "" && matchTeam == teamID { + return &bindings[i] + } + } + return nil +} + +func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + accountID := strings.TrimSpace(b.Match.AccountID) + if accountID == "*" { + continue + } + if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { + continue + } + return &bindings[i] + } + return nil +} + +func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + accountID := strings.TrimSpace(b.Match.AccountID) + if accountID != "*" { + continue + } + if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { + continue + } + return &bindings[i] + } + return nil +} + +func (r *RouteResolver) pickAgentID(agentID string) string { + trimmed := strings.TrimSpace(agentID) + if trimmed == "" { + return NormalizeAgentID(r.resolveDefaultAgentID()) + } + normalized := NormalizeAgentID(trimmed) + agents := r.cfg.Agents.List + if len(agents) == 0 { + return normalized + } + for _, a := range agents { + if NormalizeAgentID(a.ID) == normalized { + return normalized + } + } + return NormalizeAgentID(r.resolveDefaultAgentID()) +} + +func (r *RouteResolver) resolveDefaultAgentID() string { + agents := r.cfg.Agents.List + if len(agents) == 0 { + return DefaultAgentID + } + for _, a := range agents { + if a.Default { + id := strings.TrimSpace(a.ID) + if id != "" { + return NormalizeAgentID(id) + } + } + } + if id := strings.TrimSpace(agents[0].ID); id != "" { + return NormalizeAgentID(id) + } + return DefaultAgentID +} diff --git a/pkg/routing/route_test.go b/pkg/routing/route_test.go new file mode 100644 index 000000000..8255db5f9 --- /dev/null +++ b/pkg/routing/route_test.go @@ -0,0 +1,297 @@ +package routing + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config { + return &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: "/tmp/picoclaw-test", + Model: "gpt-4", + }, + List: agents, + }, + Bindings: bindings, + Session: config.SessionConfig{ + DMScope: "per-peer", + }, + } +} + +func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) { + cfg := testConfig(nil, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != DefaultAgentID { + t.Errorf("AgentID = %q, want %q", route.AgentID, DefaultAgentID) + } + if route.MatchedBy != "default" { + t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy) + } +} + +func TestResolveRoute_PeerBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "sales", Default: true}, + {ID: "support"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "support", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + Peer: &config.PeerMatch{Kind: "direct", ID: "user123"}, + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + }) + + if route.AgentID != "support" { + t.Errorf("AgentID = %q, want 'support'", route.AgentID) + } + if route.MatchedBy != "binding.peer" { + t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + } +} + +func TestResolveRoute_GuildBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "gaming"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "gaming", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + GuildID: "guild-abc", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "discord", + GuildID: "guild-abc", + Peer: &RoutePeer{Kind: "channel", ID: "ch1"}, + }) + + if route.AgentID != "gaming" { + t.Errorf("AgentID = %q, want 'gaming'", route.AgentID) + } + if route.MatchedBy != "binding.guild" { + t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy) + } +} + +func TestResolveRoute_TeamBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "work"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "work", + Match: config.BindingMatch{ + Channel: "slack", + AccountID: "*", + TeamID: "T12345", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "slack", + TeamID: "T12345", + Peer: &RoutePeer{Kind: "channel", ID: "C001"}, + }) + + if route.AgentID != "work" { + t.Errorf("AgentID = %q, want 'work'", route.AgentID) + } + if route.MatchedBy != "binding.team" { + t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy) + } +} + +func TestResolveRoute_AccountBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "default-agent", Default: true}, + {ID: "premium"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "premium", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "bot2", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + AccountID: "bot2", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != "premium" { + t.Errorf("AgentID = %q, want 'premium'", route.AgentID) + } + if route.MatchedBy != "binding.account" { + t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy) + } +} + +func TestResolveRoute_ChannelWildcard(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "main", Default: true}, + {ID: "telegram-bot"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "telegram-bot", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != "telegram-bot" { + t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID) + } + if route.MatchedBy != "binding.channel" { + t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy) + } +} + +func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "vip"}, + {ID: "gaming"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "vip", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"}, + }, + }, + { + AgentID: "gaming", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + GuildID: "guild-1", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "discord", + GuildID: "guild-1", + Peer: &RoutePeer{Kind: "direct", ID: "user-vip"}, + }) + + if route.AgentID != "vip" { + t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID) + } + if route.MatchedBy != "binding.peer" { + t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + } +} + +func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "main", Default: true}, + } + bindings := []config.AgentBinding{ + { + AgentID: "nonexistent", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + }) + + if route.AgentID != "main" { + t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID) + } +} + +func TestResolveRoute_DefaultAgentSelection(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta", Default: true}, + {ID: "gamma"}, + } + cfg := testConfig(agents, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "cli", + }) + + if route.AgentID != "beta" { + t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID) + } +} + +func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta"}, + } + cfg := testConfig(agents, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "cli", + }) + + if route.AgentID != "alpha" { + t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID) + } +} diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go new file mode 100644 index 000000000..e12f0d1d8 --- /dev/null +++ b/pkg/routing/session_key.go @@ -0,0 +1,183 @@ +package routing + +import ( + "fmt" + "strings" +) + +// DMScope controls DM session isolation granularity. +type DMScope string + +const ( + DMScopeMain DMScope = "main" + DMScopePerPeer DMScope = "per-peer" + DMScopePerChannelPeer DMScope = "per-channel-peer" + DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer" +) + +// RoutePeer represents a chat peer with kind and ID. +type RoutePeer struct { + Kind string // "direct", "group", "channel" + ID string +} + +// SessionKeyParams holds all inputs for session key construction. +type SessionKeyParams struct { + AgentID string + Channel string + AccountID string + Peer *RoutePeer + DMScope DMScope + IdentityLinks map[string][]string +} + +// ParsedSessionKey is the result of parsing an agent-scoped session key. +type ParsedSessionKey struct { + AgentID string + Rest string +} + +// BuildAgentMainSessionKey returns "agent::main". +func BuildAgentMainSessionKey(agentID string) string { + return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey) +} + +// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope. +func BuildAgentPeerSessionKey(params SessionKeyParams) string { + agentID := NormalizeAgentID(params.AgentID) + + peer := params.Peer + if peer == nil { + peer = &RoutePeer{Kind: "direct"} + } + peerKind := strings.TrimSpace(peer.Kind) + if peerKind == "" { + peerKind = "direct" + } + + if peerKind == "direct" { + dmScope := params.DMScope + if dmScope == "" { + dmScope = DMScopeMain + } + peerID := strings.TrimSpace(peer.ID) + + // Resolve identity links (cross-platform collapse) + if dmScope != DMScopeMain && peerID != "" { + if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" { + peerID = linked + } + } + peerID = strings.ToLower(peerID) + + switch dmScope { + case DMScopePerAccountChannelPeer: + if peerID != "" { + channel := normalizeChannel(params.Channel) + accountID := NormalizeAccountID(params.AccountID) + return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID) + } + case DMScopePerChannelPeer: + if peerID != "" { + channel := normalizeChannel(params.Channel) + return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID) + } + case DMScopePerPeer: + if peerID != "" { + return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID) + } + } + return BuildAgentMainSessionKey(agentID) + } + + // Group/channel peers always get per-peer sessions + channel := normalizeChannel(params.Channel) + peerID := strings.ToLower(strings.TrimSpace(peer.ID)) + if peerID == "" { + peerID = "unknown" + } + return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID) +} + +// ParseAgentSessionKey extracts agentId and rest from "agent::". +func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return nil + } + parts := strings.SplitN(raw, ":", 3) + if len(parts) < 3 { + return nil + } + if parts[0] != "agent" { + return nil + } + agentID := strings.TrimSpace(parts[1]) + rest := parts[2] + if agentID == "" || rest == "" { + return nil + } + return &ParsedSessionKey{AgentID: agentID, Rest: rest} +} + +// IsSubagentSessionKey returns true if the session key represents a subagent. +func IsSubagentSessionKey(sessionKey string) bool { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return false + } + if strings.HasPrefix(strings.ToLower(raw), "subagent:") { + return true + } + parsed := ParseAgentSessionKey(raw) + if parsed == nil { + return false + } + return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:") +} + +func normalizeChannel(channel string) string { + c := strings.TrimSpace(strings.ToLower(channel)) + if c == "" { + return "unknown" + } + return c +} + +func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string { + if len(identityLinks) == 0 { + return "" + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + candidates := make(map[string]bool) + rawCandidate := strings.ToLower(peerID) + if rawCandidate != "" { + candidates[rawCandidate] = true + } + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel != "" { + scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID)) + candidates[scopedCandidate] = true + } + if len(candidates) == 0 { + return "" + } + + for canonical, ids := range identityLinks { + canonicalName := strings.TrimSpace(canonical) + if canonicalName == "" { + continue + } + for _, id := range ids { + normalized := strings.ToLower(strings.TrimSpace(id)) + if normalized != "" && candidates[normalized] { + return canonicalName + } + } + } + return "" +} diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go new file mode 100644 index 000000000..81e4ce018 --- /dev/null +++ b/pkg/routing/session_key_test.go @@ -0,0 +1,162 @@ +package routing + +import "testing" + +func TestBuildAgentMainSessionKey(t *testing.T) { + got := BuildAgentMainSessionKey("sales") + want := "agent:sales:main" + if got != want { + t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want) + } +} + +func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) { + got := BuildAgentMainSessionKey("Sales Bot") + want := "agent:sales-bot:main" + if got != want { + t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopeMain, + }) + want := "agent:main:main" + if got != want { + t.Errorf("DMScopeMain = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerPeer, + }) + want := "agent:main:direct:user123" + if got != want { + t.Errorf("DMScopePerPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerChannelPeer, + }) + want := "agent:main:telegram:direct:user123" + if got != want { + t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + AccountID: "bot1", + Peer: &RoutePeer{Kind: "direct", ID: "User123"}, + DMScope: DMScopePerAccountChannelPeer, + }) + want := "agent:main:telegram:bot1:direct:user123" + if got != want { + t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "group", ID: "chat456"}, + DMScope: DMScopePerPeer, + }) + want := "agent:main:telegram:group:chat456" + if got != want { + t.Errorf("GroupPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: nil, + DMScope: DMScopePerPeer, + }) + // nil peer defaults to direct with empty ID, falls to main + want := "agent:main:main" + if got != want { + t.Errorf("NilPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) { + links := map[string][]string{ + "john": {"telegram:user123", "discord:john#1234"}, + } + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerPeer, + IdentityLinks: links, + }) + want := "agent:main:direct:john" + if got != want { + t.Errorf("IdentityLink = %q, want %q", got, want) + } +} + +func TestParseAgentSessionKey_Valid(t *testing.T) { + parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123") + if parsed == nil { + t.Fatal("expected non-nil result") + } + if parsed.AgentID != "sales" { + t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID) + } + if parsed.Rest != "telegram:direct:user123" { + t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest) + } +} + +func TestParseAgentSessionKey_Invalid(t *testing.T) { + tests := []string{ + "", + "foo:bar", + "notprefix:sales:main", + "agent::main", + "agent:sales:", + } + for _, input := range tests { + if got := ParseAgentSessionKey(input); got != nil { + t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got) + } + } +} + +func TestIsSubagentSessionKey(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"subagent:task-1", true}, + {"agent:main:subagent:task-1", true}, + {"agent:main:main", false}, + {"agent:main:telegram:direct:user123", false}, + {"", false}, + } + for _, tt := range tests { + if got := IsSubagentSessionKey(tt.input); got != tt.want { + t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 42dd36a33..f01372467 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -6,10 +6,11 @@ import ( ) type SpawnTool struct { - manager *SubagentManager - originChannel string - originChatID string - callback AsyncCallback // For async completion notification + manager *SubagentManager + originChannel string + originChatID string + allowlistCheck func(targetAgentID string) bool + callback AsyncCallback // For async completion notification } func NewSpawnTool(manager *SubagentManager) *SpawnTool { @@ -45,6 +46,10 @@ func (t *SpawnTool) Parameters() map[string]interface{} { "type": "string", "description": "Optional short label for the task (for display)", }, + "agent_id": map[string]interface{}{ + "type": "string", + "description": "Optional target agent ID to delegate the task to", + }, }, "required": []string{"task"}, } @@ -55,6 +60,10 @@ func (t *SpawnTool) SetContext(channel, chatID string) { t.originChatID = chatID } +func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) { + t.allowlistCheck = check +} + func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { task, ok := args["task"].(string) if !ok { @@ -62,13 +71,21 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *T } label, _ := args["label"].(string) + agentID, _ := args["agent_id"].(string) + + // Check allowlist if targeting a specific agent + if agentID != "" && t.allowlistCheck != nil { + if !t.allowlistCheck(agentID) { + return ErrorResult(fmt.Sprintf("not allowed to spawn agent '%s'", agentID)) + } + } if t.manager == nil { return ErrorResult("Subagent manager not configured") } // Pass callback to manager for async completion notification - result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback) + result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback) if err != nil { return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index efa1d33aa..2fc7162d0 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -14,6 +14,7 @@ type SubagentTask struct { ID string Task string Label string + AgentID string OriginChannel string OriginChatID string Status string @@ -61,7 +62,7 @@ func (sm *SubagentManager) RegisterTool(tool Tool) { sm.tools.Register(tool) } -func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) { +func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string, callback AsyncCallback) (string, error) { sm.mu.Lock() defer sm.mu.Unlock() @@ -72,6 +73,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel ID: taskID, Task: task, Label: label, + AgentID: agentID, OriginChannel: originChannel, OriginChatID: originChatID, Status: "running",