Files
picoclaw/pkg/tools/registry.go
T
Boris Bliznioukov aef1e8e8c4 fix: eliminate data races on shared tool instances (#1080)
* fix: eliminate data races on shared tool instances

Signed-off-by: Boris Bliznioukov <blib@mail.com>

* fix: remove unused indirect dependency on github.com/gdamore/tcell/v2

Signed-off-by: Boris Bliznioukov <blib@mail.com>

* fix: reviewer comments improve context handling for tool execution and ensure defaults for non-conversation callers

Signed-off-by: Boris Bliznioukov <blib@mail.com>

---------

Signed-off-by: Boris Bliznioukov <blib@mail.com>
2026-03-05 09:57:33 +08:00

206 lines
5.4 KiB
Go

package tools
import (
"context"
"fmt"
"sort"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
type ToolRegistry struct {
tools map[string]Tool
mu sync.RWMutex
}
func NewToolRegistry() *ToolRegistry {
return &ToolRegistry{
tools: make(map[string]Tool),
}
}
func (r *ToolRegistry) Register(tool Tool) {
r.mu.Lock()
defer r.mu.Unlock()
name := tool.Name()
if _, exists := r.tools[name]; exists {
logger.WarnCF("tools", "Tool registration overwrites existing tool",
map[string]any{"name": name})
}
r.tools[name] = tool
}
func (r *ToolRegistry) Get(name string) (Tool, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
tool, ok := r.tools[name]
return tool, ok
}
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult {
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
}
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncExecutor and a non-nil callback is provided,
// ExecuteAsync is called instead of Execute — the callback is a parameter,
// never stored as mutable state on the tool.
func (r *ToolRegistry) ExecuteWithContext(
ctx context.Context,
name string,
args map[string]any,
channel, chatID string,
asyncCallback AsyncCallback,
) *ToolResult {
logger.InfoCF("tool", "Tool execution started",
map[string]any{
"tool": name,
"args": args,
})
tool, ok := r.Get(name)
if !ok {
logger.ErrorCF("tool", "Tool not found",
map[string]any{
"tool": name,
})
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
}
// Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx).
// Always inject — tools validate what they require.
ctx = WithToolContext(ctx, channel, chatID)
// If tool implements AsyncExecutor and callback is provided, use ExecuteAsync.
// The callback is a call parameter, not mutable state on the tool instance.
var result *ToolResult
start := time.Now()
if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil {
logger.DebugCF("tool", "Executing async tool via ExecuteAsync",
map[string]any{
"tool": name,
})
result = asyncExec.ExecuteAsync(ctx, args, asyncCallback)
} else {
result = tool.Execute(ctx, args)
}
duration := time.Since(start)
// Log based on result type
if result.IsError {
logger.ErrorCF("tool", "Tool execution failed",
map[string]any{
"tool": name,
"duration": duration.Milliseconds(),
"error": result.ForLLM,
})
} else if result.Async {
logger.InfoCF("tool", "Tool started (async)",
map[string]any{
"tool": name,
"duration": duration.Milliseconds(),
})
} else {
logger.InfoCF("tool", "Tool execution completed",
map[string]any{
"tool": name,
"duration_ms": duration.Milliseconds(),
"result_length": len(result.ForLLM),
})
}
return result
}
// sortedToolNames returns tool names in sorted order for deterministic iteration.
// This is critical for KV cache stability: non-deterministic map iteration would
// produce different system prompts and tool definitions on each call, invalidating
// the LLM's prefix cache even when no tools have changed.
func (r *ToolRegistry) sortedToolNames() []string {
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
sort.Strings(names)
return names
}
func (r *ToolRegistry) GetDefinitions() []map[string]any {
r.mu.RLock()
defer r.mu.RUnlock()
sorted := r.sortedToolNames()
definitions := make([]map[string]any, 0, len(sorted))
for _, name := range sorted {
definitions = append(definitions, ToolToSchema(r.tools[name]))
}
return definitions
}
// ToProviderDefs converts tool definitions to provider-compatible format.
// This is the format expected by LLM provider APIs.
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
sorted := r.sortedToolNames()
definitions := make([]providers.ToolDefinition, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
schema := ToolToSchema(tool)
// Safely extract nested values with type checks
fn, ok := schema["function"].(map[string]any)
if !ok {
continue
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
params, _ := fn["parameters"].(map[string]any)
definitions = append(definitions, providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: name,
Description: desc,
Parameters: params,
},
})
}
return definitions
}
// List returns a list of all registered tool names.
func (r *ToolRegistry) List() []string {
r.mu.RLock()
defer r.mu.RUnlock()
return r.sortedToolNames()
}
// Count returns the number of registered tools.
func (r *ToolRegistry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.tools)
}
// GetSummaries returns human-readable summaries of all registered tools.
// Returns a slice of "name - description" strings.
func (r *ToolRegistry) GetSummaries() []string {
r.mu.RLock()
defer r.mu.RUnlock()
sorted := r.sortedToolNames()
summaries := make([]string, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description()))
}
return summaries
}