mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge remote-tracking branch 'origin/main' into refactor/model-to-model-name
This commit is contained in:
@@ -97,6 +97,10 @@ func registerSharedTools(
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
|
||||
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
|
||||
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
|
||||
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
|
||||
@@ -428,6 +428,13 @@ type BraveConfig struct {
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type TavilyConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type DuckDuckGoConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"`
|
||||
@@ -441,6 +448,7 @@ type PerplexityConfig struct {
|
||||
|
||||
type WebToolsConfig struct {
|
||||
Brave BraveConfig `json:"brave"`
|
||||
Tavily TavilyConfig `json:"tavily"`
|
||||
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig `json:"perplexity"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// --- mock types ---
|
||||
|
||||
type mockRegistryTool struct {
|
||||
name string
|
||||
desc string
|
||||
params map[string]any
|
||||
result *ToolResult
|
||||
}
|
||||
|
||||
func (m *mockRegistryTool) Name() string { return m.name }
|
||||
func (m *mockRegistryTool) Description() string { return m.desc }
|
||||
func (m *mockRegistryTool) Parameters() map[string]any { return m.params }
|
||||
func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolResult {
|
||||
return m.result
|
||||
}
|
||||
|
||||
type mockCtxTool struct {
|
||||
mockRegistryTool
|
||||
channel string
|
||||
chatID string
|
||||
}
|
||||
|
||||
func (m *mockCtxTool) SetContext(channel, chatID string) {
|
||||
m.channel = channel
|
||||
m.chatID = chatID
|
||||
}
|
||||
|
||||
type mockAsyncRegistryTool struct {
|
||||
mockRegistryTool
|
||||
cb AsyncCallback
|
||||
}
|
||||
|
||||
func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
|
||||
m.cb = cb
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func newMockTool(name, desc string) *mockRegistryTool {
|
||||
return &mockRegistryTool{
|
||||
name: name,
|
||||
desc: desc,
|
||||
params: map[string]any{"type": "object"},
|
||||
result: SilentResult("ok"),
|
||||
}
|
||||
}
|
||||
|
||||
// --- tests ---
|
||||
|
||||
func TestNewToolRegistry(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
if r.Count() != 0 {
|
||||
t.Errorf("expected empty registry, got count %d", r.Count())
|
||||
}
|
||||
if len(r.List()) != 0 {
|
||||
t.Errorf("expected empty list, got %v", r.List())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_RegisterAndGet(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
tool := newMockTool("echo", "echoes input")
|
||||
r.Register(tool)
|
||||
|
||||
got, ok := r.Get("echo")
|
||||
if !ok {
|
||||
t.Fatal("expected to find registered tool")
|
||||
}
|
||||
if got.Name() != "echo" {
|
||||
t.Errorf("expected name 'echo', got %q", got.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Get_NotFound(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
_, ok := r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected ok=false for unregistered tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_RegisterOverwrite(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("dup", "first"))
|
||||
r.Register(newMockTool("dup", "second"))
|
||||
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected count 1 after overwrite, got %d", r.Count())
|
||||
}
|
||||
tool, _ := r.Get("dup")
|
||||
if tool.Description() != "second" {
|
||||
t.Errorf("expected overwritten description 'second', got %q", tool.Description())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Execute_Success(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(&mockRegistryTool{
|
||||
name: "greet",
|
||||
desc: "says hello",
|
||||
params: map[string]any{},
|
||||
result: SilentResult("hello"),
|
||||
})
|
||||
|
||||
result := r.Execute(context.Background(), "greet", nil)
|
||||
if result.IsError {
|
||||
t.Errorf("expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM != "hello" {
|
||||
t.Errorf("expected ForLLM 'hello', got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Execute_NotFound(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
result := r.Execute(context.Background(), "missing", nil)
|
||||
if !result.IsError {
|
||||
t.Error("expected error for missing tool")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "not found") {
|
||||
t.Errorf("expected 'not found' in error, got %q", result.ForLLM)
|
||||
}
|
||||
if result.Err == nil {
|
||||
t.Error("expected Err to be set via WithError")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
ct := &mockCtxTool{
|
||||
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
|
||||
}
|
||||
r.Register(ct)
|
||||
|
||||
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
|
||||
|
||||
if ct.channel != "telegram" {
|
||||
t.Errorf("expected channel 'telegram', got %q", ct.channel)
|
||||
}
|
||||
if ct.chatID != "chat-42" {
|
||||
t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
ct := &mockCtxTool{
|
||||
mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
|
||||
}
|
||||
r.Register(ct)
|
||||
|
||||
r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
|
||||
|
||||
if ct.channel != "" || ct.chatID != "" {
|
||||
t.Error("SetContext should not be called with empty channel/chatID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
at := &mockAsyncRegistryTool{
|
||||
mockRegistryTool: *newMockTool("async_tool", "async work"),
|
||||
}
|
||||
at.result = AsyncResult("started")
|
||||
r.Register(at)
|
||||
|
||||
called := false
|
||||
cb := func(_ context.Context, _ *ToolResult) { called = true }
|
||||
|
||||
result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
|
||||
if at.cb == nil {
|
||||
t.Error("expected SetCallback to have been called")
|
||||
}
|
||||
if !result.Async {
|
||||
t.Error("expected async result")
|
||||
}
|
||||
|
||||
at.cb(context.Background(), SilentResult("done"))
|
||||
if !called {
|
||||
t.Error("expected callback to be invoked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_GetDefinitions(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("alpha", "tool A"))
|
||||
|
||||
defs := r.GetDefinitions()
|
||||
if len(defs) != 1 {
|
||||
t.Fatalf("expected 1 definition, got %d", len(defs))
|
||||
}
|
||||
if defs[0]["type"] != "function" {
|
||||
t.Errorf("expected type 'function', got %v", defs[0]["type"])
|
||||
}
|
||||
fn, ok := defs[0]["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected 'function' key to be a map")
|
||||
}
|
||||
if fn["name"] != "alpha" {
|
||||
t.Errorf("expected name 'alpha', got %v", fn["name"])
|
||||
}
|
||||
if fn["description"] != "tool A" {
|
||||
t.Errorf("expected description 'tool A', got %v", fn["description"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ToProviderDefs(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
params := map[string]any{"type": "object", "properties": map[string]any{}}
|
||||
r.Register(&mockRegistryTool{
|
||||
name: "beta",
|
||||
desc: "tool B",
|
||||
params: params,
|
||||
result: SilentResult("ok"),
|
||||
})
|
||||
|
||||
defs := r.ToProviderDefs()
|
||||
if len(defs) != 1 {
|
||||
t.Fatalf("expected 1 provider def, got %d", len(defs))
|
||||
}
|
||||
|
||||
want := providers.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "beta",
|
||||
Description: "tool B",
|
||||
Parameters: params,
|
||||
},
|
||||
}
|
||||
got := defs[0]
|
||||
if got.Type != want.Type {
|
||||
t.Errorf("Type: want %q, got %q", want.Type, got.Type)
|
||||
}
|
||||
if got.Function.Name != want.Function.Name {
|
||||
t.Errorf("Name: want %q, got %q", want.Function.Name, got.Function.Name)
|
||||
}
|
||||
if got.Function.Description != want.Function.Description {
|
||||
t.Errorf("Description: want %q, got %q", want.Function.Description, got.Function.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_List(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("x", ""))
|
||||
r.Register(newMockTool("y", ""))
|
||||
|
||||
names := r.List()
|
||||
if len(names) != 2 {
|
||||
t.Fatalf("expected 2 names, got %d", len(names))
|
||||
}
|
||||
|
||||
nameSet := map[string]bool{}
|
||||
for _, n := range names {
|
||||
nameSet[n] = true
|
||||
}
|
||||
if !nameSet["x"] || !nameSet["y"] {
|
||||
t.Errorf("expected names {x, y}, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Count(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
if r.Count() != 0 {
|
||||
t.Errorf("expected 0, got %d", r.Count())
|
||||
}
|
||||
|
||||
r.Register(newMockTool("a", ""))
|
||||
r.Register(newMockTool("b", ""))
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2, got %d", r.Count())
|
||||
}
|
||||
|
||||
r.Register(newMockTool("a", "replaced"))
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2 after overwrite, got %d", r.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_GetSummaries(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
r.Register(newMockTool("read_file", "Reads a file"))
|
||||
|
||||
summaries := r.GetSummaries()
|
||||
if len(summaries) != 1 {
|
||||
t.Fatalf("expected 1 summary, got %d", len(summaries))
|
||||
}
|
||||
if !strings.Contains(summaries[0], "`read_file`") {
|
||||
t.Errorf("expected backtick-quoted name in summary, got %q", summaries[0])
|
||||
}
|
||||
if !strings.Contains(summaries[0], "Reads a file") {
|
||||
t.Errorf("expected description in summary, got %q", summaries[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolToSchema(t *testing.T) {
|
||||
tool := newMockTool("demo", "demo tool")
|
||||
schema := ToolToSchema(tool)
|
||||
|
||||
if schema["type"] != "function" {
|
||||
t.Errorf("expected type 'function', got %v", schema["type"])
|
||||
}
|
||||
fn, ok := schema["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected 'function' to be a map")
|
||||
}
|
||||
if fn["name"] != "demo" {
|
||||
t.Errorf("expected name 'demo', got %v", fn["name"])
|
||||
}
|
||||
if fn["description"] != "demo tool" {
|
||||
t.Errorf("expected description 'demo tool', got %v", fn["description"])
|
||||
}
|
||||
if fn["parameters"] == nil {
|
||||
t.Error("expected parameters to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_ConcurrentAccess(t *testing.T) {
|
||||
r := NewToolRegistry()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
name := string(rune('A' + n%26))
|
||||
r.Register(newMockTool(name, "concurrent"))
|
||||
r.Get(name)
|
||||
r.Count()
|
||||
r.List()
|
||||
r.GetDefinitions()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if r.Count() == 0 {
|
||||
t.Error("expected tools to be registered after concurrent access")
|
||||
}
|
||||
}
|
||||
+96
-1
@@ -1,6 +1,7 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -84,6 +85,88 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
type TavilySearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := p.baseURL
|
||||
if searchURL == "" {
|
||||
searchURL = "https://api.tavily.com/search"
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"api_key": p.apiKey,
|
||||
"query": query,
|
||||
"search_depth": "advanced",
|
||||
"include_answer": false,
|
||||
"include_images": false,
|
||||
"include_raw_content": false,
|
||||
"max_results": count,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Content != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Content))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct{}
|
||||
|
||||
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
@@ -256,6 +339,10 @@ type WebSearchToolOptions struct {
|
||||
BraveAPIKey string
|
||||
BraveMaxResults int
|
||||
BraveEnabled bool
|
||||
TavilyAPIKey string
|
||||
TavilyBaseURL string
|
||||
TavilyMaxResults int
|
||||
TavilyEnabled bool
|
||||
DuckDuckGoMaxResults int
|
||||
DuckDuckGoEnabled bool
|
||||
PerplexityAPIKey string
|
||||
@@ -267,7 +354,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > DuckDuckGo
|
||||
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
@@ -278,6 +365,14 @@ func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
provider = &TavilySearchProvider{
|
||||
apiKey: opts.TavilyAPIKey,
|
||||
baseURL: opts.TavilyBaseURL,
|
||||
}
|
||||
if opts.TavilyMaxResults > 0 {
|
||||
maxResults = opts.TavilyMaxResults
|
||||
}
|
||||
} else if opts.DuckDuckGoEnabled {
|
||||
provider = &DuckDuckGoSearchProvider{}
|
||||
if opts.DuckDuckGoMaxResults > 0 {
|
||||
|
||||
@@ -333,3 +333,75 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_TavilySearch_Success verifies successful Tavily search
|
||||
func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
// Verify payload
|
||||
var payload map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&payload)
|
||||
if payload["api_key"] != "test-key" {
|
||||
t.Errorf("Expected api_key test-key, got %v", payload["api_key"])
|
||||
}
|
||||
if payload["query"] != "test query" {
|
||||
t.Errorf("Expected query 'test query', got %v", payload["query"])
|
||||
}
|
||||
|
||||
// Return mock response
|
||||
response := map[string]any{
|
||||
"results": []map[string]any{
|
||||
{
|
||||
"title": "Test Result 1",
|
||||
"url": "https://example.com/1",
|
||||
"content": "Content for result 1",
|
||||
},
|
||||
{
|
||||
"title": "Test Result 2",
|
||||
"url": "https://example.com/2",
|
||||
"content": "Content for result 2",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKey: "test-key",
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"query": "test query",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain result titles and URLs
|
||||
if !strings.Contains(result.ForUser, "Test Result 1") ||
|
||||
!strings.Contains(result.ForUser, "https://example.com/1") {
|
||||
t.Errorf("Expected results in output, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Should mention via Tavily
|
||||
if !strings.Contains(result.ForUser, "via Tavily") {
|
||||
t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@ package utils
|
||||
// Handles multi-byte Unicode characters properly.
|
||||
// If the string is truncated, "..." is appended to indicate truncation.
|
||||
func Truncate(s string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "short string unchanged",
|
||||
input: "hi",
|
||||
maxLen: 10,
|
||||
want: "hi",
|
||||
},
|
||||
{
|
||||
name: "exact length unchanged",
|
||||
input: "hello",
|
||||
maxLen: 5,
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "long string truncated with ellipsis",
|
||||
input: "hello world",
|
||||
maxLen: 8,
|
||||
want: "hello...",
|
||||
},
|
||||
{
|
||||
name: "maxLen equals 4 leaves 1 char plus ellipsis",
|
||||
input: "abcdef",
|
||||
maxLen: 4,
|
||||
want: "a...",
|
||||
},
|
||||
{
|
||||
name: "maxLen 3 returns first 3 chars without ellipsis",
|
||||
input: "abcdef",
|
||||
maxLen: 3,
|
||||
want: "abc",
|
||||
},
|
||||
{
|
||||
name: "maxLen 2 returns first 2 chars",
|
||||
input: "abcdef",
|
||||
maxLen: 2,
|
||||
want: "ab",
|
||||
},
|
||||
{
|
||||
name: "maxLen 1 returns first char",
|
||||
input: "abcdef",
|
||||
maxLen: 1,
|
||||
want: "a",
|
||||
},
|
||||
{
|
||||
name: "maxLen 0 returns empty",
|
||||
input: "hello",
|
||||
maxLen: 0,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "negative maxLen returns empty",
|
||||
input: "hello",
|
||||
maxLen: -1,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty string unchanged",
|
||||
input: "",
|
||||
maxLen: 5,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty string with zero maxLen",
|
||||
input: "",
|
||||
maxLen: 0,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "unicode truncated correctly",
|
||||
input: "\U0001f600\U0001f601\U0001f602\U0001f603\U0001f604",
|
||||
maxLen: 4,
|
||||
want: "\U0001f600...",
|
||||
},
|
||||
{
|
||||
name: "unicode short enough",
|
||||
input: "\u00e9\u00e8",
|
||||
maxLen: 5,
|
||||
want: "\u00e9\u00e8",
|
||||
},
|
||||
{
|
||||
name: "mixed ascii and unicode",
|
||||
input: "Go\U0001f680\U0001f525\U0001f4a5\U0001f30d",
|
||||
maxLen: 5,
|
||||
want: "Go...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := Truncate(tt.input, tt.maxLen)
|
||||
if got != tt.want {
|
||||
t.Errorf("Truncate(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user