diff --git a/pkg/agent/agent_mcp.go b/pkg/agent/agent_mcp.go index 251d32b58..fcb57a5d4 100644 --- a/pkg/agent/agent_mcp.go +++ b/pkg/agent/agent_mcp.go @@ -135,6 +135,25 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error { serverCfg := al.cfg.Tools.MCP.Servers[serverName] registerAsHidden := serverIsDeferred(al.cfg.Tools.MCP.Discovery.Enabled, serverCfg) + for _, agentID := range agentIDs { + agent, ok := al.registry.GetAgent(agentID) + if !ok || agent.ContextBuilder == nil { + continue + } + if err := agent.ContextBuilder.RegisterPromptContributor(mcpServerPromptContributor{ + serverName: serverName, + toolCount: len(conn.Tools), + deferred: registerAsHidden, + }); err != nil { + logger.WarnCF("agent", "Failed to register MCP prompt contributor", + map[string]any{ + "agent_id": agentID, + "server": serverName, + "error": err.Error(), + }) + } + } + for _, tool := range conn.Tools { for _, agentID := range agentIDs { agent, ok := al.registry.GetAgent(agentID) diff --git a/pkg/agent/context.go b/pkg/agent/context.go index feb04e347..ecde7c33e 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -22,13 +22,11 @@ import ( ) type ContextBuilder struct { - workspace string - skillsLoader *skills.SkillsLoader - memory *MemoryStore - toolDiscoveryBM25 bool - toolDiscoveryRegex bool - splitOnMarker bool - promptRegistry *PromptRegistry + workspace string + skillsLoader *skills.SkillsLoader + memory *MemoryStore + splitOnMarker bool + promptRegistry *PromptRegistry // Cache for system prompt to avoid rebuilding on every call. // This fixes issue #607: repeated reprocessing of the entire context. @@ -50,8 +48,16 @@ type ContextBuilder struct { } func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuilder { - cb.toolDiscoveryBM25 = useBM25 - cb.toolDiscoveryRegex = useRegex + if useBM25 || useRegex { + if err := cb.RegisterPromptContributor(toolDiscoveryPromptContributor{ + useBM25: useBM25, + useRegex: useRegex, + }); err != nil { + logger.WarnCF("agent", "Failed to register tool discovery prompt contributor", map[string]any{ + "error": err.Error(), + }) + } + } return cb } @@ -83,11 +89,19 @@ func NewContextBuilder(workspace string) *ContextBuilder { } func (cb *ContextBuilder) RegisterPromptSource(desc PromptSourceDescriptor) error { - return cb.promptRegistryOrDefault().RegisterSource(desc) + err := cb.promptRegistryOrDefault().RegisterSource(desc) + if err == nil { + cb.InvalidateCache() + } + return err } func (cb *ContextBuilder) RegisterPromptContributor(contributor PromptContributor) error { - return cb.promptRegistryOrDefault().RegisterContributor(contributor) + err := cb.promptRegistryOrDefault().RegisterContributor(contributor) + if err == nil { + cb.InvalidateCache() + } + return err } func (cb *ContextBuilder) promptRegistryOrDefault() *PromptRegistry { @@ -124,16 +138,16 @@ Your workspace is at: %s version, workspacePath, workspacePath, workspacePath, workspacePath, workspacePath) } -func (cb *ContextBuilder) getDiscoveryRule() string { - if !cb.toolDiscoveryBM25 && !cb.toolDiscoveryRegex { +func formatToolDiscoveryRule(useBM25, useRegex bool) string { + if !useBM25 && !useRegex { return "" } var toolNames []string - if cb.toolDiscoveryBM25 { + if useBM25 { toolNames = append(toolNames, `"tool_search_tool_bm25"`) } - if cb.toolDiscoveryRegex { + if useRegex { toolNames = append(toolNames, `"tool_search_tool_regex"`) } @@ -173,19 +187,6 @@ func (cb *ContextBuilder) BuildSystemPromptParts() []PromptPart { Cache: PromptCacheEphemeral, }) - if toolDiscovery := cb.getDiscoveryRule(); toolDiscovery != "" { - add(PromptPart{ - ID: "capability.tool_discovery", - Layer: PromptLayerCapability, - Slot: PromptSlotTooling, - Source: PromptSource{ID: PromptSourceToolDiscovery, Name: "tool_registry:discovery"}, - Title: "tool discovery", - Content: toolDiscovery, - Stable: true, - Cache: PromptCacheEphemeral, - }) - } - // Bootstrap files bootstrapContent := cb.LoadBootstrapFiles() if bootstrapContent != "" { @@ -318,6 +319,19 @@ func (cb *ContextBuilder) EstimateSystemTokens(summary string, activeSkills []st totalChars += 7 // separator \n\n---\n\n } + if contributedParts, err := cb.promptRegistryOrDefault().Collect(context.Background(), PromptBuildRequest{ + Summary: summary, + ActiveSkills: append([]string(nil), activeSkills...), + }); err == nil { + for _, part := range contributedParts { + if strings.TrimSpace(part.Content) == "" { + continue + } + totalChars += utf8.RuneCountInString(part.Content) + totalChars += 7 // separator + } + } + if summary != "" { // Matches the CONTEXT_SUMMARY: prefix added in BuildMessages const summaryPrefix = "CONTEXT_SUMMARY: The following is an approximate summary of prior conversation " + diff --git a/pkg/agent/hooks.go b/pkg/agent/hooks.go index e9863c4f6..69336b89f 100644 --- a/pkg/agent/hooks.go +++ b/pkg/agent/hooks.go @@ -377,14 +377,18 @@ func (hm *HookManager) applyBeforeLLMControls( if next == nil || current == nil { return next } - if llmHookSystemMessagesUnchanged(current.Messages, next.Messages) { - return next + if !llmHookSystemMessagesUnchanged(current.Messages, next.Messages) { + logger.WarnCF("hooks", "Hook attempted to modify system prompt; preserving original messages", map[string]any{ + "hook": hookName, + }) + next.Messages = cloneProviderMessages(current.Messages) + } + if !llmHookToolDefinitionsUnchanged(current.Tools, next.Tools) { + logger.WarnCF("hooks", "Hook attempted to modify tool definitions; preserving original tools", map[string]any{ + "hook": hookName, + }) + next.Tools = cloneToolDefinitions(current.Tools) } - - logger.WarnCF("hooks", "Hook attempted to modify system prompt; preserving original messages", map[string]any{ - "hook": hookName, - }) - next.Messages = cloneProviderMessages(current.Messages) return next } @@ -413,6 +417,10 @@ func systemMessageFingerprints(messages []providers.Message) []systemMessageFing return fingerprints } +func llmHookToolDefinitionsUnchanged(before, after []providers.ToolDefinition) bool { + return reflect.DeepEqual(cloneToolDefinitions(before), cloneToolDefinitions(after)) +} + func (hm *HookManager) BeforeTool( ctx context.Context, call *ToolCallHookRequest, diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go index 419752e5b..a936296dd 100644 --- a/pkg/agent/hooks_test.go +++ b/pkg/agent/hooks_test.go @@ -186,6 +186,36 @@ func (h *llmUserAppendHook) AfterLLM( return resp.Clone(), HookDecision{Action: HookActionContinue}, nil } +type llmToolRewriteHook struct{} + +func (h *llmToolRewriteHook) BeforeLLM( + ctx context.Context, + req *LLMHookRequest, +) (*LLMHookRequest, HookDecision, error) { + next := req.Clone() + next.Model = "changed-model" + next.Tools[0].Function.Description = "rewritten tool" + next.Tools = append(next.Tools, providers.ToolDefinition{ + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "hook_tool", + Description: "hook tool", + Parameters: map[string]any{"type": "object"}, + }, + PromptLayer: string(PromptLayerCapability), + PromptSlot: string(PromptSlotTooling), + PromptSource: "hook:test", + }) + return next, HookDecision{Action: HookActionModify}, nil +} + +func (h *llmToolRewriteHook) AfterLLM( + ctx context.Context, + resp *LLMHookResponse, +) (*LLMHookResponse, HookDecision, error) { + return resp.Clone(), HookDecision{Action: HookActionContinue}, nil +} + func TestHookManager_BeforeLLMControlsSystemPromptMutation(t *testing.T) { hm := NewHookManager(nil) if err := hm.Mount(NamedHook("rewrite-system", &llmSystemRewriteHook{})); err != nil { @@ -244,6 +274,51 @@ func TestHookManager_BeforeLLMAllowsNonSystemMessageMutation(t *testing.T) { } } +func TestHookManager_BeforeLLMControlsToolDefinitionMutation(t *testing.T) { + hm := NewHookManager(nil) + if err := hm.Mount(NamedHook("rewrite-tool", &llmToolRewriteHook{})); err != nil { + t.Fatalf("Mount() error = %v", err) + } + + req := &LLMHookRequest{ + Model: "original-model", + Messages: []providers.Message{ + {Role: "system", Content: "system"}, + {Role: "user", Content: "hello"}, + }, + Tools: []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "mcp_github_create_issue", + Description: "create issue", + Parameters: map[string]any{"type": "object"}, + }, + PromptLayer: string(PromptLayerCapability), + PromptSlot: string(PromptSlotMCP), + PromptSource: "mcp:github", + }, + }, + } + + got, decision := hm.BeforeLLM(context.Background(), req) + if decision.normalizedAction() != HookActionContinue { + t.Fatalf("decision = %v, want continue", decision) + } + if got.Model != "changed-model" { + t.Fatalf("model = %q, want changed-model", got.Model) + } + if len(got.Tools) != 1 { + t.Fatalf("tools len = %d, want original 1", len(got.Tools)) + } + if got.Tools[0].Function.Description != "create issue" { + t.Fatalf("tool description = %q, want original", got.Tools[0].Function.Description) + } + if got.Tools[0].PromptSource != "mcp:github" || got.Tools[0].PromptSlot != string(PromptSlotMCP) { + t.Fatalf("tool prompt metadata = %#v, want original mcp metadata", got.Tools[0]) + } +} + func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) { provider := &llmHookTestProvider{} al, agent, cleanup := newHookTestLoop(t, provider) diff --git a/pkg/agent/prompt.go b/pkg/agent/prompt.go index aa06bba46..be5ccddf2 100644 --- a/pkg/agent/prompt.go +++ b/pkg/agent/prompt.go @@ -52,6 +52,7 @@ const ( PromptSourceMemory PromptSourceID = "memory:workspace" PromptSourceSkillCatalog PromptSourceID = "skill:index" PromptSourceActiveSkills PromptSourceID = "skill:active" + PromptSourceToolRegistry PromptSourceID = "tool_registry:native" PromptSourceToolDiscovery PromptSourceID = "tool_registry:discovery" PromptSourceOutputPolicy PromptSourceID = "runtime.output" PromptSourceSubTurnProfile PromptSourceID = "subturn.profile" @@ -173,6 +174,13 @@ func builtinPromptSources() []PromptSourceDescriptor { Allowed: []PromptPlacement{{Layer: PromptLayerCapability, Slot: PromptSlotTooling}}, StableByDefault: true, }, + { + ID: PromptSourceToolRegistry, + Owner: "tools", + Description: "Native provider tool definitions", + Allowed: []PromptPlacement{{Layer: PromptLayerCapability, Slot: PromptSlotTooling}}, + StableByDefault: true, + }, { ID: PromptSourceSkillCatalog, Owner: "skills", @@ -278,12 +286,17 @@ func (r *PromptRegistry) RegisterContributor(contributor PromptContributor) erro if contributor == nil { return fmt.Errorf("prompt contributor is nil") } - if err := r.RegisterSource(contributor.PromptSource()); err != nil { + desc := contributor.PromptSource() + desc.ID = PromptSourceID(strings.TrimSpace(string(desc.ID))) + if err := r.RegisterSource(desc); err != nil { return err } r.mu.Lock() defer r.mu.Unlock() + r.contributors = slices.DeleteFunc(r.contributors, func(existing PromptContributor) bool { + return PromptSourceID(strings.TrimSpace(string(existing.PromptSource().ID))) == desc.ID + }) r.contributors = append(r.contributors, contributor) return nil } diff --git a/pkg/agent/prompt_contributors.go b/pkg/agent/prompt_contributors.go new file mode 100644 index 000000000..960572e03 --- /dev/null +++ b/pkg/agent/prompt_contributors.go @@ -0,0 +1,139 @@ +package agent + +import ( + "context" + "fmt" + "strings" +) + +type toolDiscoveryPromptContributor struct { + useBM25 bool + useRegex bool +} + +func (c toolDiscoveryPromptContributor) PromptSource() PromptSourceDescriptor { + return PromptSourceDescriptor{ + ID: PromptSourceToolDiscovery, + Owner: "tools", + Description: "Tool discovery instructions", + Allowed: []PromptPlacement{{Layer: PromptLayerCapability, Slot: PromptSlotTooling}}, + StableByDefault: true, + } +} + +func (c toolDiscoveryPromptContributor) ContributePrompt( + _ context.Context, + _ PromptBuildRequest, +) ([]PromptPart, error) { + content := formatToolDiscoveryRule(c.useBM25, c.useRegex) + if strings.TrimSpace(content) == "" { + return nil, nil + } + + return []PromptPart{ + { + ID: "capability.tool_discovery", + Layer: PromptLayerCapability, + Slot: PromptSlotTooling, + Source: PromptSource{ID: PromptSourceToolDiscovery, Name: "tool_registry:discovery"}, + Title: "tool discovery", + Content: content, + Stable: true, + Cache: PromptCacheEphemeral, + }, + }, nil +} + +type mcpServerPromptContributor struct { + serverName string + toolCount int + deferred bool +} + +func (c mcpServerPromptContributor) PromptSource() PromptSourceDescriptor { + return PromptSourceDescriptor{ + ID: mcpPromptSourceID(c.serverName), + Owner: "mcp", + Description: fmt.Sprintf("MCP server %q capability prompt", c.serverName), + Allowed: []PromptPlacement{{Layer: PromptLayerCapability, Slot: PromptSlotMCP}}, + StableByDefault: true, + } +} + +func (c mcpServerPromptContributor) ContributePrompt( + _ context.Context, + _ PromptBuildRequest, +) ([]PromptPart, error) { + serverName := strings.TrimSpace(c.serverName) + if serverName == "" || c.toolCount <= 0 { + return nil, nil + } + + availability := "available as native tools" + if c.deferred { + availability = "hidden behind tool discovery until unlocked" + } + + return []PromptPart{ + { + ID: "capability.mcp." + promptSourceComponent(serverName), + Layer: PromptLayerCapability, + Slot: PromptSlotMCP, + Source: PromptSource{ID: mcpPromptSourceID(serverName), Name: "mcp:" + serverName}, + Title: "MCP server capability", + Content: fmt.Sprintf( + "MCP server `%s` is connected. It contributes %d tool(s), currently %s.", + serverName, + c.toolCount, + availability, + ), + Stable: true, + Cache: PromptCacheEphemeral, + }, + }, nil +} + +func mcpPromptSourceID(serverName string) PromptSourceID { + return PromptSourceID("mcp:" + promptSourceComponent(serverName)) +} + +func promptSourceComponent(value string) string { + const maxLen = 64 + + value = strings.ToLower(strings.TrimSpace(value)) + if value == "" { + return "unnamed" + } + + var b strings.Builder + lastWasSep := false + for _, r := range value { + switch { + case r >= 'a' && r <= 'z': + b.WriteRune(r) + lastWasSep = false + case r >= '0' && r <= '9': + b.WriteRune(r) + lastWasSep = false + case r == '-' || r == '_': + if !lastWasSep && b.Len() > 0 { + b.WriteRune(r) + lastWasSep = true + } + default: + if !lastWasSep && b.Len() > 0 { + b.WriteRune('_') + lastWasSep = true + } + } + } + + result := strings.Trim(b.String(), "_") + if result == "" { + return "unnamed" + } + if len(result) > maxLen { + return result[:maxLen] + } + return result +} diff --git a/pkg/agent/prompt_test.go b/pkg/agent/prompt_test.go index af46eae7a..b3f609610 100644 --- a/pkg/agent/prompt_test.go +++ b/pkg/agent/prompt_test.go @@ -164,6 +164,62 @@ func TestBuildMessagesFromPrompt_AttachesInternalPromptMetadata(t *testing.T) { } } +func TestContextBuilder_CollectsToolDiscoveryContributor(t *testing.T) { + t.Setenv("PICOCLAW_BUILTIN_SKILLS", t.TempDir()) + cb := NewContextBuilder(t.TempDir()).WithToolDiscovery(true, false) + + messages := cb.BuildMessagesFromPrompt(PromptBuildRequest{CurrentMessage: "hello"}) + system := messages[0] + if !strings.Contains(system.Content, "tool_search_tool_bm25") { + t.Fatalf("system prompt missing tool discovery rule: %q", system.Content) + } + + var found bool + for _, part := range system.SystemParts { + if part.PromptSource == string(PromptSourceToolDiscovery) { + found = true + if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotTooling) { + t.Fatalf("tool discovery metadata = %#v, want capability/tooling", part) + } + } + } + if !found { + t.Fatal("system parts missing tool discovery prompt metadata") + } +} + +func TestContextBuilder_CollectsMCPServerContributor(t *testing.T) { + t.Setenv("PICOCLAW_BUILTIN_SKILLS", t.TempDir()) + cb := NewContextBuilder(t.TempDir()) + err := cb.RegisterPromptContributor(mcpServerPromptContributor{ + serverName: "GitHub Server", + toolCount: 3, + deferred: true, + }) + if err != nil { + t.Fatalf("RegisterPromptContributor() error = %v", err) + } + + messages := cb.BuildMessagesFromPrompt(PromptBuildRequest{CurrentMessage: "hello"}) + system := messages[0] + if !strings.Contains(system.Content, "MCP server `GitHub Server` is connected") { + t.Fatalf("system prompt missing MCP contributor content: %q", system.Content) + } + + var found bool + for _, part := range system.SystemParts { + if part.PromptSource == "mcp:github_server" { + found = true + if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotMCP) { + t.Fatalf("mcp metadata = %#v, want capability/mcp", part) + } + } + } + if !found { + t.Fatal("system parts missing MCP prompt metadata") + } +} + type testPromptContributor struct { desc PromptSourceDescriptor part PromptPart diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index ca123a2b2..bab4433e7 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -98,6 +98,13 @@ type Message struct { type ToolDefinition struct { Type string `json:"type"` Function ToolFunctionDefinition `json:"function"` + + // Prompt metadata is internal to the agent runtime. Tool definitions are + // model-visible capability prompts even though providers send them outside + // the system message. + PromptLayer string `json:"-"` + PromptSlot string `json:"-"` + PromptSource string `json:"-"` } type ToolFunctionDefinition struct { diff --git a/pkg/tools/integration/mcp_tool.go b/pkg/tools/integration/mcp_tool.go index 340bb9e8e..78c348316 100644 --- a/pkg/tools/integration/mcp_tool.go +++ b/pkg/tools/integration/mcp_tool.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" + toolshared "github.com/sipeed/picoclaw/pkg/tools/shared" ) // MCPManager defines the interface for MCP manager operations @@ -161,6 +162,14 @@ func (t *MCPTool) Description() string { return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc) } +func (t *MCPTool) PromptMetadata() toolshared.PromptMetadata { + return toolshared.PromptMetadata{ + Layer: toolshared.ToolPromptLayerCapability, + Slot: toolshared.ToolPromptSlotMCP, + Source: "mcp:" + sanitizeIdentifierComponent(t.serverName), + } +} + // Parameters returns the tool parameters schema func (t *MCPTool) Parameters() map[string]any { // The InputSchema is already a JSON Schema object diff --git a/pkg/tools/integration/mcp_tool_test.go b/pkg/tools/integration/mcp_tool_test.go index e5c54abb6..7b0b2cd5a 100644 --- a/pkg/tools/integration/mcp_tool_test.go +++ b/pkg/tools/integration/mcp_tool_test.go @@ -11,6 +11,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/sipeed/picoclaw/pkg/media" + toolshared "github.com/sipeed/picoclaw/pkg/tools/shared" ) // MockMCPManager is a mock implementation of MCPManager interface for testing @@ -104,6 +105,22 @@ func TestMCPTool_Name(t *testing.T) { } } +func TestMCPTool_PromptMetadata(t *testing.T) { + manager := &MockMCPManager{} + tool := NewMCPTool(manager, "GitHub Server", &mcp.Tool{Name: "create_issue"}) + + metadata := tool.PromptMetadata() + if metadata.Layer != toolshared.ToolPromptLayerCapability { + t.Fatalf("metadata.Layer = %q, want %q", metadata.Layer, toolshared.ToolPromptLayerCapability) + } + if metadata.Slot != toolshared.ToolPromptSlotMCP { + t.Fatalf("metadata.Slot = %q, want %q", metadata.Slot, toolshared.ToolPromptSlotMCP) + } + if metadata.Source != "mcp:github_server" { + t.Fatalf("metadata.Source = %q, want mcp:github_server", metadata.Source) + } +} + // TestMCPTool_Description verifies tool description generation func TestMCPTool_Description(t *testing.T) { tests := []struct { diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index e51dff71a..0ff9293a3 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -352,6 +352,7 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { name, _ := fn["name"].(string) desc, _ := fn["description"].(string) params, _ := fn["parameters"].(map[string]any) + metadata := promptMetadataForTool(entry.Tool) definitions = append(definitions, providers.ToolDefinition{ Type: "function", @@ -360,11 +361,35 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { Description: desc, Parameters: params, }, + PromptLayer: metadata.Layer, + PromptSlot: metadata.Slot, + PromptSource: metadata.Source, }) } return definitions } +func promptMetadataForTool(tool Tool) PromptMetadata { + metadata := PromptMetadata{ + Layer: ToolPromptLayerCapability, + Slot: ToolPromptSlotTooling, + Source: ToolPromptSourceRegistry, + } + if provider, ok := tool.(PromptMetadataProvider); ok { + provided := provider.PromptMetadata() + if provided.Layer != "" { + metadata.Layer = provided.Layer + } + if provided.Slot != "" { + metadata.Slot = provided.Slot + } + if provided.Source != "" { + metadata.Source = provided.Source + } + } + return metadata +} + // List returns a list of all registered tool names. func (r *ToolRegistry) List() []string { r.mu.RLock() diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 16bd30928..eac96382f 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -39,6 +39,15 @@ func (m *mockContextAwareTool) Execute(ctx context.Context, _ map[string]any) *T return m.result } +type mockPromptMetadataTool struct { + mockRegistryTool + metadata PromptMetadata +} + +func (m *mockPromptMetadataTool) PromptMetadata() PromptMetadata { + return m.metadata +} + type mockAsyncRegistryTool struct { mockRegistryTool lastCB AsyncCallback @@ -375,6 +384,47 @@ func TestToolToSchema(t *testing.T) { } } +func TestToolRegistry_ToProviderDefsAttachesPromptMetadata(t *testing.T) { + r := NewToolRegistry() + r.Register(newMockTool("native", "native tool")) + r.Register(&mockPromptMetadataTool{ + mockRegistryTool: mockRegistryTool{ + name: "mcp_demo", + desc: "mcp tool", + params: map[string]any{"type": "object"}, + }, + metadata: PromptMetadata{ + Layer: ToolPromptLayerCapability, + Slot: ToolPromptSlotMCP, + Source: "mcp:demo", + }, + }) + + defs := r.ToProviderDefs() + if len(defs) != 2 { + t.Fatalf("ToProviderDefs() len = %d, want 2", len(defs)) + } + + byName := make(map[string]providers.ToolDefinition, len(defs)) + for _, def := range defs { + byName[def.Function.Name] = def + } + + native := byName["native"] + if native.PromptLayer != ToolPromptLayerCapability || + native.PromptSlot != ToolPromptSlotTooling || + native.PromptSource != ToolPromptSourceRegistry { + t.Fatalf("native prompt metadata = %#v, want default tooling source", native) + } + + mcp := byName["mcp_demo"] + if mcp.PromptLayer != ToolPromptLayerCapability || + mcp.PromptSlot != ToolPromptSlotMCP || + mcp.PromptSource != "mcp:demo" { + t.Fatalf("mcp prompt metadata = %#v, want mcp source", mcp) + } +} + func TestToolRegistry_Clone(t *testing.T) { r := NewToolRegistry() r.Register(newMockTool("read_file", "reads files")) diff --git a/pkg/tools/search_tool.go b/pkg/tools/search_tool.go index f41c80d90..c5884c9de 100644 --- a/pkg/tools/search_tool.go +++ b/pkg/tools/search_tool.go @@ -34,6 +34,14 @@ func (t *RegexSearchTool) Description() string { return "Search available hidden tools on-demand using a regex pattern. Returns JSON schemas of discovered tools." } +func (t *RegexSearchTool) PromptMetadata() PromptMetadata { + return PromptMetadata{ + Layer: ToolPromptLayerCapability, + Slot: ToolPromptSlotTooling, + Source: ToolPromptSourceDiscovery, + } +} + func (t *RegexSearchTool) Parameters() map[string]any { return map[string]any{ "type": "object", @@ -95,6 +103,14 @@ func (t *BM25SearchTool) Description() string { return "Search available hidden tools on-demand using natural language query describing the action you need to perform. Returns JSON schemas of discovered tools." } +func (t *BM25SearchTool) PromptMetadata() PromptMetadata { + return PromptMetadata{ + Layer: ToolPromptLayerCapability, + Slot: ToolPromptSlotTooling, + Source: ToolPromptSourceDiscovery, + } +} + func (t *BM25SearchTool) Parameters() map[string]any { return map[string]any{ "type": "object", diff --git a/pkg/tools/shared/base.go b/pkg/tools/shared/base.go index 5498d24ab..298e1b478 100644 --- a/pkg/tools/shared/base.go +++ b/pkg/tools/shared/base.go @@ -14,6 +14,24 @@ type Tool interface { Execute(ctx context.Context, args map[string]any) *ToolResult } +const ( + ToolPromptLayerCapability = "capability" + ToolPromptSlotTooling = "tooling" + ToolPromptSlotMCP = "mcp" + ToolPromptSourceRegistry = "tool_registry:native" + ToolPromptSourceDiscovery = "tool_registry:discovery" +) + +type PromptMetadata struct { + Layer string + Slot string + Source string +} + +type PromptMetadataProvider interface { + PromptMetadata() PromptMetadata +} + // --- Request-scoped tool context (channel / chatID) --- // // Carried via context.Value so that concurrent tool calls each receive diff --git a/pkg/tools/shared_facade.go b/pkg/tools/shared_facade.go index 6e40e4e3a..8409ea060 100644 --- a/pkg/tools/shared_facade.go +++ b/pkg/tools/shared_facade.go @@ -22,12 +22,20 @@ type ( Tool = toolshared.Tool AsyncCallback = toolshared.AsyncCallback AsyncExecutor = toolshared.AsyncExecutor + PromptMetadata = toolshared.PromptMetadata + PromptMetadataProvider = toolshared.PromptMetadataProvider ToolResult = toolshared.ToolResult ) const ( handledToolLLMNote = toolshared.HandledToolLLMNote artifactPathsLLMNote = toolshared.ArtifactPathsLLMNote + + ToolPromptLayerCapability = toolshared.ToolPromptLayerCapability + ToolPromptSlotTooling = toolshared.ToolPromptSlotTooling + ToolPromptSlotMCP = toolshared.ToolPromptSlotMCP + ToolPromptSourceRegistry = toolshared.ToolPromptSourceRegistry + ToolPromptSourceDiscovery = toolshared.ToolPromptSourceDiscovery ) func WithToolContext(ctx context.Context, channel, chatID string) context.Context {