From 847218ef29a6bae4d44bceb6a1d95044ab2f9c57 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 29 Mar 2026 23:22:47 +0200 Subject: [PATCH] refactor(agent): added mcp allowlist --- pkg/agent/instance.go | 11 +++++++++ pkg/agent/instance_test.go | 10 ++++++++ pkg/agent/loop_mcp.go | 9 +++++++ pkg/agent/loop_mcp_test.go | 49 +++++++++++++++++++++++++++++++++++++ pkg/agent/tool_allowlist.go | 17 +++++++++++++ 5 files changed, 96 insertions(+) diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 89bf0416a..f95a165af 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -39,6 +39,7 @@ type AgentInstance struct { Tools *tools.ToolRegistry Subagents *config.SubagentsConfig SkillsFilter []string + MCPServerAllowlist map[string]struct{} Candidates []providers.FallbackCandidate // Router is non-nil when model routing is configured and the light model @@ -75,6 +76,7 @@ func NewAgentInstance( allowReadPaths := buildAllowReadPatterns(cfg) allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths) agentToolAllowlist := resolveAgentToolAllowlist(definition) + agentMCPServerAllowlist := resolveAgentMCPServerAllowlist(definition) toolsRegistry := tools.NewToolRegistry() toolsRegistry.SetAllowlist(agentToolAllowlist) @@ -237,6 +239,7 @@ func NewAgentInstance( Tools: toolsRegistry, Subagents: subagents, SkillsFilter: skillsFilter, + MCPServerAllowlist: agentMCPServerAllowlist, Candidates: candidates, Router: router, LightCandidates: lightCandidates, @@ -295,6 +298,14 @@ func resolveAgentSkillsFilter( return append([]string(nil), agentCfg.Skills...) } +func (a *AgentInstance) AllowsMCPServer(serverName string) bool { + if a == nil || a.MCPServerAllowlist == nil { + return true + } + _, ok := a.MCPServerAllowlist[strings.ToLower(strings.TrimSpace(serverName))] + return ok +} + func compilePatterns(patterns []string) []*regexp.Regexp { compiled := make([]*regexp.Regexp, 0, len(patterns)) for _, p := range patterns { diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go index aedda1c32..869e5fbc7 100644 --- a/pkg/agent/instance_test.go +++ b/pkg/agent/instance_test.go @@ -287,6 +287,7 @@ func TestNewAgentInstance_UsesFrontmatterModelAndSkills(t *testing.T) { "AGENT.md": `--- model: frontmatter-model skills: [frontmatter-skill] +mcpServers: [GitHub, filesystem] --- # Agent @@ -319,4 +320,13 @@ Use frontmatter identity. if len(agent.SkillsFilter) != 1 || agent.SkillsFilter[0] != "frontmatter-skill" { t.Fatalf("agent.SkillsFilter = %v, want [frontmatter-skill]", agent.SkillsFilter) } + if !agent.AllowsMCPServer("github") { + t.Fatal("expected github MCP server to be allowed from frontmatter") + } + if !agent.AllowsMCPServer("FILESYSTEM") { + t.Fatal("expected filesystem MCP server matching to be case-insensitive") + } + if agent.AllowsMCPServer("slack") { + t.Fatal("expected slack MCP server to be blocked by frontmatter allowlist") + } } diff --git a/pkg/agent/loop_mcp.go b/pkg/agent/loop_mcp.go index 97debbc33..1fad059a4 100644 --- a/pkg/agent/loop_mcp.go +++ b/pkg/agent/loop_mcp.go @@ -124,6 +124,15 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error { if !ok { continue } + if !agent.AllowsMCPServer(serverName) { + logger.DebugCF("agent", "Skipped MCP tool registration by agent mcpServers allowlist", + map[string]any{ + "agent_id": agentID, + "server": serverName, + "tool": tool.Name, + }) + continue + } mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) diff --git a/pkg/agent/loop_mcp_test.go b/pkg/agent/loop_mcp_test.go index 35c3e49c8..baf126bd1 100644 --- a/pkg/agent/loop_mcp_test.go +++ b/pkg/agent/loop_mcp_test.go @@ -7,6 +7,8 @@ package agent import ( + "os" + "path/filepath" "testing" "github.com/sipeed/picoclaw/pkg/config" @@ -73,3 +75,50 @@ func TestServerIsDeferred(t *testing.T) { }) } } + +func TestResolveAgentMCPServerAllowlist(t *testing.T) { + workspace := t.TempDir() + agentPath := filepath.Join(workspace, "AGENT.md") + content := `--- +mcpServers: [GitHub, filesystem, github] +--- +# Agent +` + if err := os.WriteFile(agentPath, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile(AGENT.md) error = %v", err) + } + + allowlist := resolveAgentMCPServerAllowlist(loadAgentDefinition(workspace)) + if len(allowlist) != 2 { + t.Fatalf("len(allowlist) = %d, want 2", len(allowlist)) + } + if _, ok := allowlist["github"]; !ok { + t.Fatal("expected github to be present in MCP allowlist") + } + if _, ok := allowlist["filesystem"]; !ok { + t.Fatal("expected filesystem to be present in MCP allowlist") + } +} + +func TestAgentInstance_AllowsMCPServer(t *testing.T) { + t.Run("nil allowlist allows all", func(t *testing.T) { + agent := &AgentInstance{} + if !agent.AllowsMCPServer("github") { + t.Fatal("expected nil MCP allowlist to allow all servers") + } + }) + + t.Run("explicit allowlist filters servers", func(t *testing.T) { + agent := &AgentInstance{ + MCPServerAllowlist: map[string]struct{}{ + "github": {}, + }, + } + if !agent.AllowsMCPServer("GitHub") { + t.Fatal("expected MCP server matching to be case-insensitive") + } + if agent.AllowsMCPServer("filesystem") { + t.Fatal("expected filesystem to be blocked by MCP allowlist") + } + }) +} diff --git a/pkg/agent/tool_allowlist.go b/pkg/agent/tool_allowlist.go index 899c84b89..de68352ad 100644 --- a/pkg/agent/tool_allowlist.go +++ b/pkg/agent/tool_allowlist.go @@ -26,3 +26,20 @@ func resolveAgentToolAllowlist(definition AgentContextDefinition) []string { sort.Strings(result) return result } + +func resolveAgentMCPServerAllowlist(definition AgentContextDefinition) map[string]struct{} { + if definition.Agent == nil || definition.Agent.Frontmatter.MCPServers == nil { + return nil + } + + allowlist := make(map[string]struct{}, len(definition.Agent.Frontmatter.MCPServers)) + for _, raw := range definition.Agent.Frontmatter.MCPServers { + trimmed := strings.ToLower(strings.TrimSpace(raw)) + if trimmed == "" { + continue + } + allowlist[trimmed] = struct{}{} + } + + return allowlist +}