mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): migrate tool prompts to capability slots
This commit is contained in:
@@ -135,6 +135,25 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
serverCfg := al.cfg.Tools.MCP.Servers[serverName]
|
||||
registerAsHidden := serverIsDeferred(al.cfg.Tools.MCP.Discovery.Enabled, serverCfg)
|
||||
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok || agent.ContextBuilder == nil {
|
||||
continue
|
||||
}
|
||||
if err := agent.ContextBuilder.RegisterPromptContributor(mcpServerPromptContributor{
|
||||
serverName: serverName,
|
||||
toolCount: len(conn.Tools),
|
||||
deferred: registerAsHidden,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Failed to register MCP prompt contributor",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, tool := range conn.Tools {
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
|
||||
+42
-28
@@ -22,13 +22,11 @@ import (
|
||||
)
|
||||
|
||||
type ContextBuilder struct {
|
||||
workspace string
|
||||
skillsLoader *skills.SkillsLoader
|
||||
memory *MemoryStore
|
||||
toolDiscoveryBM25 bool
|
||||
toolDiscoveryRegex bool
|
||||
splitOnMarker bool
|
||||
promptRegistry *PromptRegistry
|
||||
workspace string
|
||||
skillsLoader *skills.SkillsLoader
|
||||
memory *MemoryStore
|
||||
splitOnMarker bool
|
||||
promptRegistry *PromptRegistry
|
||||
|
||||
// Cache for system prompt to avoid rebuilding on every call.
|
||||
// This fixes issue #607: repeated reprocessing of the entire context.
|
||||
@@ -50,8 +48,16 @@ type ContextBuilder struct {
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuilder {
|
||||
cb.toolDiscoveryBM25 = useBM25
|
||||
cb.toolDiscoveryRegex = useRegex
|
||||
if useBM25 || useRegex {
|
||||
if err := cb.RegisterPromptContributor(toolDiscoveryPromptContributor{
|
||||
useBM25: useBM25,
|
||||
useRegex: useRegex,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Failed to register tool discovery prompt contributor", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return cb
|
||||
}
|
||||
|
||||
@@ -83,11 +89,19 @@ func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) RegisterPromptSource(desc PromptSourceDescriptor) error {
|
||||
return cb.promptRegistryOrDefault().RegisterSource(desc)
|
||||
err := cb.promptRegistryOrDefault().RegisterSource(desc)
|
||||
if err == nil {
|
||||
cb.InvalidateCache()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) RegisterPromptContributor(contributor PromptContributor) error {
|
||||
return cb.promptRegistryOrDefault().RegisterContributor(contributor)
|
||||
err := cb.promptRegistryOrDefault().RegisterContributor(contributor)
|
||||
if err == nil {
|
||||
cb.InvalidateCache()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) promptRegistryOrDefault() *PromptRegistry {
|
||||
@@ -124,16 +138,16 @@ Your workspace is at: %s
|
||||
version, workspacePath, workspacePath, workspacePath, workspacePath, workspacePath)
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) getDiscoveryRule() string {
|
||||
if !cb.toolDiscoveryBM25 && !cb.toolDiscoveryRegex {
|
||||
func formatToolDiscoveryRule(useBM25, useRegex bool) string {
|
||||
if !useBM25 && !useRegex {
|
||||
return ""
|
||||
}
|
||||
|
||||
var toolNames []string
|
||||
if cb.toolDiscoveryBM25 {
|
||||
if useBM25 {
|
||||
toolNames = append(toolNames, `"tool_search_tool_bm25"`)
|
||||
}
|
||||
if cb.toolDiscoveryRegex {
|
||||
if useRegex {
|
||||
toolNames = append(toolNames, `"tool_search_tool_regex"`)
|
||||
}
|
||||
|
||||
@@ -173,19 +187,6 @@ func (cb *ContextBuilder) BuildSystemPromptParts() []PromptPart {
|
||||
Cache: PromptCacheEphemeral,
|
||||
})
|
||||
|
||||
if toolDiscovery := cb.getDiscoveryRule(); toolDiscovery != "" {
|
||||
add(PromptPart{
|
||||
ID: "capability.tool_discovery",
|
||||
Layer: PromptLayerCapability,
|
||||
Slot: PromptSlotTooling,
|
||||
Source: PromptSource{ID: PromptSourceToolDiscovery, Name: "tool_registry:discovery"},
|
||||
Title: "tool discovery",
|
||||
Content: toolDiscovery,
|
||||
Stable: true,
|
||||
Cache: PromptCacheEphemeral,
|
||||
})
|
||||
}
|
||||
|
||||
// Bootstrap files
|
||||
bootstrapContent := cb.LoadBootstrapFiles()
|
||||
if bootstrapContent != "" {
|
||||
@@ -318,6 +319,19 @@ func (cb *ContextBuilder) EstimateSystemTokens(summary string, activeSkills []st
|
||||
totalChars += 7 // separator \n\n---\n\n
|
||||
}
|
||||
|
||||
if contributedParts, err := cb.promptRegistryOrDefault().Collect(context.Background(), PromptBuildRequest{
|
||||
Summary: summary,
|
||||
ActiveSkills: append([]string(nil), activeSkills...),
|
||||
}); err == nil {
|
||||
for _, part := range contributedParts {
|
||||
if strings.TrimSpace(part.Content) == "" {
|
||||
continue
|
||||
}
|
||||
totalChars += utf8.RuneCountInString(part.Content)
|
||||
totalChars += 7 // separator
|
||||
}
|
||||
}
|
||||
|
||||
if summary != "" {
|
||||
// Matches the CONTEXT_SUMMARY: prefix added in BuildMessages
|
||||
const summaryPrefix = "CONTEXT_SUMMARY: The following is an approximate summary of prior conversation " +
|
||||
|
||||
+15
-7
@@ -377,14 +377,18 @@ func (hm *HookManager) applyBeforeLLMControls(
|
||||
if next == nil || current == nil {
|
||||
return next
|
||||
}
|
||||
if llmHookSystemMessagesUnchanged(current.Messages, next.Messages) {
|
||||
return next
|
||||
if !llmHookSystemMessagesUnchanged(current.Messages, next.Messages) {
|
||||
logger.WarnCF("hooks", "Hook attempted to modify system prompt; preserving original messages", map[string]any{
|
||||
"hook": hookName,
|
||||
})
|
||||
next.Messages = cloneProviderMessages(current.Messages)
|
||||
}
|
||||
if !llmHookToolDefinitionsUnchanged(current.Tools, next.Tools) {
|
||||
logger.WarnCF("hooks", "Hook attempted to modify tool definitions; preserving original tools", map[string]any{
|
||||
"hook": hookName,
|
||||
})
|
||||
next.Tools = cloneToolDefinitions(current.Tools)
|
||||
}
|
||||
|
||||
logger.WarnCF("hooks", "Hook attempted to modify system prompt; preserving original messages", map[string]any{
|
||||
"hook": hookName,
|
||||
})
|
||||
next.Messages = cloneProviderMessages(current.Messages)
|
||||
return next
|
||||
}
|
||||
|
||||
@@ -413,6 +417,10 @@ func systemMessageFingerprints(messages []providers.Message) []systemMessageFing
|
||||
return fingerprints
|
||||
}
|
||||
|
||||
func llmHookToolDefinitionsUnchanged(before, after []providers.ToolDefinition) bool {
|
||||
return reflect.DeepEqual(cloneToolDefinitions(before), cloneToolDefinitions(after))
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
|
||||
@@ -186,6 +186,36 @@ func (h *llmUserAppendHook) AfterLLM(
|
||||
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
type llmToolRewriteHook struct{}
|
||||
|
||||
func (h *llmToolRewriteHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = "changed-model"
|
||||
next.Tools[0].Function.Description = "rewritten tool"
|
||||
next.Tools = append(next.Tools, providers.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "hook_tool",
|
||||
Description: "hook tool",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
PromptLayer: string(PromptLayerCapability),
|
||||
PromptSlot: string(PromptSlotTooling),
|
||||
PromptSource: "hook:test",
|
||||
})
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *llmToolRewriteHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func TestHookManager_BeforeLLMControlsSystemPromptMutation(t *testing.T) {
|
||||
hm := NewHookManager(nil)
|
||||
if err := hm.Mount(NamedHook("rewrite-system", &llmSystemRewriteHook{})); err != nil {
|
||||
@@ -244,6 +274,51 @@ func TestHookManager_BeforeLLMAllowsNonSystemMessageMutation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookManager_BeforeLLMControlsToolDefinitionMutation(t *testing.T) {
|
||||
hm := NewHookManager(nil)
|
||||
if err := hm.Mount(NamedHook("rewrite-tool", &llmToolRewriteHook{})); err != nil {
|
||||
t.Fatalf("Mount() error = %v", err)
|
||||
}
|
||||
|
||||
req := &LLMHookRequest{
|
||||
Model: "original-model",
|
||||
Messages: []providers.Message{
|
||||
{Role: "system", Content: "system"},
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
Tools: []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "mcp_github_create_issue",
|
||||
Description: "create issue",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
PromptLayer: string(PromptLayerCapability),
|
||||
PromptSlot: string(PromptSlotMCP),
|
||||
PromptSource: "mcp:github",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, decision := hm.BeforeLLM(context.Background(), req)
|
||||
if decision.normalizedAction() != HookActionContinue {
|
||||
t.Fatalf("decision = %v, want continue", decision)
|
||||
}
|
||||
if got.Model != "changed-model" {
|
||||
t.Fatalf("model = %q, want changed-model", got.Model)
|
||||
}
|
||||
if len(got.Tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want original 1", len(got.Tools))
|
||||
}
|
||||
if got.Tools[0].Function.Description != "create issue" {
|
||||
t.Fatalf("tool description = %q, want original", got.Tools[0].Function.Description)
|
||||
}
|
||||
if got.Tools[0].PromptSource != "mcp:github" || got.Tools[0].PromptSlot != string(PromptSlotMCP) {
|
||||
t.Fatalf("tool prompt metadata = %#v, want original mcp metadata", got.Tools[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
|
||||
+14
-1
@@ -52,6 +52,7 @@ const (
|
||||
PromptSourceMemory PromptSourceID = "memory:workspace"
|
||||
PromptSourceSkillCatalog PromptSourceID = "skill:index"
|
||||
PromptSourceActiveSkills PromptSourceID = "skill:active"
|
||||
PromptSourceToolRegistry PromptSourceID = "tool_registry:native"
|
||||
PromptSourceToolDiscovery PromptSourceID = "tool_registry:discovery"
|
||||
PromptSourceOutputPolicy PromptSourceID = "runtime.output"
|
||||
PromptSourceSubTurnProfile PromptSourceID = "subturn.profile"
|
||||
@@ -173,6 +174,13 @@ func builtinPromptSources() []PromptSourceDescriptor {
|
||||
Allowed: []PromptPlacement{{Layer: PromptLayerCapability, Slot: PromptSlotTooling}},
|
||||
StableByDefault: true,
|
||||
},
|
||||
{
|
||||
ID: PromptSourceToolRegistry,
|
||||
Owner: "tools",
|
||||
Description: "Native provider tool definitions",
|
||||
Allowed: []PromptPlacement{{Layer: PromptLayerCapability, Slot: PromptSlotTooling}},
|
||||
StableByDefault: true,
|
||||
},
|
||||
{
|
||||
ID: PromptSourceSkillCatalog,
|
||||
Owner: "skills",
|
||||
@@ -278,12 +286,17 @@ func (r *PromptRegistry) RegisterContributor(contributor PromptContributor) erro
|
||||
if contributor == nil {
|
||||
return fmt.Errorf("prompt contributor is nil")
|
||||
}
|
||||
if err := r.RegisterSource(contributor.PromptSource()); err != nil {
|
||||
desc := contributor.PromptSource()
|
||||
desc.ID = PromptSourceID(strings.TrimSpace(string(desc.ID)))
|
||||
if err := r.RegisterSource(desc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.contributors = slices.DeleteFunc(r.contributors, func(existing PromptContributor) bool {
|
||||
return PromptSourceID(strings.TrimSpace(string(existing.PromptSource().ID))) == desc.ID
|
||||
})
|
||||
r.contributors = append(r.contributors, contributor)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type toolDiscoveryPromptContributor struct {
|
||||
useBM25 bool
|
||||
useRegex bool
|
||||
}
|
||||
|
||||
func (c toolDiscoveryPromptContributor) PromptSource() PromptSourceDescriptor {
|
||||
return PromptSourceDescriptor{
|
||||
ID: PromptSourceToolDiscovery,
|
||||
Owner: "tools",
|
||||
Description: "Tool discovery instructions",
|
||||
Allowed: []PromptPlacement{{Layer: PromptLayerCapability, Slot: PromptSlotTooling}},
|
||||
StableByDefault: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (c toolDiscoveryPromptContributor) ContributePrompt(
|
||||
_ context.Context,
|
||||
_ PromptBuildRequest,
|
||||
) ([]PromptPart, error) {
|
||||
content := formatToolDiscoveryRule(c.useBM25, c.useRegex)
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []PromptPart{
|
||||
{
|
||||
ID: "capability.tool_discovery",
|
||||
Layer: PromptLayerCapability,
|
||||
Slot: PromptSlotTooling,
|
||||
Source: PromptSource{ID: PromptSourceToolDiscovery, Name: "tool_registry:discovery"},
|
||||
Title: "tool discovery",
|
||||
Content: content,
|
||||
Stable: true,
|
||||
Cache: PromptCacheEphemeral,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mcpServerPromptContributor struct {
|
||||
serverName string
|
||||
toolCount int
|
||||
deferred bool
|
||||
}
|
||||
|
||||
func (c mcpServerPromptContributor) PromptSource() PromptSourceDescriptor {
|
||||
return PromptSourceDescriptor{
|
||||
ID: mcpPromptSourceID(c.serverName),
|
||||
Owner: "mcp",
|
||||
Description: fmt.Sprintf("MCP server %q capability prompt", c.serverName),
|
||||
Allowed: []PromptPlacement{{Layer: PromptLayerCapability, Slot: PromptSlotMCP}},
|
||||
StableByDefault: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (c mcpServerPromptContributor) ContributePrompt(
|
||||
_ context.Context,
|
||||
_ PromptBuildRequest,
|
||||
) ([]PromptPart, error) {
|
||||
serverName := strings.TrimSpace(c.serverName)
|
||||
if serverName == "" || c.toolCount <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
availability := "available as native tools"
|
||||
if c.deferred {
|
||||
availability = "hidden behind tool discovery until unlocked"
|
||||
}
|
||||
|
||||
return []PromptPart{
|
||||
{
|
||||
ID: "capability.mcp." + promptSourceComponent(serverName),
|
||||
Layer: PromptLayerCapability,
|
||||
Slot: PromptSlotMCP,
|
||||
Source: PromptSource{ID: mcpPromptSourceID(serverName), Name: "mcp:" + serverName},
|
||||
Title: "MCP server capability",
|
||||
Content: fmt.Sprintf(
|
||||
"MCP server `%s` is connected. It contributes %d tool(s), currently %s.",
|
||||
serverName,
|
||||
c.toolCount,
|
||||
availability,
|
||||
),
|
||||
Stable: true,
|
||||
Cache: PromptCacheEphemeral,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func mcpPromptSourceID(serverName string) PromptSourceID {
|
||||
return PromptSourceID("mcp:" + promptSourceComponent(serverName))
|
||||
}
|
||||
|
||||
func promptSourceComponent(value string) string {
|
||||
const maxLen = 64
|
||||
|
||||
value = strings.ToLower(strings.TrimSpace(value))
|
||||
if value == "" {
|
||||
return "unnamed"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
lastWasSep := false
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
b.WriteRune(r)
|
||||
lastWasSep = false
|
||||
case r >= '0' && r <= '9':
|
||||
b.WriteRune(r)
|
||||
lastWasSep = false
|
||||
case r == '-' || r == '_':
|
||||
if !lastWasSep && b.Len() > 0 {
|
||||
b.WriteRune(r)
|
||||
lastWasSep = true
|
||||
}
|
||||
default:
|
||||
if !lastWasSep && b.Len() > 0 {
|
||||
b.WriteRune('_')
|
||||
lastWasSep = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := strings.Trim(b.String(), "_")
|
||||
if result == "" {
|
||||
return "unnamed"
|
||||
}
|
||||
if len(result) > maxLen {
|
||||
return result[:maxLen]
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -164,6 +164,62 @@ func TestBuildMessagesFromPrompt_AttachesInternalPromptMetadata(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextBuilder_CollectsToolDiscoveryContributor(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_BUILTIN_SKILLS", t.TempDir())
|
||||
cb := NewContextBuilder(t.TempDir()).WithToolDiscovery(true, false)
|
||||
|
||||
messages := cb.BuildMessagesFromPrompt(PromptBuildRequest{CurrentMessage: "hello"})
|
||||
system := messages[0]
|
||||
if !strings.Contains(system.Content, "tool_search_tool_bm25") {
|
||||
t.Fatalf("system prompt missing tool discovery rule: %q", system.Content)
|
||||
}
|
||||
|
||||
var found bool
|
||||
for _, part := range system.SystemParts {
|
||||
if part.PromptSource == string(PromptSourceToolDiscovery) {
|
||||
found = true
|
||||
if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotTooling) {
|
||||
t.Fatalf("tool discovery metadata = %#v, want capability/tooling", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("system parts missing tool discovery prompt metadata")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextBuilder_CollectsMCPServerContributor(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_BUILTIN_SKILLS", t.TempDir())
|
||||
cb := NewContextBuilder(t.TempDir())
|
||||
err := cb.RegisterPromptContributor(mcpServerPromptContributor{
|
||||
serverName: "GitHub Server",
|
||||
toolCount: 3,
|
||||
deferred: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RegisterPromptContributor() error = %v", err)
|
||||
}
|
||||
|
||||
messages := cb.BuildMessagesFromPrompt(PromptBuildRequest{CurrentMessage: "hello"})
|
||||
system := messages[0]
|
||||
if !strings.Contains(system.Content, "MCP server `GitHub Server` is connected") {
|
||||
t.Fatalf("system prompt missing MCP contributor content: %q", system.Content)
|
||||
}
|
||||
|
||||
var found bool
|
||||
for _, part := range system.SystemParts {
|
||||
if part.PromptSource == "mcp:github_server" {
|
||||
found = true
|
||||
if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotMCP) {
|
||||
t.Fatalf("mcp metadata = %#v, want capability/mcp", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("system parts missing MCP prompt metadata")
|
||||
}
|
||||
}
|
||||
|
||||
type testPromptContributor struct {
|
||||
desc PromptSourceDescriptor
|
||||
part PromptPart
|
||||
|
||||
@@ -98,6 +98,13 @@ type Message struct {
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDefinition `json:"function"`
|
||||
|
||||
// Prompt metadata is internal to the agent runtime. Tool definitions are
|
||||
// model-visible capability prompts even though providers send them outside
|
||||
// the system message.
|
||||
PromptLayer string `json:"-"`
|
||||
PromptSlot string `json:"-"`
|
||||
PromptSource string `json:"-"`
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
toolshared "github.com/sipeed/picoclaw/pkg/tools/shared"
|
||||
)
|
||||
|
||||
// MCPManager defines the interface for MCP manager operations
|
||||
@@ -161,6 +162,14 @@ func (t *MCPTool) Description() string {
|
||||
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
|
||||
}
|
||||
|
||||
func (t *MCPTool) PromptMetadata() toolshared.PromptMetadata {
|
||||
return toolshared.PromptMetadata{
|
||||
Layer: toolshared.ToolPromptLayerCapability,
|
||||
Slot: toolshared.ToolPromptSlotMCP,
|
||||
Source: "mcp:" + sanitizeIdentifierComponent(t.serverName),
|
||||
}
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *MCPTool) Parameters() map[string]any {
|
||||
// The InputSchema is already a JSON Schema object
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
toolshared "github.com/sipeed/picoclaw/pkg/tools/shared"
|
||||
)
|
||||
|
||||
// MockMCPManager is a mock implementation of MCPManager interface for testing
|
||||
@@ -104,6 +105,22 @@ func TestMCPTool_Name(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_PromptMetadata(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := NewMCPTool(manager, "GitHub Server", &mcp.Tool{Name: "create_issue"})
|
||||
|
||||
metadata := tool.PromptMetadata()
|
||||
if metadata.Layer != toolshared.ToolPromptLayerCapability {
|
||||
t.Fatalf("metadata.Layer = %q, want %q", metadata.Layer, toolshared.ToolPromptLayerCapability)
|
||||
}
|
||||
if metadata.Slot != toolshared.ToolPromptSlotMCP {
|
||||
t.Fatalf("metadata.Slot = %q, want %q", metadata.Slot, toolshared.ToolPromptSlotMCP)
|
||||
}
|
||||
if metadata.Source != "mcp:github_server" {
|
||||
t.Fatalf("metadata.Source = %q, want mcp:github_server", metadata.Source)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Description verifies tool description generation
|
||||
func TestMCPTool_Description(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -352,6 +352,7 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
params, _ := fn["parameters"].(map[string]any)
|
||||
metadata := promptMetadataForTool(entry.Tool)
|
||||
|
||||
definitions = append(definitions, providers.ToolDefinition{
|
||||
Type: "function",
|
||||
@@ -360,11 +361,35 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
|
||||
Description: desc,
|
||||
Parameters: params,
|
||||
},
|
||||
PromptLayer: metadata.Layer,
|
||||
PromptSlot: metadata.Slot,
|
||||
PromptSource: metadata.Source,
|
||||
})
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
|
||||
func promptMetadataForTool(tool Tool) PromptMetadata {
|
||||
metadata := PromptMetadata{
|
||||
Layer: ToolPromptLayerCapability,
|
||||
Slot: ToolPromptSlotTooling,
|
||||
Source: ToolPromptSourceRegistry,
|
||||
}
|
||||
if provider, ok := tool.(PromptMetadataProvider); ok {
|
||||
provided := provider.PromptMetadata()
|
||||
if provided.Layer != "" {
|
||||
metadata.Layer = provided.Layer
|
||||
}
|
||||
if provided.Slot != "" {
|
||||
metadata.Slot = provided.Slot
|
||||
}
|
||||
if provided.Source != "" {
|
||||
metadata.Source = provided.Source
|
||||
}
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
// List returns a list of all registered tool names.
|
||||
func (r *ToolRegistry) List() []string {
|
||||
r.mu.RLock()
|
||||
|
||||
@@ -39,6 +39,15 @@ func (m *mockContextAwareTool) Execute(ctx context.Context, _ map[string]any) *T
|
||||
return m.result
|
||||
}
|
||||
|
||||
type mockPromptMetadataTool struct {
|
||||
mockRegistryTool
|
||||
metadata PromptMetadata
|
||||
}
|
||||
|
||||
func (m *mockPromptMetadataTool) PromptMetadata() PromptMetadata {
|
||||
return m.metadata
|
||||
}
|
||||
|
||||
type mockAsyncRegistryTool struct {
|
||||
mockRegistryTool
|
||||
lastCB AsyncCallback
|
||||
@@ -375,6 +384,47 @@ func TestToolToSchema(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ToProviderDefsAttachesPromptMetadata(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("native", "native tool"))
|
||||
r.Register(&mockPromptMetadataTool{
|
||||
mockRegistryTool: mockRegistryTool{
|
||||
name: "mcp_demo",
|
||||
desc: "mcp tool",
|
||||
params: map[string]any{"type": "object"},
|
||||
},
|
||||
metadata: PromptMetadata{
|
||||
Layer: ToolPromptLayerCapability,
|
||||
Slot: ToolPromptSlotMCP,
|
||||
Source: "mcp:demo",
|
||||
},
|
||||
})
|
||||
|
||||
defs := r.ToProviderDefs()
|
||||
if len(defs) != 2 {
|
||||
t.Fatalf("ToProviderDefs() len = %d, want 2", len(defs))
|
||||
}
|
||||
|
||||
byName := make(map[string]providers.ToolDefinition, len(defs))
|
||||
for _, def := range defs {
|
||||
byName[def.Function.Name] = def
|
||||
}
|
||||
|
||||
native := byName["native"]
|
||||
if native.PromptLayer != ToolPromptLayerCapability ||
|
||||
native.PromptSlot != ToolPromptSlotTooling ||
|
||||
native.PromptSource != ToolPromptSourceRegistry {
|
||||
t.Fatalf("native prompt metadata = %#v, want default tooling source", native)
|
||||
}
|
||||
|
||||
mcp := byName["mcp_demo"]
|
||||
if mcp.PromptLayer != ToolPromptLayerCapability ||
|
||||
mcp.PromptSlot != ToolPromptSlotMCP ||
|
||||
mcp.PromptSource != "mcp:demo" {
|
||||
t.Fatalf("mcp prompt metadata = %#v, want mcp source", mcp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Clone(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("read_file", "reads files"))
|
||||
|
||||
@@ -34,6 +34,14 @@ func (t *RegexSearchTool) Description() string {
|
||||
return "Search available hidden tools on-demand using a regex pattern. Returns JSON schemas of discovered tools."
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) PromptMetadata() PromptMetadata {
|
||||
return PromptMetadata{
|
||||
Layer: ToolPromptLayerCapability,
|
||||
Slot: ToolPromptSlotTooling,
|
||||
Source: ToolPromptSourceDiscovery,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
@@ -95,6 +103,14 @@ func (t *BM25SearchTool) Description() string {
|
||||
return "Search available hidden tools on-demand using natural language query describing the action you need to perform. Returns JSON schemas of discovered tools."
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) PromptMetadata() PromptMetadata {
|
||||
return PromptMetadata{
|
||||
Layer: ToolPromptLayerCapability,
|
||||
Slot: ToolPromptSlotTooling,
|
||||
Source: ToolPromptSourceDiscovery,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
|
||||
@@ -14,6 +14,24 @@ type Tool interface {
|
||||
Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
}
|
||||
|
||||
const (
|
||||
ToolPromptLayerCapability = "capability"
|
||||
ToolPromptSlotTooling = "tooling"
|
||||
ToolPromptSlotMCP = "mcp"
|
||||
ToolPromptSourceRegistry = "tool_registry:native"
|
||||
ToolPromptSourceDiscovery = "tool_registry:discovery"
|
||||
)
|
||||
|
||||
type PromptMetadata struct {
|
||||
Layer string
|
||||
Slot string
|
||||
Source string
|
||||
}
|
||||
|
||||
type PromptMetadataProvider interface {
|
||||
PromptMetadata() PromptMetadata
|
||||
}
|
||||
|
||||
// --- Request-scoped tool context (channel / chatID) ---
|
||||
//
|
||||
// Carried via context.Value so that concurrent tool calls each receive
|
||||
|
||||
@@ -22,12 +22,20 @@ type (
|
||||
Tool = toolshared.Tool
|
||||
AsyncCallback = toolshared.AsyncCallback
|
||||
AsyncExecutor = toolshared.AsyncExecutor
|
||||
PromptMetadata = toolshared.PromptMetadata
|
||||
PromptMetadataProvider = toolshared.PromptMetadataProvider
|
||||
ToolResult = toolshared.ToolResult
|
||||
)
|
||||
|
||||
const (
|
||||
handledToolLLMNote = toolshared.HandledToolLLMNote
|
||||
artifactPathsLLMNote = toolshared.ArtifactPathsLLMNote
|
||||
|
||||
ToolPromptLayerCapability = toolshared.ToolPromptLayerCapability
|
||||
ToolPromptSlotTooling = toolshared.ToolPromptSlotTooling
|
||||
ToolPromptSlotMCP = toolshared.ToolPromptSlotMCP
|
||||
ToolPromptSourceRegistry = toolshared.ToolPromptSourceRegistry
|
||||
ToolPromptSourceDiscovery = toolshared.ToolPromptSourceDiscovery
|
||||
)
|
||||
|
||||
func WithToolContext(ctx context.Context, channel, chatID string) context.Context {
|
||||
|
||||
Reference in New Issue
Block a user