From b89f6445d10545717e5526064472fecbc6340b07 Mon Sep 17 00:00:00 2001 From: Mauro Date: Mon, 9 Mar 2026 18:21:49 +0100 Subject: [PATCH] feat(mcp): tool search tools (#1243) * feat(mcp): tool search tools * removed unused call_discovered_tool * improvements and optimizations * fix gate mcp enabled * fix TOCTOU race BM25 cache version check * fix encapsulation bypass on registry internals * safety comment on TickTTL * added more unit tests * enhanced logs --- config/config.example.json | 16 +- docs/tools_configuration.md | 151 ++++++++++++--- pkg/agent/context.go | 40 +++- pkg/agent/instance.go | 6 +- pkg/agent/loop.go | 60 +++++- pkg/config/config.go | 11 +- pkg/config/defaults.go | 7 + pkg/tools/registry.go | 148 ++++++++++++-- pkg/tools/search_tool.go | 304 +++++++++++++++++++++++++++++ pkg/tools/search_tools_test.go | 339 +++++++++++++++++++++++++++++++++ pkg/utils/bm25.go | 272 ++++++++++++++++++++++++++ pkg/utils/bm25_test.go | 175 +++++++++++++++++ 12 files changed, 1481 insertions(+), 48 deletions(-) create mode 100644 pkg/tools/search_tool.go create mode 100644 pkg/tools/search_tools_test.go create mode 100644 pkg/utils/bm25.go create mode 100644 pkg/utils/bm25_test.go diff --git a/config/config.example.json b/config/config.example.json index 0e2cae8e5..3a33b3caf 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -194,8 +194,13 @@ "nickserv_password": "", "sasl_user": "", "sasl_password": "", - "channels": ["#mychannel"], - "request_caps": ["server-time", "message-tags"], + "channels": [ + "#mychannel" + ], + "request_caps": [ + "server-time", + "message-tags" + ], "allow_from": [], "group_trigger": { "mention_only": true @@ -316,6 +321,13 @@ }, "mcp": { "enabled": false, + "discovery": { + "enabled": false, + "ttl": 5, + "max_search_results": 5, + "use_bm25": true, + "use_regex": false + }, "servers": { "context7": { "enabled": false, diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md index e64a3a107..8c8eb31f0 100644 --- a/docs/tools_configuration.md +++ b/docs/tools_configuration.md @@ -7,11 +7,21 @@ PicoClaw's tools configuration is located in the `tools` field of `config.json`. ```json { "tools": { - "web": { ... }, - "mcp": { ... }, - "exec": { ... }, - "cron": { ... }, - "skills": { ... } + "web": { + ... + }, + "mcp": { + ... + }, + "exec": { + ... + }, + "cron": { + ... + }, + "skills": { + ... + } } } ``` @@ -23,7 +33,7 @@ Web tools are used for web search and fetching. ### Brave | Config | Type | Default | Description | -| ------------- | ------ | ------- | ------------------------- | +|---------------|--------|---------|---------------------------| | `enabled` | bool | false | Enable Brave search | | `api_key` | string | - | Brave Search API key | | `max_results` | int | 5 | Maximum number of results | @@ -31,14 +41,14 @@ Web tools are used for web search and fetching. ### DuckDuckGo | Config | Type | Default | Description | -| ------------- | ---- | ------- | ------------------------- | +|---------------|------|---------|---------------------------| | `enabled` | bool | true | Enable DuckDuckGo search | | `max_results` | int | 5 | Maximum number of results | ### Perplexity | Config | Type | Default | Description | -| ------------- | ------ | ------- | ------------------------- | +|---------------|--------|---------|---------------------------| | `enabled` | bool | false | Enable Perplexity search | | `api_key` | string | - | Perplexity API key | | `max_results` | int | 5 | Maximum number of results | @@ -48,7 +58,7 @@ Web tools are used for web search and fetching. The exec tool is used to execute shell commands. | Config | Type | Default | Description | -| ---------------------- | ----- | ------- | ------------------------------------------ | +|------------------------|-------|---------|--------------------------------------------| | `enable_deny_patterns` | bool | true | Enable default dangerous command blocking | | `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) | @@ -81,7 +91,10 @@ By default, PicoClaw blocks the following dangerous commands: "tools": { "exec": { "enable_deny_patterns": true, - "custom_deny_patterns": ["\\brm\\s+-r\\b", "\\bkillall\\s+python"] + "custom_deny_patterns": [ + "\\brm\\s+-r\\b", + "\\bkillall\\s+python" + ] } } } @@ -92,24 +105,47 @@ By default, PicoClaw blocks the following dangerous commands: The cron tool is used for scheduling periodic tasks. | Config | Type | Default | Description | -| ---------------------- | ---- | ------- | ---------------------------------------------- | +|------------------------|------|---------|------------------------------------------------| | `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit | ## MCP Tool The MCP tool enables integration with external Model Context Protocol servers. +### Tool Discovery (Lazy Loading) + +When connecting to multiple MCP servers, exposing hundreds of tools simultaneously can exhaust the LLM's context window +and increase API costs. The **Discovery** feature solves this by keeping MCP tools *hidden* by default. + +Instead of loading all tools, the LLM is provided with a lightweight search tool (using BM25 keyword matching or Regex). +When the LLM needs a specific capability, it searches the hidden library. Matching tools are then temporarily "unlocked" +and injected into the context for a configured number of turns (`ttl`). + ### Global Config -| Config | Type | Default | Description | -| --------- | ------ | ------- | ----------------------------------- | -| `enabled` | bool | false | Enable MCP integration globally | -| `servers` | object | `{}` | Map of server name to server config | +| Config | Type | Default | Description | +|-------------|--------|---------|----------------------------------------------| +| `enabled` | bool | false | Enable MCP integration globally | +| `discovery` | object | `{}` | Configuration for Tool Discovery (see below) | +| `servers` | object | `{}` | Map of server name to server config | + +### Discovery Config (`discovery`) + +| Config | Type | Default | Description | +|----------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------| +| `enabled` | bool | false | If true, MCP tools are hidden and loaded on-demand via search. If false, all tools are loaded | +| `ttl` | int | 5 | Number of conversational turns a discovered tool remains unlocked | +| `max_search_results` | int | 5 | Maximum number of tools returned per search query | +| `use_bm25` | bool | true | Enable the natural language/keyword search tool (`tool_search_tool_bm25`). **Warning**: consumes more resources than regex search | +| `use_regex` | bool | false | Enable the regex pattern search tool (`tool_search_tool_regex`) | + +> **Note:** If `discovery.enabled` is `true`, you MUST enable at least one search engine (`use_bm25` or `use_regex`), +> otherwise the application will fail to start. ### Per-Server Config | Config | Type | Required | Description | -| ---------- | ------ | -------- | ------------------------------------------ | +|------------|--------|----------|--------------------------------------------| | `enabled` | bool | yes | Enable this MCP server | | `type` | string | no | Transport type: `stdio`, `sse`, `http` | | `command` | string | stdio | Executable command for stdio transport | @@ -122,8 +158,8 @@ The MCP tool enables integration with external Model Context Protocol servers. ### Transport Behavior - If `type` is omitted, transport is auto-detected: - - `url` is set β†’ `sse` - - `command` is set β†’ `stdio` + - `url` is set β†’ `sse` + - `command` is set β†’ `stdio` - `http` and `sse` both use `url` + optional `headers`. - `env` and `env_file` are only applied to `stdio` servers. @@ -140,7 +176,11 @@ The MCP tool enables integration with external Model Context Protocol servers. "filesystem": { "enabled": true, "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp" + ] } } } @@ -170,20 +210,76 @@ The MCP tool enables integration with external Model Context Protocol servers. } ``` +#### 3) Massive MCP setup with Tool Discovery enabled + +*In this example, the LLM will only see the `tool_search_tool_bm25`. It will search and unlock Github or Postgres tools +dynamically only when requested by the user.* + +```json +{ + "tools": { + "mcp": { + "enabled": true, + "discovery": { + "enabled": true, + "ttl": 5, + "max_search_results": 5, + "use_bm25": true, + "use_regex": false + }, + "servers": { + "github": { + "enabled": true, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-github" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN" + } + }, + "postgres": { + "enabled": true, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-postgres", + "postgresql://user:password@localhost/dbname" + ] + }, + "slack": { + "enabled": true, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-slack" + ], + "env": { + "SLACK_BOT_TOKEN": "YOUR_SLACK_BOT_TOKEN", + "SLACK_TEAM_ID": "YOUR_SLACK_TEAM_ID" + } + } + } + } + } +} +``` + ## Skills Tool The skills tool configures skill discovery and installation via registries like ClawHub. ### Registries -| Config | Type | Default | Description | -| ---------------------------------- | ------ | -------------------- | ----------------------- | -| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | -| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | +| Config | Type | Default | Description | +|------------------------------------|--------|----------------------|----------------------------------------------| +| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | +| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | | `registries.clawhub.auth_token` | string | `""` | Optional Bearer token for higher rate limits | -| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | -| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | -| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | +| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | +| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | +| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | ### Configuration Example @@ -217,4 +313,5 @@ For example: - `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10` - `PICOCLAW_TOOLS_MCP_ENABLED=true` -Note: Nested map-style config (for example `tools.mcp.servers..*`) is configured in `config.json` rather than environment variables. +Note: Nested map-style config (for example `tools.mcp.servers..*`) is configured in `config.json` rather than +environment variables. diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 719b0cb6d..92663a32d 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -18,9 +18,11 @@ import ( ) type ContextBuilder struct { - workspace string - skillsLoader *skills.SkillsLoader - memory *MemoryStore + workspace string + skillsLoader *skills.SkillsLoader + memory *MemoryStore + toolDiscoveryBM25 bool + toolDiscoveryRegex bool // Cache for system prompt to avoid rebuilding on every call. // This fixes issue #607: repeated reprocessing of the entire context. @@ -41,6 +43,12 @@ type ContextBuilder struct { skillFilesAtCache map[string]time.Time } +func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuilder { + cb.toolDiscoveryBM25 = useBM25 + cb.toolDiscoveryRegex = useRegex + return cb +} + func getGlobalConfigDir() string { if home := os.Getenv("PICOCLAW_HOME"); home != "" { return home @@ -71,6 +79,7 @@ func NewContextBuilder(workspace string) *ContextBuilder { func (cb *ContextBuilder) getIdentity() string { workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace)) + toolDiscovery := cb.getDiscoveryRule() return fmt.Sprintf(`# picoclaw 🦞 @@ -90,8 +99,29 @@ Your workspace is at: %s 3. **Memory** - When interacting with me if something seems memorable, update %s/memory/MEMORY.md -4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content.`, - workspacePath, workspacePath, workspacePath, workspacePath, workspacePath) +4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content. + +%s`, + workspacePath, workspacePath, workspacePath, workspacePath, workspacePath, toolDiscovery) +} + +func (cb *ContextBuilder) getDiscoveryRule() string { + if !cb.toolDiscoveryBM25 && !cb.toolDiscoveryRegex { + return "" + } + + var toolNames []string + if cb.toolDiscoveryBM25 { + toolNames = append(toolNames, `"tool_search_tool_bm25"`) + } + if cb.toolDiscoveryRegex { + toolNames = append(toolNames, `"tool_search_tool_regex"`) + } + + return fmt.Sprintf( + `5. **Tool Discovery** - Your visible tools are limited to save memory, but a vast hidden library exists. If you lack the right tool for a task, BEFORE giving up, you MUST search using the %s tool. Do not refuse a request unless the search returns nothing. Found tools will temporarily unlock for your next turn.`, + strings.Join(toolNames, " or "), + ) } func (cb *ContextBuilder) BuildSystemPrompt() string { diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 5a838b67e..b60818875 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -97,7 +97,11 @@ func NewAgentInstance( sessionsDir := filepath.Join(workspace, "sessions") sessionsManager := session.NewSessionManager(sessionsDir) - contextBuilder := NewContextBuilder(workspace) + mcpDiscoveryActive := cfg.Tools.MCP.Enabled && cfg.Tools.MCP.Discovery.Enabled + contextBuilder := NewContextBuilder(workspace).WithToolDiscovery( + mcpDiscoveryActive && cfg.Tools.MCP.Discovery.UseBM25, + mcpDiscoveryActive && cfg.Tools.MCP.Discovery.UseRegex, + ) agentID := routing.DefaultAgentID agentName := "" diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 3d13071c0..58f53bef8 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -283,7 +283,13 @@ func (al *AgentLoop) Run(ctx context.Context) error { } mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) - agent.Tools.Register(mcpTool) + + if al.cfg.Tools.MCP.Discovery.Enabled { + agent.Tools.RegisterHidden(mcpTool) + } else { + agent.Tools.Register(mcpTool) + } + totalRegistrations++ logger.DebugCF("agent", "Registered MCP tool", map[string]any{ @@ -302,6 +308,47 @@ func (al *AgentLoop) Run(ctx context.Context) error { "total_registrations": totalRegistrations, "agent_count": agentCount, }) + + // Initializes Discovery Tools only if enabled by configuration + if al.cfg.Tools.MCP.Enabled && al.cfg.Tools.MCP.Discovery.Enabled { + useBM25 := al.cfg.Tools.MCP.Discovery.UseBM25 + useRegex := al.cfg.Tools.MCP.Discovery.UseRegex + + // Fail fast: If discovery is enabled but no search method is turned on + if !useBM25 && !useRegex { + return fmt.Errorf( + "tool discovery is enabled but neither 'use_bm25' nor 'use_regex' is set to true in the configuration", + ) + } + + ttl := al.cfg.Tools.MCP.Discovery.TTL + if ttl <= 0 { + ttl = 5 // Default value + } + + maxSearchResults := al.cfg.Tools.MCP.Discovery.MaxSearchResults + if maxSearchResults <= 0 { + maxSearchResults = 5 // Default value + } + + logger.InfoCF("agent", "Initializing tool discovery", map[string]any{ + "bm25": useBM25, "regex": useRegex, "ttl": ttl, "max_results": maxSearchResults, + }) + + for _, agentID := range agentIDs { + agent, ok := al.registry.GetAgent(agentID) + if !ok { + continue + } + + if useRegex { + agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults)) + } + if useBM25 { + agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults)) + } + } + } } } @@ -1254,6 +1301,17 @@ func (al *AgentLoop) runLLMIteration( // Save tool result message to session agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) } + + // Tick down TTL of discovered tools after processing tool results. + // Only reached when tool calls were made (the loop continues); + // the break on no-tool-call responses skips this. + // NOTE: This is safe because processMessage is sequential per agent. + // If per-agent concurrency is added, TTL consistency between + // ToProviderDefs and Get must be re-evaluated. + agent.Tools.TickTTL() + logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{ + "agent_id": agent.ID, "iteration": iteration, + }) } return finalContent, iteration, nil diff --git a/pkg/config/config.go b/pkg/config/config.go index 0558d0e8f..a47ab3091 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -578,6 +578,14 @@ type GatewayConfig struct { Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` } +type ToolDiscoveryConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_DISCOVERY_ENABLED"` + TTL int `json:"ttl" env:"PICOCLAW_TOOLS_DISCOVERY_TTL"` + MaxSearchResults int `json:"max_search_results" env:"PICOCLAW_MAX_SEARCH_RESULTS"` + UseBM25 bool `json:"use_bm25" env:"PICOCLAW_TOOLS_DISCOVERY_USE_BM25"` + UseRegex bool `json:"use_regex" env:"PICOCLAW_TOOLS_DISCOVERY_USE_REGEX"` +} + type ToolConfig struct { Enabled bool `json:"enabled" env:"ENABLED"` } @@ -735,7 +743,8 @@ type MCPServerConfig struct { // MCPConfig defines configuration for all MCP servers type MCPConfig struct { - ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"` + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_MCP_"` + Discovery ToolDiscoveryConfig ` json:"discovery"` // Servers is a map of server name to server configuration Servers map[string]MCPServerConfig `json:"servers,omitempty"` } diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 2fd99c1ba..e64baa720 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -443,6 +443,13 @@ func DefaultConfig() *Config { ToolConfig: ToolConfig{ Enabled: false, }, + Discovery: ToolDiscoveryConfig{ + Enabled: false, + TTL: 5, + MaxSearchResults: 5, + UseBM25: true, + UseRegex: false, + }, Servers: map[string]MCPServerConfig{}, }, AppendFile: ToolConfig{ diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index ca8436c67..0635f47d7 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -5,20 +5,28 @@ import ( "fmt" "sort" "sync" + "sync/atomic" "time" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" ) +type ToolEntry struct { + Tool Tool + IsCore bool + TTL int +} + type ToolRegistry struct { - tools map[string]Tool - mu sync.RWMutex + tools map[string]*ToolEntry + mu sync.RWMutex + version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation } func NewToolRegistry() *ToolRegistry { return &ToolRegistry{ - tools: make(map[string]Tool), + tools: make(map[string]*ToolEntry), } } @@ -30,14 +38,116 @@ func (r *ToolRegistry) Register(tool Tool) { logger.WarnCF("tools", "Tool registration overwrites existing tool", map[string]any{"name": name}) } - r.tools[name] = tool + r.tools[name] = &ToolEntry{ + Tool: tool, + IsCore: true, + TTL: 0, // Core tools do not use TTL + } + r.version.Add(1) + logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name}) +} + +// RegisterHidden saves hidden tools (visible only via TTL) +func (r *ToolRegistry) RegisterHidden(tool Tool) { + r.mu.Lock() + defer r.mu.Unlock() + name := tool.Name() + if _, exists := r.tools[name]; exists { + logger.WarnCF("tools", "Hidden tool registration overwrites existing tool", + map[string]any{"name": name}) + } + r.tools[name] = &ToolEntry{ + Tool: tool, + IsCore: false, + TTL: 0, + } + r.version.Add(1) + logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name}) +} + +// PromoteTools atomically sets the TTL for multiple non-core tools. +// This prevents a concurrent TickTTL from decrementing between promotions. +func (r *ToolRegistry) PromoteTools(names []string, ttl int) { + r.mu.Lock() + defer r.mu.Unlock() + promoted := 0 + for _, name := range names { + if entry, exists := r.tools[name]; exists { + if !entry.IsCore { + entry.TTL = ttl + promoted++ + } + } + } + logger.DebugCF( + "tools", + "PromoteTools completed", + map[string]any{"requested": len(names), "promoted": promoted, "ttl": ttl}, + ) +} + +// TickTTL decreases TTL only for non-core tools +func (r *ToolRegistry) TickTTL() { + r.mu.Lock() + defer r.mu.Unlock() + for _, entry := range r.tools { + if !entry.IsCore && entry.TTL > 0 { + entry.TTL-- + } + } +} + +// Version returns the current registry version (atomically). +func (r *ToolRegistry) Version() uint64 { + return r.version.Load() +} + +// HiddenToolSnapshot holds a consistent snapshot of hidden tools and the +// registry version at which it was taken. Used by BM25SearchTool cache. +type HiddenToolSnapshot struct { + Docs []HiddenToolDoc + Version uint64 +} + +// HiddenToolDoc is a lightweight representation of a hidden tool for search indexing. +type HiddenToolDoc struct { + Name string + Description string +} + +// SnapshotHiddenTools returns all non-core tools and the current registry +// version under a single read-lock, guaranteeing consistency between the +// two values. +func (r *ToolRegistry) SnapshotHiddenTools() HiddenToolSnapshot { + r.mu.RLock() + defer r.mu.RUnlock() + docs := make([]HiddenToolDoc, 0, len(r.tools)) + for name, entry := range r.tools { + if !entry.IsCore { + docs = append(docs, HiddenToolDoc{ + Name: name, + Description: entry.Tool.Description(), + }) + } + } + return HiddenToolSnapshot{ + Docs: docs, + Version: r.version.Load(), + } } func (r *ToolRegistry) Get(name string) (Tool, bool) { r.mu.RLock() defer r.mu.RUnlock() - tool, ok := r.tools[name] - return tool, ok + entry, ok := r.tools[name] + if !ok { + return nil, false + } + // Hidden tools with expired TTL are not callable. + if !entry.IsCore && entry.TTL <= 0 { + return nil, false + } + return entry.Tool, true } func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult { @@ -135,7 +245,13 @@ func (r *ToolRegistry) GetDefinitions() []map[string]any { sorted := r.sortedToolNames() definitions := make([]map[string]any, 0, len(sorted)) for _, name := range sorted { - definitions = append(definitions, ToolToSchema(r.tools[name])) + entry := r.tools[name] + + if !entry.IsCore && entry.TTL <= 0 { + continue + } + + definitions = append(definitions, ToolToSchema(r.tools[name].Tool)) } return definitions } @@ -149,8 +265,13 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { sorted := r.sortedToolNames() definitions := make([]providers.ToolDefinition, 0, len(sorted)) for _, name := range sorted { - tool := r.tools[name] - schema := ToolToSchema(tool) + entry := r.tools[name] + + if !entry.IsCore && entry.TTL <= 0 { + continue + } + + schema := ToolToSchema(entry.Tool) // Safely extract nested values with type checks fn, ok := schema["function"].(map[string]any) @@ -198,8 +319,13 @@ func (r *ToolRegistry) GetSummaries() []string { sorted := r.sortedToolNames() summaries := make([]string, 0, len(sorted)) for _, name := range sorted { - tool := r.tools[name] - summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description())) + entry := r.tools[name] + + if !entry.IsCore && entry.TTL <= 0 { + continue + } + + summaries = append(summaries, fmt.Sprintf("- `%s` - %s", entry.Tool.Name(), entry.Tool.Description())) } return summaries } diff --git a/pkg/tools/search_tool.go b/pkg/tools/search_tool.go new file mode 100644 index 000000000..f41c80d90 --- /dev/null +++ b/pkg/tools/search_tool.go @@ -0,0 +1,304 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "regexp" + "strings" + "sync" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + MaxRegexPatternLength = 200 +) + +type RegexSearchTool struct { + registry *ToolRegistry + ttl int + maxSearchResults int +} + +func NewRegexSearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *RegexSearchTool { + return &RegexSearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults} +} + +func (t *RegexSearchTool) Name() string { + return "tool_search_tool_regex" +} + +func (t *RegexSearchTool) Description() string { + return "Search available hidden tools on-demand using a regex pattern. Returns JSON schemas of discovered tools." +} + +func (t *RegexSearchTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "pattern": map[string]any{ + "type": "string", + "description": "Regex pattern to match tool name or description", + }, + }, + "required": []string{"pattern"}, + } +} + +func (t *RegexSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + pattern, ok := args["pattern"].(string) + if !ok || strings.TrimSpace(pattern) == "" { + // An empty string regex (?i) will match every hidden tool, + // dumping massive payloads into the context and burning tokens. + return ErrorResult("Missing or invalid 'pattern' argument. Must be a non-empty string.") + } + + if len(pattern) > MaxRegexPatternLength { + logger.WarnCF("discovery", "Regex pattern rejected (too long)", map[string]any{"len": len(pattern)}) + return ErrorResult(fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength)) + } + + logger.DebugCF("discovery", "Regex search", map[string]any{"pattern": pattern}) + + res, err := t.registry.SearchRegex(pattern, t.maxSearchResults) + if err != nil { + logger.WarnCF("discovery", "Invalid regex pattern", map[string]any{"pattern": pattern, "error": err.Error()}) + return ErrorResult(fmt.Sprintf("Invalid regex pattern syntax: %v. Please fix your regex and try again.", err)) + } + + logger.InfoCF("discovery", "Regex search completed", map[string]any{"pattern": pattern, "results": len(res)}) + return formatDiscoveryResponse(t.registry, res, t.ttl) +} + +type BM25SearchTool struct { + registry *ToolRegistry + ttl int + maxSearchResults int + + // Cache: rebuilt only when the registry version changes. + cacheMu sync.Mutex + cachedEngine *bm25CachedEngine + cacheVersion uint64 +} + +func NewBM25SearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *BM25SearchTool { + return &BM25SearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults} +} + +func (t *BM25SearchTool) Name() string { + return "tool_search_tool_bm25" +} + +func (t *BM25SearchTool) Description() string { + return "Search available hidden tools on-demand using natural language query describing the action you need to perform. Returns JSON schemas of discovered tools." +} + +func (t *BM25SearchTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "Search query", + }, + }, + "required": []string{"query"}, + } +} + +func (t *BM25SearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + query, ok := args["query"].(string) + if !ok || strings.TrimSpace(query) == "" { + // An empty string query will match every hidden tool, + // dumping massive payloads into the context and burning tokens. + return ErrorResult("Missing or invalid 'query' argument. Must be a non-empty string.") + } + + logger.DebugCF("discovery", "BM25 search", map[string]any{"query": query}) + + cached := t.getOrBuildEngine() + if cached == nil { + logger.DebugCF("discovery", "BM25 search: no hidden tools available", nil) + return SilentResult("No tools found matching the query.") + } + + ranked := cached.engine.Search(query, t.maxSearchResults) + if len(ranked) == 0 { + logger.DebugCF("discovery", "BM25 search: no matches", map[string]any{"query": query}) + return SilentResult("No tools found matching the query.") + } + + results := make([]ToolSearchResult, len(ranked)) + for i, r := range ranked { + results[i] = ToolSearchResult{ + Name: r.Document.Name, + Description: r.Document.Description, + } + } + + logger.InfoCF("discovery", "BM25 search completed", map[string]any{"query": query, "results": len(results)}) + return formatDiscoveryResponse(t.registry, results, t.ttl) +} + +// ToolSearchResult represents the result returned to the LLM. +// Parameters are omitted from the JSON response to save context tokens; +// the LLM will see full schemas via ToProviderDefs after promotion. +type ToolSearchResult struct { + Name string `json:"name"` + Description string `json:"description"` +} + +func (r *ToolRegistry) SearchRegex(pattern string, maxSearchResults int) ([]ToolSearchResult, error) { + if maxSearchResults <= 0 { + return nil, nil + } + + regex, err := regexp.Compile("(?i)" + pattern) + if err != nil { + return nil, fmt.Errorf("failed to compile regex pattern %q: %w", pattern, err) + } + + r.mu.RLock() + defer r.mu.RUnlock() + + var results []ToolSearchResult + + // Iterate in sorted order for deterministic results across calls. + for _, name := range r.sortedToolNames() { + entry := r.tools[name] + // Search only among the hidden tools (Core tools are already visible) + if !entry.IsCore { + // Directly call interface methods! No reflection/unmarshalling needed. + desc := entry.Tool.Description() + + if regex.MatchString(name) || regex.MatchString(desc) { + results = append(results, ToolSearchResult{ + Name: name, + Description: desc, + }) + if len(results) >= maxSearchResults { + break // Stop searching once we hit the max! Saves CPU. + } + } + } + } + + return results, nil +} + +func formatDiscoveryResponse(registry *ToolRegistry, results []ToolSearchResult, ttl int) *ToolResult { + if len(results) == 0 { + return SilentResult("No tools found matching the query.") + } + + names := make([]string, len(results)) + for i, r := range results { + names[i] = r.Name + } + registry.PromoteTools(names, ttl) + logger.InfoCF("discovery", "Promoted tools", map[string]any{"tools": names, "ttl": ttl}) + + b, err := json.Marshal(results) + if err != nil { + return ErrorResult("Failed to format search results: " + err.Error()) + } + + msg := fmt.Sprintf( + "Found %d tools:\n%s\n\nSUCCESS: These tools have been temporarily UNLOCKED as native tools! In your next response, you can call them directly just like any normal tool", + len(results), + string(b), + ) + + return SilentResult(msg) +} + +// Lightweight internal type used as corpus document for BM25. +type searchDoc struct { + Name string + Description string +} + +// bm25CachedEngine wraps a BM25Engine with its corpus snapshot. +type bm25CachedEngine struct { + engine *utils.BM25Engine[searchDoc] +} + +// snapshotToSearchDocs converts a HiddenToolSnapshot to BM25 searchDoc slice. +func snapshotToSearchDocs(snap HiddenToolSnapshot) []searchDoc { + docs := make([]searchDoc, len(snap.Docs)) + for i, d := range snap.Docs { + docs[i] = searchDoc{Name: d.Name, Description: d.Description} + } + return docs +} + +// buildBM25Engine creates a BM25Engine from a slice of searchDocs. +func buildBM25Engine(docs []searchDoc) *utils.BM25Engine[searchDoc] { + return utils.NewBM25Engine( + docs, + func(doc searchDoc) string { + return doc.Name + " " + doc.Description + }, + ) +} + +// getOrBuildEngine returns a cached BM25 engine, rebuilding it only when +// the registry version has changed (new tools registered). +func (t *BM25SearchTool) getOrBuildEngine() *bm25CachedEngine { + // Fast path: optimistic check without locking. + if t.cachedEngine != nil && t.cacheVersion == t.registry.Version() { + return t.cachedEngine + } + + t.cacheMu.Lock() + defer t.cacheMu.Unlock() + + // Snapshot + version are read under a single registry RLock, + // guaranteeing consistency (no TOCTOU). + snap := t.registry.SnapshotHiddenTools() + + // Re-check: another goroutine may have rebuilt while we waited for cacheMu. + if t.cachedEngine != nil && t.cacheVersion == snap.Version { + return t.cachedEngine + } + + docs := snapshotToSearchDocs(snap) + if len(docs) == 0 { + t.cachedEngine = nil + t.cacheVersion = snap.Version + return nil + } + + cached := &bm25CachedEngine{engine: buildBM25Engine(docs)} + t.cachedEngine = cached + t.cacheVersion = snap.Version + logger.DebugCF("discovery", "BM25 engine rebuilt", map[string]any{"docs": len(docs), "version": snap.Version}) + return cached +} + +// SearchBM25 ranks hidden tools against query using BM25 via utils.BM25Engine. +// This non-cached variant rebuilds the engine on every call. Used by tests +// and any code that doesn't hold a BM25SearchTool instance. +func (r *ToolRegistry) SearchBM25(query string, maxSearchResults int) []ToolSearchResult { + snap := r.SnapshotHiddenTools() + docs := snapshotToSearchDocs(snap) + if len(docs) == 0 { + return nil + } + + ranked := buildBM25Engine(docs).Search(query, maxSearchResults) + if len(ranked) == 0 { + return nil + } + + out := make([]ToolSearchResult, len(ranked)) + for i, r := range ranked { + out[i] = ToolSearchResult{ + Name: r.Document.Name, + Description: r.Document.Description, + } + } + return out +} diff --git a/pkg/tools/search_tools_test.go b/pkg/tools/search_tools_test.go new file mode 100644 index 000000000..3aae941cb --- /dev/null +++ b/pkg/tools/search_tools_test.go @@ -0,0 +1,339 @@ +package tools + +import ( + "context" + "fmt" + "strings" + "testing" +) + +// Dummy tool to fill the registry in our tests. +type mockSearchableTool struct { + name string + desc string +} + +func (m *mockSearchableTool) Name() string { return m.name } +func (m *mockSearchableTool) Description() string { return m.desc } +func (m *mockSearchableTool) Parameters() map[string]any { + return map[string]any{"type": "object"} +} + +func (m *mockSearchableTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + return SilentResult("mock executed: " + m.name) +} + +// Helper to initialize a populated ToolRegistry +func setupPopulatedRegistry() *ToolRegistry { + reg := NewToolRegistry() + + // A core tool (NOT to be found by searches) + reg.Register(&mockSearchableTool{ + name: "core_search", + desc: "I am a visible core tool for searching files", + }) + + // Hidden tools (must be found by searches) + reg.RegisterHidden(&mockSearchableTool{ + name: "mcp_read_file", + desc: "Read the contents of a system file", + }) + reg.RegisterHidden(&mockSearchableTool{ + name: "mcp_list_dir", + desc: "List directories and files in the system", + }) + reg.RegisterHidden(&mockSearchableTool{ + name: "mcp_fetch_net", + desc: "Fetch data from a network database", + }) + + return reg +} + +func TestRegexSearchTool_Execute(t *testing.T) { + reg := setupPopulatedRegistry() + tool := NewRegexSearchTool(reg, 5, 10) + ctx := context.Background() + + t.Run("Empty Pattern Error", func(t *testing.T) { + res := tool.Execute(ctx, map[string]any{}) + if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'pattern'") { + t.Errorf("Expected missing pattern error, got: %v", res.ForLLM) + } + }) + + t.Run("Invalid Regex Syntax", func(t *testing.T) { + res := tool.Execute(ctx, map[string]any{"pattern": "[unclosed"}) + if !res.IsError || !strings.Contains(res.ForLLM, "Invalid regex pattern syntax") { + t.Errorf("Expected regex syntax error, got: %v", res.ForLLM) + } + }) + + t.Run("No Match Found", func(t *testing.T) { + res := tool.Execute(ctx, map[string]any{"pattern": "alien"}) + if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") { + t.Errorf("Expected 'no tools found' message, got: %v", res.ForLLM) + } + }) + + t.Run("Successful Match & Promotion", func(t *testing.T) { + res := tool.Execute(ctx, map[string]any{"pattern": "system"}) + + if res.IsError { + t.Fatalf("Unexpected error: %v", res.ForLLM) + } + if !strings.Contains(res.ForLLM, "SUCCESS: These tools have been temporarily UNLOCKED") { + t.Errorf("Expected success string, got: %v", res.ForLLM) + } + if !strings.Contains(res.ForLLM, "mcp_read_file") { + t.Errorf("Expected 'mcp_read_file' in results") + } + + // Verify that the TTL has been updated for the tools found + reg.mu.RLock() + defer reg.mu.RUnlock() + if reg.tools["mcp_read_file"].TTL != 5 { + t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 5, got %d", reg.tools["mcp_read_file"].TTL) + } + if reg.tools["mcp_fetch_net"].TTL != 0 { + t.Errorf("Expected 'mcp_fetch_net' to NOT be promoted (TTL=0)") + } + }) +} + +func TestBM25SearchTool_Execute(t *testing.T) { + reg := setupPopulatedRegistry() + tool := NewBM25SearchTool(reg, 3, 10) + ctx := context.Background() + + t.Run("Empty Query Error", func(t *testing.T) { + res := tool.Execute(ctx, map[string]any{"query": " "}) + if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'query'") { + t.Errorf("Expected missing query error, got: %v", res.ForLLM) + } + }) + + t.Run("No Match Found", func(t *testing.T) { + res := tool.Execute(ctx, map[string]any{"query": "aliens spaceships"}) + if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") { + t.Errorf("Expected 'no tools found', got: %v", res.ForLLM) + } + }) + + t.Run("Successful Match & Promotion", func(t *testing.T) { + res := tool.Execute(ctx, map[string]any{"query": "read files"}) + + if res.IsError { + t.Fatalf("Unexpected error: %v", res.ForLLM) + } + if !strings.Contains(res.ForLLM, "mcp_read_file") { + t.Errorf("Expected 'mcp_read_file' in BM25 results") + } + + reg.mu.RLock() + defer reg.mu.RUnlock() + if reg.tools["mcp_read_file"].TTL != 3 { + t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 3") + } + }) +} + +func TestRegexSearchTool_PatternTooLong(t *testing.T) { + reg := setupPopulatedRegistry() + tool := NewRegexSearchTool(reg, 5, 10) + ctx := context.Background() + + longPattern := strings.Repeat("a", MaxRegexPatternLength+1) + res := tool.Execute(ctx, map[string]any{"pattern": longPattern}) + if !res.IsError || !strings.Contains(res.ForLLM, "Pattern too long") { + t.Errorf("Expected pattern too long error, got: %v", res.ForLLM) + } +} + +func TestSearchRegex_ZeroMaxResults(t *testing.T) { + reg := setupPopulatedRegistry() + + res, err := reg.SearchRegex("mcp", 0) + if err != nil { + t.Fatalf("SearchRegex failed: %v", err) + } + if len(res) != 0 { + t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res)) + } +} + +func TestSearchBM25_ZeroMaxResults(t *testing.T) { + reg := setupPopulatedRegistry() + + res := reg.SearchBM25("read file", 0) + if len(res) != 0 { + t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res)) + } +} + +func TestSearchRegex_DeterministicOrder(t *testing.T) { + reg := NewToolRegistry() + for i := 0; i < 20; i++ { + reg.RegisterHidden(&mockSearchableTool{ + name: fmt.Sprintf("tool_%02d", i), + desc: "searchable tool", + }) + } + + // Run the same search multiple times and verify order is stable + var firstRun []string + for attempt := 0; attempt < 10; attempt++ { + res, err := reg.SearchRegex("searchable", 20) + if err != nil { + t.Fatalf("SearchRegex failed: %v", err) + } + + names := make([]string, len(res)) + for i, r := range res { + names[i] = r.Name + } + + if attempt == 0 { + firstRun = names + } else { + for i, name := range names { + if name != firstRun[i] { + t.Fatalf("Non-deterministic order at attempt %d, index %d: got %q, want %q", + attempt, i, name, firstRun[i]) + } + } + } + } +} + +func TestToolRegistry_SearchLimitsAndCoreFiltering(t *testing.T) { + reg := NewToolRegistry() + + // Add 1 Core and 10 Hidden, all containing the word "match" + reg.Register(&mockSearchableTool{"core_match", "I am core with match"}) + for i := 0; i < 10; i++ { + reg.RegisterHidden(&mockSearchableTool{ + name: fmt.Sprintf("hidden_match_%d", i), + desc: "this has a match", + }) + } + + t.Run("Regex limits and core filtering", func(t *testing.T) { + // Search with Regex and a limit of maxSearchResults = 4 + res, err := reg.SearchRegex("match", 4) + if err != nil { + t.Fatalf("SearchRegex failed: %v", err) + } + + if len(res) != 4 { + t.Errorf("Expected exactly 4 results due to limit, got %d", len(res)) + } + + for _, r := range res { + if r.Name == "core_match" { + t.Errorf("SearchRegex returned a Core tool, which should be excluded") + } + } + }) + + t.Run("BM25 limits and core filtering", func(t *testing.T) { + // Search with BM25 and a limit of maxSearchResults = 3 + res := reg.SearchBM25("match", 3) + + if len(res) != 3 { + t.Errorf("Expected exactly 3 results due to limit, got %d", len(res)) + } + + for _, r := range res { + if r.Name == "core_match" { + t.Errorf("SearchBM25 returned a Core tool, which should be excluded") + } + } + }) +} + +func TestGet_HiddenToolTTLLifecycle(t *testing.T) { + reg := NewToolRegistry() + reg.RegisterHidden(&mockSearchableTool{name: "hidden_tool", desc: "test"}) + + // TTL=0 at registration β†’ not gettable + _, ok := reg.Get("hidden_tool") + if ok { + t.Error("Expected hidden tool with TTL=0 to NOT be gettable") + } + + // Promote β†’ gettable + reg.PromoteTools([]string{"hidden_tool"}, 3) + _, ok = reg.Get("hidden_tool") + if !ok { + t.Error("Expected promoted hidden tool to be gettable") + } + + // Tick down to 0 β†’ not gettable again + reg.TickTTL() // 3β†’2 + reg.TickTTL() // 2β†’1 + reg.TickTTL() // 1β†’0 + _, ok = reg.Get("hidden_tool") + if ok { + t.Error("Expected hidden tool with TTL ticked to 0 to NOT be gettable") + } + + // Core tools remain always gettable + reg.Register(&mockSearchableTool{name: "core_tool", desc: "core"}) + _, ok = reg.Get("core_tool") + if !ok { + t.Error("Expected core tool to always be gettable") + } +} + +func TestBM25CacheInvalidation(t *testing.T) { + reg := NewToolRegistry() + reg.RegisterHidden(&mockSearchableTool{name: "tool_alpha", desc: "alpha functionality"}) + + tool := NewBM25SearchTool(reg, 5, 10) + ctx := context.Background() + + // First search should find tool_alpha + res := tool.Execute(ctx, map[string]any{"query": "alpha"}) + if !strings.Contains(res.ForLLM, "tool_alpha") { + t.Fatalf("Expected 'tool_alpha' in first search, got: %v", res.ForLLM) + } + + // Register a new hidden tool + reg.RegisterHidden(&mockSearchableTool{name: "tool_beta", desc: "beta functionality"}) + + // Cache should be invalidated; new tool should be findable + res = tool.Execute(ctx, map[string]any{"query": "beta"}) + if !strings.Contains(res.ForLLM, "tool_beta") { + t.Errorf("Expected 'tool_beta' after cache invalidation, got: %v", res.ForLLM) + } +} + +func TestPromoteTools_ConcurrentWithTickTTL(t *testing.T) { + reg := NewToolRegistry() + for i := 0; i < 20; i++ { + reg.RegisterHidden(&mockSearchableTool{ + name: fmt.Sprintf("concurrent_tool_%d", i), + desc: "concurrent test tool", + }) + } + + names := make([]string, 20) + for i := 0; i < 20; i++ { + names[i] = fmt.Sprintf("concurrent_tool_%d", i) + } + + // Hammer PromoteTools and TickTTL concurrently to detect races + done := make(chan struct{}) + go func() { + for i := 0; i < 1000; i++ { + reg.PromoteTools(names, 5) + } + close(done) + }() + + for i := 0; i < 1000; i++ { + reg.TickTTL() + } + <-done +} diff --git a/pkg/utils/bm25.go b/pkg/utils/bm25.go new file mode 100644 index 000000000..95c63f0e3 --- /dev/null +++ b/pkg/utils/bm25.go @@ -0,0 +1,272 @@ +// Package utils provides shared, reusable algorithms. +// This file implements a generic BM25 search engine. +// +// Usage: +// +// type MyDoc struct { ID string; Body string } +// +// corpus := []MyDoc{...} +// engine := bm25.New(corpus, func(d MyDoc) string { +// return d.ID + " " + d.Body +// }) +// results := engine.Search("my query", 5) +package utils + +import ( + "math" + "sort" + "strings" +) + +// ── Tuning defaults ─────────────────────────────────────────────────────────── + +const ( + // DefaultBM25K1 is the term-frequency saturation factor (typical range 1.2–2.0). + // Higher values give more weight to repeated terms. + DefaultBM25K1 = 1.2 + + // DefaultBM25B is the document-length normalization factor (0 = none, 1 = full). + DefaultBM25B = 0.75 +) + +// BM25Engine is a query-time BM25 search engine over a generic corpus. +// T is the document type; the caller supplies a TextFunc that extracts the +// searchable text from each document. +// +// The engine is stateless between queries: no caching, no invalidation logic. +// All indexing work is performed inside Search() on every call, making it +// safe to use on corpora that change frequently. +type BM25Engine[T any] struct { + corpus []T + textFunc func(T) string + k1 float64 + b float64 +} + +// BM25Option is a functional option to configure a BM25Engine. +type BM25Option func(*bm25Config) + +type bm25Config struct { + k1 float64 + b float64 +} + +// WithK1 overrides the term-frequency saturation constant (default 1.2). +func WithK1(k1 float64) BM25Option { + return func(c *bm25Config) { c.k1 = k1 } +} + +// WithB overrides the document-length normalization factor (default 0.75). +func WithB(b float64) BM25Option { + return func(c *bm25Config) { c.b = b } +} + +// NewBM25Engine creates a BM25Engine for the given corpus. +// +// - corpus : slice of documents of any type T. +// - textFunc : function that returns the searchable text for a document. +// - opts : optional tuning (WithK1, WithB). +// +// The corpus slice is referenced, not copied. Callers must not mutate it +// concurrently with Search(). +func NewBM25Engine[T any](corpus []T, textFunc func(T) string, opts ...BM25Option) *BM25Engine[T] { + cfg := bm25Config{k1: DefaultBM25K1, b: DefaultBM25B} + for _, o := range opts { + o(&cfg) + } + return &BM25Engine[T]{ + corpus: corpus, + textFunc: textFunc, + k1: cfg.k1, + b: cfg.b, + } +} + +// BM25Result is a single ranked result from a Search call. +type BM25Result[T any] struct { + Document T + Score float32 +} + +// Search ranks the corpus against query and returns the top-k results. +// Returns an empty slice (not nil) when there are no matches. +// +// Complexity: O(NΓ—L) for indexing + O(|Q|Γ—avgPostingLen) for scoring, +// where N = corpus size, L = average document length, Q = query terms. +// Top-k extraction uses a fixed-size min-heap: O(candidates Γ— log k). +func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] { + if topK <= 0 { + return []BM25Result[T]{} + } + + queryTerms := bm25Tokenize(query) + if len(queryTerms) == 0 { + return []BM25Result[T]{} + } + + N := len(e.corpus) + if N == 0 { + return []BM25Result[T]{} + } + + // Step 1: build per-document tf + raw doc lengths + type docEntry struct { + tf map[string]uint32 + rawLen int + } + + entries := make([]docEntry, N) + df := make(map[string]int, 64) + totalLen := 0 + + for i, doc := range e.corpus { + tokens := bm25Tokenize(e.textFunc(doc)) + totalLen += len(tokens) + + tf := make(map[string]uint32, len(tokens)) + for _, t := range tokens { + tf[t]++ + } + // df: each term counts once per document (iterate the map, keys are unique) + for t := range tf { + df[t]++ + } + + entries[i] = docEntry{tf: tf, rawLen: len(tokens)} + } + + avgDocLen := float64(totalLen) / float64(N) + + // Step 2: pre-compute IDF and per-doc length normalization + // IDF (Robertson smoothing): log( (N - df(t) + 0.5) / (df(t) + 0.5) + 1 ) + idf := make(map[string]float32, len(df)) + for term, freq := range df { + idf[term] = float32(math.Log( + (float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1, + )) + } + + // docLenNorm[i] = k1 * (1 - b + b * |doc_i| / avgDocLen) + // Stored as float32 β€” sufficient precision for ranking. + docLenNorm := make([]float32, N) + for i, entry := range entries { + docLenNorm[i] = float32(e.k1 * (1 - e.b + e.b*float64(entry.rawLen)/avgDocLen)) + } + + // Step 3: build inverted index (posting lists) + // Iterate the tf map directly β€” map keys are already unique, no seen-set needed. + posting := make(map[string][]int32, len(df)) + for i, entry := range entries { + for term := range entry.tf { + posting[term] = append(posting[term], int32(i)) + } + } + + // Step 4: score via posting lists + // Deduplicate query terms to avoid double-weighting the same term. + unique := bm25Dedupe(queryTerms) + + scores := make(map[int32]float32) + for _, term := range unique { + termIDF, ok := idf[term] + if !ok { + continue // term not in vocabulary β†’ zero contribution + } + for _, docID := range posting[term] { + freq := float32(entries[docID].tf[term]) + // TF_norm = freq * (k1+1) / (freq + docLenNorm) + tfNorm := freq * float32(e.k1+1) / (freq + docLenNorm[docID]) + scores[docID] += termIDF * tfNorm + } + } + + if len(scores) == 0 { + return []BM25Result[T]{} + } + + // Step 5: top-K via fixed-size min-heap + heap := make([]bm25ScoredDoc, 0, topK) + + for docID, sc := range scores { + switch { + case len(heap) < topK: + heap = append(heap, bm25ScoredDoc{docID: docID, score: sc}) + if len(heap) == topK { + bm25MinHeapify(heap) + } + case sc > heap[0].score: + heap[0] = bm25ScoredDoc{docID: docID, score: sc} + bm25SiftDown(heap, 0) + } + } + + sort.Slice(heap, func(i, j int) bool { return heap[i].score > heap[j].score }) + + out := make([]BM25Result[T], len(heap)) + for i, h := range heap { + out[i] = BM25Result[T]{ + Document: e.corpus[h.docID], + Score: h.score, + } + } + return out +} + +// bm25Tokenize splits s into lowercase tokens, stripping edge punctuation. +func bm25Tokenize(s string) []string { + raw := strings.Fields(strings.ToLower(s)) + out := raw[:0] // reuse backing array to avoid extra allocation + for _, t := range raw { + t = strings.Trim(t, ".,;:!?\"'()/\\-_") + if t != "" { + out = append(out, t) + } + } + return out +} + +// bm25Dedupe returns a new slice with duplicate tokens removed, +// preserving first-occurrence order. +func bm25Dedupe(tokens []string) []string { + seen := make(map[string]struct{}, len(tokens)) + out := make([]string, 0, len(tokens)) + for _, t := range tokens { + if _, ok := seen[t]; !ok { + seen[t] = struct{}{} + out = append(out, t) + } + } + return out +} + +type bm25ScoredDoc struct { + docID int32 + score float32 +} + +// bm25MinHeapify builds a min-heap in-place using Floyd's algorithm: O(k). +func bm25MinHeapify(h []bm25ScoredDoc) { + for i := len(h)/2 - 1; i >= 0; i-- { + bm25SiftDown(h, i) + } +} + +// bm25SiftDown restores the min-heap property starting at node i: O(log k). +func bm25SiftDown(h []bm25ScoredDoc, i int) { + n := len(h) + for { + smallest := i + l, r := 2*i+1, 2*i+2 + if l < n && h[l].score < h[smallest].score { + smallest = l + } + if r < n && h[r].score < h[smallest].score { + smallest = r + } + if smallest == i { + break + } + h[i], h[smallest] = h[smallest], h[i] + i = smallest + } +} diff --git a/pkg/utils/bm25_test.go b/pkg/utils/bm25_test.go new file mode 100644 index 000000000..4bc85b246 --- /dev/null +++ b/pkg/utils/bm25_test.go @@ -0,0 +1,175 @@ +package utils + +import ( + "reflect" + "testing" +) + +// testDoc is a generic structure for use in tests. +type testDoc struct { + ID int + Text string +} + +func extractText(d testDoc) string { + return d.Text +} + +func TestBM25Search_EdgeCases(t *testing.T) { + corpus := []testDoc{ + {1, "hello world"}, + {2, "foo bar"}, + } + engine := NewBM25Engine(corpus, extractText) + + tests := []struct { + name string + query string + topK int + }{ + {"Zero topK", "hello", 0}, + {"Negative topK", "hello", -1}, + {"Empty query", "", 5}, + {"Query with only punctuation", "...,,,!!!", 5}, + {"No matches found", "golang", 5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results := engine.Search(tt.query, tt.topK) + if len(results) != 0 { + t.Errorf("expected 0 results, got %d", len(results)) + } + // Check that it never returns nil, but an empty slice + if results == nil { + t.Errorf("expected empty slice, got nil") + } + }) + } +} + +func TestBM25Search_EmptyCorpus(t *testing.T) { + engine := NewBM25Engine([]testDoc{}, extractText) + results := engine.Search("hello", 5) + if len(results) != 0 || results == nil { + t.Errorf("expected empty slice from empty corpus, got %v", results) + } +} + +func TestBM25Search_RankingLogic(t *testing.T) { + corpus := []testDoc{ + {1, "the quick brown fox jumps over the lazy dog"}, + {2, "quick fox"}, + {3, "quick quick quick fox"}, // High Term Frequency (TF) + {4, "completely irrelevant document here"}, + } + engine := NewBM25Engine(corpus, extractText) + + t.Run("Term Frequency (TF) boosts score", func(t *testing.T) { + results := engine.Search("quick", 5) + if len(results) < 3 { + t.Fatalf("expected at least 3 results, got %d", len(results)) + } + // Doc 3 has the word "quick" repeated 3 times, it should beat Doc 2 + if results[0].Document.ID != 3 { + t.Errorf("expected doc 3 to rank first due to high TF, got doc %d", results[0].Document.ID) + } + }) + + t.Run("Document Length penalty", func(t *testing.T) { + results := engine.Search("fox", 5) + if len(results) < 3 { + t.Fatalf("expected at least 3 results, got %d", len(results)) + } + // Doc 2 ("quick fox") is much shorter than Doc 1 ("the quick brown fox..."), + // so, with equal Term Frequency for the word "fox" (1 time), Doc 2 wins. + if results[0].Document.ID != 2 { + t.Errorf("expected doc 2 to rank first due to shorter length, got doc %d", results[0].Document.ID) + } + }) + + t.Run("TopK limits results", func(t *testing.T) { + results := engine.Search("quick", 2) + if len(results) != 2 { + t.Errorf("expected exactly 2 results, got %d", len(results)) + } + }) +} + +func TestBM25Tokenize(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + {"Hello World", []string{"hello", "world"}}, + {" spaces everywhere ", []string{"spaces", "everywhere"}}, + {"punctuation... test!!!", []string{"punctuation", "test"}}, + {"(parentheses) and-hyphens", []string{"parentheses", "and-hyphens"}}, // hyphens trimmed from edges + {"internal-hyphen is kept", []string{"internal-hyphen", "is", "kept"}}, + {".,;?!", []string{}}, // Becomes empty after trim + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := bm25Tokenize(tt.input) + if len(got) == 0 && len(tt.expected) == 0 { + return // Both empty + } + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("bm25Tokenize(%q) = %v, want %v", tt.input, got, tt.expected) + } + }) + } +} + +func TestBM25Dedupe(t *testing.T) { + input := []string{"apple", "banana", "apple", "orange", "banana"} + expected := []string{"apple", "banana", "orange"} + + got := bm25Dedupe(input) + if !reflect.DeepEqual(got, expected) { + t.Errorf("bm25Dedupe() = %v, want %v", got, expected) + } +} + +func TestBM25Options(t *testing.T) { + corpus := []testDoc{{1, "test"}} + + engine := NewBM25Engine( + corpus, + extractText, + WithK1(2.5), + WithB(0.9), + ) + + if engine.k1 != 2.5 { + t.Errorf("expected k1 to be 2.5, got %v", engine.k1) + } + if engine.b != 0.9 { + t.Errorf("expected b to be 0.9, got %v", engine.b) + } +} + +func TestBM25Search_SortingStability(t *testing.T) { + // Ensure that sorting by heap returns in correct descending order + corpus := []testDoc{ + {1, "golang is good"}, + {2, "golang golang"}, + {3, "golang golang golang"}, + {4, "golang golang golang golang"}, + } + engine := NewBM25Engine(corpus, extractText) + results := engine.Search("golang", 10) + + if len(results) != 4 { + t.Fatalf("expected 4 results, got %d", len(results)) + } + + // Score should be strictly decreasing + for i := 1; i < len(results); i++ { + if results[i].Score > results[i-1].Score { + t.Errorf("results not sorted correctly: result %d score (%v) > result %d score (%v)", + i, results[i].Score, i-1, results[i-1].Score) + } + } +}