feat(mcp): tool search tools (#1243)

* feat(mcp): tool search tools

* removed unused call_discovered_tool

* improvements and optimizations

* fix gate mcp enabled

* fix TOCTOU race BM25 cache version check

* fix encapsulation bypass on registry internals

* safety comment on TickTTL

* added more unit tests

* enhanced logs
This commit is contained in:
Mauro
2026-03-09 18:21:49 +01:00
committed by GitHub
parent c45c5073c0
commit b89f6445d1
12 changed files with 1481 additions and 48 deletions
+137 -11
View File
@@ -5,20 +5,28 @@ import (
"fmt"
"sort"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
type ToolEntry struct {
Tool Tool
IsCore bool
TTL int
}
type ToolRegistry struct {
tools map[string]Tool
mu sync.RWMutex
tools map[string]*ToolEntry
mu sync.RWMutex
version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation
}
func NewToolRegistry() *ToolRegistry {
return &ToolRegistry{
tools: make(map[string]Tool),
tools: make(map[string]*ToolEntry),
}
}
@@ -30,14 +38,116 @@ func (r *ToolRegistry) Register(tool Tool) {
logger.WarnCF("tools", "Tool registration overwrites existing tool",
map[string]any{"name": name})
}
r.tools[name] = tool
r.tools[name] = &ToolEntry{
Tool: tool,
IsCore: true,
TTL: 0, // Core tools do not use TTL
}
r.version.Add(1)
logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name})
}
// RegisterHidden saves hidden tools (visible only via TTL)
func (r *ToolRegistry) RegisterHidden(tool Tool) {
r.mu.Lock()
defer r.mu.Unlock()
name := tool.Name()
if _, exists := r.tools[name]; exists {
logger.WarnCF("tools", "Hidden tool registration overwrites existing tool",
map[string]any{"name": name})
}
r.tools[name] = &ToolEntry{
Tool: tool,
IsCore: false,
TTL: 0,
}
r.version.Add(1)
logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name})
}
// PromoteTools atomically sets the TTL for multiple non-core tools.
// This prevents a concurrent TickTTL from decrementing between promotions.
func (r *ToolRegistry) PromoteTools(names []string, ttl int) {
r.mu.Lock()
defer r.mu.Unlock()
promoted := 0
for _, name := range names {
if entry, exists := r.tools[name]; exists {
if !entry.IsCore {
entry.TTL = ttl
promoted++
}
}
}
logger.DebugCF(
"tools",
"PromoteTools completed",
map[string]any{"requested": len(names), "promoted": promoted, "ttl": ttl},
)
}
// TickTTL decreases TTL only for non-core tools
func (r *ToolRegistry) TickTTL() {
r.mu.Lock()
defer r.mu.Unlock()
for _, entry := range r.tools {
if !entry.IsCore && entry.TTL > 0 {
entry.TTL--
}
}
}
// Version returns the current registry version (atomically).
func (r *ToolRegistry) Version() uint64 {
return r.version.Load()
}
// HiddenToolSnapshot holds a consistent snapshot of hidden tools and the
// registry version at which it was taken. Used by BM25SearchTool cache.
type HiddenToolSnapshot struct {
Docs []HiddenToolDoc
Version uint64
}
// HiddenToolDoc is a lightweight representation of a hidden tool for search indexing.
type HiddenToolDoc struct {
Name string
Description string
}
// SnapshotHiddenTools returns all non-core tools and the current registry
// version under a single read-lock, guaranteeing consistency between the
// two values.
func (r *ToolRegistry) SnapshotHiddenTools() HiddenToolSnapshot {
r.mu.RLock()
defer r.mu.RUnlock()
docs := make([]HiddenToolDoc, 0, len(r.tools))
for name, entry := range r.tools {
if !entry.IsCore {
docs = append(docs, HiddenToolDoc{
Name: name,
Description: entry.Tool.Description(),
})
}
}
return HiddenToolSnapshot{
Docs: docs,
Version: r.version.Load(),
}
}
func (r *ToolRegistry) Get(name string) (Tool, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
tool, ok := r.tools[name]
return tool, ok
entry, ok := r.tools[name]
if !ok {
return nil, false
}
// Hidden tools with expired TTL are not callable.
if !entry.IsCore && entry.TTL <= 0 {
return nil, false
}
return entry.Tool, true
}
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult {
@@ -135,7 +245,13 @@ func (r *ToolRegistry) GetDefinitions() []map[string]any {
sorted := r.sortedToolNames()
definitions := make([]map[string]any, 0, len(sorted))
for _, name := range sorted {
definitions = append(definitions, ToolToSchema(r.tools[name]))
entry := r.tools[name]
if !entry.IsCore && entry.TTL <= 0 {
continue
}
definitions = append(definitions, ToolToSchema(r.tools[name].Tool))
}
return definitions
}
@@ -149,8 +265,13 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
sorted := r.sortedToolNames()
definitions := make([]providers.ToolDefinition, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
schema := ToolToSchema(tool)
entry := r.tools[name]
if !entry.IsCore && entry.TTL <= 0 {
continue
}
schema := ToolToSchema(entry.Tool)
// Safely extract nested values with type checks
fn, ok := schema["function"].(map[string]any)
@@ -198,8 +319,13 @@ func (r *ToolRegistry) GetSummaries() []string {
sorted := r.sortedToolNames()
summaries := make([]string, 0, len(sorted))
for _, name := range sorted {
tool := r.tools[name]
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description()))
entry := r.tools[name]
if !entry.IsCore && entry.TTL <= 0 {
continue
}
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", entry.Tool.Name(), entry.Tool.Description()))
}
return summaries
}
+304
View File
@@ -0,0 +1,304 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
const (
MaxRegexPatternLength = 200
)
type RegexSearchTool struct {
registry *ToolRegistry
ttl int
maxSearchResults int
}
func NewRegexSearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *RegexSearchTool {
return &RegexSearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults}
}
func (t *RegexSearchTool) Name() string {
return "tool_search_tool_regex"
}
func (t *RegexSearchTool) Description() string {
return "Search available hidden tools on-demand using a regex pattern. Returns JSON schemas of discovered tools."
}
func (t *RegexSearchTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"pattern": map[string]any{
"type": "string",
"description": "Regex pattern to match tool name or description",
},
},
"required": []string{"pattern"},
}
}
func (t *RegexSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
pattern, ok := args["pattern"].(string)
if !ok || strings.TrimSpace(pattern) == "" {
// An empty string regex (?i) will match every hidden tool,
// dumping massive payloads into the context and burning tokens.
return ErrorResult("Missing or invalid 'pattern' argument. Must be a non-empty string.")
}
if len(pattern) > MaxRegexPatternLength {
logger.WarnCF("discovery", "Regex pattern rejected (too long)", map[string]any{"len": len(pattern)})
return ErrorResult(fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength))
}
logger.DebugCF("discovery", "Regex search", map[string]any{"pattern": pattern})
res, err := t.registry.SearchRegex(pattern, t.maxSearchResults)
if err != nil {
logger.WarnCF("discovery", "Invalid regex pattern", map[string]any{"pattern": pattern, "error": err.Error()})
return ErrorResult(fmt.Sprintf("Invalid regex pattern syntax: %v. Please fix your regex and try again.", err))
}
logger.InfoCF("discovery", "Regex search completed", map[string]any{"pattern": pattern, "results": len(res)})
return formatDiscoveryResponse(t.registry, res, t.ttl)
}
type BM25SearchTool struct {
registry *ToolRegistry
ttl int
maxSearchResults int
// Cache: rebuilt only when the registry version changes.
cacheMu sync.Mutex
cachedEngine *bm25CachedEngine
cacheVersion uint64
}
func NewBM25SearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *BM25SearchTool {
return &BM25SearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults}
}
func (t *BM25SearchTool) Name() string {
return "tool_search_tool_bm25"
}
func (t *BM25SearchTool) Description() string {
return "Search available hidden tools on-demand using natural language query describing the action you need to perform. Returns JSON schemas of discovered tools."
}
func (t *BM25SearchTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "Search query",
},
},
"required": []string{"query"},
}
}
func (t *BM25SearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
query, ok := args["query"].(string)
if !ok || strings.TrimSpace(query) == "" {
// An empty string query will match every hidden tool,
// dumping massive payloads into the context and burning tokens.
return ErrorResult("Missing or invalid 'query' argument. Must be a non-empty string.")
}
logger.DebugCF("discovery", "BM25 search", map[string]any{"query": query})
cached := t.getOrBuildEngine()
if cached == nil {
logger.DebugCF("discovery", "BM25 search: no hidden tools available", nil)
return SilentResult("No tools found matching the query.")
}
ranked := cached.engine.Search(query, t.maxSearchResults)
if len(ranked) == 0 {
logger.DebugCF("discovery", "BM25 search: no matches", map[string]any{"query": query})
return SilentResult("No tools found matching the query.")
}
results := make([]ToolSearchResult, len(ranked))
for i, r := range ranked {
results[i] = ToolSearchResult{
Name: r.Document.Name,
Description: r.Document.Description,
}
}
logger.InfoCF("discovery", "BM25 search completed", map[string]any{"query": query, "results": len(results)})
return formatDiscoveryResponse(t.registry, results, t.ttl)
}
// ToolSearchResult represents the result returned to the LLM.
// Parameters are omitted from the JSON response to save context tokens;
// the LLM will see full schemas via ToProviderDefs after promotion.
type ToolSearchResult struct {
Name string `json:"name"`
Description string `json:"description"`
}
func (r *ToolRegistry) SearchRegex(pattern string, maxSearchResults int) ([]ToolSearchResult, error) {
if maxSearchResults <= 0 {
return nil, nil
}
regex, err := regexp.Compile("(?i)" + pattern)
if err != nil {
return nil, fmt.Errorf("failed to compile regex pattern %q: %w", pattern, err)
}
r.mu.RLock()
defer r.mu.RUnlock()
var results []ToolSearchResult
// Iterate in sorted order for deterministic results across calls.
for _, name := range r.sortedToolNames() {
entry := r.tools[name]
// Search only among the hidden tools (Core tools are already visible)
if !entry.IsCore {
// Directly call interface methods! No reflection/unmarshalling needed.
desc := entry.Tool.Description()
if regex.MatchString(name) || regex.MatchString(desc) {
results = append(results, ToolSearchResult{
Name: name,
Description: desc,
})
if len(results) >= maxSearchResults {
break // Stop searching once we hit the max! Saves CPU.
}
}
}
}
return results, nil
}
func formatDiscoveryResponse(registry *ToolRegistry, results []ToolSearchResult, ttl int) *ToolResult {
if len(results) == 0 {
return SilentResult("No tools found matching the query.")
}
names := make([]string, len(results))
for i, r := range results {
names[i] = r.Name
}
registry.PromoteTools(names, ttl)
logger.InfoCF("discovery", "Promoted tools", map[string]any{"tools": names, "ttl": ttl})
b, err := json.Marshal(results)
if err != nil {
return ErrorResult("Failed to format search results: " + err.Error())
}
msg := fmt.Sprintf(
"Found %d tools:\n%s\n\nSUCCESS: These tools have been temporarily UNLOCKED as native tools! In your next response, you can call them directly just like any normal tool",
len(results),
string(b),
)
return SilentResult(msg)
}
// Lightweight internal type used as corpus document for BM25.
type searchDoc struct {
Name string
Description string
}
// bm25CachedEngine wraps a BM25Engine with its corpus snapshot.
type bm25CachedEngine struct {
engine *utils.BM25Engine[searchDoc]
}
// snapshotToSearchDocs converts a HiddenToolSnapshot to BM25 searchDoc slice.
func snapshotToSearchDocs(snap HiddenToolSnapshot) []searchDoc {
docs := make([]searchDoc, len(snap.Docs))
for i, d := range snap.Docs {
docs[i] = searchDoc{Name: d.Name, Description: d.Description}
}
return docs
}
// buildBM25Engine creates a BM25Engine from a slice of searchDocs.
func buildBM25Engine(docs []searchDoc) *utils.BM25Engine[searchDoc] {
return utils.NewBM25Engine(
docs,
func(doc searchDoc) string {
return doc.Name + " " + doc.Description
},
)
}
// getOrBuildEngine returns a cached BM25 engine, rebuilding it only when
// the registry version has changed (new tools registered).
func (t *BM25SearchTool) getOrBuildEngine() *bm25CachedEngine {
// Fast path: optimistic check without locking.
if t.cachedEngine != nil && t.cacheVersion == t.registry.Version() {
return t.cachedEngine
}
t.cacheMu.Lock()
defer t.cacheMu.Unlock()
// Snapshot + version are read under a single registry RLock,
// guaranteeing consistency (no TOCTOU).
snap := t.registry.SnapshotHiddenTools()
// Re-check: another goroutine may have rebuilt while we waited for cacheMu.
if t.cachedEngine != nil && t.cacheVersion == snap.Version {
return t.cachedEngine
}
docs := snapshotToSearchDocs(snap)
if len(docs) == 0 {
t.cachedEngine = nil
t.cacheVersion = snap.Version
return nil
}
cached := &bm25CachedEngine{engine: buildBM25Engine(docs)}
t.cachedEngine = cached
t.cacheVersion = snap.Version
logger.DebugCF("discovery", "BM25 engine rebuilt", map[string]any{"docs": len(docs), "version": snap.Version})
return cached
}
// SearchBM25 ranks hidden tools against query using BM25 via utils.BM25Engine.
// This non-cached variant rebuilds the engine on every call. Used by tests
// and any code that doesn't hold a BM25SearchTool instance.
func (r *ToolRegistry) SearchBM25(query string, maxSearchResults int) []ToolSearchResult {
snap := r.SnapshotHiddenTools()
docs := snapshotToSearchDocs(snap)
if len(docs) == 0 {
return nil
}
ranked := buildBM25Engine(docs).Search(query, maxSearchResults)
if len(ranked) == 0 {
return nil
}
out := make([]ToolSearchResult, len(ranked))
for i, r := range ranked {
out[i] = ToolSearchResult{
Name: r.Document.Name,
Description: r.Document.Description,
}
}
return out
}
+339
View File
@@ -0,0 +1,339 @@
package tools
import (
"context"
"fmt"
"strings"
"testing"
)
// Dummy tool to fill the registry in our tests.
type mockSearchableTool struct {
name string
desc string
}
func (m *mockSearchableTool) Name() string { return m.name }
func (m *mockSearchableTool) Description() string { return m.desc }
func (m *mockSearchableTool) Parameters() map[string]any {
return map[string]any{"type": "object"}
}
func (m *mockSearchableTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
return SilentResult("mock executed: " + m.name)
}
// Helper to initialize a populated ToolRegistry
func setupPopulatedRegistry() *ToolRegistry {
reg := NewToolRegistry()
// A core tool (NOT to be found by searches)
reg.Register(&mockSearchableTool{
name: "core_search",
desc: "I am a visible core tool for searching files",
})
// Hidden tools (must be found by searches)
reg.RegisterHidden(&mockSearchableTool{
name: "mcp_read_file",
desc: "Read the contents of a system file",
})
reg.RegisterHidden(&mockSearchableTool{
name: "mcp_list_dir",
desc: "List directories and files in the system",
})
reg.RegisterHidden(&mockSearchableTool{
name: "mcp_fetch_net",
desc: "Fetch data from a network database",
})
return reg
}
func TestRegexSearchTool_Execute(t *testing.T) {
reg := setupPopulatedRegistry()
tool := NewRegexSearchTool(reg, 5, 10)
ctx := context.Background()
t.Run("Empty Pattern Error", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{})
if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'pattern'") {
t.Errorf("Expected missing pattern error, got: %v", res.ForLLM)
}
})
t.Run("Invalid Regex Syntax", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"pattern": "[unclosed"})
if !res.IsError || !strings.Contains(res.ForLLM, "Invalid regex pattern syntax") {
t.Errorf("Expected regex syntax error, got: %v", res.ForLLM)
}
})
t.Run("No Match Found", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"pattern": "alien"})
if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") {
t.Errorf("Expected 'no tools found' message, got: %v", res.ForLLM)
}
})
t.Run("Successful Match & Promotion", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"pattern": "system"})
if res.IsError {
t.Fatalf("Unexpected error: %v", res.ForLLM)
}
if !strings.Contains(res.ForLLM, "SUCCESS: These tools have been temporarily UNLOCKED") {
t.Errorf("Expected success string, got: %v", res.ForLLM)
}
if !strings.Contains(res.ForLLM, "mcp_read_file") {
t.Errorf("Expected 'mcp_read_file' in results")
}
// Verify that the TTL has been updated for the tools found
reg.mu.RLock()
defer reg.mu.RUnlock()
if reg.tools["mcp_read_file"].TTL != 5 {
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 5, got %d", reg.tools["mcp_read_file"].TTL)
}
if reg.tools["mcp_fetch_net"].TTL != 0 {
t.Errorf("Expected 'mcp_fetch_net' to NOT be promoted (TTL=0)")
}
})
}
func TestBM25SearchTool_Execute(t *testing.T) {
reg := setupPopulatedRegistry()
tool := NewBM25SearchTool(reg, 3, 10)
ctx := context.Background()
t.Run("Empty Query Error", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"query": " "})
if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'query'") {
t.Errorf("Expected missing query error, got: %v", res.ForLLM)
}
})
t.Run("No Match Found", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"query": "aliens spaceships"})
if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") {
t.Errorf("Expected 'no tools found', got: %v", res.ForLLM)
}
})
t.Run("Successful Match & Promotion", func(t *testing.T) {
res := tool.Execute(ctx, map[string]any{"query": "read files"})
if res.IsError {
t.Fatalf("Unexpected error: %v", res.ForLLM)
}
if !strings.Contains(res.ForLLM, "mcp_read_file") {
t.Errorf("Expected 'mcp_read_file' in BM25 results")
}
reg.mu.RLock()
defer reg.mu.RUnlock()
if reg.tools["mcp_read_file"].TTL != 3 {
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 3")
}
})
}
func TestRegexSearchTool_PatternTooLong(t *testing.T) {
reg := setupPopulatedRegistry()
tool := NewRegexSearchTool(reg, 5, 10)
ctx := context.Background()
longPattern := strings.Repeat("a", MaxRegexPatternLength+1)
res := tool.Execute(ctx, map[string]any{"pattern": longPattern})
if !res.IsError || !strings.Contains(res.ForLLM, "Pattern too long") {
t.Errorf("Expected pattern too long error, got: %v", res.ForLLM)
}
}
func TestSearchRegex_ZeroMaxResults(t *testing.T) {
reg := setupPopulatedRegistry()
res, err := reg.SearchRegex("mcp", 0)
if err != nil {
t.Fatalf("SearchRegex failed: %v", err)
}
if len(res) != 0 {
t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res))
}
}
func TestSearchBM25_ZeroMaxResults(t *testing.T) {
reg := setupPopulatedRegistry()
res := reg.SearchBM25("read file", 0)
if len(res) != 0 {
t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res))
}
}
func TestSearchRegex_DeterministicOrder(t *testing.T) {
reg := NewToolRegistry()
for i := 0; i < 20; i++ {
reg.RegisterHidden(&mockSearchableTool{
name: fmt.Sprintf("tool_%02d", i),
desc: "searchable tool",
})
}
// Run the same search multiple times and verify order is stable
var firstRun []string
for attempt := 0; attempt < 10; attempt++ {
res, err := reg.SearchRegex("searchable", 20)
if err != nil {
t.Fatalf("SearchRegex failed: %v", err)
}
names := make([]string, len(res))
for i, r := range res {
names[i] = r.Name
}
if attempt == 0 {
firstRun = names
} else {
for i, name := range names {
if name != firstRun[i] {
t.Fatalf("Non-deterministic order at attempt %d, index %d: got %q, want %q",
attempt, i, name, firstRun[i])
}
}
}
}
}
func TestToolRegistry_SearchLimitsAndCoreFiltering(t *testing.T) {
reg := NewToolRegistry()
// Add 1 Core and 10 Hidden, all containing the word "match"
reg.Register(&mockSearchableTool{"core_match", "I am core with match"})
for i := 0; i < 10; i++ {
reg.RegisterHidden(&mockSearchableTool{
name: fmt.Sprintf("hidden_match_%d", i),
desc: "this has a match",
})
}
t.Run("Regex limits and core filtering", func(t *testing.T) {
// Search with Regex and a limit of maxSearchResults = 4
res, err := reg.SearchRegex("match", 4)
if err != nil {
t.Fatalf("SearchRegex failed: %v", err)
}
if len(res) != 4 {
t.Errorf("Expected exactly 4 results due to limit, got %d", len(res))
}
for _, r := range res {
if r.Name == "core_match" {
t.Errorf("SearchRegex returned a Core tool, which should be excluded")
}
}
})
t.Run("BM25 limits and core filtering", func(t *testing.T) {
// Search with BM25 and a limit of maxSearchResults = 3
res := reg.SearchBM25("match", 3)
if len(res) != 3 {
t.Errorf("Expected exactly 3 results due to limit, got %d", len(res))
}
for _, r := range res {
if r.Name == "core_match" {
t.Errorf("SearchBM25 returned a Core tool, which should be excluded")
}
}
})
}
func TestGet_HiddenToolTTLLifecycle(t *testing.T) {
reg := NewToolRegistry()
reg.RegisterHidden(&mockSearchableTool{name: "hidden_tool", desc: "test"})
// TTL=0 at registration → not gettable
_, ok := reg.Get("hidden_tool")
if ok {
t.Error("Expected hidden tool with TTL=0 to NOT be gettable")
}
// Promote → gettable
reg.PromoteTools([]string{"hidden_tool"}, 3)
_, ok = reg.Get("hidden_tool")
if !ok {
t.Error("Expected promoted hidden tool to be gettable")
}
// Tick down to 0 → not gettable again
reg.TickTTL() // 3→2
reg.TickTTL() // 2→1
reg.TickTTL() // 1→0
_, ok = reg.Get("hidden_tool")
if ok {
t.Error("Expected hidden tool with TTL ticked to 0 to NOT be gettable")
}
// Core tools remain always gettable
reg.Register(&mockSearchableTool{name: "core_tool", desc: "core"})
_, ok = reg.Get("core_tool")
if !ok {
t.Error("Expected core tool to always be gettable")
}
}
func TestBM25CacheInvalidation(t *testing.T) {
reg := NewToolRegistry()
reg.RegisterHidden(&mockSearchableTool{name: "tool_alpha", desc: "alpha functionality"})
tool := NewBM25SearchTool(reg, 5, 10)
ctx := context.Background()
// First search should find tool_alpha
res := tool.Execute(ctx, map[string]any{"query": "alpha"})
if !strings.Contains(res.ForLLM, "tool_alpha") {
t.Fatalf("Expected 'tool_alpha' in first search, got: %v", res.ForLLM)
}
// Register a new hidden tool
reg.RegisterHidden(&mockSearchableTool{name: "tool_beta", desc: "beta functionality"})
// Cache should be invalidated; new tool should be findable
res = tool.Execute(ctx, map[string]any{"query": "beta"})
if !strings.Contains(res.ForLLM, "tool_beta") {
t.Errorf("Expected 'tool_beta' after cache invalidation, got: %v", res.ForLLM)
}
}
func TestPromoteTools_ConcurrentWithTickTTL(t *testing.T) {
reg := NewToolRegistry()
for i := 0; i < 20; i++ {
reg.RegisterHidden(&mockSearchableTool{
name: fmt.Sprintf("concurrent_tool_%d", i),
desc: "concurrent test tool",
})
}
names := make([]string, 20)
for i := 0; i < 20; i++ {
names[i] = fmt.Sprintf("concurrent_tool_%d", i)
}
// Hammer PromoteTools and TickTTL concurrently to detect races
done := make(chan struct{})
go func() {
for i := 0; i < 1000; i++ {
reg.PromoteTools(names, 5)
}
close(done)
}()
for i := 0; i < 1000; i++ {
reg.TickTTL()
}
<-done
}