mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
fix(agent): align MCP prompt registration with tool allowlist
This commit is contained in:
+62
-20
@@ -144,25 +144,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
// Per-server "deferred" field takes precedence over the global Discovery.Enabled.
|
||||
serverCfg := mcpCfg.Servers[serverName]
|
||||
registerAsHidden := serverIsDeferred(al.cfg.Tools.MCP.Discovery.Enabled, serverCfg)
|
||||
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok || agent.ContextBuilder == nil {
|
||||
continue
|
||||
}
|
||||
if err := agent.ContextBuilder.RegisterPromptContributor(mcpServerPromptContributor{
|
||||
serverName: serverName,
|
||||
toolCount: len(conn.Tools),
|
||||
deferred: registerAsHidden,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Failed to register MCP prompt contributor",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
registeredToolsByAgent := make(map[string]map[string]struct{}, len(agentIDs))
|
||||
|
||||
for _, tool := range conn.Tools {
|
||||
for _, agentID := range agentIDs {
|
||||
@@ -181,6 +163,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
toolName := mcpTool.Name()
|
||||
mcpTool.SetWorkspace(agent.Workspace)
|
||||
mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
|
||||
mcpTool.SetEventPublisher(al.runtimeEvents)
|
||||
@@ -190,18 +173,36 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
} else {
|
||||
agent.Tools.Register(mcpTool)
|
||||
}
|
||||
if !toolRegistryIncludes(agent.Tools, toolName) {
|
||||
continue
|
||||
}
|
||||
|
||||
recordRegisteredMCPTool(registeredToolsByAgent, agentID, toolName)
|
||||
totalRegistrations++
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"tool": tool.Name,
|
||||
"name": mcpTool.Name(),
|
||||
"name": toolName,
|
||||
"deferred": registerAsHidden,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
registerMCPServerPromptContributor(
|
||||
agentID,
|
||||
agent,
|
||||
serverName,
|
||||
len(registeredToolsByAgent[agentID]),
|
||||
registerAsHidden,
|
||||
)
|
||||
}
|
||||
}
|
||||
logger.InfoCF("agent", "MCP tools registered successfully",
|
||||
map[string]any{
|
||||
@@ -265,6 +266,47 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
return al.mcp.getInitErr()
|
||||
}
|
||||
|
||||
func registerMCPServerPromptContributor(
|
||||
agentID string,
|
||||
agent *AgentInstance,
|
||||
serverName string,
|
||||
toolCount int,
|
||||
registerAsHidden bool,
|
||||
) {
|
||||
if agent == nil || agent.ContextBuilder == nil || toolCount <= 0 {
|
||||
return
|
||||
}
|
||||
if err := agent.ContextBuilder.RegisterPromptContributor(mcpServerPromptContributor{
|
||||
serverName: serverName,
|
||||
toolCount: toolCount,
|
||||
deferred: registerAsHidden,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Failed to register MCP prompt contributor",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func recordRegisteredMCPTool(
|
||||
registeredToolsByAgent map[string]map[string]struct{},
|
||||
agentID, toolName string,
|
||||
) {
|
||||
if registeredToolsByAgent[agentID] == nil {
|
||||
registeredToolsByAgent[agentID] = make(map[string]struct{})
|
||||
}
|
||||
registeredToolsByAgent[agentID][toolName] = struct{}{}
|
||||
}
|
||||
|
||||
func toolRegistryIncludes(registry *tools.ToolRegistry, name string) bool {
|
||||
if registry == nil {
|
||||
return false
|
||||
}
|
||||
return registry.HasRegistered(name)
|
||||
}
|
||||
|
||||
func filterMCPConfigServers(
|
||||
mcpCfg config.MCPConfig,
|
||||
allowed map[string]struct{},
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
agenttools "github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
@@ -135,6 +136,42 @@ func TestServerIsDeferred(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterMCPServerPromptContributorUsesActualRegisteredToolCount(t *testing.T) {
|
||||
cb := NewContextBuilder(t.TempDir())
|
||||
agent := &AgentInstance{ContextBuilder: cb}
|
||||
|
||||
registerMCPServerPromptContributor("research", agent, "github", 0, false)
|
||||
messages := cb.BuildMessagesFromPrompt(PromptBuildRequest{CurrentMessage: "hello"})
|
||||
if prompt := messages[0].Content; strings.Contains(prompt, "MCP server `github`") {
|
||||
t.Fatalf("expected no MCP prompt when no tools were registered, got %q", prompt)
|
||||
}
|
||||
|
||||
registerMCPServerPromptContributor("research", agent, "github", 2, false)
|
||||
messages = cb.BuildMessagesFromPrompt(PromptBuildRequest{CurrentMessage: "hello"})
|
||||
prompt := messages[0].Content
|
||||
if !strings.Contains(prompt, "MCP server `github` is connected") {
|
||||
t.Fatalf("expected MCP prompt for registered tools, got %q", prompt)
|
||||
}
|
||||
if !strings.Contains(prompt, "It contributes 2 tool(s)") {
|
||||
t.Fatalf("expected actual registered tool count in prompt, got %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistryIncludesReportsOnlyRegisteredTools(t *testing.T) {
|
||||
registry := agenttools.NewToolRegistry()
|
||||
registry.SetAllowlist([]string{"mcp_github_search"})
|
||||
|
||||
registry.RegisterHidden(&allowlistTestTool{name: "mcp_github_search"})
|
||||
registry.RegisterHidden(&allowlistTestTool{name: "mcp_github_create_issue"})
|
||||
|
||||
if !toolRegistryIncludes(registry, "mcp_github_search") {
|
||||
t.Fatal("expected hidden registered MCP tool to be included")
|
||||
}
|
||||
if toolRegistryIncludes(registry, "mcp_github_create_issue") {
|
||||
t.Fatal("blocked MCP tool should not be included")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureMCPInitialized_LoadFailureSetsInitErr(t *testing.T) {
|
||||
al, cfg, _, _, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -176,6 +176,15 @@ func (r *ToolRegistry) toolAllowedLocked(name string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// HasRegistered reports whether a tool name is present in the registry,
|
||||
// including hidden tools whose TTL is currently zero.
|
||||
func (r *ToolRegistry) HasRegistered(name string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
_, ok := r.tools[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// HiddenToolSnapshot holds a consistent snapshot of hidden tools and the
|
||||
// registry version at which it was taken. Used by BM25SearchTool cache.
|
||||
type HiddenToolSnapshot struct {
|
||||
|
||||
@@ -130,6 +130,28 @@ func TestToolRegistry_AllowlistFiltersRegistrations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_HasRegisteredIncludesHiddenTools(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.SetAllowlist([]string{"visible", "hidden"})
|
||||
|
||||
r.Register(newMockTool("visible", "visible"))
|
||||
r.RegisterHidden(newMockTool("hidden", "hidden"))
|
||||
r.RegisterHidden(newMockTool("blocked", "blocked"))
|
||||
|
||||
if !r.HasRegistered("visible") {
|
||||
t.Fatal("expected visible tool to be registered")
|
||||
}
|
||||
if !r.HasRegistered("hidden") {
|
||||
t.Fatal("expected hidden tool to be reported as registered")
|
||||
}
|
||||
if r.HasRegistered("blocked") {
|
||||
t.Fatal("blocked tool should not be registered")
|
||||
}
|
||||
if _, ok := r.Get("hidden"); ok {
|
||||
t.Fatal("hidden tool with zero TTL should not be callable through Get")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Get_NotFound(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
_, ok := r.Get("nonexistent")
|
||||
|
||||
Reference in New Issue
Block a user