mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(mcp): tool search tools (#1243)
* feat(mcp): tool search tools * removed unused call_discovered_tool * improvements and optimizations * fix gate mcp enabled * fix TOCTOU race BM25 cache version check * fix encapsulation bypass on registry internals * safety comment on TickTTL * added more unit tests * enhanced logs
This commit is contained in:
@@ -194,8 +194,13 @@
|
||||
"nickserv_password": "",
|
||||
"sasl_user": "",
|
||||
"sasl_password": "",
|
||||
"channels": ["#mychannel"],
|
||||
"request_caps": ["server-time", "message-tags"],
|
||||
"channels": [
|
||||
"#mychannel"
|
||||
],
|
||||
"request_caps": [
|
||||
"server-time",
|
||||
"message-tags"
|
||||
],
|
||||
"allow_from": [],
|
||||
"group_trigger": {
|
||||
"mention_only": true
|
||||
@@ -316,6 +321,13 @@
|
||||
},
|
||||
"mcp": {
|
||||
"enabled": false,
|
||||
"discovery": {
|
||||
"enabled": false,
|
||||
"ttl": 5,
|
||||
"max_search_results": 5,
|
||||
"use_bm25": true,
|
||||
"use_regex": false
|
||||
},
|
||||
"servers": {
|
||||
"context7": {
|
||||
"enabled": false,
|
||||
|
||||
+124
-27
@@ -7,11 +7,21 @@ PicoClaw's tools configuration is located in the `tools` field of `config.json`.
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": { ... },
|
||||
"mcp": { ... },
|
||||
"exec": { ... },
|
||||
"cron": { ... },
|
||||
"skills": { ... }
|
||||
"web": {
|
||||
...
|
||||
},
|
||||
"mcp": {
|
||||
...
|
||||
},
|
||||
"exec": {
|
||||
...
|
||||
},
|
||||
"cron": {
|
||||
...
|
||||
},
|
||||
"skills": {
|
||||
...
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -23,7 +33,7 @@ Web tools are used for web search and fetching.
|
||||
### Brave
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
| ------------- | ------ | ------- | ------------------------- |
|
||||
|---------------|--------|---------|---------------------------|
|
||||
| `enabled` | bool | false | Enable Brave search |
|
||||
| `api_key` | string | - | Brave Search API key |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
@@ -31,14 +41,14 @@ Web tools are used for web search and fetching.
|
||||
### DuckDuckGo
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
| ------------- | ---- | ------- | ------------------------- |
|
||||
|---------------|------|---------|---------------------------|
|
||||
| `enabled` | bool | true | Enable DuckDuckGo search |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
|
||||
### Perplexity
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
| ------------- | ------ | ------- | ------------------------- |
|
||||
|---------------|--------|---------|---------------------------|
|
||||
| `enabled` | bool | false | Enable Perplexity search |
|
||||
| `api_key` | string | - | Perplexity API key |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
@@ -48,7 +58,7 @@ Web tools are used for web search and fetching.
|
||||
The exec tool is used to execute shell commands.
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
| ---------------------- | ----- | ------- | ------------------------------------------ |
|
||||
|------------------------|-------|---------|--------------------------------------------|
|
||||
| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking |
|
||||
| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) |
|
||||
|
||||
@@ -81,7 +91,10 @@ By default, PicoClaw blocks the following dangerous commands:
|
||||
"tools": {
|
||||
"exec": {
|
||||
"enable_deny_patterns": true,
|
||||
"custom_deny_patterns": ["\\brm\\s+-r\\b", "\\bkillall\\s+python"]
|
||||
"custom_deny_patterns": [
|
||||
"\\brm\\s+-r\\b",
|
||||
"\\bkillall\\s+python"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -92,24 +105,47 @@ By default, PicoClaw blocks the following dangerous commands:
|
||||
The cron tool is used for scheduling periodic tasks.
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
| ---------------------- | ---- | ------- | ---------------------------------------------- |
|
||||
|------------------------|------|---------|------------------------------------------------|
|
||||
| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
|
||||
|
||||
## MCP Tool
|
||||
|
||||
The MCP tool enables integration with external Model Context Protocol servers.
|
||||
|
||||
### Tool Discovery (Lazy Loading)
|
||||
|
||||
When connecting to multiple MCP servers, exposing hundreds of tools simultaneously can exhaust the LLM's context window
|
||||
and increase API costs. The **Discovery** feature solves this by keeping MCP tools *hidden* by default.
|
||||
|
||||
Instead of loading all tools, the LLM is provided with a lightweight search tool (using BM25 keyword matching or Regex).
|
||||
When the LLM needs a specific capability, it searches the hidden library. Matching tools are then temporarily "unlocked"
|
||||
and injected into the context for a configured number of turns (`ttl`).
|
||||
|
||||
### Global Config
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
| --------- | ------ | ------- | ----------------------------------- |
|
||||
| `enabled` | bool | false | Enable MCP integration globally |
|
||||
| `servers` | object | `{}` | Map of server name to server config |
|
||||
| Config | Type | Default | Description |
|
||||
|-------------|--------|---------|----------------------------------------------|
|
||||
| `enabled` | bool | false | Enable MCP integration globally |
|
||||
| `discovery` | object | `{}` | Configuration for Tool Discovery (see below) |
|
||||
| `servers` | object | `{}` | Map of server name to server config |
|
||||
|
||||
### Discovery Config (`discovery`)
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|----------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `enabled` | bool | false | If true, MCP tools are hidden and loaded on-demand via search. If false, all tools are loaded |
|
||||
| `ttl` | int | 5 | Number of conversational turns a discovered tool remains unlocked |
|
||||
| `max_search_results` | int | 5 | Maximum number of tools returned per search query |
|
||||
| `use_bm25` | bool | true | Enable the natural language/keyword search tool (`tool_search_tool_bm25`). **Warning**: consumes more resources than regex search |
|
||||
| `use_regex` | bool | false | Enable the regex pattern search tool (`tool_search_tool_regex`) |
|
||||
|
||||
> **Note:** If `discovery.enabled` is `true`, you MUST enable at least one search engine (`use_bm25` or `use_regex`),
|
||||
> otherwise the application will fail to start.
|
||||
|
||||
### Per-Server Config
|
||||
|
||||
| Config | Type | Required | Description |
|
||||
| ---------- | ------ | -------- | ------------------------------------------ |
|
||||
|------------|--------|----------|--------------------------------------------|
|
||||
| `enabled` | bool | yes | Enable this MCP server |
|
||||
| `type` | string | no | Transport type: `stdio`, `sse`, `http` |
|
||||
| `command` | string | stdio | Executable command for stdio transport |
|
||||
@@ -122,8 +158,8 @@ The MCP tool enables integration with external Model Context Protocol servers.
|
||||
### Transport Behavior
|
||||
|
||||
- If `type` is omitted, transport is auto-detected:
|
||||
- `url` is set → `sse`
|
||||
- `command` is set → `stdio`
|
||||
- `url` is set → `sse`
|
||||
- `command` is set → `stdio`
|
||||
- `http` and `sse` both use `url` + optional `headers`.
|
||||
- `env` and `env_file` are only applied to `stdio` servers.
|
||||
|
||||
@@ -140,7 +176,11 @@ The MCP tool enables integration with external Model Context Protocol servers.
|
||||
"filesystem": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/tmp"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -170,20 +210,76 @@ The MCP tool enables integration with external Model Context Protocol servers.
|
||||
}
|
||||
```
|
||||
|
||||
#### 3) Massive MCP setup with Tool Discovery enabled
|
||||
|
||||
*In this example, the LLM will only see the `tool_search_tool_bm25`. It will search and unlock Github or Postgres tools
|
||||
dynamically only when requested by the user.*
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"mcp": {
|
||||
"enabled": true,
|
||||
"discovery": {
|
||||
"enabled": true,
|
||||
"ttl": 5,
|
||||
"max_search_results": 5,
|
||||
"use_bm25": true,
|
||||
"use_regex": false
|
||||
},
|
||||
"servers": {
|
||||
"github": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-github"
|
||||
],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN"
|
||||
}
|
||||
},
|
||||
"postgres": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-postgres",
|
||||
"postgresql://user:password@localhost/dbname"
|
||||
]
|
||||
},
|
||||
"slack": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-slack"
|
||||
],
|
||||
"env": {
|
||||
"SLACK_BOT_TOKEN": "YOUR_SLACK_BOT_TOKEN",
|
||||
"SLACK_TEAM_ID": "YOUR_SLACK_TEAM_ID"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Skills Tool
|
||||
|
||||
The skills tool configures skill discovery and installation via registries like ClawHub.
|
||||
|
||||
### Registries
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
| ---------------------------------- | ------ | -------------------- | ----------------------- |
|
||||
| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
|
||||
| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
|
||||
| Config | Type | Default | Description |
|
||||
|------------------------------------|--------|----------------------|----------------------------------------------|
|
||||
| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
|
||||
| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
|
||||
| `registries.clawhub.auth_token` | string | `""` | Optional Bearer token for higher rate limits |
|
||||
| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
|
||||
| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
|
||||
| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
|
||||
| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
|
||||
| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
|
||||
| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
|
||||
|
||||
### Configuration Example
|
||||
|
||||
@@ -217,4 +313,5 @@ For example:
|
||||
- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10`
|
||||
- `PICOCLAW_TOOLS_MCP_ENABLED=true`
|
||||
|
||||
Note: Nested map-style config (for example `tools.mcp.servers.<name>.*`) is configured in `config.json` rather than environment variables.
|
||||
Note: Nested map-style config (for example `tools.mcp.servers.<name>.*`) is configured in `config.json` rather than
|
||||
environment variables.
|
||||
|
||||
+35
-5
@@ -18,9 +18,11 @@ import (
|
||||
)
|
||||
|
||||
type ContextBuilder struct {
|
||||
workspace string
|
||||
skillsLoader *skills.SkillsLoader
|
||||
memory *MemoryStore
|
||||
workspace string
|
||||
skillsLoader *skills.SkillsLoader
|
||||
memory *MemoryStore
|
||||
toolDiscoveryBM25 bool
|
||||
toolDiscoveryRegex bool
|
||||
|
||||
// Cache for system prompt to avoid rebuilding on every call.
|
||||
// This fixes issue #607: repeated reprocessing of the entire context.
|
||||
@@ -41,6 +43,12 @@ type ContextBuilder struct {
|
||||
skillFilesAtCache map[string]time.Time
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuilder {
|
||||
cb.toolDiscoveryBM25 = useBM25
|
||||
cb.toolDiscoveryRegex = useRegex
|
||||
return cb
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
return home
|
||||
@@ -71,6 +79,7 @@ func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
|
||||
func (cb *ContextBuilder) getIdentity() string {
|
||||
workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace))
|
||||
toolDiscovery := cb.getDiscoveryRule()
|
||||
|
||||
return fmt.Sprintf(`# picoclaw 🦞
|
||||
|
||||
@@ -90,8 +99,29 @@ Your workspace is at: %s
|
||||
|
||||
3. **Memory** - When interacting with me if something seems memorable, update %s/memory/MEMORY.md
|
||||
|
||||
4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content.`,
|
||||
workspacePath, workspacePath, workspacePath, workspacePath, workspacePath)
|
||||
4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content.
|
||||
|
||||
%s`,
|
||||
workspacePath, workspacePath, workspacePath, workspacePath, workspacePath, toolDiscovery)
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) getDiscoveryRule() string {
|
||||
if !cb.toolDiscoveryBM25 && !cb.toolDiscoveryRegex {
|
||||
return ""
|
||||
}
|
||||
|
||||
var toolNames []string
|
||||
if cb.toolDiscoveryBM25 {
|
||||
toolNames = append(toolNames, `"tool_search_tool_bm25"`)
|
||||
}
|
||||
if cb.toolDiscoveryRegex {
|
||||
toolNames = append(toolNames, `"tool_search_tool_regex"`)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
`5. **Tool Discovery** - Your visible tools are limited to save memory, but a vast hidden library exists. If you lack the right tool for a task, BEFORE giving up, you MUST search using the %s tool. Do not refuse a request unless the search returns nothing. Found tools will temporarily unlock for your next turn.`,
|
||||
strings.Join(toolNames, " or "),
|
||||
)
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) BuildSystemPrompt() string {
|
||||
|
||||
@@ -97,7 +97,11 @@ func NewAgentInstance(
|
||||
sessionsDir := filepath.Join(workspace, "sessions")
|
||||
sessionsManager := session.NewSessionManager(sessionsDir)
|
||||
|
||||
contextBuilder := NewContextBuilder(workspace)
|
||||
mcpDiscoveryActive := cfg.Tools.MCP.Enabled && cfg.Tools.MCP.Discovery.Enabled
|
||||
contextBuilder := NewContextBuilder(workspace).WithToolDiscovery(
|
||||
mcpDiscoveryActive && cfg.Tools.MCP.Discovery.UseBM25,
|
||||
mcpDiscoveryActive && cfg.Tools.MCP.Discovery.UseRegex,
|
||||
)
|
||||
|
||||
agentID := routing.DefaultAgentID
|
||||
agentName := ""
|
||||
|
||||
+59
-1
@@ -283,7 +283,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
agent.Tools.Register(mcpTool)
|
||||
|
||||
if al.cfg.Tools.MCP.Discovery.Enabled {
|
||||
agent.Tools.RegisterHidden(mcpTool)
|
||||
} else {
|
||||
agent.Tools.Register(mcpTool)
|
||||
}
|
||||
|
||||
totalRegistrations++
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]any{
|
||||
@@ -302,6 +308,47 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
"total_registrations": totalRegistrations,
|
||||
"agent_count": agentCount,
|
||||
})
|
||||
|
||||
// Initializes Discovery Tools only if enabled by configuration
|
||||
if al.cfg.Tools.MCP.Enabled && al.cfg.Tools.MCP.Discovery.Enabled {
|
||||
useBM25 := al.cfg.Tools.MCP.Discovery.UseBM25
|
||||
useRegex := al.cfg.Tools.MCP.Discovery.UseRegex
|
||||
|
||||
// Fail fast: If discovery is enabled but no search method is turned on
|
||||
if !useBM25 && !useRegex {
|
||||
return fmt.Errorf(
|
||||
"tool discovery is enabled but neither 'use_bm25' nor 'use_regex' is set to true in the configuration",
|
||||
)
|
||||
}
|
||||
|
||||
ttl := al.cfg.Tools.MCP.Discovery.TTL
|
||||
if ttl <= 0 {
|
||||
ttl = 5 // Default value
|
||||
}
|
||||
|
||||
maxSearchResults := al.cfg.Tools.MCP.Discovery.MaxSearchResults
|
||||
if maxSearchResults <= 0 {
|
||||
maxSearchResults = 5 // Default value
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Initializing tool discovery", map[string]any{
|
||||
"bm25": useBM25, "regex": useRegex, "ttl": ttl, "max_results": maxSearchResults,
|
||||
})
|
||||
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if useRegex {
|
||||
agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
if useBM25 {
|
||||
agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1254,6 +1301,17 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// Save tool result message to session
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
|
||||
}
|
||||
|
||||
// Tick down TTL of discovered tools after processing tool results.
|
||||
// Only reached when tool calls were made (the loop continues);
|
||||
// the break on no-tool-call responses skips this.
|
||||
// NOTE: This is safe because processMessage is sequential per agent.
|
||||
// If per-agent concurrency is added, TTL consistency between
|
||||
// ToProviderDefs and Get must be re-evaluated.
|
||||
agent.Tools.TickTTL()
|
||||
logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{
|
||||
"agent_id": agent.ID, "iteration": iteration,
|
||||
})
|
||||
}
|
||||
|
||||
return finalContent, iteration, nil
|
||||
|
||||
+10
-1
@@ -578,6 +578,14 @@ type GatewayConfig struct {
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
}
|
||||
|
||||
type ToolDiscoveryConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_DISCOVERY_ENABLED"`
|
||||
TTL int `json:"ttl" env:"PICOCLAW_TOOLS_DISCOVERY_TTL"`
|
||||
MaxSearchResults int `json:"max_search_results" env:"PICOCLAW_MAX_SEARCH_RESULTS"`
|
||||
UseBM25 bool `json:"use_bm25" env:"PICOCLAW_TOOLS_DISCOVERY_USE_BM25"`
|
||||
UseRegex bool `json:"use_regex" env:"PICOCLAW_TOOLS_DISCOVERY_USE_REGEX"`
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
Enabled bool `json:"enabled" env:"ENABLED"`
|
||||
}
|
||||
@@ -735,7 +743,8 @@ type MCPServerConfig struct {
|
||||
|
||||
// MCPConfig defines configuration for all MCP servers
|
||||
type MCPConfig struct {
|
||||
ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"`
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_MCP_"`
|
||||
Discovery ToolDiscoveryConfig ` json:"discovery"`
|
||||
// Servers is a map of server name to server configuration
|
||||
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
@@ -443,6 +443,13 @@ func DefaultConfig() *Config {
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
Discovery: ToolDiscoveryConfig{
|
||||
Enabled: false,
|
||||
TTL: 5,
|
||||
MaxSearchResults: 5,
|
||||
UseBM25: true,
|
||||
UseRegex: false,
|
||||
},
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
AppendFile: ToolConfig{
|
||||
|
||||
+137
-11
@@ -5,20 +5,28 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type ToolEntry struct {
|
||||
Tool Tool
|
||||
IsCore bool
|
||||
TTL int
|
||||
}
|
||||
|
||||
type ToolRegistry struct {
|
||||
tools map[string]Tool
|
||||
mu sync.RWMutex
|
||||
tools map[string]*ToolEntry
|
||||
mu sync.RWMutex
|
||||
version atomic.Uint64 // incremented on Register/RegisterHidden for cache invalidation
|
||||
}
|
||||
|
||||
func NewToolRegistry() *ToolRegistry {
|
||||
return &ToolRegistry{
|
||||
tools: make(map[string]Tool),
|
||||
tools: make(map[string]*ToolEntry),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,14 +38,116 @@ func (r *ToolRegistry) Register(tool Tool) {
|
||||
logger.WarnCF("tools", "Tool registration overwrites existing tool",
|
||||
map[string]any{"name": name})
|
||||
}
|
||||
r.tools[name] = tool
|
||||
r.tools[name] = &ToolEntry{
|
||||
Tool: tool,
|
||||
IsCore: true,
|
||||
TTL: 0, // Core tools do not use TTL
|
||||
}
|
||||
r.version.Add(1)
|
||||
logger.DebugCF("tools", "Registered core tool", map[string]any{"name": name})
|
||||
}
|
||||
|
||||
// RegisterHidden saves hidden tools (visible only via TTL)
|
||||
func (r *ToolRegistry) RegisterHidden(tool Tool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
name := tool.Name()
|
||||
if _, exists := r.tools[name]; exists {
|
||||
logger.WarnCF("tools", "Hidden tool registration overwrites existing tool",
|
||||
map[string]any{"name": name})
|
||||
}
|
||||
r.tools[name] = &ToolEntry{
|
||||
Tool: tool,
|
||||
IsCore: false,
|
||||
TTL: 0,
|
||||
}
|
||||
r.version.Add(1)
|
||||
logger.DebugCF("tools", "Registered hidden tool", map[string]any{"name": name})
|
||||
}
|
||||
|
||||
// PromoteTools atomically sets the TTL for multiple non-core tools.
|
||||
// This prevents a concurrent TickTTL from decrementing between promotions.
|
||||
func (r *ToolRegistry) PromoteTools(names []string, ttl int) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
promoted := 0
|
||||
for _, name := range names {
|
||||
if entry, exists := r.tools[name]; exists {
|
||||
if !entry.IsCore {
|
||||
entry.TTL = ttl
|
||||
promoted++
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.DebugCF(
|
||||
"tools",
|
||||
"PromoteTools completed",
|
||||
map[string]any{"requested": len(names), "promoted": promoted, "ttl": ttl},
|
||||
)
|
||||
}
|
||||
|
||||
// TickTTL decreases TTL only for non-core tools
|
||||
func (r *ToolRegistry) TickTTL() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, entry := range r.tools {
|
||||
if !entry.IsCore && entry.TTL > 0 {
|
||||
entry.TTL--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Version returns the current registry version (atomically).
|
||||
func (r *ToolRegistry) Version() uint64 {
|
||||
return r.version.Load()
|
||||
}
|
||||
|
||||
// HiddenToolSnapshot holds a consistent snapshot of hidden tools and the
|
||||
// registry version at which it was taken. Used by BM25SearchTool cache.
|
||||
type HiddenToolSnapshot struct {
|
||||
Docs []HiddenToolDoc
|
||||
Version uint64
|
||||
}
|
||||
|
||||
// HiddenToolDoc is a lightweight representation of a hidden tool for search indexing.
|
||||
type HiddenToolDoc struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// SnapshotHiddenTools returns all non-core tools and the current registry
|
||||
// version under a single read-lock, guaranteeing consistency between the
|
||||
// two values.
|
||||
func (r *ToolRegistry) SnapshotHiddenTools() HiddenToolSnapshot {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
docs := make([]HiddenToolDoc, 0, len(r.tools))
|
||||
for name, entry := range r.tools {
|
||||
if !entry.IsCore {
|
||||
docs = append(docs, HiddenToolDoc{
|
||||
Name: name,
|
||||
Description: entry.Tool.Description(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return HiddenToolSnapshot{
|
||||
Docs: docs,
|
||||
Version: r.version.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
tool, ok := r.tools[name]
|
||||
return tool, ok
|
||||
entry, ok := r.tools[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Hidden tools with expired TTL are not callable.
|
||||
if !entry.IsCore && entry.TTL <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
return entry.Tool, true
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult {
|
||||
@@ -135,7 +245,13 @@ func (r *ToolRegistry) GetDefinitions() []map[string]any {
|
||||
sorted := r.sortedToolNames()
|
||||
definitions := make([]map[string]any, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
definitions = append(definitions, ToolToSchema(r.tools[name]))
|
||||
entry := r.tools[name]
|
||||
|
||||
if !entry.IsCore && entry.TTL <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
definitions = append(definitions, ToolToSchema(r.tools[name].Tool))
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
@@ -149,8 +265,13 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
|
||||
sorted := r.sortedToolNames()
|
||||
definitions := make([]providers.ToolDefinition, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
tool := r.tools[name]
|
||||
schema := ToolToSchema(tool)
|
||||
entry := r.tools[name]
|
||||
|
||||
if !entry.IsCore && entry.TTL <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
schema := ToolToSchema(entry.Tool)
|
||||
|
||||
// Safely extract nested values with type checks
|
||||
fn, ok := schema["function"].(map[string]any)
|
||||
@@ -198,8 +319,13 @@ func (r *ToolRegistry) GetSummaries() []string {
|
||||
sorted := r.sortedToolNames()
|
||||
summaries := make([]string, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
tool := r.tools[name]
|
||||
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description()))
|
||||
entry := r.tools[name]
|
||||
|
||||
if !entry.IsCore && entry.TTL <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", entry.Tool.Name(), entry.Tool.Description()))
|
||||
}
|
||||
return summaries
|
||||
}
|
||||
|
||||
@@ -0,0 +1,304 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxRegexPatternLength = 200
|
||||
)
|
||||
|
||||
type RegexSearchTool struct {
|
||||
registry *ToolRegistry
|
||||
ttl int
|
||||
maxSearchResults int
|
||||
}
|
||||
|
||||
func NewRegexSearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *RegexSearchTool {
|
||||
return &RegexSearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults}
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) Name() string {
|
||||
return "tool_search_tool_regex"
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) Description() string {
|
||||
return "Search available hidden tools on-demand using a regex pattern. Returns JSON schemas of discovered tools."
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Regex pattern to match tool name or description",
|
||||
},
|
||||
},
|
||||
"required": []string{"pattern"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RegexSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
pattern, ok := args["pattern"].(string)
|
||||
if !ok || strings.TrimSpace(pattern) == "" {
|
||||
// An empty string regex (?i) will match every hidden tool,
|
||||
// dumping massive payloads into the context and burning tokens.
|
||||
return ErrorResult("Missing or invalid 'pattern' argument. Must be a non-empty string.")
|
||||
}
|
||||
|
||||
if len(pattern) > MaxRegexPatternLength {
|
||||
logger.WarnCF("discovery", "Regex pattern rejected (too long)", map[string]any{"len": len(pattern)})
|
||||
return ErrorResult(fmt.Sprintf("Pattern too long: max %d characters allowed", MaxRegexPatternLength))
|
||||
}
|
||||
|
||||
logger.DebugCF("discovery", "Regex search", map[string]any{"pattern": pattern})
|
||||
|
||||
res, err := t.registry.SearchRegex(pattern, t.maxSearchResults)
|
||||
if err != nil {
|
||||
logger.WarnCF("discovery", "Invalid regex pattern", map[string]any{"pattern": pattern, "error": err.Error()})
|
||||
return ErrorResult(fmt.Sprintf("Invalid regex pattern syntax: %v. Please fix your regex and try again.", err))
|
||||
}
|
||||
|
||||
logger.InfoCF("discovery", "Regex search completed", map[string]any{"pattern": pattern, "results": len(res)})
|
||||
return formatDiscoveryResponse(t.registry, res, t.ttl)
|
||||
}
|
||||
|
||||
type BM25SearchTool struct {
|
||||
registry *ToolRegistry
|
||||
ttl int
|
||||
maxSearchResults int
|
||||
|
||||
// Cache: rebuilt only when the registry version changes.
|
||||
cacheMu sync.Mutex
|
||||
cachedEngine *bm25CachedEngine
|
||||
cacheVersion uint64
|
||||
}
|
||||
|
||||
func NewBM25SearchTool(r *ToolRegistry, ttl int, maxSearchResults int) *BM25SearchTool {
|
||||
return &BM25SearchTool{registry: r, ttl: ttl, maxSearchResults: maxSearchResults}
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) Name() string {
|
||||
return "tool_search_tool_bm25"
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) Description() string {
|
||||
return "Search available hidden tools on-demand using natural language query describing the action you need to perform. Returns JSON schemas of discovered tools."
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *BM25SearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok || strings.TrimSpace(query) == "" {
|
||||
// An empty string query will match every hidden tool,
|
||||
// dumping massive payloads into the context and burning tokens.
|
||||
return ErrorResult("Missing or invalid 'query' argument. Must be a non-empty string.")
|
||||
}
|
||||
|
||||
logger.DebugCF("discovery", "BM25 search", map[string]any{"query": query})
|
||||
|
||||
cached := t.getOrBuildEngine()
|
||||
if cached == nil {
|
||||
logger.DebugCF("discovery", "BM25 search: no hidden tools available", nil)
|
||||
return SilentResult("No tools found matching the query.")
|
||||
}
|
||||
|
||||
ranked := cached.engine.Search(query, t.maxSearchResults)
|
||||
if len(ranked) == 0 {
|
||||
logger.DebugCF("discovery", "BM25 search: no matches", map[string]any{"query": query})
|
||||
return SilentResult("No tools found matching the query.")
|
||||
}
|
||||
|
||||
results := make([]ToolSearchResult, len(ranked))
|
||||
for i, r := range ranked {
|
||||
results[i] = ToolSearchResult{
|
||||
Name: r.Document.Name,
|
||||
Description: r.Document.Description,
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("discovery", "BM25 search completed", map[string]any{"query": query, "results": len(results)})
|
||||
return formatDiscoveryResponse(t.registry, results, t.ttl)
|
||||
}
|
||||
|
||||
// ToolSearchResult represents the result returned to the LLM.
|
||||
// Parameters are omitted from the JSON response to save context tokens;
|
||||
// the LLM will see full schemas via ToProviderDefs after promotion.
|
||||
type ToolSearchResult struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) SearchRegex(pattern string, maxSearchResults int) ([]ToolSearchResult, error) {
|
||||
if maxSearchResults <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
regex, err := regexp.Compile("(?i)" + pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile regex pattern %q: %w", pattern, err)
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var results []ToolSearchResult
|
||||
|
||||
// Iterate in sorted order for deterministic results across calls.
|
||||
for _, name := range r.sortedToolNames() {
|
||||
entry := r.tools[name]
|
||||
// Search only among the hidden tools (Core tools are already visible)
|
||||
if !entry.IsCore {
|
||||
// Directly call interface methods! No reflection/unmarshalling needed.
|
||||
desc := entry.Tool.Description()
|
||||
|
||||
if regex.MatchString(name) || regex.MatchString(desc) {
|
||||
results = append(results, ToolSearchResult{
|
||||
Name: name,
|
||||
Description: desc,
|
||||
})
|
||||
if len(results) >= maxSearchResults {
|
||||
break // Stop searching once we hit the max! Saves CPU.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func formatDiscoveryResponse(registry *ToolRegistry, results []ToolSearchResult, ttl int) *ToolResult {
|
||||
if len(results) == 0 {
|
||||
return SilentResult("No tools found matching the query.")
|
||||
}
|
||||
|
||||
names := make([]string, len(results))
|
||||
for i, r := range results {
|
||||
names[i] = r.Name
|
||||
}
|
||||
registry.PromoteTools(names, ttl)
|
||||
logger.InfoCF("discovery", "Promoted tools", map[string]any{"tools": names, "ttl": ttl})
|
||||
|
||||
b, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return ErrorResult("Failed to format search results: " + err.Error())
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(
|
||||
"Found %d tools:\n%s\n\nSUCCESS: These tools have been temporarily UNLOCKED as native tools! In your next response, you can call them directly just like any normal tool",
|
||||
len(results),
|
||||
string(b),
|
||||
)
|
||||
|
||||
return SilentResult(msg)
|
||||
}
|
||||
|
||||
// Lightweight internal type used as corpus document for BM25.
|
||||
type searchDoc struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// bm25CachedEngine wraps a BM25Engine with its corpus snapshot.
|
||||
type bm25CachedEngine struct {
|
||||
engine *utils.BM25Engine[searchDoc]
|
||||
}
|
||||
|
||||
// snapshotToSearchDocs converts a HiddenToolSnapshot to BM25 searchDoc slice.
|
||||
func snapshotToSearchDocs(snap HiddenToolSnapshot) []searchDoc {
|
||||
docs := make([]searchDoc, len(snap.Docs))
|
||||
for i, d := range snap.Docs {
|
||||
docs[i] = searchDoc{Name: d.Name, Description: d.Description}
|
||||
}
|
||||
return docs
|
||||
}
|
||||
|
||||
// buildBM25Engine creates a BM25Engine from a slice of searchDocs.
|
||||
func buildBM25Engine(docs []searchDoc) *utils.BM25Engine[searchDoc] {
|
||||
return utils.NewBM25Engine(
|
||||
docs,
|
||||
func(doc searchDoc) string {
|
||||
return doc.Name + " " + doc.Description
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// getOrBuildEngine returns a cached BM25 engine, rebuilding it only when
|
||||
// the registry version has changed (new tools registered).
|
||||
func (t *BM25SearchTool) getOrBuildEngine() *bm25CachedEngine {
|
||||
// Fast path: optimistic check without locking.
|
||||
if t.cachedEngine != nil && t.cacheVersion == t.registry.Version() {
|
||||
return t.cachedEngine
|
||||
}
|
||||
|
||||
t.cacheMu.Lock()
|
||||
defer t.cacheMu.Unlock()
|
||||
|
||||
// Snapshot + version are read under a single registry RLock,
|
||||
// guaranteeing consistency (no TOCTOU).
|
||||
snap := t.registry.SnapshotHiddenTools()
|
||||
|
||||
// Re-check: another goroutine may have rebuilt while we waited for cacheMu.
|
||||
if t.cachedEngine != nil && t.cacheVersion == snap.Version {
|
||||
return t.cachedEngine
|
||||
}
|
||||
|
||||
docs := snapshotToSearchDocs(snap)
|
||||
if len(docs) == 0 {
|
||||
t.cachedEngine = nil
|
||||
t.cacheVersion = snap.Version
|
||||
return nil
|
||||
}
|
||||
|
||||
cached := &bm25CachedEngine{engine: buildBM25Engine(docs)}
|
||||
t.cachedEngine = cached
|
||||
t.cacheVersion = snap.Version
|
||||
logger.DebugCF("discovery", "BM25 engine rebuilt", map[string]any{"docs": len(docs), "version": snap.Version})
|
||||
return cached
|
||||
}
|
||||
|
||||
// SearchBM25 ranks hidden tools against query using BM25 via utils.BM25Engine.
|
||||
// This non-cached variant rebuilds the engine on every call. Used by tests
|
||||
// and any code that doesn't hold a BM25SearchTool instance.
|
||||
func (r *ToolRegistry) SearchBM25(query string, maxSearchResults int) []ToolSearchResult {
|
||||
snap := r.SnapshotHiddenTools()
|
||||
docs := snapshotToSearchDocs(snap)
|
||||
if len(docs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ranked := buildBM25Engine(docs).Search(query, maxSearchResults)
|
||||
if len(ranked) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]ToolSearchResult, len(ranked))
|
||||
for i, r := range ranked {
|
||||
out[i] = ToolSearchResult{
|
||||
Name: r.Document.Name,
|
||||
Description: r.Document.Description,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Dummy tool to fill the registry in our tests.
|
||||
type mockSearchableTool struct {
|
||||
name string
|
||||
desc string
|
||||
}
|
||||
|
||||
func (m *mockSearchableTool) Name() string { return m.name }
|
||||
func (m *mockSearchableTool) Description() string { return m.desc }
|
||||
func (m *mockSearchableTool) Parameters() map[string]any {
|
||||
return map[string]any{"type": "object"}
|
||||
}
|
||||
|
||||
func (m *mockSearchableTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
return SilentResult("mock executed: " + m.name)
|
||||
}
|
||||
|
||||
// Helper to initialize a populated ToolRegistry
|
||||
func setupPopulatedRegistry() *ToolRegistry {
|
||||
reg := NewToolRegistry()
|
||||
|
||||
// A core tool (NOT to be found by searches)
|
||||
reg.Register(&mockSearchableTool{
|
||||
name: "core_search",
|
||||
desc: "I am a visible core tool for searching files",
|
||||
})
|
||||
|
||||
// Hidden tools (must be found by searches)
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: "mcp_read_file",
|
||||
desc: "Read the contents of a system file",
|
||||
})
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: "mcp_list_dir",
|
||||
desc: "List directories and files in the system",
|
||||
})
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: "mcp_fetch_net",
|
||||
desc: "Fetch data from a network database",
|
||||
})
|
||||
|
||||
return reg
|
||||
}
|
||||
|
||||
func TestRegexSearchTool_Execute(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
tool := NewRegexSearchTool(reg, 5, 10)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Empty Pattern Error", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{})
|
||||
if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'pattern'") {
|
||||
t.Errorf("Expected missing pattern error, got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid Regex Syntax", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"pattern": "[unclosed"})
|
||||
if !res.IsError || !strings.Contains(res.ForLLM, "Invalid regex pattern syntax") {
|
||||
t.Errorf("Expected regex syntax error, got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("No Match Found", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"pattern": "alien"})
|
||||
if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") {
|
||||
t.Errorf("Expected 'no tools found' message, got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Successful Match & Promotion", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"pattern": "system"})
|
||||
|
||||
if res.IsError {
|
||||
t.Fatalf("Unexpected error: %v", res.ForLLM)
|
||||
}
|
||||
if !strings.Contains(res.ForLLM, "SUCCESS: These tools have been temporarily UNLOCKED") {
|
||||
t.Errorf("Expected success string, got: %v", res.ForLLM)
|
||||
}
|
||||
if !strings.Contains(res.ForLLM, "mcp_read_file") {
|
||||
t.Errorf("Expected 'mcp_read_file' in results")
|
||||
}
|
||||
|
||||
// Verify that the TTL has been updated for the tools found
|
||||
reg.mu.RLock()
|
||||
defer reg.mu.RUnlock()
|
||||
if reg.tools["mcp_read_file"].TTL != 5 {
|
||||
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 5, got %d", reg.tools["mcp_read_file"].TTL)
|
||||
}
|
||||
if reg.tools["mcp_fetch_net"].TTL != 0 {
|
||||
t.Errorf("Expected 'mcp_fetch_net' to NOT be promoted (TTL=0)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBM25SearchTool_Execute(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
tool := NewBM25SearchTool(reg, 3, 10)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Empty Query Error", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"query": " "})
|
||||
if !res.IsError || !strings.Contains(res.ForLLM, "Missing or invalid 'query'") {
|
||||
t.Errorf("Expected missing query error, got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("No Match Found", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"query": "aliens spaceships"})
|
||||
if res.IsError || !strings.Contains(res.ForLLM, "No tools found matching") {
|
||||
t.Errorf("Expected 'no tools found', got: %v", res.ForLLM)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Successful Match & Promotion", func(t *testing.T) {
|
||||
res := tool.Execute(ctx, map[string]any{"query": "read files"})
|
||||
|
||||
if res.IsError {
|
||||
t.Fatalf("Unexpected error: %v", res.ForLLM)
|
||||
}
|
||||
if !strings.Contains(res.ForLLM, "mcp_read_file") {
|
||||
t.Errorf("Expected 'mcp_read_file' in BM25 results")
|
||||
}
|
||||
|
||||
reg.mu.RLock()
|
||||
defer reg.mu.RUnlock()
|
||||
if reg.tools["mcp_read_file"].TTL != 3 {
|
||||
t.Errorf("Expected TTL of 'mcp_read_file' to be promoted to 3")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegexSearchTool_PatternTooLong(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
tool := NewRegexSearchTool(reg, 5, 10)
|
||||
ctx := context.Background()
|
||||
|
||||
longPattern := strings.Repeat("a", MaxRegexPatternLength+1)
|
||||
res := tool.Execute(ctx, map[string]any{"pattern": longPattern})
|
||||
if !res.IsError || !strings.Contains(res.ForLLM, "Pattern too long") {
|
||||
t.Errorf("Expected pattern too long error, got: %v", res.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchRegex_ZeroMaxResults(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
|
||||
res, err := reg.SearchRegex("mcp", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchRegex failed: %v", err)
|
||||
}
|
||||
if len(res) != 0 {
|
||||
t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBM25_ZeroMaxResults(t *testing.T) {
|
||||
reg := setupPopulatedRegistry()
|
||||
|
||||
res := reg.SearchBM25("read file", 0)
|
||||
if len(res) != 0 {
|
||||
t.Errorf("Expected 0 results with maxSearchResults=0, got %d", len(res))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchRegex_DeterministicOrder(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
for i := 0; i < 20; i++ {
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: fmt.Sprintf("tool_%02d", i),
|
||||
desc: "searchable tool",
|
||||
})
|
||||
}
|
||||
|
||||
// Run the same search multiple times and verify order is stable
|
||||
var firstRun []string
|
||||
for attempt := 0; attempt < 10; attempt++ {
|
||||
res, err := reg.SearchRegex("searchable", 20)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchRegex failed: %v", err)
|
||||
}
|
||||
|
||||
names := make([]string, len(res))
|
||||
for i, r := range res {
|
||||
names[i] = r.Name
|
||||
}
|
||||
|
||||
if attempt == 0 {
|
||||
firstRun = names
|
||||
} else {
|
||||
for i, name := range names {
|
||||
if name != firstRun[i] {
|
||||
t.Fatalf("Non-deterministic order at attempt %d, index %d: got %q, want %q",
|
||||
attempt, i, name, firstRun[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_SearchLimitsAndCoreFiltering(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
|
||||
// Add 1 Core and 10 Hidden, all containing the word "match"
|
||||
reg.Register(&mockSearchableTool{"core_match", "I am core with match"})
|
||||
for i := 0; i < 10; i++ {
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: fmt.Sprintf("hidden_match_%d", i),
|
||||
desc: "this has a match",
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Regex limits and core filtering", func(t *testing.T) {
|
||||
// Search with Regex and a limit of maxSearchResults = 4
|
||||
res, err := reg.SearchRegex("match", 4)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchRegex failed: %v", err)
|
||||
}
|
||||
|
||||
if len(res) != 4 {
|
||||
t.Errorf("Expected exactly 4 results due to limit, got %d", len(res))
|
||||
}
|
||||
|
||||
for _, r := range res {
|
||||
if r.Name == "core_match" {
|
||||
t.Errorf("SearchRegex returned a Core tool, which should be excluded")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BM25 limits and core filtering", func(t *testing.T) {
|
||||
// Search with BM25 and a limit of maxSearchResults = 3
|
||||
res := reg.SearchBM25("match", 3)
|
||||
|
||||
if len(res) != 3 {
|
||||
t.Errorf("Expected exactly 3 results due to limit, got %d", len(res))
|
||||
}
|
||||
|
||||
for _, r := range res {
|
||||
if r.Name == "core_match" {
|
||||
t.Errorf("SearchBM25 returned a Core tool, which should be excluded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGet_HiddenToolTTLLifecycle(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
reg.RegisterHidden(&mockSearchableTool{name: "hidden_tool", desc: "test"})
|
||||
|
||||
// TTL=0 at registration → not gettable
|
||||
_, ok := reg.Get("hidden_tool")
|
||||
if ok {
|
||||
t.Error("Expected hidden tool with TTL=0 to NOT be gettable")
|
||||
}
|
||||
|
||||
// Promote → gettable
|
||||
reg.PromoteTools([]string{"hidden_tool"}, 3)
|
||||
_, ok = reg.Get("hidden_tool")
|
||||
if !ok {
|
||||
t.Error("Expected promoted hidden tool to be gettable")
|
||||
}
|
||||
|
||||
// Tick down to 0 → not gettable again
|
||||
reg.TickTTL() // 3→2
|
||||
reg.TickTTL() // 2→1
|
||||
reg.TickTTL() // 1→0
|
||||
_, ok = reg.Get("hidden_tool")
|
||||
if ok {
|
||||
t.Error("Expected hidden tool with TTL ticked to 0 to NOT be gettable")
|
||||
}
|
||||
|
||||
// Core tools remain always gettable
|
||||
reg.Register(&mockSearchableTool{name: "core_tool", desc: "core"})
|
||||
_, ok = reg.Get("core_tool")
|
||||
if !ok {
|
||||
t.Error("Expected core tool to always be gettable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25CacheInvalidation(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
reg.RegisterHidden(&mockSearchableTool{name: "tool_alpha", desc: "alpha functionality"})
|
||||
|
||||
tool := NewBM25SearchTool(reg, 5, 10)
|
||||
ctx := context.Background()
|
||||
|
||||
// First search should find tool_alpha
|
||||
res := tool.Execute(ctx, map[string]any{"query": "alpha"})
|
||||
if !strings.Contains(res.ForLLM, "tool_alpha") {
|
||||
t.Fatalf("Expected 'tool_alpha' in first search, got: %v", res.ForLLM)
|
||||
}
|
||||
|
||||
// Register a new hidden tool
|
||||
reg.RegisterHidden(&mockSearchableTool{name: "tool_beta", desc: "beta functionality"})
|
||||
|
||||
// Cache should be invalidated; new tool should be findable
|
||||
res = tool.Execute(ctx, map[string]any{"query": "beta"})
|
||||
if !strings.Contains(res.ForLLM, "tool_beta") {
|
||||
t.Errorf("Expected 'tool_beta' after cache invalidation, got: %v", res.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromoteTools_ConcurrentWithTickTTL(t *testing.T) {
|
||||
reg := NewToolRegistry()
|
||||
for i := 0; i < 20; i++ {
|
||||
reg.RegisterHidden(&mockSearchableTool{
|
||||
name: fmt.Sprintf("concurrent_tool_%d", i),
|
||||
desc: "concurrent test tool",
|
||||
})
|
||||
}
|
||||
|
||||
names := make([]string, 20)
|
||||
for i := 0; i < 20; i++ {
|
||||
names[i] = fmt.Sprintf("concurrent_tool_%d", i)
|
||||
}
|
||||
|
||||
// Hammer PromoteTools and TickTTL concurrently to detect races
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for i := 0; i < 1000; i++ {
|
||||
reg.PromoteTools(names, 5)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
reg.TickTTL()
|
||||
}
|
||||
<-done
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
// Package utils provides shared, reusable algorithms.
|
||||
// This file implements a generic BM25 search engine.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// type MyDoc struct { ID string; Body string }
|
||||
//
|
||||
// corpus := []MyDoc{...}
|
||||
// engine := bm25.New(corpus, func(d MyDoc) string {
|
||||
// return d.ID + " " + d.Body
|
||||
// })
|
||||
// results := engine.Search("my query", 5)
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ── Tuning defaults ───────────────────────────────────────────────────────────
|
||||
|
||||
const (
|
||||
// DefaultBM25K1 is the term-frequency saturation factor (typical range 1.2–2.0).
|
||||
// Higher values give more weight to repeated terms.
|
||||
DefaultBM25K1 = 1.2
|
||||
|
||||
// DefaultBM25B is the document-length normalization factor (0 = none, 1 = full).
|
||||
DefaultBM25B = 0.75
|
||||
)
|
||||
|
||||
// BM25Engine is a query-time BM25 search engine over a generic corpus.
|
||||
// T is the document type; the caller supplies a TextFunc that extracts the
|
||||
// searchable text from each document.
|
||||
//
|
||||
// The engine is stateless between queries: no caching, no invalidation logic.
|
||||
// All indexing work is performed inside Search() on every call, making it
|
||||
// safe to use on corpora that change frequently.
|
||||
type BM25Engine[T any] struct {
|
||||
corpus []T
|
||||
textFunc func(T) string
|
||||
k1 float64
|
||||
b float64
|
||||
}
|
||||
|
||||
// BM25Option is a functional option to configure a BM25Engine.
|
||||
type BM25Option func(*bm25Config)
|
||||
|
||||
type bm25Config struct {
|
||||
k1 float64
|
||||
b float64
|
||||
}
|
||||
|
||||
// WithK1 overrides the term-frequency saturation constant (default 1.2).
|
||||
func WithK1(k1 float64) BM25Option {
|
||||
return func(c *bm25Config) { c.k1 = k1 }
|
||||
}
|
||||
|
||||
// WithB overrides the document-length normalization factor (default 0.75).
|
||||
func WithB(b float64) BM25Option {
|
||||
return func(c *bm25Config) { c.b = b }
|
||||
}
|
||||
|
||||
// NewBM25Engine creates a BM25Engine for the given corpus.
|
||||
//
|
||||
// - corpus : slice of documents of any type T.
|
||||
// - textFunc : function that returns the searchable text for a document.
|
||||
// - opts : optional tuning (WithK1, WithB).
|
||||
//
|
||||
// The corpus slice is referenced, not copied. Callers must not mutate it
|
||||
// concurrently with Search().
|
||||
func NewBM25Engine[T any](corpus []T, textFunc func(T) string, opts ...BM25Option) *BM25Engine[T] {
|
||||
cfg := bm25Config{k1: DefaultBM25K1, b: DefaultBM25B}
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
return &BM25Engine[T]{
|
||||
corpus: corpus,
|
||||
textFunc: textFunc,
|
||||
k1: cfg.k1,
|
||||
b: cfg.b,
|
||||
}
|
||||
}
|
||||
|
||||
// BM25Result is a single ranked result from a Search call.
|
||||
type BM25Result[T any] struct {
|
||||
Document T
|
||||
Score float32
|
||||
}
|
||||
|
||||
// Search ranks the corpus against query and returns the top-k results.
|
||||
// Returns an empty slice (not nil) when there are no matches.
|
||||
//
|
||||
// Complexity: O(N×L) for indexing + O(|Q|×avgPostingLen) for scoring,
|
||||
// where N = corpus size, L = average document length, Q = query terms.
|
||||
// Top-k extraction uses a fixed-size min-heap: O(candidates × log k).
|
||||
func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] {
|
||||
if topK <= 0 {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
queryTerms := bm25Tokenize(query)
|
||||
if len(queryTerms) == 0 {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
N := len(e.corpus)
|
||||
if N == 0 {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
// Step 1: build per-document tf + raw doc lengths
|
||||
type docEntry struct {
|
||||
tf map[string]uint32
|
||||
rawLen int
|
||||
}
|
||||
|
||||
entries := make([]docEntry, N)
|
||||
df := make(map[string]int, 64)
|
||||
totalLen := 0
|
||||
|
||||
for i, doc := range e.corpus {
|
||||
tokens := bm25Tokenize(e.textFunc(doc))
|
||||
totalLen += len(tokens)
|
||||
|
||||
tf := make(map[string]uint32, len(tokens))
|
||||
for _, t := range tokens {
|
||||
tf[t]++
|
||||
}
|
||||
// df: each term counts once per document (iterate the map, keys are unique)
|
||||
for t := range tf {
|
||||
df[t]++
|
||||
}
|
||||
|
||||
entries[i] = docEntry{tf: tf, rawLen: len(tokens)}
|
||||
}
|
||||
|
||||
avgDocLen := float64(totalLen) / float64(N)
|
||||
|
||||
// Step 2: pre-compute IDF and per-doc length normalization
|
||||
// IDF (Robertson smoothing): log( (N - df(t) + 0.5) / (df(t) + 0.5) + 1 )
|
||||
idf := make(map[string]float32, len(df))
|
||||
for term, freq := range df {
|
||||
idf[term] = float32(math.Log(
|
||||
(float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1,
|
||||
))
|
||||
}
|
||||
|
||||
// docLenNorm[i] = k1 * (1 - b + b * |doc_i| / avgDocLen)
|
||||
// Stored as float32 — sufficient precision for ranking.
|
||||
docLenNorm := make([]float32, N)
|
||||
for i, entry := range entries {
|
||||
docLenNorm[i] = float32(e.k1 * (1 - e.b + e.b*float64(entry.rawLen)/avgDocLen))
|
||||
}
|
||||
|
||||
// Step 3: build inverted index (posting lists)
|
||||
// Iterate the tf map directly — map keys are already unique, no seen-set needed.
|
||||
posting := make(map[string][]int32, len(df))
|
||||
for i, entry := range entries {
|
||||
for term := range entry.tf {
|
||||
posting[term] = append(posting[term], int32(i))
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: score via posting lists
|
||||
// Deduplicate query terms to avoid double-weighting the same term.
|
||||
unique := bm25Dedupe(queryTerms)
|
||||
|
||||
scores := make(map[int32]float32)
|
||||
for _, term := range unique {
|
||||
termIDF, ok := idf[term]
|
||||
if !ok {
|
||||
continue // term not in vocabulary → zero contribution
|
||||
}
|
||||
for _, docID := range posting[term] {
|
||||
freq := float32(entries[docID].tf[term])
|
||||
// TF_norm = freq * (k1+1) / (freq + docLenNorm)
|
||||
tfNorm := freq * float32(e.k1+1) / (freq + docLenNorm[docID])
|
||||
scores[docID] += termIDF * tfNorm
|
||||
}
|
||||
}
|
||||
|
||||
if len(scores) == 0 {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
// Step 5: top-K via fixed-size min-heap
|
||||
heap := make([]bm25ScoredDoc, 0, topK)
|
||||
|
||||
for docID, sc := range scores {
|
||||
switch {
|
||||
case len(heap) < topK:
|
||||
heap = append(heap, bm25ScoredDoc{docID: docID, score: sc})
|
||||
if len(heap) == topK {
|
||||
bm25MinHeapify(heap)
|
||||
}
|
||||
case sc > heap[0].score:
|
||||
heap[0] = bm25ScoredDoc{docID: docID, score: sc}
|
||||
bm25SiftDown(heap, 0)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(heap, func(i, j int) bool { return heap[i].score > heap[j].score })
|
||||
|
||||
out := make([]BM25Result[T], len(heap))
|
||||
for i, h := range heap {
|
||||
out[i] = BM25Result[T]{
|
||||
Document: e.corpus[h.docID],
|
||||
Score: h.score,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// bm25Tokenize splits s into lowercase tokens, stripping edge punctuation.
|
||||
func bm25Tokenize(s string) []string {
|
||||
raw := strings.Fields(strings.ToLower(s))
|
||||
out := raw[:0] // reuse backing array to avoid extra allocation
|
||||
for _, t := range raw {
|
||||
t = strings.Trim(t, ".,;:!?\"'()/\\-_")
|
||||
if t != "" {
|
||||
out = append(out, t)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// bm25Dedupe returns a new slice with duplicate tokens removed,
|
||||
// preserving first-occurrence order.
|
||||
func bm25Dedupe(tokens []string) []string {
|
||||
seen := make(map[string]struct{}, len(tokens))
|
||||
out := make([]string, 0, len(tokens))
|
||||
for _, t := range tokens {
|
||||
if _, ok := seen[t]; !ok {
|
||||
seen[t] = struct{}{}
|
||||
out = append(out, t)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type bm25ScoredDoc struct {
|
||||
docID int32
|
||||
score float32
|
||||
}
|
||||
|
||||
// bm25MinHeapify builds a min-heap in-place using Floyd's algorithm: O(k).
|
||||
func bm25MinHeapify(h []bm25ScoredDoc) {
|
||||
for i := len(h)/2 - 1; i >= 0; i-- {
|
||||
bm25SiftDown(h, i)
|
||||
}
|
||||
}
|
||||
|
||||
// bm25SiftDown restores the min-heap property starting at node i: O(log k).
|
||||
func bm25SiftDown(h []bm25ScoredDoc, i int) {
|
||||
n := len(h)
|
||||
for {
|
||||
smallest := i
|
||||
l, r := 2*i+1, 2*i+2
|
||||
if l < n && h[l].score < h[smallest].score {
|
||||
smallest = l
|
||||
}
|
||||
if r < n && h[r].score < h[smallest].score {
|
||||
smallest = r
|
||||
}
|
||||
if smallest == i {
|
||||
break
|
||||
}
|
||||
h[i], h[smallest] = h[smallest], h[i]
|
||||
i = smallest
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,175 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// testDoc is a generic structure for use in tests.
|
||||
type testDoc struct {
|
||||
ID int
|
||||
Text string
|
||||
}
|
||||
|
||||
func extractText(d testDoc) string {
|
||||
return d.Text
|
||||
}
|
||||
|
||||
func TestBM25Search_EdgeCases(t *testing.T) {
|
||||
corpus := []testDoc{
|
||||
{1, "hello world"},
|
||||
{2, "foo bar"},
|
||||
}
|
||||
engine := NewBM25Engine(corpus, extractText)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
topK int
|
||||
}{
|
||||
{"Zero topK", "hello", 0},
|
||||
{"Negative topK", "hello", -1},
|
||||
{"Empty query", "", 5},
|
||||
{"Query with only punctuation", "...,,,!!!", 5},
|
||||
{"No matches found", "golang", 5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := engine.Search(tt.query, tt.topK)
|
||||
if len(results) != 0 {
|
||||
t.Errorf("expected 0 results, got %d", len(results))
|
||||
}
|
||||
// Check that it never returns nil, but an empty slice
|
||||
if results == nil {
|
||||
t.Errorf("expected empty slice, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Search_EmptyCorpus(t *testing.T) {
|
||||
engine := NewBM25Engine([]testDoc{}, extractText)
|
||||
results := engine.Search("hello", 5)
|
||||
if len(results) != 0 || results == nil {
|
||||
t.Errorf("expected empty slice from empty corpus, got %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Search_RankingLogic(t *testing.T) {
|
||||
corpus := []testDoc{
|
||||
{1, "the quick brown fox jumps over the lazy dog"},
|
||||
{2, "quick fox"},
|
||||
{3, "quick quick quick fox"}, // High Term Frequency (TF)
|
||||
{4, "completely irrelevant document here"},
|
||||
}
|
||||
engine := NewBM25Engine(corpus, extractText)
|
||||
|
||||
t.Run("Term Frequency (TF) boosts score", func(t *testing.T) {
|
||||
results := engine.Search("quick", 5)
|
||||
if len(results) < 3 {
|
||||
t.Fatalf("expected at least 3 results, got %d", len(results))
|
||||
}
|
||||
// Doc 3 has the word "quick" repeated 3 times, it should beat Doc 2
|
||||
if results[0].Document.ID != 3 {
|
||||
t.Errorf("expected doc 3 to rank first due to high TF, got doc %d", results[0].Document.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Document Length penalty", func(t *testing.T) {
|
||||
results := engine.Search("fox", 5)
|
||||
if len(results) < 3 {
|
||||
t.Fatalf("expected at least 3 results, got %d", len(results))
|
||||
}
|
||||
// Doc 2 ("quick fox") is much shorter than Doc 1 ("the quick brown fox..."),
|
||||
// so, with equal Term Frequency for the word "fox" (1 time), Doc 2 wins.
|
||||
if results[0].Document.ID != 2 {
|
||||
t.Errorf("expected doc 2 to rank first due to shorter length, got doc %d", results[0].Document.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TopK limits results", func(t *testing.T) {
|
||||
results := engine.Search("quick", 2)
|
||||
if len(results) != 2 {
|
||||
t.Errorf("expected exactly 2 results, got %d", len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBM25Tokenize(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"Hello World", []string{"hello", "world"}},
|
||||
{" spaces everywhere ", []string{"spaces", "everywhere"}},
|
||||
{"punctuation... test!!!", []string{"punctuation", "test"}},
|
||||
{"(parentheses) and-hyphens", []string{"parentheses", "and-hyphens"}}, // hyphens trimmed from edges
|
||||
{"internal-hyphen is kept", []string{"internal-hyphen", "is", "kept"}},
|
||||
{".,;?!", []string{}}, // Becomes empty after trim
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := bm25Tokenize(tt.input)
|
||||
if len(got) == 0 && len(tt.expected) == 0 {
|
||||
return // Both empty
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.expected) {
|
||||
t.Errorf("bm25Tokenize(%q) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Dedupe(t *testing.T) {
|
||||
input := []string{"apple", "banana", "apple", "orange", "banana"}
|
||||
expected := []string{"apple", "banana", "orange"}
|
||||
|
||||
got := bm25Dedupe(input)
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("bm25Dedupe() = %v, want %v", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Options(t *testing.T) {
|
||||
corpus := []testDoc{{1, "test"}}
|
||||
|
||||
engine := NewBM25Engine(
|
||||
corpus,
|
||||
extractText,
|
||||
WithK1(2.5),
|
||||
WithB(0.9),
|
||||
)
|
||||
|
||||
if engine.k1 != 2.5 {
|
||||
t.Errorf("expected k1 to be 2.5, got %v", engine.k1)
|
||||
}
|
||||
if engine.b != 0.9 {
|
||||
t.Errorf("expected b to be 0.9, got %v", engine.b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBM25Search_SortingStability(t *testing.T) {
|
||||
// Ensure that sorting by heap returns in correct descending order
|
||||
corpus := []testDoc{
|
||||
{1, "golang is good"},
|
||||
{2, "golang golang"},
|
||||
{3, "golang golang golang"},
|
||||
{4, "golang golang golang golang"},
|
||||
}
|
||||
engine := NewBM25Engine(corpus, extractText)
|
||||
results := engine.Search("golang", 10)
|
||||
|
||||
if len(results) != 4 {
|
||||
t.Fatalf("expected 4 results, got %d", len(results))
|
||||
}
|
||||
|
||||
// Score should be strictly decreasing
|
||||
for i := 1; i < len(results); i++ {
|
||||
if results[i].Score > results[i-1].Score {
|
||||
t.Errorf("results not sorted correctly: result %d score (%v) > result %d score (%v)",
|
||||
i, results[i].Score, i-1, results[i-1].Score)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user