package tools import ( "context" "fmt" "sort" "sync" "sync/atomic" "time" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" ) type ToolEntry struct { Tool Tool IsCore bool TTL int } type ToolRegistry struct { tools map[string]*ToolEntry mu sync.RWMutex version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation } func NewToolRegistry() *ToolRegistry { return &ToolRegistry{ tools: make(map[string]*ToolEntry), } } func (r *ToolRegistry) Register(tool Tool) { r.mu.Lock() defer r.mu.Unlock() name := tool.Name() if _, exists := r.tools[name]; exists { logger.WarnCF("tools", "Tool registration overwrites existing tool", map[string]any{"name": name}) } r.tools[name] = &ToolEntry{ Tool: tool, IsCore: true, TTL: 0, // Core tools do not use TTL } r.version.Add(1) logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name}) } // RegisterHidden saves hidden tools (visible only via TTL) func (r *ToolRegistry) RegisterHidden(tool Tool) { r.mu.Lock() defer r.mu.Unlock() name := tool.Name() if _, exists := r.tools[name]; exists { logger.WarnCF("tools", "Hidden tool registration overwrites existing tool", map[string]any{"name": name}) } r.tools[name] = &ToolEntry{ Tool: tool, IsCore: false, TTL: 0, } r.version.Add(1) logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name}) } // PromoteTools atomically sets the TTL for multiple non-core tools. // This prevents a concurrent TickTTL from decrementing between promotions. func (r *ToolRegistry) PromoteTools(names []string, ttl int) { r.mu.Lock() defer r.mu.Unlock() promoted := 0 for _, name := range names { if entry, exists := r.tools[name]; exists { if !entry.IsCore { entry.TTL = ttl promoted++ } } } logger.DebugCF( "tools", "PromoteTools completed", map[string]any{"requested": len(names), "promoted": promoted, "ttl": ttl}, ) } // TickTTL decreases TTL only for non-core tools func (r *ToolRegistry) TickTTL() { r.mu.Lock() defer r.mu.Unlock() for _, entry := range r.tools { if !entry.IsCore && entry.TTL > 0 { entry.TTL-- } } } // Version returns the current registry version (atomically). func (r *ToolRegistry) Version() uint64 { return r.version.Load() } // HiddenToolSnapshot holds a consistent snapshot of hidden tools and the // registry version at which it was taken. Used by BM25SearchTool cache. type HiddenToolSnapshot struct { Docs []HiddenToolDoc Version uint64 } // HiddenToolDoc is a lightweight representation of a hidden tool for search indexing. type HiddenToolDoc struct { Name string Description string } // SnapshotHiddenTools returns all non-core tools and the current registry // version under a single read-lock, guaranteeing consistency between the // two values. func (r *ToolRegistry) SnapshotHiddenTools() HiddenToolSnapshot { r.mu.RLock() defer r.mu.RUnlock() docs := make([]HiddenToolDoc, 0, len(r.tools)) for name, entry := range r.tools { if !entry.IsCore { docs = append(docs, HiddenToolDoc{ Name: name, Description: entry.Tool.Description(), }) } } return HiddenToolSnapshot{ Docs: docs, Version: r.version.Load(), } } func (r *ToolRegistry) Get(name string) (Tool, bool) { r.mu.RLock() defer r.mu.RUnlock() entry, ok := r.tools[name] if !ok { return nil, false } // Hidden tools with expired TTL are not callable. if !entry.IsCore && entry.TTL <= 0 { return nil, false } return entry.Tool, true } func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult { return r.ExecuteWithContext(ctx, name, args, "", "", nil) } // ExecuteWithContext executes a tool with channel/chatID context and optional async callback. // If the tool implements AsyncExecutor and a non-nil callback is provided, // ExecuteAsync is called instead of Execute — the callback is a parameter, // never stored as mutable state on the tool. func (r *ToolRegistry) ExecuteWithContext( ctx context.Context, name string, args map[string]any, channel, chatID string, asyncCallback AsyncCallback, ) *ToolResult { logger.InfoCF("tool", "Tool execution started", map[string]any{ "tool": name, "args": args, }) tool, ok := r.Get(name) if !ok { logger.ErrorCF("tool", "Tool not found", map[string]any{ "tool": name, }) return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) } // Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx). // Always inject — tools validate what they require. ctx = WithToolContext(ctx, channel, chatID) // If tool implements AsyncExecutor and callback is provided, use ExecuteAsync. // The callback is a call parameter, not mutable state on the tool instance. var result *ToolResult start := time.Now() if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil { logger.DebugCF("tool", "Executing async tool via ExecuteAsync", map[string]any{ "tool": name, }) result = asyncExec.ExecuteAsync(ctx, args, asyncCallback) } else { result = tool.Execute(ctx, args) } duration := time.Since(start) // Log based on result type if result.IsError { logger.ErrorCF("tool", "Tool execution failed", map[string]any{ "tool": name, "duration": duration.Milliseconds(), "error": result.ForLLM, }) } else if result.Async { logger.InfoCF("tool", "Tool started (async)", map[string]any{ "tool": name, "duration": duration.Milliseconds(), }) } else { logger.InfoCF("tool", "Tool execution completed", map[string]any{ "tool": name, "duration_ms": duration.Milliseconds(), "result_length": len(result.ForLLM), }) } return result } // sortedToolNames returns tool names in sorted order for deterministic iteration. // This is critical for KV cache stability: non-deterministic map iteration would // produce different system prompts and tool definitions on each call, invalidating // the LLM's prefix cache even when no tools have changed. func (r *ToolRegistry) sortedToolNames() []string { names := make([]string, 0, len(r.tools)) for name := range r.tools { names = append(names, name) } sort.Strings(names) return names } func (r *ToolRegistry) GetDefinitions() []map[string]any { r.mu.RLock() defer r.mu.RUnlock() sorted := r.sortedToolNames() definitions := make([]map[string]any, 0, len(sorted)) for _, name := range sorted { entry := r.tools[name] if !entry.IsCore && entry.TTL <= 0 { continue } definitions = append(definitions, ToolToSchema(r.tools[name].Tool)) } return definitions } // ToProviderDefs converts tool definitions to provider-compatible format. // This is the format expected by LLM provider APIs. func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { r.mu.RLock() defer r.mu.RUnlock() sorted := r.sortedToolNames() definitions := make([]providers.ToolDefinition, 0, len(sorted)) for _, name := range sorted { entry := r.tools[name] if !entry.IsCore && entry.TTL <= 0 { continue } schema := ToolToSchema(entry.Tool) // Safely extract nested values with type checks fn, ok := schema["function"].(map[string]any) if !ok { continue } name, _ := fn["name"].(string) desc, _ := fn["description"].(string) params, _ := fn["parameters"].(map[string]any) definitions = append(definitions, providers.ToolDefinition{ Type: "function", Function: providers.ToolFunctionDefinition{ Name: name, Description: desc, Parameters: params, }, }) } return definitions } // List returns a list of all registered tool names. func (r *ToolRegistry) List() []string { r.mu.RLock() defer r.mu.RUnlock() return r.sortedToolNames() } // Count returns the number of registered tools. func (r *ToolRegistry) Count() int { r.mu.RLock() defer r.mu.RUnlock() return len(r.tools) } // GetSummaries returns human-readable summaries of all registered tools. // Returns a slice of "name - description" strings. func (r *ToolRegistry) GetSummaries() []string { r.mu.RLock() defer r.mu.RUnlock() sorted := r.sortedToolNames() summaries := make([]string, 0, len(sorted)) for _, name := range sorted { entry := r.tools[name] if !entry.IsCore && entry.TTL <= 0 { continue } summaries = append(summaries, fmt.Sprintf("- `%s` - %s", entry.Tool.Name(), entry.Tool.Description())) } return summaries }