refactor(agent): added mcp allowlist

This commit is contained in:
afjcjsbx
2026-03-29 23:22:47 +02:00
parent 0ef25f779e
commit 847218ef29
5 changed files with 96 additions and 0 deletions
+11
View File
@@ -39,6 +39,7 @@ type AgentInstance struct {
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
SkillsFilter []string
MCPServerAllowlist map[string]struct{}
Candidates []providers.FallbackCandidate
// Router is non-nil when model routing is configured and the light model
@@ -75,6 +76,7 @@ func NewAgentInstance(
allowReadPaths := buildAllowReadPatterns(cfg)
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
agentToolAllowlist := resolveAgentToolAllowlist(definition)
agentMCPServerAllowlist := resolveAgentMCPServerAllowlist(definition)
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.SetAllowlist(agentToolAllowlist)
@@ -237,6 +239,7 @@ func NewAgentInstance(
Tools: toolsRegistry,
Subagents: subagents,
SkillsFilter: skillsFilter,
MCPServerAllowlist: agentMCPServerAllowlist,
Candidates: candidates,
Router: router,
LightCandidates: lightCandidates,
@@ -295,6 +298,14 @@ func resolveAgentSkillsFilter(
return append([]string(nil), agentCfg.Skills...)
}
func (a *AgentInstance) AllowsMCPServer(serverName string) bool {
if a == nil || a.MCPServerAllowlist == nil {
return true
}
_, ok := a.MCPServerAllowlist[strings.ToLower(strings.TrimSpace(serverName))]
return ok
}
func compilePatterns(patterns []string) []*regexp.Regexp {
compiled := make([]*regexp.Regexp, 0, len(patterns))
for _, p := range patterns {
+10
View File
@@ -287,6 +287,7 @@ func TestNewAgentInstance_UsesFrontmatterModelAndSkills(t *testing.T) {
"AGENT.md": `---
model: frontmatter-model
skills: [frontmatter-skill]
mcpServers: [GitHub, filesystem]
---
# Agent
@@ -319,4 +320,13 @@ Use frontmatter identity.
if len(agent.SkillsFilter) != 1 || agent.SkillsFilter[0] != "frontmatter-skill" {
t.Fatalf("agent.SkillsFilter = %v, want [frontmatter-skill]", agent.SkillsFilter)
}
if !agent.AllowsMCPServer("github") {
t.Fatal("expected github MCP server to be allowed from frontmatter")
}
if !agent.AllowsMCPServer("FILESYSTEM") {
t.Fatal("expected filesystem MCP server matching to be case-insensitive")
}
if agent.AllowsMCPServer("slack") {
t.Fatal("expected slack MCP server to be blocked by frontmatter allowlist")
}
}
+9
View File
@@ -124,6 +124,15 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
if !ok {
continue
}
if !agent.AllowsMCPServer(serverName) {
logger.DebugCF("agent", "Skipped MCP tool registration by agent mcpServers allowlist",
map[string]any{
"agent_id": agentID,
"server": serverName,
"tool": tool.Name,
})
continue
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
+49
View File
@@ -7,6 +7,8 @@
package agent
import (
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
@@ -73,3 +75,50 @@ func TestServerIsDeferred(t *testing.T) {
})
}
}
func TestResolveAgentMCPServerAllowlist(t *testing.T) {
workspace := t.TempDir()
agentPath := filepath.Join(workspace, "AGENT.md")
content := `---
mcpServers: [GitHub, filesystem, github]
---
# Agent
`
if err := os.WriteFile(agentPath, []byte(content), 0o644); err != nil {
t.Fatalf("WriteFile(AGENT.md) error = %v", err)
}
allowlist := resolveAgentMCPServerAllowlist(loadAgentDefinition(workspace))
if len(allowlist) != 2 {
t.Fatalf("len(allowlist) = %d, want 2", len(allowlist))
}
if _, ok := allowlist["github"]; !ok {
t.Fatal("expected github to be present in MCP allowlist")
}
if _, ok := allowlist["filesystem"]; !ok {
t.Fatal("expected filesystem to be present in MCP allowlist")
}
}
func TestAgentInstance_AllowsMCPServer(t *testing.T) {
t.Run("nil allowlist allows all", func(t *testing.T) {
agent := &AgentInstance{}
if !agent.AllowsMCPServer("github") {
t.Fatal("expected nil MCP allowlist to allow all servers")
}
})
t.Run("explicit allowlist filters servers", func(t *testing.T) {
agent := &AgentInstance{
MCPServerAllowlist: map[string]struct{}{
"github": {},
},
}
if !agent.AllowsMCPServer("GitHub") {
t.Fatal("expected MCP server matching to be case-insensitive")
}
if agent.AllowsMCPServer("filesystem") {
t.Fatal("expected filesystem to be blocked by MCP allowlist")
}
})
}
+17
View File
@@ -26,3 +26,20 @@ func resolveAgentToolAllowlist(definition AgentContextDefinition) []string {
sort.Strings(result)
return result
}
func resolveAgentMCPServerAllowlist(definition AgentContextDefinition) map[string]struct{} {
if definition.Agent == nil || definition.Agent.Frontmatter.MCPServers == nil {
return nil
}
allowlist := make(map[string]struct{}, len(definition.Agent.Frontmatter.MCPServers))
for _, raw := range definition.Agent.Frontmatter.MCPServers {
trimmed := strings.ToLower(strings.TrimSpace(raw))
if trimmed == "" {
continue
}
allowlist[trimmed] = struct{}{}
}
return allowlist
}