test(tools,utils): add ToolRegistry unit tests and fix Truncate panic on negative maxLen (#517)

Add comprehensive unit tests for the ToolRegistry covering registration,
lookup, execution, context injection, async callbacks, schema generation,
provider definition conversion, and concurrent access.

Fix a defensive edge case in Truncate where a negative maxLen would cause
a slice bounds panic, and add table-driven tests covering boundary
conditions, zero/negative lengths, and Unicode handling.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
King Tai
2026-02-22 18:40:59 +08:00
committed by GitHub
parent 6b55fb5f1d
commit cb0c8703fb
3 changed files with 459 additions and 0 deletions
+350
View File
@@ -0,0 +1,350 @@
package tools
import (
"context"
"strings"
"sync"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
// --- mock types ---
type mockRegistryTool struct {
name string
desc string
params map[string]interface{}
result *ToolResult
}
func (m *mockRegistryTool) Name() string { return m.name }
func (m *mockRegistryTool) Description() string { return m.desc }
func (m *mockRegistryTool) Parameters() map[string]interface{} { return m.params }
func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]interface{}) *ToolResult {
return m.result
}
type mockCtxTool struct {
mockRegistryTool
channel string
chatID string
}
func (m *mockCtxTool) SetContext(channel, chatID string) {
m.channel = channel
m.chatID = chatID
}
type mockAsyncRegistryTool struct {
mockRegistryTool
cb AsyncCallback
}
func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
m.cb = cb
}
// --- helpers ---
func newMockTool(name, desc string) *mockRegistryTool {
return &mockRegistryTool{
name: name,
desc: desc,
params: map[string]interface{}{"type": "object"},
result: SilentResult("ok"),
}
}
// --- tests ---
func TestNewToolRegistry(t *testing.T) {
r := NewToolRegistry()
if r.Count() != 0 {
t.Errorf("expected empty registry, got count %d", r.Count())
}
if len(r.List()) != 0 {
t.Errorf("expected empty list, got %v", r.List())
}
}
func TestToolRegistry_RegisterAndGet(t *testing.T) {
r := NewToolRegistry()
tool := newMockTool("echo", "echoes input")
r.Register(tool)
got, ok := r.Get("echo")
if !ok {
t.Fatal("expected to find registered tool")
}
if got.Name() != "echo" {
t.Errorf("expected name 'echo', got %q", got.Name())
}
}
func TestToolRegistry_Get_NotFound(t *testing.T) {
r := NewToolRegistry()
_, ok := r.Get("nonexistent")
if ok {
t.Error("expected ok=false for unregistered tool")
}
}
func TestToolRegistry_RegisterOverwrite(t *testing.T) {
r := NewToolRegistry()
r.Register(newMockTool("dup", "first"))
r.Register(newMockTool("dup", "second"))
if r.Count() != 1 {
t.Errorf("expected count 1 after overwrite, got %d", r.Count())
}
tool, _ := r.Get("dup")
if tool.Description() != "second" {
t.Errorf("expected overwritten description 'second', got %q", tool.Description())
}
}
func TestToolRegistry_Execute_Success(t *testing.T) {
r := NewToolRegistry()
r.Register(&mockRegistryTool{
name: "greet",
desc: "says hello",
params: map[string]interface{}{},
result: SilentResult("hello"),
})
result := r.Execute(context.Background(), "greet", nil)
if result.IsError {
t.Errorf("expected success, got error: %s", result.ForLLM)
}
if result.ForLLM != "hello" {
t.Errorf("expected ForLLM 'hello', got %q", result.ForLLM)
}
}
func TestToolRegistry_Execute_NotFound(t *testing.T) {
r := NewToolRegistry()
result := r.Execute(context.Background(), "missing", nil)
if !result.IsError {
t.Error("expected error for missing tool")
}
if !strings.Contains(result.ForLLM, "not found") {
t.Errorf("expected 'not found' in error, got %q", result.ForLLM)
}
if result.Err == nil {
t.Error("expected Err to be set via WithError")
}
}
func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
if ct.channel != "telegram" {
t.Errorf("expected channel 'telegram', got %q", ct.channel)
}
if ct.chatID != "chat-42" {
t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
}
}
func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
r := NewToolRegistry()
ct := &mockCtxTool{
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
}
r.Register(ct)
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
if ct.channel != "" || ct.chatID != "" {
t.Error("SetContext should not be called with empty channel/chatID")
}
}
func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
r := NewToolRegistry()
at := &mockAsyncRegistryTool{
mockRegistryTool: *newMockTool("async_tool", "async work"),
}
at.result = AsyncResult("started")
r.Register(at)
called := false
cb := func(_ context.Context, _ *ToolResult) { called = true }
result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
if at.cb == nil {
t.Error("expected SetCallback to have been called")
}
if !result.Async {
t.Error("expected async result")
}
at.cb(context.Background(), SilentResult("done"))
if !called {
t.Error("expected callback to be invoked")
}
}
func TestToolRegistry_GetDefinitions(t *testing.T) {
r := NewToolRegistry()
r.Register(newMockTool("alpha", "tool A"))
defs := r.GetDefinitions()
if len(defs) != 1 {
t.Fatalf("expected 1 definition, got %d", len(defs))
}
if defs[0]["type"] != "function" {
t.Errorf("expected type 'function', got %v", defs[0]["type"])
}
fn, ok := defs[0]["function"].(map[string]interface{})
if !ok {
t.Fatal("expected 'function' key to be a map")
}
if fn["name"] != "alpha" {
t.Errorf("expected name 'alpha', got %v", fn["name"])
}
if fn["description"] != "tool A" {
t.Errorf("expected description 'tool A', got %v", fn["description"])
}
}
func TestToolRegistry_ToProviderDefs(t *testing.T) {
r := NewToolRegistry()
params := map[string]interface{}{"type": "object", "properties": map[string]interface{}{}}
r.Register(&mockRegistryTool{
name: "beta",
desc: "tool B",
params: params,
result: SilentResult("ok"),
})
defs := r.ToProviderDefs()
if len(defs) != 1 {
t.Fatalf("expected 1 provider def, got %d", len(defs))
}
want := providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "beta",
Description: "tool B",
Parameters: params,
},
}
got := defs[0]
if got.Type != want.Type {
t.Errorf("Type: want %q, got %q", want.Type, got.Type)
}
if got.Function.Name != want.Function.Name {
t.Errorf("Name: want %q, got %q", want.Function.Name, got.Function.Name)
}
if got.Function.Description != want.Function.Description {
t.Errorf("Description: want %q, got %q", want.Function.Description, got.Function.Description)
}
}
func TestToolRegistry_List(t *testing.T) {
r := NewToolRegistry()
r.Register(newMockTool("x", ""))
r.Register(newMockTool("y", ""))
names := r.List()
if len(names) != 2 {
t.Fatalf("expected 2 names, got %d", len(names))
}
nameSet := map[string]bool{}
for _, n := range names {
nameSet[n] = true
}
if !nameSet["x"] || !nameSet["y"] {
t.Errorf("expected names {x, y}, got %v", names)
}
}
func TestToolRegistry_Count(t *testing.T) {
r := NewToolRegistry()
if r.Count() != 0 {
t.Errorf("expected 0, got %d", r.Count())
}
r.Register(newMockTool("a", ""))
r.Register(newMockTool("b", ""))
if r.Count() != 2 {
t.Errorf("expected 2, got %d", r.Count())
}
r.Register(newMockTool("a", "replaced"))
if r.Count() != 2 {
t.Errorf("expected 2 after overwrite, got %d", r.Count())
}
}
func TestToolRegistry_GetSummaries(t *testing.T) {
r := NewToolRegistry()
r.Register(newMockTool("read_file", "Reads a file"))
summaries := r.GetSummaries()
if len(summaries) != 1 {
t.Fatalf("expected 1 summary, got %d", len(summaries))
}
if !strings.Contains(summaries[0], "`read_file`") {
t.Errorf("expected backtick-quoted name in summary, got %q", summaries[0])
}
if !strings.Contains(summaries[0], "Reads a file") {
t.Errorf("expected description in summary, got %q", summaries[0])
}
}
func TestToolToSchema(t *testing.T) {
tool := newMockTool("demo", "demo tool")
schema := ToolToSchema(tool)
if schema["type"] != "function" {
t.Errorf("expected type 'function', got %v", schema["type"])
}
fn, ok := schema["function"].(map[string]interface{})
if !ok {
t.Fatal("expected 'function' to be a map")
}
if fn["name"] != "demo" {
t.Errorf("expected name 'demo', got %v", fn["name"])
}
if fn["description"] != "demo tool" {
t.Errorf("expected description 'demo tool', got %v", fn["description"])
}
if fn["parameters"] == nil {
t.Error("expected parameters to be set")
}
}
func TestToolRegistry_ConcurrentAccess(t *testing.T) {
r := NewToolRegistry()
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
name := string(rune('A' + n%26))
r.Register(newMockTool(name, "concurrent"))
r.Get(name)
r.Count()
r.List()
r.GetDefinitions()
}(i)
}
wg.Wait()
if r.Count() == 0 {
t.Error("expected tools to be registered after concurrent access")
}
}
+3
View File
@@ -4,6 +4,9 @@ package utils
// Handles multi-byte Unicode characters properly.
// If the string is truncated, "..." is appended to indicate truncation.
func Truncate(s string, maxLen int) string {
if maxLen <= 0 {
return ""
}
runes := []rune(s)
if len(runes) <= maxLen {
return s
+106
View File
@@ -0,0 +1,106 @@
package utils
import "testing"
func TestTruncate(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
want string
}{
{
name: "short string unchanged",
input: "hi",
maxLen: 10,
want: "hi",
},
{
name: "exact length unchanged",
input: "hello",
maxLen: 5,
want: "hello",
},
{
name: "long string truncated with ellipsis",
input: "hello world",
maxLen: 8,
want: "hello...",
},
{
name: "maxLen equals 4 leaves 1 char plus ellipsis",
input: "abcdef",
maxLen: 4,
want: "a...",
},
{
name: "maxLen 3 returns first 3 chars without ellipsis",
input: "abcdef",
maxLen: 3,
want: "abc",
},
{
name: "maxLen 2 returns first 2 chars",
input: "abcdef",
maxLen: 2,
want: "ab",
},
{
name: "maxLen 1 returns first char",
input: "abcdef",
maxLen: 1,
want: "a",
},
{
name: "maxLen 0 returns empty",
input: "hello",
maxLen: 0,
want: "",
},
{
name: "negative maxLen returns empty",
input: "hello",
maxLen: -1,
want: "",
},
{
name: "empty string unchanged",
input: "",
maxLen: 5,
want: "",
},
{
name: "empty string with zero maxLen",
input: "",
maxLen: 0,
want: "",
},
{
name: "unicode truncated correctly",
input: "\U0001f600\U0001f601\U0001f602\U0001f603\U0001f604",
maxLen: 4,
want: "\U0001f600...",
},
{
name: "unicode short enough",
input: "\u00e9\u00e8",
maxLen: 5,
want: "\u00e9\u00e8",
},
{
name: "mixed ascii and unicode",
input: "Go\U0001f680\U0001f525\U0001f4a5\U0001f30d",
maxLen: 5,
want: "Go...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Truncate(tt.input, tt.maxLen)
if got != tt.want {
t.Errorf("Truncate(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
}
})
}
}