diff --git a/pkg/agent/loop_mcp.go b/pkg/agent/loop_mcp.go index 1fad059a4..c9f3bc03d 100644 --- a/pkg/agent/loop_mcp.go +++ b/pkg/agent/loop_mcp.go @@ -69,8 +69,18 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error { return nil } + mcpCfg := filterMCPConfigServers(al.cfg.Tools.MCP, al.registry.allowedMCPServers()) + if mcpCfg.Servers == nil || len(mcpCfg.Servers) == 0 { + logger.InfoCF( + "agent", + "No MCP servers selected after applying per-agent mcpServers allowlists", + nil, + ) + return nil + } + findValidServer := false - for _, serverCfg := range al.cfg.Tools.MCP.Servers { + for _, serverCfg := range mcpCfg.Servers { if serverCfg.Enabled { findValidServer = true } @@ -89,7 +99,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error { workspacePath = defaultAgent.Workspace } - if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil { + if err := mcpManager.LoadFromMCPConfig(ctx, mcpCfg, workspacePath); err != nil { logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available", map[string]any{ "error": err.Error(), @@ -115,7 +125,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error { // Determine whether this server's tools should be deferred (hidden). // Per-server "deferred" field takes precedence over the global Discovery.Enabled. - serverCfg := al.cfg.Tools.MCP.Servers[serverName] + serverCfg := mcpCfg.Servers[serverName] registerAsHidden := serverIsDeferred(al.cfg.Tools.MCP.Discovery.Enabled, serverCfg) for _, tool := range conn.Tools { @@ -216,6 +226,25 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error { return al.mcp.getInitErr() } +func filterMCPConfigServers( + mcpCfg config.MCPConfig, + allowed map[string]struct{}, +) config.MCPConfig { + if allowed == nil { + return mcpCfg + } + + filtered := mcpCfg + filtered.Servers = make(map[string]config.MCPServerConfig) + for serverName, serverCfg := range mcpCfg.Servers { + if _, ok := allowed[serverName]; ok { + filtered.Servers[serverName] = serverCfg + } + } + + return filtered +} + // serverIsDeferred reports whether an MCP server's tools should be registered // as hidden (deferred/discovery mode). // diff --git a/pkg/agent/loop_mcp_test.go b/pkg/agent/loop_mcp_test.go index baf126bd1..ee00d22ba 100644 --- a/pkg/agent/loop_mcp_test.go +++ b/pkg/agent/loop_mcp_test.go @@ -122,3 +122,59 @@ func TestAgentInstance_AllowsMCPServer(t *testing.T) { } }) } + +func TestAgentRegistry_AllowedMCPServers(t *testing.T) { + t.Run("returns nil when any agent allows all servers", func(t *testing.T) { + registry := &AgentRegistry{ + agents: map[string]*AgentInstance{ + "main": {ID: "main", MCPServerAllowlist: nil}, + "research": {ID: "research", MCPServerAllowlist: map[string]struct{}{"github": {}}}, + }, + } + + if allowed := registry.allowedMCPServers(); allowed != nil { + t.Fatalf("expected nil union when one agent allows all, got %v", allowed) + } + }) + + t.Run("returns union of explicit allowlists", func(t *testing.T) { + registry := &AgentRegistry{ + agents: map[string]*AgentInstance{ + "main": {ID: "main", MCPServerAllowlist: map[string]struct{}{"github": {}}}, + "research": {ID: "research", MCPServerAllowlist: map[string]struct{}{"filesystem": {}}}, + }, + } + + allowed := registry.allowedMCPServers() + if len(allowed) != 2 { + t.Fatalf("len(allowed) = %d, want 2", len(allowed)) + } + if _, ok := allowed["github"]; !ok { + t.Fatal("expected github in allowed MCP server union") + } + if _, ok := allowed["filesystem"]; !ok { + t.Fatal("expected filesystem in allowed MCP server union") + } + }) +} + +func TestFilterMCPConfigServers(t *testing.T) { + mcpCfg := config.MCPConfig{ + ToolConfig: config.ToolConfig{Enabled: true}, + Servers: map[string]config.MCPServerConfig{ + "github": {Enabled: true}, + "filesystem": {Enabled: true}, + }, + } + + filtered := filterMCPConfigServers(mcpCfg, map[string]struct{}{"github": {}}) + if len(filtered.Servers) != 1 { + t.Fatalf("len(filtered.Servers) = %d, want 1", len(filtered.Servers)) + } + if _, ok := filtered.Servers["github"]; !ok { + t.Fatal("expected github server to remain after filtering") + } + if _, ok := filtered.Servers["filesystem"]; ok { + t.Fatal("expected filesystem server to be removed by filtering") + } +} diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go index ef5645e51..1eba72250 100644 --- a/pkg/agent/registry.go +++ b/pkg/agent/registry.go @@ -88,6 +88,30 @@ func (r *AgentRegistry) ListAgentIDs() []string { return ids } +func (r *AgentRegistry) allowedMCPServers() map[string]struct{} { + r.mu.RLock() + defer r.mu.RUnlock() + + if len(r.agents) == 0 { + return nil + } + + union := make(map[string]struct{}) + for _, agent := range r.agents { + if agent == nil { + continue + } + if agent.MCPServerAllowlist == nil { + return nil + } + for serverName := range agent.MCPServerAllowlist { + union[serverName] = struct{}{} + } + } + + return union +} + // CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID. func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool { parent, ok := r.GetAgent(parentAgentID)