fix(agent): load only allowed MCP servers

This commit is contained in:
afjcjsbx
2026-03-29 23:43:35 +02:00
parent 409251e69d
commit f5f1dc9808
3 changed files with 112 additions and 3 deletions
+32 -3
View File
@@ -69,8 +69,18 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
return nil
}
mcpCfg := filterMCPConfigServers(al.cfg.Tools.MCP, al.registry.allowedMCPServers())
if mcpCfg.Servers == nil || len(mcpCfg.Servers) == 0 {
logger.InfoCF(
"agent",
"No MCP servers selected after applying per-agent mcpServers allowlists",
nil,
)
return nil
}
findValidServer := false
for _, serverCfg := range al.cfg.Tools.MCP.Servers {
for _, serverCfg := range mcpCfg.Servers {
if serverCfg.Enabled {
findValidServer = true
}
@@ -89,7 +99,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
workspacePath = defaultAgent.Workspace
}
if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
if err := mcpManager.LoadFromMCPConfig(ctx, mcpCfg, workspacePath); err != nil {
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
map[string]any{
"error": err.Error(),
@@ -115,7 +125,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
// Determine whether this server's tools should be deferred (hidden).
// Per-server "deferred" field takes precedence over the global Discovery.Enabled.
serverCfg := al.cfg.Tools.MCP.Servers[serverName]
serverCfg := mcpCfg.Servers[serverName]
registerAsHidden := serverIsDeferred(al.cfg.Tools.MCP.Discovery.Enabled, serverCfg)
for _, tool := range conn.Tools {
@@ -216,6 +226,25 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
return al.mcp.getInitErr()
}
func filterMCPConfigServers(
mcpCfg config.MCPConfig,
allowed map[string]struct{},
) config.MCPConfig {
if allowed == nil {
return mcpCfg
}
filtered := mcpCfg
filtered.Servers = make(map[string]config.MCPServerConfig)
for serverName, serverCfg := range mcpCfg.Servers {
if _, ok := allowed[serverName]; ok {
filtered.Servers[serverName] = serverCfg
}
}
return filtered
}
// serverIsDeferred reports whether an MCP server's tools should be registered
// as hidden (deferred/discovery mode).
//
+56
View File
@@ -122,3 +122,59 @@ func TestAgentInstance_AllowsMCPServer(t *testing.T) {
}
})
}
func TestAgentRegistry_AllowedMCPServers(t *testing.T) {
t.Run("returns nil when any agent allows all servers", func(t *testing.T) {
registry := &AgentRegistry{
agents: map[string]*AgentInstance{
"main": {ID: "main", MCPServerAllowlist: nil},
"research": {ID: "research", MCPServerAllowlist: map[string]struct{}{"github": {}}},
},
}
if allowed := registry.allowedMCPServers(); allowed != nil {
t.Fatalf("expected nil union when one agent allows all, got %v", allowed)
}
})
t.Run("returns union of explicit allowlists", func(t *testing.T) {
registry := &AgentRegistry{
agents: map[string]*AgentInstance{
"main": {ID: "main", MCPServerAllowlist: map[string]struct{}{"github": {}}},
"research": {ID: "research", MCPServerAllowlist: map[string]struct{}{"filesystem": {}}},
},
}
allowed := registry.allowedMCPServers()
if len(allowed) != 2 {
t.Fatalf("len(allowed) = %d, want 2", len(allowed))
}
if _, ok := allowed["github"]; !ok {
t.Fatal("expected github in allowed MCP server union")
}
if _, ok := allowed["filesystem"]; !ok {
t.Fatal("expected filesystem in allowed MCP server union")
}
})
}
func TestFilterMCPConfigServers(t *testing.T) {
mcpCfg := config.MCPConfig{
ToolConfig: config.ToolConfig{Enabled: true},
Servers: map[string]config.MCPServerConfig{
"github": {Enabled: true},
"filesystem": {Enabled: true},
},
}
filtered := filterMCPConfigServers(mcpCfg, map[string]struct{}{"github": {}})
if len(filtered.Servers) != 1 {
t.Fatalf("len(filtered.Servers) = %d, want 1", len(filtered.Servers))
}
if _, ok := filtered.Servers["github"]; !ok {
t.Fatal("expected github server to remain after filtering")
}
if _, ok := filtered.Servers["filesystem"]; ok {
t.Fatal("expected filesystem server to be removed by filtering")
}
}
+24
View File
@@ -88,6 +88,30 @@ func (r *AgentRegistry) ListAgentIDs() []string {
return ids
}
func (r *AgentRegistry) allowedMCPServers() map[string]struct{} {
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.agents) == 0 {
return nil
}
union := make(map[string]struct{})
for _, agent := range r.agents {
if agent == nil {
continue
}
if agent.MCPServerAllowlist == nil {
return nil
}
for serverName := range agent.MCPServerAllowlist {
union[serverName] = struct{}{}
}
}
return union
}
// CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID.
func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool {
parent, ok := r.GetAgent(parentAgentID)