fix(agent): align MCP prompt registration with tool allowlist

This commit is contained in:
afjcjsbx
2026-05-07 14:01:43 +02:00
parent 27bd816b1c
commit dd8e247550
4 changed files with 130 additions and 20 deletions
+62 -20
View File
@@ -144,25 +144,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
// Per-server "deferred" field takes precedence over the global Discovery.Enabled.
serverCfg := mcpCfg.Servers[serverName]
registerAsHidden := serverIsDeferred(al.cfg.Tools.MCP.Discovery.Enabled, serverCfg)
for _, agentID := range agentIDs {
agent, ok := al.registry.GetAgent(agentID)
if !ok || agent.ContextBuilder == nil {
continue
}
if err := agent.ContextBuilder.RegisterPromptContributor(mcpServerPromptContributor{
serverName: serverName,
toolCount: len(conn.Tools),
deferred: registerAsHidden,
}); err != nil {
logger.WarnCF("agent", "Failed to register MCP prompt contributor",
map[string]any{
"agent_id": agentID,
"server": serverName,
"error": err.Error(),
})
}
}
registeredToolsByAgent := make(map[string]map[string]struct{}, len(agentIDs))
for _, tool := range conn.Tools {
for _, agentID := range agentIDs {
@@ -181,6 +163,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
toolName := mcpTool.Name()
mcpTool.SetWorkspace(agent.Workspace)
mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
mcpTool.SetEventPublisher(al.runtimeEvents)
@@ -190,18 +173,36 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
} else {
agent.Tools.Register(mcpTool)
}
if !toolRegistryIncludes(agent.Tools, toolName) {
continue
}
recordRegisteredMCPTool(registeredToolsByAgent, agentID, toolName)
totalRegistrations++
logger.DebugCF("agent", "Registered MCP tool",
map[string]any{
"agent_id": agentID,
"server": serverName,
"tool": tool.Name,
"name": mcpTool.Name(),
"name": toolName,
"deferred": registerAsHidden,
})
}
}
for _, agentID := range agentIDs {
agent, ok := al.registry.GetAgent(agentID)
if !ok {
continue
}
registerMCPServerPromptContributor(
agentID,
agent,
serverName,
len(registeredToolsByAgent[agentID]),
registerAsHidden,
)
}
}
logger.InfoCF("agent", "MCP tools registered successfully",
map[string]any{
@@ -265,6 +266,47 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
return al.mcp.getInitErr()
}
func registerMCPServerPromptContributor(
agentID string,
agent *AgentInstance,
serverName string,
toolCount int,
registerAsHidden bool,
) {
if agent == nil || agent.ContextBuilder == nil || toolCount <= 0 {
return
}
if err := agent.ContextBuilder.RegisterPromptContributor(mcpServerPromptContributor{
serverName: serverName,
toolCount: toolCount,
deferred: registerAsHidden,
}); err != nil {
logger.WarnCF("agent", "Failed to register MCP prompt contributor",
map[string]any{
"agent_id": agentID,
"server": serverName,
"error": err.Error(),
})
}
}
func recordRegisteredMCPTool(
registeredToolsByAgent map[string]map[string]struct{},
agentID, toolName string,
) {
if registeredToolsByAgent[agentID] == nil {
registeredToolsByAgent[agentID] = make(map[string]struct{})
}
registeredToolsByAgent[agentID][toolName] = struct{}{}
}
func toolRegistryIncludes(registry *tools.ToolRegistry, name string) bool {
if registry == nil {
return false
}
return registry.HasRegistered(name)
}
func filterMCPConfigServers(
mcpCfg config.MCPConfig,
allowed map[string]struct{},
+37
View File
@@ -14,6 +14,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/mcp"
agenttools "github.com/sipeed/picoclaw/pkg/tools"
)
func boolPtr(b bool) *bool { return &b }
@@ -135,6 +136,42 @@ func TestServerIsDeferred(t *testing.T) {
}
}
func TestRegisterMCPServerPromptContributorUsesActualRegisteredToolCount(t *testing.T) {
cb := NewContextBuilder(t.TempDir())
agent := &AgentInstance{ContextBuilder: cb}
registerMCPServerPromptContributor("research", agent, "github", 0, false)
messages := cb.BuildMessagesFromPrompt(PromptBuildRequest{CurrentMessage: "hello"})
if prompt := messages[0].Content; strings.Contains(prompt, "MCP server `github`") {
t.Fatalf("expected no MCP prompt when no tools were registered, got %q", prompt)
}
registerMCPServerPromptContributor("research", agent, "github", 2, false)
messages = cb.BuildMessagesFromPrompt(PromptBuildRequest{CurrentMessage: "hello"})
prompt := messages[0].Content
if !strings.Contains(prompt, "MCP server `github` is connected") {
t.Fatalf("expected MCP prompt for registered tools, got %q", prompt)
}
if !strings.Contains(prompt, "It contributes 2 tool(s)") {
t.Fatalf("expected actual registered tool count in prompt, got %q", prompt)
}
}
func TestToolRegistryIncludesReportsOnlyRegisteredTools(t *testing.T) {
registry := agenttools.NewToolRegistry()
registry.SetAllowlist([]string{"mcp_github_search"})
registry.RegisterHidden(&allowlistTestTool{name: "mcp_github_search"})
registry.RegisterHidden(&allowlistTestTool{name: "mcp_github_create_issue"})
if !toolRegistryIncludes(registry, "mcp_github_search") {
t.Fatal("expected hidden registered MCP tool to be included")
}
if toolRegistryIncludes(registry, "mcp_github_create_issue") {
t.Fatal("blocked MCP tool should not be included")
}
}
func TestEnsureMCPInitialized_LoadFailureSetsInitErr(t *testing.T) {
al, cfg, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
+9
View File
@@ -176,6 +176,15 @@ func (r *ToolRegistry) toolAllowedLocked(name string) bool {
return ok
}
// HasRegistered reports whether a tool name is present in the registry,
// including hidden tools whose TTL is currently zero.
func (r *ToolRegistry) HasRegistered(name string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, ok := r.tools[name]
return ok
}
// HiddenToolSnapshot holds a consistent snapshot of hidden tools and the
// registry version at which it was taken. Used by BM25SearchTool cache.
type HiddenToolSnapshot struct {
+22
View File
@@ -130,6 +130,28 @@ func TestToolRegistry_AllowlistFiltersRegistrations(t *testing.T) {
}
}
func TestToolRegistry_HasRegisteredIncludesHiddenTools(t *testing.T) {
r := NewToolRegistry()
r.SetAllowlist([]string{"visible", "hidden"})
r.Register(newMockTool("visible", "visible"))
r.RegisterHidden(newMockTool("hidden", "hidden"))
r.RegisterHidden(newMockTool("blocked", "blocked"))
if !r.HasRegistered("visible") {
t.Fatal("expected visible tool to be registered")
}
if !r.HasRegistered("hidden") {
t.Fatal("expected hidden tool to be reported as registered")
}
if r.HasRegistered("blocked") {
t.Fatal("blocked tool should not be registered")
}
if _, ok := r.Get("hidden"); ok {
t.Fatal("hidden tool with zero TTL should not be callable through Get")
}
}
func TestToolRegistry_Get_NotFound(t *testing.T) {
r := NewToolRegistry()
_, ok := r.Get("nonexistent")