From cb0c8703fb9d5ce373bc0c4f770177ba66508b25 Mon Sep 17 00:00:00 2001 From: King Tai <109292982+CrisisAlpha@users.noreply.github.com> Date: Sun, 22 Feb 2026 18:40:59 +0800 Subject: [PATCH] test(tools,utils): add ToolRegistry unit tests and fix Truncate panic on negative maxLen (#517) Add comprehensive unit tests for the ToolRegistry covering registration, lookup, execution, context injection, async callbacks, schema generation, provider definition conversion, and concurrent access. Fix a defensive edge case in Truncate where a negative maxLen would cause a slice bounds panic, and add table-driven tests covering boundary conditions, zero/negative lengths, and Unicode handling. Co-authored-by: Cursor --- pkg/tools/registry_test.go | 350 +++++++++++++++++++++++++++++++++++++ pkg/utils/string.go | 3 + pkg/utils/string_test.go | 106 +++++++++++ 3 files changed, 459 insertions(+) create mode 100644 pkg/tools/registry_test.go create mode 100644 pkg/utils/string_test.go diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go new file mode 100644 index 000000000..33978e543 --- /dev/null +++ b/pkg/tools/registry_test.go @@ -0,0 +1,350 @@ +package tools + +import ( + "context" + "strings" + "sync" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// --- mock types --- + +type mockRegistryTool struct { + name string + desc string + params map[string]interface{} + result *ToolResult +} + +func (m *mockRegistryTool) Name() string { return m.name } +func (m *mockRegistryTool) Description() string { return m.desc } +func (m *mockRegistryTool) Parameters() map[string]interface{} { return m.params } +func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]interface{}) *ToolResult { + return m.result +} + +type mockCtxTool struct { + mockRegistryTool + channel string + chatID string +} + +func (m *mockCtxTool) SetContext(channel, chatID string) { + m.channel = channel + m.chatID = chatID +} + +type mockAsyncRegistryTool struct { + mockRegistryTool + cb AsyncCallback +} + +func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) { + m.cb = cb +} + +// --- helpers --- + +func newMockTool(name, desc string) *mockRegistryTool { + return &mockRegistryTool{ + name: name, + desc: desc, + params: map[string]interface{}{"type": "object"}, + result: SilentResult("ok"), + } +} + +// --- tests --- + +func TestNewToolRegistry(t *testing.T) { + r := NewToolRegistry() + if r.Count() != 0 { + t.Errorf("expected empty registry, got count %d", r.Count()) + } + if len(r.List()) != 0 { + t.Errorf("expected empty list, got %v", r.List()) + } +} + +func TestToolRegistry_RegisterAndGet(t *testing.T) { + r := NewToolRegistry() + tool := newMockTool("echo", "echoes input") + r.Register(tool) + + got, ok := r.Get("echo") + if !ok { + t.Fatal("expected to find registered tool") + } + if got.Name() != "echo" { + t.Errorf("expected name 'echo', got %q", got.Name()) + } +} + +func TestToolRegistry_Get_NotFound(t *testing.T) { + r := NewToolRegistry() + _, ok := r.Get("nonexistent") + if ok { + t.Error("expected ok=false for unregistered tool") + } +} + +func TestToolRegistry_RegisterOverwrite(t *testing.T) { + r := NewToolRegistry() + r.Register(newMockTool("dup", "first")) + r.Register(newMockTool("dup", "second")) + + if r.Count() != 1 { + t.Errorf("expected count 1 after overwrite, got %d", r.Count()) + } + tool, _ := r.Get("dup") + if tool.Description() != "second" { + t.Errorf("expected overwritten description 'second', got %q", tool.Description()) + } +} + +func TestToolRegistry_Execute_Success(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockRegistryTool{ + name: "greet", + desc: "says hello", + params: map[string]interface{}{}, + result: SilentResult("hello"), + }) + + result := r.Execute(context.Background(), "greet", nil) + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + if result.ForLLM != "hello" { + t.Errorf("expected ForLLM 'hello', got %q", result.ForLLM) + } +} + +func TestToolRegistry_Execute_NotFound(t *testing.T) { + r := NewToolRegistry() + result := r.Execute(context.Background(), "missing", nil) + if !result.IsError { + t.Error("expected error for missing tool") + } + if !strings.Contains(result.ForLLM, "not found") { + t.Errorf("expected 'not found' in error, got %q", result.ForLLM) + } + if result.Err == nil { + t.Error("expected Err to be set via WithError") + } +} + +func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) { + r := NewToolRegistry() + ct := &mockCtxTool{ + mockRegistryTool: *newMockTool("ctx_tool", "needs context"), + } + r.Register(ct) + + r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil) + + if ct.channel != "telegram" { + t.Errorf("expected channel 'telegram', got %q", ct.channel) + } + if ct.chatID != "chat-42" { + t.Errorf("expected chatID 'chat-42', got %q", ct.chatID) + } +} + +func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) { + r := NewToolRegistry() + ct := &mockCtxTool{ + mockRegistryTool: *newMockTool("ctx_tool", "needs context"), + } + r.Register(ct) + + r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil) + + if ct.channel != "" || ct.chatID != "" { + t.Error("SetContext should not be called with empty channel/chatID") + } +} + +func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) { + r := NewToolRegistry() + at := &mockAsyncRegistryTool{ + mockRegistryTool: *newMockTool("async_tool", "async work"), + } + at.result = AsyncResult("started") + r.Register(at) + + called := false + cb := func(_ context.Context, _ *ToolResult) { called = true } + + result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb) + if at.cb == nil { + t.Error("expected SetCallback to have been called") + } + if !result.Async { + t.Error("expected async result") + } + + at.cb(context.Background(), SilentResult("done")) + if !called { + t.Error("expected callback to be invoked") + } +} + +func TestToolRegistry_GetDefinitions(t *testing.T) { + r := NewToolRegistry() + r.Register(newMockTool("alpha", "tool A")) + + defs := r.GetDefinitions() + if len(defs) != 1 { + t.Fatalf("expected 1 definition, got %d", len(defs)) + } + if defs[0]["type"] != "function" { + t.Errorf("expected type 'function', got %v", defs[0]["type"]) + } + fn, ok := defs[0]["function"].(map[string]interface{}) + if !ok { + t.Fatal("expected 'function' key to be a map") + } + if fn["name"] != "alpha" { + t.Errorf("expected name 'alpha', got %v", fn["name"]) + } + if fn["description"] != "tool A" { + t.Errorf("expected description 'tool A', got %v", fn["description"]) + } +} + +func TestToolRegistry_ToProviderDefs(t *testing.T) { + r := NewToolRegistry() + params := map[string]interface{}{"type": "object", "properties": map[string]interface{}{}} + r.Register(&mockRegistryTool{ + name: "beta", + desc: "tool B", + params: params, + result: SilentResult("ok"), + }) + + defs := r.ToProviderDefs() + if len(defs) != 1 { + t.Fatalf("expected 1 provider def, got %d", len(defs)) + } + + want := providers.ToolDefinition{ + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "beta", + Description: "tool B", + Parameters: params, + }, + } + got := defs[0] + if got.Type != want.Type { + t.Errorf("Type: want %q, got %q", want.Type, got.Type) + } + if got.Function.Name != want.Function.Name { + t.Errorf("Name: want %q, got %q", want.Function.Name, got.Function.Name) + } + if got.Function.Description != want.Function.Description { + t.Errorf("Description: want %q, got %q", want.Function.Description, got.Function.Description) + } +} + +func TestToolRegistry_List(t *testing.T) { + r := NewToolRegistry() + r.Register(newMockTool("x", "")) + r.Register(newMockTool("y", "")) + + names := r.List() + if len(names) != 2 { + t.Fatalf("expected 2 names, got %d", len(names)) + } + + nameSet := map[string]bool{} + for _, n := range names { + nameSet[n] = true + } + if !nameSet["x"] || !nameSet["y"] { + t.Errorf("expected names {x, y}, got %v", names) + } +} + +func TestToolRegistry_Count(t *testing.T) { + r := NewToolRegistry() + if r.Count() != 0 { + t.Errorf("expected 0, got %d", r.Count()) + } + + r.Register(newMockTool("a", "")) + r.Register(newMockTool("b", "")) + if r.Count() != 2 { + t.Errorf("expected 2, got %d", r.Count()) + } + + r.Register(newMockTool("a", "replaced")) + if r.Count() != 2 { + t.Errorf("expected 2 after overwrite, got %d", r.Count()) + } +} + +func TestToolRegistry_GetSummaries(t *testing.T) { + r := NewToolRegistry() + r.Register(newMockTool("read_file", "Reads a file")) + + summaries := r.GetSummaries() + if len(summaries) != 1 { + t.Fatalf("expected 1 summary, got %d", len(summaries)) + } + if !strings.Contains(summaries[0], "`read_file`") { + t.Errorf("expected backtick-quoted name in summary, got %q", summaries[0]) + } + if !strings.Contains(summaries[0], "Reads a file") { + t.Errorf("expected description in summary, got %q", summaries[0]) + } +} + +func TestToolToSchema(t *testing.T) { + tool := newMockTool("demo", "demo tool") + schema := ToolToSchema(tool) + + if schema["type"] != "function" { + t.Errorf("expected type 'function', got %v", schema["type"]) + } + fn, ok := schema["function"].(map[string]interface{}) + if !ok { + t.Fatal("expected 'function' to be a map") + } + if fn["name"] != "demo" { + t.Errorf("expected name 'demo', got %v", fn["name"]) + } + if fn["description"] != "demo tool" { + t.Errorf("expected description 'demo tool', got %v", fn["description"]) + } + if fn["parameters"] == nil { + t.Error("expected parameters to be set") + } +} + +func TestToolRegistry_ConcurrentAccess(t *testing.T) { + r := NewToolRegistry() + var wg sync.WaitGroup + + for i := 0; i < 50; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + name := string(rune('A' + n%26)) + r.Register(newMockTool(name, "concurrent")) + r.Get(name) + r.Count() + r.List() + r.GetDefinitions() + }(i) + } + + wg.Wait() + + if r.Count() == 0 { + t.Error("expected tools to be registered after concurrent access") + } +} diff --git a/pkg/utils/string.go b/pkg/utils/string.go index 7a6aa37cc..62d9beee0 100644 --- a/pkg/utils/string.go +++ b/pkg/utils/string.go @@ -4,6 +4,9 @@ package utils // Handles multi-byte Unicode characters properly. // If the string is truncated, "..." is appended to indicate truncation. func Truncate(s string, maxLen int) string { + if maxLen <= 0 { + return "" + } runes := []rune(s) if len(runes) <= maxLen { return s diff --git a/pkg/utils/string_test.go b/pkg/utils/string_test.go new file mode 100644 index 000000000..a44ead228 --- /dev/null +++ b/pkg/utils/string_test.go @@ -0,0 +1,106 @@ +package utils + +import "testing" + +func TestTruncate(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + want string + }{ + { + name: "short string unchanged", + input: "hi", + maxLen: 10, + want: "hi", + }, + { + name: "exact length unchanged", + input: "hello", + maxLen: 5, + want: "hello", + }, + { + name: "long string truncated with ellipsis", + input: "hello world", + maxLen: 8, + want: "hello...", + }, + { + name: "maxLen equals 4 leaves 1 char plus ellipsis", + input: "abcdef", + maxLen: 4, + want: "a...", + }, + { + name: "maxLen 3 returns first 3 chars without ellipsis", + input: "abcdef", + maxLen: 3, + want: "abc", + }, + { + name: "maxLen 2 returns first 2 chars", + input: "abcdef", + maxLen: 2, + want: "ab", + }, + { + name: "maxLen 1 returns first char", + input: "abcdef", + maxLen: 1, + want: "a", + }, + { + name: "maxLen 0 returns empty", + input: "hello", + maxLen: 0, + want: "", + }, + { + name: "negative maxLen returns empty", + input: "hello", + maxLen: -1, + want: "", + }, + { + name: "empty string unchanged", + input: "", + maxLen: 5, + want: "", + }, + { + name: "empty string with zero maxLen", + input: "", + maxLen: 0, + want: "", + }, + { + name: "unicode truncated correctly", + input: "\U0001f600\U0001f601\U0001f602\U0001f603\U0001f604", + maxLen: 4, + want: "\U0001f600...", + }, + { + name: "unicode short enough", + input: "\u00e9\u00e8", + maxLen: 5, + want: "\u00e9\u00e8", + }, + { + name: "mixed ascii and unicode", + input: "Go\U0001f680\U0001f525\U0001f4a5\U0001f30d", + maxLen: 5, + want: "Go...", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Truncate(tt.input, tt.maxLen) + if got != tt.want { + t.Errorf("Truncate(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want) + } + }) + } +}