diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go new file mode 100644 index 000000000..5eb0630b5 --- /dev/null +++ b/pkg/agent/instance.go @@ -0,0 +1,144 @@ +package agent + +import ( + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// AgentInstance represents a fully configured agent with its own workspace, +// session manager, context builder, and tool registry. +type AgentInstance struct { + ID string + Name string + Model string + Fallbacks []string + Workspace string + MaxIterations int + ContextWindow int + Provider providers.LLMProvider + Sessions *session.SessionManager + ContextBuilder *ContextBuilder + Tools *tools.ToolRegistry + Subagents *config.SubagentsConfig + SkillsFilter []string + Candidates []providers.FallbackCandidate +} + +// NewAgentInstance creates an agent instance from config. +func NewAgentInstance( + agentCfg *config.AgentConfig, + defaults *config.AgentDefaults, + provider providers.LLMProvider, +) *AgentInstance { + workspace := resolveAgentWorkspace(agentCfg, defaults) + os.MkdirAll(workspace, 0755) + + model := resolveAgentModel(agentCfg, defaults) + fallbacks := resolveAgentFallbacks(agentCfg, defaults) + + restrict := defaults.RestrictToWorkspace + toolsRegistry := tools.NewToolRegistry() + toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) + toolsRegistry.Register(tools.NewExecTool(workspace, restrict)) + toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) + + sessionsDir := filepath.Join(workspace, "sessions") + sessionsManager := session.NewSessionManager(sessionsDir) + + contextBuilder := NewContextBuilder(workspace) + contextBuilder.SetToolsRegistry(toolsRegistry) + + agentID := routing.DefaultAgentID + agentName := "" + var subagents *config.SubagentsConfig + var skillsFilter []string + + if agentCfg != nil { + agentID = routing.NormalizeAgentID(agentCfg.ID) + agentName = agentCfg.Name + subagents = agentCfg.Subagents + skillsFilter = agentCfg.Skills + } + + maxIter := defaults.MaxToolIterations + if maxIter == 0 { + maxIter = 20 + } + + // Resolve fallback candidates + modelCfg := providers.ModelConfig{ + Primary: model, + Fallbacks: fallbacks, + } + candidates := providers.ResolveCandidates(modelCfg, defaults.Provider) + + return &AgentInstance{ + ID: agentID, + Name: agentName, + Model: model, + Fallbacks: fallbacks, + Workspace: workspace, + MaxIterations: maxIter, + ContextWindow: defaults.MaxTokens, + Provider: provider, + Sessions: sessionsManager, + ContextBuilder: contextBuilder, + Tools: toolsRegistry, + Subagents: subagents, + SkillsFilter: skillsFilter, + Candidates: candidates, + } +} + +// resolveAgentWorkspace determines the workspace directory for an agent. +func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string { + if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" { + return expandHome(strings.TrimSpace(agentCfg.Workspace)) + } + if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" { + return expandHome(defaults.Workspace) + } + home, _ := os.UserHomeDir() + id := routing.NormalizeAgentID(agentCfg.ID) + return filepath.Join(home, ".picoclaw", "workspace-"+id) +} + +// resolveAgentModel resolves the primary model for an agent. +func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string { + if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" { + return strings.TrimSpace(agentCfg.Model.Primary) + } + return defaults.Model +} + +// resolveAgentFallbacks resolves the fallback models for an agent. +func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) []string { + if agentCfg != nil && agentCfg.Model != nil && agentCfg.Model.Fallbacks != nil { + return agentCfg.Model.Fallbacks + } + return defaults.ModelFallbacks +} + +func expandHome(path string) string { + if path == "" { + return path + } + if path[0] == '~' { + home, _ := os.UserHomeDir() + if len(path) > 1 && path[1] == '/' { + return home + path[1:] + } + return home + } + return path +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index fac2856e9..ffc2191e3 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -10,8 +10,6 @@ import ( "context" "encoding/json" "fmt" - "os" - "path/filepath" "strings" "sync" "sync/atomic" @@ -21,23 +19,18 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" - "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" ) type AgentLoop struct { - bus *bus.MessageBus - provider providers.LLMProvider - workspace string - model string - contextWindow int // Maximum context window size in tokens - maxIterations int - sessions *session.SessionManager - contextBuilder *ContextBuilder - tools *tools.ToolRegistry - running atomic.Bool - summarizing sync.Map // Tracks which sessions are currently being summarized + bus *bus.MessageBus + cfg *config.Config + registry *AgentRegistry + running atomic.Bool + summarizing sync.Map + fallback *providers.FallbackChain } // processOptions configures how a message is processed @@ -52,60 +45,61 @@ type processOptions struct { } func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { - workspace := cfg.WorkspacePath() - os.MkdirAll(workspace, 0755) + registry := NewAgentRegistry(cfg, provider) - restrict := cfg.Agents.Defaults.RestrictToWorkspace + // Register shared tools to all agents + registerSharedTools(cfg, msgBus, registry, provider) - toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) - toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) - toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) - toolsRegistry.Register(tools.NewExecTool(workspace, restrict)) - - braveAPIKey := cfg.Tools.Web.Search.APIKey - toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) - toolsRegistry.Register(tools.NewWebFetchTool(50000)) - - // Register message tool - messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, - }) - return nil - }) - toolsRegistry.Register(messageTool) - - // Register spawn tool - subagentManager := tools.NewSubagentManager(provider, workspace, msgBus) - spawnTool := tools.NewSpawnTool(subagentManager) - toolsRegistry.Register(spawnTool) - - // Register edit file tool - editFileTool := tools.NewEditFileTool(workspace, restrict) - toolsRegistry.Register(editFileTool) - toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) - - sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) - - // Create context builder and set tools registry - contextBuilder := NewContextBuilder(workspace) - contextBuilder.SetToolsRegistry(toolsRegistry) + // Set up shared fallback chain + cooldown := providers.NewCooldownTracker() + fallbackChain := providers.NewFallbackChain(cooldown) return &AgentLoop{ - bus: msgBus, - provider: provider, - workspace: workspace, - model: cfg.Agents.Defaults.Model, - contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization - maxIterations: cfg.Agents.Defaults.MaxToolIterations, - sessions: sessionsManager, - contextBuilder: contextBuilder, - tools: toolsRegistry, - summarizing: sync.Map{}, + bus: msgBus, + cfg: cfg, + registry: registry, + summarizing: sync.Map{}, + fallback: fallbackChain, + } +} + +// registerSharedTools registers tools that are shared across all agents (web, message, spawn). +func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *AgentRegistry, provider providers.LLMProvider) { + braveAPIKey := cfg.Tools.Web.Search.APIKey + + for _, agentID := range registry.ListAgentIDs() { + agent, ok := registry.GetAgent(agentID) + if !ok { + continue + } + + // Web tools + agent.Tools.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) + agent.Tools.Register(tools.NewWebFetchTool(50000)) + + // Message tool + messageTool := tools.NewMessageTool() + messageTool.SetSendCallback(func(channel, chatID, content string) error { + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) + return nil + }) + agent.Tools.Register(messageTool) + + // Spawn tool with allowlist checker + subagentManager := tools.NewSubagentManager(provider, agent.Workspace, msgBus) + spawnTool := tools.NewSpawnTool(subagentManager) + currentAgentID := agentID + spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { + return registry.CanSpawnSubagent(currentAgentID, targetAgentID) + }) + agent.Tools.Register(spawnTool) + + // Update context builder with the complete tools registry + agent.ContextBuilder.SetToolsRegistry(agent.Tools) } } @@ -145,7 +139,11 @@ func (al *AgentLoop) Stop() { } func (al *AgentLoop) RegisterTool(tool tools.Tool) { - al.tools.Register(tool) + for _, agentID := range al.registry.ListAgentIDs() { + if agent, ok := al.registry.GetAgent(agentID); ok { + agent.Tools.Register(tool) + } + } } func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { @@ -165,7 +163,6 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess } func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - // Add message preview to log preview := utils.Truncate(msg.Content, 80) logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview), map[string]interface{}{ @@ -180,9 +177,36 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.processSystemMessage(ctx, msg) } - // Process as user message - return al.runAgentLoop(ctx, processOptions{ - SessionKey: msg.SessionKey, + // Route to determine agent and session key + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + AccountID: msg.Metadata["account_id"], + Peer: extractPeer(msg), + ParentPeer: extractParentPeer(msg), + GuildID: msg.Metadata["guild_id"], + TeamID: msg.Metadata["team_id"], + }) + + agent, ok := al.registry.GetAgent(route.AgentID) + if !ok { + agent = al.registry.GetDefaultAgent() + } + + // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) + sessionKey := route.SessionKey + if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") { + sessionKey = msg.SessionKey + } + + logger.InfoCF("agent", "Routed message", + map[string]interface{}{ + "agent_id": agent.ID, + "session_key": sessionKey, + "matched_by": route.MatchedBy, + }) + + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, @@ -193,7 +217,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - // Verify this is a system message if msg.Channel != "system" { return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel) } @@ -210,36 +233,36 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe originChannel = msg.ChatID[:idx] originChatID = msg.ChatID[idx+1:] } else { - // Fallback originChannel = "cli" originChatID = msg.ChatID } - // Use the origin session for context - sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) + // Use default agent for system messages + agent := al.registry.GetDefaultAgent() - // Process as system message with routing back to origin - return al.runAgentLoop(ctx, processOptions{ + // Use the origin session for context + sessionKey := routing.BuildAgentMainSessionKey(agent.ID) + + return al.runAgentLoop(ctx, agent, processOptions{ SessionKey: sessionKey, Channel: originChannel, ChatID: originChatID, UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), DefaultResponse: "Background task completed.", EnableSummary: false, - SendResponse: true, // Send response back to original channel + SendResponse: true, }) } // runAgentLoop is the core message processing logic. -// It handles context building, LLM calls, tool execution, and response handling. -func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { +func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) { // 1. Update tool contexts - al.updateToolContexts(opts.Channel, opts.ChatID) + al.updateToolContexts(agent, opts.Channel, opts.ChatID) // 2. Build messages - history := al.sessions.GetHistory(opts.SessionKey) - summary := al.sessions.GetSummary(opts.SessionKey) - messages := al.contextBuilder.BuildMessages( + history := agent.Sessions.GetHistory(opts.SessionKey) + summary := agent.Sessions.GetSummary(opts.SessionKey) + messages := agent.ContextBuilder.BuildMessages( history, summary, opts.UserMessage, @@ -249,10 +272,10 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str ) // 3. Save user message to session - al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) + agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) // 4. Run LLM iteration loop - finalContent, iteration, err := al.runLLMIteration(ctx, messages, opts) + finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err } @@ -263,12 +286,12 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // 6. Save final assistant message to session - al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) - al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey)) + agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent) + agent.Sessions.Save(agent.Sessions.GetOrCreate(opts.SessionKey)) // 7. Optional: summarization if opts.EnableSummary { - al.maybeSummarize(opts.SessionKey) + al.maybeSummarize(agent, opts.SessionKey) } // 8. Optional: send response via bus @@ -284,6 +307,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str responsePreview := utils.Truncate(finalContent, 120) logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), map[string]interface{}{ + "agent_id": agent.ID, "session_key": opts.SessionKey, "iterations": iteration, "final_length": len(finalContent), @@ -293,22 +317,22 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // runLLMIteration executes the LLM call loop with tool handling. -// Returns the final content, iteration count, and any error. -func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.Message, opts processOptions) (string, int, error) { +func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, messages []providers.Message, opts processOptions) (string, int, error) { iteration := 0 var finalContent string - for iteration < al.maxIterations { + for iteration < agent.MaxIterations { iteration++ logger.DebugCF("agent", "LLM iteration", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, - "max": al.maxIterations, + "max": agent.MaxIterations, }) // Build tool definitions - toolDefs := al.tools.GetDefinitions() + toolDefs := agent.Tools.GetDefinitions() providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs)) for _, td := range toolDefs { providerToolDefs = append(providerToolDefs, providers.ToolDefinition{ @@ -324,8 +348,9 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M // Log LLM request details logger.DebugCF("agent", "LLM request", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, - "model": al.model, + "model": agent.Model, "messages_count": len(messages), "tools_count": len(providerToolDefs), "max_tokens": 8192, @@ -341,15 +366,40 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "tools_json": formatToolsForLog(providerToolDefs), }) - // Call LLM - response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, - }) + // Call LLM with fallback chain if candidates are configured. + var response *providers.LLMResponse + var err error + + if len(agent.Candidates) > 1 && al.fallback != nil { + fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates, + func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { + return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + }, + ) + if fbErr != nil { + err = fbErr + } else { + response = fbResult.Response + if fbResult.Provider != "" && len(fbResult.Attempts) > 0 { + logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", + fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), + map[string]interface{}{"agent_id": agent.ID, "iteration": iteration}) + } + } + } else { + response, err = agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + } if err != nil { logger.ErrorCF("agent", "LLM call failed", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, "error": err.Error(), }) @@ -361,6 +411,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M finalContent = response.Content logger.InfoCF("agent", "LLM response without tool calls (direct answer)", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, "content_chars": len(finalContent), }) @@ -374,6 +425,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } logger.InfoCF("agent", "LLM requested tool calls", map[string]interface{}{ + "agent_id": agent.ID, "tools": toolNames, "count": len(toolNames), "iteration": iteration, @@ -398,20 +450,20 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M messages = append(messages, assistantMsg) // Save assistant message with tool calls to session - al.sessions.AddFullMessage(opts.SessionKey, assistantMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) // Execute tool calls for _, tc := range response.ToolCalls { - // Log tool call with arguments preview argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), map[string]interface{}{ + "agent_id": agent.ID, "tool": tc.Name, "iteration": iteration, }) - result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID) + result, err := agent.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID) if err != nil { result = fmt.Sprintf("Error: %v", err) } @@ -424,7 +476,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M messages = append(messages, toolResultMsg) // Save tool result message to session - al.sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) } } @@ -432,13 +484,13 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } // updateToolContexts updates the context for tools that need channel/chatID info. -func (al *AgentLoop) updateToolContexts(channel, chatID string) { - if tool, ok := al.tools.Get("message"); ok { +func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) { + if tool, ok := agent.Tools.Get("message"); ok { if mt, ok := tool.(*tools.MessageTool); ok { mt.SetContext(channel, chatID) } } - if tool, ok := al.tools.Get("spawn"); ok { + if tool, ok := agent.Tools.Get("spawn"); ok { if st, ok := tool.(*tools.SpawnTool); ok { st.SetContext(channel, chatID) } @@ -446,16 +498,17 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) { } // maybeSummarize triggers summarization if the session history exceeds thresholds. -func (al *AgentLoop) maybeSummarize(sessionKey string) { - newHistory := al.sessions.GetHistory(sessionKey) +func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string) { + newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) - threshold := al.contextWindow * 75 / 100 + threshold := agent.ContextWindow * 75 / 100 if len(newHistory) > 20 || tokenEstimate > threshold { - if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading { + summarizeKey := agent.ID + ":" + sessionKey + if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading { go func() { - defer al.summarizing.Delete(sessionKey) - al.summarizeSession(sessionKey) + defer al.summarizing.Delete(summarizeKey) + al.summarizeSession(agent, sessionKey) }() } } @@ -465,15 +518,26 @@ func (al *AgentLoop) maybeSummarize(sessionKey string) { func (al *AgentLoop) GetStartupInfo() map[string]interface{} { info := make(map[string]interface{}) + agent := al.registry.GetDefaultAgent() + if agent == nil { + return info + } + // Tools info - tools := al.tools.List() + toolsList := agent.Tools.List() info["tools"] = map[string]interface{}{ - "count": len(tools), - "names": tools, + "count": len(toolsList), + "names": toolsList, } // Skills info - info["skills"] = al.contextBuilder.GetSkillsInfo() + info["skills"] = agent.ContextBuilder.GetSkillsInfo() + + // Agents info + info["agents"] = map[string]interface{}{ + "count": len(al.registry.ListAgentIDs()), + "ids": al.registry.ListAgentIDs(), + } return info } @@ -530,12 +594,12 @@ func formatToolsForLog(tools []providers.ToolDefinition) string { } // summarizeSession summarizes the conversation history for a session. -func (al *AgentLoop) summarizeSession(sessionKey string) { +func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() - history := al.sessions.GetHistory(sessionKey) - summary := al.sessions.GetSummary(sessionKey) + history := agent.Sessions.GetHistory(sessionKey) + summary := agent.Sessions.GetSummary(sessionKey) // Keep last 4 messages for continuity if len(history) <= 4 { @@ -545,8 +609,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { toSummarize := history[:len(history)-4] // Oversized Message Guard - // Skip messages larger than 50% of context window to prevent summarizer overflow - maxMessageTokens := al.contextWindow / 2 + maxMessageTokens := agent.ContextWindow / 2 validMessages := make([]providers.Message, 0) omitted := false @@ -554,7 +617,6 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { if m.Role != "user" && m.Role != "assistant" { continue } - // Estimate tokens for this message msgTokens := len(m.Content) / 4 if msgTokens > maxMessageTokens { omitted = true @@ -568,19 +630,17 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { } // Multi-Part Summarization - // Split into two parts if history is significant var finalSummary string if len(validMessages) > 10 { mid := len(validMessages) / 2 part1 := validMessages[:mid] part2 := validMessages[mid:] - s1, _ := al.summarizeBatch(ctx, part1, "") - s2, _ := al.summarizeBatch(ctx, part2, "") + s1, _ := al.summarizeBatch(ctx, agent, part1, "") + s2, _ := al.summarizeBatch(ctx, agent, part2, "") - // Merge them mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2) - resp, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, al.model, map[string]interface{}{ + resp, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, agent.Model, map[string]interface{}{ "max_tokens": 1024, "temperature": 0.3, }) @@ -590,7 +650,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { finalSummary = s1 + " " + s2 } } else { - finalSummary, _ = al.summarizeBatch(ctx, validMessages, summary) + finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary) } if omitted && finalSummary != "" { @@ -598,14 +658,14 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { } if finalSummary != "" { - al.sessions.SetSummary(sessionKey, finalSummary) - al.sessions.TruncateHistory(sessionKey, 4) - al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + agent.Sessions.SetSummary(sessionKey, finalSummary) + agent.Sessions.TruncateHistory(sessionKey, 4) + agent.Sessions.Save(agent.Sessions.GetOrCreate(sessionKey)) } } // summarizeBatch summarizes a batch of messages. -func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Message, existingSummary string) (string, error) { +func (al *AgentLoop) summarizeBatch(ctx context.Context, agent *AgentInstance, batch []providers.Message, existingSummary string) (string, error) { prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n" if existingSummary != "" { prompt += "Existing context: " + existingSummary + "\n" @@ -615,7 +675,7 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content) } - response, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, al.model, map[string]interface{}{ + response, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, agent.Model, map[string]interface{}{ "max_tokens": 1024, "temperature": 0.3, }) @@ -629,7 +689,34 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa func (al *AgentLoop) estimateTokens(messages []providers.Message) int { total := 0 for _, m := range messages { - total += len(m.Content) / 4 // Simple heuristic: 4 chars per token + total += len(m.Content) / 4 } return total } + +// extractPeer extracts the routing peer from inbound message metadata. +func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { + peerKind := msg.Metadata["peer_kind"] + if peerKind == "" { + return nil + } + peerID := msg.Metadata["peer_id"] + if peerID == "" { + if peerKind == "direct" { + peerID = msg.SenderID + } else { + peerID = msg.ChatID + } + } + return &routing.RoutePeer{Kind: peerKind, ID: peerID} +} + +// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. +func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { + parentKind := msg.Metadata["parent_peer_kind"] + parentID := msg.Metadata["parent_peer_id"] + if parentKind == "" || parentID == "" { + return nil + } + return &routing.RoutePeer{Kind: parentKind, ID: parentID} +} diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go new file mode 100644 index 000000000..e37149c31 --- /dev/null +++ b/pkg/agent/registry.go @@ -0,0 +1,114 @@ +package agent + +import ( + "sync" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" +) + +// AgentRegistry manages multiple agent instances and routes messages to them. +type AgentRegistry struct { + agents map[string]*AgentInstance + resolver *routing.RouteResolver + mu sync.RWMutex +} + +// NewAgentRegistry creates a registry from config, instantiating all agents. +func NewAgentRegistry( + cfg *config.Config, + provider providers.LLMProvider, +) *AgentRegistry { + registry := &AgentRegistry{ + agents: make(map[string]*AgentInstance), + resolver: routing.NewRouteResolver(cfg), + } + + agentConfigs := cfg.Agents.List + if len(agentConfigs) == 0 { + implicitAgent := &config.AgentConfig{ + ID: "main", + Default: true, + } + instance := NewAgentInstance(implicitAgent, &cfg.Agents.Defaults, provider) + registry.agents["main"] = instance + logger.InfoCF("agent", "Created implicit main agent (no agents.list configured)", nil) + } else { + for i := range agentConfigs { + ac := &agentConfigs[i] + id := routing.NormalizeAgentID(ac.ID) + instance := NewAgentInstance(ac, &cfg.Agents.Defaults, provider) + registry.agents[id] = instance + logger.InfoCF("agent", "Registered agent", + map[string]interface{}{ + "agent_id": id, + "name": ac.Name, + "workspace": instance.Workspace, + "model": instance.Model, + }) + } + } + + return registry +} + +// GetAgent returns the agent instance for a given ID. +func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + id := routing.NormalizeAgentID(agentID) + agent, ok := r.agents[id] + return agent, ok +} + +// ResolveRoute determines which agent handles the message. +func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute { + return r.resolver.ResolveRoute(input) +} + +// ListAgentIDs returns all registered agent IDs. +func (r *AgentRegistry) ListAgentIDs() []string { + r.mu.RLock() + defer r.mu.RUnlock() + ids := make([]string, 0, len(r.agents)) + for id := range r.agents { + ids = append(ids, id) + } + return ids +} + +// CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID. +func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool { + parent, ok := r.GetAgent(parentAgentID) + if !ok { + return false + } + if parent.Subagents == nil || parent.Subagents.AllowAgents == nil { + return false + } + targetNorm := routing.NormalizeAgentID(targetAgentID) + for _, allowed := range parent.Subagents.AllowAgents { + if allowed == "*" { + return true + } + if routing.NormalizeAgentID(allowed) == targetNorm { + return true + } + } + return false +} + +// GetDefaultAgent returns the default agent instance. +func (r *AgentRegistry) GetDefaultAgent() *AgentInstance { + r.mu.RLock() + defer r.mu.RUnlock() + if agent, ok := r.agents["main"]; ok { + return agent + } + for _, agent := range r.agents { + return agent + } + return nil +} diff --git a/pkg/agent/registry_test.go b/pkg/agent/registry_test.go new file mode 100644 index 000000000..d4ccc064d --- /dev/null +++ b/pkg/agent/registry_test.go @@ -0,0 +1,199 @@ +package agent + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +type mockProvider struct{} + +func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "mock", FinishReason: "stop"}, nil +} + +func (m *mockProvider) GetDefaultModel() string { + return "mock-model" +} + +func testCfg(agents []config.AgentConfig) *config.Config { + return &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: "/tmp/picoclaw-test-registry", + Model: "gpt-4", + MaxTokens: 8192, + MaxToolIterations: 10, + }, + List: agents, + }, + } +} + +func TestNewAgentRegistry_ImplicitMain(t *testing.T) { + cfg := testCfg(nil) + registry := NewAgentRegistry(cfg, &mockProvider{}) + + ids := registry.ListAgentIDs() + if len(ids) != 1 || ids[0] != "main" { + t.Errorf("expected implicit main agent, got %v", ids) + } + + agent, ok := registry.GetAgent("main") + if !ok || agent == nil { + t.Fatal("expected to find 'main' agent") + } + if agent.ID != "main" { + t.Errorf("agent.ID = %q, want 'main'", agent.ID) + } +} + +func TestNewAgentRegistry_ExplicitAgents(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "sales", Default: true, Name: "Sales Bot"}, + {ID: "support", Name: "Support Bot"}, + }) + registry := NewAgentRegistry(cfg, &mockProvider{}) + + ids := registry.ListAgentIDs() + if len(ids) != 2 { + t.Fatalf("expected 2 agents, got %d: %v", len(ids), ids) + } + + sales, ok := registry.GetAgent("sales") + if !ok || sales == nil { + t.Fatal("expected to find 'sales' agent") + } + if sales.Name != "Sales Bot" { + t.Errorf("sales.Name = %q, want 'Sales Bot'", sales.Name) + } + + support, ok := registry.GetAgent("support") + if !ok || support == nil { + t.Fatal("expected to find 'support' agent") + } +} + +func TestAgentRegistry_GetAgent_Normalize(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "my-agent", Default: true}, + }) + registry := NewAgentRegistry(cfg, &mockProvider{}) + + agent, ok := registry.GetAgent("My-Agent") + if !ok || agent == nil { + t.Fatal("expected to find agent with normalized ID") + } + if agent.ID != "my-agent" { + t.Errorf("agent.ID = %q, want 'my-agent'", agent.ID) + } +} + +func TestAgentRegistry_GetDefaultAgent(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta", Default: true}, + }) + registry := NewAgentRegistry(cfg, &mockProvider{}) + + // GetDefaultAgent first checks for "main", then returns any + agent := registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected a default agent") + } +} + +func TestAgentRegistry_CanSpawnSubagent(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + { + ID: "parent", + Default: true, + Subagents: &config.SubagentsConfig{ + AllowAgents: []string{"child1", "child2"}, + }, + }, + {ID: "child1"}, + {ID: "child2"}, + {ID: "restricted"}, + }) + registry := NewAgentRegistry(cfg, &mockProvider{}) + + if !registry.CanSpawnSubagent("parent", "child1") { + t.Error("expected parent to be allowed to spawn child1") + } + if !registry.CanSpawnSubagent("parent", "child2") { + t.Error("expected parent to be allowed to spawn child2") + } + if registry.CanSpawnSubagent("parent", "restricted") { + t.Error("expected parent to NOT be allowed to spawn restricted") + } + if registry.CanSpawnSubagent("child1", "child2") { + t.Error("expected child1 to NOT be allowed to spawn (no subagents config)") + } +} + +func TestAgentRegistry_CanSpawnSubagent_Wildcard(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + { + ID: "admin", + Default: true, + Subagents: &config.SubagentsConfig{ + AllowAgents: []string{"*"}, + }, + }, + {ID: "any-agent"}, + }) + registry := NewAgentRegistry(cfg, &mockProvider{}) + + if !registry.CanSpawnSubagent("admin", "any-agent") { + t.Error("expected wildcard to allow spawning any agent") + } + if !registry.CanSpawnSubagent("admin", "nonexistent") { + t.Error("expected wildcard to allow spawning even nonexistent agents") + } +} + +func TestAgentInstance_Model(t *testing.T) { + model := &config.AgentModelConfig{Primary: "claude-opus"} + cfg := testCfg([]config.AgentConfig{ + {ID: "custom", Default: true, Model: model}, + }) + registry := NewAgentRegistry(cfg, &mockProvider{}) + + agent, _ := registry.GetAgent("custom") + if agent.Model != "claude-opus" { + t.Errorf("agent.Model = %q, want 'claude-opus'", agent.Model) + } +} + +func TestAgentInstance_FallbackInheritance(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "inherit", Default: true}, + }) + cfg.Agents.Defaults.ModelFallbacks = []string{"openai/gpt-4o-mini", "anthropic/haiku"} + registry := NewAgentRegistry(cfg, &mockProvider{}) + + agent, _ := registry.GetAgent("inherit") + if len(agent.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks inherited from defaults, got %d", len(agent.Fallbacks)) + } +} + +func TestAgentInstance_FallbackExplicitEmpty(t *testing.T) { + model := &config.AgentModelConfig{ + Primary: "gpt-4", + Fallbacks: []string{}, // explicitly empty = disable + } + cfg := testCfg([]config.AgentConfig{ + {ID: "no-fallback", Default: true, Model: model}, + }) + cfg.Agents.Defaults.ModelFallbacks = []string{"should-not-inherit"} + registry := NewAgentRegistry(cfg, &mockProvider{}) + + agent, _ := registry.GetAgent("no-fallback") + if len(agent.Fallbacks) != 0 { + t.Errorf("expected 0 fallbacks (explicit empty), got %d: %v", len(agent.Fallbacks), agent.Fallbacks) + } +} diff --git a/pkg/channels/base.go b/pkg/channels/base.go index fabec1a86..c1d3085ec 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -2,7 +2,6 @@ package channels import ( "context" - "fmt" "strings" "github.com/sipeed/picoclaw/pkg/bus" @@ -72,17 +71,13 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st return } - // Build session key: channel:chatID - sessionKey := fmt.Sprintf("%s:%s", c.name, chatID) - msg := bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - ChatID: chatID, - Content: content, - Media: media, - SessionKey: sessionKey, - Metadata: metadata, + Channel: c.name, + SenderID: senderID, + ChatID: chatID, + Content: content, + Media: media, + Metadata: metadata, } c.bus.PublishInbound(msg) diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index e65c99eec..af4a01b35 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -228,6 +228,13 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag "preview": utils.Truncate(content, 50), }) + peerKind := "channel" + peerID := m.ChannelID + if m.GuildID == "" { + peerKind = "direct" + peerID = senderID + } + metadata := map[string]string{ "message_id": m.ID, "user_id": senderID, @@ -236,6 +243,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag "guild_id": m.GuildID, "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), + "peer_kind": peerKind, + "peer_id": peerID, } c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go index b3ac12e01..58dc7824c 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack.go @@ -25,6 +25,7 @@ type SlackChannel struct { api *slack.Client socketClient *socketmode.Client botUserID string + teamID string transcriber *voice.GroqTranscriber ctx context.Context cancel context.CancelFunc @@ -72,6 +73,7 @@ func (c *SlackChannel) Start(ctx context.Context) error { return fmt.Errorf("slack auth test failed: %w", err) } c.botUserID = authResp.UserID + c.teamID = authResp.TeamID logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ "bot_user_id": c.botUserID, @@ -274,11 +276,21 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { return } + peerKind := "channel" + peerID := channelID + if strings.HasPrefix(channelID, "D") { + peerKind = "direct" + peerID = senderID + } + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", + "peer_kind": peerKind, + "peer_id": peerID, + "team_id": c.teamID, } logger.DebugCF("slack", "Received message", map[string]interface{}{ @@ -324,12 +336,22 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { return } + mentionPeerKind := "channel" + mentionPeerID := channelID + if strings.HasPrefix(channelID, "D") { + mentionPeerKind = "direct" + mentionPeerID = senderID + } + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", "is_mention": "true", + "peer_kind": mentionPeerKind, + "peer_id": mentionPeerID, + "team_id": c.teamID, } c.HandleMessage(senderID, chatID, content, nil, metadata) @@ -359,6 +381,9 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "platform": "slack", "is_command": "true", "trigger_id": cmd.TriggerID, + "peer_kind": "channel", + "peer_id": channelID, + "team_id": c.teamID, } logger.DebugCF("slack", "Slash command received", map[string]interface{}{ diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 3ad4818c3..32924206f 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -351,12 +351,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat }(chatID, pID) } + peerKind := "direct" + peerID := fmt.Sprintf("%d", user.ID) + if message.Chat.Type != "private" { + peerKind = "group" + peerID = fmt.Sprintf("%d", chatID) + } + metadata := map[string]string{ "message_id": fmt.Sprintf("%d", message.MessageID), "user_id": fmt.Sprintf("%d", user.ID), "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), + "peer_kind": peerKind, + "peer_id": peerID, } c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) diff --git a/pkg/config/config.go b/pkg/config/config.go index 56f1e1958..accccc583 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -45,6 +45,8 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { type Config struct { Agents AgentsConfig `json:"agents"` + Bindings []AgentBinding `json:"bindings,omitempty"` + Session SessionConfig `json:"session,omitempty"` Channels ChannelsConfig `json:"channels"` Providers ProvidersConfig `json:"providers"` Gateway GatewayConfig `json:"gateway"` @@ -54,16 +56,97 @@ type Config struct { type AgentsConfig struct { Defaults AgentDefaults `json:"defaults"` + List []AgentConfig `json:"list,omitempty"` +} + +// AgentModelConfig supports both string and structured model config. +// String format: "gpt-4" (just primary, no fallbacks) +// Object format: {"primary": "gpt-4", "fallbacks": ["claude-haiku"]} +type AgentModelConfig struct { + Primary string `json:"primary,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` +} + +func (m *AgentModelConfig) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err == nil { + m.Primary = s + m.Fallbacks = nil + return nil + } + type raw struct { + Primary string `json:"primary"` + Fallbacks []string `json:"fallbacks"` + } + var r raw + if err := json.Unmarshal(data, &r); err != nil { + return err + } + m.Primary = r.Primary + m.Fallbacks = r.Fallbacks + return nil +} + +func (m AgentModelConfig) MarshalJSON() ([]byte, error) { + if len(m.Fallbacks) == 0 && m.Primary != "" { + return json.Marshal(m.Primary) + } + type raw struct { + Primary string `json:"primary,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` + } + return json.Marshal(raw{Primary: m.Primary, Fallbacks: m.Fallbacks}) +} + +type AgentConfig struct { + ID string `json:"id"` + Default bool `json:"default,omitempty"` + Name string `json:"name,omitempty"` + Workspace string `json:"workspace,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` + Skills []string `json:"skills,omitempty"` + Subagents *SubagentsConfig `json:"subagents,omitempty"` +} + +type SubagentsConfig struct { + AllowAgents []string `json:"allow_agents,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` +} + +type PeerMatch struct { + Kind string `json:"kind"` + ID string `json:"id"` +} + +type BindingMatch struct { + Channel string `json:"channel"` + AccountID string `json:"account_id,omitempty"` + Peer *PeerMatch `json:"peer,omitempty"` + GuildID string `json:"guild_id,omitempty"` + TeamID string `json:"team_id,omitempty"` +} + +type AgentBinding struct { + AgentID string `json:"agent_id"` + Match BindingMatch `json:"match"` +} + +type SessionConfig struct { + DMScope string `json:"dm_scope,omitempty"` + IdentityLinks map[string][]string `json:"identity_links,omitempty"` } type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` - Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` + ModelFallbacks []string `json:"model_fallbacks,omitempty"` + ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` + ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` } type ChannelsConfig struct { @@ -348,6 +431,32 @@ func (c *Config) GetAPIBase() string { return "" } +// ModelConfig holds primary model and fallback list. +type ModelConfig struct { + Primary string + Fallbacks []string +} + +// GetModelConfig returns the text model configuration with fallbacks. +func (c *Config) GetModelConfig() ModelConfig { + c.mu.RLock() + defer c.mu.RUnlock() + return ModelConfig{ + Primary: c.Agents.Defaults.Model, + Fallbacks: c.Agents.Defaults.ModelFallbacks, + } +} + +// GetImageModelConfig returns the image model configuration with fallbacks. +func (c *Config) GetImageModelConfig() ModelConfig { + c.mu.RLock() + defer c.mu.RUnlock() + return ModelConfig{ + Primary: c.Agents.Defaults.ImageModel, + Fallbacks: c.Agents.Defaults.ImageModelFallbacks, + } +} + func expandHome(path string) string { if path == "" { return path diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 000000000..e99c4f0aa --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,186 @@ +package config + +import ( + "encoding/json" + "testing" +) + +func TestAgentModelConfig_UnmarshalString(t *testing.T) { + var m AgentModelConfig + if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil { + t.Fatalf("unmarshal string: %v", err) + } + if m.Primary != "gpt-4" { + t.Errorf("Primary = %q, want 'gpt-4'", m.Primary) + } + if m.Fallbacks != nil { + t.Errorf("Fallbacks = %v, want nil", m.Fallbacks) + } +} + +func TestAgentModelConfig_UnmarshalObject(t *testing.T) { + var m AgentModelConfig + data := `{"primary": "claude-opus", "fallbacks": ["gpt-4o-mini", "haiku"]}` + if err := json.Unmarshal([]byte(data), &m); err != nil { + t.Fatalf("unmarshal object: %v", err) + } + if m.Primary != "claude-opus" { + t.Errorf("Primary = %q, want 'claude-opus'", m.Primary) + } + if len(m.Fallbacks) != 2 { + t.Fatalf("Fallbacks len = %d, want 2", len(m.Fallbacks)) + } + if m.Fallbacks[0] != "gpt-4o-mini" || m.Fallbacks[1] != "haiku" { + t.Errorf("Fallbacks = %v", m.Fallbacks) + } +} + +func TestAgentModelConfig_MarshalString(t *testing.T) { + m := AgentModelConfig{Primary: "gpt-4"} + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(data) != `"gpt-4"` { + t.Errorf("marshal = %s, want '\"gpt-4\"'", string(data)) + } +} + +func TestAgentModelConfig_MarshalObject(t *testing.T) { + m := AgentModelConfig{Primary: "claude-opus", Fallbacks: []string{"haiku"}} + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var result map[string]interface{} + json.Unmarshal(data, &result) + if result["primary"] != "claude-opus" { + t.Errorf("primary = %v", result["primary"]) + } +} + +func TestAgentConfig_FullParse(t *testing.T) { + jsonData := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "max_tool_iterations": 20 + }, + "list": [ + { + "id": "sales", + "default": true, + "name": "Sales Bot", + "model": "gpt-4" + }, + { + "id": "support", + "name": "Support Bot", + "model": { + "primary": "claude-opus", + "fallbacks": ["haiku"] + }, + "subagents": { + "allow_agents": ["sales"] + } + } + ] + }, + "bindings": [ + { + "agent_id": "support", + "match": { + "channel": "telegram", + "account_id": "*", + "peer": {"kind": "direct", "id": "user123"} + } + } + ], + "session": { + "dm_scope": "per-peer", + "identity_links": { + "john": ["telegram:123", "discord:john#1234"] + } + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(cfg.Agents.List) != 2 { + t.Fatalf("agents.list len = %d, want 2", len(cfg.Agents.List)) + } + + sales := cfg.Agents.List[0] + if sales.ID != "sales" || !sales.Default || sales.Name != "Sales Bot" { + t.Errorf("sales = %+v", sales) + } + if sales.Model == nil || sales.Model.Primary != "gpt-4" { + t.Errorf("sales.Model = %+v", sales.Model) + } + + support := cfg.Agents.List[1] + if support.ID != "support" || support.Name != "Support Bot" { + t.Errorf("support = %+v", support) + } + if support.Model == nil || support.Model.Primary != "claude-opus" { + t.Errorf("support.Model = %+v", support.Model) + } + if len(support.Model.Fallbacks) != 1 || support.Model.Fallbacks[0] != "haiku" { + t.Errorf("support.Model.Fallbacks = %v", support.Model.Fallbacks) + } + if support.Subagents == nil || len(support.Subagents.AllowAgents) != 1 { + t.Errorf("support.Subagents = %+v", support.Subagents) + } + + if len(cfg.Bindings) != 1 { + t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings)) + } + binding := cfg.Bindings[0] + if binding.AgentID != "support" || binding.Match.Channel != "telegram" { + t.Errorf("binding = %+v", binding) + } + if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" { + t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer) + } + + if cfg.Session.DMScope != "per-peer" { + t.Errorf("Session.DMScope = %q", cfg.Session.DMScope) + } + if len(cfg.Session.IdentityLinks) != 1 { + t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks) + } + links := cfg.Session.IdentityLinks["john"] + if len(links) != 2 { + t.Errorf("john links = %v", links) + } +} + +func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) { + jsonData := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "max_tool_iterations": 20 + } + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(cfg.Agents.List) != 0 { + t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List)) + } + if len(cfg.Bindings) != 0 { + t.Errorf("bindings should be empty, got %d", len(cfg.Bindings)) + } +} diff --git a/pkg/routing/agent_id.go b/pkg/routing/agent_id.go new file mode 100644 index 000000000..bcf2f0dc0 --- /dev/null +++ b/pkg/routing/agent_id.go @@ -0,0 +1,66 @@ +package routing + +import ( + "regexp" + "strings" +) + +const ( + DefaultAgentID = "main" + DefaultMainKey = "main" + DefaultAccountID = "default" + MaxAgentIDLength = 64 +) + +var ( + validIDRe = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{0,63}$`) + invalidCharsRe = regexp.MustCompile(`[^a-z0-9_-]+`) + leadingDashRe = regexp.MustCompile(`^-+`) + trailingDashRe = regexp.MustCompile(`-+$`) +) + +// NormalizeAgentID sanitizes an agent ID to [a-z0-9][a-z0-9_-]{0,63}. +// Invalid characters are collapsed to "-". Leading/trailing dashes stripped. +// Empty input returns DefaultAgentID ("main"). +func NormalizeAgentID(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return DefaultAgentID + } + lower := strings.ToLower(trimmed) + if validIDRe.MatchString(lower) { + return lower + } + result := invalidCharsRe.ReplaceAllString(lower, "-") + result = leadingDashRe.ReplaceAllString(result, "") + result = trailingDashRe.ReplaceAllString(result, "") + if len(result) > MaxAgentIDLength { + result = result[:MaxAgentIDLength] + } + if result == "" { + return DefaultAgentID + } + return result +} + +// NormalizeAccountID sanitizes an account ID. Empty returns DefaultAccountID. +func NormalizeAccountID(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return DefaultAccountID + } + lower := strings.ToLower(trimmed) + if validIDRe.MatchString(lower) { + return lower + } + result := invalidCharsRe.ReplaceAllString(lower, "-") + result = leadingDashRe.ReplaceAllString(result, "") + result = trailingDashRe.ReplaceAllString(result, "") + if len(result) > MaxAgentIDLength { + result = result[:MaxAgentIDLength] + } + if result == "" { + return DefaultAccountID + } + return result +} diff --git a/pkg/routing/agent_id_test.go b/pkg/routing/agent_id_test.go new file mode 100644 index 000000000..050fe0645 --- /dev/null +++ b/pkg/routing/agent_id_test.go @@ -0,0 +1,86 @@ +package routing + +import "testing" + +func TestNormalizeAgentID_Empty(t *testing.T) { + if got := NormalizeAgentID(""); got != DefaultAgentID { + t.Errorf("NormalizeAgentID('') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_Whitespace(t *testing.T) { + if got := NormalizeAgentID(" "); got != DefaultAgentID { + t.Errorf("NormalizeAgentID(' ') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_Valid(t *testing.T) { + tests := []struct { + input, want string + }{ + {"main", "main"}, + {"Main", "main"}, + {"SALES", "sales"}, + {"support-bot", "support-bot"}, + {"agent_1", "agent_1"}, + {"a", "a"}, + {"0test", "0test"}, + } + for _, tt := range tests { + if got := NormalizeAgentID(tt.input); got != tt.want { + t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestNormalizeAgentID_InvalidChars(t *testing.T) { + tests := []struct { + input, want string + }{ + {"Hello World", "hello-world"}, + {"agent@123", "agent-123"}, + {"foo.bar.baz", "foo-bar-baz"}, + {"--leading", "leading"}, + {"--both--", "both"}, + } + for _, tt := range tests { + if got := NormalizeAgentID(tt.input); got != tt.want { + t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestNormalizeAgentID_AllInvalid(t *testing.T) { + if got := NormalizeAgentID("@@@"); got != DefaultAgentID { + t.Errorf("NormalizeAgentID('@@@') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_TruncatesAt64(t *testing.T) { + long := "" + for i := 0; i < 100; i++ { + long += "a" + } + got := NormalizeAgentID(long) + if len(got) > MaxAgentIDLength { + t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength) + } +} + +func TestNormalizeAccountID_Empty(t *testing.T) { + if got := NormalizeAccountID(""); got != DefaultAccountID { + t.Errorf("NormalizeAccountID('') = %q, want %q", got, DefaultAccountID) + } +} + +func TestNormalizeAccountID_Valid(t *testing.T) { + if got := NormalizeAccountID("MyBot"); got != "mybot" { + t.Errorf("NormalizeAccountID('MyBot') = %q, want 'mybot'", got) + } +} + +func TestNormalizeAccountID_InvalidChars(t *testing.T) { + if got := NormalizeAccountID("bot@home"); got != "bot-home" { + t.Errorf("NormalizeAccountID('bot@home') = %q, want 'bot-home'", got) + } +} diff --git a/pkg/routing/route.go b/pkg/routing/route.go new file mode 100644 index 000000000..9eb060c53 --- /dev/null +++ b/pkg/routing/route.go @@ -0,0 +1,252 @@ +package routing + +import ( + "strings" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// RouteInput contains the routing context from an inbound message. +type RouteInput struct { + Channel string + AccountID string + Peer *RoutePeer + ParentPeer *RoutePeer + GuildID string + TeamID string +} + +// ResolvedRoute is the result of agent routing. +type ResolvedRoute struct { + AgentID string + Channel string + AccountID string + SessionKey string + MainSessionKey string + MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default" +} + +// RouteResolver determines which agent handles a message based on config bindings. +type RouteResolver struct { + cfg *config.Config +} + +// NewRouteResolver creates a new route resolver. +func NewRouteResolver(cfg *config.Config) *RouteResolver { + return &RouteResolver{cfg: cfg} +} + +// ResolveRoute determines which agent handles the message and constructs session keys. +// Implements the 7-level priority cascade: +// peer > parent_peer > guild > team > account > channel_wildcard > default +func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { + channel := strings.ToLower(strings.TrimSpace(input.Channel)) + accountID := NormalizeAccountID(input.AccountID) + peer := input.Peer + + dmScope := DMScope(r.cfg.Session.DMScope) + if dmScope == "" { + dmScope = DMScopeMain + } + identityLinks := r.cfg.Session.IdentityLinks + + bindings := r.filterBindings(channel, accountID) + + choose := func(agentID string, matchedBy string) ResolvedRoute { + resolvedAgentID := r.pickAgentID(agentID) + sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: resolvedAgentID, + Channel: channel, + AccountID: accountID, + Peer: peer, + DMScope: dmScope, + IdentityLinks: identityLinks, + })) + mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID)) + return ResolvedRoute{ + AgentID: resolvedAgentID, + Channel: channel, + AccountID: accountID, + SessionKey: sessionKey, + MainSessionKey: mainSessionKey, + MatchedBy: matchedBy, + } + } + + // Priority 1: Peer binding + if peer != nil && strings.TrimSpace(peer.ID) != "" { + if match := r.findPeerMatch(bindings, peer); match != nil { + return choose(match.AgentID, "binding.peer") + } + } + + // Priority 2: Parent peer binding + parentPeer := input.ParentPeer + if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" { + if match := r.findPeerMatch(bindings, parentPeer); match != nil { + return choose(match.AgentID, "binding.peer.parent") + } + } + + // Priority 3: Guild binding + guildID := strings.TrimSpace(input.GuildID) + if guildID != "" { + if match := r.findGuildMatch(bindings, guildID); match != nil { + return choose(match.AgentID, "binding.guild") + } + } + + // Priority 4: Team binding + teamID := strings.TrimSpace(input.TeamID) + if teamID != "" { + if match := r.findTeamMatch(bindings, teamID); match != nil { + return choose(match.AgentID, "binding.team") + } + } + + // Priority 5: Account binding + if match := r.findAccountMatch(bindings); match != nil { + return choose(match.AgentID, "binding.account") + } + + // Priority 6: Channel wildcard binding + if match := r.findChannelWildcardMatch(bindings); match != nil { + return choose(match.AgentID, "binding.channel") + } + + // Priority 7: Default agent + return choose(r.resolveDefaultAgentID(), "default") +} + +func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding { + var filtered []config.AgentBinding + for _, b := range r.cfg.Bindings { + matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel)) + if matchChannel == "" || matchChannel != channel { + continue + } + if !matchesAccountID(b.Match.AccountID, accountID) { + continue + } + filtered = append(filtered, b) + } + return filtered +} + +func matchesAccountID(matchAccountID, actual string) bool { + trimmed := strings.TrimSpace(matchAccountID) + if trimmed == "" { + return actual == DefaultAccountID + } + if trimmed == "*" { + return true + } + return strings.ToLower(trimmed) == strings.ToLower(actual) +} + +func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + if b.Match.Peer == nil { + continue + } + peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind)) + peerID := strings.TrimSpace(b.Match.Peer.ID) + if peerKind == "" || peerID == "" { + continue + } + if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID { + return b + } + } + return nil +} + +func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + matchGuild := strings.TrimSpace(b.Match.GuildID) + if matchGuild != "" && matchGuild == guildID { + return &bindings[i] + } + } + return nil +} + +func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + matchTeam := strings.TrimSpace(b.Match.TeamID) + if matchTeam != "" && matchTeam == teamID { + return &bindings[i] + } + } + return nil +} + +func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + accountID := strings.TrimSpace(b.Match.AccountID) + if accountID == "*" { + continue + } + if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { + continue + } + return &bindings[i] + } + return nil +} + +func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + accountID := strings.TrimSpace(b.Match.AccountID) + if accountID != "*" { + continue + } + if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { + continue + } + return &bindings[i] + } + return nil +} + +func (r *RouteResolver) pickAgentID(agentID string) string { + trimmed := strings.TrimSpace(agentID) + if trimmed == "" { + return NormalizeAgentID(r.resolveDefaultAgentID()) + } + normalized := NormalizeAgentID(trimmed) + agents := r.cfg.Agents.List + if len(agents) == 0 { + return normalized + } + for _, a := range agents { + if NormalizeAgentID(a.ID) == normalized { + return normalized + } + } + return NormalizeAgentID(r.resolveDefaultAgentID()) +} + +func (r *RouteResolver) resolveDefaultAgentID() string { + agents := r.cfg.Agents.List + if len(agents) == 0 { + return DefaultAgentID + } + for _, a := range agents { + if a.Default { + id := strings.TrimSpace(a.ID) + if id != "" { + return NormalizeAgentID(id) + } + } + } + if id := strings.TrimSpace(agents[0].ID); id != "" { + return NormalizeAgentID(id) + } + return DefaultAgentID +} diff --git a/pkg/routing/route_test.go b/pkg/routing/route_test.go new file mode 100644 index 000000000..8255db5f9 --- /dev/null +++ b/pkg/routing/route_test.go @@ -0,0 +1,297 @@ +package routing + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config { + return &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: "/tmp/picoclaw-test", + Model: "gpt-4", + }, + List: agents, + }, + Bindings: bindings, + Session: config.SessionConfig{ + DMScope: "per-peer", + }, + } +} + +func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) { + cfg := testConfig(nil, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != DefaultAgentID { + t.Errorf("AgentID = %q, want %q", route.AgentID, DefaultAgentID) + } + if route.MatchedBy != "default" { + t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy) + } +} + +func TestResolveRoute_PeerBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "sales", Default: true}, + {ID: "support"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "support", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + Peer: &config.PeerMatch{Kind: "direct", ID: "user123"}, + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + }) + + if route.AgentID != "support" { + t.Errorf("AgentID = %q, want 'support'", route.AgentID) + } + if route.MatchedBy != "binding.peer" { + t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + } +} + +func TestResolveRoute_GuildBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "gaming"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "gaming", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + GuildID: "guild-abc", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "discord", + GuildID: "guild-abc", + Peer: &RoutePeer{Kind: "channel", ID: "ch1"}, + }) + + if route.AgentID != "gaming" { + t.Errorf("AgentID = %q, want 'gaming'", route.AgentID) + } + if route.MatchedBy != "binding.guild" { + t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy) + } +} + +func TestResolveRoute_TeamBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "work"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "work", + Match: config.BindingMatch{ + Channel: "slack", + AccountID: "*", + TeamID: "T12345", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "slack", + TeamID: "T12345", + Peer: &RoutePeer{Kind: "channel", ID: "C001"}, + }) + + if route.AgentID != "work" { + t.Errorf("AgentID = %q, want 'work'", route.AgentID) + } + if route.MatchedBy != "binding.team" { + t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy) + } +} + +func TestResolveRoute_AccountBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "default-agent", Default: true}, + {ID: "premium"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "premium", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "bot2", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + AccountID: "bot2", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != "premium" { + t.Errorf("AgentID = %q, want 'premium'", route.AgentID) + } + if route.MatchedBy != "binding.account" { + t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy) + } +} + +func TestResolveRoute_ChannelWildcard(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "main", Default: true}, + {ID: "telegram-bot"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "telegram-bot", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != "telegram-bot" { + t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID) + } + if route.MatchedBy != "binding.channel" { + t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy) + } +} + +func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "vip"}, + {ID: "gaming"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "vip", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"}, + }, + }, + { + AgentID: "gaming", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + GuildID: "guild-1", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "discord", + GuildID: "guild-1", + Peer: &RoutePeer{Kind: "direct", ID: "user-vip"}, + }) + + if route.AgentID != "vip" { + t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID) + } + if route.MatchedBy != "binding.peer" { + t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + } +} + +func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "main", Default: true}, + } + bindings := []config.AgentBinding{ + { + AgentID: "nonexistent", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + }) + + if route.AgentID != "main" { + t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID) + } +} + +func TestResolveRoute_DefaultAgentSelection(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta", Default: true}, + {ID: "gamma"}, + } + cfg := testConfig(agents, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "cli", + }) + + if route.AgentID != "beta" { + t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID) + } +} + +func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta"}, + } + cfg := testConfig(agents, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "cli", + }) + + if route.AgentID != "alpha" { + t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID) + } +} diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go new file mode 100644 index 000000000..e12f0d1d8 --- /dev/null +++ b/pkg/routing/session_key.go @@ -0,0 +1,183 @@ +package routing + +import ( + "fmt" + "strings" +) + +// DMScope controls DM session isolation granularity. +type DMScope string + +const ( + DMScopeMain DMScope = "main" + DMScopePerPeer DMScope = "per-peer" + DMScopePerChannelPeer DMScope = "per-channel-peer" + DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer" +) + +// RoutePeer represents a chat peer with kind and ID. +type RoutePeer struct { + Kind string // "direct", "group", "channel" + ID string +} + +// SessionKeyParams holds all inputs for session key construction. +type SessionKeyParams struct { + AgentID string + Channel string + AccountID string + Peer *RoutePeer + DMScope DMScope + IdentityLinks map[string][]string +} + +// ParsedSessionKey is the result of parsing an agent-scoped session key. +type ParsedSessionKey struct { + AgentID string + Rest string +} + +// BuildAgentMainSessionKey returns "agent::main". +func BuildAgentMainSessionKey(agentID string) string { + return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey) +} + +// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope. +func BuildAgentPeerSessionKey(params SessionKeyParams) string { + agentID := NormalizeAgentID(params.AgentID) + + peer := params.Peer + if peer == nil { + peer = &RoutePeer{Kind: "direct"} + } + peerKind := strings.TrimSpace(peer.Kind) + if peerKind == "" { + peerKind = "direct" + } + + if peerKind == "direct" { + dmScope := params.DMScope + if dmScope == "" { + dmScope = DMScopeMain + } + peerID := strings.TrimSpace(peer.ID) + + // Resolve identity links (cross-platform collapse) + if dmScope != DMScopeMain && peerID != "" { + if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" { + peerID = linked + } + } + peerID = strings.ToLower(peerID) + + switch dmScope { + case DMScopePerAccountChannelPeer: + if peerID != "" { + channel := normalizeChannel(params.Channel) + accountID := NormalizeAccountID(params.AccountID) + return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID) + } + case DMScopePerChannelPeer: + if peerID != "" { + channel := normalizeChannel(params.Channel) + return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID) + } + case DMScopePerPeer: + if peerID != "" { + return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID) + } + } + return BuildAgentMainSessionKey(agentID) + } + + // Group/channel peers always get per-peer sessions + channel := normalizeChannel(params.Channel) + peerID := strings.ToLower(strings.TrimSpace(peer.ID)) + if peerID == "" { + peerID = "unknown" + } + return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID) +} + +// ParseAgentSessionKey extracts agentId and rest from "agent::". +func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return nil + } + parts := strings.SplitN(raw, ":", 3) + if len(parts) < 3 { + return nil + } + if parts[0] != "agent" { + return nil + } + agentID := strings.TrimSpace(parts[1]) + rest := parts[2] + if agentID == "" || rest == "" { + return nil + } + return &ParsedSessionKey{AgentID: agentID, Rest: rest} +} + +// IsSubagentSessionKey returns true if the session key represents a subagent. +func IsSubagentSessionKey(sessionKey string) bool { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return false + } + if strings.HasPrefix(strings.ToLower(raw), "subagent:") { + return true + } + parsed := ParseAgentSessionKey(raw) + if parsed == nil { + return false + } + return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:") +} + +func normalizeChannel(channel string) string { + c := strings.TrimSpace(strings.ToLower(channel)) + if c == "" { + return "unknown" + } + return c +} + +func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string { + if len(identityLinks) == 0 { + return "" + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + candidates := make(map[string]bool) + rawCandidate := strings.ToLower(peerID) + if rawCandidate != "" { + candidates[rawCandidate] = true + } + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel != "" { + scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID)) + candidates[scopedCandidate] = true + } + if len(candidates) == 0 { + return "" + } + + for canonical, ids := range identityLinks { + canonicalName := strings.TrimSpace(canonical) + if canonicalName == "" { + continue + } + for _, id := range ids { + normalized := strings.ToLower(strings.TrimSpace(id)) + if normalized != "" && candidates[normalized] { + return canonicalName + } + } + } + return "" +} diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go new file mode 100644 index 000000000..81e4ce018 --- /dev/null +++ b/pkg/routing/session_key_test.go @@ -0,0 +1,162 @@ +package routing + +import "testing" + +func TestBuildAgentMainSessionKey(t *testing.T) { + got := BuildAgentMainSessionKey("sales") + want := "agent:sales:main" + if got != want { + t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want) + } +} + +func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) { + got := BuildAgentMainSessionKey("Sales Bot") + want := "agent:sales-bot:main" + if got != want { + t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopeMain, + }) + want := "agent:main:main" + if got != want { + t.Errorf("DMScopeMain = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerPeer, + }) + want := "agent:main:direct:user123" + if got != want { + t.Errorf("DMScopePerPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerChannelPeer, + }) + want := "agent:main:telegram:direct:user123" + if got != want { + t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + AccountID: "bot1", + Peer: &RoutePeer{Kind: "direct", ID: "User123"}, + DMScope: DMScopePerAccountChannelPeer, + }) + want := "agent:main:telegram:bot1:direct:user123" + if got != want { + t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "group", ID: "chat456"}, + DMScope: DMScopePerPeer, + }) + want := "agent:main:telegram:group:chat456" + if got != want { + t.Errorf("GroupPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: nil, + DMScope: DMScopePerPeer, + }) + // nil peer defaults to direct with empty ID, falls to main + want := "agent:main:main" + if got != want { + t.Errorf("NilPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) { + links := map[string][]string{ + "john": {"telegram:user123", "discord:john#1234"}, + } + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerPeer, + IdentityLinks: links, + }) + want := "agent:main:direct:john" + if got != want { + t.Errorf("IdentityLink = %q, want %q", got, want) + } +} + +func TestParseAgentSessionKey_Valid(t *testing.T) { + parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123") + if parsed == nil { + t.Fatal("expected non-nil result") + } + if parsed.AgentID != "sales" { + t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID) + } + if parsed.Rest != "telegram:direct:user123" { + t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest) + } +} + +func TestParseAgentSessionKey_Invalid(t *testing.T) { + tests := []string{ + "", + "foo:bar", + "notprefix:sales:main", + "agent::main", + "agent:sales:", + } + for _, input := range tests { + if got := ParseAgentSessionKey(input); got != nil { + t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got) + } + } +} + +func TestIsSubagentSessionKey(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"subagent:task-1", true}, + {"agent:main:subagent:task-1", true}, + {"agent:main:main", false}, + {"agent:main:telegram:direct:user123", false}, + {"", false}, + } + for _, tt := range tests { + if got := IsSubagentSessionKey(tt.input); got != tt.want { + t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 1bd7ac432..c449769de 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -6,9 +6,10 @@ import ( ) type SpawnTool struct { - manager *SubagentManager - originChannel string - originChatID string + manager *SubagentManager + originChannel string + originChatID string + allowlistCheck func(targetAgentID string) bool } func NewSpawnTool(manager *SubagentManager) *SpawnTool { @@ -39,6 +40,10 @@ func (t *SpawnTool) Parameters() map[string]interface{} { "type": "string", "description": "Optional short label for the task (for display)", }, + "agent_id": map[string]interface{}{ + "type": "string", + "description": "Optional target agent ID to delegate the task to", + }, }, "required": []string{"task"}, } @@ -49,6 +54,10 @@ func (t *SpawnTool) SetContext(channel, chatID string) { t.originChatID = chatID } +func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) { + t.allowlistCheck = check +} + func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { task, ok := args["task"].(string) if !ok { @@ -56,12 +65,20 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s } label, _ := args["label"].(string) + agentID, _ := args["agent_id"].(string) + + // Check allowlist if targeting a specific agent + if agentID != "" && t.allowlistCheck != nil { + if !t.allowlistCheck(agentID) { + return fmt.Sprintf("Error: not allowed to spawn agent '%s'", agentID), nil + } + } if t.manager == nil { return "Error: Subagent manager not configured", nil } - result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID) + result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID) if err != nil { return "", fmt.Errorf("failed to spawn subagent: %w", err) } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 0c05097f0..d45ab3433 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -14,6 +14,7 @@ type SubagentTask struct { ID string Task string Label string + AgentID string OriginChannel string OriginChatID string Status string @@ -40,7 +41,7 @@ func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *b } } -func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) { +func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string) (string, error) { sm.mu.Lock() defer sm.mu.Unlock() @@ -51,6 +52,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel ID: taskID, Task: task, Label: label, + AgentID: agentID, OriginChannel: originChannel, OriginChatID: originChatID, Status: "running",